From 332699bafdb9c35833c33c364b212a11c209b672 Mon Sep 17 00:00:00 2001 From: David Li Date: Fri, 10 Sep 2021 09:11:14 -0400 Subject: [PATCH 1/3] ARROW-13966: [C++] Support decimals in comparisons --- .../arrow/compute/kernels/codegen_internal.cc | 86 +++++++++++++++++ .../arrow/compute/kernels/codegen_internal.h | 13 +++ .../compute/kernels/scalar_arithmetic.cc | 94 +++---------------- .../arrow/compute/kernels/scalar_compare.cc | 8 ++ .../compute/kernels/scalar_compare_test.cc | 59 ++++++++++++ docs/source/cpp/compute.rst | 21 +++-- 6 files changed, 190 insertions(+), 91 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/codegen_internal.cc b/cpp/src/arrow/compute/kernels/codegen_internal.cc index f8b90085010..ac1a55b6433 100644 --- a/cpp/src/arrow/compute/kernels/codegen_internal.cc +++ b/cpp/src/arrow/compute/kernels/codegen_internal.cc @@ -17,6 +17,7 @@ #include "arrow/compute/kernels/codegen_internal.h" +#include #include #include #include @@ -341,6 +342,91 @@ std::shared_ptr CommonBinary(const std::vector& descrs) { return large_binary(); } +Status CastBinaryDecimalArgs(DecimalPromotion promotion, + std::vector* descrs) { + auto& left_type = (*descrs)[0].type; + auto& right_type = (*descrs)[1].type; + DCHECK(is_decimal(left_type->id()) || is_decimal(right_type->id())); + + // decimal + float = float + if (is_floating(left_type->id())) { + right_type = left_type; + return Status::OK(); + } else if (is_floating(right_type->id())) { + left_type = right_type; + return Status::OK(); + } + + // precision, scale of left and right args + int32_t p1, s1, p2, s2; + + // decimal + integer = decimal + if (is_decimal(left_type->id())) { + auto decimal = checked_cast(left_type.get()); + p1 = decimal->precision(); + s1 = decimal->scale(); + } else { + DCHECK(is_integer(left_type->id())); + p1 = static_cast(std::ceil(std::log10(bit_width(left_type->id())))); + s1 = 0; + } + if (is_decimal(right_type->id())) { + auto decimal = checked_cast(right_type.get()); + p2 = decimal->precision(); + s2 = decimal->scale(); + } else { + DCHECK(is_integer(right_type->id())); + p2 = static_cast(std::ceil(std::log10(bit_width(right_type->id())))); + s2 = 0; + } + if (s1 < 0 || s2 < 0) { + return Status::NotImplemented("Decimals with negative scales not supported"); + } + + // decimal128 + decimal256 = decimal256 + Type::type casted_type_id = Type::DECIMAL128; + if (left_type->id() == Type::DECIMAL256 || right_type->id() == Type::DECIMAL256) { + casted_type_id = Type::DECIMAL256; + } + + // decimal promotion rules compatible with amazon redshift + // https://docs.aws.amazon.com/redshift/latest/dg/r_numeric_computations201.html + int32_t left_scaleup, right_scaleup; + + switch (promotion) { + case DecimalPromotion::kAdd: { + left_scaleup = std::max(s1, s2) - s1; + right_scaleup = std::max(s1, s2) - s2; + break; + } + case DecimalPromotion::kMultiply: { + left_scaleup = right_scaleup = 0; + break; + } + case DecimalPromotion::kDivide: { + left_scaleup = std::max(4, s1 + p2 - s2 + 1) + s2 - s1; + right_scaleup = 0; + break; + } + default: + DCHECK(false) << "Invalid DecimalPromotion value " << static_cast(promotion); + } + ARROW_ASSIGN_OR_RAISE( + left_type, DecimalType::Make(casted_type_id, p1 + left_scaleup, s1 + left_scaleup)); + ARROW_ASSIGN_OR_RAISE(right_type, DecimalType::Make(casted_type_id, p2 + right_scaleup, + s2 + right_scaleup)); + return Status::OK(); +} + +bool HasDecimal(const std::vector& descrs) { + for (const auto& descr : descrs) { + if (is_decimal(descr.type->id())) { + return true; + } + } + return false; +} + } // namespace internal } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/compute/kernels/codegen_internal.h b/cpp/src/arrow/compute/kernels/codegen_internal.h index 9c8b2cef198..951b502b991 100644 --- a/cpp/src/arrow/compute/kernels/codegen_internal.h +++ b/cpp/src/arrow/compute/kernels/codegen_internal.h @@ -1313,6 +1313,19 @@ std::shared_ptr CommonTimestamp(const std::vector& descrs) ARROW_EXPORT std::shared_ptr CommonBinary(const std::vector& descrs); +/// How to promote decimal precision/scale in CastBinaryDecimalArgs. +enum class DecimalPromotion : uint8_t { + kAdd, + kMultiply, + kDivide, +}; + +ARROW_EXPORT +Status CastBinaryDecimalArgs(DecimalPromotion promotion, std::vector* descrs); + +ARROW_EXPORT +bool HasDecimal(const std::vector& descrs); + } // namespace internal } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/compute/kernels/scalar_arithmetic.cc b/cpp/src/arrow/compute/kernels/scalar_arithmetic.cc index 7692f037124..51b79ab78de 100644 --- a/cpp/src/arrow/compute/kernels/scalar_arithmetic.cc +++ b/cpp/src/arrow/compute/kernels/scalar_arithmetic.cc @@ -965,78 +965,6 @@ ArrayKernelExec GenerateArithmeticFloatingPoint(detail::GetTypeId get_id) { } } -Status CastBinaryDecimalArgs(const std::string& func_name, - std::vector* values) { - auto& left_type = (*values)[0].type; - auto& right_type = (*values)[1].type; - DCHECK(is_decimal(left_type->id()) || is_decimal(right_type->id())); - - // decimal + float = float - if (is_floating(left_type->id())) { - right_type = left_type; - return Status::OK(); - } else if (is_floating(right_type->id())) { - left_type = right_type; - return Status::OK(); - } - - // precision, scale of left and right args - int32_t p1, s1, p2, s2; - - // decimal + integer = decimal - if (is_decimal(left_type->id())) { - auto decimal = checked_cast(left_type.get()); - p1 = decimal->precision(); - s1 = decimal->scale(); - } else { - DCHECK(is_integer(left_type->id())); - p1 = static_cast(std::ceil(std::log10(bit_width(left_type->id())))); - s1 = 0; - } - if (is_decimal(right_type->id())) { - auto decimal = checked_cast(right_type.get()); - p2 = decimal->precision(); - s2 = decimal->scale(); - } else { - DCHECK(is_integer(right_type->id())); - p2 = static_cast(std::ceil(std::log10(bit_width(right_type->id())))); - s2 = 0; - } - if (s1 < 0 || s2 < 0) { - return Status::NotImplemented("Decimals with negative scales not supported"); - } - - // decimal128 + decimal256 = decimal256 - Type::type casted_type_id = Type::DECIMAL128; - if (left_type->id() == Type::DECIMAL256 || right_type->id() == Type::DECIMAL256) { - casted_type_id = Type::DECIMAL256; - } - - // decimal promotion rules compatible with amazon redshift - // https://docs.aws.amazon.com/redshift/latest/dg/r_numeric_computations201.html - int32_t left_scaleup, right_scaleup; - - // "add_checked" -> "add" - const std::string op = func_name.substr(0, func_name.find("_")); - if (op == "add" || op == "subtract") { - left_scaleup = std::max(s1, s2) - s1; - right_scaleup = std::max(s1, s2) - s2; - } else if (op == "multiply") { - left_scaleup = right_scaleup = 0; - } else if (op == "divide") { - left_scaleup = std::max(4, s1 + p2 - s2 + 1) + s2 - s1; - right_scaleup = 0; - } else { - return Status::Invalid("Invalid decimal function: ", func_name); - } - - ARROW_ASSIGN_OR_RAISE( - left_type, DecimalType::Make(casted_type_id, p1 + left_scaleup, s1 + left_scaleup)); - ARROW_ASSIGN_OR_RAISE(right_type, DecimalType::Make(casted_type_id, p2 + right_scaleup, - s2 + right_scaleup)); - return Status::OK(); -} - // resolve decimal binary operation output type per *casted* args template Result ResolveDecimalBinaryOperationOutput( @@ -1166,17 +1094,21 @@ struct ArithmeticFunction : ScalarFunction { } Status CheckDecimals(std::vector* values) const { - bool has_decimal = false; - for (const auto& value : *values) { - if (is_decimal(value.type->id())) { - has_decimal = true; - break; - } - } - if (!has_decimal) return Status::OK(); + if (!HasDecimal(*values)) return Status::OK(); if (values->size() == 2) { - return CastBinaryDecimalArgs(name(), values); + // "add_checked" -> "add" + const auto func_name = name(); + const std::string op = func_name.substr(0, func_name.find("_")); + if (op == "add" || op == "subtract") { + return CastBinaryDecimalArgs(DecimalPromotion::kAdd, values); + } else if (op == "multiply") { + return CastBinaryDecimalArgs(DecimalPromotion::kMultiply, values); + } else if (op == "divide") { + return CastBinaryDecimalArgs(DecimalPromotion::kDivide, values); + } else { + return Status::Invalid("Invalid decimal function: ", func_name); + } } return Status::OK(); } diff --git a/cpp/src/arrow/compute/kernels/scalar_compare.cc b/cpp/src/arrow/compute/kernels/scalar_compare.cc index 4342d776c38..6d571867ec5 100644 --- a/cpp/src/arrow/compute/kernels/scalar_compare.cc +++ b/cpp/src/arrow/compute/kernels/scalar_compare.cc @@ -172,6 +172,8 @@ struct CompareFunction : ScalarFunction { ReplaceTypes(type, values); } else if (auto type = CommonBinary(*values)) { ReplaceTypes(type, values); + } else if (HasDecimal(*values)) { + RETURN_NOT_OK(CastBinaryDecimalArgs(DecimalPromotion::kAdd, values)); } if (auto kernel = DispatchExactImpl(this, *values)) return kernel; @@ -259,6 +261,12 @@ std::shared_ptr MakeCompareFunction(std::string name, DCHECK_OK(func->AddKernel({ty, ty}, boolean(), std::move(exec))); } + for (const auto id : DecimalTypeIds()) { + auto exec = GenerateDecimal(id); + DCHECK_OK( + func->AddKernel({InputType(id), InputType(id)}, boolean(), std::move(exec))); + } + return func; } diff --git a/cpp/src/arrow/compute/kernels/scalar_compare_test.cc b/cpp/src/arrow/compute/kernels/scalar_compare_test.cc index 37680945a3e..da3f0c242ac 100644 --- a/cpp/src/arrow/compute/kernels/scalar_compare_test.cc +++ b/cpp/src/arrow/compute/kernels/scalar_compare_test.cc @@ -453,6 +453,54 @@ TEST(TestCompareTimestamps, Basics) { CheckArrayCase(seconds_utc, CompareOperator::EQUAL, "[false, false, true]"); } +template +class TestCompareDecimal : public ::testing::Test {}; +TYPED_TEST_SUITE(TestCompareDecimal, DecimalArrowTypes); + +TYPED_TEST(TestCompareDecimal, ArrayScalar) {} + +TYPED_TEST(TestCompareDecimal, ScalarArray) {} + +TYPED_TEST(TestCompareDecimal, ArrayArray) { + auto ty = std::make_shared(3, 2); + + std::vector> cases = { + std::make_pair("equal", "[1, 0, 0, 1, 0, 0, null, null]"), + std::make_pair("not_equal", "[0, 1, 1, 0, 1, 1, null, null]"), + std::make_pair("less", "[0, 1, 0, 0, 1, 0, null, null]"), + std::make_pair("less_equal", "[1, 1, 0, 1, 1, 0, null, null]"), + std::make_pair("greater", "[0, 0, 1, 0, 0, 1, null, null]"), + std::make_pair("greater_equal", "[1, 0, 1, 1, 0, 1, null, null]"), + }; + + auto lhs = ArrayFromJSON( + ty, R"(["1.23", "1.23", "2.34", "-1.23", "-1.23", "1.23", "1.23", null])"); + auto lhs_float = + ArrayFromJSON(float64(), "[1.23, 1.23, 2.34, -1.23, -1.23, 1.23, 1.23, null]"); + auto rhs = ArrayFromJSON( + ty, R"(["1.23", "2.34", "1.23", "-1.23", "1.23", "-1.23", null, "1.23"])"); + auto rhs_float = + ArrayFromJSON(float64(), "[1.23, 2.34, 1.23, -1.23, 1.23, -1.23, null, 1.23]"); + auto lhs_intlike = ArrayFromJSON( + ty, R"(["1.00", "1.00", "2.00", "-1.00", "-1.00", "1.00", "1.00", null])"); + auto rhs_int = ArrayFromJSON(ty, R"([1, 2, 1, -1, 1, -1, null, 1])"); + for (const auto& op : cases) { + const auto& function = op.first; + const auto& expected = op.second; + + SCOPED_TRACE(function); + + CheckScalarBinary(function, ArrayFromJSON(ty, R"([])"), ArrayFromJSON(ty, R"([])"), + ArrayFromJSON(boolean(), "[]")); + CheckScalarBinary(function, ArrayFromJSON(ty, R"([null])"), + ArrayFromJSON(ty, R"([null])"), ArrayFromJSON(boolean(), "[null]")); + CheckScalarBinary(function, lhs, rhs, ArrayFromJSON(boolean(), expected)); + CheckScalarBinary(function, lhs_float, rhs, ArrayFromJSON(boolean(), expected)); + CheckScalarBinary(function, lhs, rhs_float, ArrayFromJSON(boolean(), expected)); + CheckScalarBinary(function, lhs_intlike, rhs_int, ArrayFromJSON(boolean(), expected)); + } +} + TEST(TestCompareKernel, DispatchBest) { for (std::string name : {"equal", "not_equal", "less", "less_equal", "greater", "greater_equal"}) { @@ -490,6 +538,17 @@ TEST(TestCompareKernel, DispatchBest) { CheckDispatchBest(name, {utf8(), binary()}, {binary(), binary()}); CheckDispatchBest(name, {large_utf8(), binary()}, {large_binary(), large_binary()}); + + CheckDispatchBest(name, {decimal128(3, 2), decimal128(6, 3)}, + {decimal128(4, 3), decimal128(6, 3)}); + CheckDispatchBest(name, {decimal128(3, 2), decimal256(3, 2)}, + {decimal256(3, 2), decimal256(3, 2)}); + CheckDispatchBest(name, {decimal128(3, 2), float64()}, {float64(), float64()}); + CheckDispatchBest(name, {float64(), decimal128(3, 2)}, {float64(), float64()}); + CheckDispatchBest(name, {decimal128(3, 2), int64()}, + {decimal128(3, 2), decimal128(3, 2)}); + CheckDispatchBest(name, {int64(), decimal128(3, 2)}, + {decimal128(3, 2), decimal128(3, 2)}); } } diff --git a/docs/source/cpp/compute.rst b/docs/source/cpp/compute.rst index 21955575d61..b8321e29e1a 100644 --- a/docs/source/cpp/compute.rst +++ b/docs/source/cpp/compute.rst @@ -542,16 +542,17 @@ cast to the :ref:`common numeric type ` before comparison), or two inputs of Binary- or String-like types, or two inputs of Temporal types. If any input is dictionary encoded it will be expanded for the purposes of comparison. If any of the input elements in a pair is null, the corresponding -output element is null. - -+--------------------------+------------+---------------------------------------------+---------------------+ -| Function names | Arity | Input types | Output type | -+==========================+============+=============================================+=====================+ -| equal, not_equal | Binary | Numeric, Temporal, Binary- and String-like | Boolean | -+--------------------------+------------+---------------------------------------------+---------------------+ -| greater, greater_equal, | Binary | Numeric, Temporal, Binary- and String-like | Boolean | -| less, less_equal | | | | -+--------------------------+------------+---------------------------------------------+---------------------+ +output element is null. Decimal arguments will be promoted in the same way as +for ``add`` and ``subtract``. + ++--------------------------+------------+------------------------------------------------------+---------------------+ +| Function names | Arity | Input types | Output type | ++==========================+============+======================================================+=====================+ +| equal, not_equal | Binary | Numeric, Temporal, Decimal, Binary- and String-like | Boolean | ++--------------------------+------------+------------------------------------------------------+---------------------+ +| greater, greater_equal, | Binary | Numeric, Temporal, Decimal, Binary- and String-like | Boolean | +| less, less_equal | | | | ++--------------------------+------------+------------------------------------------------------+---------------------+ These functions take any number of inputs of numeric type (in which case they will be cast to the :ref:`common numeric type ` before From 22a1825b8b08eab78db3cf97f39bbb5f27f8b045 Mon Sep 17 00:00:00 2001 From: David Li Date: Fri, 10 Sep 2021 09:58:40 -0400 Subject: [PATCH 2/3] ARROW-13966: [C++] Fix int-to-decimal promotion --- .../arrow/compute/kernels/codegen_internal.cc | 4 +- .../compute/kernels/scalar_cast_numeric.cc | 7 +- .../compute/kernels/scalar_compare_test.cc | 67 +++++++++++++++++-- cpp/src/arrow/util/decimal.h | 23 +++++++ 4 files changed, 88 insertions(+), 13 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/codegen_internal.cc b/cpp/src/arrow/compute/kernels/codegen_internal.cc index ac1a55b6433..5e04440e5b1 100644 --- a/cpp/src/arrow/compute/kernels/codegen_internal.cc +++ b/cpp/src/arrow/compute/kernels/codegen_internal.cc @@ -367,7 +367,7 @@ Status CastBinaryDecimalArgs(DecimalPromotion promotion, s1 = decimal->scale(); } else { DCHECK(is_integer(left_type->id())); - p1 = static_cast(std::ceil(std::log10(bit_width(left_type->id())))); + ARROW_ASSIGN_OR_RAISE(p1, MaxDecimalDigitsForInteger(left_type->id())); s1 = 0; } if (is_decimal(right_type->id())) { @@ -376,7 +376,7 @@ Status CastBinaryDecimalArgs(DecimalPromotion promotion, s2 = decimal->scale(); } else { DCHECK(is_integer(right_type->id())); - p2 = static_cast(std::ceil(std::log10(bit_width(right_type->id())))); + ARROW_ASSIGN_OR_RAISE(p2, MaxDecimalDigitsForInteger(right_type->id())); s2 = 0; } if (s1 < 0 || s2 < 0) { diff --git a/cpp/src/arrow/compute/kernels/scalar_cast_numeric.cc b/cpp/src/arrow/compute/kernels/scalar_cast_numeric.cc index e9cf9284ceb..1ce6896deba 100644 --- a/cpp/src/arrow/compute/kernels/scalar_cast_numeric.cc +++ b/cpp/src/arrow/compute/kernels/scalar_cast_numeric.cc @@ -420,11 +420,8 @@ struct CastFunctor decimal_digits{3, 5, 10, 19}; - using ctype = typename I::c_type; - static_assert(sizeof(ctype) <= 8, ""); - const int precision = decimal_digits[BitUtil::Log2(sizeof(ctype))] + out_scale; + ARROW_ASSIGN_OR_RAISE(int32_t precision, MaxDecimalDigitsForInteger(I::type_id)); + precision += out_scale; if (out_precision < precision) { return Status::Invalid( "Precision is not great enough for the result. " diff --git a/cpp/src/arrow/compute/kernels/scalar_compare_test.cc b/cpp/src/arrow/compute/kernels/scalar_compare_test.cc index da3f0c242ac..a5bc89d87f3 100644 --- a/cpp/src/arrow/compute/kernels/scalar_compare_test.cc +++ b/cpp/src/arrow/compute/kernels/scalar_compare_test.cc @@ -457,9 +457,65 @@ template class TestCompareDecimal : public ::testing::Test {}; TYPED_TEST_SUITE(TestCompareDecimal, DecimalArrowTypes); -TYPED_TEST(TestCompareDecimal, ArrayScalar) {} +TYPED_TEST(TestCompareDecimal, ArrayScalar) { + auto ty = std::make_shared(3, 2); + + std::vector> cases = { + std::make_pair("equal", "[1, 0, 0, null]"), + std::make_pair("not_equal", "[0, 1, 1, null]"), + std::make_pair("less", "[0, 0, 1, null]"), + std::make_pair("less_equal", "[1, 0, 1, null]"), + std::make_pair("greater", "[0, 1, 0, null]"), + std::make_pair("greater_equal", "[1, 1, 0, null]"), + }; + + auto lhs = ArrayFromJSON(ty, R"(["1.23", "2.34", "-1.23", null])"); + auto lhs_float = ArrayFromJSON(float64(), "[1.23, 2.34, -1.23, null]"); + auto lhs_intlike = ArrayFromJSON(ty, R"(["1.00", "2.00", "-1.00", null])"); + auto rhs = ScalarFromJSON(ty, R"("1.23")"); + auto rhs_float = ScalarFromJSON(float64(), "1.23"); + auto rhs_int = ScalarFromJSON(int64(), "1"); + for (const auto& op : cases) { + const auto& function = op.first; + const auto& expected = op.second; + + SCOPED_TRACE(function); + CheckScalarBinary(function, lhs, rhs, ArrayFromJSON(boolean(), expected)); + CheckScalarBinary(function, lhs_float, rhs, ArrayFromJSON(boolean(), expected)); + CheckScalarBinary(function, lhs, rhs_float, ArrayFromJSON(boolean(), expected)); + CheckScalarBinary(function, lhs_intlike, rhs_int, ArrayFromJSON(boolean(), expected)); + } +} + +TYPED_TEST(TestCompareDecimal, ScalarArray) { + auto ty = std::make_shared(3, 2); -TYPED_TEST(TestCompareDecimal, ScalarArray) {} + std::vector> cases = { + std::make_pair("equal", "[1, 0, 0, null]"), + std::make_pair("not_equal", "[0, 1, 1, null]"), + std::make_pair("less", "[0, 1, 0, null]"), + std::make_pair("less_equal", "[1, 1, 0, null]"), + std::make_pair("greater", "[0, 0, 1, null]"), + std::make_pair("greater_equal", "[1, 0, 1, null]"), + }; + + auto lhs = ScalarFromJSON(ty, R"("1.23")"); + auto lhs_float = ScalarFromJSON(float64(), "1.23"); + auto lhs_int = ScalarFromJSON(int64(), "1"); + auto rhs = ArrayFromJSON(ty, R"(["1.23", "2.34", "-1.23", null])"); + auto rhs_float = ArrayFromJSON(float64(), "[1.23, 2.34, -1.23, null]"); + auto rhs_intlike = ArrayFromJSON(ty, R"(["1.00", "2.00", "-1.00", null])"); + for (const auto& op : cases) { + const auto& function = op.first; + const auto& expected = op.second; + + SCOPED_TRACE(function); + CheckScalarBinary(function, lhs, rhs, ArrayFromJSON(boolean(), expected)); + CheckScalarBinary(function, lhs_float, rhs, ArrayFromJSON(boolean(), expected)); + CheckScalarBinary(function, lhs, rhs_float, ArrayFromJSON(boolean(), expected)); + CheckScalarBinary(function, lhs_int, rhs_intlike, ArrayFromJSON(boolean(), expected)); + } +} TYPED_TEST(TestCompareDecimal, ArrayArray) { auto ty = std::make_shared(3, 2); @@ -477,19 +533,18 @@ TYPED_TEST(TestCompareDecimal, ArrayArray) { ty, R"(["1.23", "1.23", "2.34", "-1.23", "-1.23", "1.23", "1.23", null])"); auto lhs_float = ArrayFromJSON(float64(), "[1.23, 1.23, 2.34, -1.23, -1.23, 1.23, 1.23, null]"); + auto lhs_intlike = ArrayFromJSON( + ty, R"(["1.00", "1.00", "2.00", "-1.00", "-1.00", "1.00", "1.00", null])"); auto rhs = ArrayFromJSON( ty, R"(["1.23", "2.34", "1.23", "-1.23", "1.23", "-1.23", null, "1.23"])"); auto rhs_float = ArrayFromJSON(float64(), "[1.23, 2.34, 1.23, -1.23, 1.23, -1.23, null, 1.23]"); - auto lhs_intlike = ArrayFromJSON( - ty, R"(["1.00", "1.00", "2.00", "-1.00", "-1.00", "1.00", "1.00", null])"); - auto rhs_int = ArrayFromJSON(ty, R"([1, 2, 1, -1, 1, -1, null, 1])"); + auto rhs_int = ArrayFromJSON(int64(), "[1, 2, 1, -1, 1, -1, null, 1]"); for (const auto& op : cases) { const auto& function = op.first; const auto& expected = op.second; SCOPED_TRACE(function); - CheckScalarBinary(function, ArrayFromJSON(ty, R"([])"), ArrayFromJSON(ty, R"([])"), ArrayFromJSON(boolean(), "[]")); CheckScalarBinary(function, ArrayFromJSON(ty, R"([null])"), diff --git a/cpp/src/arrow/util/decimal.h b/cpp/src/arrow/util/decimal.h index 4a158728833..da88fbeb379 100644 --- a/cpp/src/arrow/util/decimal.h +++ b/cpp/src/arrow/util/decimal.h @@ -25,6 +25,7 @@ #include "arrow/result.h" #include "arrow/status.h" +#include "arrow/type_fwd.h" #include "arrow/util/basic_decimal.h" #include "arrow/util/string_view.h" @@ -288,4 +289,26 @@ struct Decimal256::ToRealConversion { } }; +/// For an integer type, return the max number of decimal digits +/// (=minimal decimal precision) it can represent. +inline Result MaxDecimalDigitsForInteger(Type::type type_id) { + switch (type_id) { + case Type::INT8: + case Type::UINT8: + return 3; + case Type::INT16: + case Type::UINT16: + return 5; + case Type::INT32: + case Type::UINT32: + return 10; + case Type::INT64: + case Type::UINT64: + return 19; + default: + break; + } + return Status::Invalid("Not an integer type: ", type_id); +} + } // namespace arrow From 16e72372ec84635b97d47ad6761543e1e8df864a Mon Sep 17 00:00:00 2001 From: David Li Date: Fri, 10 Sep 2021 17:17:35 -0400 Subject: [PATCH 3/3] ARROW-13966: [C++] Update docs --- docs/source/cpp/compute.rst | 22 ++++++++++++++-------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/docs/source/cpp/compute.rst b/docs/source/cpp/compute.rst index b8321e29e1a..3b3bd336740 100644 --- a/docs/source/cpp/compute.rst +++ b/docs/source/cpp/compute.rst @@ -368,6 +368,12 @@ then typically wraps around). Most functions are also available in an overflow-checking variant, suffixed ``_checked``, which returns an ``Invalid`` :class:`Status` when overflow is detected. +For functions which support decimal inputs (currently ``add``, ``subtract``, +``multiply``, and ``divide`` and their checked variants), decimals of different +precisions/scales will be promoted appropriately. Mixed decimal and +floating-point arguments will cast all arguments to floating-point, while mixed +decimal and integer arguments will cast all arguments to decimals. + +------------------+--------+----------------+----------------------+-------+ | Function name | Arity | Input types | Output type | Notes | +==================+========+================+======================+=======+ @@ -545,14 +551,14 @@ comparison. If any of the input elements in a pair is null, the corresponding output element is null. Decimal arguments will be promoted in the same way as for ``add`` and ``subtract``. -+--------------------------+------------+------------------------------------------------------+---------------------+ -| Function names | Arity | Input types | Output type | -+==========================+============+======================================================+=====================+ -| equal, not_equal | Binary | Numeric, Temporal, Decimal, Binary- and String-like | Boolean | -+--------------------------+------------+------------------------------------------------------+---------------------+ -| greater, greater_equal, | Binary | Numeric, Temporal, Decimal, Binary- and String-like | Boolean | -| less, less_equal | | | | -+--------------------------+------------+------------------------------------------------------+---------------------+ ++--------------------------+------------+---------------------------------------------+---------------------+ +| Function names | Arity | Input types | Output type | ++==========================+============+=============================================+=====================+ +| equal, not_equal | Binary | Numeric, Temporal, Binary- and String-like | Boolean | ++--------------------------+------------+---------------------------------------------+---------------------+ +| greater, greater_equal, | Binary | Numeric, Temporal, Binary- and String-like | Boolean | +| less, less_equal | | | | ++--------------------------+------------+---------------------------------------------+---------------------+ These functions take any number of inputs of numeric type (in which case they will be cast to the :ref:`common numeric type ` before