From ae740e3078ba6b1fabcdea48bd92b45e2d5ececb Mon Sep 17 00:00:00 2001 From: Eduardo Ponce Date: Fri, 25 Jun 2021 06:48:37 -0400 Subject: [PATCH] ARROW-12744: [C++] Add rounding kernels --- cpp/src/arrow/compute/api_scalar.cc | 79 +++++ cpp/src/arrow/compute/api_scalar.h | 92 ++++- cpp/src/arrow/compute/function.h | 2 +- cpp/src/arrow/compute/function_test.cc | 6 + .../arrow/compute/kernels/codegen_internal.h | 4 +- .../compute/kernels/scalar_arithmetic.cc | 316 +++++++++++++++++- .../compute/kernels/scalar_arithmetic_test.cc | 308 +++++++++++++++-- cpp/src/arrow/python/python_test.cc | 2 +- docs/source/cpp/compute.rst | 106 +++++- docs/source/python/api/compute.rst | 6 +- python/pyarrow/_compute.pyx | 51 +++ python/pyarrow/compute.py | 2 + python/pyarrow/includes/libarrow.pxd | 35 ++ python/pyarrow/tests/test_compute.py | 68 +++- 14 files changed, 1012 insertions(+), 65 deletions(-) diff --git a/cpp/src/arrow/compute/api_scalar.cc b/cpp/src/arrow/compute/api_scalar.cc index d992d073d51..83aaee5f0fe 100644 --- a/cpp/src/arrow/compute/api_scalar.cc +++ b/cpp/src/arrow/compute/api_scalar.cc @@ -53,6 +53,7 @@ struct EnumTraits return ""; } }; + template <> struct EnumTraits : BasicEnumTraits return ""; } }; + template <> struct EnumTraits : BasicEnumTraits< @@ -136,6 +138,42 @@ struct EnumTraits return ""; } }; + +template <> +struct EnumTraits + : BasicEnumTraits { + static std::string name() { return "compute::RoundMode"; } + static std::string value_name(compute::RoundMode value) { + switch (value) { + case compute::RoundMode::DOWN: + return "DOWN"; + case compute::RoundMode::UP: + return "UP"; + case compute::RoundMode::TOWARDS_ZERO: + return "TOWARDS_ZERO"; + case compute::RoundMode::TOWARDS_INFINITY: + return "TOWARDS_INFINITY"; + case compute::RoundMode::HALF_DOWN: + return "HALF_DOWN"; + case compute::RoundMode::HALF_UP: + return "HALF_UP"; + case compute::RoundMode::HALF_TOWARDS_ZERO: + return "HALF_TOWARDS_ZERO"; + case compute::RoundMode::HALF_TOWARDS_INFINITY: + return "HALF_TOWARDS_INFINITY"; + case compute::RoundMode::HALF_TO_EVEN: + return "HALF_TO_EVEN"; + case compute::RoundMode::HALF_TO_ODD: + return "HALF_TO_ODD"; + } + return ""; + } +}; } // namespace internal namespace compute { @@ -153,6 +191,12 @@ static auto kArithmeticOptionsType = GetFunctionOptionsType( static auto kElementWiseAggregateOptionsType = GetFunctionOptionsType( DataMember("skip_nulls", &ElementWiseAggregateOptions::skip_nulls)); +static auto kRoundOptionsType = GetFunctionOptionsType( + DataMember("ndigits", &RoundOptions::ndigits), + DataMember("round_mode", &RoundOptions::round_mode)); +static auto kRoundToMultipleOptionsType = GetFunctionOptionsType( + DataMember("multiple", &RoundToMultipleOptions::multiple), + DataMember("round_mode", &RoundToMultipleOptions::round_mode)); static auto kJoinOptionsType = GetFunctionOptionsType( DataMember("null_handling", &JoinOptions::null_handling), DataMember("null_replacement", &JoinOptions::null_replacement)); @@ -217,6 +261,30 @@ ElementWiseAggregateOptions::ElementWiseAggregateOptions(bool skip_nulls) skip_nulls(skip_nulls) {} constexpr char ElementWiseAggregateOptions::kTypeName[]; +RoundOptions::RoundOptions(int64_t ndigits, RoundMode round_mode) + : FunctionOptions(internal::kRoundOptionsType), + ndigits(ndigits), + round_mode(round_mode) { + static_assert(RoundMode::HALF_DOWN > RoundMode::DOWN && + RoundMode::HALF_DOWN > RoundMode::UP && + RoundMode::HALF_DOWN > RoundMode::TOWARDS_ZERO && + RoundMode::HALF_DOWN > RoundMode::TOWARDS_INFINITY && + RoundMode::HALF_DOWN < RoundMode::HALF_UP && + RoundMode::HALF_DOWN < RoundMode::HALF_TOWARDS_ZERO && + RoundMode::HALF_DOWN < RoundMode::HALF_TOWARDS_INFINITY && + RoundMode::HALF_DOWN < RoundMode::HALF_TO_EVEN && + RoundMode::HALF_DOWN < RoundMode::HALF_TO_ODD, + "Invalid order of round modes. Modes prefixed with HALF need to be " + "enumerated last with HALF_DOWN being the first among them."); +} +constexpr char RoundOptions::kTypeName[]; + +RoundToMultipleOptions::RoundToMultipleOptions(double multiple, RoundMode round_mode) + : FunctionOptions(internal::kRoundToMultipleOptionsType), + multiple(multiple), + round_mode(round_mode) {} +constexpr char RoundToMultipleOptions::kTypeName[]; + JoinOptions::JoinOptions(NullHandlingBehavior null_handling, std::string null_replacement) : FunctionOptions(internal::kJoinOptionsType), null_handling(null_handling), @@ -352,6 +420,8 @@ namespace internal { void RegisterScalarOptions(FunctionRegistry* registry) { DCHECK_OK(registry->AddFunctionOptionsType(kArithmeticOptionsType)); DCHECK_OK(registry->AddFunctionOptionsType(kElementWiseAggregateOptionsType)); + DCHECK_OK(registry->AddFunctionOptionsType(kRoundOptionsType)); + DCHECK_OK(registry->AddFunctionOptionsType(kRoundToMultipleOptionsType)); DCHECK_OK(registry->AddFunctionOptionsType(kJoinOptionsType)); DCHECK_OK(registry->AddFunctionOptionsType(kMatchSubstringOptionsType)); DCHECK_OK(registry->AddFunctionOptionsType(kSplitOptionsType)); @@ -405,6 +475,15 @@ SCALAR_ARITHMETIC_UNARY(Log10, "log10", "log10_checked") SCALAR_ARITHMETIC_UNARY(Log2, "log2", "log2_checked") SCALAR_ARITHMETIC_UNARY(Log1p, "log1p", "log1p_checked") +Result Round(const Datum& arg, RoundOptions options, ExecContext* ctx) { + return CallFunction("round", {arg}, &options, ctx); +} + +Result RoundToMultiple(const Datum& arg, RoundToMultipleOptions options, + ExecContext* ctx) { + return CallFunction("round_to_multiple", {arg}, &options, ctx); +} + #define SCALAR_ARITHMETIC_BINARY(NAME, REGISTRY_NAME, REGISTRY_CHECKED_NAME) \ Result NAME(const Datum& left, const Datum& right, ArithmeticOptions options, \ ExecContext* ctx) { \ diff --git a/cpp/src/arrow/compute/api_scalar.h b/cpp/src/arrow/compute/api_scalar.h index c50803a0456..9f9a2931398 100644 --- a/cpp/src/arrow/compute/api_scalar.h +++ b/cpp/src/arrow/compute/api_scalar.h @@ -50,10 +50,58 @@ class ARROW_EXPORT ElementWiseAggregateOptions : public FunctionOptions { explicit ElementWiseAggregateOptions(bool skip_nulls = true); constexpr static char const kTypeName[] = "ElementWiseAggregateOptions"; static ElementWiseAggregateOptions Defaults() { return ElementWiseAggregateOptions{}; } - bool skip_nulls; }; +/// Rounding and tie-breaking modes for round compute functions. +/// Additional details and examples are provided in compute.rst. +enum class RoundMode : int8_t { + /// Round to nearest integer less than or equal in magnitude (aka "floor") + DOWN, + /// Round to nearest integer greater than or equal in magnitude (aka "ceil") + UP, + /// Get the integral part without fractional digits (aka "trunc") + TOWARDS_ZERO, + /// Round negative values with DOWN rule and positive values with UP rule + TOWARDS_INFINITY, + /// Round ties with DOWN rule + HALF_DOWN, + /// Round ties with UP rule + HALF_UP, + /// Round ties with TOWARDS_ZERO rule + HALF_TOWARDS_ZERO, + /// Round ties with TOWARDS_INFINITY rule + HALF_TOWARDS_INFINITY, + /// Round ties to nearest even integer + HALF_TO_EVEN, + /// Round ties to nearest odd integer + HALF_TO_ODD, +}; + +class ARROW_EXPORT RoundOptions : public FunctionOptions { + public: + explicit RoundOptions(int64_t ndigits = 0, + RoundMode round_mode = RoundMode::HALF_TO_EVEN); + constexpr static char const kTypeName[] = "RoundOptions"; + static RoundOptions Defaults() { return RoundOptions(); } + /// Rounding precision (number of digits to round to) + int64_t ndigits; + /// Rounding and tie-breaking mode + RoundMode round_mode; +}; + +class ARROW_EXPORT RoundToMultipleOptions : public FunctionOptions { + public: + explicit RoundToMultipleOptions(double multiple = 1.0, + RoundMode round_mode = RoundMode::HALF_TO_EVEN); + constexpr static char const kTypeName[] = "RoundToMultipleOptions"; + static RoundToMultipleOptions Defaults() { return RoundToMultipleOptions(); } + /// Rounding scale (multiple to round to) + double multiple; + /// Rounding and tie-breaking mode + RoundMode round_mode; +}; + /// Options for var_args_join. class ARROW_EXPORT JoinOptions : public FunctionOptions { public: @@ -559,8 +607,9 @@ Result Logb(const Datum& arg, const Datum& base, ExecContext* ctx = NULLPTR); /// \brief Round to the nearest integer less than or equal in magnitude to the -/// argument. Array values can be of arbitrary length. If argument is null the -/// result will be null. +/// argument. +/// +/// If argument is null the result will be null. /// /// \param[in] arg the value to round /// \param[in] ctx the function execution context, optional @@ -569,8 +618,9 @@ ARROW_EXPORT Result Floor(const Datum& arg, ExecContext* ctx = NULLPTR); /// \brief Round to the nearest integer greater than or equal in magnitude to the -/// argument. Array values can be of arbitrary length. If argument is null the -/// result will be null. +/// argument. +/// +/// If argument is null the result will be null. /// /// \param[in] arg the value to round /// \param[in] ctx the function execution context, optional @@ -578,8 +628,9 @@ Result Floor(const Datum& arg, ExecContext* ctx = NULLPTR); ARROW_EXPORT Result Ceil(const Datum& arg, ExecContext* ctx = NULLPTR); -/// \brief Get the integral part without fractional digits. Array values can be -/// of arbitrary length. If argument is null the result will be null. +/// \brief Get the integral part without fractional digits. +/// +/// If argument is null the result will be null. /// /// \param[in] arg the value to truncate /// \param[in] ctx the function execution context, optional @@ -618,10 +669,35 @@ Result MinElementWise( /// /// \param[in] arg the value to extract sign from /// \param[in] ctx the function execution context, optional -/// \return the elementwise sign function +/// \return the element-wise sign function ARROW_EXPORT Result Sign(const Datum& arg, ExecContext* ctx = NULLPTR); +/// \brief Round a value to a given precision. +/// +/// If argument is null the result will be null. +/// +/// \param[in] arg the value rounded +/// \param[in] options rounding options (rounding mode and number of digits), optional +/// \param[in] ctx the function execution context, optional +/// \return the element-wise rounded value +ARROW_EXPORT +Result Round(const Datum& arg, RoundOptions options = RoundOptions::Defaults(), + ExecContext* ctx = NULLPTR); + +/// \brief Round a value to a given multiple. +/// +/// If argument is null the result will be null. +/// +/// \param[in] arg the value to round +/// \param[in] options rounding options (rounding mode and multiple), optional +/// \param[in] ctx the function execution context, optional +/// \return the element-wise rounded value +ARROW_EXPORT +Result RoundToMultiple( + const Datum& arg, RoundToMultipleOptions options = RoundToMultipleOptions::Defaults(), + ExecContext* ctx = NULLPTR); + /// \brief Compare a numeric array with a scalar. /// /// \param[in] left datum to compare, must be an Array diff --git a/cpp/src/arrow/compute/function.h b/cpp/src/arrow/compute/function.h index 6434d5090f6..f08b50699a5 100644 --- a/cpp/src/arrow/compute/function.h +++ b/cpp/src/arrow/compute/function.h @@ -227,7 +227,7 @@ class ARROW_EXPORT Function { virtual Result Execute(const std::vector& args, const FunctionOptions* options, ExecContext* ctx) const; - /// \brief Returns a the default options for this function. + /// \brief Returns the default options for this function. /// /// Whatever option semantics a Function has, implementations must guarantee /// that default_options() is valid to pass to Execute as options. diff --git a/cpp/src/arrow/compute/function_test.cc b/cpp/src/arrow/compute/function_test.cc index c08fdaca627..183167490b6 100644 --- a/cpp/src/arrow/compute/function_test.cc +++ b/cpp/src/arrow/compute/function_test.cc @@ -58,6 +58,12 @@ TEST(FunctionOptions, Equality) { options.emplace_back(new IndexOptions(ScalarFromJSON(boolean(), "null"))); options.emplace_back(new ArithmeticOptions()); options.emplace_back(new ArithmeticOptions(/*check_overflow=*/true)); + options.emplace_back(new RoundOptions()); + options.emplace_back( + new RoundOptions(/*ndigits=*/2, /*round_mode=*/RoundMode::TOWARDS_INFINITY)); + options.emplace_back(new RoundToMultipleOptions()); + options.emplace_back(new RoundToMultipleOptions( + /*multiple=*/100, /*round_mode=*/RoundMode::TOWARDS_INFINITY)); options.emplace_back(new ElementWiseAggregateOptions()); options.emplace_back(new ElementWiseAggregateOptions(/*skip_nulls=*/false)); options.emplace_back(new JoinOptions()); diff --git a/cpp/src/arrow/compute/kernels/codegen_internal.h b/cpp/src/arrow/compute/kernels/codegen_internal.h index 951b502b991..3fbc6caa27d 100644 --- a/cpp/src/arrow/compute/kernels/codegen_internal.h +++ b/cpp/src/arrow/compute/kernels/codegen_internal.h @@ -94,8 +94,8 @@ struct OptionsWrapper : public KernelState { /// KernelContext and the FunctionOptions as argument template struct KernelStateFromFunctionOptions : public KernelState { - explicit KernelStateFromFunctionOptions(KernelContext* ctx, OptionsType state) - : state(StateType(ctx, std::move(state))) {} + explicit KernelStateFromFunctionOptions(KernelContext* ctx, OptionsType options) + : state(StateType(ctx, std::move(options))) {} static Result> Init(KernelContext* ctx, const KernelInitArgs& args) { diff --git a/cpp/src/arrow/compute/kernels/scalar_arithmetic.cc b/cpp/src/arrow/compute/kernels/scalar_arithmetic.cc index 7fa5ac52f3f..fa78083effa 100644 --- a/cpp/src/arrow/compute/kernels/scalar_arithmetic.cc +++ b/cpp/src/arrow/compute/kernels/scalar_arithmetic.cc @@ -21,7 +21,8 @@ #include #include -#include "arrow/compute/kernels/codegen_internal.h" +#include "arrow/compare.h" +#include "arrow/compute/api_scalar.h" #include "arrow/compute/kernels/common.h" #include "arrow/compute/kernels/util_internal.h" #include "arrow/type.h" @@ -862,11 +863,224 @@ struct LogbChecked { } }; +struct RoundUtil { + // Calculate powers of ten with arbitrary integer exponent + template + static enable_if_floating_point Pow10(int64_t power) { + static constexpr T lut[] = {1e0F, 1e1F, 1e2F, 1e3F, 1e4F, 1e5F, 1e6F, 1e7F, + 1e8F, 1e9F, 1e10F, 1e11F, 1e12F, 1e13F, 1e14F, 1e15F}; + int64_t lut_size = (sizeof(lut) / sizeof(*lut)); + int64_t abs_power = std::abs(power); + auto pow10 = lut[std::min(abs_power, lut_size - 1)]; + while (abs_power-- >= lut_size) { + pow10 *= 1e1F; + } + return (power >= 0) ? pow10 : (1 / pow10); + } +}; + +// Specializations of rounding implementations for round kernels +template +struct RoundImpl; + +template +struct RoundImpl { + static constexpr enable_if_floating_point Round(const T val) { + return std::floor(val); + } +}; + +template +struct RoundImpl { + static constexpr enable_if_floating_point Round(const T val) { + return std::ceil(val); + } +}; + +template +struct RoundImpl { + static constexpr enable_if_floating_point Round(const T val) { + return std::trunc(val); + } +}; + +template +struct RoundImpl { + static constexpr enable_if_floating_point Round(const T val) { + return std::signbit(val) ? std::floor(val) : std::ceil(val); + } +}; + +// NOTE: RoundImpl variants for the HALF_* rounding modes are only +// invoked when the fractional part is equal to 0.5 (std::round is invoked +// otherwise). + +template +struct RoundImpl { + static constexpr enable_if_floating_point Round(const T val) { + return RoundImpl::Round(val); + } +}; + +template +struct RoundImpl { + static constexpr enable_if_floating_point Round(const T val) { + return RoundImpl::Round(val); + } +}; + +template +struct RoundImpl { + static constexpr enable_if_floating_point Round(const T val) { + return RoundImpl::Round(val); + } +}; + +template +struct RoundImpl { + static constexpr enable_if_floating_point Round(const T val) { + return RoundImpl::Round(val); + } +}; + +template +struct RoundImpl { + static constexpr enable_if_floating_point Round(const T val) { + return std::round(val * T(0.5)) * 2; + } +}; + +template +struct RoundImpl { + static constexpr enable_if_floating_point Round(const T val) { + return std::floor(val * T(0.5)) + std::ceil(val * T(0.5)); + } +}; + +// Specializations of kernel state for round kernels +template +struct RoundOptionsWrapper; + +template <> +struct RoundOptionsWrapper : public OptionsWrapper { + using OptionsType = RoundOptions; + using State = RoundOptionsWrapper; + double pow10; + + explicit RoundOptionsWrapper(OptionsType options) : OptionsWrapper(std::move(options)) { + // Only positive exponents for powers of 10 are used because combining + // multiply and division operations produced more stable rounding than + // using multiply-only. Refer to NumPy's round implementation: + // https://github.com/numpy/numpy/blob/7b2f20b406d27364c812f7a81a9c901afbd3600c/numpy/core/src/multiarray/calculation.c#L589 + pow10 = RoundUtil::Pow10(std::abs(options.ndigits)); + } + + static Result> Init(KernelContext* ctx, + const KernelInitArgs& args) { + if (auto options = static_cast(args.options)) { + return ::arrow::internal::make_unique(*options); + } + return Status::Invalid( + "Attempted to initialize KernelState from null FunctionOptions"); + } +}; + +template <> +struct RoundOptionsWrapper + : public OptionsWrapper { + using OptionsType = RoundToMultipleOptions; + + static Result> Init(KernelContext* ctx, + const KernelInitArgs& args) { + ARROW_ASSIGN_OR_RAISE(auto state, OptionsWrapper::Init(ctx, args)); + auto options = Get(*state); + if (options.multiple <= 0) { + return Status::Invalid("Rounding multiple has to be a positive value"); + } + return std::move(state); + } +}; + +template +struct Round { + using State = RoundOptionsWrapper; + + template + static enable_if_floating_point Call(KernelContext* ctx, Arg arg, Status* st) { + static_assert(std::is_same::value, ""); + // Do not process Inf or NaN because they will trigger the overflow error at end of + // function. + if (!std::isfinite(arg)) { + return arg; + } + auto state = static_cast(ctx->state()); + auto options = state->options; + auto pow10 = T(state->pow10); + auto round_val = (options.ndigits >= 0) ? (arg * pow10) : (arg / pow10); + auto frac = round_val - std::floor(round_val); + if (frac != T(0)) { + // Use std::round() if in tie-breaking mode and scaled value is not 0.5. + if ((RndMode >= RoundMode::HALF_DOWN) && (frac != T(0.5))) { + round_val = std::round(round_val); + } else { + round_val = RoundImpl::Round(round_val); + } + // Equality check is ommitted so that the common case of 10^0 (integer rounding) + // uses multiply-only + round_val = (options.ndigits > 0) ? (round_val / pow10) : (round_val * pow10); + if (!std::isfinite(round_val)) { + *st = Status::Invalid("overflow occurred during rounding"); + return arg; + } + } else { + // If scaled value is an integer, then no rounding is needed. + round_val = arg; + } + return round_val; + } +}; + +template +struct RoundToMultiple { + using State = RoundOptionsWrapper; + + template + static enable_if_floating_point Call(KernelContext* ctx, Arg arg, Status* st) { + static_assert(std::is_same::value, ""); + // Do not process Inf or NaN because they will trigger the overflow error at end of + // function. + if (!std::isfinite(arg)) { + return arg; + } + auto options = State::Get(ctx); + auto round_val = arg / T(options.multiple); + auto frac = round_val - std::floor(round_val); + if (frac != T(0)) { + // Use std::round() if in tie-breaking mode and scaled value is not 0.5. + if ((RndMode >= RoundMode::HALF_DOWN) && (frac != T(0.5))) { + round_val = std::round(round_val); + } else { + round_val = RoundImpl::Round(round_val); + } + round_val *= T(options.multiple); + if (!std::isfinite(round_val)) { + *st = Status::Invalid("overflow occurred during rounding"); + return arg; + } + } else { + // If scaled value is an integer, then no rounding is needed. + round_val = arg; + } + return round_val; + } +}; + struct Floor { template static constexpr enable_if_floating_point Call(KernelContext*, Arg arg, Status*) { - return std::floor(arg); + static_assert(std::is_same::value, ""); + return RoundImpl::Round(arg); } }; @@ -874,7 +1088,8 @@ struct Ceil { template static constexpr enable_if_floating_point Call(KernelContext*, Arg arg, Status*) { - return std::ceil(arg); + static_assert(std::is_same::value, ""); + return RoundImpl::Round(arg); } }; @@ -882,7 +1097,8 @@ struct Trunc { template static constexpr enable_if_floating_point Call(KernelContext*, Arg arg, Status*) { - return std::trunc(arg); + static_assert(std::is_same::value, ""); + return RoundImpl::Round(arg); } }; @@ -1221,6 +1437,65 @@ std::shared_ptr MakeUnaryArithmeticFunctionNotNull( return func; } +// Generate a kernel given an arithmetic rounding functor +template