diff --git a/cpp/src/arrow/compute/kernels/codegen_internal.cc b/cpp/src/arrow/compute/kernels/codegen_internal.cc index f8b90085010..5e04440e5b1 100644 --- a/cpp/src/arrow/compute/kernels/codegen_internal.cc +++ b/cpp/src/arrow/compute/kernels/codegen_internal.cc @@ -17,6 +17,7 @@ #include "arrow/compute/kernels/codegen_internal.h" +#include #include #include #include @@ -341,6 +342,91 @@ std::shared_ptr CommonBinary(const std::vector& descrs) { return large_binary(); } +Status CastBinaryDecimalArgs(DecimalPromotion promotion, + std::vector* descrs) { + auto& left_type = (*descrs)[0].type; + auto& right_type = (*descrs)[1].type; + DCHECK(is_decimal(left_type->id()) || is_decimal(right_type->id())); + + // decimal + float = float + if (is_floating(left_type->id())) { + right_type = left_type; + return Status::OK(); + } else if (is_floating(right_type->id())) { + left_type = right_type; + return Status::OK(); + } + + // precision, scale of left and right args + int32_t p1, s1, p2, s2; + + // decimal + integer = decimal + if (is_decimal(left_type->id())) { + auto decimal = checked_cast(left_type.get()); + p1 = decimal->precision(); + s1 = decimal->scale(); + } else { + DCHECK(is_integer(left_type->id())); + ARROW_ASSIGN_OR_RAISE(p1, MaxDecimalDigitsForInteger(left_type->id())); + s1 = 0; + } + if (is_decimal(right_type->id())) { + auto decimal = checked_cast(right_type.get()); + p2 = decimal->precision(); + s2 = decimal->scale(); + } else { + DCHECK(is_integer(right_type->id())); + ARROW_ASSIGN_OR_RAISE(p2, MaxDecimalDigitsForInteger(right_type->id())); + s2 = 0; + } + if (s1 < 0 || s2 < 0) { + return Status::NotImplemented("Decimals with negative scales not supported"); + } + + // decimal128 + decimal256 = decimal256 + Type::type casted_type_id = Type::DECIMAL128; + if (left_type->id() == Type::DECIMAL256 || right_type->id() == Type::DECIMAL256) { + casted_type_id = Type::DECIMAL256; + } + + // decimal promotion rules compatible with amazon redshift + // https://docs.aws.amazon.com/redshift/latest/dg/r_numeric_computations201.html + int32_t left_scaleup, right_scaleup; + + switch (promotion) { + case DecimalPromotion::kAdd: { + left_scaleup = std::max(s1, s2) - s1; + right_scaleup = std::max(s1, s2) - s2; + break; + } + case DecimalPromotion::kMultiply: { + left_scaleup = right_scaleup = 0; + break; + } + case DecimalPromotion::kDivide: { + left_scaleup = std::max(4, s1 + p2 - s2 + 1) + s2 - s1; + right_scaleup = 0; + break; + } + default: + DCHECK(false) << "Invalid DecimalPromotion value " << static_cast(promotion); + } + ARROW_ASSIGN_OR_RAISE( + left_type, DecimalType::Make(casted_type_id, p1 + left_scaleup, s1 + left_scaleup)); + ARROW_ASSIGN_OR_RAISE(right_type, DecimalType::Make(casted_type_id, p2 + right_scaleup, + s2 + right_scaleup)); + return Status::OK(); +} + +bool HasDecimal(const std::vector& descrs) { + for (const auto& descr : descrs) { + if (is_decimal(descr.type->id())) { + return true; + } + } + return false; +} + } // namespace internal } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/compute/kernels/codegen_internal.h b/cpp/src/arrow/compute/kernels/codegen_internal.h index 9c8b2cef198..951b502b991 100644 --- a/cpp/src/arrow/compute/kernels/codegen_internal.h +++ b/cpp/src/arrow/compute/kernels/codegen_internal.h @@ -1313,6 +1313,19 @@ std::shared_ptr CommonTimestamp(const std::vector& descrs) ARROW_EXPORT std::shared_ptr CommonBinary(const std::vector& descrs); +/// How to promote decimal precision/scale in CastBinaryDecimalArgs. +enum class DecimalPromotion : uint8_t { + kAdd, + kMultiply, + kDivide, +}; + +ARROW_EXPORT +Status CastBinaryDecimalArgs(DecimalPromotion promotion, std::vector* descrs); + +ARROW_EXPORT +bool HasDecimal(const std::vector& descrs); + } // namespace internal } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/compute/kernels/scalar_arithmetic.cc b/cpp/src/arrow/compute/kernels/scalar_arithmetic.cc index 7692f037124..51b79ab78de 100644 --- a/cpp/src/arrow/compute/kernels/scalar_arithmetic.cc +++ b/cpp/src/arrow/compute/kernels/scalar_arithmetic.cc @@ -965,78 +965,6 @@ ArrayKernelExec GenerateArithmeticFloatingPoint(detail::GetTypeId get_id) { } } -Status CastBinaryDecimalArgs(const std::string& func_name, - std::vector* values) { - auto& left_type = (*values)[0].type; - auto& right_type = (*values)[1].type; - DCHECK(is_decimal(left_type->id()) || is_decimal(right_type->id())); - - // decimal + float = float - if (is_floating(left_type->id())) { - right_type = left_type; - return Status::OK(); - } else if (is_floating(right_type->id())) { - left_type = right_type; - return Status::OK(); - } - - // precision, scale of left and right args - int32_t p1, s1, p2, s2; - - // decimal + integer = decimal - if (is_decimal(left_type->id())) { - auto decimal = checked_cast(left_type.get()); - p1 = decimal->precision(); - s1 = decimal->scale(); - } else { - DCHECK(is_integer(left_type->id())); - p1 = static_cast(std::ceil(std::log10(bit_width(left_type->id())))); - s1 = 0; - } - if (is_decimal(right_type->id())) { - auto decimal = checked_cast(right_type.get()); - p2 = decimal->precision(); - s2 = decimal->scale(); - } else { - DCHECK(is_integer(right_type->id())); - p2 = static_cast(std::ceil(std::log10(bit_width(right_type->id())))); - s2 = 0; - } - if (s1 < 0 || s2 < 0) { - return Status::NotImplemented("Decimals with negative scales not supported"); - } - - // decimal128 + decimal256 = decimal256 - Type::type casted_type_id = Type::DECIMAL128; - if (left_type->id() == Type::DECIMAL256 || right_type->id() == Type::DECIMAL256) { - casted_type_id = Type::DECIMAL256; - } - - // decimal promotion rules compatible with amazon redshift - // https://docs.aws.amazon.com/redshift/latest/dg/r_numeric_computations201.html - int32_t left_scaleup, right_scaleup; - - // "add_checked" -> "add" - const std::string op = func_name.substr(0, func_name.find("_")); - if (op == "add" || op == "subtract") { - left_scaleup = std::max(s1, s2) - s1; - right_scaleup = std::max(s1, s2) - s2; - } else if (op == "multiply") { - left_scaleup = right_scaleup = 0; - } else if (op == "divide") { - left_scaleup = std::max(4, s1 + p2 - s2 + 1) + s2 - s1; - right_scaleup = 0; - } else { - return Status::Invalid("Invalid decimal function: ", func_name); - } - - ARROW_ASSIGN_OR_RAISE( - left_type, DecimalType::Make(casted_type_id, p1 + left_scaleup, s1 + left_scaleup)); - ARROW_ASSIGN_OR_RAISE(right_type, DecimalType::Make(casted_type_id, p2 + right_scaleup, - s2 + right_scaleup)); - return Status::OK(); -} - // resolve decimal binary operation output type per *casted* args template Result ResolveDecimalBinaryOperationOutput( @@ -1166,17 +1094,21 @@ struct ArithmeticFunction : ScalarFunction { } Status CheckDecimals(std::vector* values) const { - bool has_decimal = false; - for (const auto& value : *values) { - if (is_decimal(value.type->id())) { - has_decimal = true; - break; - } - } - if (!has_decimal) return Status::OK(); + if (!HasDecimal(*values)) return Status::OK(); if (values->size() == 2) { - return CastBinaryDecimalArgs(name(), values); + // "add_checked" -> "add" + const auto func_name = name(); + const std::string op = func_name.substr(0, func_name.find("_")); + if (op == "add" || op == "subtract") { + return CastBinaryDecimalArgs(DecimalPromotion::kAdd, values); + } else if (op == "multiply") { + return CastBinaryDecimalArgs(DecimalPromotion::kMultiply, values); + } else if (op == "divide") { + return CastBinaryDecimalArgs(DecimalPromotion::kDivide, values); + } else { + return Status::Invalid("Invalid decimal function: ", func_name); + } } return Status::OK(); } diff --git a/cpp/src/arrow/compute/kernels/scalar_cast_numeric.cc b/cpp/src/arrow/compute/kernels/scalar_cast_numeric.cc index e9cf9284ceb..1ce6896deba 100644 --- a/cpp/src/arrow/compute/kernels/scalar_cast_numeric.cc +++ b/cpp/src/arrow/compute/kernels/scalar_cast_numeric.cc @@ -420,11 +420,8 @@ struct CastFunctor decimal_digits{3, 5, 10, 19}; - using ctype = typename I::c_type; - static_assert(sizeof(ctype) <= 8, ""); - const int precision = decimal_digits[BitUtil::Log2(sizeof(ctype))] + out_scale; + ARROW_ASSIGN_OR_RAISE(int32_t precision, MaxDecimalDigitsForInteger(I::type_id)); + precision += out_scale; if (out_precision < precision) { return Status::Invalid( "Precision is not great enough for the result. " diff --git a/cpp/src/arrow/compute/kernels/scalar_compare.cc b/cpp/src/arrow/compute/kernels/scalar_compare.cc index 4342d776c38..6d571867ec5 100644 --- a/cpp/src/arrow/compute/kernels/scalar_compare.cc +++ b/cpp/src/arrow/compute/kernels/scalar_compare.cc @@ -172,6 +172,8 @@ struct CompareFunction : ScalarFunction { ReplaceTypes(type, values); } else if (auto type = CommonBinary(*values)) { ReplaceTypes(type, values); + } else if (HasDecimal(*values)) { + RETURN_NOT_OK(CastBinaryDecimalArgs(DecimalPromotion::kAdd, values)); } if (auto kernel = DispatchExactImpl(this, *values)) return kernel; @@ -259,6 +261,12 @@ std::shared_ptr MakeCompareFunction(std::string name, DCHECK_OK(func->AddKernel({ty, ty}, boolean(), std::move(exec))); } + for (const auto id : DecimalTypeIds()) { + auto exec = GenerateDecimal(id); + DCHECK_OK( + func->AddKernel({InputType(id), InputType(id)}, boolean(), std::move(exec))); + } + return func; } diff --git a/cpp/src/arrow/compute/kernels/scalar_compare_test.cc b/cpp/src/arrow/compute/kernels/scalar_compare_test.cc index 37680945a3e..a5bc89d87f3 100644 --- a/cpp/src/arrow/compute/kernels/scalar_compare_test.cc +++ b/cpp/src/arrow/compute/kernels/scalar_compare_test.cc @@ -453,6 +453,109 @@ TEST(TestCompareTimestamps, Basics) { CheckArrayCase(seconds_utc, CompareOperator::EQUAL, "[false, false, true]"); } +template +class TestCompareDecimal : public ::testing::Test {}; +TYPED_TEST_SUITE(TestCompareDecimal, DecimalArrowTypes); + +TYPED_TEST(TestCompareDecimal, ArrayScalar) { + auto ty = std::make_shared(3, 2); + + std::vector> cases = { + std::make_pair("equal", "[1, 0, 0, null]"), + std::make_pair("not_equal", "[0, 1, 1, null]"), + std::make_pair("less", "[0, 0, 1, null]"), + std::make_pair("less_equal", "[1, 0, 1, null]"), + std::make_pair("greater", "[0, 1, 0, null]"), + std::make_pair("greater_equal", "[1, 1, 0, null]"), + }; + + auto lhs = ArrayFromJSON(ty, R"(["1.23", "2.34", "-1.23", null])"); + auto lhs_float = ArrayFromJSON(float64(), "[1.23, 2.34, -1.23, null]"); + auto lhs_intlike = ArrayFromJSON(ty, R"(["1.00", "2.00", "-1.00", null])"); + auto rhs = ScalarFromJSON(ty, R"("1.23")"); + auto rhs_float = ScalarFromJSON(float64(), "1.23"); + auto rhs_int = ScalarFromJSON(int64(), "1"); + for (const auto& op : cases) { + const auto& function = op.first; + const auto& expected = op.second; + + SCOPED_TRACE(function); + CheckScalarBinary(function, lhs, rhs, ArrayFromJSON(boolean(), expected)); + CheckScalarBinary(function, lhs_float, rhs, ArrayFromJSON(boolean(), expected)); + CheckScalarBinary(function, lhs, rhs_float, ArrayFromJSON(boolean(), expected)); + CheckScalarBinary(function, lhs_intlike, rhs_int, ArrayFromJSON(boolean(), expected)); + } +} + +TYPED_TEST(TestCompareDecimal, ScalarArray) { + auto ty = std::make_shared(3, 2); + + std::vector> cases = { + std::make_pair("equal", "[1, 0, 0, null]"), + std::make_pair("not_equal", "[0, 1, 1, null]"), + std::make_pair("less", "[0, 1, 0, null]"), + std::make_pair("less_equal", "[1, 1, 0, null]"), + std::make_pair("greater", "[0, 0, 1, null]"), + std::make_pair("greater_equal", "[1, 0, 1, null]"), + }; + + auto lhs = ScalarFromJSON(ty, R"("1.23")"); + auto lhs_float = ScalarFromJSON(float64(), "1.23"); + auto lhs_int = ScalarFromJSON(int64(), "1"); + auto rhs = ArrayFromJSON(ty, R"(["1.23", "2.34", "-1.23", null])"); + auto rhs_float = ArrayFromJSON(float64(), "[1.23, 2.34, -1.23, null]"); + auto rhs_intlike = ArrayFromJSON(ty, R"(["1.00", "2.00", "-1.00", null])"); + for (const auto& op : cases) { + const auto& function = op.first; + const auto& expected = op.second; + + SCOPED_TRACE(function); + CheckScalarBinary(function, lhs, rhs, ArrayFromJSON(boolean(), expected)); + CheckScalarBinary(function, lhs_float, rhs, ArrayFromJSON(boolean(), expected)); + CheckScalarBinary(function, lhs, rhs_float, ArrayFromJSON(boolean(), expected)); + CheckScalarBinary(function, lhs_int, rhs_intlike, ArrayFromJSON(boolean(), expected)); + } +} + +TYPED_TEST(TestCompareDecimal, ArrayArray) { + auto ty = std::make_shared(3, 2); + + std::vector> cases = { + std::make_pair("equal", "[1, 0, 0, 1, 0, 0, null, null]"), + std::make_pair("not_equal", "[0, 1, 1, 0, 1, 1, null, null]"), + std::make_pair("less", "[0, 1, 0, 0, 1, 0, null, null]"), + std::make_pair("less_equal", "[1, 1, 0, 1, 1, 0, null, null]"), + std::make_pair("greater", "[0, 0, 1, 0, 0, 1, null, null]"), + std::make_pair("greater_equal", "[1, 0, 1, 1, 0, 1, null, null]"), + }; + + auto lhs = ArrayFromJSON( + ty, R"(["1.23", "1.23", "2.34", "-1.23", "-1.23", "1.23", "1.23", null])"); + auto lhs_float = + ArrayFromJSON(float64(), "[1.23, 1.23, 2.34, -1.23, -1.23, 1.23, 1.23, null]"); + auto lhs_intlike = ArrayFromJSON( + ty, R"(["1.00", "1.00", "2.00", "-1.00", "-1.00", "1.00", "1.00", null])"); + auto rhs = ArrayFromJSON( + ty, R"(["1.23", "2.34", "1.23", "-1.23", "1.23", "-1.23", null, "1.23"])"); + auto rhs_float = + ArrayFromJSON(float64(), "[1.23, 2.34, 1.23, -1.23, 1.23, -1.23, null, 1.23]"); + auto rhs_int = ArrayFromJSON(int64(), "[1, 2, 1, -1, 1, -1, null, 1]"); + for (const auto& op : cases) { + const auto& function = op.first; + const auto& expected = op.second; + + SCOPED_TRACE(function); + CheckScalarBinary(function, ArrayFromJSON(ty, R"([])"), ArrayFromJSON(ty, R"([])"), + ArrayFromJSON(boolean(), "[]")); + CheckScalarBinary(function, ArrayFromJSON(ty, R"([null])"), + ArrayFromJSON(ty, R"([null])"), ArrayFromJSON(boolean(), "[null]")); + CheckScalarBinary(function, lhs, rhs, ArrayFromJSON(boolean(), expected)); + CheckScalarBinary(function, lhs_float, rhs, ArrayFromJSON(boolean(), expected)); + CheckScalarBinary(function, lhs, rhs_float, ArrayFromJSON(boolean(), expected)); + CheckScalarBinary(function, lhs_intlike, rhs_int, ArrayFromJSON(boolean(), expected)); + } +} + TEST(TestCompareKernel, DispatchBest) { for (std::string name : {"equal", "not_equal", "less", "less_equal", "greater", "greater_equal"}) { @@ -490,6 +593,17 @@ TEST(TestCompareKernel, DispatchBest) { CheckDispatchBest(name, {utf8(), binary()}, {binary(), binary()}); CheckDispatchBest(name, {large_utf8(), binary()}, {large_binary(), large_binary()}); + + CheckDispatchBest(name, {decimal128(3, 2), decimal128(6, 3)}, + {decimal128(4, 3), decimal128(6, 3)}); + CheckDispatchBest(name, {decimal128(3, 2), decimal256(3, 2)}, + {decimal256(3, 2), decimal256(3, 2)}); + CheckDispatchBest(name, {decimal128(3, 2), float64()}, {float64(), float64()}); + CheckDispatchBest(name, {float64(), decimal128(3, 2)}, {float64(), float64()}); + CheckDispatchBest(name, {decimal128(3, 2), int64()}, + {decimal128(3, 2), decimal128(3, 2)}); + CheckDispatchBest(name, {int64(), decimal128(3, 2)}, + {decimal128(3, 2), decimal128(3, 2)}); } } diff --git a/cpp/src/arrow/util/decimal.h b/cpp/src/arrow/util/decimal.h index 4a158728833..da88fbeb379 100644 --- a/cpp/src/arrow/util/decimal.h +++ b/cpp/src/arrow/util/decimal.h @@ -25,6 +25,7 @@ #include "arrow/result.h" #include "arrow/status.h" +#include "arrow/type_fwd.h" #include "arrow/util/basic_decimal.h" #include "arrow/util/string_view.h" @@ -288,4 +289,26 @@ struct Decimal256::ToRealConversion { } }; +/// For an integer type, return the max number of decimal digits +/// (=minimal decimal precision) it can represent. +inline Result MaxDecimalDigitsForInteger(Type::type type_id) { + switch (type_id) { + case Type::INT8: + case Type::UINT8: + return 3; + case Type::INT16: + case Type::UINT16: + return 5; + case Type::INT32: + case Type::UINT32: + return 10; + case Type::INT64: + case Type::UINT64: + return 19; + default: + break; + } + return Status::Invalid("Not an integer type: ", type_id); +} + } // namespace arrow diff --git a/docs/source/cpp/compute.rst b/docs/source/cpp/compute.rst index 21955575d61..3b3bd336740 100644 --- a/docs/source/cpp/compute.rst +++ b/docs/source/cpp/compute.rst @@ -368,6 +368,12 @@ then typically wraps around). Most functions are also available in an overflow-checking variant, suffixed ``_checked``, which returns an ``Invalid`` :class:`Status` when overflow is detected. +For functions which support decimal inputs (currently ``add``, ``subtract``, +``multiply``, and ``divide`` and their checked variants), decimals of different +precisions/scales will be promoted appropriately. Mixed decimal and +floating-point arguments will cast all arguments to floating-point, while mixed +decimal and integer arguments will cast all arguments to decimals. + +------------------+--------+----------------+----------------------+-------+ | Function name | Arity | Input types | Output type | Notes | +==================+========+================+======================+=======+ @@ -542,7 +548,8 @@ cast to the :ref:`common numeric type ` before comparison), or two inputs of Binary- or String-like types, or two inputs of Temporal types. If any input is dictionary encoded it will be expanded for the purposes of comparison. If any of the input elements in a pair is null, the corresponding -output element is null. +output element is null. Decimal arguments will be promoted in the same way as +for ``add`` and ``subtract``. +--------------------------+------------+---------------------------------------------+---------------------+ | Function names | Arity | Input types | Output type |