From 332699bafdb9c35833c33c364b212a11c209b672 Mon Sep 17 00:00:00 2001
From: David Li
Date: Fri, 10 Sep 2021 09:11:14 -0400
Subject: [PATCH 1/3] ARROW-13966: [C++] Support decimals in comparisons
---
.../arrow/compute/kernels/codegen_internal.cc | 86 +++++++++++++++++
.../arrow/compute/kernels/codegen_internal.h | 13 +++
.../compute/kernels/scalar_arithmetic.cc | 94 +++----------------
.../arrow/compute/kernels/scalar_compare.cc | 8 ++
.../compute/kernels/scalar_compare_test.cc | 59 ++++++++++++
docs/source/cpp/compute.rst | 21 +++--
6 files changed, 190 insertions(+), 91 deletions(-)
diff --git a/cpp/src/arrow/compute/kernels/codegen_internal.cc b/cpp/src/arrow/compute/kernels/codegen_internal.cc
index f8b90085010..ac1a55b6433 100644
--- a/cpp/src/arrow/compute/kernels/codegen_internal.cc
+++ b/cpp/src/arrow/compute/kernels/codegen_internal.cc
@@ -17,6 +17,7 @@
#include "arrow/compute/kernels/codegen_internal.h"
+#include
#include
#include
#include
@@ -341,6 +342,91 @@ std::shared_ptr CommonBinary(const std::vector& descrs) {
return large_binary();
}
+Status CastBinaryDecimalArgs(DecimalPromotion promotion,
+ std::vector* descrs) {
+ auto& left_type = (*descrs)[0].type;
+ auto& right_type = (*descrs)[1].type;
+ DCHECK(is_decimal(left_type->id()) || is_decimal(right_type->id()));
+
+ // decimal + float = float
+ if (is_floating(left_type->id())) {
+ right_type = left_type;
+ return Status::OK();
+ } else if (is_floating(right_type->id())) {
+ left_type = right_type;
+ return Status::OK();
+ }
+
+ // precision, scale of left and right args
+ int32_t p1, s1, p2, s2;
+
+ // decimal + integer = decimal
+ if (is_decimal(left_type->id())) {
+ auto decimal = checked_cast(left_type.get());
+ p1 = decimal->precision();
+ s1 = decimal->scale();
+ } else {
+ DCHECK(is_integer(left_type->id()));
+ p1 = static_cast(std::ceil(std::log10(bit_width(left_type->id()))));
+ s1 = 0;
+ }
+ if (is_decimal(right_type->id())) {
+ auto decimal = checked_cast(right_type.get());
+ p2 = decimal->precision();
+ s2 = decimal->scale();
+ } else {
+ DCHECK(is_integer(right_type->id()));
+ p2 = static_cast(std::ceil(std::log10(bit_width(right_type->id()))));
+ s2 = 0;
+ }
+ if (s1 < 0 || s2 < 0) {
+ return Status::NotImplemented("Decimals with negative scales not supported");
+ }
+
+ // decimal128 + decimal256 = decimal256
+ Type::type casted_type_id = Type::DECIMAL128;
+ if (left_type->id() == Type::DECIMAL256 || right_type->id() == Type::DECIMAL256) {
+ casted_type_id = Type::DECIMAL256;
+ }
+
+ // decimal promotion rules compatible with amazon redshift
+ // https://docs.aws.amazon.com/redshift/latest/dg/r_numeric_computations201.html
+ int32_t left_scaleup, right_scaleup;
+
+ switch (promotion) {
+ case DecimalPromotion::kAdd: {
+ left_scaleup = std::max(s1, s2) - s1;
+ right_scaleup = std::max(s1, s2) - s2;
+ break;
+ }
+ case DecimalPromotion::kMultiply: {
+ left_scaleup = right_scaleup = 0;
+ break;
+ }
+ case DecimalPromotion::kDivide: {
+ left_scaleup = std::max(4, s1 + p2 - s2 + 1) + s2 - s1;
+ right_scaleup = 0;
+ break;
+ }
+ default:
+ DCHECK(false) << "Invalid DecimalPromotion value " << static_cast(promotion);
+ }
+ ARROW_ASSIGN_OR_RAISE(
+ left_type, DecimalType::Make(casted_type_id, p1 + left_scaleup, s1 + left_scaleup));
+ ARROW_ASSIGN_OR_RAISE(right_type, DecimalType::Make(casted_type_id, p2 + right_scaleup,
+ s2 + right_scaleup));
+ return Status::OK();
+}
+
+bool HasDecimal(const std::vector& descrs) {
+ for (const auto& descr : descrs) {
+ if (is_decimal(descr.type->id())) {
+ return true;
+ }
+ }
+ return false;
+}
+
} // namespace internal
} // namespace compute
} // namespace arrow
diff --git a/cpp/src/arrow/compute/kernels/codegen_internal.h b/cpp/src/arrow/compute/kernels/codegen_internal.h
index 9c8b2cef198..951b502b991 100644
--- a/cpp/src/arrow/compute/kernels/codegen_internal.h
+++ b/cpp/src/arrow/compute/kernels/codegen_internal.h
@@ -1313,6 +1313,19 @@ std::shared_ptr CommonTimestamp(const std::vector& descrs)
ARROW_EXPORT
std::shared_ptr CommonBinary(const std::vector& descrs);
+/// How to promote decimal precision/scale in CastBinaryDecimalArgs.
+enum class DecimalPromotion : uint8_t {
+ kAdd,
+ kMultiply,
+ kDivide,
+};
+
+ARROW_EXPORT
+Status CastBinaryDecimalArgs(DecimalPromotion promotion, std::vector* descrs);
+
+ARROW_EXPORT
+bool HasDecimal(const std::vector& descrs);
+
} // namespace internal
} // namespace compute
} // namespace arrow
diff --git a/cpp/src/arrow/compute/kernels/scalar_arithmetic.cc b/cpp/src/arrow/compute/kernels/scalar_arithmetic.cc
index 7692f037124..51b79ab78de 100644
--- a/cpp/src/arrow/compute/kernels/scalar_arithmetic.cc
+++ b/cpp/src/arrow/compute/kernels/scalar_arithmetic.cc
@@ -965,78 +965,6 @@ ArrayKernelExec GenerateArithmeticFloatingPoint(detail::GetTypeId get_id) {
}
}
-Status CastBinaryDecimalArgs(const std::string& func_name,
- std::vector* values) {
- auto& left_type = (*values)[0].type;
- auto& right_type = (*values)[1].type;
- DCHECK(is_decimal(left_type->id()) || is_decimal(right_type->id()));
-
- // decimal + float = float
- if (is_floating(left_type->id())) {
- right_type = left_type;
- return Status::OK();
- } else if (is_floating(right_type->id())) {
- left_type = right_type;
- return Status::OK();
- }
-
- // precision, scale of left and right args
- int32_t p1, s1, p2, s2;
-
- // decimal + integer = decimal
- if (is_decimal(left_type->id())) {
- auto decimal = checked_cast(left_type.get());
- p1 = decimal->precision();
- s1 = decimal->scale();
- } else {
- DCHECK(is_integer(left_type->id()));
- p1 = static_cast(std::ceil(std::log10(bit_width(left_type->id()))));
- s1 = 0;
- }
- if (is_decimal(right_type->id())) {
- auto decimal = checked_cast(right_type.get());
- p2 = decimal->precision();
- s2 = decimal->scale();
- } else {
- DCHECK(is_integer(right_type->id()));
- p2 = static_cast(std::ceil(std::log10(bit_width(right_type->id()))));
- s2 = 0;
- }
- if (s1 < 0 || s2 < 0) {
- return Status::NotImplemented("Decimals with negative scales not supported");
- }
-
- // decimal128 + decimal256 = decimal256
- Type::type casted_type_id = Type::DECIMAL128;
- if (left_type->id() == Type::DECIMAL256 || right_type->id() == Type::DECIMAL256) {
- casted_type_id = Type::DECIMAL256;
- }
-
- // decimal promotion rules compatible with amazon redshift
- // https://docs.aws.amazon.com/redshift/latest/dg/r_numeric_computations201.html
- int32_t left_scaleup, right_scaleup;
-
- // "add_checked" -> "add"
- const std::string op = func_name.substr(0, func_name.find("_"));
- if (op == "add" || op == "subtract") {
- left_scaleup = std::max(s1, s2) - s1;
- right_scaleup = std::max(s1, s2) - s2;
- } else if (op == "multiply") {
- left_scaleup = right_scaleup = 0;
- } else if (op == "divide") {
- left_scaleup = std::max(4, s1 + p2 - s2 + 1) + s2 - s1;
- right_scaleup = 0;
- } else {
- return Status::Invalid("Invalid decimal function: ", func_name);
- }
-
- ARROW_ASSIGN_OR_RAISE(
- left_type, DecimalType::Make(casted_type_id, p1 + left_scaleup, s1 + left_scaleup));
- ARROW_ASSIGN_OR_RAISE(right_type, DecimalType::Make(casted_type_id, p2 + right_scaleup,
- s2 + right_scaleup));
- return Status::OK();
-}
-
// resolve decimal binary operation output type per *casted* args
template
Result ResolveDecimalBinaryOperationOutput(
@@ -1166,17 +1094,21 @@ struct ArithmeticFunction : ScalarFunction {
}
Status CheckDecimals(std::vector* values) const {
- bool has_decimal = false;
- for (const auto& value : *values) {
- if (is_decimal(value.type->id())) {
- has_decimal = true;
- break;
- }
- }
- if (!has_decimal) return Status::OK();
+ if (!HasDecimal(*values)) return Status::OK();
if (values->size() == 2) {
- return CastBinaryDecimalArgs(name(), values);
+ // "add_checked" -> "add"
+ const auto func_name = name();
+ const std::string op = func_name.substr(0, func_name.find("_"));
+ if (op == "add" || op == "subtract") {
+ return CastBinaryDecimalArgs(DecimalPromotion::kAdd, values);
+ } else if (op == "multiply") {
+ return CastBinaryDecimalArgs(DecimalPromotion::kMultiply, values);
+ } else if (op == "divide") {
+ return CastBinaryDecimalArgs(DecimalPromotion::kDivide, values);
+ } else {
+ return Status::Invalid("Invalid decimal function: ", func_name);
+ }
}
return Status::OK();
}
diff --git a/cpp/src/arrow/compute/kernels/scalar_compare.cc b/cpp/src/arrow/compute/kernels/scalar_compare.cc
index 4342d776c38..6d571867ec5 100644
--- a/cpp/src/arrow/compute/kernels/scalar_compare.cc
+++ b/cpp/src/arrow/compute/kernels/scalar_compare.cc
@@ -172,6 +172,8 @@ struct CompareFunction : ScalarFunction {
ReplaceTypes(type, values);
} else if (auto type = CommonBinary(*values)) {
ReplaceTypes(type, values);
+ } else if (HasDecimal(*values)) {
+ RETURN_NOT_OK(CastBinaryDecimalArgs(DecimalPromotion::kAdd, values));
}
if (auto kernel = DispatchExactImpl(this, *values)) return kernel;
@@ -259,6 +261,12 @@ std::shared_ptr MakeCompareFunction(std::string name,
DCHECK_OK(func->AddKernel({ty, ty}, boolean(), std::move(exec)));
}
+ for (const auto id : DecimalTypeIds()) {
+ auto exec = GenerateDecimal(id);
+ DCHECK_OK(
+ func->AddKernel({InputType(id), InputType(id)}, boolean(), std::move(exec)));
+ }
+
return func;
}
diff --git a/cpp/src/arrow/compute/kernels/scalar_compare_test.cc b/cpp/src/arrow/compute/kernels/scalar_compare_test.cc
index 37680945a3e..da3f0c242ac 100644
--- a/cpp/src/arrow/compute/kernels/scalar_compare_test.cc
+++ b/cpp/src/arrow/compute/kernels/scalar_compare_test.cc
@@ -453,6 +453,54 @@ TEST(TestCompareTimestamps, Basics) {
CheckArrayCase(seconds_utc, CompareOperator::EQUAL, "[false, false, true]");
}
+template
+class TestCompareDecimal : public ::testing::Test {};
+TYPED_TEST_SUITE(TestCompareDecimal, DecimalArrowTypes);
+
+TYPED_TEST(TestCompareDecimal, ArrayScalar) {}
+
+TYPED_TEST(TestCompareDecimal, ScalarArray) {}
+
+TYPED_TEST(TestCompareDecimal, ArrayArray) {
+ auto ty = std::make_shared(3, 2);
+
+ std::vector> cases = {
+ std::make_pair("equal", "[1, 0, 0, 1, 0, 0, null, null]"),
+ std::make_pair("not_equal", "[0, 1, 1, 0, 1, 1, null, null]"),
+ std::make_pair("less", "[0, 1, 0, 0, 1, 0, null, null]"),
+ std::make_pair("less_equal", "[1, 1, 0, 1, 1, 0, null, null]"),
+ std::make_pair("greater", "[0, 0, 1, 0, 0, 1, null, null]"),
+ std::make_pair("greater_equal", "[1, 0, 1, 1, 0, 1, null, null]"),
+ };
+
+ auto lhs = ArrayFromJSON(
+ ty, R"(["1.23", "1.23", "2.34", "-1.23", "-1.23", "1.23", "1.23", null])");
+ auto lhs_float =
+ ArrayFromJSON(float64(), "[1.23, 1.23, 2.34, -1.23, -1.23, 1.23, 1.23, null]");
+ auto rhs = ArrayFromJSON(
+ ty, R"(["1.23", "2.34", "1.23", "-1.23", "1.23", "-1.23", null, "1.23"])");
+ auto rhs_float =
+ ArrayFromJSON(float64(), "[1.23, 2.34, 1.23, -1.23, 1.23, -1.23, null, 1.23]");
+ auto lhs_intlike = ArrayFromJSON(
+ ty, R"(["1.00", "1.00", "2.00", "-1.00", "-1.00", "1.00", "1.00", null])");
+ auto rhs_int = ArrayFromJSON(ty, R"([1, 2, 1, -1, 1, -1, null, 1])");
+ for (const auto& op : cases) {
+ const auto& function = op.first;
+ const auto& expected = op.second;
+
+ SCOPED_TRACE(function);
+
+ CheckScalarBinary(function, ArrayFromJSON(ty, R"([])"), ArrayFromJSON(ty, R"([])"),
+ ArrayFromJSON(boolean(), "[]"));
+ CheckScalarBinary(function, ArrayFromJSON(ty, R"([null])"),
+ ArrayFromJSON(ty, R"([null])"), ArrayFromJSON(boolean(), "[null]"));
+ CheckScalarBinary(function, lhs, rhs, ArrayFromJSON(boolean(), expected));
+ CheckScalarBinary(function, lhs_float, rhs, ArrayFromJSON(boolean(), expected));
+ CheckScalarBinary(function, lhs, rhs_float, ArrayFromJSON(boolean(), expected));
+ CheckScalarBinary(function, lhs_intlike, rhs_int, ArrayFromJSON(boolean(), expected));
+ }
+}
+
TEST(TestCompareKernel, DispatchBest) {
for (std::string name :
{"equal", "not_equal", "less", "less_equal", "greater", "greater_equal"}) {
@@ -490,6 +538,17 @@ TEST(TestCompareKernel, DispatchBest) {
CheckDispatchBest(name, {utf8(), binary()}, {binary(), binary()});
CheckDispatchBest(name, {large_utf8(), binary()}, {large_binary(), large_binary()});
+
+ CheckDispatchBest(name, {decimal128(3, 2), decimal128(6, 3)},
+ {decimal128(4, 3), decimal128(6, 3)});
+ CheckDispatchBest(name, {decimal128(3, 2), decimal256(3, 2)},
+ {decimal256(3, 2), decimal256(3, 2)});
+ CheckDispatchBest(name, {decimal128(3, 2), float64()}, {float64(), float64()});
+ CheckDispatchBest(name, {float64(), decimal128(3, 2)}, {float64(), float64()});
+ CheckDispatchBest(name, {decimal128(3, 2), int64()},
+ {decimal128(3, 2), decimal128(3, 2)});
+ CheckDispatchBest(name, {int64(), decimal128(3, 2)},
+ {decimal128(3, 2), decimal128(3, 2)});
}
}
diff --git a/docs/source/cpp/compute.rst b/docs/source/cpp/compute.rst
index 21955575d61..b8321e29e1a 100644
--- a/docs/source/cpp/compute.rst
+++ b/docs/source/cpp/compute.rst
@@ -542,16 +542,17 @@ cast to the :ref:`common numeric type ` before comparison),
or two inputs of Binary- or String-like types, or two inputs of Temporal types.
If any input is dictionary encoded it will be expanded for the purposes of
comparison. If any of the input elements in a pair is null, the corresponding
-output element is null.
-
-+--------------------------+------------+---------------------------------------------+---------------------+
-| Function names | Arity | Input types | Output type |
-+==========================+============+=============================================+=====================+
-| equal, not_equal | Binary | Numeric, Temporal, Binary- and String-like | Boolean |
-+--------------------------+------------+---------------------------------------------+---------------------+
-| greater, greater_equal, | Binary | Numeric, Temporal, Binary- and String-like | Boolean |
-| less, less_equal | | | |
-+--------------------------+------------+---------------------------------------------+---------------------+
+output element is null. Decimal arguments will be promoted in the same way as
+for ``add`` and ``subtract``.
+
++--------------------------+------------+------------------------------------------------------+---------------------+
+| Function names | Arity | Input types | Output type |
++==========================+============+======================================================+=====================+
+| equal, not_equal | Binary | Numeric, Temporal, Decimal, Binary- and String-like | Boolean |
++--------------------------+------------+------------------------------------------------------+---------------------+
+| greater, greater_equal, | Binary | Numeric, Temporal, Decimal, Binary- and String-like | Boolean |
+| less, less_equal | | | |
++--------------------------+------------+------------------------------------------------------+---------------------+
These functions take any number of inputs of numeric type (in which case they
will be cast to the :ref:`common numeric type ` before
From 22a1825b8b08eab78db3cf97f39bbb5f27f8b045 Mon Sep 17 00:00:00 2001
From: David Li
Date: Fri, 10 Sep 2021 09:58:40 -0400
Subject: [PATCH 2/3] ARROW-13966: [C++] Fix int-to-decimal promotion
---
.../arrow/compute/kernels/codegen_internal.cc | 4 +-
.../compute/kernels/scalar_cast_numeric.cc | 7 +-
.../compute/kernels/scalar_compare_test.cc | 67 +++++++++++++++++--
cpp/src/arrow/util/decimal.h | 23 +++++++
4 files changed, 88 insertions(+), 13 deletions(-)
diff --git a/cpp/src/arrow/compute/kernels/codegen_internal.cc b/cpp/src/arrow/compute/kernels/codegen_internal.cc
index ac1a55b6433..5e04440e5b1 100644
--- a/cpp/src/arrow/compute/kernels/codegen_internal.cc
+++ b/cpp/src/arrow/compute/kernels/codegen_internal.cc
@@ -367,7 +367,7 @@ Status CastBinaryDecimalArgs(DecimalPromotion promotion,
s1 = decimal->scale();
} else {
DCHECK(is_integer(left_type->id()));
- p1 = static_cast(std::ceil(std::log10(bit_width(left_type->id()))));
+ ARROW_ASSIGN_OR_RAISE(p1, MaxDecimalDigitsForInteger(left_type->id()));
s1 = 0;
}
if (is_decimal(right_type->id())) {
@@ -376,7 +376,7 @@ Status CastBinaryDecimalArgs(DecimalPromotion promotion,
s2 = decimal->scale();
} else {
DCHECK(is_integer(right_type->id()));
- p2 = static_cast(std::ceil(std::log10(bit_width(right_type->id()))));
+ ARROW_ASSIGN_OR_RAISE(p2, MaxDecimalDigitsForInteger(right_type->id()));
s2 = 0;
}
if (s1 < 0 || s2 < 0) {
diff --git a/cpp/src/arrow/compute/kernels/scalar_cast_numeric.cc b/cpp/src/arrow/compute/kernels/scalar_cast_numeric.cc
index e9cf9284ceb..1ce6896deba 100644
--- a/cpp/src/arrow/compute/kernels/scalar_cast_numeric.cc
+++ b/cpp/src/arrow/compute/kernels/scalar_cast_numeric.cc
@@ -420,11 +420,8 @@ struct CastFunctor decimal_digits{3, 5, 10, 19};
- using ctype = typename I::c_type;
- static_assert(sizeof(ctype) <= 8, "");
- const int precision = decimal_digits[BitUtil::Log2(sizeof(ctype))] + out_scale;
+ ARROW_ASSIGN_OR_RAISE(int32_t precision, MaxDecimalDigitsForInteger(I::type_id));
+ precision += out_scale;
if (out_precision < precision) {
return Status::Invalid(
"Precision is not great enough for the result. "
diff --git a/cpp/src/arrow/compute/kernels/scalar_compare_test.cc b/cpp/src/arrow/compute/kernels/scalar_compare_test.cc
index da3f0c242ac..a5bc89d87f3 100644
--- a/cpp/src/arrow/compute/kernels/scalar_compare_test.cc
+++ b/cpp/src/arrow/compute/kernels/scalar_compare_test.cc
@@ -457,9 +457,65 @@ template
class TestCompareDecimal : public ::testing::Test {};
TYPED_TEST_SUITE(TestCompareDecimal, DecimalArrowTypes);
-TYPED_TEST(TestCompareDecimal, ArrayScalar) {}
+TYPED_TEST(TestCompareDecimal, ArrayScalar) {
+ auto ty = std::make_shared(3, 2);
+
+ std::vector> cases = {
+ std::make_pair("equal", "[1, 0, 0, null]"),
+ std::make_pair("not_equal", "[0, 1, 1, null]"),
+ std::make_pair("less", "[0, 0, 1, null]"),
+ std::make_pair("less_equal", "[1, 0, 1, null]"),
+ std::make_pair("greater", "[0, 1, 0, null]"),
+ std::make_pair("greater_equal", "[1, 1, 0, null]"),
+ };
+
+ auto lhs = ArrayFromJSON(ty, R"(["1.23", "2.34", "-1.23", null])");
+ auto lhs_float = ArrayFromJSON(float64(), "[1.23, 2.34, -1.23, null]");
+ auto lhs_intlike = ArrayFromJSON(ty, R"(["1.00", "2.00", "-1.00", null])");
+ auto rhs = ScalarFromJSON(ty, R"("1.23")");
+ auto rhs_float = ScalarFromJSON(float64(), "1.23");
+ auto rhs_int = ScalarFromJSON(int64(), "1");
+ for (const auto& op : cases) {
+ const auto& function = op.first;
+ const auto& expected = op.second;
+
+ SCOPED_TRACE(function);
+ CheckScalarBinary(function, lhs, rhs, ArrayFromJSON(boolean(), expected));
+ CheckScalarBinary(function, lhs_float, rhs, ArrayFromJSON(boolean(), expected));
+ CheckScalarBinary(function, lhs, rhs_float, ArrayFromJSON(boolean(), expected));
+ CheckScalarBinary(function, lhs_intlike, rhs_int, ArrayFromJSON(boolean(), expected));
+ }
+}
+
+TYPED_TEST(TestCompareDecimal, ScalarArray) {
+ auto ty = std::make_shared(3, 2);
-TYPED_TEST(TestCompareDecimal, ScalarArray) {}
+ std::vector> cases = {
+ std::make_pair("equal", "[1, 0, 0, null]"),
+ std::make_pair("not_equal", "[0, 1, 1, null]"),
+ std::make_pair("less", "[0, 1, 0, null]"),
+ std::make_pair("less_equal", "[1, 1, 0, null]"),
+ std::make_pair("greater", "[0, 0, 1, null]"),
+ std::make_pair("greater_equal", "[1, 0, 1, null]"),
+ };
+
+ auto lhs = ScalarFromJSON(ty, R"("1.23")");
+ auto lhs_float = ScalarFromJSON(float64(), "1.23");
+ auto lhs_int = ScalarFromJSON(int64(), "1");
+ auto rhs = ArrayFromJSON(ty, R"(["1.23", "2.34", "-1.23", null])");
+ auto rhs_float = ArrayFromJSON(float64(), "[1.23, 2.34, -1.23, null]");
+ auto rhs_intlike = ArrayFromJSON(ty, R"(["1.00", "2.00", "-1.00", null])");
+ for (const auto& op : cases) {
+ const auto& function = op.first;
+ const auto& expected = op.second;
+
+ SCOPED_TRACE(function);
+ CheckScalarBinary(function, lhs, rhs, ArrayFromJSON(boolean(), expected));
+ CheckScalarBinary(function, lhs_float, rhs, ArrayFromJSON(boolean(), expected));
+ CheckScalarBinary(function, lhs, rhs_float, ArrayFromJSON(boolean(), expected));
+ CheckScalarBinary(function, lhs_int, rhs_intlike, ArrayFromJSON(boolean(), expected));
+ }
+}
TYPED_TEST(TestCompareDecimal, ArrayArray) {
auto ty = std::make_shared(3, 2);
@@ -477,19 +533,18 @@ TYPED_TEST(TestCompareDecimal, ArrayArray) {
ty, R"(["1.23", "1.23", "2.34", "-1.23", "-1.23", "1.23", "1.23", null])");
auto lhs_float =
ArrayFromJSON(float64(), "[1.23, 1.23, 2.34, -1.23, -1.23, 1.23, 1.23, null]");
+ auto lhs_intlike = ArrayFromJSON(
+ ty, R"(["1.00", "1.00", "2.00", "-1.00", "-1.00", "1.00", "1.00", null])");
auto rhs = ArrayFromJSON(
ty, R"(["1.23", "2.34", "1.23", "-1.23", "1.23", "-1.23", null, "1.23"])");
auto rhs_float =
ArrayFromJSON(float64(), "[1.23, 2.34, 1.23, -1.23, 1.23, -1.23, null, 1.23]");
- auto lhs_intlike = ArrayFromJSON(
- ty, R"(["1.00", "1.00", "2.00", "-1.00", "-1.00", "1.00", "1.00", null])");
- auto rhs_int = ArrayFromJSON(ty, R"([1, 2, 1, -1, 1, -1, null, 1])");
+ auto rhs_int = ArrayFromJSON(int64(), "[1, 2, 1, -1, 1, -1, null, 1]");
for (const auto& op : cases) {
const auto& function = op.first;
const auto& expected = op.second;
SCOPED_TRACE(function);
-
CheckScalarBinary(function, ArrayFromJSON(ty, R"([])"), ArrayFromJSON(ty, R"([])"),
ArrayFromJSON(boolean(), "[]"));
CheckScalarBinary(function, ArrayFromJSON(ty, R"([null])"),
diff --git a/cpp/src/arrow/util/decimal.h b/cpp/src/arrow/util/decimal.h
index 4a158728833..da88fbeb379 100644
--- a/cpp/src/arrow/util/decimal.h
+++ b/cpp/src/arrow/util/decimal.h
@@ -25,6 +25,7 @@
#include "arrow/result.h"
#include "arrow/status.h"
+#include "arrow/type_fwd.h"
#include "arrow/util/basic_decimal.h"
#include "arrow/util/string_view.h"
@@ -288,4 +289,26 @@ struct Decimal256::ToRealConversion {
}
};
+/// For an integer type, return the max number of decimal digits
+/// (=minimal decimal precision) it can represent.
+inline Result MaxDecimalDigitsForInteger(Type::type type_id) {
+ switch (type_id) {
+ case Type::INT8:
+ case Type::UINT8:
+ return 3;
+ case Type::INT16:
+ case Type::UINT16:
+ return 5;
+ case Type::INT32:
+ case Type::UINT32:
+ return 10;
+ case Type::INT64:
+ case Type::UINT64:
+ return 19;
+ default:
+ break;
+ }
+ return Status::Invalid("Not an integer type: ", type_id);
+}
+
} // namespace arrow
From 16e72372ec84635b97d47ad6761543e1e8df864a Mon Sep 17 00:00:00 2001
From: David Li
Date: Fri, 10 Sep 2021 17:17:35 -0400
Subject: [PATCH 3/3] ARROW-13966: [C++] Update docs
---
docs/source/cpp/compute.rst | 22 ++++++++++++++--------
1 file changed, 14 insertions(+), 8 deletions(-)
diff --git a/docs/source/cpp/compute.rst b/docs/source/cpp/compute.rst
index b8321e29e1a..3b3bd336740 100644
--- a/docs/source/cpp/compute.rst
+++ b/docs/source/cpp/compute.rst
@@ -368,6 +368,12 @@ then typically wraps around). Most functions are also available in an
overflow-checking variant, suffixed ``_checked``, which returns
an ``Invalid`` :class:`Status` when overflow is detected.
+For functions which support decimal inputs (currently ``add``, ``subtract``,
+``multiply``, and ``divide`` and their checked variants), decimals of different
+precisions/scales will be promoted appropriately. Mixed decimal and
+floating-point arguments will cast all arguments to floating-point, while mixed
+decimal and integer arguments will cast all arguments to decimals.
+
+------------------+--------+----------------+----------------------+-------+
| Function name | Arity | Input types | Output type | Notes |
+==================+========+================+======================+=======+
@@ -545,14 +551,14 @@ comparison. If any of the input elements in a pair is null, the corresponding
output element is null. Decimal arguments will be promoted in the same way as
for ``add`` and ``subtract``.
-+--------------------------+------------+------------------------------------------------------+---------------------+
-| Function names | Arity | Input types | Output type |
-+==========================+============+======================================================+=====================+
-| equal, not_equal | Binary | Numeric, Temporal, Decimal, Binary- and String-like | Boolean |
-+--------------------------+------------+------------------------------------------------------+---------------------+
-| greater, greater_equal, | Binary | Numeric, Temporal, Decimal, Binary- and String-like | Boolean |
-| less, less_equal | | | |
-+--------------------------+------------+------------------------------------------------------+---------------------+
++--------------------------+------------+---------------------------------------------+---------------------+
+| Function names | Arity | Input types | Output type |
++==========================+============+=============================================+=====================+
+| equal, not_equal | Binary | Numeric, Temporal, Binary- and String-like | Boolean |
++--------------------------+------------+---------------------------------------------+---------------------+
+| greater, greater_equal, | Binary | Numeric, Temporal, Binary- and String-like | Boolean |
+| less, less_equal | | | |
++--------------------------+------------+---------------------------------------------+---------------------+
These functions take any number of inputs of numeric type (in which case they
will be cast to the :ref:`common numeric type ` before