From d97fd9faf0da2bb5d3049672fd28c9a034dbc2a4 Mon Sep 17 00:00:00 2001
From: David Li
Date: Fri, 1 Oct 2021 14:31:21 -0400
Subject: [PATCH 1/3] ARROW-13975: [C++] Implement decimal round
---
cpp/src/arrow/compute/api_scalar.cc | 5 +-
cpp/src/arrow/compute/api_scalar.h | 11 +-
.../compute/kernels/scalar_arithmetic.cc | 634 +++++++++++++++---
.../compute/kernels/scalar_arithmetic_test.cc | 542 ++++++++++++++-
cpp/src/arrow/util/basic_decimal.cc | 14 +
cpp/src/arrow/util/basic_decimal.h | 4 +
python/pyarrow/tests/test_compute.py | 3 +-
r/src/compute.cpp | 3 +-
8 files changed, 1112 insertions(+), 104 deletions(-)
diff --git a/cpp/src/arrow/compute/api_scalar.cc b/cpp/src/arrow/compute/api_scalar.cc
index 3a6e99342ef..e3fe1bdf73d 100644
--- a/cpp/src/arrow/compute/api_scalar.cc
+++ b/cpp/src/arrow/compute/api_scalar.cc
@@ -283,8 +283,11 @@ RoundOptions::RoundOptions(int64_t ndigits, RoundMode round_mode)
constexpr char RoundOptions::kTypeName[];
RoundToMultipleOptions::RoundToMultipleOptions(double multiple, RoundMode round_mode)
+ : RoundToMultipleOptions(std::make_shared(multiple), round_mode) {}
+RoundToMultipleOptions::RoundToMultipleOptions(std::shared_ptr multiple,
+ RoundMode round_mode)
: FunctionOptions(internal::kRoundToMultipleOptionsType),
- multiple(multiple),
+ multiple(std::move(multiple)),
round_mode(round_mode) {}
constexpr char RoundToMultipleOptions::kTypeName[];
diff --git a/cpp/src/arrow/compute/api_scalar.h b/cpp/src/arrow/compute/api_scalar.h
index 8e1cd4875d5..65ec4d6e18b 100644
--- a/cpp/src/arrow/compute/api_scalar.h
+++ b/cpp/src/arrow/compute/api_scalar.h
@@ -93,10 +93,17 @@ class ARROW_EXPORT RoundToMultipleOptions : public FunctionOptions {
public:
explicit RoundToMultipleOptions(double multiple = 1.0,
RoundMode round_mode = RoundMode::HALF_TO_EVEN);
+ explicit RoundToMultipleOptions(std::shared_ptr multiple,
+ RoundMode round_mode = RoundMode::HALF_TO_EVEN);
constexpr static char const kTypeName[] = "RoundToMultipleOptions";
static RoundToMultipleOptions Defaults() { return RoundToMultipleOptions(); }
- /// Rounding scale (multiple to round to)
- double multiple;
+ /// Rounding scale (multiple to round to).
+ ///
+ /// Should be a scalar of a type compatible with the argument to be rounded.
+ /// For example, rounding a decimal value means a decimal multiple is
+ /// required. Rounding a floating point or integer value means a floating
+ /// point scalar is required.
+ std::shared_ptr multiple;
/// Rounding and tie-breaking mode
RoundMode round_mode;
};
diff --git a/cpp/src/arrow/compute/kernels/scalar_arithmetic.cc b/cpp/src/arrow/compute/kernels/scalar_arithmetic.cc
index bf2d27f1f72..48a5e815b67 100644
--- a/cpp/src/arrow/compute/kernels/scalar_arithmetic.cc
+++ b/cpp/src/arrow/compute/kernels/scalar_arithmetic.cc
@@ -46,6 +46,7 @@ using applicator::ScalarBinaryEqualTypes;
using applicator::ScalarBinaryNotNullEqualTypes;
using applicator::ScalarUnary;
using applicator::ScalarUnaryNotNull;
+using applicator::ScalarUnaryNotNullStateful;
namespace {
@@ -73,10 +74,10 @@ using enable_if_c_integer =
template
using enable_if_floating_point = enable_if_t::value, R>;
-template
+template
using enable_if_decimal_value =
enable_if_t::value || std::is_same::value,
- T>;
+ R>;
struct AbsoluteValue {
template
@@ -880,81 +881,170 @@ struct RoundUtil {
};
// Specializations of rounding implementations for round kernels
-template
+template
struct RoundImpl;
-template
-struct RoundImpl {
+template
+struct RoundImpl {
+ template
static constexpr enable_if_floating_point Round(const T val) {
return std::floor(val);
}
+
+ template
+ static enable_if_decimal_value Round(T* val, const T& remainder,
+ const T& pow10, const int32_t scale) {
+ (*val) -= remainder;
+ if (remainder.Sign() < 0) {
+ (*val) -= pow10;
+ }
+ }
};
-template
-struct RoundImpl {
+template
+struct RoundImpl {
+ template
static constexpr enable_if_floating_point Round(const T val) {
return std::ceil(val);
}
+
+ template
+ static enable_if_decimal_value Round(T* val, const T& remainder,
+ const T& pow10, const int32_t scale) {
+ (*val) -= remainder;
+ if (remainder.Sign() > 0 && remainder != 0) {
+ (*val) += pow10;
+ }
+ }
};
-template
-struct RoundImpl {
+template
+struct RoundImpl {
+ template
static constexpr enable_if_floating_point Round(const T val) {
return std::trunc(val);
}
+
+ template
+ static enable_if_decimal_value Round(T* val, const T& remainder,
+ const T& pow10, const int32_t scale) {
+ (*val) -= remainder;
+ }
};
-template
-struct RoundImpl {
+template
+struct RoundImpl {
+ template
static constexpr enable_if_floating_point Round(const T val) {
return std::signbit(val) ? std::floor(val) : std::ceil(val);
}
+
+ template
+ static enable_if_decimal_value Round(T* val, const T& remainder,
+ const T& pow10, const int32_t scale) {
+ (*val) -= remainder;
+ if (remainder.Sign() < 0) {
+ (*val) -= pow10;
+ } else if (remainder.Sign() > 0 && remainder != 0) {
+ (*val) += pow10;
+ }
+ }
};
// NOTE: RoundImpl variants for the HALF_* rounding modes are only
// invoked when the fractional part is equal to 0.5 (std::round is invoked
// otherwise).
-template
-struct RoundImpl {
+template
+struct RoundImpl {
+ template
static constexpr enable_if_floating_point Round(const T val) {
return RoundImpl::Round(val);
}
+
+ template
+ static enable_if_decimal_value Round(T* val, const T& remainder,
+ const T& pow10, const int32_t scale) {
+ RoundImpl::Round(val, remainder, pow10, scale);
+ }
};
-template
-struct RoundImpl {
+template
+struct RoundImpl {
+ template
static constexpr enable_if_floating_point Round(const T val) {
return RoundImpl::Round(val);
}
+
+ template
+ static enable_if_decimal_value Round(T* val, const T& remainder,
+ const T& pow10, const int32_t scale) {
+ RoundImpl::Round(val, remainder, pow10, scale);
+ }
};
-template
-struct RoundImpl {
+template
+struct RoundImpl {
+ template
static constexpr enable_if_floating_point Round(const T val) {
return RoundImpl::Round(val);
}
+
+ template
+ static enable_if_decimal_value Round(T* val, const T& remainder,
+ const T& pow10, const int32_t scale) {
+ RoundImpl::Round(val, remainder, pow10, scale);
+ }
};
-template
-struct RoundImpl {
+template
+struct RoundImpl {
+ template
static constexpr enable_if_floating_point Round(const T val) {
return RoundImpl::Round(val);
}
+
+ template
+ static enable_if_decimal_value Round(T* val, const T& remainder,
+ const T& pow10, const int32_t scale) {
+ RoundImpl::Round(val, remainder, pow10, scale);
+ }
};
-template
-struct RoundImpl {
+template
+struct RoundImpl {
+ template
static constexpr enable_if_floating_point Round(const T val) {
return std::round(val * T(0.5)) * 2;
}
+
+ template
+ static enable_if_decimal_value Round(T* val, const T& remainder,
+ const T& pow10, const int32_t scale) {
+ auto scaled = val->ReduceScaleBy(scale, /*round=*/false);
+ if (scaled.low_bits() % 2 != 0) {
+ scaled += remainder.Sign() >= 0 ? 1 : -1;
+ }
+ *val = scaled.IncreaseScaleBy(scale);
+ }
};
-template
-struct RoundImpl {
+template
+struct RoundImpl {
+ template
static constexpr enable_if_floating_point Round(const T val) {
return std::floor(val * T(0.5)) + std::ceil(val * T(0.5));
}
+
+ template
+ static enable_if_decimal_value Round(T* val, const T& remainder,
+ const T& pow10, const int32_t scale) {
+ auto scaled = val->ReduceScaleBy(scale, /*round=*/false);
+ if (scaled.low_bits() % 2 == 0) {
+ scaled += remainder.Sign() ? 1 : -1;
+ }
+ *val = scaled.IncreaseScaleBy(scale);
+ }
};
// Specializations of kernel state for round kernels
@@ -994,40 +1084,89 @@ struct RoundOptionsWrapper
const KernelInitArgs& args) {
ARROW_ASSIGN_OR_RAISE(auto state, OptionsWrapper::Init(ctx, args));
auto options = Get(*state);
- if (options.multiple <= 0) {
- return Status::Invalid("Rounding multiple has to be a positive value");
+ const auto& type = *args.inputs[0].type;
+ if (!options.multiple || !options.multiple->is_valid) {
+ return Status::Invalid("Rounding multiple must be non-null and valid");
+ }
+ if (is_floating(type.id())) {
+ switch (options.multiple->type->id()) {
+ case Type::FLOAT: {
+ if (UnboxScalar::Unbox(*options.multiple) < 0) {
+ return Status::Invalid("Rounding multiple must be positive");
+ }
+ break;
+ }
+ case Type::DOUBLE: {
+ if (UnboxScalar::Unbox(*options.multiple) < 0) {
+ return Status::Invalid("Rounding multiple must be positive");
+ }
+ break;
+ }
+ case Type::HALF_FLOAT:
+ return Status::NotImplemented("Half-float values are not supported");
+ default:
+ return Status::Invalid("Rounding multiple must be a ", type, " scalar, not ",
+ *options.multiple->type);
+ }
+ } else {
+ DCHECK(is_decimal(type.id()));
+ if (!type.Equals(*options.multiple->type)) {
+ return Status::Invalid("Rounding multiple must be a ", type, " scalar, not ",
+ *options.multiple->type);
+ }
+ switch (options.multiple->type->id()) {
+ case Type::DECIMAL128: {
+ if (UnboxScalar::Unbox(*options.multiple) <= 0) {
+ return Status::Invalid("Rounding multiple must be positive");
+ }
+ break;
+ }
+ case Type::DECIMAL256: {
+ if (UnboxScalar::Unbox(*options.multiple) <= 0) {
+ return Status::Invalid("Rounding multiple must be positive");
+ }
+ break;
+ }
+ default:
+ // This shouldn't happen
+ return Status::Invalid("Rounding multiple must be a ", type, " scalar, not ",
+ *options.multiple->type);
+ }
}
return std::move(state);
}
};
-template
+template
struct Round {
+ using CType = typename TypeTraits::CType;
using State = RoundOptionsWrapper;
- template
- static enable_if_floating_point Call(KernelContext* ctx, Arg arg, Status* st) {
- static_assert(std::is_same::value, "");
+ CType pow10;
+ int64_t ndigits;
+
+ explicit Round(const State& state, const DataType& out_ty)
+ : pow10(static_cast(state.pow10)), ndigits(state.options.ndigits) {}
+
+ template ::CType>
+ enable_if_floating_point Call(KernelContext* ctx, CType arg, Status* st) const {
// Do not process Inf or NaN because they will trigger the overflow error at end of
// function.
if (!std::isfinite(arg)) {
return arg;
}
- auto state = static_cast(ctx->state());
- auto options = state->options;
- auto pow10 = T(state->pow10);
- auto round_val = (options.ndigits >= 0) ? (arg * pow10) : (arg / pow10);
+ auto round_val = ndigits >= 0 ? (arg * pow10) : (arg / pow10);
auto frac = round_val - std::floor(round_val);
if (frac != T(0)) {
// Use std::round() if in tie-breaking mode and scaled value is not 0.5.
if ((RndMode >= RoundMode::HALF_DOWN) && (frac != T(0.5))) {
round_val = std::round(round_val);
} else {
- round_val = RoundImpl::Round(round_val);
+ round_val = RoundImpl::Round(round_val);
}
// Equality check is ommitted so that the common case of 10^0 (integer rounding)
// uses multiply-only
- round_val = (options.ndigits > 0) ? (round_val / pow10) : (round_val * pow10);
+ round_val = ndigits > 0 ? (round_val / pow10) : (round_val * pow10);
if (!std::isfinite(round_val)) {
*st = Status::Invalid("overflow occurred during rounding");
return arg;
@@ -1040,29 +1179,128 @@ struct Round {
}
};
-template
+template
+struct Round> {
+ using CType = typename TypeTraits::CType;
+ using State = RoundOptionsWrapper;
+
+ const ArrowType& ty;
+ int64_t ndigits;
+ int32_t pow;
+ // pow10 is "1" for the given decimal scale. Similarly half_pow10 is "0.5".
+ CType pow10, half_pow10, neg_half_pow10;
+
+ explicit Round(const State& state, const DataType& out_ty)
+ : Round(state.options.ndigits, out_ty) {}
+
+ explicit Round(int64_t ndigits, const DataType& out_ty)
+ : ty(checked_cast(out_ty)),
+ ndigits(ndigits),
+ pow(static_cast(ty.scale() - ndigits)) {
+ if (pow >= ty.precision() || pow < 0) {
+ pow10 = half_pow10 = neg_half_pow10 = 0;
+ } else {
+ pow10 = CType::GetScaleMultiplier(pow);
+ half_pow10 = CType::GetHalfScaleMultiplier(pow);
+ neg_half_pow10 = -half_pow10;
+ }
+ }
+
+ template ::CType>
+ enable_if_decimal_value Call(KernelContext* ctx, CType arg, Status* st) const {
+ if (pow >= ty.precision()) {
+ *st = Status::Invalid("Rounding to ", ndigits,
+ " digits will not fit in precision of ", ty);
+ return arg;
+ } else if (pow < 0) {
+ // no-op, copy output to input
+ return arg;
+ }
+
+ std::pair pair;
+ *st = arg.Divide(pow10).Value(&pair);
+ if (!st->ok()) return arg;
+ // The remainder is effectively the scaled fractional part after division.
+ const auto& remainder = pair.second;
+ if (remainder == 0) return arg;
+ if (kRoundMode >= RoundMode::HALF_DOWN) {
+ if (remainder == half_pow10 || remainder == neg_half_pow10) {
+ // On the halfway point, use tiebreaker
+ RoundImpl::Round(&arg, remainder, pow10, pow);
+ } else if (remainder.Sign() >= 0) {
+ // Positive, round up/down
+ arg -= remainder;
+ if (remainder > half_pow10) {
+ arg += pow10;
+ }
+ } else {
+ // Negative, round up/down
+ arg -= remainder;
+ if (remainder < neg_half_pow10) {
+ arg -= pow10;
+ }
+ }
+ } else {
+ RoundImpl::Round(&arg, remainder, pow10, pow);
+ }
+ if (!arg.FitsInPrecision(ty.precision())) {
+ *st = Status::Invalid("Rounded value ", arg.ToString(ty.scale()),
+ " does not fit in precision of ", ty);
+ return 0;
+ }
+ return arg;
+ }
+};
+
+template
+Status FixedRoundDecimalExec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ using Op = Round;
+ return ScalarUnaryNotNullStateful(
+ Op(kDigits, *out->type()))
+ .Exec(ctx, batch, out);
+}
+
+template
struct RoundToMultiple {
+ using CType = typename TypeTraits::CType;
using State = RoundOptionsWrapper;
- template
- static enable_if_floating_point Call(KernelContext* ctx, Arg arg, Status* st) {
- static_assert(std::is_same::value, "");
+ CType multiple;
+
+ explicit RoundToMultiple(const State& state, const DataType& out_ty) {
+ const auto& options = state.options;
+ DCHECK(options.multiple);
+ DCHECK(options.multiple->is_valid);
+ DCHECK(is_floating(options.multiple->type->id()));
+ switch (options.multiple->type->id()) {
+ case Type::FLOAT:
+ multiple = static_cast(UnboxScalar::Unbox(*options.multiple));
+ break;
+ case Type::DOUBLE:
+ multiple = static_cast(UnboxScalar::Unbox(*options.multiple));
+ break;
+ default:
+ DCHECK(false);
+ }
+ }
+
+ template ::CType>
+ enable_if_floating_point Call(KernelContext* ctx, CType arg, Status* st) const {
// Do not process Inf or NaN because they will trigger the overflow error at end of
// function.
if (!std::isfinite(arg)) {
return arg;
}
- auto options = State::Get(ctx);
- auto round_val = arg / T(options.multiple);
+ auto round_val = arg / multiple;
auto frac = round_val - std::floor(round_val);
if (frac != T(0)) {
// Use std::round() if in tie-breaking mode and scaled value is not 0.5.
- if ((RndMode >= RoundMode::HALF_DOWN) && (frac != T(0.5))) {
+ if ((kRoundMode >= RoundMode::HALF_DOWN) && (frac != T(0.5))) {
round_val = std::round(round_val);
} else {
- round_val = RoundImpl::Round(round_val);
+ round_val = RoundImpl::Round(round_val);
}
- round_val *= T(options.multiple);
+ round_val *= multiple;
if (!std::isfinite(round_val)) {
*st = Status::Invalid("overflow occurred during rounding");
return arg;
@@ -1075,6 +1313,116 @@ struct RoundToMultiple {
}
};
+template
+struct RoundToMultiple> {
+ using CType = typename TypeTraits::CType;
+ using State = RoundOptionsWrapper;
+
+ const ArrowType& ty;
+ CType multiple, half_multiple, neg_half_multiple;
+ bool has_halfway_point;
+
+ explicit RoundToMultiple(const State& state, const DataType& out_ty)
+ : ty(checked_cast(out_ty)) {
+ const auto& options = state.options;
+ DCHECK(options.multiple);
+ DCHECK(options.multiple->is_valid);
+ DCHECK(options.multiple->type->Equals(out_ty));
+ multiple = UnboxScalar::Unbox(*options.multiple);
+ half_multiple = multiple;
+ half_multiple /= 2;
+ neg_half_multiple = -half_multiple;
+ has_halfway_point = multiple.low_bits() % 2 == 0;
+ }
+
+ template ::CType>
+ enable_if_decimal_value Call(KernelContext* ctx, CType arg, Status* st) const {
+ std::pair pair;
+ *st = arg.Divide(multiple).Value(&pair);
+ if (!st->ok()) return arg;
+ const auto& remainder = pair.second;
+ if (remainder == 0) return arg;
+ if (kRoundMode >= RoundMode::HALF_DOWN) {
+ if (has_halfway_point &&
+ (remainder == half_multiple || remainder == neg_half_multiple)) {
+ // On the halfway point, use tiebreaker
+ // Manually implement rounding since we're not actually rounding a
+ // decimal value, but rather manipulating the multiple
+ switch (kRoundMode) {
+ case RoundMode::HALF_DOWN:
+ if (remainder.Sign() < 0) pair.first -= 1;
+ break;
+ case RoundMode::HALF_UP:
+ if (remainder.Sign() >= 0) pair.first += 1;
+ break;
+ case RoundMode::HALF_TOWARDS_ZERO:
+ // Do nothing
+ break;
+ case RoundMode::HALF_TOWARDS_INFINITY:
+ if (remainder.Sign() >= 0) {
+ pair.first += 1;
+ } else {
+ pair.first -= 1;
+ }
+ break;
+ case RoundMode::HALF_TO_EVEN:
+ if (pair.first.low_bits() % 2 != 0) {
+ pair.first += remainder.Sign() >= 0 ? 1 : -1;
+ }
+ break;
+ case RoundMode::HALF_TO_ODD:
+ if (pair.first.low_bits() % 2 == 0) {
+ pair.first += remainder.Sign() >= 0 ? 1 : -1;
+ }
+ break;
+ default:
+ DCHECK(false);
+ }
+ } else if (remainder.Sign() >= 0) {
+ // Positive, round up/down
+ if (remainder > half_multiple) {
+ pair.first += 1;
+ }
+ } else {
+ // Negative, round up/down
+ if (remainder < neg_half_multiple) {
+ pair.first -= 1;
+ }
+ }
+ } else {
+ // Manually implement rounding since we're not actually rounding a
+ // decimal value, but rather manipulating the multiple
+ switch (kRoundMode) {
+ case RoundMode::DOWN:
+ if (remainder.Sign() < 0) pair.first -= 1;
+ break;
+ case RoundMode::UP:
+ if (remainder.Sign() >= 0) pair.first += 1;
+ break;
+ case RoundMode::TOWARDS_ZERO:
+ // Do nothing
+ break;
+ case RoundMode::TOWARDS_INFINITY:
+ if (remainder.Sign() >= 0) {
+ pair.first += 1;
+ } else {
+ pair.first -= 1;
+ }
+ break;
+ default:
+ DCHECK(false);
+ }
+ }
+ CType round_val = pair.first * multiple;
+ if (!round_val.FitsInPrecision(ty.precision())) {
+ *st = Status::Invalid("Rounded value ", round_val.ToString(ty.scale()),
+ " does not fit in precision of ", ty);
+ return 0;
+ }
+ return round_val;
+ }
+};
+
struct Floor {
template
static constexpr enable_if_floating_point Call(KernelContext*, Arg arg,
@@ -1343,6 +1691,37 @@ struct ArithmeticFunction : ScalarFunction {
}
};
+/// An ArithmeticFunction that promotes only integer arguments to double.
+struct ArithmeticIntegerToFloatingPointFunction : public ArithmeticFunction {
+ using ArithmeticFunction::ArithmeticFunction;
+
+ Result DispatchBest(std::vector* values) const override {
+ RETURN_NOT_OK(CheckArity(*values));
+ RETURN_NOT_OK(CheckDecimals(values));
+
+ using arrow::compute::detail::DispatchExactImpl;
+ if (auto kernel = DispatchExactImpl(this, *values)) return kernel;
+
+ EnsureDictionaryDecoded(values);
+
+ if (values->size() == 2) {
+ ReplaceNullWithOtherType(values);
+ }
+
+ for (auto& descr : *values) {
+ if (is_integer(descr.type->id())) {
+ descr.type = float64();
+ }
+ }
+ if (auto type = CommonNumeric(*values)) {
+ ReplaceTypes(type, values);
+ }
+
+ if (auto kernel = DispatchExactImpl(this, *values)) return kernel;
+ return arrow::compute::detail::NoMatchingKernel(this, *values);
+ }
+};
+
/// An ArithmeticFunction that promotes integer arguments to double.
struct ArithmeticFloatingPointFunction : public ArithmeticFunction {
using ArithmeticFunction::ArithmeticFunction;
@@ -1452,61 +1831,100 @@ std::shared_ptr MakeUnaryArithmeticFunctionNotNull(
return func;
}
-// Generate a kernel given an arithmetic rounding functor
-template class Op>
-ArrayKernelExec GenerateExecForRound(RoundMode rmode, detail::GetTypeId ty) {
- switch (rmode) {
- case RoundMode::DOWN:
- return GenerateArithmeticFloatingPoint>(ty);
- case RoundMode::UP:
- return GenerateArithmeticFloatingPoint>(ty);
- case RoundMode::TOWARDS_ZERO:
- return GenerateArithmeticFloatingPoint>(ty);
- case RoundMode::TOWARDS_INFINITY:
- return GenerateArithmeticFloatingPoint>(ty);
- case RoundMode::HALF_DOWN:
- return GenerateArithmeticFloatingPoint>(ty);
- case RoundMode::HALF_UP:
- return GenerateArithmeticFloatingPoint>(
- ty);
- case RoundMode::HALF_TOWARDS_ZERO:
- return GenerateArithmeticFloatingPoint>(ty);
- case RoundMode::HALF_TOWARDS_INFINITY:
- return GenerateArithmeticFloatingPoint>(ty);
- case RoundMode::HALF_TO_EVEN:
- return GenerateArithmeticFloatingPoint>(ty);
- case RoundMode::HALF_TO_ODD:
- return GenerateArithmeticFloatingPoint>(ty);
- default:
- DCHECK(false);
- return ExecFail;
+// Exec the round kernel for the given types
+template class OpImpl>
+Status ExecRound(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ using State = RoundOptionsWrapper;
+ const auto& state = static_cast(*ctx->state());
+ switch (state.options.round_mode) {
+ case RoundMode::DOWN: {
+ using Op = OpImpl;
+ return ScalarUnaryNotNullStateful(Op(state, *out->type()))
+ .Exec(ctx, batch, out);
+ }
+ case RoundMode::UP: {
+ using Op = OpImpl;
+ return ScalarUnaryNotNullStateful(Op(state, *out->type()))
+ .Exec(ctx, batch, out);
+ }
+ case RoundMode::TOWARDS_ZERO: {
+ using Op = OpImpl;
+ return ScalarUnaryNotNullStateful(Op(state, *out->type()))
+ .Exec(ctx, batch, out);
+ }
+ case RoundMode::TOWARDS_INFINITY: {
+ using Op = OpImpl;
+ return ScalarUnaryNotNullStateful(Op(state, *out->type()))
+ .Exec(ctx, batch, out);
+ }
+ case RoundMode::HALF_DOWN: {
+ using Op = OpImpl;
+ return ScalarUnaryNotNullStateful(Op(state, *out->type()))
+ .Exec(ctx, batch, out);
+ }
+ case RoundMode::HALF_UP: {
+ using Op = OpImpl;
+ return ScalarUnaryNotNullStateful(Op(state, *out->type()))
+ .Exec(ctx, batch, out);
+ }
+ case RoundMode::HALF_TOWARDS_ZERO: {
+ using Op = OpImpl;
+ return ScalarUnaryNotNullStateful(Op(state, *out->type()))
+ .Exec(ctx, batch, out);
+ }
+ case RoundMode::HALF_TOWARDS_INFINITY: {
+ using Op = OpImpl;
+ return ScalarUnaryNotNullStateful(Op(state, *out->type()))
+ .Exec(ctx, batch, out);
+ }
+ case RoundMode::HALF_TO_EVEN: {
+ using Op = OpImpl;
+ return ScalarUnaryNotNullStateful(Op(state, *out->type()))
+ .Exec(ctx, batch, out);
+ }
+ case RoundMode::HALF_TO_ODD: {
+ using Op = OpImpl;
+ return ScalarUnaryNotNullStateful(Op(state, *out->type()))
+ .Exec(ctx, batch, out);
+ }
}
+ DCHECK(false);
+ return Status::NotImplemented(
+ "Internal implementation error: round mode not implemented: ",
+ state.options.ToString());
}
// Like MakeUnaryArithmeticFunction, but for unary rounding functions that control
// kernel dispatch based on RoundMode, only on non-null output.
-template class Op, typename OptionsType>
+template class Op, typename OptionsType>
std::shared_ptr MakeUnaryRoundFunction(std::string name,
const FunctionDoc* doc) {
using State = RoundOptionsWrapper;
-
static const OptionsType kDefaultOptions = OptionsType::Defaults();
- auto func = std::make_shared(name, Arity::Unary(), doc,
- &kDefaultOptions);
- for (const auto& ty : FloatingPointTypes()) {
- auto exec = [&](KernelContext* ctx, const ExecBatch& batch, Datum* out) {
- auto options = State::Get(ctx);
- auto exec_ = GenerateExecForRound(options.round_mode, ty);
- return exec_(ctx, batch, out);
+ auto func = std::make_shared(
+ name, Arity::Unary(), doc, &kDefaultOptions);
+ for (const auto& ty : {float32(), float64(), decimal128(1, 0), decimal256(1, 0)}) {
+ auto type_id = ty->id();
+ auto exec = [type_id](KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ switch (type_id) {
+ case Type::FLOAT:
+ return ExecRound(ctx, batch, out);
+ case Type::DOUBLE:
+ return ExecRound(ctx, batch, out);
+ case Type::DECIMAL128:
+ return ExecRound(ctx, batch, out);
+ case Type::DECIMAL256:
+ return ExecRound(ctx, batch, out);
+ default: {
+ DCHECK(false);
+ return ExecFail(ctx, batch, out);
+ }
+ }
};
- DCHECK_OK(func->AddKernel({ty}, ty, exec, State::Init));
+ DCHECK_OK(func->AddKernel(
+ {InputType(type_id)},
+ is_decimal(type_id) ? OutputType(FirstType) : OutputType(ty), exec, State::Init));
}
AddNullExec(func.get());
return func;
@@ -1552,11 +1970,10 @@ std::shared_ptr MakeShiftFunctionNotNull(std::string name,
return func;
}
-template
+template
std::shared_ptr MakeUnaryArithmeticFunctionFloatingPoint(
std::string name, const FunctionDoc* doc) {
- auto func =
- std::make_shared(name, Arity::Unary(), doc);
+ auto func = std::make_shared(name, Arity::Unary(), doc);
for (const auto& ty : FloatingPointTypes()) {
auto exec = GenerateArithmeticFloatingPoint(ty);
DCHECK_OK(func->AddKernel({ty}, ty, exec));
@@ -2133,13 +2550,40 @@ void RegisterScalarArithmetic(FunctionRegistry* registry) {
// ----------------------------------------------------------------------
// Rounding functions
- auto floor = MakeUnaryArithmeticFunctionFloatingPoint("floor", &floor_doc);
+ auto floor =
+ MakeUnaryArithmeticFunctionFloatingPoint(
+ "floor", &floor_doc);
+ DCHECK_OK(floor->AddKernel(
+ {InputType(Type::DECIMAL128)}, OutputType(FirstType),
+ FixedRoundDecimalExec));
+ DCHECK_OK(floor->AddKernel(
+ {InputType(Type::DECIMAL256)}, OutputType(FirstType),
+ FixedRoundDecimalExec));
DCHECK_OK(registry->AddFunction(std::move(floor)));
- auto ceil = MakeUnaryArithmeticFunctionFloatingPoint("ceil", &ceil_doc);
+ auto ceil =
+ MakeUnaryArithmeticFunctionFloatingPoint(
+ "ceil", &ceil_doc);
+ DCHECK_OK(ceil->AddKernel(
+ {InputType(Type::DECIMAL128)}, OutputType(FirstType),
+ FixedRoundDecimalExec));
+ DCHECK_OK(ceil->AddKernel(
+ {InputType(Type::DECIMAL256)}, OutputType(FirstType),
+ FixedRoundDecimalExec));
DCHECK_OK(registry->AddFunction(std::move(ceil)));
- auto trunc = MakeUnaryArithmeticFunctionFloatingPoint("trunc", &trunc_doc);
+ auto trunc =
+ MakeUnaryArithmeticFunctionFloatingPoint(
+ "trunc", &trunc_doc);
+ DCHECK_OK(trunc->AddKernel(
+ {InputType(Type::DECIMAL128)}, OutputType(FirstType),
+ FixedRoundDecimalExec));
+ DCHECK_OK(trunc->AddKernel(
+ {InputType(Type::DECIMAL256)}, OutputType(FirstType),
+ FixedRoundDecimalExec));
DCHECK_OK(registry->AddFunction(std::move(trunc)));
auto round = MakeUnaryRoundFunction("round", &round_doc);
diff --git a/cpp/src/arrow/compute/kernels/scalar_arithmetic_test.cc b/cpp/src/arrow/compute/kernels/scalar_arithmetic_test.cc
index bef439e3e36..09681b2763b 100644
--- a/cpp/src/arrow/compute/kernels/scalar_arithmetic_test.cc
+++ b/cpp/src/arrow/compute/kernels/scalar_arithmetic_test.cc
@@ -72,6 +72,19 @@ void AssertNullToNull(const std::string& func_name) {
}
}
+// Construct an array of decimals, where negative scale is allowed.
+//
+// Works around DecimalXXX::FromString intentionally not inferring
+// negative scales.
+std::shared_ptr DecimalArrayFromJSON(const std::shared_ptr& type,
+ const std::string& json) {
+ const auto& ty = checked_cast(*type);
+ if (ty.scale() >= 0) return ArrayFromJSON(type, json);
+ auto p = ty.precision() - ty.scale();
+ auto adjusted_ty = ty.id() == Type::DECIMAL128 ? decimal128(p, 0) : decimal256(p, 0);
+ return Cast(ArrayFromJSON(adjusted_ty, json), type).ValueOrDie().make_array();
+}
+
template
class TestBaseUnaryArithmetic : public TestBase {
protected:
@@ -236,7 +249,9 @@ class TestUnaryRoundToMultiple
using Base = TestBaseUnaryArithmetic;
using Base::options_;
void SetRoundMode(RoundMode value) { options_.round_mode = value; }
- void SetRoundMultiple(double value) { options_.multiple = value; }
+ void SetRoundMultiple(double value) {
+ options_.multiple = std::make_shared(value);
+ }
};
template
@@ -251,6 +266,40 @@ class TestUnaryRoundToMultipleUnsigned : public TestUnaryRoundToMultipleIntegral
template
class TestUnaryRoundToMultipleFloating : public TestUnaryRoundToMultiple {};
+class TestArithmeticDecimal : public ::testing::Test {
+ protected:
+ std::vector> PositiveScaleTypes() {
+ return {decimal128(4, 2), decimal256(4, 2), decimal128(38, 2), decimal256(76, 2)};
+ }
+ std::vector> NegativeScaleTypes() {
+ return {decimal128(2, -2), decimal256(2, -2)};
+ }
+
+ // Validate that func(*decimals) is the same as
+ // func([cast(x, float64) x for x in decimals])
+ void CheckDecimalToFloat(const std::string& func, const DatumVector& args) {
+ DatumVector floating_args;
+ for (const auto& arg : args) {
+ if (is_decimal(arg.type()->id())) {
+ ASSERT_OK_AND_ASSIGN(auto casted, Cast(arg, float64()));
+ floating_args.push_back(casted);
+ } else {
+ floating_args.push_back(arg);
+ }
+ }
+ ASSERT_OK_AND_ASSIGN(auto expected, CallFunction(func, floating_args));
+ ASSERT_OK_AND_ASSIGN(auto actual, CallFunction(func, args));
+ auto equal_options = EqualOptions::Defaults().nans_equal(true);
+ AssertDatumsApproxEqual(actual, expected, /*verbose=*/true, equal_options);
+ }
+
+ void CheckRaises(const std::string& func, const DatumVector& args,
+ const std::string& substr, FunctionOptions* options = nullptr) {
+ EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid, ::testing::HasSubstr(substr),
+ CallFunction(func, args, options));
+ }
+};
+
template
class TestBinaryArithmetic : public TestBase {
protected:
@@ -1436,6 +1485,494 @@ TYPED_TEST(TestUnaryArithmeticFloating, AbsoluteValue) {
}
}
+class TestUnaryArithmeticDecimal : public TestArithmeticDecimal {};
+
+// Check two modes exhaustively, give all modes a simple test
+TEST_F(TestUnaryArithmeticDecimal, Round) {
+ const auto func = "round";
+ RoundOptions options(2, RoundMode::DOWN);
+ for (const auto& ty : {decimal128(4, 3), decimal256(4, 3)}) {
+ auto values = ArrayFromJSON(
+ ty,
+ R"(["1.010", "1.012", "1.015", "1.019", "-1.010", "-1.012", "-1.015", "-1.019", null])");
+ options.round_mode = RoundMode::DOWN;
+ CheckScalar(
+ func, {values},
+ ArrayFromJSON(
+ ty,
+ R"(["1.010", "1.010", "1.010", "1.010", "-1.010", "-1.020", "-1.020", "-1.020", null])"),
+ &options);
+ options.round_mode = RoundMode::UP;
+ CheckScalar(
+ func, {values},
+ ArrayFromJSON(
+ ty,
+ R"(["1.010", "1.020", "1.020", "1.020", "-1.010", "-1.010", "-1.010", "-1.010", null])"),
+ &options);
+ options.round_mode = RoundMode::TOWARDS_ZERO;
+ CheckScalar(
+ func, {values},
+ ArrayFromJSON(
+ ty,
+ R"(["1.010", "1.010", "1.010", "1.010", "-1.010", "-1.010", "-1.010", "-1.010", null])"),
+ &options);
+ options.round_mode = RoundMode::TOWARDS_INFINITY;
+ CheckScalar(
+ func, {values},
+ ArrayFromJSON(
+ ty,
+ R"(["1.010", "1.020", "1.020", "1.020", "-1.010", "-1.020", "-1.020", "-1.020", null])"),
+ &options);
+ options.round_mode = RoundMode::HALF_DOWN;
+ CheckScalar(
+ func, {values},
+ ArrayFromJSON(
+ ty,
+ R"(["1.010", "1.010", "1.010", "1.020", "-1.010", "-1.010", "-1.020", "-1.020", null])"),
+ &options);
+ options.round_mode = RoundMode::HALF_UP;
+ CheckScalar(
+ func, {values},
+ ArrayFromJSON(
+ ty,
+ R"(["1.010", "1.010", "1.020", "1.020", "-1.010", "-1.010", "-1.010", "-1.020", null])"),
+ &options);
+ options.round_mode = RoundMode::HALF_TOWARDS_ZERO;
+ CheckScalar(
+ func, {values},
+ ArrayFromJSON(
+ ty,
+ R"(["1.010", "1.010", "1.010", "1.020", "-1.010", "-1.010", "-1.010", "-1.020", null])"),
+ &options);
+ options.round_mode = RoundMode::HALF_TOWARDS_INFINITY;
+ CheckScalar(
+ func, {values},
+ ArrayFromJSON(
+ ty,
+ R"(["1.010", "1.010", "1.020", "1.020", "-1.010", "-1.010", "-1.020", "-1.020", null])"),
+ &options);
+ options.round_mode = RoundMode::HALF_TO_EVEN;
+ CheckScalar(
+ func, {values},
+ ArrayFromJSON(
+ ty,
+ R"(["1.010", "1.010", "1.020", "1.020", "-1.010", "-1.010", "-1.020", "-1.020", null])"),
+ &options);
+ options.round_mode = RoundMode::HALF_TO_ODD;
+ CheckScalar(
+ func, {values},
+ ArrayFromJSON(
+ ty,
+ R"(["1.010", "1.010", "1.010", "1.020", "-1.010", "-1.010", "-1.010", "-1.020", null])"),
+ &options);
+ }
+}
+
+TEST_F(TestUnaryArithmeticDecimal, RoundTowardsInfinity) {
+ const auto func = "round";
+ RoundOptions options(0, RoundMode::TOWARDS_INFINITY);
+ for (const auto& ty : {decimal128(4, 2), decimal256(4, 2)}) {
+ auto values = ArrayFromJSON(
+ ty, R"(["1.00", "1.99", "1.01", "-42.00", "-42.99", "-42.15", null])");
+ CheckScalar(func, {ArrayFromJSON(ty, R"([])")}, ArrayFromJSON(ty, R"([])"), &options);
+ options.ndigits = 0;
+ CheckScalar(
+ func, {values},
+ ArrayFromJSON(ty,
+ R"(["1.00", "2.00", "2.00", "-42.00", "-43.00", "-43.00", null])"),
+ &options);
+ options.ndigits = 1;
+ CheckScalar(
+ func, {values},
+ ArrayFromJSON(ty,
+ R"(["1.00", "2.00", "1.10", "-42.00", "-43.00", "-42.20", null])"),
+ &options);
+ options.ndigits = 2;
+ CheckScalar(func, {values}, values, &options);
+ options.ndigits = 4;
+ CheckScalar(func, {values}, values, &options);
+ options.ndigits = 100;
+ CheckScalar(func, {values}, values, &options);
+ options.ndigits = -1;
+ CheckScalar(
+ func, {values},
+ ArrayFromJSON(
+ ty, R"(["10.00", "10.00", "10.00", "-50.00", "-50.00", "-50.00", null])"),
+ &options);
+ options.ndigits = -2;
+ CheckRaises(func, {values}, "Rounding to -2 digits will not fit in precision",
+ &options);
+ options.ndigits = -1;
+ CheckRaises(func, {ArrayFromJSON(ty, R"(["99.99"])")},
+ "Rounded value 100.00 does not fit in precision", &options);
+ }
+ for (const auto& ty : {decimal128(2, -2), decimal256(2, -2)}) {
+ auto values = DecimalArrayFromJSON(
+ ty, R"(["10E2", "12E2", "18E2", "-10E2", "-12E2", "-18E2", null])");
+ options.ndigits = 0;
+ CheckScalar(func, {values}, values, &options);
+ options.ndigits = 2;
+ CheckScalar(func, {values}, values, &options);
+ options.ndigits = 100;
+ CheckScalar(func, {values}, values, &options);
+ options.ndigits = -1;
+ CheckScalar(func, {values}, values, &options);
+ options.ndigits = -2;
+ CheckScalar(func, {values}, values, &options);
+ options.ndigits = -3;
+ CheckScalar(func, {values},
+ DecimalArrayFromJSON(
+ ty, R"(["10E2", "20E2", "20E2", "-10E2", "-20E2", "-20E2", null])"),
+ &options);
+ options.ndigits = -4;
+ CheckRaises(func, {values}, "Rounding to -4 digits will not fit in precision",
+ &options);
+ }
+}
+
+TEST_F(TestUnaryArithmeticDecimal, RoundHalfToEven) {
+ const auto func = "round";
+ RoundOptions options(0, RoundMode::HALF_TO_EVEN);
+ for (const auto& ty : {decimal128(4, 2), decimal256(4, 2)}) {
+ auto values = ArrayFromJSON(
+ ty,
+ R"(["1.00", "5.99", "1.01", "-42.00", "-42.99", "-42.15", "1.50", "2.50", "-5.50", "-2.55", null])");
+ CheckScalar(func, {ArrayFromJSON(ty, R"([])")}, ArrayFromJSON(ty, R"([])"), &options);
+ options.ndigits = 0;
+ CheckScalar(
+ func, {values},
+ ArrayFromJSON(
+ ty,
+ R"(["1.00", "6.00", "1.00", "-42.00", "-43.00", "-42.00", "2.00", "2.00", "-6.00", "-3.00", null])"),
+ &options);
+ options.ndigits = 1;
+ CheckScalar(
+ func, {values},
+ ArrayFromJSON(
+ ty,
+ R"(["1.00", "6.00", "1.00", "-42.00", "-43.00", "-42.20", "1.50", "2.50", "-5.50", "-2.60", null])"),
+ &options);
+ options.ndigits = 2;
+ CheckScalar(func, {values}, values, &options);
+ options.ndigits = 4;
+ CheckScalar(func, {values}, values, &options);
+ options.ndigits = 100;
+ CheckScalar(func, {values}, values, &options);
+ options.ndigits = -1;
+ CheckScalar(
+ func, {values},
+ ArrayFromJSON(
+ ty,
+ R"(["0.00", "10.00", "0.00", "-40.00", "-40.00", "-40.00", "0.00", "0.00", "-10.00", "0.00", null])"),
+ &options);
+ options.ndigits = -2;
+ CheckRaises(func, {values}, "Rounding to -2 digits will not fit in precision",
+ &options);
+ options.ndigits = -1;
+ CheckRaises(func, {ArrayFromJSON(ty, R"(["99.99"])")},
+ "Rounded value 100.00 does not fit in precision", &options);
+ }
+ for (const auto& ty : {decimal128(2, -2), decimal256(2, -2)}) {
+ auto values = DecimalArrayFromJSON(
+ ty,
+ R"(["5E2", "10E2", "12E2", "15E2", "18E2", "-10E2", "-12E2", "-15E2", "-18E2", null])");
+ options.ndigits = 0;
+ CheckScalar(func, {values}, values, &options);
+ options.ndigits = 2;
+ CheckScalar(func, {values}, values, &options);
+ options.ndigits = 100;
+ CheckScalar(func, {values}, values, &options);
+ options.ndigits = -1;
+ CheckScalar(func, {values}, values, &options);
+ options.ndigits = -2;
+ CheckScalar(func, {values}, values, &options);
+ options.ndigits = -3;
+ CheckScalar(
+ func, {values},
+ DecimalArrayFromJSON(
+ ty,
+ R"(["0", "10E2", "10E2", "20E2", "20E2", "-10E2", "-10E2", "-20E2", "-20E2", null])"),
+ &options);
+ options.ndigits = -4;
+ CheckRaises(func, {values}, "Rounding to -4 digits will not fit in precision",
+ &options);
+ }
+}
+
+TEST_F(TestUnaryArithmeticDecimal, RoundCeil) {
+ const auto func = "ceil";
+ for (const auto& ty : PositiveScaleTypes()) {
+ CheckScalar(func, {ArrayFromJSON(ty, R"([])")}, ArrayFromJSON(ty, R"([])"));
+ CheckScalar(
+ func,
+ {ArrayFromJSON(
+ ty, R"(["1.00", "1.99", "1.01", "-42.00", "-42.99", "-42.15", null])")},
+ ArrayFromJSON(ty,
+ R"(["1.00", "2.00", "2.00", "-42.00", "-42.00", "-42.00", null])"));
+ }
+ for (const auto& ty : {decimal128(4, 2), decimal256(4, 2)}) {
+ CheckRaises(func, {ScalarFromJSON(ty, R"("99.99")")},
+ "Rounded value 100.00 does not fit in precision of decimal");
+ CheckScalar(func, {ScalarFromJSON(ty, R"("-99.99")")},
+ ScalarFromJSON(ty, R"("-99.00")"));
+ }
+ for (const auto& ty : NegativeScaleTypes()) {
+ CheckScalar(func, {ArrayFromJSON(ty, R"([])")}, ArrayFromJSON(ty, R"([])"));
+ CheckScalar(func, {DecimalArrayFromJSON(ty, R"(["12E2", "-42E2", null])")},
+ DecimalArrayFromJSON(ty, R"(["12E2", "-42E2", null])"));
+ }
+}
+
+TEST_F(TestUnaryArithmeticDecimal, RoundFloor) {
+ const auto func = "floor";
+ for (const auto& ty : PositiveScaleTypes()) {
+ CheckScalar(func, {ArrayFromJSON(ty, R"([])")}, ArrayFromJSON(ty, R"([])"));
+ CheckScalar(
+ func,
+ {ArrayFromJSON(
+ ty, R"(["1.00", "1.99", "1.01", "-42.00", "-42.99", "-42.15", null])")},
+ ArrayFromJSON(ty,
+ R"(["1.00", "1.00", "1.00", "-42.00", "-43.00", "-43.00", null])"));
+ }
+ for (const auto& ty : {decimal128(4, 2), decimal256(4, 2)}) {
+ CheckScalar(func, {ScalarFromJSON(ty, R"("99.99")")},
+ ScalarFromJSON(ty, R"("99.00")"));
+ CheckRaises(func, {ScalarFromJSON(ty, R"("-99.99")")},
+ "Rounded value -100.00 does not fit in precision of decimal");
+ }
+ for (const auto& ty : NegativeScaleTypes()) {
+ CheckScalar(func, {ArrayFromJSON(ty, R"([])")}, ArrayFromJSON(ty, R"([])"));
+ CheckScalar(func, {DecimalArrayFromJSON(ty, R"(["12E2", "-42E2", null])")},
+ DecimalArrayFromJSON(ty, R"(["12E2", "-42E2", null])"));
+ }
+}
+
+TEST_F(TestUnaryArithmeticDecimal, RoundTrunc) {
+ const auto func = "trunc";
+ for (const auto& ty : PositiveScaleTypes()) {
+ CheckScalar(func, {ArrayFromJSON(ty, R"([])")}, ArrayFromJSON(ty, R"([])"));
+ CheckScalar(
+ func,
+ {ArrayFromJSON(
+ ty, R"(["1.00", "1.99", "1.01", "-42.00", "-42.99", "-42.15", null])")},
+ ArrayFromJSON(ty,
+ R"(["1.00", "1.00", "1.00", "-42.00", "-42.00", "-42.00", null])"));
+ }
+ for (const auto& ty : {decimal128(4, 2), decimal256(4, 2)}) {
+ CheckScalar(func, {ScalarFromJSON(ty, R"("99.99")")},
+ ScalarFromJSON(ty, R"("99.00")"));
+ CheckScalar(func, {ScalarFromJSON(ty, R"("-99.99")")},
+ ScalarFromJSON(ty, R"("-99.00")"));
+ }
+ for (const auto& ty : NegativeScaleTypes()) {
+ CheckScalar(func, {ArrayFromJSON(ty, R"([])")}, ArrayFromJSON(ty, R"([])"));
+ CheckScalar(func, {DecimalArrayFromJSON(ty, R"(["12E2", "-42E2", null])")},
+ DecimalArrayFromJSON(ty, R"(["12E2", "-42E2", null])"));
+ }
+}
+
+TEST_F(TestUnaryArithmeticDecimal, RoundToMultiple) {
+ const auto func = "round_to_multiple";
+ RoundToMultipleOptions options(0, RoundMode::DOWN);
+ for (const auto& ty : {decimal128(4, 2), decimal256(4, 2)}) {
+ if (ty->id() == Type::DECIMAL128) {
+ options.multiple = std::make_shared(Decimal128(200), ty);
+ } else {
+ options.multiple = std::make_shared(Decimal256(200), ty);
+ }
+ auto values = ArrayFromJSON(
+ ty,
+ R"(["-3.50", "-3.00", "-2.50", "-2.00", "-1.50", "-1.00", "-0.50", "0.00",
+ "0.50", "1.00", "1.50", "2.00", "2.50", "3.00", "3.50", null])");
+ options.round_mode = RoundMode::DOWN;
+ CheckScalar(
+ func, {values},
+ ArrayFromJSON(
+ ty,
+ R"(["-4.00", "-4.00", "-4.00", "-2.00", "-2.00", "-2.00", "-2.00", "0.00",
+ "0.00", "0.00", "0.00", "2.00", "2.00", "2.00", "2.00", null])"),
+ &options);
+ options.round_mode = RoundMode::UP;
+ CheckScalar(
+ func, {values},
+ ArrayFromJSON(
+ ty,
+ R"(["-2.00", "-2.00", "-2.00", "-2.00", "-0.00", "-0.00", "-0.00", "0.00",
+ "2.00", "2.00", "2.00", "2.00", "4.00", "4.00", "4.00", null])"),
+ &options);
+ options.round_mode = RoundMode::TOWARDS_ZERO;
+ CheckScalar(
+ func, {values},
+ ArrayFromJSON(
+ ty,
+ R"(["-2.00", "-2.00", "-2.00", "-2.00", "-0.00", "-0.00", "-0.00", "0.00",
+ "0.00", "0.00", "0.00", "2.00", "2.00", "2.00", "2.00", null])"),
+ &options);
+ options.round_mode = RoundMode::TOWARDS_INFINITY;
+ CheckScalar(
+ func, {values},
+ ArrayFromJSON(
+ ty,
+ R"(["-4.00", "-4.00", "-4.00", "-2.00", "-2.00", "-2.00", "-2.00", "0.00",
+ "2.00", "2.00", "2.00", "2.00", "4.00", "4.00", "4.00", null])"),
+ &options);
+ options.round_mode = RoundMode::HALF_DOWN;
+ CheckScalar(
+ func, {values},
+ ArrayFromJSON(
+ ty,
+ R"(["-4.00", "-4.00", "-2.00", "-2.00", "-2.00", "-2.00", "-0.00", "0.00",
+ "0.00", "0.00", "2.00", "2.00", "2.00", "2.00", "4.00", null])"),
+ &options);
+ options.round_mode = RoundMode::HALF_UP;
+ CheckScalar(
+ func, {values},
+ ArrayFromJSON(
+ ty,
+ R"(["-4.00", "-2.00", "-2.00", "-2.00", "-2.00", "-0.00", "-0.00", "0.00",
+ "0.00", "2.00", "2.00", "2.00", "2.00", "4.00", "4.00", null])"),
+ &options);
+ options.round_mode = RoundMode::HALF_TOWARDS_ZERO;
+ CheckScalar(
+ func, {values},
+ ArrayFromJSON(
+ ty,
+ R"(["-4.00", "-2.00", "-2.00", "-2.00", "-2.00", "-0.00", "-0.00", "0.00",
+ "0.00", "0.00", "2.00", "2.00", "2.00", "2.00", "4.00", null])"),
+ &options);
+ options.round_mode = RoundMode::HALF_TOWARDS_INFINITY;
+ CheckScalar(
+ func, {values},
+ ArrayFromJSON(
+ ty,
+ R"(["-4.00", "-4.00", "-2.00", "-2.00", "-2.00", "-2.00", "-0.00", "0.00",
+ "0.00", "2.00", "2.00", "2.00", "2.00", "4.00", "4.00", null])"),
+ &options);
+ options.round_mode = RoundMode::HALF_TO_EVEN;
+ CheckScalar(
+ func, {values},
+ ArrayFromJSON(
+ ty,
+ R"(["-4.00", "-4.00", "-2.00", "-2.00", "-2.00", "-0.00", "-0.00", "0.00",
+ "0.00", "0.00", "2.00", "2.00", "2.00", "4.00", "4.00", null])"),
+ &options);
+ options.round_mode = RoundMode::HALF_TO_ODD;
+ CheckScalar(
+ func, {values},
+ ArrayFromJSON(
+ ty,
+ R"(["-4.00", "-2.00", "-2.00", "-2.00", "-2.00", "-2.00", "-0.00", "0.00",
+ "0.00", "2.00", "2.00", "2.00", "2.00", "2.00", "4.00", null])"),
+ &options);
+ }
+}
+
+TEST_F(TestUnaryArithmeticDecimal, RoundToMultipleTowardsInfinity) {
+ const auto func = "round_to_multiple";
+ RoundToMultipleOptions options(0, RoundMode::TOWARDS_INFINITY);
+ auto set_multiple = [&](const std::shared_ptr& ty, int64_t value) {
+ if (ty->id() == Type::DECIMAL128) {
+ options.multiple = std::make_shared(Decimal128(value), ty);
+ } else {
+ options.multiple = std::make_shared(Decimal256(value), ty);
+ }
+ };
+ for (const auto& ty : {decimal128(4, 2), decimal256(4, 2)}) {
+ auto values = ArrayFromJSON(
+ ty, R"(["1.00", "1.99", "1.01", "-42.00", "-42.99", "-42.15", null])");
+ set_multiple(ty, 25);
+ CheckScalar(func, {ArrayFromJSON(ty, R"([])")}, ArrayFromJSON(ty, R"([])"), &options);
+ CheckScalar(
+ func, {values},
+ ArrayFromJSON(ty,
+ R"(["1.00", "2.00", "1.25", "-42.00", "-43.00", "-42.25", null])"),
+ &options);
+ set_multiple(ty, 1);
+ CheckScalar(func, {values}, values, &options);
+ set_multiple(ty, 0);
+ CheckRaises(func, {ArrayFromJSON(ty, R"(["99.99"])")},
+ "Rounding multiple must be positive", &options);
+ set_multiple(ty, -10);
+ CheckRaises(func, {ArrayFromJSON(ty, R"(["99.99"])")},
+ "Rounding multiple must be positive", &options);
+ set_multiple(ty, 100);
+ CheckRaises(func, {ArrayFromJSON(ty, R"(["99.99"])")},
+ "Rounded value 100.00 does not fit in precision", &options);
+ options.multiple = std::make_shared(1.0);
+ CheckRaises(func, {ArrayFromJSON(ty, R"(["99.99"])")}, "scalar, not double",
+ &options);
+ options.multiple =
+ std::make_shared(Decimal128(0), decimal128(3, 0));
+ CheckRaises(func, {ArrayFromJSON(ty, R"(["99.99"])")}, "scalar, not decimal128(3, 0)",
+ &options);
+ options.multiple = std::make_shared(decimal128(3, 0));
+ CheckRaises(func, {ArrayFromJSON(ty, R"(["99.99"])")},
+ "Rounding multiple must be non-null and valid", &options);
+ options.multiple = nullptr;
+ CheckRaises(func, {ArrayFromJSON(ty, R"(["99.99"])")},
+ "Rounding multiple must be non-null and valid", &options);
+ }
+ for (const auto& ty : {decimal128(2, -2), decimal256(2, -2)}) {
+ auto values = DecimalArrayFromJSON(
+ ty, R"(["10E2", "12E2", "18E2", "-10E2", "-12E2", "-18E2", null])");
+ set_multiple(ty, 4);
+ CheckScalar(func, {values},
+ DecimalArrayFromJSON(
+ ty, R"(["12E2", "12E2", "20E2", "-12E2", "-12E2", "-20E2", null])"),
+ &options);
+ set_multiple(ty, 1);
+ CheckScalar(func, {values}, values, &options);
+ }
+}
+
+TEST_F(TestUnaryArithmeticDecimal, RoundToMultipleHalfToOdd) {
+ const auto func = "round_to_multiple";
+ RoundToMultipleOptions options(0, RoundMode::HALF_TO_ODD);
+ auto set_multiple = [&](const std::shared_ptr& ty, int64_t value) {
+ if (ty->id() == Type::DECIMAL128) {
+ options.multiple = std::make_shared(Decimal128(value), ty);
+ } else {
+ options.multiple = std::make_shared(Decimal256(value), ty);
+ }
+ };
+ for (const auto& ty : {decimal128(4, 2), decimal256(4, 2)}) {
+ auto values =
+ ArrayFromJSON(ty, R"(["-0.38", "-0.37", "-0.25", "-0.13", "-0.12", "0.00",
+ "0.12", "0.13", "0.25", "0.37", "0.38", null])");
+ // There is no exact halfway point, check what happens
+ set_multiple(ty, 25);
+ CheckScalar(func, {ArrayFromJSON(ty, R"([])")}, ArrayFromJSON(ty, R"([])"), &options);
+ CheckScalar(func, {values},
+ ArrayFromJSON(ty, R"(["-0.50", "-0.25", "-0.25", "-0.25", "-0.00", "0.00",
+ "0.00", "0.25", "0.25", "0.25", "0.50", null])"),
+ &options);
+ set_multiple(ty, 1);
+ CheckScalar(func, {values}, values, &options);
+ set_multiple(ty, 24);
+ CheckScalar(func, {ArrayFromJSON(ty, R"([])")}, ArrayFromJSON(ty, R"([])"), &options);
+ CheckScalar(func, {values},
+ ArrayFromJSON(ty, R"(["-0.48", "-0.48", "-0.24", "-0.24", "-0.24", "0.00",
+ "0.24", "0.24", "0.24", "0.48", "0.48", null])"),
+ &options);
+ }
+ for (const auto& ty : {decimal128(2, -2), decimal256(2, -2)}) {
+ auto values = DecimalArrayFromJSON(
+ ty, R"(["10E2", "12E2", "18E2", "-10E2", "-12E2", "-18E2", null])");
+ set_multiple(ty, 4);
+ CheckScalar(func, {values},
+ DecimalArrayFromJSON(
+ ty, R"(["12E2", "12E2", "20E2", "-12E2", "-12E2", "-20E2", null])"),
+ &options);
+ set_multiple(ty, 5);
+ CheckScalar(func, {values},
+ DecimalArrayFromJSON(
+ ty, R"(["10E2", "10E2", "20E2", "-10E2", "-10E2", "-20E2", null])"),
+ &options);
+ set_multiple(ty, 1);
+ CheckScalar(func, {values}, values, &options);
+ }
+}
+
TYPED_TEST_SUITE(TestUnaryRoundIntegral, IntegralTypes);
TYPED_TEST_SUITE(TestUnaryRoundSigned, SignedIntegerTypes);
TYPED_TEST_SUITE(TestUnaryRoundUnsigned, UnsignedIntegerTypes);
@@ -1639,8 +2176,7 @@ TYPED_TEST(TestUnaryRoundToMultipleFloating, RoundToMultiple) {
}
this->SetRoundMultiple(-2);
- this->AssertUnaryOpRaises(RoundToMultiple, values,
- "multiple has to be a positive value");
+ this->AssertUnaryOpRaises(RoundToMultiple, values, "multiple must be positive");
}
TEST(TestBinaryDecimalArithmetic, DispatchBest) {
diff --git a/cpp/src/arrow/util/basic_decimal.cc b/cpp/src/arrow/util/basic_decimal.cc
index 24c193dff9f..7417997df49 100644
--- a/cpp/src/arrow/util/basic_decimal.cc
+++ b/cpp/src/arrow/util/basic_decimal.cc
@@ -1106,6 +1106,13 @@ const BasicDecimal128& BasicDecimal128::GetScaleMultiplier(int32_t scale) {
return ScaleMultipliers[scale];
}
+const BasicDecimal128& BasicDecimal128::GetHalfScaleMultiplier(int32_t scale) {
+ DCHECK_GE(scale, 0);
+ DCHECK_LE(scale, 38);
+
+ return ScaleMultipliersHalf[scale];
+}
+
const BasicDecimal128& BasicDecimal128::GetMaxValue() { return kMaxValue; }
BasicDecimal128 BasicDecimal128::IncreaseScaleBy(int32_t increase_by) const {
@@ -1318,6 +1325,13 @@ const BasicDecimal256& BasicDecimal256::GetScaleMultiplier(int32_t scale) {
return ScaleMultipliersDecimal256[scale];
}
+const BasicDecimal256& BasicDecimal256::GetHalfScaleMultiplier(int32_t scale) {
+ DCHECK_GE(scale, 0);
+ DCHECK_LE(scale, 76);
+
+ return ScaleMultipliersHalfDecimal256[scale];
+}
+
BasicDecimal256 operator*(const BasicDecimal256& left, const BasicDecimal256& right) {
BasicDecimal256 result = left;
result *= right;
diff --git a/cpp/src/arrow/util/basic_decimal.h b/cpp/src/arrow/util/basic_decimal.h
index 22e8a9bf255..a4df3285596 100644
--- a/cpp/src/arrow/util/basic_decimal.h
+++ b/cpp/src/arrow/util/basic_decimal.h
@@ -196,6 +196,8 @@ class ARROW_EXPORT BasicDecimal128 {
/// \brief Scale multiplier for given scale value.
static const BasicDecimal128& GetScaleMultiplier(int32_t scale);
+ /// \brief Half-scale multiplier for given scale value.
+ static const BasicDecimal128& GetHalfScaleMultiplier(int32_t scale);
/// \brief Convert BasicDecimal128 from one scale to another
DecimalStatus Rescale(int32_t original_scale, int32_t new_scale,
@@ -372,6 +374,8 @@ class ARROW_EXPORT BasicDecimal256 {
/// \brief Scale multiplier for given scale value.
static const BasicDecimal256& GetScaleMultiplier(int32_t scale);
+ /// \brief Half-scale multiplier for given scale value.
+ static const BasicDecimal256& GetHalfScaleMultiplier(int32_t scale);
/// \brief Convert BasicDecimal256 from one scale to another
DecimalStatus Rescale(int32_t original_scale, int32_t new_scale,
diff --git a/python/pyarrow/tests/test_compute.py b/python/pyarrow/tests/test_compute.py
index 24cf2e9570e..be2da31b9d1 100644
--- a/python/pyarrow/tests/test_compute.py
+++ b/python/pyarrow/tests/test_compute.py
@@ -1409,8 +1409,7 @@ def test_round_to_multiple():
result = pc.round_to_multiple(values, options=options)
np.testing.assert_allclose(result, pa.array(expected), equal_nan=True)
- with pytest.raises(pa.ArrowInvalid,
- match="multiple has to be a positive value"):
+ with pytest.raises(pa.ArrowInvalid, match="multiple must be positive"):
pc.round_to_multiple(values, multiple=-2)
diff --git a/r/src/compute.cpp b/r/src/compute.cpp
index f3cf514b885..4462c4033fa 100644
--- a/r/src/compute.cpp
+++ b/r/src/compute.cpp
@@ -515,7 +515,8 @@ std::shared_ptr make_compute_options(
using Options = arrow::compute::RoundToMultipleOptions;
auto out = std::make_shared