From 6b562b1975bfbf6539e0bc67db8326e072588a88 Mon Sep 17 00:00:00 2001 From: chesterxu Date: Fri, 26 Apr 2024 21:11:45 +0800 Subject: [PATCH 1/7] init --- be/src/vec/functions/function_floor.h | 245 ++++++++++++++++++ be/src/vec/functions/math.cpp | 7 +- be/src/vec/functions/round.h | 16 +- .../doris/analysis/FunctionCallExpr.java | 4 +- .../functions/ComputePrecisionForRound.java | 3 +- 5 files changed, 262 insertions(+), 13 deletions(-) create mode 100644 be/src/vec/functions/function_floor.h diff --git a/be/src/vec/functions/function_floor.h b/be/src/vec/functions/function_floor.h new file mode 100644 index 00000000000000..d4616ae959b10a --- /dev/null +++ b/be/src/vec/functions/function_floor.h @@ -0,0 +1,245 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include +#include +#include +#include + +#include "common/exception.h" +#include "common/status.h" +#include "olap/olap_common.h" +#include "round.h" +#include "vec/columns/column.h" +#include "vec/columns/column_const.h" +#include "vec/columns/column_decimal.h" +#include "vec/columns/column_vector.h" +#include "vec/common/assert_cast.h" +#include "vec/core/call_on_type_index.h" +#include "vec/core/field.h" +#include "vec/core/types.h" +#include "vec/data_types/data_type.h" +#include "vec/data_types/data_type_decimal.h" +#include "vec/data_types/data_type_number.h" + +namespace doris::vectorized { + +struct FloorFloatOneArgImpl { + static constexpr auto name = "floor"; + static DataTypes get_variadic_argument_types() { return {std::make_shared()}; } +}; + +struct FloorFloatTwoArgImpl { + static constexpr auto name = "floor"; + static DataTypes get_variadic_argument_types() { + return {std::make_shared(), std::make_shared()}; + } +}; + +struct FloorDecimalOneArgImpl { + static constexpr auto name = "floor"; + static DataTypes get_variadic_argument_types() { + // All Decimal types are named Decimal, and real scale will be passed as type argument for execute function + // So we can just register Decimal32 here + return {std::make_shared>(9, 0)}; + } +}; + +struct FloorDecimalTwoArgImpl { + static constexpr auto name = "floor"; + static DataTypes get_variadic_argument_types() { + return {std::make_shared>(9, 0), + std::make_shared()}; + } +}; + +template +class FunctionFloor : public FunctionRounding { +public: + static FunctionPtr create() { return std::make_shared(); } + + ColumnNumbers get_arguments_that_are_always_constant() const override { return {}; } + // SELECT number, floor(123.345, 1) FROM numbers("number"="10") + // should NOT behave like two column arguments, so we can not use const column default implementation + bool use_default_implementation_for_constants() const override { return false; } + + Status execute_impl(FunctionContext* context, Block& block, const ColumnNumbers& arguments, + size_t result, size_t input_rows_count) const override { + const ColumnWithTypeAndName& column_general = block.get_by_position(arguments[0]); + ColumnPtr res; + + // potential argument types: + // 0. floor(ColumnConst, ColumnConst) + // 1. floor(Column), floor(Column, ColumnConst) + // 2. floor(Column, Column) + // 3. floor(ColumnConst, Column) + + if (arguments.size() == 2 && is_column_const(*block.get_by_position(arguments[0]).column) && + is_column_const(*block.get_by_position(arguments[1]).column)) { + // floor(ColumnConst, ColumnConst) + auto col_general = + assert_cast(*column_general.column).get_data_column_ptr(); + Int16 scale_arg = 0; + RETURN_IF_ERROR(FunctionFloor::get_scale_arg( + block.get_by_position(arguments[1]), &scale_arg)); + + auto call = [&](const auto& types) -> bool { + using Types = std::decay_t; + using DataType = typename Types::LeftType; + + if constexpr (IsDataTypeNumber || IsDataTypeDecimal) { + using FieldType = typename DataType::FieldType; + res = Dispatcher::apply_vec_const(col_general, + scale_arg); + return true; + } + + return false; + }; + +#if !defined(__SSE4_1__) && !defined(__aarch64__) + /// In case of "nearbyint" function is used, we should ensure the expected rounding mode for the Banker's rounding. + /// Actually it is by default. But we will set it just in case. + + if constexpr (rounding_mode == RoundingMode::Round) { + if (0 != fesetround(FE_TONEAREST)) { + return Status::InvalidArgument("Cannot set floating point rounding mode"); + } + } +#endif + + if (!call_on_index_and_data_type(column_general.type->get_type_id(), call)) { + return Status::InvalidArgument("Invalid argument type {} for function {}", + column_general.type->get_name(), "floor"); + } + // Important, make sure the result column has the same size as the input column + res = ColumnConst::create(std::move(res), input_rows_count); + } else if (arguments.size() == 1 || + (arguments.size() == 2 && + is_column_const(*block.get_by_position(arguments[1]).column))) { + // floor(Column) or floor(Column, ColumnConst) + Int16 scale_arg = 0; + if (arguments.size() == 2) { + RETURN_IF_ERROR(FunctionFloor::get_scale_arg( + block.get_by_position(arguments[1]), &scale_arg)); + } + + auto call = [&](const auto& types) -> bool { + using Types = std::decay_t; + using DataType = typename Types::LeftType; + + if constexpr (IsDataTypeNumber || IsDataTypeDecimal) { + using FieldType = typename DataType::FieldType; + res = Dispatcher:: + apply_vec_const(column_general.column.get(), scale_arg); + return true; + } + + return false; + }; +#if !defined(__SSE4_1__) && !defined(__aarch64__) + /// In case of "nearbyint" function is used, we should ensure the expected rounding mode for the Banker's rounding. + /// Actually it is by default. But we will set it just in case. + + if constexpr (rounding_mode == RoundingMode::Round) { + if (0 != fesetround(FE_TONEAREST)) { + return Status::InvalidArgument("Cannot set floating point rounding mode"); + } + } +#endif + + if (!call_on_index_and_data_type(column_general.type->get_type_id(), call)) { + return Status::InvalidArgument("Invalid argument type {} for function {}", + column_general.type->get_name(), "floor"); + } + + } else if (is_column_const(*block.get_by_position(arguments[0]).column)) { + // floor(ColumnConst, Column) + const ColumnWithTypeAndName& column_scale = block.get_by_position(arguments[1]); + const ColumnConst& const_col_general = + assert_cast(*column_general.column); + + auto call = [&](const auto& types) -> bool { + using Types = std::decay_t; + using DataType = typename Types::LeftType; + + if constexpr (IsDataTypeNumber || IsDataTypeDecimal) { + using FieldType = typename DataType::FieldType; + res = Dispatcher:: + apply_const_vec(&const_col_general, column_scale.column.get()); + return true; + } + + return false; + }; + +#if !defined(__SSE4_1__) && !defined(__aarch64__) + /// In case of "nearbyint" function is used, we should ensure the expected rounding mode for the Banker's rounding. + /// Actually it is by default. But we will set it just in case. + + if constexpr (rounding_mode == RoundingMode::Round) { + if (0 != fesetround(FE_TONEAREST)) { + return Status::InvalidArgument("Cannot set floating point rounding mode"); + } + } +#endif + + if (!call_on_index_and_data_type(column_general.type->get_type_id(), call)) { + return Status::InvalidArgument("Invalid argument type {} for function {}", + column_general.type->get_name(), "floor"); + } + } else { + // floor(Column, Column) + const ColumnWithTypeAndName& column_scale = block.get_by_position(arguments[1]); + + auto call = [&](const auto& types) -> bool { + using Types = std::decay_t; + using DataType = typename Types::LeftType; + + if constexpr (IsDataTypeNumber || IsDataTypeDecimal) { + using FieldType = typename DataType::FieldType; + res = Dispatcher:: + apply_vec_vec(column_general.column.get(), column_scale.column.get()); + return true; + } + return false; + }; + +#if !defined(__SSE4_1__) && !defined(__aarch64__) + /// In case of "nearbyint" function is used, we should ensure the expected rounding mode for the Banker's rounding. + /// Actually it is by default. But we will set it just in case. + + if constexpr (rounding_mode == RoundingMode::Round) { + if (0 != fesetround(FE_TONEAREST)) { + return Status::InvalidArgument("Cannot set floating point rounding mode"); + } + } +#endif + + if (!call_on_index_and_data_type(column_general.type->get_type_id(), call)) { + return Status::InvalidArgument("Invalid argument type {} for function {}", + column_general.type->get_name(), "floor"); + } + } + + block.replace_by_position(result, std::move(res)); + return Status::OK(); + } +}; + +} // namespace doris::vectorized diff --git a/be/src/vec/functions/math.cpp b/be/src/vec/functions/math.cpp index c0dfe7615764ba..35c567b1a7185d 100644 --- a/be/src/vec/functions/math.cpp +++ b/be/src/vec/functions/math.cpp @@ -42,6 +42,7 @@ #include "vec/data_types/number_traits.h" #include "vec/functions/function_binary_arithmetic.h" #include "vec/functions/function_const.h" +#include "vec/functions/function_floor.h" #include "vec/functions/function_math_log.h" #include "vec/functions/function_math_unary.h" #include "vec/functions/function_string.h" @@ -396,8 +397,6 @@ void register_function_math(SimpleFunctionFactory& factory) { #define REGISTER_ROUND_FUNCTIONS(IMPL) \ factory.register_function< \ FunctionRounding, RoundingMode::Round, TieBreakingMode::Auto>>(); \ - factory.register_function< \ - FunctionRounding, RoundingMode::Floor, TieBreakingMode::Auto>>(); \ factory.register_function< \ FunctionRounding, RoundingMode::Ceil, TieBreakingMode::Auto>>(); \ factory.register_function, RoundingMode::Round, \ @@ -448,5 +447,9 @@ void register_function_math(SimpleFunctionFactory& factory) { factory.register_function>(); factory.register_function>(); factory.register_function>(); + factory.register_function>(); + factory.register_function>(); + factory.register_function>(); + factory.register_function>(); } } // namespace doris::vectorized diff --git a/be/src/vec/functions/round.h b/be/src/vec/functions/round.h index a9d1e7a019c0a4..b3021cf18e67c0 100644 --- a/be/src/vec/functions/round.h +++ b/be/src/vec/functions/round.h @@ -483,10 +483,10 @@ struct Dispatcher { // NOTE: This function is only tested for truncate // DO NOT USE THIS METHOD FOR OTHER ROUNDING BASED FUNCTION UNTIL YOU KNOW EXACTLY WHAT YOU ARE DOING !!! static ColumnPtr apply_vec_vec(const IColumn* col_general, const IColumn* col_scale) { - if constexpr (rounding_mode != RoundingMode::Trunc) { - throw doris::Exception(ErrorCode::INVALID_ARGUMENT, - "Using column as scale is only supported for function truncate"); - } + // if constexpr (rounding_mode != RoundingMode::Trunc) { + // throw doris::Exception(ErrorCode::INVALID_ARGUMENT, + // "Using column as scale is only supported for function truncate"); + // } const ColumnInt32& col_scale_i32 = assert_cast(*col_scale); const size_t input_row_count = col_scale_i32.size(); @@ -568,10 +568,10 @@ struct Dispatcher { // DO NOT USE THIS METHOD FOR OTHER ROUNDING BASED FUNCTION UNTIL YOU KNOW EXACTLY WHAT YOU ARE DOING !!! only test for truncate static ColumnPtr apply_const_vec(const ColumnConst* const_col_general, const IColumn* col_scale) { - if constexpr (rounding_mode != RoundingMode::Trunc) { - throw doris::Exception(ErrorCode::INVALID_ARGUMENT, - "Using column as scale is only supported for function truncate"); - } + // if constexpr (rounding_mode != RoundingMode::Trunc) { + // throw doris::Exception(ErrorCode::INVALID_ARGUMENT, + // "Using column as scale is only supported for function truncate"); + // } const ColumnInt32& col_scale_i32 = assert_cast(*col_scale); const size_t input_rows_count = col_scale->size(); diff --git a/fe/fe-core/src/main/java/org/apache/doris/analysis/FunctionCallExpr.java b/fe/fe-core/src/main/java/org/apache/doris/analysis/FunctionCallExpr.java index 47182ebd6d59f5..34cd8c7b120b7d 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/analysis/FunctionCallExpr.java +++ b/fe/fe-core/src/main/java/org/apache/doris/analysis/FunctionCallExpr.java @@ -264,10 +264,10 @@ public class FunctionCallExpr extends Expr { PRECISION_INFER_RULE.put("round", roundRule); PRECISION_INFER_RULE.put("round_bankers", roundRule); PRECISION_INFER_RULE.put("ceil", roundRule); - PRECISION_INFER_RULE.put("floor", roundRule); + PRECISION_INFER_RULE.put("floor", truncateRule); PRECISION_INFER_RULE.put("dround", roundRule); PRECISION_INFER_RULE.put("dceil", roundRule); - PRECISION_INFER_RULE.put("dfloor", roundRule); + PRECISION_INFER_RULE.put("dfloor", truncateRule); PRECISION_INFER_RULE.put("truncate", truncateRule); } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/ComputePrecisionForRound.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/ComputePrecisionForRound.java index 6b6308c516ce58..40bf83c7406861 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/ComputePrecisionForRound.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/ComputePrecisionForRound.java @@ -20,6 +20,7 @@ import org.apache.doris.catalog.FunctionSignature; import org.apache.doris.nereids.trees.expressions.Cast; import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.functions.scalar.Floor; import org.apache.doris.nereids.trees.expressions.functions.scalar.Truncate; import org.apache.doris.nereids.trees.expressions.literal.IntegerLikeLiteral; import org.apache.doris.nereids.types.DecimalV3Type; @@ -40,7 +41,7 @@ default FunctionSignature computePrecision(FunctionSignature signature) { Expression floatLength = getArgument(1); int scale; - if (this instanceof Truncate) { + if (this instanceof Truncate || this instanceof Floor) { if (floatLength.isLiteral() || ( floatLength instanceof Cast && floatLength.child(0).isLiteral() && floatLength.child(0).getDataType() instanceof Int32OrLessType)) { From b20d635ab6dd8c3fe08b0af7901090bf4307b21d Mon Sep 17 00:00:00 2001 From: chesterxu Date: Sat, 27 Apr 2024 17:18:49 +0800 Subject: [PATCH 2/7] add be code --- be/src/vec/functions/function_floor.h | 245 ---- be/src/vec/functions/function_truncate.h | 245 ---- be/src/vec/functions/math.cpp | 84 -- be/src/vec/functions/round.cpp | 65 + be/src/vec/functions/round.h | 320 ++++- .../vec/functions/simple_function_factory.h | 2 + be/test/vec/function/function_round_test.cpp | 1146 +++++++++++++++++ .../function_truncate_decimal_test.cpp | 370 ------ 8 files changed, 1487 insertions(+), 990 deletions(-) delete mode 100644 be/src/vec/functions/function_floor.h delete mode 100644 be/src/vec/functions/function_truncate.h create mode 100644 be/src/vec/functions/round.cpp create mode 100644 be/test/vec/function/function_round_test.cpp delete mode 100644 be/test/vec/function/function_truncate_decimal_test.cpp diff --git a/be/src/vec/functions/function_floor.h b/be/src/vec/functions/function_floor.h deleted file mode 100644 index d4616ae959b10a..00000000000000 --- a/be/src/vec/functions/function_floor.h +++ /dev/null @@ -1,245 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -#include -#include -#include -#include - -#include "common/exception.h" -#include "common/status.h" -#include "olap/olap_common.h" -#include "round.h" -#include "vec/columns/column.h" -#include "vec/columns/column_const.h" -#include "vec/columns/column_decimal.h" -#include "vec/columns/column_vector.h" -#include "vec/common/assert_cast.h" -#include "vec/core/call_on_type_index.h" -#include "vec/core/field.h" -#include "vec/core/types.h" -#include "vec/data_types/data_type.h" -#include "vec/data_types/data_type_decimal.h" -#include "vec/data_types/data_type_number.h" - -namespace doris::vectorized { - -struct FloorFloatOneArgImpl { - static constexpr auto name = "floor"; - static DataTypes get_variadic_argument_types() { return {std::make_shared()}; } -}; - -struct FloorFloatTwoArgImpl { - static constexpr auto name = "floor"; - static DataTypes get_variadic_argument_types() { - return {std::make_shared(), std::make_shared()}; - } -}; - -struct FloorDecimalOneArgImpl { - static constexpr auto name = "floor"; - static DataTypes get_variadic_argument_types() { - // All Decimal types are named Decimal, and real scale will be passed as type argument for execute function - // So we can just register Decimal32 here - return {std::make_shared>(9, 0)}; - } -}; - -struct FloorDecimalTwoArgImpl { - static constexpr auto name = "floor"; - static DataTypes get_variadic_argument_types() { - return {std::make_shared>(9, 0), - std::make_shared()}; - } -}; - -template -class FunctionFloor : public FunctionRounding { -public: - static FunctionPtr create() { return std::make_shared(); } - - ColumnNumbers get_arguments_that_are_always_constant() const override { return {}; } - // SELECT number, floor(123.345, 1) FROM numbers("number"="10") - // should NOT behave like two column arguments, so we can not use const column default implementation - bool use_default_implementation_for_constants() const override { return false; } - - Status execute_impl(FunctionContext* context, Block& block, const ColumnNumbers& arguments, - size_t result, size_t input_rows_count) const override { - const ColumnWithTypeAndName& column_general = block.get_by_position(arguments[0]); - ColumnPtr res; - - // potential argument types: - // 0. floor(ColumnConst, ColumnConst) - // 1. floor(Column), floor(Column, ColumnConst) - // 2. floor(Column, Column) - // 3. floor(ColumnConst, Column) - - if (arguments.size() == 2 && is_column_const(*block.get_by_position(arguments[0]).column) && - is_column_const(*block.get_by_position(arguments[1]).column)) { - // floor(ColumnConst, ColumnConst) - auto col_general = - assert_cast(*column_general.column).get_data_column_ptr(); - Int16 scale_arg = 0; - RETURN_IF_ERROR(FunctionFloor::get_scale_arg( - block.get_by_position(arguments[1]), &scale_arg)); - - auto call = [&](const auto& types) -> bool { - using Types = std::decay_t; - using DataType = typename Types::LeftType; - - if constexpr (IsDataTypeNumber || IsDataTypeDecimal) { - using FieldType = typename DataType::FieldType; - res = Dispatcher::apply_vec_const(col_general, - scale_arg); - return true; - } - - return false; - }; - -#if !defined(__SSE4_1__) && !defined(__aarch64__) - /// In case of "nearbyint" function is used, we should ensure the expected rounding mode for the Banker's rounding. - /// Actually it is by default. But we will set it just in case. - - if constexpr (rounding_mode == RoundingMode::Round) { - if (0 != fesetround(FE_TONEAREST)) { - return Status::InvalidArgument("Cannot set floating point rounding mode"); - } - } -#endif - - if (!call_on_index_and_data_type(column_general.type->get_type_id(), call)) { - return Status::InvalidArgument("Invalid argument type {} for function {}", - column_general.type->get_name(), "floor"); - } - // Important, make sure the result column has the same size as the input column - res = ColumnConst::create(std::move(res), input_rows_count); - } else if (arguments.size() == 1 || - (arguments.size() == 2 && - is_column_const(*block.get_by_position(arguments[1]).column))) { - // floor(Column) or floor(Column, ColumnConst) - Int16 scale_arg = 0; - if (arguments.size() == 2) { - RETURN_IF_ERROR(FunctionFloor::get_scale_arg( - block.get_by_position(arguments[1]), &scale_arg)); - } - - auto call = [&](const auto& types) -> bool { - using Types = std::decay_t; - using DataType = typename Types::LeftType; - - if constexpr (IsDataTypeNumber || IsDataTypeDecimal) { - using FieldType = typename DataType::FieldType; - res = Dispatcher:: - apply_vec_const(column_general.column.get(), scale_arg); - return true; - } - - return false; - }; -#if !defined(__SSE4_1__) && !defined(__aarch64__) - /// In case of "nearbyint" function is used, we should ensure the expected rounding mode for the Banker's rounding. - /// Actually it is by default. But we will set it just in case. - - if constexpr (rounding_mode == RoundingMode::Round) { - if (0 != fesetround(FE_TONEAREST)) { - return Status::InvalidArgument("Cannot set floating point rounding mode"); - } - } -#endif - - if (!call_on_index_and_data_type(column_general.type->get_type_id(), call)) { - return Status::InvalidArgument("Invalid argument type {} for function {}", - column_general.type->get_name(), "floor"); - } - - } else if (is_column_const(*block.get_by_position(arguments[0]).column)) { - // floor(ColumnConst, Column) - const ColumnWithTypeAndName& column_scale = block.get_by_position(arguments[1]); - const ColumnConst& const_col_general = - assert_cast(*column_general.column); - - auto call = [&](const auto& types) -> bool { - using Types = std::decay_t; - using DataType = typename Types::LeftType; - - if constexpr (IsDataTypeNumber || IsDataTypeDecimal) { - using FieldType = typename DataType::FieldType; - res = Dispatcher:: - apply_const_vec(&const_col_general, column_scale.column.get()); - return true; - } - - return false; - }; - -#if !defined(__SSE4_1__) && !defined(__aarch64__) - /// In case of "nearbyint" function is used, we should ensure the expected rounding mode for the Banker's rounding. - /// Actually it is by default. But we will set it just in case. - - if constexpr (rounding_mode == RoundingMode::Round) { - if (0 != fesetround(FE_TONEAREST)) { - return Status::InvalidArgument("Cannot set floating point rounding mode"); - } - } -#endif - - if (!call_on_index_and_data_type(column_general.type->get_type_id(), call)) { - return Status::InvalidArgument("Invalid argument type {} for function {}", - column_general.type->get_name(), "floor"); - } - } else { - // floor(Column, Column) - const ColumnWithTypeAndName& column_scale = block.get_by_position(arguments[1]); - - auto call = [&](const auto& types) -> bool { - using Types = std::decay_t; - using DataType = typename Types::LeftType; - - if constexpr (IsDataTypeNumber || IsDataTypeDecimal) { - using FieldType = typename DataType::FieldType; - res = Dispatcher:: - apply_vec_vec(column_general.column.get(), column_scale.column.get()); - return true; - } - return false; - }; - -#if !defined(__SSE4_1__) && !defined(__aarch64__) - /// In case of "nearbyint" function is used, we should ensure the expected rounding mode for the Banker's rounding. - /// Actually it is by default. But we will set it just in case. - - if constexpr (rounding_mode == RoundingMode::Round) { - if (0 != fesetround(FE_TONEAREST)) { - return Status::InvalidArgument("Cannot set floating point rounding mode"); - } - } -#endif - - if (!call_on_index_and_data_type(column_general.type->get_type_id(), call)) { - return Status::InvalidArgument("Invalid argument type {} for function {}", - column_general.type->get_name(), "floor"); - } - } - - block.replace_by_position(result, std::move(res)); - return Status::OK(); - } -}; - -} // namespace doris::vectorized diff --git a/be/src/vec/functions/function_truncate.h b/be/src/vec/functions/function_truncate.h deleted file mode 100644 index e29bc99c0417dc..00000000000000 --- a/be/src/vec/functions/function_truncate.h +++ /dev/null @@ -1,245 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -#include -#include -#include -#include - -#include "common/exception.h" -#include "common/status.h" -#include "olap/olap_common.h" -#include "round.h" -#include "vec/columns/column.h" -#include "vec/columns/column_const.h" -#include "vec/columns/column_decimal.h" -#include "vec/columns/column_vector.h" -#include "vec/common/assert_cast.h" -#include "vec/core/call_on_type_index.h" -#include "vec/core/field.h" -#include "vec/core/types.h" -#include "vec/data_types/data_type.h" -#include "vec/data_types/data_type_decimal.h" -#include "vec/data_types/data_type_number.h" - -namespace doris::vectorized { - -struct TruncateFloatOneArgImpl { - static constexpr auto name = "truncate"; - static DataTypes get_variadic_argument_types() { return {std::make_shared()}; } -}; - -struct TruncateFloatTwoArgImpl { - static constexpr auto name = "truncate"; - static DataTypes get_variadic_argument_types() { - return {std::make_shared(), std::make_shared()}; - } -}; - -struct TruncateDecimalOneArgImpl { - static constexpr auto name = "truncate"; - static DataTypes get_variadic_argument_types() { - // All Decimal types are named Decimal, and real scale will be passed as type argument for execute function - // So we can just register Decimal32 here - return {std::make_shared>(9, 0)}; - } -}; - -struct TruncateDecimalTwoArgImpl { - static constexpr auto name = "truncate"; - static DataTypes get_variadic_argument_types() { - return {std::make_shared>(9, 0), - std::make_shared()}; - } -}; - -template -class FunctionTruncate : public FunctionRounding { -public: - static FunctionPtr create() { return std::make_shared(); } - - ColumnNumbers get_arguments_that_are_always_constant() const override { return {}; } - // SELECT number, truncate(123.345, 1) FROM number("numbers"="10") - // should NOT behave like two column arguments, so we can not use const column default implementation - bool use_default_implementation_for_constants() const override { return false; } - - Status execute_impl(FunctionContext* context, Block& block, const ColumnNumbers& arguments, - size_t result, size_t input_rows_count) const override { - const ColumnWithTypeAndName& column_general = block.get_by_position(arguments[0]); - ColumnPtr res; - - // potential argument types: - // 0. truncate(ColumnConst, ColumnConst) - // 1. truncate(Column), truncate(Column, ColumnConst) - // 2. truncate(Column, Column) - // 3. truncate(ColumnConst, Column) - - if (arguments.size() == 2 && is_column_const(*block.get_by_position(arguments[0]).column) && - is_column_const(*block.get_by_position(arguments[1]).column)) { - // truncate(ColumnConst, ColumnConst) - auto col_general = - assert_cast(*column_general.column).get_data_column_ptr(); - Int16 scale_arg = 0; - RETURN_IF_ERROR(FunctionTruncate::get_scale_arg( - block.get_by_position(arguments[1]), &scale_arg)); - - auto call = [&](const auto& types) -> bool { - using Types = std::decay_t; - using DataType = typename Types::LeftType; - - if constexpr (IsDataTypeNumber || IsDataTypeDecimal) { - using FieldType = typename DataType::FieldType; - res = Dispatcher::apply_vec_const(col_general, - scale_arg); - return true; - } - - return false; - }; - -#if !defined(__SSE4_1__) && !defined(__aarch64__) - /// In case of "nearbyint" function is used, we should ensure the expected rounding mode for the Banker's rounding. - /// Actually it is by default. But we will set it just in case. - - if constexpr (rounding_mode == RoundingMode::Round) { - if (0 != fesetround(FE_TONEAREST)) { - return Status::InvalidArgument("Cannot set floating point rounding mode"); - } - } -#endif - - if (!call_on_index_and_data_type(column_general.type->get_type_id(), call)) { - return Status::InvalidArgument("Invalid argument type {} for function {}", - column_general.type->get_name(), "truncate"); - } - // Important, make sure the result column has the same size as the input column - res = ColumnConst::create(std::move(res), input_rows_count); - } else if (arguments.size() == 1 || - (arguments.size() == 2 && - is_column_const(*block.get_by_position(arguments[1]).column))) { - // truncate(Column) or truncate(Column, ColumnConst) - Int16 scale_arg = 0; - if (arguments.size() == 2) { - RETURN_IF_ERROR(FunctionTruncate::get_scale_arg( - block.get_by_position(arguments[1]), &scale_arg)); - } - - auto call = [&](const auto& types) -> bool { - using Types = std::decay_t; - using DataType = typename Types::LeftType; - - if constexpr (IsDataTypeNumber || IsDataTypeDecimal) { - using FieldType = typename DataType::FieldType; - res = Dispatcher:: - apply_vec_const(column_general.column.get(), scale_arg); - return true; - } - - return false; - }; -#if !defined(__SSE4_1__) && !defined(__aarch64__) - /// In case of "nearbyint" function is used, we should ensure the expected rounding mode for the Banker's rounding. - /// Actually it is by default. But we will set it just in case. - - if constexpr (rounding_mode == RoundingMode::Round) { - if (0 != fesetround(FE_TONEAREST)) { - return Status::InvalidArgument("Cannot set floating point rounding mode"); - } - } -#endif - - if (!call_on_index_and_data_type(column_general.type->get_type_id(), call)) { - return Status::InvalidArgument("Invalid argument type {} for function {}", - column_general.type->get_name(), "truncate"); - } - - } else if (is_column_const(*block.get_by_position(arguments[0]).column)) { - // truncate(ColumnConst, Column) - const ColumnWithTypeAndName& column_scale = block.get_by_position(arguments[1]); - const ColumnConst& const_col_general = - assert_cast(*column_general.column); - - auto call = [&](const auto& types) -> bool { - using Types = std::decay_t; - using DataType = typename Types::LeftType; - - if constexpr (IsDataTypeNumber || IsDataTypeDecimal) { - using FieldType = typename DataType::FieldType; - res = Dispatcher:: - apply_const_vec(&const_col_general, column_scale.column.get()); - return true; - } - - return false; - }; - -#if !defined(__SSE4_1__) && !defined(__aarch64__) - /// In case of "nearbyint" function is used, we should ensure the expected rounding mode for the Banker's rounding. - /// Actually it is by default. But we will set it just in case. - - if constexpr (rounding_mode == RoundingMode::Round) { - if (0 != fesetround(FE_TONEAREST)) { - return Status::InvalidArgument("Cannot set floating point rounding mode"); - } - } -#endif - - if (!call_on_index_and_data_type(column_general.type->get_type_id(), call)) { - return Status::InvalidArgument("Invalid argument type {} for function {}", - column_general.type->get_name(), "truncate"); - } - } else { - // truncate(Column, Column) - const ColumnWithTypeAndName& column_scale = block.get_by_position(arguments[1]); - - auto call = [&](const auto& types) -> bool { - using Types = std::decay_t; - using DataType = typename Types::LeftType; - - if constexpr (IsDataTypeNumber || IsDataTypeDecimal) { - using FieldType = typename DataType::FieldType; - res = Dispatcher:: - apply_vec_vec(column_general.column.get(), column_scale.column.get()); - return true; - } - return false; - }; - -#if !defined(__SSE4_1__) && !defined(__aarch64__) - /// In case of "nearbyint" function is used, we should ensure the expected rounding mode for the Banker's rounding. - /// Actually it is by default. But we will set it just in case. - - if constexpr (rounding_mode == RoundingMode::Round) { - if (0 != fesetround(FE_TONEAREST)) { - return Status::InvalidArgument("Cannot set floating point rounding mode"); - } - } -#endif - - if (!call_on_index_and_data_type(column_general.type->get_type_id(), call)) { - return Status::InvalidArgument("Invalid argument type {} for function {}", - column_general.type->get_name(), "truncate"); - } - } - - block.replace_by_position(result, std::move(res)); - return Status::OK(); - } -}; - -} // namespace doris::vectorized diff --git a/be/src/vec/functions/math.cpp b/be/src/vec/functions/math.cpp index 35c567b1a7185d..3596d2ebae6fb0 100644 --- a/be/src/vec/functions/math.cpp +++ b/be/src/vec/functions/math.cpp @@ -42,14 +42,11 @@ #include "vec/data_types/number_traits.h" #include "vec/functions/function_binary_arithmetic.h" #include "vec/functions/function_const.h" -#include "vec/functions/function_floor.h" #include "vec/functions/function_math_log.h" #include "vec/functions/function_math_unary.h" #include "vec/functions/function_string.h" #include "vec/functions/function_totype.h" -#include "vec/functions/function_truncate.h" #include "vec/functions/function_unary_arithmetic.h" -#include "vec/functions/round.h" #include "vec/functions/simple_function_factory.h" namespace doris { @@ -331,82 +328,9 @@ struct PowName { }; using FunctionPow = FunctionBinaryArithmetic; -struct TruncateName { - static constexpr auto name = "truncate"; -}; - -struct CeilName { - static constexpr auto name = "ceil"; -}; - -struct FloorName { - static constexpr auto name = "floor"; -}; - -struct RoundName { - static constexpr auto name = "round"; -}; - -struct RoundBankersName { - static constexpr auto name = "round_bankers"; -}; - -/// round(double,int32)-->double -/// key_str:roundFloat64Int32 -template -struct DoubleRoundTwoImpl { - static constexpr auto name = Name::name; - - static DataTypes get_variadic_argument_types() { - return {std::make_shared(), - std::make_shared()}; - } -}; - -template -struct DoubleRoundOneImpl { - static constexpr auto name = Name::name; - - static DataTypes get_variadic_argument_types() { - return {std::make_shared()}; - } -}; - -template -struct DecimalRoundTwoImpl { - static constexpr auto name = Name::name; - - static DataTypes get_variadic_argument_types() { - return {std::make_shared>(9, 0), - std::make_shared()}; - } -}; - -template -struct DecimalRoundOneImpl { - static constexpr auto name = Name::name; - - static DataTypes get_variadic_argument_types() { - return {std::make_shared>(9, 0)}; - } -}; - // TODO: Now math may cause one thread compile time too long, because the function in math // so mush. Split it to speed up compile time in the future void register_function_math(SimpleFunctionFactory& factory) { -#define REGISTER_ROUND_FUNCTIONS(IMPL) \ - factory.register_function< \ - FunctionRounding, RoundingMode::Round, TieBreakingMode::Auto>>(); \ - factory.register_function< \ - FunctionRounding, RoundingMode::Ceil, TieBreakingMode::Auto>>(); \ - factory.register_function, RoundingMode::Round, \ - TieBreakingMode::Bankers>>(); - - REGISTER_ROUND_FUNCTIONS(DecimalRoundOneImpl) - REGISTER_ROUND_FUNCTIONS(DecimalRoundTwoImpl) - REGISTER_ROUND_FUNCTIONS(DoubleRoundOneImpl) - REGISTER_ROUND_FUNCTIONS(DoubleRoundTwoImpl) - factory.register_alias("round", "dround"); factory.register_function(); factory.register_function(); factory.register_function(); @@ -443,13 +367,5 @@ void register_function_math(SimpleFunctionFactory& factory) { factory.register_function(); factory.register_function(); factory.register_function(); - factory.register_function>(); - factory.register_function>(); - factory.register_function>(); - factory.register_function>(); - factory.register_function>(); - factory.register_function>(); - factory.register_function>(); - factory.register_function>(); } } // namespace doris::vectorized diff --git a/be/src/vec/functions/round.cpp b/be/src/vec/functions/round.cpp new file mode 100644 index 00000000000000..6b504839729c63 --- /dev/null +++ b/be/src/vec/functions/round.cpp @@ -0,0 +1,65 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "vec/functions/round.h" + +#include +#include +#include +#include + +#include "common/exception.h" +#include "common/status.h" +#include "olap/olap_common.h" +#include "vec/columns/column.h" +#include "vec/columns/column_const.h" +#include "vec/columns/column_decimal.h" +#include "vec/columns/column_vector.h" +#include "vec/common/assert_cast.h" +#include "vec/core/call_on_type_index.h" +#include "vec/core/field.h" +#include "vec/core/types.h" +#include "vec/data_types/data_type.h" +#include "vec/data_types/data_type_decimal.h" +#include "vec/data_types/data_type_number.h" +#include "vec/functions/simple_function_factory.h" + +namespace doris::vectorized { + +// We split round funcs from register_function_math() in math.cpp to here, +// so that to speed up compile time and make code more readable. +void register_function_round(SimpleFunctionFactory& factory) { +#define REGISTER_ROUND_FUNCTIONS(IMPL) \ + factory.register_function, RoundingMode::Trunc, \ + TieBreakingMode::Auto>>(); \ + factory.register_function< \ + FunctionRounding, RoundingMode::Floor, TieBreakingMode::Auto>>(); \ + factory.register_function< \ + FunctionRounding, RoundingMode::Round, TieBreakingMode::Auto>>(); \ + factory.register_function< \ + FunctionRounding, RoundingMode::Ceil, TieBreakingMode::Auto>>(); \ + factory.register_function, RoundingMode::Round, \ + TieBreakingMode::Bankers>>(); + + REGISTER_ROUND_FUNCTIONS(DecimalRoundOneImpl) + REGISTER_ROUND_FUNCTIONS(DecimalRoundTwoImpl) + REGISTER_ROUND_FUNCTIONS(DoubleRoundOneImpl) + REGISTER_ROUND_FUNCTIONS(DoubleRoundTwoImpl) + factory.register_alias("round", "dround"); +} + +} // namespace doris::vectorized diff --git a/be/src/vec/functions/round.h b/be/src/vec/functions/round.h index b3021cf18e67c0..417b88d3e6f9fc 100644 --- a/be/src/vec/functions/round.h +++ b/be/src/vec/functions/round.h @@ -184,8 +184,6 @@ class DecimalRoundingImpl { } } - // NOTE: This function is only tested for truncate - // DO NOT USE THIS METHOD FOR OTHER ROUNDING BASED FUNCTION UNTIL YOU KNOW EXACTLY WHAT YOU ARE DOING !!! static NO_INLINE void apply(const NativeType& in, UInt32 in_scale, NativeType& out, Int16 out_scale) { Int16 scale_arg = in_scale - out_scale; @@ -480,15 +478,8 @@ struct Dispatcher { } } - // NOTE: This function is only tested for truncate - // DO NOT USE THIS METHOD FOR OTHER ROUNDING BASED FUNCTION UNTIL YOU KNOW EXACTLY WHAT YOU ARE DOING !!! static ColumnPtr apply_vec_vec(const IColumn* col_general, const IColumn* col_scale) { - // if constexpr (rounding_mode != RoundingMode::Trunc) { - // throw doris::Exception(ErrorCode::INVALID_ARGUMENT, - // "Using column as scale is only supported for function truncate"); - // } - - const ColumnInt32& col_scale_i32 = assert_cast(*col_scale); + const auto& col_scale_i32 = assert_cast(*col_scale); const size_t input_row_count = col_scale_i32.size(); for (size_t i = 0; i < input_row_count; ++i) { const Int32 scale_arg = col_scale_i32.get_data()[i]; @@ -537,7 +528,7 @@ struct Dispatcher { } for (size_t i = 0; i < input_row_count; ++i) { - // For truncate(ColumnDecimal, ColumnInt32), we should always have same scale with source Decimal column + // For func(ColumnDecimal, ColumnInt32), we should always have same scale with source Decimal column // So we need this check to make sure the result have correct digits count // // Case 0: scale_arg <= -(integer part digits count) @@ -564,15 +555,8 @@ struct Dispatcher { } } - // NOTE: This function is only tested for truncate - // DO NOT USE THIS METHOD FOR OTHER ROUNDING BASED FUNCTION UNTIL YOU KNOW EXACTLY WHAT YOU ARE DOING !!! only test for truncate static ColumnPtr apply_const_vec(const ColumnConst* const_col_general, const IColumn* col_scale) { - // if constexpr (rounding_mode != RoundingMode::Trunc) { - // throw doris::Exception(ErrorCode::INVALID_ARGUMENT, - // "Using column as scale is only supported for function truncate"); - // } - const ColumnInt32& col_scale_i32 = assert_cast(*col_scale); const size_t input_rows_count = col_scale->size(); @@ -602,7 +586,7 @@ struct Dispatcher { } for (size_t i = 0; i < input_rows_count; ++i) { - // For truncate(ColumnDecimal, ColumnInt32), we should always have same scale with source Decimal column + // For func(ColumnDecimal, ColumnInt32), we should always have same scale with source Decimal column // So we need this check to make sure the result have correct digits count // // Case 0: scale_arg <= -(integer part digits count) @@ -696,44 +680,170 @@ class FunctionRounding : public IFunction { return Status::OK(); } - ColumnNumbers get_arguments_that_are_always_constant() const override { return {1}; } + ColumnNumbers get_arguments_that_are_always_constant() const override { return {}; } + // SELECT number, func(123.345, 1) FROM numbers("number"="10") + // should NOT behave like two column arguments, so we can not use const column default implementation + bool use_default_implementation_for_constants() const override { return false; } Status execute_impl(FunctionContext* context, Block& block, const ColumnNumbers& arguments, - size_t result, size_t /*input_rows_count*/) const override { - const ColumnWithTypeAndName& column = block.get_by_position(arguments[0]); - Int16 scale_arg = 0; - if (arguments.size() == 2) { + size_t result, size_t input_rows_count) const override { + const ColumnWithTypeAndName& column_general = block.get_by_position(arguments[0]); + ColumnPtr res; + + // potential argument types: + // 0. func(ColumnConst, ColumnConst) + // 1. func(Column), func(ColumnConst), func(Column, ColumnConst) + // 2. func(Column, Column) + // 3. func(ColumnConst, Column) + + if (arguments.size() == 2 && is_column_const(*block.get_by_position(arguments[0]).column) && + is_column_const(*block.get_by_position(arguments[1]).column)) { + // func(ColumnConst, ColumnConst) + auto col_general = + assert_cast(*column_general.column).get_data_column_ptr(); + Int16 scale_arg = 0; RETURN_IF_ERROR(get_scale_arg(block.get_by_position(arguments[1]), &scale_arg)); - } - ColumnPtr res; - auto call = [&](const auto& types) -> bool { - using Types = std::decay_t; - using DataType = typename Types::LeftType; - - if constexpr (IsDataTypeNumber || IsDataTypeDecimal) { - using FieldType = typename DataType::FieldType; - res = Dispatcher::apply_vec_const( - column.column.get(), scale_arg); - return true; + auto call = [&](const auto& types) -> bool { + using Types = std::decay_t; + using DataType = typename Types::LeftType; + + if constexpr (IsDataTypeNumber || IsDataTypeDecimal) { + using FieldType = typename DataType::FieldType; + res = Dispatcher::apply_vec_const( + col_general, scale_arg); + return true; + } + + return false; + }; + +#if !defined(__SSE4_1__) && !defined(__aarch64__) + /// In case of "nearbyint" function is used, we should ensure the expected rounding mode for the Banker's rounding. + /// Actually it is by default. But we will set it just in case. + + if constexpr (rounding_mode == RoundingMode::Round) { + if (0 != fesetround(FE_TONEAREST)) { + return Status::InvalidArgument("Cannot set floating point rounding mode"); + } + } +#endif + + if (!call_on_index_and_data_type(column_general.type->get_type_id(), call)) { + return Status::InvalidArgument("Invalid argument type {} for function {}", + column_general.type->get_name(), get_name()); + } + // Important, make sure the result column has the same size as the input column + res = ColumnConst::create(std::move(res), input_rows_count); + } else if (arguments.size() == 1 || + (arguments.size() == 2 && + is_column_const(*block.get_by_position(arguments[1]).column))) { + // func(Column) or func(ColumnConst) or func(Column, ColumnConst) + Int16 scale_arg = 0; + const auto* col_general = column_general.column.get(); + if (arguments.size() == 2) { + RETURN_IF_ERROR(get_scale_arg(block.get_by_position(arguments[1]), &scale_arg)); + } else if (is_column_const(*column_general.column)) { + // if we only have one ColumnConst + col_general = assert_cast(*column_general.column) + .get_data_column_ptr(); } - return false; - }; + auto call = [&](const auto& types) -> bool { + using Types = std::decay_t; + using DataType = typename Types::LeftType; + + if constexpr (IsDataTypeNumber || IsDataTypeDecimal) { + using FieldType = typename DataType::FieldType; + res = Dispatcher::apply_vec_const( + col_general, scale_arg); + return true; + } + + return false; + }; #if !defined(__SSE4_1__) && !defined(__aarch64__) - /// In case of "nearbyint" function is used, we should ensure the expected rounding mode for the Banker's rounding. - /// Actually it is by default. But we will set it just in case. + /// In case of "nearbyint" function is used, we should ensure the expected rounding mode for the Banker's rounding. + /// Actually it is by default. But we will set it just in case. - if constexpr (rounding_mode == RoundingMode::Round) { - if (0 != fesetround(FE_TONEAREST)) { - return Status::InvalidArgument("Cannot set floating point rounding mode"); + if constexpr (rounding_mode == RoundingMode::Round) { + if (0 != fesetround(FE_TONEAREST)) { + return Status::InvalidArgument("Cannot set floating point rounding mode"); + } + } +#endif + + if (!call_on_index_and_data_type(column_general.type->get_type_id(), call)) { + return Status::InvalidArgument("Invalid argument type {} for function {}", + column_general.type->get_name(), get_name()); + } + + } else if (is_column_const(*block.get_by_position(arguments[0]).column)) { + // func(ColumnConst, Column) + const ColumnWithTypeAndName& column_scale = block.get_by_position(arguments[1]); + const auto& const_col_general = assert_cast(*column_general.column); + + auto call = [&](const auto& types) -> bool { + using Types = std::decay_t; + using DataType = typename Types::LeftType; + + if constexpr (IsDataTypeNumber || IsDataTypeDecimal) { + using FieldType = typename DataType::FieldType; + res = Dispatcher::apply_const_vec( + &const_col_general, column_scale.column.get()); + return true; + } + + return false; + }; + +#if !defined(__SSE4_1__) && !defined(__aarch64__) + /// In case of "nearbyint" function is used, we should ensure the expected rounding mode for the Banker's rounding. + /// Actually it is by default. But we will set it just in case. + + if constexpr (rounding_mode == RoundingMode::Round) { + if (0 != fesetround(FE_TONEAREST)) { + return Status::InvalidArgument("Cannot set floating point rounding mode"); + } } - } #endif - if (!call_on_index_and_data_type(column.type->get_type_id(), call)) { - return Status::InvalidArgument("Invalid argument type {} for function {}", - column.type->get_name(), name); + if (!call_on_index_and_data_type(column_general.type->get_type_id(), call)) { + return Status::InvalidArgument("Invalid argument type {} for function {}", + column_general.type->get_name(), get_name()); + } + } else { + // func(Column, Column) + const ColumnWithTypeAndName& column_scale = block.get_by_position(arguments[1]); + + auto call = [&](const auto& types) -> bool { + using Types = std::decay_t; + using DataType = typename Types::LeftType; + + if constexpr (IsDataTypeNumber || IsDataTypeDecimal) { + using FieldType = typename DataType::FieldType; + res = Dispatcher::apply_vec_vec( + column_general.column.get(), column_scale.column.get()); + return true; + } + return false; + }; + +#if !defined(__SSE4_1__) && !defined(__aarch64__) + /// In case of "nearbyint" function is used, we should ensure the expected rounding mode for the Banker's rounding. + /// Actually it is by default. But we will set it just in case. + + if constexpr (rounding_mode == RoundingMode::Round) { + if (0 != fesetround(FE_TONEAREST)) { + return Status::InvalidArgument("Cannot set floating point rounding mode"); + } + } +#endif + + if (!call_on_index_and_data_type(column_general.type->get_type_id(), call)) { + return Status::InvalidArgument("Invalid argument type {} for function {}", + column_general.type->get_name(), get_name()); + } } block.replace_by_position(result, std::move(res)); @@ -741,4 +851,122 @@ class FunctionRounding : public IFunction { } }; +struct TruncateFloatOneArgImpl { + static constexpr auto name = "truncate"; + static DataTypes get_variadic_argument_types() { return {std::make_shared()}; } +}; + +struct TruncateFloatTwoArgImpl { + static constexpr auto name = "truncate"; + static DataTypes get_variadic_argument_types() { + return {std::make_shared(), std::make_shared()}; + } +}; + +struct TruncateDecimalOneArgImpl { + static constexpr auto name = "truncate"; + static DataTypes get_variadic_argument_types() { + // All Decimal types are named Decimal, and real scale will be passed as type argument for execute function + // So we can just register Decimal32 here + return {std::make_shared>(9, 0)}; + } +}; + +struct TruncateDecimalTwoArgImpl { + static constexpr auto name = "truncate"; + static DataTypes get_variadic_argument_types() { + return {std::make_shared>(9, 0), + std::make_shared()}; + } +}; + +struct FloorFloatOneArgImpl { + static constexpr auto name = "floor"; + static DataTypes get_variadic_argument_types() { return {std::make_shared()}; } +}; + +struct FloorFloatTwoArgImpl { + static constexpr auto name = "floor"; + static DataTypes get_variadic_argument_types() { + return {std::make_shared(), std::make_shared()}; + } +}; + +struct FloorDecimalOneArgImpl { + static constexpr auto name = "floor"; + static DataTypes get_variadic_argument_types() { + // All Decimal types are named Decimal, and real scale will be passed as type argument for execute function + // So we can just register Decimal32 here + return {std::make_shared>(9, 0)}; + } +}; + +struct FloorDecimalTwoArgImpl { + static constexpr auto name = "floor"; + static DataTypes get_variadic_argument_types() { + return {std::make_shared>(9, 0), + std::make_shared()}; + } +}; + +struct TruncateName { + static constexpr auto name = "truncate"; +}; + +struct FloorName { + static constexpr auto name = "floor"; +}; + +struct CeilName { + static constexpr auto name = "ceil"; +}; + +struct RoundName { + static constexpr auto name = "round"; +}; + +struct RoundBankersName { + static constexpr auto name = "round_bankers"; +}; + +/// round(double,int32)-->double +/// key_str:roundFloat64Int32 +template +struct DoubleRoundTwoImpl { + static constexpr auto name = Name::name; + + static DataTypes get_variadic_argument_types() { + return {std::make_shared(), + std::make_shared()}; + } +}; + +template +struct DoubleRoundOneImpl { + static constexpr auto name = Name::name; + + static DataTypes get_variadic_argument_types() { + return {std::make_shared()}; + } +}; + +template +struct DecimalRoundTwoImpl { + static constexpr auto name = Name::name; + + static DataTypes get_variadic_argument_types() { + return {std::make_shared>(9, 0), + std::make_shared()}; + } +}; + +template +struct DecimalRoundOneImpl { + static constexpr auto name = Name::name; + + static DataTypes get_variadic_argument_types() { + return {std::make_shared>(9, 0)}; + } +}; + } // namespace doris::vectorized diff --git a/be/src/vec/functions/simple_function_factory.h b/be/src/vec/functions/simple_function_factory.h index 052e2e89134dc9..889a97436355c9 100644 --- a/be/src/vec/functions/simple_function_factory.h +++ b/be/src/vec/functions/simple_function_factory.h @@ -46,6 +46,7 @@ void register_function_int_div(SimpleFunctionFactory& factory); void register_function_bit(SimpleFunctionFactory& factory); void register_function_bit_count(SimpleFunctionFactory& factory); void register_function_bit_shift(SimpleFunctionFactory& factory); +void register_function_round(SimpleFunctionFactory& factory); void register_function_math(SimpleFunctionFactory& factory); void register_function_modulo(SimpleFunctionFactory& factory); void register_function_bitmap(SimpleFunctionFactory& factory); @@ -226,6 +227,7 @@ class SimpleFunctionFactory { register_function_conv(instance); register_function_plus(instance); register_function_minus(instance); + register_function_round(instance); register_function_math(instance); register_function_multiply(instance); register_function_divide(instance); diff --git a/be/test/vec/function/function_round_test.cpp b/be/test/vec/function/function_round_test.cpp new file mode 100644 index 00000000000000..b162000d7baa90 --- /dev/null +++ b/be/test/vec/function/function_round_test.cpp @@ -0,0 +1,1146 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "function_test_util.h" +#include "vec/columns/column.h" +#include "vec/columns/column_const.h" +#include "vec/columns/column_decimal.h" +#include "vec/columns/columns_number.h" +#include "vec/common/assert_cast.h" +#include "vec/core/column_numbers.h" +#include "vec/core/types.h" +#include "vec/data_types/data_type_decimal.h" +#include "vec/data_types/data_type_number.h" +#include "vec/functions/round.h" + +namespace doris::vectorized { +// {precision, scale} -> {input, scale_arg, expectation} +using DecimalTestDataSet = + std::map, std::vector>>; + +// input, scale_arg, expectation +using FloatTestDataSet = std::vector>; + +using TruncateFunction = FunctionRounding, RoundingMode::Trunc, + TieBreakingMode::Auto>; + +using FloorFunction = + FunctionRounding, RoundingMode::Floor, TieBreakingMode::Auto>; + +using CeilFunction = + FunctionRounding, RoundingMode::Ceil, TieBreakingMode::Auto>; + +using RoundFunction = + FunctionRounding, RoundingMode::Round, TieBreakingMode::Auto>; + +using RoundBankersFunction = FunctionRounding, + RoundingMode::Round, TieBreakingMode::Bankers>; + +// test cases for truncate and floor function of decimal32 +const static DecimalTestDataSet trunc_floor_decimal32_cases = { + {{1, 0}, + { + {1, -10, 0}, {1, -9, 0}, {1, -8, 0}, {1, -7, 0}, {1, -6, 0}, {1, -5, 0}, + {1, -4, 0}, {1, -3, 0}, {1, -2, 0}, {1, -1, 0}, {1, 0, 1}, {1, 1, 1}, + {1, 2, 1}, {1, 3, 1}, {1, 4, 1}, {1, 5, 1}, {1, 6, 1}, {1, 7, 1}, + {1, 8, 1}, {1, 9, 1}, {1, 10, 1}, + }}, + {{1, 1}, + { + {1, -10, 0}, {1, -9, 0}, {1, -8, 0}, {1, -7, 0}, {1, -6, 0}, {1, -5, 0}, + {1, -4, 0}, {1, -3, 0}, {1, -2, 0}, {1, -1, 0}, {1, 0, 0}, {1, 1, 1}, + {1, 2, 1}, {1, 3, 1}, {1, 4, 1}, {1, 5, 1}, {1, 6, 1}, {1, 7, 1}, + {1, 8, 1}, {1, 9, 1}, {1, 10, 1}, + }}, + {{2, 0}, + { + {12, -4, 0}, + {12, -3, 0}, + {12, -2, 0}, + {12, -1, 10}, + {12, 0, 12}, + {12, 1, 12}, + {12, 2, 12}, + {12, 3, 12}, + {12, 4, 12}, + }}, + {{2, 1}, + { + {12, -4, 0}, + {12, -3, 0}, + {12, -2, 0}, + {12, -1, 0}, + {12, 0, 10}, + {12, 1, 12}, + {12, 2, 12}, + {12, 3, 12}, + {12, 4, 12}, + }}, + {{2, 2}, + { + {12, -4, 0}, + {12, -3, 0}, + {12, -2, 0}, + {12, -1, 0}, + {12, 0, 0}, + {12, 1, 10}, + {12, 2, 12}, + {12, 3, 12}, + {12, 4, 12}, + }}, + {{9, 0}, + { + {123456789, -10, 0}, {123456789, -9, 0}, {123456789, -8, 100000000}, + {123456789, -7, 120000000}, {123456789, -6, 123000000}, {123456789, -5, 123400000}, + {123456789, -4, 123450000}, {123456789, -3, 123456000}, {123456789, -2, 123456700}, + {123456789, -1, 123456780}, {123456789, 0, 123456789}, {123456789, 1, 123456789}, + {123456789, 2, 123456789}, {123456789, 3, 123456789}, {123456789, 4, 123456789}, + {123456789, 5, 123456789}, {123456789, 6, 123456789}, {123456789, 7, 123456789}, + {123456789, 8, 123456789}, {123456789, 9, 123456789}, {123456789, 10, 123456789}, + }}, + {{9, 1}, + { + {123456789, -10, 0}, {123456789, -9, 0}, {123456789, -8, 0}, + {123456789, -7, 100000000}, {123456789, -6, 120000000}, {123456789, -5, 123000000}, + {123456789, -4, 123400000}, {123456789, -3, 123450000}, {123456789, -2, 123456000}, + {123456789, -1, 123456700}, {123456789, 0, 123456780}, {123456789, 1, 123456789}, + {123456789, 2, 123456789}, {123456789, 3, 123456789}, {123456789, 4, 123456789}, + {123456789, 5, 123456789}, {123456789, 6, 123456789}, {123456789, 7, 123456789}, + {123456789, 8, 123456789}, {123456789, 9, 123456789}, {123456789, 10, 123456789}, + }}, + {{9, 2}, + { + {123456789, -10, 0}, {123456789, -9, 0}, {123456789, -8, 0}, + {123456789, -7, 0}, {123456789, -6, 100000000}, {123456789, -5, 120000000}, + {123456789, -4, 123000000}, {123456789, -3, 123400000}, {123456789, -2, 123450000}, + {123456789, -1, 123456000}, {123456789, 0, 123456700}, {123456789, 1, 123456780}, + {123456789, 2, 123456789}, {123456789, 3, 123456789}, {123456789, 4, 123456789}, + {123456789, 5, 123456789}, {123456789, 6, 123456789}, {123456789, 7, 123456789}, + {123456789, 8, 123456789}, {123456789, 9, 123456789}, {123456789, 10, 123456789}, + }}, + {{9, 3}, + { + {123456789, -10, 0}, {123456789, -9, 0}, {123456789, -8, 0}, + {123456789, -7, 0}, {123456789, -6, 0}, {123456789, -5, 100000000}, + {123456789, -4, 120000000}, {123456789, -3, 123000000}, {123456789, -2, 123400000}, + {123456789, -1, 123450000}, {123456789, 0, 123456000}, {123456789, 1, 123456700}, + {123456789, 2, 123456780}, {123456789, 3, 123456789}, {123456789, 4, 123456789}, + {123456789, 5, 123456789}, {123456789, 6, 123456789}, {123456789, 7, 123456789}, + {123456789, 8, 123456789}, {123456789, 9, 123456789}, {123456789, 10, 123456789}, + }}, + {{9, 4}, + { + {123456789, -10, 0}, {123456789, -9, 0}, {123456789, -8, 0}, + {123456789, -7, 0}, {123456789, -6, 0}, {123456789, -5, 0}, + {123456789, -4, 100000000}, {123456789, -3, 120000000}, {123456789, -2, 123000000}, + {123456789, -1, 123400000}, {123456789, 0, 123450000}, {123456789, 1, 123456000}, + {123456789, 2, 123456700}, {123456789, 3, 123456780}, {123456789, 4, 123456789}, + {123456789, 5, 123456789}, {123456789, 6, 123456789}, {123456789, 7, 123456789}, + {123456789, 8, 123456789}, {123456789, 9, 123456789}, {123456789, 10, 123456789}, + }}, + {{9, 5}, + { + {123456789, -10, 0}, {123456789, -9, 0}, {123456789, -8, 0}, + {123456789, -7, 0}, {123456789, -6, 0}, {123456789, -5, 0}, + {123456789, -4, 0}, {123456789, -3, 100000000}, {123456789, -2, 120000000}, + {123456789, -1, 123000000}, {123456789, 0, 123400000}, {123456789, 1, 123450000}, + {123456789, 2, 123456000}, {123456789, 3, 123456700}, {123456789, 4, 123456780}, + {123456789, 5, 123456789}, {123456789, 6, 123456789}, {123456789, 7, 123456789}, + {123456789, 8, 123456789}, {123456789, 9, 123456789}, {123456789, 10, 123456789}, + }}, + {{9, 6}, + { + {123456789, -10, 0}, {123456789, -9, 0}, {123456789, -8, 0}, + {123456789, -7, 0}, {123456789, -6, 0}, {123456789, -5, 0}, + {123456789, -4, 0}, {123456789, -3, 0}, {123456789, -2, 100000000}, + {123456789, -1, 120000000}, {123456789, 0, 123000000}, {123456789, 1, 123400000}, + {123456789, 2, 123450000}, {123456789, 3, 123456000}, {123456789, 4, 123456700}, + {123456789, 5, 123456780}, {123456789, 6, 123456789}, {123456789, 7, 123456789}, + {123456789, 8, 123456789}, {123456789, 9, 123456789}, {123456789, 10, 123456789}, + }}, + {{9, 7}, + { + {123456789, -10, 0}, {123456789, -9, 0}, {123456789, -8, 0}, + {123456789, -7, 0}, {123456789, -6, 0}, {123456789, -5, 0}, + {123456789, -4, 0}, {123456789, -3, 0}, {123456789, -2, 0}, + {123456789, -1, 100000000}, {123456789, 0, 120000000}, {123456789, 1, 123000000}, + {123456789, 2, 123400000}, {123456789, 3, 123450000}, {123456789, 4, 123456000}, + {123456789, 5, 123456700}, {123456789, 6, 123456780}, {123456789, 7, 123456789}, + {123456789, 8, 123456789}, {123456789, 9, 123456789}, {123456789, 10, 123456789}, + }}, + {{9, 8}, + { + {123456789, -10, 0}, {123456789, -9, 0}, {123456789, -8, 0}, + {123456789, -7, 0}, {123456789, -6, 0}, {123456789, -5, 0}, + {123456789, -4, 0}, {123456789, -3, 0}, {123456789, -2, 0}, + {123456789, -1, 0}, {123456789, 0, 100000000}, {123456789, 1, 120000000}, + {123456789, 2, 123000000}, {123456789, 3, 123400000}, {123456789, 4, 123450000}, + {123456789, 5, 123456000}, {123456789, 6, 123456700}, {123456789, 7, 123456780}, + {123456789, 8, 123456789}, {123456789, 9, 123456789}, {123456789, 10, 123456789}, + }}, + {{9, 9}, + { + {123456789, -10, 0}, {123456789, -9, 0}, {123456789, -8, 0}, + {123456789, -7, 0}, {123456789, -6, 0}, {123456789, -5, 0}, + {123456789, -4, 0}, {123456789, -3, 0}, {123456789, -2, 0}, + {123456789, -1, 0}, {123456789, 0, 0}, {123456789, 1, 100000000}, + {123456789, 2, 120000000}, {123456789, 3, 123000000}, {123456789, 4, 123400000}, + {123456789, 5, 123450000}, {123456789, 6, 123456000}, {123456789, 7, 123456700}, + {123456789, 8, 123456780}, {123456789, 9, 123456789}, {123456789, 10, 123456789}, + }}}; + +// test cases for truncate and floor function of decimal64 +const static DecimalTestDataSet trunc_floor_decimal64_cases = { + {{10, 0}, + {{1234567891, -11, 0}, {1234567891, -10, 0}, {1234567891, -9, 1000000000}, + {1234567891, -8, 1200000000}, {1234567891, -7, 1230000000}, {1234567891, -6, 1234000000}, + {1234567891, -5, 1234500000}, {1234567891, -4, 1234560000}, {1234567891, -3, 1234567000}, + {1234567891, -2, 1234567800}, {1234567891, -1, 1234567890}, {1234567891, 0, 1234567891}, + {1234567891, 1, 1234567891}, {1234567891, 2, 1234567891}, {1234567891, 3, 1234567891}, + {1234567891, 4, 1234567891}, {1234567891, 5, 1234567891}, {1234567891, 6, 1234567891}, + {1234567891, 7, 1234567891}, {1234567891, 8, 1234567891}, {1234567891, 9, 1234567891}, + {1234567891, 10, 1234567891}, {1234567891, 11, 1234567891}}}, + {{10, 1}, + {{1234567891, -11, 0}, {1234567891, -10, 0}, {1234567891, -9, 0}, + {1234567891, -8, 1000000000}, {1234567891, -7, 1200000000}, {1234567891, -6, 1230000000}, + {1234567891, -5, 1234000000}, {1234567891, -4, 1234500000}, {1234567891, -3, 1234560000}, + {1234567891, -2, 1234567000}, {1234567891, -1, 1234567800}, {1234567891, 0, 1234567890}, + {1234567891, 1, 1234567891}, {1234567891, 2, 1234567891}, {1234567891, 3, 1234567891}, + {1234567891, 4, 1234567891}, {1234567891, 5, 1234567891}, {1234567891, 6, 1234567891}, + {1234567891, 7, 1234567891}, {1234567891, 8, 1234567891}, {1234567891, 9, 1234567891}, + {1234567891, 10, 1234567891}, {1234567891, 11, 1234567891} + + }}, + {{10, 2}, + {{1234567891, -11, 0}, {1234567891, -10, 0}, {1234567891, -9, 0}, + {1234567891, -8, 0}, {1234567891, -7, 1000000000}, {1234567891, -6, 1200000000}, + {1234567891, -5, 1230000000}, {1234567891, -4, 1234000000}, {1234567891, -3, 1234500000}, + {1234567891, -2, 1234560000}, {1234567891, -1, 1234567000}, {1234567891, 0, 1234567800}, + {1234567891, 1, 1234567890}, {1234567891, 2, 1234567891}, {1234567891, 3, 1234567891}, + {1234567891, 4, 1234567891}, {1234567891, 5, 1234567891}, {1234567891, 6, 1234567891}, + {1234567891, 7, 1234567891}, {1234567891, 8, 1234567891}, {1234567891, 9, 1234567891}, + {1234567891, 10, 1234567891}, {1234567891, 11, 1234567891}}}, + {{10, 9}, + {{1234567891, -11, 0}, {1234567891, -10, 0}, {1234567891, -9, 0}, + {1234567891, -8, 0}, {1234567891, -7, 0}, {1234567891, -6, 0}, + {1234567891, -5, 0}, {1234567891, -4, 0}, {1234567891, -3, 0}, + {1234567891, -2, 0}, {1234567891, -1, 0}, {1234567891, 0, 1000000000}, + {1234567891, 1, 1200000000}, {1234567891, 2, 1230000000}, {1234567891, 3, 1234000000}, + {1234567891, 4, 1234500000}, {1234567891, 5, 1234560000}, {1234567891, 6, 1234567000}, + {1234567891, 7, 1234567800}, {1234567891, 8, 1234567890}, {1234567891, 9, 1234567891}, + {1234567891, 10, 1234567891}, {1234567891, 11, 1234567891}}}, + {{18, 0}, + {{123456789123456789, -19, 0}, + {123456789123456789, -18, 0}, + {123456789123456789, -17, 100000000000000000}, + {123456789123456789, -16, 120000000000000000}, + {123456789123456789, -15, 123000000000000000}, + {123456789123456789, -14, 123400000000000000}, + {123456789123456789, -13, 123450000000000000}, + {123456789123456789, -12, 123456000000000000}, + {123456789123456789, -11, 123456700000000000}, + {123456789123456789, -10, 123456780000000000}, + {123456789123456789, -9, 123456789000000000}, + {123456789123456789, -8, 123456789100000000}, + {123456789123456789, -7, 123456789120000000}, + {123456789123456789, -6, 123456789123000000}, + {123456789123456789, -5, 123456789123400000}, + {123456789123456789, -4, 123456789123450000}, + {123456789123456789, -3, 123456789123456000}, + {123456789123456789, -2, 123456789123456700}, + {123456789123456789, -1, 123456789123456780}, + {123456789123456789, 0, 123456789123456789}, + {123456789123456789, 1, 123456789123456789}, + {123456789123456789, 2, 123456789123456789}, + {123456789123456789, 3, 123456789123456789}, + {123456789123456789, 4, 123456789123456789}, + {123456789123456789, 5, 123456789123456789}, + {123456789123456789, 6, 123456789123456789}, + {123456789123456789, 7, 123456789123456789}, + {123456789123456789, 8, 123456789123456789}, + {123456789123456789, 18, 123456789123456789}}}, + {{18, 18}, + {{123456789123456789, -1, 0}, + {123456789123456789, 0, 0}, + {123456789123456789, 1, 100000000000000000}, + {123456789123456789, 2, 120000000000000000}, + {123456789123456789, 3, 123000000000000000}, + {123456789123456789, 4, 123400000000000000}, + {123456789123456789, 5, 123450000000000000}, + {123456789123456789, 6, 123456000000000000}, + {123456789123456789, 7, 123456700000000000}, + {123456789123456789, 8, 123456780000000000}, + {123456789123456789, 9, 123456789000000000}, + {123456789123456789, 10, 123456789100000000}, + {123456789123456789, 11, 123456789120000000}, + {123456789123456789, 12, 123456789123000000}, + {123456789123456789, 13, 123456789123400000}, + {123456789123456789, 14, 123456789123450000}, + {123456789123456789, 15, 123456789123456000}, + {123456789123456789, 16, 123456789123456700}, + {123456789123456789, 17, 123456789123456780}, + {123456789123456789, 18, 123456789123456789}, + {123456789123456789, 19, 123456789123456789}, + {123456789123456789, 20, 123456789123456789}, + {123456789123456789, 21, 123456789123456789}, + {123456789123456789, 22, 123456789123456789}, + {123456789123456789, 23, 123456789123456789}, + {123456789123456789, 24, 123456789123456789}, + {123456789123456789, 25, 123456789123456789}, + {123456789123456789, 26, 123456789123456789}}}}; + +const static DecimalTestDataSet ceil_decimal32_cases = { + {{1, 0}, + { + {1, -10, 0}, {1, -9, 1000000000}, {1, -8, 100000000}, {1, -7, 10000000}, + {1, -6, 1000000}, {1, -5, 100000}, {1, -4, 10000}, {1, -3, 1000}, + {1, -2, 100}, {1, -1, 10}, {1, 0, 1}, {1, 1, 1}, + {1, 2, 1}, {1, 3, 1}, {1, 4, 1}, {1, 5, 1}, + {1, 6, 1}, {1, 7, 1}, {1, 8, 1}, {1, 9, 1}, + {1, 10, 1}, + }}, + {{1, 1}, + { + {1, -10, 0}, {1, -9, 0}, {1, -8, 1000000000}, {1, -7, 100000000}, + {1, -6, 10000000}, {1, -5, 1000000}, {1, -4, 100000}, {1, -3, 10000}, + {1, -2, 1000}, {1, -1, 100}, {1, 0, 10}, {1, 1, 1}, + {1, 2, 1}, {1, 3, 1}, {1, 4, 1}, {1, 5, 1}, + {1, 6, 1}, {1, 7, 1}, {1, 8, 1}, {1, 9, 1}, + {1, 10, 1}, + }}, + {{2, 0}, + { + {12, -4, 10000}, + {12, -3, 1000}, + {12, -2, 100}, + {12, -1, 20}, + {12, 0, 12}, + {12, 1, 12}, + {12, 2, 12}, + {12, 3, 12}, + {12, 4, 12}, + }}, + {{2, 1}, + { + {12, -4, 100000}, + {12, -3, 10000}, + {12, -2, 1000}, + {12, -1, 100}, + {12, 0, 20}, + {12, 1, 12}, + {12, 2, 12}, + {12, 3, 12}, + {12, 4, 12}, + }}, + {{2, 2}, + { + {12, -4, 1000000}, + {12, -3, 100000}, + {12, -2, 10000}, + {12, -1, 1000}, + {12, 0, 100}, + {12, 1, 20}, + {12, 2, 12}, + {12, 3, 12}, + {12, 4, 12}, + }}, + {{9, 0}, + { + {123456789, -10, 0}, {123456789, -9, 1000000000}, + {123456789, -8, 200000000}, {123456789, -7, 130000000}, + {123456789, -6, 124000000}, {123456789, -5, 123500000}, + {123456789, -4, 123460000}, {123456789, -3, 123457000}, + {123456789, -2, 123456800}, {123456789, -1, 123456790}, + {123456789, 0, 123456789}, {123456789, 1, 123456789}, + {123456789, 2, 123456789}, {123456789, 3, 123456789}, + {123456789, 4, 123456789}, {123456789, 5, 123456789}, + {123456789, 6, 123456789}, {123456789, 7, 123456789}, + {123456789, 8, 123456789}, {123456789, 9, 123456789}, + {123456789, 10, 123456789}, + }}, + {{9, 1}, + { + {123456789, -10, 0}, {123456789, -9, 0}, + {123456789, -8, 1000000000}, {123456789, -7, 200000000}, + {123456789, -6, 130000000}, {123456789, -5, 124000000}, + {123456789, -4, 123500000}, {123456789, -3, 123460000}, + {123456789, -2, 123457000}, {123456789, -1, 123456800}, + {123456789, 0, 123456790}, {123456789, 1, 123456789}, + {123456789, 2, 123456789}, {123456789, 3, 123456789}, + {123456789, 4, 123456789}, {123456789, 5, 123456789}, + {123456789, 6, 123456789}, {123456789, 7, 123456789}, + {123456789, 8, 123456789}, {123456789, 9, 123456789}, + {123456789, 10, 123456789}, + }}, + {{9, 2}, + { + {123456789, -10, 0}, {123456789, -9, 0}, + {123456789, -8, 0}, {123456789, -7, 1000000000}, + {123456789, -6, 200000000}, {123456789, -5, 130000000}, + {123456789, -4, 124000000}, {123456789, -3, 123500000}, + {123456789, -2, 123460000}, {123456789, -1, 123457000}, + {123456789, 0, 123456800}, {123456789, 1, 123456790}, + {123456789, 2, 123456789}, {123456789, 3, 123456789}, + {123456789, 4, 123456789}, {123456789, 5, 123456789}, + {123456789, 6, 123456789}, {123456789, 7, 123456789}, + {123456789, 8, 123456789}, {123456789, 9, 123456789}, + {123456789, 10, 123456789}, + }}, + {{9, 3}, + { + {123456789, -10, 0}, {123456789, -9, 0}, + {123456789, -8, 0}, {123456789, -7, 0}, + {123456789, -6, 1000000000}, {123456789, -5, 200000000}, + {123456789, -4, 130000000}, {123456789, -3, 124000000}, + {123456789, -2, 123500000}, {123456789, -1, 123460000}, + {123456789, 0, 123457000}, {123456789, 1, 123456800}, + {123456789, 2, 123456790}, {123456789, 3, 123456789}, + {123456789, 4, 123456789}, {123456789, 5, 123456789}, + {123456789, 6, 123456789}, {123456789, 7, 123456789}, + {123456789, 8, 123456789}, {123456789, 9, 123456789}, + {123456789, 10, 123456789}, + }}, + {{9, 4}, + { + {123456789, -10, 0}, {123456789, -9, 0}, + {123456789, -8, 0}, {123456789, -7, 0}, + {123456789, -6, 0}, {123456789, -5, 1000000000}, + {123456789, -4, 200000000}, {123456789, -3, 130000000}, + {123456789, -2, 124000000}, {123456789, -1, 123500000}, + {123456789, 0, 123460000}, {123456789, 1, 123457000}, + {123456789, 2, 123456800}, {123456789, 3, 123456790}, + {123456789, 4, 123456789}, {123456789, 5, 123456789}, + {123456789, 6, 123456789}, {123456789, 7, 123456789}, + {123456789, 8, 123456789}, {123456789, 9, 123456789}, + {123456789, 10, 123456789}, + }}, + {{9, 5}, + { + {123456789, -10, 0}, {123456789, -9, 0}, + {123456789, -8, 0}, {123456789, -7, 0}, + {123456789, -6, 0}, {123456789, -5, 0}, + {123456789, -4, 1000000000}, {123456789, -3, 200000000}, + {123456789, -2, 130000000}, {123456789, -1, 124000000}, + {123456789, 0, 123500000}, {123456789, 1, 123460000}, + {123456789, 2, 123457000}, {123456789, 3, 123456800}, + {123456789, 4, 123456790}, {123456789, 5, 123456789}, + {123456789, 6, 123456789}, {123456789, 7, 123456789}, + {123456789, 8, 123456789}, {123456789, 9, 123456789}, + {123456789, 10, 123456789}, + }}, + {{9, 6}, + { + {123456789, -10, 0}, {123456789, -9, 0}, + {123456789, -8, 0}, {123456789, -7, 0}, + {123456789, -6, 0}, {123456789, -5, 0}, + {123456789, -4, 0}, {123456789, -3, 1000000000}, + {123456789, -2, 200000000}, {123456789, -1, 130000000}, + {123456789, 0, 124000000}, {123456789, 1, 123500000}, + {123456789, 2, 123460000}, {123456789, 3, 123457000}, + {123456789, 4, 123456800}, {123456789, 5, 123456790}, + {123456789, 6, 123456789}, {123456789, 7, 123456789}, + {123456789, 8, 123456789}, {123456789, 9, 123456789}, + {123456789, 10, 123456789}, + }}, + {{9, 7}, + { + {123456789, -10, 0}, {123456789, -9, 0}, {123456789, -8, 0}, + {123456789, -7, 0}, {123456789, -6, 0}, {123456789, -5, 0}, + {123456789, -4, 0}, {123456789, -3, 0}, {123456789, -2, 1000000000}, + {123456789, -1, 200000000}, {123456789, 0, 130000000}, {123456789, 1, 124000000}, + {123456789, 2, 123500000}, {123456789, 3, 123460000}, {123456789, 4, 123457000}, + {123456789, 5, 123456800}, {123456789, 6, 123456790}, {123456789, 7, 123456789}, + {123456789, 8, 123456789}, {123456789, 9, 123456789}, {123456789, 10, 123456789}, + }}, + {{9, 8}, + { + {123456789, -10, 0}, {123456789, -9, 0}, {123456789, -8, 0}, + {123456789, -7, 0}, {123456789, -6, 0}, {123456789, -5, 0}, + {123456789, -4, 0}, {123456789, -3, 0}, {123456789, -2, 0}, + {123456789, -1, 1000000000}, {123456789, 0, 200000000}, {123456789, 1, 130000000}, + {123456789, 2, 124000000}, {123456789, 3, 123500000}, {123456789, 4, 123460000}, + {123456789, 5, 123457000}, {123456789, 6, 123456800}, {123456789, 7, 123456790}, + {123456789, 8, 123456789}, {123456789, 9, 123456789}, {123456789, 10, 123456789}, + }}, + {{9, 9}, + { + {123456789, -10, 0}, {123456789, -9, 0}, {123456789, -8, 0}, + {123456789, -7, 0}, {123456789, -6, 0}, {123456789, -5, 0}, + {123456789, -4, 0}, {123456789, -3, 0}, {123456789, -2, 0}, + {123456789, -1, 0}, {123456789, 0, 1000000000}, {123456789, 1, 200000000}, + {123456789, 2, 130000000}, {123456789, 3, 124000000}, {123456789, 4, 123500000}, + {123456789, 5, 123460000}, {123456789, 6, 123457000}, {123456789, 7, 123456800}, + {123456789, 8, 123456790}, {123456789, 9, 123456789}, {123456789, 10, 123456789}, + }}}; + +// test cases for ceil function of decimal64 +const static DecimalTestDataSet ceil_decimal64_cases = { + {{10, 0}, {{1234567891, -11, 100000000000}, {1234567891, -10, 10000000000}, + {1234567891, -9, 2000000000}, {1234567891, -8, 1300000000}, + {1234567891, -7, 1240000000}, {1234567891, -6, 1235000000}, + {1234567891, -5, 1234600000}, {1234567891, -4, 1234570000}, + {1234567891, -3, 1234568000}, {1234567891, -2, 1234567900}, + {1234567891, -1, 1234567900}, {1234567891, 0, 1234567891}, + {1234567891, 1, 1234567891}, {1234567891, 2, 1234567891}, + {1234567891, 3, 1234567891}, {1234567891, 4, 1234567891}, + {1234567891, 5, 1234567891}, {1234567891, 6, 1234567891}, + {1234567891, 7, 1234567891}, {1234567891, 8, 1234567891}, + {1234567891, 9, 1234567891}, {1234567891, 10, 1234567891}, + {1234567891, 11, 1234567891}}}, + {{10, 1}, {{1234567891, -11, 1000000000000}, {1234567891, -10, 100000000000}, + {1234567891, -9, 10000000000}, {1234567891, -8, 2000000000}, + {1234567891, -7, 1300000000}, {1234567891, -6, 1240000000}, + {1234567891, -5, 1235000000}, {1234567891, -4, 1234600000}, + {1234567891, -3, 1234570000}, {1234567891, -2, 1234568000}, + {1234567891, -1, 1234567900}, {1234567891, 0, 1234567900}, + {1234567891, 1, 1234567891}, {1234567891, 2, 1234567891}, + {1234567891, 3, 1234567891}, {1234567891, 4, 1234567891}, + {1234567891, 5, 1234567891}, {1234567891, 6, 1234567891}, + {1234567891, 7, 1234567891}, {1234567891, 8, 1234567891}, + {1234567891, 9, 1234567891}, {1234567891, 10, 1234567891}, + {1234567891, 11, 1234567891}}}, + {{10, 2}, {{1234567891, -11, 10000000000000}, {1234567891, -10, 1000000000000}, + {1234567891, -9, 100000000000}, {1234567891, -8, 10000000000}, + {1234567891, -7, 2000000000}, {1234567891, -6, 1300000000}, + {1234567891, -5, 1240000000}, {1234567891, -4, 1235000000}, + {1234567891, -3, 1234600000}, {1234567891, -2, 1234570000}, + {1234567891, -1, 1234568000}, {1234567891, 0, 1234567900}, + {1234567891, 1, 1234567900}, {1234567891, 2, 1234567891}, + {1234567891, 3, 1234567891}, {1234567891, 4, 1234567891}, + {1234567891, 5, 1234567891}, {1234567891, 6, 1234567891}, + {1234567891, 7, 1234567891}, {1234567891, 8, 1234567891}, + {1234567891, 9, 1234567891}, {1234567891, 10, 1234567891}, + {1234567891, 11, 1234567891}}}, + {{10, 9}, + {{1234567891, -11, 0}, + {1234567891, -10, 0}, + {1234567891, -9, 1000000000000000000}, + {1234567891, -8, 100000000000000000}, + {1234567891, -7, 10000000000000000}, + {1234567891, -6, 1000000000000000}, + {1234567891, -5, 100000000000000}, + {1234567891, -4, 10000000000000}, + {1234567891, -3, 1000000000000}, + {1234567891, -2, 100000000000}, + {1234567891, -1, 10000000000}, + {1234567891, 0, 2000000000}, + {1234567891, 1, 1300000000}, + {1234567891, 2, 1240000000}, + {1234567891, 3, 1235000000}, + {1234567891, 4, 1234600000}, + {1234567891, 5, 1234570000}, + {1234567891, 6, 1234568000}, + {1234567891, 7, 1234567900}, + {1234567891, 8, 1234567900}, + {1234567891, 9, 1234567891}, + {1234567891, 10, 1234567891}, + {1234567891, 11, 1234567891}}}, + {{18, 0}, + {{123456789123456789, -19, 0}, + {123456789123456789, -18, 1000000000000000000}, + {123456789123456789, -17, 200000000000000000}, + {123456789123456789, -16, 130000000000000000}, + {123456789123456789, -15, 124000000000000000}, + {123456789123456789, -14, 123500000000000000}, + {123456789123456789, -13, 123460000000000000}, + {123456789123456789, -12, 123457000000000000}, + {123456789123456789, -11, 123456800000000000}, + {123456789123456789, -10, 123456790000000000}, + {123456789123456789, -9, 123456790000000000}, + {123456789123456789, -8, 123456789200000000}, + {123456789123456789, -7, 123456789130000000}, + {123456789123456789, -6, 123456789124000000}, + {123456789123456789, -5, 123456789123500000}, + {123456789123456789, -4, 123456789123460000}, + {123456789123456789, -3, 123456789123457000}, + {123456789123456789, -2, 123456789123456800}, + {123456789123456789, -1, 123456789123456790}, + {123456789123456789, 0, 123456789123456789}, + {123456789123456789, 1, 123456789123456789}, + {123456789123456789, 2, 123456789123456789}, + {123456789123456789, 3, 123456789123456789}, + {123456789123456789, 4, 123456789123456789}, + {123456789123456789, 5, 123456789123456789}, + {123456789123456789, 6, 123456789123456789}, + {123456789123456789, 7, 123456789123456789}, + {123456789123456789, 8, 123456789123456789}, + {123456789123456789, 18, 123456789123456789}}}, + {{18, 18}, + {{123456789123456789, -1, 0}, + {123456789123456789, 0, 1000000000000000000}, + {123456789123456789, 1, 200000000000000000}, + {123456789123456789, 2, 130000000000000000}, + {123456789123456789, 3, 124000000000000000}, + {123456789123456789, 4, 123500000000000000}, + {123456789123456789, 5, 123460000000000000}, + {123456789123456789, 6, 123457000000000000}, + {123456789123456789, 7, 123456800000000000}, + {123456789123456789, 8, 123456790000000000}, + {123456789123456789, 9, 123456790000000000}, + {123456789123456789, 10, 123456789200000000}, + {123456789123456789, 11, 123456789130000000}, + {123456789123456789, 12, 123456789124000000}, + {123456789123456789, 13, 123456789123500000}, + {123456789123456789, 14, 123456789123460000}, + {123456789123456789, 15, 123456789123457000}, + {123456789123456789, 16, 123456789123456800}, + {123456789123456789, 17, 123456789123456790}, + {123456789123456789, 18, 123456789123456789}, + {123456789123456789, 19, 123456789123456789}, + {123456789123456789, 20, 123456789123456789}, + {123456789123456789, 21, 123456789123456789}, + {123456789123456789, 22, 123456789123456789}, + {123456789123456789, 23, 123456789123456789}, + {123456789123456789, 24, 123456789123456789}, + {123456789123456789, 25, 123456789123456789}, + {123456789123456789, 26, 123456789123456789}}}}; + +// test cases for round and round_bankers function of decimal32 +const static DecimalTestDataSet round_decimal32_cases = { + {{1, 0}, + { + {1, -10, 0}, {1, -9, 0}, {1, -8, 0}, {1, -7, 0}, {1, -6, 0}, {1, -5, 0}, + {1, -4, 0}, {1, -3, 0}, {1, -2, 0}, {1, -1, 0}, {1, 0, 1}, {1, 1, 1}, + {1, 2, 1}, {1, 3, 1}, {1, 4, 1}, {1, 5, 1}, {1, 6, 1}, {1, 7, 1}, + {1, 8, 1}, {1, 9, 1}, {1, 10, 1}, + }}, + {{1, 1}, + { + {1, -10, 0}, {1, -9, 0}, {1, -8, 0}, {1, -7, 0}, {1, -6, 0}, {1, -5, 0}, + {1, -4, 0}, {1, -3, 0}, {1, -2, 0}, {1, -1, 0}, {1, 0, 0}, {1, 1, 1}, + {1, 2, 1}, {1, 3, 1}, {1, 4, 1}, {1, 5, 1}, {1, 6, 1}, {1, 7, 1}, + {1, 8, 1}, {1, 9, 1}, {1, 10, 1}, + }}, + {{2, 0}, + { + {12, -4, 0}, + {12, -3, 0}, + {12, -2, 0}, + {12, -1, 10}, + {12, 0, 12}, + {12, 1, 12}, + {12, 2, 12}, + {12, 3, 12}, + {12, 4, 12}, + }}, + {{2, 1}, + { + {12, -4, 0}, + {12, -3, 0}, + {12, -2, 0}, + {12, -1, 0}, + {12, 0, 10}, + {12, 1, 12}, + {12, 2, 12}, + {12, 3, 12}, + {12, 4, 12}, + }}, + {{2, 2}, + { + {12, -4, 0}, + {12, -3, 0}, + {12, -2, 0}, + {12, -1, 0}, + {12, 0, 0}, + {12, 1, 10}, + {12, 2, 12}, + {12, 3, 12}, + {12, 4, 12}, + }}, + {{9, 0}, + { + {123456789, -10, 0}, {123456789, -9, 0}, {123456789, -8, 100000000}, + {123456789, -7, 120000000}, {123456789, -6, 123000000}, {123456789, -5, 123500000}, + {123456789, -4, 123460000}, {123456789, -3, 123457000}, {123456789, -2, 123456800}, + {123456789, -1, 123456790}, {123456789, 0, 123456789}, {123456789, 1, 123456789}, + {123456789, 2, 123456789}, {123456789, 3, 123456789}, {123456789, 4, 123456789}, + {123456789, 5, 123456789}, {123456789, 6, 123456789}, {123456789, 7, 123456789}, + {123456789, 8, 123456789}, {123456789, 9, 123456789}, {123456789, 10, 123456789}, + }}, + {{9, 1}, + { + {123456789, -10, 0}, {123456789, -9, 0}, {123456789, -8, 0}, + {123456789, -7, 100000000}, {123456789, -6, 120000000}, {123456789, -5, 123000000}, + {123456789, -4, 123500000}, {123456789, -3, 123460000}, {123456789, -2, 123457000}, + {123456789, -1, 123456800}, {123456789, 0, 123456790}, {123456789, 1, 123456789}, + {123456789, 2, 123456789}, {123456789, 3, 123456789}, {123456789, 4, 123456789}, + {123456789, 5, 123456789}, {123456789, 6, 123456789}, {123456789, 7, 123456789}, + {123456789, 8, 123456789}, {123456789, 9, 123456789}, {123456789, 10, 123456789}, + }}, + {{9, 2}, + { + {123456789, -10, 0}, {123456789, -9, 0}, {123456789, -8, 0}, + {123456789, -7, 0}, {123456789, -6, 100000000}, {123456789, -5, 120000000}, + {123456789, -4, 123000000}, {123456789, -3, 123500000}, {123456789, -2, 123460000}, + {123456789, -1, 123457000}, {123456789, 0, 123456800}, {123456789, 1, 123456790}, + {123456789, 2, 123456789}, {123456789, 3, 123456789}, {123456789, 4, 123456789}, + {123456789, 5, 123456789}, {123456789, 6, 123456789}, {123456789, 7, 123456789}, + {123456789, 8, 123456789}, {123456789, 9, 123456789}, {123456789, 10, 123456789}, + }}, + {{9, 3}, + { + {123456789, -10, 0}, {123456789, -9, 0}, {123456789, -8, 0}, + {123456789, -7, 0}, {123456789, -6, 0}, {123456789, -5, 100000000}, + {123456789, -4, 120000000}, {123456789, -3, 123000000}, {123456789, -2, 123500000}, + {123456789, -1, 123460000}, {123456789, 0, 123457000}, {123456789, 1, 123456800}, + {123456789, 2, 123456790}, {123456789, 3, 123456789}, {123456789, 4, 123456789}, + {123456789, 5, 123456789}, {123456789, 6, 123456789}, {123456789, 7, 123456789}, + {123456789, 8, 123456789}, {123456789, 9, 123456789}, {123456789, 10, 123456789}, + }}, + {{9, 4}, + { + {123456789, -10, 0}, {123456789, -9, 0}, {123456789, -8, 0}, + {123456789, -7, 0}, {123456789, -6, 0}, {123456789, -5, 0}, + {123456789, -4, 100000000}, {123456789, -3, 120000000}, {123456789, -2, 123000000}, + {123456789, -1, 123500000}, {123456789, 0, 123460000}, {123456789, 1, 123457000}, + {123456789, 2, 123456800}, {123456789, 3, 123456790}, {123456789, 4, 123456789}, + {123456789, 5, 123456789}, {123456789, 6, 123456789}, {123456789, 7, 123456789}, + {123456789, 8, 123456789}, {123456789, 9, 123456789}, {123456789, 10, 123456789}, + }}, + {{9, 5}, + { + {123456789, -10, 0}, {123456789, -9, 0}, {123456789, -8, 0}, + {123456789, -7, 0}, {123456789, -6, 0}, {123456789, -5, 0}, + {123456789, -4, 0}, {123456789, -3, 100000000}, {123456789, -2, 120000000}, + {123456789, -1, 123000000}, {123456789, 0, 123500000}, {123456789, 1, 123460000}, + {123456789, 2, 123457000}, {123456789, 3, 123456800}, {123456789, 4, 123456790}, + {123456789, 5, 123456789}, {123456789, 6, 123456789}, {123456789, 7, 123456789}, + {123456789, 8, 123456789}, {123456789, 9, 123456789}, {123456789, 10, 123456789}, + }}, + {{9, 6}, + { + {123456789, -10, 0}, {123456789, -9, 0}, {123456789, -8, 0}, + {123456789, -7, 0}, {123456789, -6, 0}, {123456789, -5, 0}, + {123456789, -4, 0}, {123456789, -3, 0}, {123456789, -2, 100000000}, + {123456789, -1, 120000000}, {123456789, 0, 123000000}, {123456789, 1, 123500000}, + {123456789, 2, 123460000}, {123456789, 3, 123457000}, {123456789, 4, 123456800}, + {123456789, 5, 123456790}, {123456789, 6, 123456789}, {123456789, 7, 123456789}, + {123456789, 8, 123456789}, {123456789, 9, 123456789}, {123456789, 10, 123456789}, + }}, + {{9, 7}, + { + {123456789, -10, 0}, {123456789, -9, 0}, {123456789, -8, 0}, + {123456789, -7, 0}, {123456789, -6, 0}, {123456789, -5, 0}, + {123456789, -4, 0}, {123456789, -3, 0}, {123456789, -2, 0}, + {123456789, -1, 100000000}, {123456789, 0, 120000000}, {123456789, 1, 123000000}, + {123456789, 2, 123500000}, {123456789, 3, 123460000}, {123456789, 4, 123457000}, + {123456789, 5, 123456800}, {123456789, 6, 123456790}, {123456789, 7, 123456789}, + {123456789, 8, 123456789}, {123456789, 9, 123456789}, {123456789, 10, 123456789}, + }}, + {{9, 8}, + { + {123456789, -10, 0}, {123456789, -9, 0}, {123456789, -8, 0}, + {123456789, -7, 0}, {123456789, -6, 0}, {123456789, -5, 0}, + {123456789, -4, 0}, {123456789, -3, 0}, {123456789, -2, 0}, + {123456789, -1, 0}, {123456789, 0, 100000000}, {123456789, 1, 120000000}, + {123456789, 2, 123000000}, {123456789, 3, 123500000}, {123456789, 4, 123460000}, + {123456789, 5, 123457000}, {123456789, 6, 123456800}, {123456789, 7, 123456790}, + {123456789, 8, 123456789}, {123456789, 9, 123456789}, {123456789, 10, 123456789}, + }}, + {{9, 9}, + { + {123456789, -10, 0}, {123456789, -9, 0}, {123456789, -8, 0}, + {123456789, -7, 0}, {123456789, -6, 0}, {123456789, -5, 0}, + {123456789, -4, 0}, {123456789, -3, 0}, {123456789, -2, 0}, + {123456789, -1, 0}, {123456789, 0, 0}, {123456789, 1, 100000000}, + {123456789, 2, 120000000}, {123456789, 3, 123000000}, {123456789, 4, 123500000}, + {123456789, 5, 123460000}, {123456789, 6, 123457000}, {123456789, 7, 123456800}, + {123456789, 8, 123456790}, {123456789, 9, 123456789}, {123456789, 10, 123456789}, + }}}; + +// test cases for round and round_bankers function of decimal64 +const static DecimalTestDataSet round_decimal64_cases = { + {{10, 0}, + {{1234567891, -11, 0}, {1234567891, -10, 0}, {1234567891, -9, 1000000000}, + {1234567891, -8, 1200000000}, {1234567891, -7, 1230000000}, {1234567891, -6, 1235000000}, + {1234567891, -5, 1234600000}, {1234567891, -4, 1234570000}, {1234567891, -3, 1234568000}, + {1234567891, -2, 1234567900}, {1234567891, -1, 1234567890}, {1234567891, 0, 1234567891}, + {1234567891, 1, 1234567891}, {1234567891, 2, 1234567891}, {1234567891, 3, 1234567891}, + {1234567891, 4, 1234567891}, {1234567891, 5, 1234567891}, {1234567891, 6, 1234567891}, + {1234567891, 7, 1234567891}, {1234567891, 8, 1234567891}, {1234567891, 9, 1234567891}, + {1234567891, 10, 1234567891}, {1234567891, 11, 1234567891}}}, + {{10, 1}, + {{1234567891, -11, 0}, {1234567891, -10, 0}, {1234567891, -9, 0}, + {1234567891, -8, 1000000000}, {1234567891, -7, 1200000000}, {1234567891, -6, 1230000000}, + {1234567891, -5, 1235000000}, {1234567891, -4, 1234600000}, {1234567891, -3, 1234570000}, + {1234567891, -2, 1234568000}, {1234567891, -1, 1234567900}, {1234567891, 0, 1234567890}, + {1234567891, 1, 1234567891}, {1234567891, 2, 1234567891}, {1234567891, 3, 1234567891}, + {1234567891, 4, 1234567891}, {1234567891, 5, 1234567891}, {1234567891, 6, 1234567891}, + {1234567891, 7, 1234567891}, {1234567891, 8, 1234567891}, {1234567891, 9, 1234567891}, + {1234567891, 10, 1234567891}, {1234567891, 11, 1234567891} + + }}, + {{10, 2}, + {{1234567891, -11, 0}, {1234567891, -10, 0}, {1234567891, -9, 0}, + {1234567891, -8, 0}, {1234567891, -7, 1000000000}, {1234567891, -6, 1200000000}, + {1234567891, -5, 1230000000}, {1234567891, -4, 1235000000}, {1234567891, -3, 1234600000}, + {1234567891, -2, 1234570000}, {1234567891, -1, 1234568000}, {1234567891, 0, 1234567900}, + {1234567891, 1, 1234567890}, {1234567891, 2, 1234567891}, {1234567891, 3, 1234567891}, + {1234567891, 4, 1234567891}, {1234567891, 5, 1234567891}, {1234567891, 6, 1234567891}, + {1234567891, 7, 1234567891}, {1234567891, 8, 1234567891}, {1234567891, 9, 1234567891}, + {1234567891, 10, 1234567891}, {1234567891, 11, 1234567891}}}, + {{10, 9}, + {{1234567891, -11, 0}, {1234567891, -10, 0}, {1234567891, -9, 0}, + {1234567891, -8, 0}, {1234567891, -7, 0}, {1234567891, -6, 0}, + {1234567891, -5, 0}, {1234567891, -4, 0}, {1234567891, -3, 0}, + {1234567891, -2, 0}, {1234567891, -1, 0}, {1234567891, 0, 1000000000}, + {1234567891, 1, 1200000000}, {1234567891, 2, 1230000000}, {1234567891, 3, 1235000000}, + {1234567891, 4, 1234600000}, {1234567891, 5, 1234570000}, {1234567891, 6, 1234568000}, + {1234567891, 7, 1234567900}, {1234567891, 8, 1234567890}, {1234567891, 9, 1234567891}, + {1234567891, 10, 1234567891}, {1234567891, 11, 1234567891}}}, + {{18, 0}, + {{123456789123456789, -19, 0}, + {123456789123456789, -18, 0}, + {123456789123456789, -17, 100000000000000000}, + {123456789123456789, -16, 120000000000000000}, + {123456789123456789, -15, 123000000000000000}, + {123456789123456789, -14, 123500000000000000}, + {123456789123456789, -13, 123460000000000000}, + {123456789123456789, -12, 123457000000000000}, + {123456789123456789, -11, 123456800000000000}, + {123456789123456789, -10, 123456790000000000}, + {123456789123456789, -9, 123456789000000000}, + {123456789123456789, -8, 123456789100000000}, + {123456789123456789, -7, 123456789120000000}, + {123456789123456789, -6, 123456789123000000}, + {123456789123456789, -5, 123456789123500000}, + {123456789123456789, -4, 123456789123460000}, + {123456789123456789, -3, 123456789123457000}, + {123456789123456789, -2, 123456789123456800}, + {123456789123456789, -1, 123456789123456790}, + {123456789123456789, 0, 123456789123456789}, + {123456789123456789, 1, 123456789123456789}, + {123456789123456789, 2, 123456789123456789}, + {123456789123456789, 3, 123456789123456789}, + {123456789123456789, 4, 123456789123456789}, + {123456789123456789, 5, 123456789123456789}, + {123456789123456789, 6, 123456789123456789}, + {123456789123456789, 7, 123456789123456789}, + {123456789123456789, 8, 123456789123456789}, + {123456789123456789, 18, 123456789123456789}}}, + {{18, 18}, + {{123456789123456789, -1, 0}, + {123456789123456789, 0, 0}, + {123456789123456789, 1, 100000000000000000}, + {123456789123456789, 2, 120000000000000000}, + {123456789123456789, 3, 123000000000000000}, + {123456789123456789, 4, 123500000000000000}, + {123456789123456789, 5, 123460000000000000}, + {123456789123456789, 6, 123457000000000000}, + {123456789123456789, 7, 123456800000000000}, + {123456789123456789, 8, 123456790000000000}, + {123456789123456789, 9, 123456789000000000}, + {123456789123456789, 10, 123456789100000000}, + {123456789123456789, 11, 123456789120000000}, + {123456789123456789, 12, 123456789123000000}, + {123456789123456789, 13, 123456789123500000}, + {123456789123456789, 14, 123456789123460000}, + {123456789123456789, 15, 123456789123457000}, + {123456789123456789, 16, 123456789123456800}, + {123456789123456789, 17, 123456789123456790}, + {123456789123456789, 18, 123456789123456789}, + {123456789123456789, 19, 123456789123456789}, + {123456789123456789, 20, 123456789123456789}, + {123456789123456789, 21, 123456789123456789}, + {123456789123456789, 22, 123456789123456789}, + {123456789123456789, 23, 123456789123456789}, + {123456789123456789, 24, 123456789123456789}, + {123456789123456789, 25, 123456789123456789}, + {123456789123456789, 26, 123456789123456789}}}}; + +// test cases for truncate function of float32 +FloatTestDataSet trunc_float32_cases = { + {10.0, 1, 10.0}, {10.0, 0, 10.0}, {-124.3867, -2, -100}, {123.123456, 4, 123.123400}}; + +// test cases for truncate function of float64 +FloatTestDataSet trunc_float64_cases = {{10.0, 1, 10.0}, + {10.0, 0, 10.0}, + {-124.3867, -2, -100}, + {123.123456, 1, 123.100000}, + {123456789.123456, 4, 123456789.123400}}; + +// test cases for floor function of float32 +FloatTestDataSet floor_float32_cases = { + {10.1234, 1, 10.1}, {10.3, 0, 10.0}, {-124.3867, -2, -200}, {123.123456, 4, 123.123400}}; + +// test cases for floor function of float64 +FloatTestDataSet floor_float64_cases = {{10.1234, 1, 10.1}, + {10.3, 0, 10.0}, + {-124.3867, -2, -200}, + {123.123456, 4, 123.123400}, + {123456789.123456, 5, 123456789.123450}}; + +// test cases for ceil function of float32 +FloatTestDataSet ceil_float32_cases = { + {10.1234, 1, 10.2}, {10.3, 0, 11}, {-124.3867, -2, -100}, {123.123456, 4, 123.123500}}; + +// test cases for ceil function of float64 +FloatTestDataSet ceil_float64_cases = {{10.1234, 1, 10.2}, + {10.3, 0, 11}, + {-124.3867, -2, -100}, + {123.123456, 4, 123.123500}, + {123456789.123456, 4, 123456789.123500}}; + +// test cases for round function of float32 +FloatTestDataSet round_float32_cases = {{2.5, 0, 3.0}, + {10.1234, 1, 10.1}, + {10.3, 0, 10}, + {-124.3867, -2, -100}, + {123.123456, 4, 123.123500}}; + +// test cases for round function of float64 +FloatTestDataSet round_float64_cases = {{2.5, 0, 3.0}, + {10.1234, 1, 10.1}, + {10.3, 0, 10}, + {-124.3867, -2, -100}, + {123.123456, 4, 123.123500}, + {123456789.123456, 4, 123456789.123500}}; + +// test cases for round_bankers function of float32 +FloatTestDataSet round_bankers_float32_cases = {{2.5, 0, 2.0}, + {10.1234, 1, 10.1}, + {10.3, 0, 10}, + {-124.3867, -2, -100}, + {123.123456, 4, 123.123500}}; + +// test cases for round_bankers function of float64 +FloatTestDataSet round_bankers_float64_cases = {{2.5, 0, 2.0}, + {10.1234, 1, 10.1}, + {10.3, 0, 10}, + {-124.3867, -2, -100}, + {123.123456, 4, 123.123500}, + {123456789.123456, 4, 123456789.123500}}; + +template +static void decimal_checker(const DecimalTestDataSet& round_test_cases, bool decimal_col_is_const) { + static_assert(IsDecimalNumber); + auto func = std::dynamic_pointer_cast(FuncType::create()); + FunctionContext* context = nullptr; + + for (const auto& test_case : round_test_cases) { + Block block; + size_t res_idx = 2; + ColumnNumbers arguments = {0, 1, 2}; + const int precision = test_case.first.first; + const int scale = test_case.first.second; + const size_t input_rows_count = test_case.second.size(); + auto col_general = ColumnDecimal::create(input_rows_count, scale); + auto col_scale = ColumnInt32::create(); + auto col_res_expected = ColumnDecimal::create(input_rows_count, scale); + size_t rid = 0; + + for (const auto& test_date : test_case.second) { + auto input = std::get<0>(test_date); + auto scale_arg = std::get<1>(test_date); + auto expectation = std::get<2>(test_date); + col_general->get_element(rid) = DecimalType(input); + col_scale->insert(scale_arg); + col_res_expected->get_element(rid) = DecimalType(expectation); + rid++; + } + + if (decimal_col_is_const) { + block.insert({ColumnConst::create(col_general->clone_resized(1), 1), + std::make_shared>(precision, scale), + "col_general_const"}); + } else { + block.insert({col_general->clone(), + std::make_shared>(precision, scale), + "col_general"}); + } + + block.insert({col_scale->clone(), std::make_shared(), "col_scale"}); + block.insert({nullptr, std::make_shared>(precision, scale), + "col_res"}); + + auto status = func->execute_impl(context, block, arguments, res_idx, input_rows_count); + auto col_res = assert_cast&>( + *(block.get_by_position(res_idx).column)); + EXPECT_TRUE(status.ok()); + + for (size_t i = 0; i < input_rows_count; ++i) { + auto res = col_res.get_element(i); + auto res_expected = col_res_expected->get_element(i); + EXPECT_EQ(res, res_expected) + << "function " << func->get_name() << " decimal_type " + << TypeName().get() << " precision " << precision + << " input_scale " << scale << " input " << col_general->get_element(i) + << " scale_arg " << col_scale->get_element(i) << " decimal_col_is_const " + << decimal_col_is_const << " res " << res << " res_expected " << res_expected; + } + } +} + +template +static void float_checker(const FloatTestDataSet& round_test_cases, bool float_col_is_const) { + static_assert(IsNumber); + auto func = std::dynamic_pointer_cast(FuncType::create()); + FunctionContext* context = nullptr; + + for (const auto& test_case : round_test_cases) { + auto col_general = ColumnVector::create(1); + auto col_scale = ColumnInt32::create(); + auto col_res_expected = ColumnVector::create(1); + size_t rid = 0; + + Block block; + size_t res_idx = 2; + ColumnNumbers arguments = {0, 1, 2}; + + auto input = std::get<0>(test_case); + auto scale_arg = std::get<1>(test_case); + auto expectation = std::get<2>(test_case); + col_general->get_element(rid) = FloatType(input); + col_scale->insert(scale_arg); + col_res_expected->get_element(rid) = FloatType(expectation); + rid++; + + if (float_col_is_const) { + block.insert({ColumnConst::create(col_general->clone_resized(1), 1), + std::make_shared>(), "col_general_const"}); + } else { + block.insert({col_general->clone(), std::make_shared>(), + "col_general"}); + } + + block.insert({col_scale->clone(), std::make_shared(), "col_scale"}); + block.insert({nullptr, std::make_shared>(), "col_res"}); + + auto status = func->execute_impl(context, block, arguments, res_idx, 1); + auto col_res = assert_cast&>( + *(block.get_by_position(res_idx).column)); + EXPECT_TRUE(status.ok()); + + auto res = col_res.get_element(0); + auto res_expected = col_res_expected->get_element(0); + EXPECT_EQ(res, res_expected) + << "function " << func->get_name() << " float_type " << TypeName().get() + << " input " << col_general->get_element(0) << " scale_arg " + << col_scale->get_element(0) << " float_col_is_const " << float_col_is_const + << " res " << res << " res_expected " << res_expected; + } +} + +/// tests for func(Column, Column) with decimal input +TEST(RoundFunctionTest, normal_decimal) { + // truncate + decimal_checker(trunc_floor_decimal32_cases, false); + decimal_checker(trunc_floor_decimal64_cases, false); + + // floor + decimal_checker(trunc_floor_decimal32_cases, false); + decimal_checker(trunc_floor_decimal64_cases, false); + + // ceil + decimal_checker(ceil_decimal32_cases, false); + decimal_checker(ceil_decimal64_cases, false); + + // round + decimal_checker(round_decimal32_cases, false); + decimal_checker(round_decimal64_cases, false); + + // round_bankers + decimal_checker(round_decimal32_cases, false); + decimal_checker(round_decimal64_cases, false); +} + +/// tests for func(ColumnConst, Column) with decimal input +TEST(RoundFunctionTest, normal_decimal_const) { + // truncate + decimal_checker(trunc_floor_decimal32_cases, true); + decimal_checker(trunc_floor_decimal64_cases, true); + + // floor + decimal_checker(trunc_floor_decimal32_cases, true); + decimal_checker(trunc_floor_decimal64_cases, true); + + // ceil + decimal_checker(ceil_decimal32_cases, true); + decimal_checker(ceil_decimal64_cases, true); + + // round + decimal_checker(round_decimal32_cases, true); + decimal_checker(round_decimal64_cases, true); + + // round_bankers + decimal_checker(round_decimal32_cases, true); + decimal_checker(round_decimal64_cases, true); +} + +/// tests for func(Column, Column) with float input +TEST(RoundFunctionTest, normal_float) { + // truncate + float_checker(trunc_float32_cases, false); + float_checker(trunc_float64_cases, false); + + // floor + float_checker(floor_float32_cases, false); + float_checker(floor_float64_cases, false); + + // ceil + float_checker(ceil_float32_cases, false); + float_checker(ceil_float64_cases, false); + + // round + float_checker(round_float32_cases, false); + float_checker(round_float64_cases, false); + + // round_bankers + float_checker(round_bankers_float32_cases, false); + float_checker(round_bankers_float64_cases, false); +} + +/// tests for func(ColumnConst, Column) with float input +TEST(RoundFunctionTest, normal_float_const) { + // truncate + float_checker(trunc_float32_cases, true); + float_checker(trunc_float64_cases, true); + + // floor + float_checker(floor_float32_cases, true); + float_checker(floor_float64_cases, true); + + // ceil + float_checker(ceil_float32_cases, true); + float_checker(ceil_float64_cases, true); + + // round + float_checker(round_float32_cases, true); + float_checker(round_float64_cases, true); + + // round_bankers + float_checker(round_bankers_float32_cases, true); + float_checker(round_bankers_float64_cases, true); +} + +} // namespace doris::vectorized diff --git a/be/test/vec/function/function_truncate_decimal_test.cpp b/be/test/vec/function/function_truncate_decimal_test.cpp deleted file mode 100644 index 36fcaa14e67fa6..00000000000000 --- a/be/test/vec/function/function_truncate_decimal_test.cpp +++ /dev/null @@ -1,370 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -#include -#include - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include "function_test_util.h" -#include "vec/columns/column.h" -#include "vec/columns/column_const.h" -#include "vec/columns/column_decimal.h" -#include "vec/columns/columns_number.h" -#include "vec/common/assert_cast.h" -#include "vec/core/column_numbers.h" -#include "vec/core/types.h" -#include "vec/data_types/data_type_decimal.h" -#include "vec/data_types/data_type_number.h" -#include "vec/functions/function_truncate.h" - -namespace doris::vectorized { -// {precision, scale} -> {input, scale_arg, expectation} -using TestDataSet = std::map, std::vector>>; - -const static TestDataSet truncate_decimal32_cases = { - {{1, 0}, - { - {1, -10, 0}, {1, -9, 0}, {1, -8, 0}, {1, -7, 0}, {1, -6, 0}, {1, -5, 0}, - {1, -4, 0}, {1, -3, 0}, {1, -2, 0}, {1, -1, 0}, {1, 0, 1}, {1, 1, 1}, - {1, 2, 1}, {1, 3, 1}, {1, 4, 1}, {1, 5, 1}, {1, 6, 1}, {1, 7, 1}, - {1, 8, 1}, {1, 9, 1}, {1, 10, 1}, - }}, - {{1, 1}, - { - {1, -10, 0}, {1, -9, 0}, {1, -8, 0}, {1, -7, 0}, {1, -6, 0}, {1, -5, 0}, - {1, -4, 0}, {1, -3, 0}, {1, -2, 0}, {1, -1, 0}, {1, 0, 0}, {1, 1, 1}, - {1, 2, 1}, {1, 3, 1}, {1, 4, 1}, {1, 5, 1}, {1, 6, 1}, {1, 7, 1}, - {1, 8, 1}, {1, 9, 1}, {1, 10, 1}, - }}, - {{2, 0}, - { - {12, -4, 0}, - {12, -3, 0}, - {12, -2, 0}, - {12, -1, 10}, - {12, 0, 12}, - {12, 1, 12}, - {12, 2, 12}, - {12, 3, 12}, - {12, 4, 12}, - }}, - {{2, 1}, - { - {12, -4, 0}, - {12, -3, 0}, - {12, -2, 0}, - {12, -1, 0}, - {12, 0, 10}, - {12, 1, 12}, - {12, 2, 12}, - {12, 3, 12}, - {12, 4, 12}, - }}, - {{2, 2}, - { - {12, -4, 0}, - {12, -3, 0}, - {12, -2, 0}, - {12, -1, 0}, - {12, 0, 0}, - {12, 1, 10}, - {12, 2, 12}, - {12, 3, 12}, - {12, 4, 12}, - }}, - {{9, 0}, - { - {123456789, -10, 0}, {123456789, -9, 0}, {123456789, -8, 100000000}, - {123456789, -7, 120000000}, {123456789, -6, 123000000}, {123456789, -5, 123400000}, - {123456789, -4, 123450000}, {123456789, -3, 123456000}, {123456789, -2, 123456700}, - {123456789, -1, 123456780}, {123456789, 0, 123456789}, {123456789, 1, 123456789}, - {123456789, 2, 123456789}, {123456789, 3, 123456789}, {123456789, 4, 123456789}, - {123456789, 5, 123456789}, {123456789, 6, 123456789}, {123456789, 7, 123456789}, - {123456789, 8, 123456789}, {123456789, 9, 123456789}, {123456789, 10, 123456789}, - }}, - {{9, 1}, - { - {123456789, -10, 0}, {123456789, -9, 0}, {123456789, -8, 0}, - {123456789, -7, 100000000}, {123456789, -6, 120000000}, {123456789, -5, 123000000}, - {123456789, -4, 123400000}, {123456789, -3, 123450000}, {123456789, -2, 123456000}, - {123456789, -1, 123456700}, {123456789, 0, 123456780}, {123456789, 1, 123456789}, - {123456789, 2, 123456789}, {123456789, 3, 123456789}, {123456789, 4, 123456789}, - {123456789, 5, 123456789}, {123456789, 6, 123456789}, {123456789, 7, 123456789}, - {123456789, 8, 123456789}, {123456789, 9, 123456789}, {123456789, 10, 123456789}, - }}, - {{9, 2}, - { - {123456789, -10, 0}, {123456789, -9, 0}, {123456789, -8, 0}, - {123456789, -7, 0}, {123456789, -6, 100000000}, {123456789, -5, 120000000}, - {123456789, -4, 123000000}, {123456789, -3, 123400000}, {123456789, -2, 123450000}, - {123456789, -1, 123456000}, {123456789, 0, 123456700}, {123456789, 1, 123456780}, - {123456789, 2, 123456789}, {123456789, 3, 123456789}, {123456789, 4, 123456789}, - {123456789, 5, 123456789}, {123456789, 6, 123456789}, {123456789, 7, 123456789}, - {123456789, 8, 123456789}, {123456789, 9, 123456789}, {123456789, 10, 123456789}, - }}, - {{9, 3}, - { - {123456789, -10, 0}, {123456789, -9, 0}, {123456789, -8, 0}, - {123456789, -7, 0}, {123456789, -6, 0}, {123456789, -5, 100000000}, - {123456789, -4, 120000000}, {123456789, -3, 123000000}, {123456789, -2, 123400000}, - {123456789, -1, 123450000}, {123456789, 0, 123456000}, {123456789, 1, 123456700}, - {123456789, 2, 123456780}, {123456789, 3, 123456789}, {123456789, 4, 123456789}, - {123456789, 5, 123456789}, {123456789, 6, 123456789}, {123456789, 7, 123456789}, - {123456789, 8, 123456789}, {123456789, 9, 123456789}, {123456789, 10, 123456789}, - }}, - {{9, 4}, - { - {123456789, -10, 0}, {123456789, -9, 0}, {123456789, -8, 0}, - {123456789, -7, 0}, {123456789, -6, 0}, {123456789, -5, 0}, - {123456789, -4, 100000000}, {123456789, -3, 120000000}, {123456789, -2, 123000000}, - {123456789, -1, 123400000}, {123456789, 0, 123450000}, {123456789, 1, 123456000}, - {123456789, 2, 123456700}, {123456789, 3, 123456780}, {123456789, 4, 123456789}, - {123456789, 5, 123456789}, {123456789, 6, 123456789}, {123456789, 7, 123456789}, - {123456789, 8, 123456789}, {123456789, 9, 123456789}, {123456789, 10, 123456789}, - }}, - {{9, 5}, - { - {123456789, -10, 0}, {123456789, -9, 0}, {123456789, -8, 0}, - {123456789, -7, 0}, {123456789, -6, 0}, {123456789, -5, 0}, - {123456789, -4, 0}, {123456789, -3, 100000000}, {123456789, -2, 120000000}, - {123456789, -1, 123000000}, {123456789, 0, 123400000}, {123456789, 1, 123450000}, - {123456789, 2, 123456000}, {123456789, 3, 123456700}, {123456789, 4, 123456780}, - {123456789, 5, 123456789}, {123456789, 6, 123456789}, {123456789, 7, 123456789}, - {123456789, 8, 123456789}, {123456789, 9, 123456789}, {123456789, 10, 123456789}, - }}, - {{9, 6}, - { - {123456789, -10, 0}, {123456789, -9, 0}, {123456789, -8, 0}, - {123456789, -7, 0}, {123456789, -6, 0}, {123456789, -5, 0}, - {123456789, -4, 0}, {123456789, -3, 0}, {123456789, -2, 100000000}, - {123456789, -1, 120000000}, {123456789, 0, 123000000}, {123456789, 1, 123400000}, - {123456789, 2, 123450000}, {123456789, 3, 123456000}, {123456789, 4, 123456700}, - {123456789, 5, 123456780}, {123456789, 6, 123456789}, {123456789, 7, 123456789}, - {123456789, 8, 123456789}, {123456789, 9, 123456789}, {123456789, 10, 123456789}, - }}, - {{9, 7}, - { - {123456789, -10, 0}, {123456789, -9, 0}, {123456789, -8, 0}, - {123456789, -7, 0}, {123456789, -6, 0}, {123456789, -5, 0}, - {123456789, -4, 0}, {123456789, -3, 0}, {123456789, -2, 0}, - {123456789, -1, 100000000}, {123456789, 0, 120000000}, {123456789, 1, 123000000}, - {123456789, 2, 123400000}, {123456789, 3, 123450000}, {123456789, 4, 123456000}, - {123456789, 5, 123456700}, {123456789, 6, 123456780}, {123456789, 7, 123456789}, - {123456789, 8, 123456789}, {123456789, 9, 123456789}, {123456789, 10, 123456789}, - }}, - {{9, 8}, - { - {123456789, -10, 0}, {123456789, -9, 0}, {123456789, -8, 0}, - {123456789, -7, 0}, {123456789, -6, 0}, {123456789, -5, 0}, - {123456789, -4, 0}, {123456789, -3, 0}, {123456789, -2, 0}, - {123456789, -1, 0}, {123456789, 0, 100000000}, {123456789, 1, 120000000}, - {123456789, 2, 123000000}, {123456789, 3, 123400000}, {123456789, 4, 123450000}, - {123456789, 5, 123456000}, {123456789, 6, 123456700}, {123456789, 7, 123456780}, - {123456789, 8, 123456789}, {123456789, 9, 123456789}, {123456789, 10, 123456789}, - }}, - {{9, 9}, - { - {123456789, -10, 0}, {123456789, -9, 0}, {123456789, -8, 0}, - {123456789, -7, 0}, {123456789, -6, 0}, {123456789, -5, 0}, - {123456789, -4, 0}, {123456789, -3, 0}, {123456789, -2, 0}, - {123456789, -1, 0}, {123456789, 0, 0}, {123456789, 1, 100000000}, - {123456789, 2, 120000000}, {123456789, 3, 123000000}, {123456789, 4, 123400000}, - {123456789, 5, 123450000}, {123456789, 6, 123456000}, {123456789, 7, 123456700}, - {123456789, 8, 123456780}, {123456789, 9, 123456789}, {123456789, 10, 123456789}, - }}}; - -const static TestDataSet truncate_decimal64_cases = { - {{10, 0}, - {{1234567891, -11, 0}, {1234567891, -10, 0}, {1234567891, -9, 1000000000}, - {1234567891, -8, 1200000000}, {1234567891, -7, 1230000000}, {1234567891, -6, 1234000000}, - {1234567891, -5, 1234500000}, {1234567891, -4, 1234560000}, {1234567891, -3, 1234567000}, - {1234567891, -2, 1234567800}, {1234567891, -1, 1234567890}, {1234567891, 0, 1234567891}, - {1234567891, 1, 1234567891}, {1234567891, 2, 1234567891}, {1234567891, 3, 1234567891}, - {1234567891, 4, 1234567891}, {1234567891, 5, 1234567891}, {1234567891, 6, 1234567891}, - {1234567891, 7, 1234567891}, {1234567891, 8, 1234567891}, {1234567891, 9, 1234567891}, - {1234567891, 10, 1234567891}, {1234567891, 11, 1234567891}}}, - {{10, 1}, - {{1234567891, -11, 0}, {1234567891, -10, 0}, {1234567891, -9, 0}, - {1234567891, -8, 1000000000}, {1234567891, -7, 1200000000}, {1234567891, -6, 1230000000}, - {1234567891, -5, 1234000000}, {1234567891, -4, 1234500000}, {1234567891, -3, 1234560000}, - {1234567891, -2, 1234567000}, {1234567891, -1, 1234567800}, {1234567891, 0, 1234567890}, - {1234567891, 1, 1234567891}, {1234567891, 2, 1234567891}, {1234567891, 3, 1234567891}, - {1234567891, 4, 1234567891}, {1234567891, 5, 1234567891}, {1234567891, 6, 1234567891}, - {1234567891, 7, 1234567891}, {1234567891, 8, 1234567891}, {1234567891, 9, 1234567891}, - {1234567891, 10, 1234567891}, {1234567891, 11, 1234567891} - - }}, - {{10, 2}, - {{1234567891, -11, 0}, {1234567891, -10, 0}, {1234567891, -9, 0}, - {1234567891, -8, 0}, {1234567891, -7, 1000000000}, {1234567891, -6, 1200000000}, - {1234567891, -5, 1230000000}, {1234567891, -4, 1234000000}, {1234567891, -3, 1234500000}, - {1234567891, -2, 1234560000}, {1234567891, -1, 1234567000}, {1234567891, 0, 1234567800}, - {1234567891, 1, 1234567890}, {1234567891, 2, 1234567891}, {1234567891, 3, 1234567891}, - {1234567891, 4, 1234567891}, {1234567891, 5, 1234567891}, {1234567891, 6, 1234567891}, - {1234567891, 7, 1234567891}, {1234567891, 8, 1234567891}, {1234567891, 9, 1234567891}, - {1234567891, 10, 1234567891}, {1234567891, 11, 1234567891}}}, - {{10, 9}, - {{1234567891, -11, 0}, {1234567891, -10, 0}, {1234567891, -9, 0}, - {1234567891, -8, 0}, {1234567891, -7, 0}, {1234567891, -6, 0}, - {1234567891, -5, 0}, {1234567891, -4, 0}, {1234567891, -3, 0}, - {1234567891, -2, 0}, {1234567891, -1, 0}, {1234567891, 0, 1000000000}, - {1234567891, 1, 1200000000}, {1234567891, 2, 1230000000}, {1234567891, 3, 1234000000}, - {1234567891, 4, 1234500000}, {1234567891, 5, 1234560000}, {1234567891, 6, 1234567000}, - {1234567891, 7, 1234567800}, {1234567891, 8, 1234567890}, {1234567891, 9, 1234567891}, - {1234567891, 10, 1234567891}, {1234567891, 11, 1234567891}}}, - {{18, 0}, - {{123456789123456789, -19, 0}, - {123456789123456789, -18, 0}, - {123456789123456789, -17, 100000000000000000}, - {123456789123456789, -16, 120000000000000000}, - {123456789123456789, -15, 123000000000000000}, - {123456789123456789, -14, 123400000000000000}, - {123456789123456789, -13, 123450000000000000}, - {123456789123456789, -12, 123456000000000000}, - {123456789123456789, -11, 123456700000000000}, - {123456789123456789, -10, 123456780000000000}, - {123456789123456789, -9, 123456789000000000}, - {123456789123456789, -8, 123456789100000000}, - {123456789123456789, -7, 123456789120000000}, - {123456789123456789, -6, 123456789123000000}, - {123456789123456789, -5, 123456789123400000}, - {123456789123456789, -4, 123456789123450000}, - {123456789123456789, -3, 123456789123456000}, - {123456789123456789, -2, 123456789123456700}, - {123456789123456789, -1, 123456789123456780}, - {123456789123456789, 0, 123456789123456789}, - {123456789123456789, 1, 123456789123456789}, - {123456789123456789, 2, 123456789123456789}, - {123456789123456789, 3, 123456789123456789}, - {123456789123456789, 4, 123456789123456789}, - {123456789123456789, 5, 123456789123456789}, - {123456789123456789, 6, 123456789123456789}, - {123456789123456789, 7, 123456789123456789}, - {123456789123456789, 8, 123456789123456789}, - {123456789123456789, 18, 123456789123456789}}}, - {{18, 18}, - {{123456789123456789, -1, 0}, - {123456789123456789, 0, 0}, - {123456789123456789, 1, 100000000000000000}, - {123456789123456789, 2, 120000000000000000}, - {123456789123456789, 3, 123000000000000000}, - {123456789123456789, 4, 123400000000000000}, - {123456789123456789, 5, 123450000000000000}, - {123456789123456789, 6, 123456000000000000}, - {123456789123456789, 7, 123456700000000000}, - {123456789123456789, 8, 123456780000000000}, - {123456789123456789, 9, 123456789000000000}, - {123456789123456789, 10, 123456789100000000}, - {123456789123456789, 11, 123456789120000000}, - {123456789123456789, 12, 123456789123000000}, - {123456789123456789, 13, 123456789123400000}, - {123456789123456789, 14, 123456789123450000}, - {123456789123456789, 15, 123456789123456000}, - {123456789123456789, 16, 123456789123456700}, - {123456789123456789, 17, 123456789123456780}, - {123456789123456789, 18, 123456789123456789}, - {123456789123456789, 19, 123456789123456789}, - {123456789123456789, 20, 123456789123456789}, - {123456789123456789, 21, 123456789123456789}, - {123456789123456789, 22, 123456789123456789}, - {123456789123456789, 23, 123456789123456789}, - {123456789123456789, 24, 123456789123456789}, - {123456789123456789, 25, 123456789123456789}, - {123456789123456789, 26, 123456789123456789}}}}; - -template -static void checker(const TestDataSet& truncate_test_cases, bool decimal_col_is_const) { - static_assert(IsDecimalNumber); - auto func = std::dynamic_pointer_cast(FuncType::create()); - FunctionContext* context = nullptr; - - for (const auto& test_case : truncate_test_cases) { - Block block; - size_t res_idx = 2; - ColumnNumbers arguments = {0, 1, 2}; - const int precision = test_case.first.first; - const int scale = test_case.first.second; - const size_t input_rows_count = test_case.second.size(); - auto col_general = ColumnDecimal::create(input_rows_count, scale); - auto col_scale = ColumnInt32::create(); - auto col_res_expected = ColumnDecimal::create(input_rows_count, scale); - size_t rid = 0; - - for (const auto& test_date : test_case.second) { - auto input = std::get<0>(test_date); - auto scale_arg = std::get<1>(test_date); - auto expectation = std::get<2>(test_date); - col_general->get_element(rid) = DecimalType(input); - col_scale->insert(scale_arg); - col_res_expected->get_element(rid) = DecimalType(expectation); - rid++; - } - - if (decimal_col_is_const) { - block.insert({ColumnConst::create(col_general->clone_resized(1), 1), - std::make_shared>(precision, scale), - "col_general_const"}); - } else { - block.insert({col_general->clone(), - std::make_shared>(precision, scale), - "col_general"}); - } - - block.insert({col_scale->clone(), std::make_shared(), "col_scale"}); - block.insert({nullptr, std::make_shared>(precision, scale), - "col_res"}); - - auto status = func->execute_impl(context, block, arguments, res_idx, input_rows_count); - auto col_res = assert_cast&>( - *(block.get_by_position(res_idx).column)); - EXPECT_TRUE(status.ok()); - - for (size_t i = 0; i < input_rows_count; ++i) { - auto res = col_res.get_element(i); - auto res_expected = col_res_expected->get_element(i); - EXPECT_EQ(res, res_expected) - << "precision " << precision << " input_scale " << scale << " input " - << col_general->get_element(i) << " scale_arg " << col_scale->get_element(i) - << " res " << res << " res_expected " << res_expected; - } - } -} -TEST(TruncateFunctionTest, normal_decimal) { - checker, Decimal32>(truncate_decimal32_cases, - false); - checker, Decimal64>(truncate_decimal64_cases, - false); -} - -TEST(TruncateFunctionTest, normal_decimal_const) { - checker, Decimal32>(truncate_decimal32_cases, true); - checker, Decimal64>(truncate_decimal64_cases, true); -} - -} // namespace doris::vectorized From 99d7fa84f9144d80c50a449d9501aecaca396913 Mon Sep 17 00:00:00 2001 From: chesterxu Date: Sat, 4 May 2024 12:25:10 +0800 Subject: [PATCH 3/7] suit fe code --- .../doris/analysis/FunctionCallExpr.java | 30 +++---------------- .../functions/ComputePrecisionForRound.java | 16 +++++----- 2 files changed, 13 insertions(+), 33 deletions(-) diff --git a/fe/fe-core/src/main/java/org/apache/doris/analysis/FunctionCallExpr.java b/fe/fe-core/src/main/java/org/apache/doris/analysis/FunctionCallExpr.java index 34cd8c7b120b7d..547fd52dcbd55e 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/analysis/FunctionCallExpr.java +++ b/fe/fe-core/src/main/java/org/apache/doris/analysis/FunctionCallExpr.java @@ -115,30 +115,8 @@ public class FunctionCallExpr extends Expr { return returnType; } }; - java.util.function.BiFunction, Type, Type> roundRule = (children, returnType) -> { - Preconditions.checkArgument(children != null && children.size() > 0); - if (children.size() == 1 && children.get(0).getType().isDecimalV3()) { - return ScalarType.createDecimalV3Type(children.get(0).getType().getPrecision(), 0); - } else if (children.size() == 2) { - Preconditions.checkArgument(children.get(1) instanceof IntLiteral - || (children.get(1) instanceof CastExpr - && children.get(1).getChild(0) instanceof IntLiteral), - "2nd argument of function round/floor/ceil must be literal"); - if (children.get(1) instanceof CastExpr && children.get(1).getChild(0) instanceof IntLiteral) { - children.get(1).getChild(0).setType(children.get(1).getType()); - children.set(1, children.get(1).getChild(0)); - } else { - children.get(1).setType(Type.INT); - } - int scaleArg = (int) (((IntLiteral) children.get(1)).getValue()); - return ScalarType.createDecimalV3Type(children.get(0).getType().getPrecision(), - Math.min(Math.max(scaleArg, 0), ((ScalarType) children.get(0).getType()).decimalScale())); - } else { - return returnType; - } - }; - java.util.function.BiFunction, Type, Type> truncateRule = (children, returnType) -> { + java.util.function.BiFunction, Type, Type> roundRule = (children, returnType) -> { Preconditions.checkArgument(children != null && children.size() > 0); if (children.size() == 1 && children.get(0).getType().isDecimalV3()) { return ScalarType.createDecimalV3Type(children.get(0).getType().getPrecision(), 0); @@ -264,11 +242,11 @@ public class FunctionCallExpr extends Expr { PRECISION_INFER_RULE.put("round", roundRule); PRECISION_INFER_RULE.put("round_bankers", roundRule); PRECISION_INFER_RULE.put("ceil", roundRule); - PRECISION_INFER_RULE.put("floor", truncateRule); + PRECISION_INFER_RULE.put("floor", roundRule); PRECISION_INFER_RULE.put("dround", roundRule); PRECISION_INFER_RULE.put("dceil", roundRule); - PRECISION_INFER_RULE.put("dfloor", truncateRule); - PRECISION_INFER_RULE.put("truncate", truncateRule); + PRECISION_INFER_RULE.put("dfloor", roundRule); + PRECISION_INFER_RULE.put("truncate", roundRule); } public static final ImmutableSet TIME_FUNCTIONS_WITH_PRECISION = new ImmutableSortedSet.Builder( diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/ComputePrecisionForRound.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/ComputePrecisionForRound.java index 40bf83c7406861..bea55b1a1f160a 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/ComputePrecisionForRound.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/ComputePrecisionForRound.java @@ -20,7 +20,10 @@ import org.apache.doris.catalog.FunctionSignature; import org.apache.doris.nereids.trees.expressions.Cast; import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.functions.scalar.Ceil; import org.apache.doris.nereids.trees.expressions.functions.scalar.Floor; +import org.apache.doris.nereids.trees.expressions.functions.scalar.Round; +import org.apache.doris.nereids.trees.expressions.functions.scalar.RoundBankers; import org.apache.doris.nereids.trees.expressions.functions.scalar.Truncate; import org.apache.doris.nereids.trees.expressions.literal.IntegerLikeLiteral; import org.apache.doris.nereids.types.DecimalV3Type; @@ -41,10 +44,10 @@ default FunctionSignature computePrecision(FunctionSignature signature) { Expression floatLength = getArgument(1); int scale; - if (this instanceof Truncate || this instanceof Floor) { - if (floatLength.isLiteral() || ( - floatLength instanceof Cast && floatLength.child(0).isLiteral() - && floatLength.child(0).getDataType() instanceof Int32OrLessType)) { + if (this instanceof Truncate || this instanceof Floor || this instanceof Ceil || this instanceof Round + || this instanceof RoundBankers) { + if (floatLength.isLiteral() || (floatLength instanceof Cast && floatLength.child(0).isLiteral() + && floatLength.child(0).getDataType() instanceof Int32OrLessType)) { // Scale argument is a literal or cast from other literal if (floatLength instanceof Cast) { scale = ((IntegerLikeLiteral) floatLength.child(0)).getIntValue(); @@ -59,9 +62,8 @@ default FunctionSignature computePrecision(FunctionSignature signature) { } } else { Preconditions.checkArgument(floatLength.getDataType() instanceof Int32OrLessType - && (floatLength.isLiteral() || ( - floatLength instanceof Cast && floatLength.child(0).isLiteral() - && floatLength.child(0).getDataType() instanceof Int32OrLessType)), + && (floatLength.isLiteral() || (floatLength instanceof Cast && floatLength.child(0).isLiteral() + && floatLength.child(0).getDataType() instanceof Int32OrLessType)), "2nd argument of function round/floor/ceil must be literal"); if (floatLength instanceof Cast) { scale = ((IntegerLikeLiteral) floatLength.child(0)).getIntValue(); From 3d838afa1b5a95f0e8c58ef0e8fa84694c01b269 Mon Sep 17 00:00:00 2001 From: chesterxu Date: Sat, 4 May 2024 21:22:34 +0800 Subject: [PATCH 4/7] add reg, fmt, fix, opt --- be/src/vec/functions/round.cpp | 39 ++--- be/src/vec/functions/round.h | 58 -------- be/test/vec/function/function_round_test.cpp | 137 ++++++++---------- .../math_functions/test_function_truncate.out | 3 + .../math_functions/test_round.out | 79 ++++++++++ .../test_function_truncate.groovy | 3 + .../math_functions/test_round.groovy | 121 ++++++++++++++++ 7 files changed, 275 insertions(+), 165 deletions(-) diff --git a/be/src/vec/functions/round.cpp b/be/src/vec/functions/round.cpp index 6b504839729c63..2ac0e3f6bda24b 100644 --- a/be/src/vec/functions/round.cpp +++ b/be/src/vec/functions/round.cpp @@ -17,25 +17,6 @@ #include "vec/functions/round.h" -#include -#include -#include -#include - -#include "common/exception.h" -#include "common/status.h" -#include "olap/olap_common.h" -#include "vec/columns/column.h" -#include "vec/columns/column_const.h" -#include "vec/columns/column_decimal.h" -#include "vec/columns/column_vector.h" -#include "vec/common/assert_cast.h" -#include "vec/core/call_on_type_index.h" -#include "vec/core/field.h" -#include "vec/core/types.h" -#include "vec/data_types/data_type.h" -#include "vec/data_types/data_type_decimal.h" -#include "vec/data_types/data_type_number.h" #include "vec/functions/simple_function_factory.h" namespace doris::vectorized { @@ -43,16 +24,16 @@ namespace doris::vectorized { // We split round funcs from register_function_math() in math.cpp to here, // so that to speed up compile time and make code more readable. void register_function_round(SimpleFunctionFactory& factory) { -#define REGISTER_ROUND_FUNCTIONS(IMPL) \ - factory.register_function, RoundingMode::Trunc, \ - TieBreakingMode::Auto>>(); \ - factory.register_function< \ - FunctionRounding, RoundingMode::Floor, TieBreakingMode::Auto>>(); \ - factory.register_function< \ - FunctionRounding, RoundingMode::Round, TieBreakingMode::Auto>>(); \ - factory.register_function< \ - FunctionRounding, RoundingMode::Ceil, TieBreakingMode::Auto>>(); \ - factory.register_function, RoundingMode::Round, \ +#define REGISTER_ROUND_FUNCTIONS(IMPL) \ + factory.register_function< \ + FunctionRounding, RoundingMode::Trunc, TieBreakingMode::Auto>>(); \ + factory.register_function< \ + FunctionRounding, RoundingMode::Floor, TieBreakingMode::Auto>>(); \ + factory.register_function< \ + FunctionRounding, RoundingMode::Round, TieBreakingMode::Auto>>(); \ + factory.register_function< \ + FunctionRounding, RoundingMode::Ceil, TieBreakingMode::Auto>>(); \ + factory.register_function, RoundingMode::Round, \ TieBreakingMode::Bankers>>(); REGISTER_ROUND_FUNCTIONS(DecimalRoundOneImpl) diff --git a/be/src/vec/functions/round.h b/be/src/vec/functions/round.h index 417b88d3e6f9fc..470114d81c1451 100644 --- a/be/src/vec/functions/round.h +++ b/be/src/vec/functions/round.h @@ -851,64 +851,6 @@ class FunctionRounding : public IFunction { } }; -struct TruncateFloatOneArgImpl { - static constexpr auto name = "truncate"; - static DataTypes get_variadic_argument_types() { return {std::make_shared()}; } -}; - -struct TruncateFloatTwoArgImpl { - static constexpr auto name = "truncate"; - static DataTypes get_variadic_argument_types() { - return {std::make_shared(), std::make_shared()}; - } -}; - -struct TruncateDecimalOneArgImpl { - static constexpr auto name = "truncate"; - static DataTypes get_variadic_argument_types() { - // All Decimal types are named Decimal, and real scale will be passed as type argument for execute function - // So we can just register Decimal32 here - return {std::make_shared>(9, 0)}; - } -}; - -struct TruncateDecimalTwoArgImpl { - static constexpr auto name = "truncate"; - static DataTypes get_variadic_argument_types() { - return {std::make_shared>(9, 0), - std::make_shared()}; - } -}; - -struct FloorFloatOneArgImpl { - static constexpr auto name = "floor"; - static DataTypes get_variadic_argument_types() { return {std::make_shared()}; } -}; - -struct FloorFloatTwoArgImpl { - static constexpr auto name = "floor"; - static DataTypes get_variadic_argument_types() { - return {std::make_shared(), std::make_shared()}; - } -}; - -struct FloorDecimalOneArgImpl { - static constexpr auto name = "floor"; - static DataTypes get_variadic_argument_types() { - // All Decimal types are named Decimal, and real scale will be passed as type argument for execute function - // So we can just register Decimal32 here - return {std::make_shared>(9, 0)}; - } -}; - -struct FloorDecimalTwoArgImpl { - static constexpr auto name = "floor"; - static DataTypes get_variadic_argument_types() { - return {std::make_shared>(9, 0), - std::make_shared()}; - } -}; - struct TruncateName { static constexpr auto name = "truncate"; }; diff --git a/be/test/vec/function/function_round_test.cpp b/be/test/vec/function/function_round_test.cpp index b162000d7baa90..df9ca2c66b7800 100644 --- a/be/test/vec/function/function_round_test.cpp +++ b/be/test/vec/function/function_round_test.cpp @@ -15,18 +15,11 @@ // specific language governing permissions and limitations // under the License. -#include -#include - #include #include #include -#include -#include -#include #include #include -#include #include #include #include @@ -51,20 +44,27 @@ using DecimalTestDataSet = // input, scale_arg, expectation using FloatTestDataSet = std::vector>; -using TruncateFunction = FunctionRounding, RoundingMode::Trunc, - TieBreakingMode::Auto>; - -using FloorFunction = +using DecimalTruncateFunction = FunctionRounding, + RoundingMode::Trunc, TieBreakingMode::Auto>; +using DecimalFloorFunction = FunctionRounding, RoundingMode::Floor, + TieBreakingMode::Auto>; +using DecimalCeilFunction = + FunctionRounding, RoundingMode::Ceil, TieBreakingMode::Auto>; +using DecimalRoundFunction = FunctionRounding, RoundingMode::Round, + TieBreakingMode::Auto>; +using DecimalRoundBankersFunction = FunctionRounding, + RoundingMode::Round, TieBreakingMode::Bankers>; + +using FloatTruncateFunction = FunctionRounding, + RoundingMode::Trunc, TieBreakingMode::Auto>; +using FloatFloorFunction = FunctionRounding, RoundingMode::Floor, TieBreakingMode::Auto>; - -using CeilFunction = +using FloatCeilFunction = FunctionRounding, RoundingMode::Ceil, TieBreakingMode::Auto>; - -using RoundFunction = +using FloatRoundFunction = FunctionRounding, RoundingMode::Round, TieBreakingMode::Auto>; - -using RoundBankersFunction = FunctionRounding, - RoundingMode::Round, TieBreakingMode::Bankers>; +using FloatRoundBankersFunction = FunctionRounding, + RoundingMode::Round, TieBreakingMode::Bankers>; // test cases for truncate and floor function of decimal32 const static DecimalTestDataSet trunc_floor_decimal32_cases = { @@ -319,6 +319,7 @@ const static DecimalTestDataSet trunc_floor_decimal64_cases = { {123456789123456789, 25, 123456789123456789}, {123456789123456789, 26, 123456789123456789}}}}; +// test cases for ceil function of decimal32 const static DecimalTestDataSet ceil_decimal32_cases = { {{1, 0}, { @@ -1053,94 +1054,74 @@ static void float_checker(const FloatTestDataSet& round_test_cases, bool float_c /// tests for func(Column, Column) with decimal input TEST(RoundFunctionTest, normal_decimal) { - // truncate - decimal_checker(trunc_floor_decimal32_cases, false); - decimal_checker(trunc_floor_decimal64_cases, false); + decimal_checker(trunc_floor_decimal32_cases, false); + decimal_checker(trunc_floor_decimal64_cases, false); - // floor - decimal_checker(trunc_floor_decimal32_cases, false); - decimal_checker(trunc_floor_decimal64_cases, false); + decimal_checker(trunc_floor_decimal32_cases, false); + decimal_checker(trunc_floor_decimal64_cases, false); - // ceil - decimal_checker(ceil_decimal32_cases, false); - decimal_checker(ceil_decimal64_cases, false); + decimal_checker(ceil_decimal32_cases, false); + decimal_checker(ceil_decimal64_cases, false); - // round - decimal_checker(round_decimal32_cases, false); - decimal_checker(round_decimal64_cases, false); + decimal_checker(round_decimal32_cases, false); + decimal_checker(round_decimal64_cases, false); - // round_bankers - decimal_checker(round_decimal32_cases, false); - decimal_checker(round_decimal64_cases, false); + decimal_checker(round_decimal32_cases, false); + decimal_checker(round_decimal64_cases, false); } /// tests for func(ColumnConst, Column) with decimal input TEST(RoundFunctionTest, normal_decimal_const) { - // truncate - decimal_checker(trunc_floor_decimal32_cases, true); - decimal_checker(trunc_floor_decimal64_cases, true); + decimal_checker(trunc_floor_decimal32_cases, true); + decimal_checker(trunc_floor_decimal64_cases, true); - // floor - decimal_checker(trunc_floor_decimal32_cases, true); - decimal_checker(trunc_floor_decimal64_cases, true); + decimal_checker(trunc_floor_decimal32_cases, true); + decimal_checker(trunc_floor_decimal64_cases, true); - // ceil - decimal_checker(ceil_decimal32_cases, true); - decimal_checker(ceil_decimal64_cases, true); + decimal_checker(ceil_decimal32_cases, true); + decimal_checker(ceil_decimal64_cases, true); - // round - decimal_checker(round_decimal32_cases, true); - decimal_checker(round_decimal64_cases, true); + decimal_checker(round_decimal32_cases, true); + decimal_checker(round_decimal64_cases, true); - // round_bankers - decimal_checker(round_decimal32_cases, true); - decimal_checker(round_decimal64_cases, true); + decimal_checker(round_decimal32_cases, true); + decimal_checker(round_decimal64_cases, true); } /// tests for func(Column, Column) with float input TEST(RoundFunctionTest, normal_float) { - // truncate - float_checker(trunc_float32_cases, false); - float_checker(trunc_float64_cases, false); + float_checker(trunc_float32_cases, false); + float_checker(trunc_float64_cases, false); - // floor - float_checker(floor_float32_cases, false); - float_checker(floor_float64_cases, false); + float_checker(floor_float32_cases, false); + float_checker(floor_float64_cases, false); - // ceil - float_checker(ceil_float32_cases, false); - float_checker(ceil_float64_cases, false); + float_checker(ceil_float32_cases, false); + float_checker(ceil_float64_cases, false); - // round - float_checker(round_float32_cases, false); - float_checker(round_float64_cases, false); + float_checker(round_float32_cases, false); + float_checker(round_float64_cases, false); - // round_bankers - float_checker(round_bankers_float32_cases, false); - float_checker(round_bankers_float64_cases, false); + float_checker(round_bankers_float32_cases, false); + float_checker(round_bankers_float64_cases, false); } /// tests for func(ColumnConst, Column) with float input TEST(RoundFunctionTest, normal_float_const) { - // truncate - float_checker(trunc_float32_cases, true); - float_checker(trunc_float64_cases, true); + float_checker(trunc_float32_cases, true); + float_checker(trunc_float64_cases, true); - // floor - float_checker(floor_float32_cases, true); - float_checker(floor_float64_cases, true); + float_checker(floor_float32_cases, true); + float_checker(floor_float64_cases, true); - // ceil - float_checker(ceil_float32_cases, true); - float_checker(ceil_float64_cases, true); + float_checker(ceil_float32_cases, true); + float_checker(ceil_float64_cases, true); - // round - float_checker(round_float32_cases, true); - float_checker(round_float64_cases, true); + float_checker(round_float32_cases, true); + float_checker(round_float64_cases, true); - // round_bankers - float_checker(round_bankers_float32_cases, true); - float_checker(round_bankers_float64_cases, true); + float_checker(round_bankers_float32_cases, true); + float_checker(round_bankers_float64_cases, true); } } // namespace doris::vectorized diff --git a/regression-test/data/query_p0/sql_functions/math_functions/test_function_truncate.out b/regression-test/data/query_p0/sql_functions/math_functions/test_function_truncate.out index 24f675ffbe29a2..80d77f50dc3cee 100644 --- a/regression-test/data/query_p0/sql_functions/math_functions/test_function_truncate.out +++ b/regression-test/data/query_p0/sql_functions/math_functions/test_function_truncate.out @@ -1,4 +1,7 @@ -- This file is automatically generated. You should know what you did if you want to edit this +-- !sql -- +10.0 10.0 + -- !sql -- 0 123.3 1 123.3 diff --git a/regression-test/data/query_p0/sql_functions/math_functions/test_round.out b/regression-test/data/query_p0/sql_functions/math_functions/test_round.out index 50d15b2843be16..1ebc9cf5b894b7 100644 --- a/regression-test/data/query_p0/sql_functions/math_functions/test_round.out +++ b/regression-test/data/query_p0/sql_functions/math_functions/test_round.out @@ -140,3 +140,82 @@ -- !query -- 0.000 0.000 0.000 +-- !floor_dec9 -- +1 123456789 123456789 12345678.1 12345678.1 0.123456789 0.100000000 +1 123456789 123456789 12345678.1 12345678.1 0.123456789 0.100000000 + +-- !floor_dec10 -- +1 123456789 123456789 1.123456789 1.100000000 0.1234567890 0.1000000000 +1 123456789 123456789 1.123456789 1.100000000 0.1234567890 0.1000000000 + +-- !floor_flo -- +1 12345.123 12345.12 1.2345678912345679E8 1.23456789123E8 +1 12345.123 12345.12 1.2345678912345679E8 1.23456789123E8 + +-- !ceil_dec9 -- +1 123456789 123456789 12345678.1 12345678.1 0.123456789 0.200000000 +1 123456789 123456789 12345678.1 12345678.1 0.123456789 0.200000000 + +-- !ceil_dec10 -- +1 123456789 123456789 1.123456789 1.200000000 0.1234567890 0.2000000000 +1 123456789 123456789 1.123456789 1.200000000 0.1234567890 0.2000000000 + +-- !ceil_flo -- +1 12345.123 12346.0 1.2345678912345679E8 1.2345679E8 +1 12345.123 12346.0 1.2345678912345679E8 1.2345679E8 + +-- !round_dec9 -- +1 123456789 123456789 12345678.1 12345678.1 0.123456789 0.100000000 +1 123456789 123456789 12345678.1 12345678.1 0.123456789 0.100000000 + +-- !round_dec10 -- +1 123456789 123456789 1.123456789 1.100000000 0.1234567890 0.1000000000 +1 123456789 123456789 1.123456789 1.100000000 0.1234567890 0.1000000000 + +-- !round_flo -- +1 12345.123 12350.0 1.2345678912345679E8 1.234568E8 +1 12345.123 12350.0 1.2345678912345679E8 1.234568E8 + +-- !round_bankers_dec9 -- +1 123456789 123456789 12345678.1 12345678.1 0.123456789 0.100000000 +1 123456789 123456789 12345678.1 12345678.1 0.123456789 0.100000000 + +-- !round_bankers_dec10 -- +1 123456789 123456789 1.123456789 1.100000000 0.1234567890 0.1000000000 +1 123456789 123456789 1.123456789 1.100000000 0.1234567890 0.1000000000 + +-- !round_bankers_flo -- +1 12345.123 12350.0 1.2345678912345679E8 1.234568E8 +1 12345.123 12350.0 1.2345678912345679E8 1.234568E8 + +-- !all_funcs_compare_dec -- +5 1.123456789 1.123450000 1.123450000 1.123460000 1.123460000 1.123460000 + +-- !bankers_compare -- +2.5 0 3.0 2.0 + +-- !nested_func -- +1 2 + +-- !pos_zero_neg_compare -- +1 1.2345678912345679E8 1.234568E8 1.2345679E8 1.2345678913E8 +1 1.2345678912345679E8 1.234568E8 1.2345679E8 1.2345678913E8 + +-- !cast_dec -- +0E-8 0.00 + +-- !col_const_compare -- +1 1.123456789 1.100000000 1.1 1.100000000 1.1 + +-- !floor_dec128 -- +1 1234567891234567891 1234567891234567891 1234567891234567891 1234567891.123456789 1234567891 1234567891.100000000 0.1234567891234567891 0 0.1000000000000000000 + +-- !ceil_dec128 -- +1 1234567891234567891 1234567891234567891 1234567891234567891 1234567891.123456789 1234567892 1234567891.200000000 0.1234567891234567891 1 0.2000000000000000000 + +-- !round_dec128 -- +1 1234567891234567891 1234567891234567891 1234567891234567891 1234567891.123456789 1234567891 1234567891.100000000 0.1234567891234567891 0 0.1000000000000000000 + +-- !round_bankers_dec128 -- +1 1234567891234567891 1234567891234567891 1234567891234567891 1234567891.123456789 1234567891 1234567891.100000000 0.1234567891234567891 0 0.1000000000000000000 + diff --git a/regression-test/suites/query_p0/sql_functions/math_functions/test_function_truncate.groovy b/regression-test/suites/query_p0/sql_functions/math_functions/test_function_truncate.groovy index 767140e7a6ff85..71d4f67f25b304 100644 --- a/regression-test/suites/query_p0/sql_functions/math_functions/test_function_truncate.groovy +++ b/regression-test/suites/query_p0/sql_functions/math_functions/test_function_truncate.groovy @@ -16,6 +16,9 @@ // under the License. suite("test_function_truncate") { + // this single parameter test should has the same result as before + qt_sql """SELECT truncate(10.12345), truncate(cast(10.12345 as decimal(7, 5)));""" + qt_sql """ SELECT number, truncate(123.345 , 1) FROM numbers("number"="10"); """ diff --git a/regression-test/suites/query_p0/sql_functions/math_functions/test_round.groovy b/regression-test/suites/query_p0/sql_functions/math_functions/test_round.groovy index efdc003fbd4a73..1d8bbb9df49513 100644 --- a/regression-test/suites/query_p0/sql_functions/math_functions/test_round.groovy +++ b/regression-test/suites/query_p0/sql_functions/math_functions/test_round.groovy @@ -143,4 +143,125 @@ qt_query """ select cast(round(sum(d1), 2) as decimalv3(27, 3)), cast(round(sum(d2), 2) as decimalv3(27, 3)), cast(round(sum(d3),2) as decimalv3(27, 3)) from ${tableName3} """ qt_query """ select cast(round(sum(d1), -2) as decimalv3(27, 3)), cast(round(sum(d2), -2) as decimalv3(27, 3)), cast(round(sum(d3), -2) as decimalv3(27, 3)) from ${tableName3} """ qt_query """ select cast(round(sum(d1), -4) as decimalv3(27, 3)), cast(round(sum(d2), -4) as decimalv3(27, 3)), cast(round(sum(d3), -4) as decimalv3(27, 3)) from ${tableName3} """ + + /// Testing with enhanced round function, which can deal with scale being a column, like this: + /// func(Column, Column), func(ColumnConst, Column). + /// Consider truncate() has been tested in test_function_truncate.groovy, so we focus on the rest here. + sql """DROP TABLE IF EXISTS test_enhanced_round;""" + sql """ + CREATE TABLE test_enhanced_round ( + rid int, flo float, dou double, + dec90 decimal(9, 0), dec91 decimal(9, 1), dec99 decimal(9, 9), + dec100 decimal(10,0), dec109 decimal(10,9), dec1010 decimal(10,10), + number int DEFAULT 1) + DISTRIBUTED BY HASH(rid) + PROPERTIES("replication_num" = "1" ); + """ + sql """ + INSERT INTO test_enhanced_round + VALUES + (1, 12345.123, 123456789.123456789, + 123456789, 12345678.1, 0.123456789, + 123456789.1, 1.123456789, 0.123456789, 1); + """ + sql """ + INSERT INTO test_enhanced_round + VALUES + (2, 12345.123, 123456789.123456789, + 123456789, 12345678.1, 0.123456789, + 123456789.1, 1.123456789, 0.123456789, 1); + """ + qt_floor_dec9 """ + SELECT number, dec90, floor(dec90, number), dec91, floor(dec91, number), dec99, floor(dec99, number) FROM test_enhanced_round order by rid; + """ + qt_floor_dec10 """ + SELECT number, dec100, floor(dec100, number), dec109, floor(dec109, number), dec1010, floor(dec1010, number) FROM test_enhanced_round order by rid; + """ + qt_floor_flo """ + SELECT number, flo, floor(flo, number + 1), dou, floor(dou, number + 2) FROM test_enhanced_round order by rid; + """ + qt_ceil_dec9 """ + SELECT number, dec90, ceil(dec90, number), dec91, ceil(dec91, number), dec99, ceil(dec99, number) FROM test_enhanced_round order by rid; + """ + qt_ceil_dec10 """ + SELECT number, dec100, ceil(dec100, number), dec109, ceil(dec109, number), dec1010, ceil(dec1010, number) FROM test_enhanced_round order by rid; + """ + qt_ceil_flo """ + SELECT number, flo, ceil(flo, number - 1), dou, ceil(dou, number - 2) FROM test_enhanced_round order by rid; + """ + qt_round_dec9 """ + SELECT number, dec90, round(dec90, number), dec91, round(dec91, number), dec99, round(dec99, number) FROM test_enhanced_round order by rid; + """ + qt_round_dec10 """ + SELECT number, dec100, round(dec100, number), dec109, round(dec109, number), dec1010, round(dec1010, number) FROM test_enhanced_round order by rid; + """ + qt_round_flo """ + SELECT number, flo, round(flo, number - 2), dou, round(dou, number - 3) FROM test_enhanced_round order by rid; + """ + qt_round_bankers_dec9 """ + SELECT number, dec90, round_bankers(dec90, number), dec91, round_bankers(dec91, number), dec99, round_bankers(dec99, number) FROM test_enhanced_round order by rid; + """ + qt_round_bankers_dec10 """ + SELECT number, dec100, round_bankers(dec100, number), dec109, round_bankers(dec109, number), dec1010, round_bankers(dec1010, number) FROM test_enhanced_round order by rid; + """ + qt_round_bankers_flo """ + SELECT number, flo, round_bankers(flo, number - 2), dou, round_bankers(dou, number - 3) FROM test_enhanced_round order by rid; + """ + + qt_all_funcs_compare_dec """ + SELECT number + 4 as new_number, dec109, truncate(dec109, number + 4) as t_res, floor(dec109, number + 4) as f_res, ceil(dec109, number + 4) as c_res, round(dec109, number + 4) as r_res, + round_bankers(dec109, number + 4) as rb_res FROM test_enhanced_round where rid = 1; + """ + qt_bankers_compare """ + SELECT number * 2.5 as input1, number - 1 as input2, round(number * 2.5, number - 1) as r_res, round_bankers(number * 2.5, number - 1) as rb_res FROM test_enhanced_round where rid = 1; + """ + qt_nested_func """ + SELECT number, floor(floor(number * floor(number) + 1), ceil(floor(number))) as nested_col FROM test_enhanced_round where rid = 1; + """ + qt_pos_zero_neg_compare """ + SELECT number, dou, ceil(dou, (-2) * number), ceil(dou, 0 * number), ceil(dou, 2 * number) FROM test_enhanced_round; + """ + qt_cast_dec """ + SELECT round(cast(0 as Decimal(9,8)), 10), round(cast(0 as Decimal(9,8)), 2); + """ + //For func(x, d), if d is a column and x has Decimal type, scale of result Decimal will always be same with input Decimal. + qt_col_const_compare """ + SELECT number, dec109, floor(dec109, number) as f_col_col, floor(dec109, 1) as f_col_const, + floor(1.123456789, number) as f_const_col, floor(1.123456789, 1) as f_const_const FROM test_enhanced_round limit 1; + """ + + sql """DROP TABLE IF EXISTS test_enhanced_round_dec128;""" + sql """ + CREATE TABLE test_enhanced_round_dec128 ( + rid int, dec190 decimal(19,0), dec199 decimal(19,9), dec1919 decimal(19,19), + dec380 decimal(38,0), dec3819 decimal(38,19), dec3838 decimal(38,38), + number int DEFAULT 1 + ) + DISTRIBUTED BY HASH(rid) + PROPERTIES("replication_num" = "1" ); + """ + sql """ + INSERT INTO test_enhanced_round_dec128 + VALUES + (1, 1234567891234567891.0, 1234567891.123456789, 0.1234567891234567891, + 12345678912345678912345678912345678912.0, + 1234567891234567891.1234567891234567891, + 0.12345678912345678912345678912345678912345678912345678912345678912345678912, 1); + """ + qt_floor_dec128 """ + SELECT number, dec190, floor(dec190, 0), floor(dec190, number), dec199, floor(dec199, 0), floor(dec199, number), + dec1919, floor(dec1919, 0), floor(dec1919, number) FROM test_enhanced_round_dec128 order by rid; + """ + qt_ceil_dec128 """ + SELECT number, dec190, ceil(dec190, 0), ceil(dec190, number), dec199, ceil(dec199, 0), ceil(dec199, number), + dec1919, ceil(dec1919, 0), ceil(dec1919, number) FROM test_enhanced_round_dec128 order by rid; + """ + qt_round_dec128 """ + SELECT number, dec190, round(dec190, 0), round(dec190, number), dec199, round(dec199, 0), round(dec199, number), + dec1919, round(dec1919, 0), round(dec1919, number) FROM test_enhanced_round_dec128 order by rid; + """ + qt_round_bankers_dec128 """ + SELECT number, dec190, round_bankers(dec190, 0), round_bankers(dec190, number), dec199, round_bankers(dec199, 0), round_bankers(dec199, number), + dec1919, round_bankers(dec1919, 0), round_bankers(dec1919, number) FROM test_enhanced_round_dec128 order by rid; + """ } From cbd0f2eee453123ae8c984adca827c933bf0876d Mon Sep 17 00:00:00 2001 From: chesterxu Date: Sun, 5 May 2024 11:45:12 +0800 Subject: [PATCH 5/7] opt --- be/src/vec/functions/math.cpp | 16 ++------- be/src/vec/functions/round.cpp | 6 +++- be/src/vec/functions/round.h | 6 ++-- be/test/vec/function/function_round_test.cpp | 6 ++++ .../functions/ComputePrecisionForRound.java | 35 ++++--------------- .../test_function_truncate.groovy | 3 +- 6 files changed, 25 insertions(+), 47 deletions(-) diff --git a/be/src/vec/functions/math.cpp b/be/src/vec/functions/math.cpp index 3596d2ebae6fb0..bc086ce444730f 100644 --- a/be/src/vec/functions/math.cpp +++ b/be/src/vec/functions/math.cpp @@ -15,29 +15,20 @@ // specific language governing permissions and limitations // under the License. -#include -#include +#include +#include -#include // IWYU pragma: no_include -#include #include -#include #include #include -#include #include "common/status.h" -#include "vec/aggregate_functions/aggregate_function.h" #include "vec/columns/column.h" #include "vec/columns/column_string.h" #include "vec/columns/column_vector.h" #include "vec/columns/columns_number.h" #include "vec/core/types.h" -#include "vec/data_types/data_type.h" -#include "vec/data_types/data_type_decimal.h" -#include "vec/data_types/data_type_nullable.h" -#include "vec/data_types/data_type_number.h" #include "vec/data_types/data_type_string.h" #include "vec/data_types/number_traits.h" #include "vec/functions/function_binary_arithmetic.h" @@ -337,8 +328,6 @@ void register_function_math(SimpleFunctionFactory& factory) { factory.register_function(); factory.register_function(); factory.register_function(); - factory.register_alias("ceil", "dceil"); - factory.register_alias("ceil", "ceiling"); factory.register_function(); factory.register_alias("ln", "dlog1"); factory.register_function(); @@ -357,7 +346,6 @@ void register_function_math(SimpleFunctionFactory& factory) { factory.register_function(); factory.register_function(); factory.register_function(); - factory.register_alias("floor", "dfloor"); factory.register_function(); factory.register_alias("pow", "power"); factory.register_alias("pow", "dpow"); diff --git a/be/src/vec/functions/round.cpp b/be/src/vec/functions/round.cpp index 2ac0e3f6bda24b..6b6fc42d653c33 100644 --- a/be/src/vec/functions/round.cpp +++ b/be/src/vec/functions/round.cpp @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -#include "vec/functions/round.h" +#include "round.h" #include "vec/functions/simple_function_factory.h" @@ -40,6 +40,10 @@ void register_function_round(SimpleFunctionFactory& factory) { REGISTER_ROUND_FUNCTIONS(DecimalRoundTwoImpl) REGISTER_ROUND_FUNCTIONS(DoubleRoundOneImpl) REGISTER_ROUND_FUNCTIONS(DoubleRoundTwoImpl) + + factory.register_alias("ceil", "dceil"); + factory.register_alias("ceil", "ceiling"); + factory.register_alias("floor", "dfloor"); factory.register_alias("round", "dround"); } diff --git a/be/src/vec/functions/round.h b/be/src/vec/functions/round.h index 470114d81c1451..ecfda317675893 100644 --- a/be/src/vec/functions/round.h +++ b/be/src/vec/functions/round.h @@ -21,7 +21,6 @@ #pragma once #include -#include #include "common/exception.h" #include "common/status.h" @@ -557,7 +556,7 @@ struct Dispatcher { static ColumnPtr apply_const_vec(const ColumnConst* const_col_general, const IColumn* col_scale) { - const ColumnInt32& col_scale_i32 = assert_cast(*col_scale); + const auto& col_scale_i32 = assert_cast(*col_scale); const size_t input_rows_count = col_scale->size(); for (size_t i = 0; i < input_rows_count; ++i) { @@ -685,6 +684,7 @@ class FunctionRounding : public IFunction { // should NOT behave like two column arguments, so we can not use const column default implementation bool use_default_implementation_for_constants() const override { return false; } + // we moved the execute logic of function_truncate.h from PR#32746 and make it suitable for all functions Status execute_impl(FunctionContext* context, Block& block, const ColumnNumbers& arguments, size_t result, size_t input_rows_count) const override { const ColumnWithTypeAndName& column_general = block.get_by_position(arguments[0]); @@ -744,7 +744,7 @@ class FunctionRounding : public IFunction { if (arguments.size() == 2) { RETURN_IF_ERROR(get_scale_arg(block.get_by_position(arguments[1]), &scale_arg)); } else if (is_column_const(*column_general.column)) { - // if we only have one ColumnConst + // if we only have one ColumnConst, we should cast it, otherwise it would cause BE crash col_general = assert_cast(*column_general.column) .get_data_column_ptr(); } diff --git a/be/test/vec/function/function_round_test.cpp b/be/test/vec/function/function_round_test.cpp index df9ca2c66b7800..10ad1c2c0f18f3 100644 --- a/be/test/vec/function/function_round_test.cpp +++ b/be/test/vec/function/function_round_test.cpp @@ -36,6 +36,12 @@ #include "vec/data_types/data_type_number.h" #include "vec/functions/round.h" +/** + This BE UT focus on enhancement of round based function, which enables them + to use column as scale argument. We test truncate/floor/ceil/round/round_bankers + together by moving test cases of function_truncate_test.cpp here. +*/ + namespace doris::vectorized { // {precision, scale} -> {input, scale_arg, expectation} using DecimalTestDataSet = diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/ComputePrecisionForRound.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/ComputePrecisionForRound.java index bea55b1a1f160a..eedbfea6df9ac6 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/ComputePrecisionForRound.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/ComputePrecisionForRound.java @@ -20,17 +20,10 @@ import org.apache.doris.catalog.FunctionSignature; import org.apache.doris.nereids.trees.expressions.Cast; import org.apache.doris.nereids.trees.expressions.Expression; -import org.apache.doris.nereids.trees.expressions.functions.scalar.Ceil; -import org.apache.doris.nereids.trees.expressions.functions.scalar.Floor; -import org.apache.doris.nereids.trees.expressions.functions.scalar.Round; -import org.apache.doris.nereids.trees.expressions.functions.scalar.RoundBankers; -import org.apache.doris.nereids.trees.expressions.functions.scalar.Truncate; import org.apache.doris.nereids.trees.expressions.literal.IntegerLikeLiteral; import org.apache.doris.nereids.types.DecimalV3Type; import org.apache.doris.nereids.types.coercion.Int32OrLessType; -import com.google.common.base.Preconditions; - /** ComputePrecisionForRound */ public interface ComputePrecisionForRound extends ComputePrecision { @Override @@ -44,33 +37,19 @@ default FunctionSignature computePrecision(FunctionSignature signature) { Expression floatLength = getArgument(1); int scale; - if (this instanceof Truncate || this instanceof Floor || this instanceof Ceil || this instanceof Round - || this instanceof RoundBankers) { - if (floatLength.isLiteral() || (floatLength instanceof Cast && floatLength.child(0).isLiteral() - && floatLength.child(0).getDataType() instanceof Int32OrLessType)) { - // Scale argument is a literal or cast from other literal - if (floatLength instanceof Cast) { - scale = ((IntegerLikeLiteral) floatLength.child(0)).getIntValue(); - } else { - scale = ((IntegerLikeLiteral) floatLength).getIntValue(); - } - scale = Math.min(Math.max(scale, 0), decimalV3Type.getScale()); - } else { - // Truncate could use Column as its scale argument. - // Result scale will always same with input Decimal in this situation. - scale = decimalV3Type.getScale(); - } - } else { - Preconditions.checkArgument(floatLength.getDataType() instanceof Int32OrLessType - && (floatLength.isLiteral() || (floatLength instanceof Cast && floatLength.child(0).isLiteral() - && floatLength.child(0).getDataType() instanceof Int32OrLessType)), - "2nd argument of function round/floor/ceil must be literal"); + if (floatLength.isLiteral() || (floatLength instanceof Cast && floatLength.child(0).isLiteral() + && floatLength.child(0).getDataType() instanceof Int32OrLessType)) { + // Scale argument is a literal or cast from other literal if (floatLength instanceof Cast) { scale = ((IntegerLikeLiteral) floatLength.child(0)).getIntValue(); } else { scale = ((IntegerLikeLiteral) floatLength).getIntValue(); } scale = Math.min(Math.max(scale, 0), decimalV3Type.getScale()); + } else { + // Func could use Column as its scale argument. + // Result scale will always same with input Decimal in this situation. + scale = decimalV3Type.getScale(); } return signature.withArgumentType(0, decimalV3Type) diff --git a/regression-test/suites/query_p0/sql_functions/math_functions/test_function_truncate.groovy b/regression-test/suites/query_p0/sql_functions/math_functions/test_function_truncate.groovy index 71d4f67f25b304..b7c36dfbaa1151 100644 --- a/regression-test/suites/query_p0/sql_functions/math_functions/test_function_truncate.groovy +++ b/regression-test/suites/query_p0/sql_functions/math_functions/test_function_truncate.groovy @@ -16,7 +16,8 @@ // under the License. suite("test_function_truncate") { - // this single parameter test should has the same result as before + // NOTICE: This single const argument test should never cause BE crash, + // like branch2.0's behavior, so we added it to check. qt_sql """SELECT truncate(10.12345), truncate(cast(10.12345 as decimal(7, 5)));""" qt_sql """ From 262d824d1afa3ff9c730acc83041d5402ef7acba Mon Sep 17 00:00:00 2001 From: chesterxu Date: Sun, 5 May 2024 23:02:35 +0800 Subject: [PATCH 6/7] opt execute_impl --- be/src/vec/functions/round.h | 200 +++++++++-------------------------- 1 file changed, 49 insertions(+), 151 deletions(-) diff --git a/be/src/vec/functions/round.h b/be/src/vec/functions/round.h index ecfda317675893..e1dab731b9fd95 100644 --- a/be/src/vec/functions/round.h +++ b/be/src/vec/functions/round.h @@ -516,7 +516,7 @@ struct Dispatcher { } else if constexpr (IsDecimalNumber) { const auto* decimal_col = assert_cast*>(col_general); - // For truncate, ALWAYS use SAME scale with source Decimal column + // ALWAYS use SAME scale with source Decimal column const Int32 input_scale = decimal_col->get_scale(); auto col_res = ColumnDecimal::create(input_row_count, input_scale); @@ -679,171 +679,69 @@ class FunctionRounding : public IFunction { return Status::OK(); } - ColumnNumbers get_arguments_that_are_always_constant() const override { return {}; } - // SELECT number, func(123.345, 1) FROM numbers("number"="10") - // should NOT behave like two column arguments, so we can not use const column default implementation - bool use_default_implementation_for_constants() const override { return false; } - - // we moved the execute logic of function_truncate.h from PR#32746 and make it suitable for all functions + //// We moved and optimized the execute_impl logic of function_truncate.h from PR#32746, + //// as well as make it suitable for all functions. Status execute_impl(FunctionContext* context, Block& block, const ColumnNumbers& arguments, size_t result, size_t input_rows_count) const override { const ColumnWithTypeAndName& column_general = block.get_by_position(arguments[0]); + const auto* col_general = column_general.column.get(); ColumnPtr res; - // potential argument types: - // 0. func(ColumnConst, ColumnConst) - // 1. func(Column), func(ColumnConst), func(Column, ColumnConst) - // 2. func(Column, Column) - // 3. func(ColumnConst, Column) - - if (arguments.size() == 2 && is_column_const(*block.get_by_position(arguments[0]).column) && - is_column_const(*block.get_by_position(arguments[1]).column)) { - // func(ColumnConst, ColumnConst) - auto col_general = - assert_cast(*column_general.column).get_data_column_ptr(); - Int16 scale_arg = 0; - RETURN_IF_ERROR(get_scale_arg(block.get_by_position(arguments[1]), &scale_arg)); - - auto call = [&](const auto& types) -> bool { - using Types = std::decay_t; - using DataType = typename Types::LeftType; - - if constexpr (IsDataTypeNumber || IsDataTypeDecimal) { - using FieldType = typename DataType::FieldType; + /// potential argument types(optimized from four types in previous PR to two): + /// if the SECOND argument is MISSING(would be considered as ZERO const) or CONST, then we have 1st type: + /// 1. func(Column), func(ColumnConst), func(Column, ColumnConst), func(ColumnConst, ColumnConst) + /// otherwise, the SECOND arugment is COLUMN, we have another type: + /// 2. func(Column, Column), func(ColumnConst, Column) + + auto call = [&](const auto& types) -> bool { + using Types = std::decay_t; + using DataType = typename Types::LeftType; + + if constexpr (IsDataTypeNumber || IsDataTypeDecimal) { + using FieldType = typename DataType::FieldType; + // the SECOND argument is MISSING or CONST + if (arguments.size() == 1 || + is_column_const(*block.get_by_position(arguments[1]).column)) { + Int16 scale_arg = 0; + if (arguments.size() == 2) { + RETURN_IF_ERROR( + get_scale_arg(block.get_by_position(arguments[1]), &scale_arg)); + } res = Dispatcher::apply_vec_const( col_general, scale_arg); - return true; - } - - return false; - }; - -#if !defined(__SSE4_1__) && !defined(__aarch64__) - /// In case of "nearbyint" function is used, we should ensure the expected rounding mode for the Banker's rounding. - /// Actually it is by default. But we will set it just in case. - - if constexpr (rounding_mode == RoundingMode::Round) { - if (0 != fesetround(FE_TONEAREST)) { - return Status::InvalidArgument("Cannot set floating point rounding mode"); - } - } -#endif - - if (!call_on_index_and_data_type(column_general.type->get_type_id(), call)) { - return Status::InvalidArgument("Invalid argument type {} for function {}", - column_general.type->get_name(), get_name()); - } - // Important, make sure the result column has the same size as the input column - res = ColumnConst::create(std::move(res), input_rows_count); - } else if (arguments.size() == 1 || - (arguments.size() == 2 && - is_column_const(*block.get_by_position(arguments[1]).column))) { - // func(Column) or func(ColumnConst) or func(Column, ColumnConst) - Int16 scale_arg = 0; - const auto* col_general = column_general.column.get(); - if (arguments.size() == 2) { - RETURN_IF_ERROR(get_scale_arg(block.get_by_position(arguments[1]), &scale_arg)); - } else if (is_column_const(*column_general.column)) { - // if we only have one ColumnConst, we should cast it, otherwise it would cause BE crash - col_general = assert_cast(*column_general.column) - .get_data_column_ptr(); - } - - auto call = [&](const auto& types) -> bool { - using Types = std::decay_t; - using DataType = typename Types::LeftType; - - if constexpr (IsDataTypeNumber || IsDataTypeDecimal) { - using FieldType = typename DataType::FieldType; - res = Dispatcher::apply_vec_const( - col_general, scale_arg); - return true; - } - - return false; - }; -#if !defined(__SSE4_1__) && !defined(__aarch64__) - /// In case of "nearbyint" function is used, we should ensure the expected rounding mode for the Banker's rounding. - /// Actually it is by default. But we will set it just in case. - - if constexpr (rounding_mode == RoundingMode::Round) { - if (0 != fesetround(FE_TONEAREST)) { - return Status::InvalidArgument("Cannot set floating point rounding mode"); - } - } -#endif - - if (!call_on_index_and_data_type(column_general.type->get_type_id(), call)) { - return Status::InvalidArgument("Invalid argument type {} for function {}", - column_general.type->get_name(), get_name()); - } - - } else if (is_column_const(*block.get_by_position(arguments[0]).column)) { - // func(ColumnConst, Column) - const ColumnWithTypeAndName& column_scale = block.get_by_position(arguments[1]); - const auto& const_col_general = assert_cast(*column_general.column); - - auto call = [&](const auto& types) -> bool { - using Types = std::decay_t; - using DataType = typename Types::LeftType; - - if constexpr (IsDataTypeNumber || IsDataTypeDecimal) { - using FieldType = typename DataType::FieldType; - res = Dispatcher::apply_const_vec( - &const_col_general, column_scale.column.get()); - return true; - } - - return false; - }; - -#if !defined(__SSE4_1__) && !defined(__aarch64__) - /// In case of "nearbyint" function is used, we should ensure the expected rounding mode for the Banker's rounding. - /// Actually it is by default. But we will set it just in case. - - if constexpr (rounding_mode == RoundingMode::Round) { - if (0 != fesetround(FE_TONEAREST)) { - return Status::InvalidArgument("Cannot set floating point rounding mode"); + } else { + // the SECOND arugment is COLUMN + if (is_column_const(*column_general.column)) { + const auto& const_col_general = + assert_cast(*column_general.column); + res = Dispatcher:: + apply_const_vec(&const_col_general, + block.get_by_position(arguments[1]).column.get()); + } else { + res = Dispatcher:: + apply_vec_vec(col_general, + block.get_by_position(arguments[1]).column.get()); + } } + return true; } -#endif - if (!call_on_index_and_data_type(column_general.type->get_type_id(), call)) { - return Status::InvalidArgument("Invalid argument type {} for function {}", - column_general.type->get_name(), get_name()); - } - } else { - // func(Column, Column) - const ColumnWithTypeAndName& column_scale = block.get_by_position(arguments[1]); - - auto call = [&](const auto& types) -> bool { - using Types = std::decay_t; - using DataType = typename Types::LeftType; - - if constexpr (IsDataTypeNumber || IsDataTypeDecimal) { - using FieldType = typename DataType::FieldType; - res = Dispatcher::apply_vec_vec( - column_general.column.get(), column_scale.column.get()); - return true; - } - return false; - }; + return false; + }; #if !defined(__SSE4_1__) && !defined(__aarch64__) - /// In case of "nearbyint" function is used, we should ensure the expected rounding mode for the Banker's rounding. - /// Actually it is by default. But we will set it just in case. - - if constexpr (rounding_mode == RoundingMode::Round) { - if (0 != fesetround(FE_TONEAREST)) { - return Status::InvalidArgument("Cannot set floating point rounding mode"); - } + /// In case of "nearbyint" function is used, we should ensure the expected rounding mode for the Banker's rounding. + /// Actually it is by default. But we will set it just in case. + if constexpr (rounding_mode == RoundingMode::Round) { + if (0 != fesetround(FE_TONEAREST)) { + return Status::InvalidArgument("Cannot set floating point rounding mode"); } + } #endif - if (!call_on_index_and_data_type(column_general.type->get_type_id(), call)) { - return Status::InvalidArgument("Invalid argument type {} for function {}", - column_general.type->get_name(), get_name()); - } + if (!call_on_index_and_data_type(column_general.type->get_type_id(), call)) { + return Status::InvalidArgument("Invalid argument type {} for function {}", + column_general.type->get_name(), get_name()); } block.replace_by_position(result, std::move(res)); From 246a386afb77fc6c94a82e53f88ef984216c87f7 Mon Sep 17 00:00:00 2001 From: chesterxu Date: Mon, 6 May 2024 17:09:34 +0800 Subject: [PATCH 7/7] fix execute_impl --- be/src/vec/functions/round.h | 40 ++++++++++++++++++++++++------------ 1 file changed, 27 insertions(+), 13 deletions(-) diff --git a/be/src/vec/functions/round.h b/be/src/vec/functions/round.h index e1dab731b9fd95..70d41bc5fe05dd 100644 --- a/be/src/vec/functions/round.h +++ b/be/src/vec/functions/round.h @@ -630,9 +630,9 @@ struct Dispatcher { return col_res; } else { - throw doris::Exception(ErrorCode::INVALID_ARGUMENT, - "Unsupported column {} for function truncate", - const_col_general->get_name()); + LOG(FATAL) << "__builtin_unreachable"; + __builtin_unreachable(); + return nullptr; } } }; @@ -679,16 +679,25 @@ class FunctionRounding : public IFunction { return Status::OK(); } + /// SELECT number, truncate(123.345, 1) FROM number("numbers"="10") + /// should NOT behave like two column arguments, so we can not use const column default implementation + bool use_default_implementation_for_constants() const override { return false; } + //// We moved and optimized the execute_impl logic of function_truncate.h from PR#32746, //// as well as make it suitable for all functions. Status execute_impl(FunctionContext* context, Block& block, const ColumnNumbers& arguments, size_t result, size_t input_rows_count) const override { const ColumnWithTypeAndName& column_general = block.get_by_position(arguments[0]); - const auto* col_general = column_general.column.get(); + const bool is_col_general_const = is_column_const(*column_general.column); + const auto* col_general = is_col_general_const + ? assert_cast(*column_general.column) + .get_data_column_ptr() + : column_general.column.get(); + ColumnPtr res; - /// potential argument types(optimized from four types in previous PR to two): - /// if the SECOND argument is MISSING(would be considered as ZERO const) or CONST, then we have 1st type: + /// potential argument types: + /// if the SECOND argument is MISSING(would be considered as ZERO const) or CONST, then we have the following type: /// 1. func(Column), func(ColumnConst), func(Column, ColumnConst), func(ColumnConst, ColumnConst) /// otherwise, the SECOND arugment is COLUMN, we have another type: /// 2. func(Column, Column), func(ColumnConst, Column) @@ -699,24 +708,29 @@ class FunctionRounding : public IFunction { if constexpr (IsDataTypeNumber || IsDataTypeDecimal) { using FieldType = typename DataType::FieldType; - // the SECOND argument is MISSING or CONST if (arguments.size() == 1 || is_column_const(*block.get_by_position(arguments[1]).column)) { + // the SECOND argument is MISSING or CONST Int16 scale_arg = 0; if (arguments.size() == 2) { RETURN_IF_ERROR( get_scale_arg(block.get_by_position(arguments[1]), &scale_arg)); } + res = Dispatcher::apply_vec_const( col_general, scale_arg); + + if (arguments.size() == 2 && is_col_general_const) { + // Important, make sure the result column has the same size as the input column + res = ColumnConst::create(std::move(res), input_rows_count); + } } else { // the SECOND arugment is COLUMN - if (is_column_const(*column_general.column)) { - const auto& const_col_general = - assert_cast(*column_general.column); + if (is_col_general_const) { res = Dispatcher:: - apply_const_vec(&const_col_general, - block.get_by_position(arguments[1]).column.get()); + apply_const_vec( + &assert_cast(*column_general.column), + block.get_by_position(arguments[1]).column.get()); } else { res = Dispatcher:: apply_vec_vec(col_general, @@ -741,7 +755,7 @@ class FunctionRounding : public IFunction { if (!call_on_index_and_data_type(column_general.type->get_type_id(), call)) { return Status::InvalidArgument("Invalid argument type {} for function {}", - column_general.type->get_name(), get_name()); + column_general.type->get_name(), name); } block.replace_by_position(result, std::move(res));