diff --git a/cpp/src/arrow/compute/api_scalar.cc b/cpp/src/arrow/compute/api_scalar.cc index 3a6e99342ef..e3fe1bdf73d 100644 --- a/cpp/src/arrow/compute/api_scalar.cc +++ b/cpp/src/arrow/compute/api_scalar.cc @@ -283,8 +283,11 @@ RoundOptions::RoundOptions(int64_t ndigits, RoundMode round_mode) constexpr char RoundOptions::kTypeName[]; RoundToMultipleOptions::RoundToMultipleOptions(double multiple, RoundMode round_mode) + : RoundToMultipleOptions(std::make_shared(multiple), round_mode) {} +RoundToMultipleOptions::RoundToMultipleOptions(std::shared_ptr multiple, + RoundMode round_mode) : FunctionOptions(internal::kRoundToMultipleOptionsType), - multiple(multiple), + multiple(std::move(multiple)), round_mode(round_mode) {} constexpr char RoundToMultipleOptions::kTypeName[]; diff --git a/cpp/src/arrow/compute/api_scalar.h b/cpp/src/arrow/compute/api_scalar.h index 8e1cd4875d5..65ec4d6e18b 100644 --- a/cpp/src/arrow/compute/api_scalar.h +++ b/cpp/src/arrow/compute/api_scalar.h @@ -93,10 +93,17 @@ class ARROW_EXPORT RoundToMultipleOptions : public FunctionOptions { public: explicit RoundToMultipleOptions(double multiple = 1.0, RoundMode round_mode = RoundMode::HALF_TO_EVEN); + explicit RoundToMultipleOptions(std::shared_ptr multiple, + RoundMode round_mode = RoundMode::HALF_TO_EVEN); constexpr static char const kTypeName[] = "RoundToMultipleOptions"; static RoundToMultipleOptions Defaults() { return RoundToMultipleOptions(); } - /// Rounding scale (multiple to round to) - double multiple; + /// Rounding scale (multiple to round to). + /// + /// Should be a scalar of a type compatible with the argument to be rounded. + /// For example, rounding a decimal value means a decimal multiple is + /// required. Rounding a floating point or integer value means a floating + /// point scalar is required. + std::shared_ptr multiple; /// Rounding and tie-breaking mode RoundMode round_mode; }; diff --git a/cpp/src/arrow/compute/kernels/scalar_arithmetic.cc b/cpp/src/arrow/compute/kernels/scalar_arithmetic.cc index bf2d27f1f72..48a5e815b67 100644 --- a/cpp/src/arrow/compute/kernels/scalar_arithmetic.cc +++ b/cpp/src/arrow/compute/kernels/scalar_arithmetic.cc @@ -46,6 +46,7 @@ using applicator::ScalarBinaryEqualTypes; using applicator::ScalarBinaryNotNullEqualTypes; using applicator::ScalarUnary; using applicator::ScalarUnaryNotNull; +using applicator::ScalarUnaryNotNullStateful; namespace { @@ -73,10 +74,10 @@ using enable_if_c_integer = template using enable_if_floating_point = enable_if_t::value, R>; -template +template using enable_if_decimal_value = enable_if_t::value || std::is_same::value, - T>; + R>; struct AbsoluteValue { template @@ -880,81 +881,170 @@ struct RoundUtil { }; // Specializations of rounding implementations for round kernels -template +template struct RoundImpl; -template -struct RoundImpl { +template +struct RoundImpl { + template static constexpr enable_if_floating_point Round(const T val) { return std::floor(val); } + + template + static enable_if_decimal_value Round(T* val, const T& remainder, + const T& pow10, const int32_t scale) { + (*val) -= remainder; + if (remainder.Sign() < 0) { + (*val) -= pow10; + } + } }; -template -struct RoundImpl { +template +struct RoundImpl { + template static constexpr enable_if_floating_point Round(const T val) { return std::ceil(val); } + + template + static enable_if_decimal_value Round(T* val, const T& remainder, + const T& pow10, const int32_t scale) { + (*val) -= remainder; + if (remainder.Sign() > 0 && remainder != 0) { + (*val) += pow10; + } + } }; -template -struct RoundImpl { +template +struct RoundImpl { + template static constexpr enable_if_floating_point Round(const T val) { return std::trunc(val); } + + template + static enable_if_decimal_value Round(T* val, const T& remainder, + const T& pow10, const int32_t scale) { + (*val) -= remainder; + } }; -template -struct RoundImpl { +template +struct RoundImpl { + template static constexpr enable_if_floating_point Round(const T val) { return std::signbit(val) ? std::floor(val) : std::ceil(val); } + + template + static enable_if_decimal_value Round(T* val, const T& remainder, + const T& pow10, const int32_t scale) { + (*val) -= remainder; + if (remainder.Sign() < 0) { + (*val) -= pow10; + } else if (remainder.Sign() > 0 && remainder != 0) { + (*val) += pow10; + } + } }; // NOTE: RoundImpl variants for the HALF_* rounding modes are only // invoked when the fractional part is equal to 0.5 (std::round is invoked // otherwise). -template -struct RoundImpl { +template +struct RoundImpl { + template static constexpr enable_if_floating_point Round(const T val) { return RoundImpl::Round(val); } + + template + static enable_if_decimal_value Round(T* val, const T& remainder, + const T& pow10, const int32_t scale) { + RoundImpl::Round(val, remainder, pow10, scale); + } }; -template -struct RoundImpl { +template +struct RoundImpl { + template static constexpr enable_if_floating_point Round(const T val) { return RoundImpl::Round(val); } + + template + static enable_if_decimal_value Round(T* val, const T& remainder, + const T& pow10, const int32_t scale) { + RoundImpl::Round(val, remainder, pow10, scale); + } }; -template -struct RoundImpl { +template +struct RoundImpl { + template static constexpr enable_if_floating_point Round(const T val) { return RoundImpl::Round(val); } + + template + static enable_if_decimal_value Round(T* val, const T& remainder, + const T& pow10, const int32_t scale) { + RoundImpl::Round(val, remainder, pow10, scale); + } }; -template -struct RoundImpl { +template +struct RoundImpl { + template static constexpr enable_if_floating_point Round(const T val) { return RoundImpl::Round(val); } + + template + static enable_if_decimal_value Round(T* val, const T& remainder, + const T& pow10, const int32_t scale) { + RoundImpl::Round(val, remainder, pow10, scale); + } }; -template -struct RoundImpl { +template +struct RoundImpl { + template static constexpr enable_if_floating_point Round(const T val) { return std::round(val * T(0.5)) * 2; } + + template + static enable_if_decimal_value Round(T* val, const T& remainder, + const T& pow10, const int32_t scale) { + auto scaled = val->ReduceScaleBy(scale, /*round=*/false); + if (scaled.low_bits() % 2 != 0) { + scaled += remainder.Sign() >= 0 ? 1 : -1; + } + *val = scaled.IncreaseScaleBy(scale); + } }; -template -struct RoundImpl { +template +struct RoundImpl { + template static constexpr enable_if_floating_point Round(const T val) { return std::floor(val * T(0.5)) + std::ceil(val * T(0.5)); } + + template + static enable_if_decimal_value Round(T* val, const T& remainder, + const T& pow10, const int32_t scale) { + auto scaled = val->ReduceScaleBy(scale, /*round=*/false); + if (scaled.low_bits() % 2 == 0) { + scaled += remainder.Sign() ? 1 : -1; + } + *val = scaled.IncreaseScaleBy(scale); + } }; // Specializations of kernel state for round kernels @@ -994,40 +1084,89 @@ struct RoundOptionsWrapper const KernelInitArgs& args) { ARROW_ASSIGN_OR_RAISE(auto state, OptionsWrapper::Init(ctx, args)); auto options = Get(*state); - if (options.multiple <= 0) { - return Status::Invalid("Rounding multiple has to be a positive value"); + const auto& type = *args.inputs[0].type; + if (!options.multiple || !options.multiple->is_valid) { + return Status::Invalid("Rounding multiple must be non-null and valid"); + } + if (is_floating(type.id())) { + switch (options.multiple->type->id()) { + case Type::FLOAT: { + if (UnboxScalar::Unbox(*options.multiple) < 0) { + return Status::Invalid("Rounding multiple must be positive"); + } + break; + } + case Type::DOUBLE: { + if (UnboxScalar::Unbox(*options.multiple) < 0) { + return Status::Invalid("Rounding multiple must be positive"); + } + break; + } + case Type::HALF_FLOAT: + return Status::NotImplemented("Half-float values are not supported"); + default: + return Status::Invalid("Rounding multiple must be a ", type, " scalar, not ", + *options.multiple->type); + } + } else { + DCHECK(is_decimal(type.id())); + if (!type.Equals(*options.multiple->type)) { + return Status::Invalid("Rounding multiple must be a ", type, " scalar, not ", + *options.multiple->type); + } + switch (options.multiple->type->id()) { + case Type::DECIMAL128: { + if (UnboxScalar::Unbox(*options.multiple) <= 0) { + return Status::Invalid("Rounding multiple must be positive"); + } + break; + } + case Type::DECIMAL256: { + if (UnboxScalar::Unbox(*options.multiple) <= 0) { + return Status::Invalid("Rounding multiple must be positive"); + } + break; + } + default: + // This shouldn't happen + return Status::Invalid("Rounding multiple must be a ", type, " scalar, not ", + *options.multiple->type); + } } return std::move(state); } }; -template +template struct Round { + using CType = typename TypeTraits::CType; using State = RoundOptionsWrapper; - template - static enable_if_floating_point Call(KernelContext* ctx, Arg arg, Status* st) { - static_assert(std::is_same::value, ""); + CType pow10; + int64_t ndigits; + + explicit Round(const State& state, const DataType& out_ty) + : pow10(static_cast(state.pow10)), ndigits(state.options.ndigits) {} + + template ::CType> + enable_if_floating_point Call(KernelContext* ctx, CType arg, Status* st) const { // Do not process Inf or NaN because they will trigger the overflow error at end of // function. if (!std::isfinite(arg)) { return arg; } - auto state = static_cast(ctx->state()); - auto options = state->options; - auto pow10 = T(state->pow10); - auto round_val = (options.ndigits >= 0) ? (arg * pow10) : (arg / pow10); + auto round_val = ndigits >= 0 ? (arg * pow10) : (arg / pow10); auto frac = round_val - std::floor(round_val); if (frac != T(0)) { // Use std::round() if in tie-breaking mode and scaled value is not 0.5. if ((RndMode >= RoundMode::HALF_DOWN) && (frac != T(0.5))) { round_val = std::round(round_val); } else { - round_val = RoundImpl::Round(round_val); + round_val = RoundImpl::Round(round_val); } // Equality check is ommitted so that the common case of 10^0 (integer rounding) // uses multiply-only - round_val = (options.ndigits > 0) ? (round_val / pow10) : (round_val * pow10); + round_val = ndigits > 0 ? (round_val / pow10) : (round_val * pow10); if (!std::isfinite(round_val)) { *st = Status::Invalid("overflow occurred during rounding"); return arg; @@ -1040,29 +1179,128 @@ struct Round { } }; -template +template +struct Round> { + using CType = typename TypeTraits::CType; + using State = RoundOptionsWrapper; + + const ArrowType& ty; + int64_t ndigits; + int32_t pow; + // pow10 is "1" for the given decimal scale. Similarly half_pow10 is "0.5". + CType pow10, half_pow10, neg_half_pow10; + + explicit Round(const State& state, const DataType& out_ty) + : Round(state.options.ndigits, out_ty) {} + + explicit Round(int64_t ndigits, const DataType& out_ty) + : ty(checked_cast(out_ty)), + ndigits(ndigits), + pow(static_cast(ty.scale() - ndigits)) { + if (pow >= ty.precision() || pow < 0) { + pow10 = half_pow10 = neg_half_pow10 = 0; + } else { + pow10 = CType::GetScaleMultiplier(pow); + half_pow10 = CType::GetHalfScaleMultiplier(pow); + neg_half_pow10 = -half_pow10; + } + } + + template ::CType> + enable_if_decimal_value Call(KernelContext* ctx, CType arg, Status* st) const { + if (pow >= ty.precision()) { + *st = Status::Invalid("Rounding to ", ndigits, + " digits will not fit in precision of ", ty); + return arg; + } else if (pow < 0) { + // no-op, copy output to input + return arg; + } + + std::pair pair; + *st = arg.Divide(pow10).Value(&pair); + if (!st->ok()) return arg; + // The remainder is effectively the scaled fractional part after division. + const auto& remainder = pair.second; + if (remainder == 0) return arg; + if (kRoundMode >= RoundMode::HALF_DOWN) { + if (remainder == half_pow10 || remainder == neg_half_pow10) { + // On the halfway point, use tiebreaker + RoundImpl::Round(&arg, remainder, pow10, pow); + } else if (remainder.Sign() >= 0) { + // Positive, round up/down + arg -= remainder; + if (remainder > half_pow10) { + arg += pow10; + } + } else { + // Negative, round up/down + arg -= remainder; + if (remainder < neg_half_pow10) { + arg -= pow10; + } + } + } else { + RoundImpl::Round(&arg, remainder, pow10, pow); + } + if (!arg.FitsInPrecision(ty.precision())) { + *st = Status::Invalid("Rounded value ", arg.ToString(ty.scale()), + " does not fit in precision of ", ty); + return 0; + } + return arg; + } +}; + +template +Status FixedRoundDecimalExec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { + using Op = Round; + return ScalarUnaryNotNullStateful( + Op(kDigits, *out->type())) + .Exec(ctx, batch, out); +} + +template struct RoundToMultiple { + using CType = typename TypeTraits::CType; using State = RoundOptionsWrapper; - template - static enable_if_floating_point Call(KernelContext* ctx, Arg arg, Status* st) { - static_assert(std::is_same::value, ""); + CType multiple; + + explicit RoundToMultiple(const State& state, const DataType& out_ty) { + const auto& options = state.options; + DCHECK(options.multiple); + DCHECK(options.multiple->is_valid); + DCHECK(is_floating(options.multiple->type->id())); + switch (options.multiple->type->id()) { + case Type::FLOAT: + multiple = static_cast(UnboxScalar::Unbox(*options.multiple)); + break; + case Type::DOUBLE: + multiple = static_cast(UnboxScalar::Unbox(*options.multiple)); + break; + default: + DCHECK(false); + } + } + + template ::CType> + enable_if_floating_point Call(KernelContext* ctx, CType arg, Status* st) const { // Do not process Inf or NaN because they will trigger the overflow error at end of // function. if (!std::isfinite(arg)) { return arg; } - auto options = State::Get(ctx); - auto round_val = arg / T(options.multiple); + auto round_val = arg / multiple; auto frac = round_val - std::floor(round_val); if (frac != T(0)) { // Use std::round() if in tie-breaking mode and scaled value is not 0.5. - if ((RndMode >= RoundMode::HALF_DOWN) && (frac != T(0.5))) { + if ((kRoundMode >= RoundMode::HALF_DOWN) && (frac != T(0.5))) { round_val = std::round(round_val); } else { - round_val = RoundImpl::Round(round_val); + round_val = RoundImpl::Round(round_val); } - round_val *= T(options.multiple); + round_val *= multiple; if (!std::isfinite(round_val)) { *st = Status::Invalid("overflow occurred during rounding"); return arg; @@ -1075,6 +1313,116 @@ struct RoundToMultiple { } }; +template +struct RoundToMultiple> { + using CType = typename TypeTraits::CType; + using State = RoundOptionsWrapper; + + const ArrowType& ty; + CType multiple, half_multiple, neg_half_multiple; + bool has_halfway_point; + + explicit RoundToMultiple(const State& state, const DataType& out_ty) + : ty(checked_cast(out_ty)) { + const auto& options = state.options; + DCHECK(options.multiple); + DCHECK(options.multiple->is_valid); + DCHECK(options.multiple->type->Equals(out_ty)); + multiple = UnboxScalar::Unbox(*options.multiple); + half_multiple = multiple; + half_multiple /= 2; + neg_half_multiple = -half_multiple; + has_halfway_point = multiple.low_bits() % 2 == 0; + } + + template ::CType> + enable_if_decimal_value Call(KernelContext* ctx, CType arg, Status* st) const { + std::pair pair; + *st = arg.Divide(multiple).Value(&pair); + if (!st->ok()) return arg; + const auto& remainder = pair.second; + if (remainder == 0) return arg; + if (kRoundMode >= RoundMode::HALF_DOWN) { + if (has_halfway_point && + (remainder == half_multiple || remainder == neg_half_multiple)) { + // On the halfway point, use tiebreaker + // Manually implement rounding since we're not actually rounding a + // decimal value, but rather manipulating the multiple + switch (kRoundMode) { + case RoundMode::HALF_DOWN: + if (remainder.Sign() < 0) pair.first -= 1; + break; + case RoundMode::HALF_UP: + if (remainder.Sign() >= 0) pair.first += 1; + break; + case RoundMode::HALF_TOWARDS_ZERO: + // Do nothing + break; + case RoundMode::HALF_TOWARDS_INFINITY: + if (remainder.Sign() >= 0) { + pair.first += 1; + } else { + pair.first -= 1; + } + break; + case RoundMode::HALF_TO_EVEN: + if (pair.first.low_bits() % 2 != 0) { + pair.first += remainder.Sign() >= 0 ? 1 : -1; + } + break; + case RoundMode::HALF_TO_ODD: + if (pair.first.low_bits() % 2 == 0) { + pair.first += remainder.Sign() >= 0 ? 1 : -1; + } + break; + default: + DCHECK(false); + } + } else if (remainder.Sign() >= 0) { + // Positive, round up/down + if (remainder > half_multiple) { + pair.first += 1; + } + } else { + // Negative, round up/down + if (remainder < neg_half_multiple) { + pair.first -= 1; + } + } + } else { + // Manually implement rounding since we're not actually rounding a + // decimal value, but rather manipulating the multiple + switch (kRoundMode) { + case RoundMode::DOWN: + if (remainder.Sign() < 0) pair.first -= 1; + break; + case RoundMode::UP: + if (remainder.Sign() >= 0) pair.first += 1; + break; + case RoundMode::TOWARDS_ZERO: + // Do nothing + break; + case RoundMode::TOWARDS_INFINITY: + if (remainder.Sign() >= 0) { + pair.first += 1; + } else { + pair.first -= 1; + } + break; + default: + DCHECK(false); + } + } + CType round_val = pair.first * multiple; + if (!round_val.FitsInPrecision(ty.precision())) { + *st = Status::Invalid("Rounded value ", round_val.ToString(ty.scale()), + " does not fit in precision of ", ty); + return 0; + } + return round_val; + } +}; + struct Floor { template static constexpr enable_if_floating_point Call(KernelContext*, Arg arg, @@ -1343,6 +1691,37 @@ struct ArithmeticFunction : ScalarFunction { } }; +/// An ArithmeticFunction that promotes only integer arguments to double. +struct ArithmeticIntegerToFloatingPointFunction : public ArithmeticFunction { + using ArithmeticFunction::ArithmeticFunction; + + Result DispatchBest(std::vector* values) const override { + RETURN_NOT_OK(CheckArity(*values)); + RETURN_NOT_OK(CheckDecimals(values)); + + using arrow::compute::detail::DispatchExactImpl; + if (auto kernel = DispatchExactImpl(this, *values)) return kernel; + + EnsureDictionaryDecoded(values); + + if (values->size() == 2) { + ReplaceNullWithOtherType(values); + } + + for (auto& descr : *values) { + if (is_integer(descr.type->id())) { + descr.type = float64(); + } + } + if (auto type = CommonNumeric(*values)) { + ReplaceTypes(type, values); + } + + if (auto kernel = DispatchExactImpl(this, *values)) return kernel; + return arrow::compute::detail::NoMatchingKernel(this, *values); + } +}; + /// An ArithmeticFunction that promotes integer arguments to double. struct ArithmeticFloatingPointFunction : public ArithmeticFunction { using ArithmeticFunction::ArithmeticFunction; @@ -1452,61 +1831,100 @@ std::shared_ptr MakeUnaryArithmeticFunctionNotNull( return func; } -// Generate a kernel given an arithmetic rounding functor -template