diff --git a/cpp/src/arrow/compute/api_scalar.cc b/cpp/src/arrow/compute/api_scalar.cc index f4696fbe02a..d169fd2ebde 100644 --- a/cpp/src/arrow/compute/api_scalar.cc +++ b/cpp/src/arrow/compute/api_scalar.cc @@ -52,6 +52,7 @@ SCALAR_ARITHMETIC_BINARY(Add, "add", "add_checked") SCALAR_ARITHMETIC_BINARY(Subtract, "subtract", "subtract_checked") SCALAR_ARITHMETIC_BINARY(Multiply, "multiply", "multiply_checked") SCALAR_ARITHMETIC_BINARY(Divide, "divide", "divide_checked") +SCALAR_ARITHMETIC_BINARY(Power, "power", "power_checked") // ---------------------------------------------------------------------- // Set-related operations diff --git a/cpp/src/arrow/compute/api_scalar.h b/cpp/src/arrow/compute/api_scalar.h index f59426d8f1b..6032f656c4a 100644 --- a/cpp/src/arrow/compute/api_scalar.h +++ b/cpp/src/arrow/compute/api_scalar.h @@ -204,6 +204,20 @@ Result Divide(const Datum& left, const Datum& right, ArithmeticOptions options = ArithmeticOptions(), ExecContext* ctx = NULLPTR); +/// \brief Raise the values of base array to the power of the exponent array values. +/// Array values must be the same length. If either base or exponent is null the result +/// will be null. +/// +/// \param[in] left the base +/// \param[in] right the exponent +/// \param[in] options arithmetic options (enable/disable overflow checking), optional +/// \param[in] ctx the function execution context, optional +/// \return the elementwise base value raised to the power of exponent +ARROW_EXPORT +Result Power(const Datum& left, const Datum& right, + ArithmeticOptions options = ArithmeticOptions(), + ExecContext* ctx = NULLPTR); + /// \brief Compare a numeric array with a scalar. /// /// \param[in] left datum to compare, must be an Array diff --git a/cpp/src/arrow/compute/kernels/scalar_arithmetic.cc b/cpp/src/arrow/compute/kernels/scalar_arithmetic.cc index 7abaa1c1a59..260721b08d9 100644 --- a/cpp/src/arrow/compute/kernels/scalar_arithmetic.cc +++ b/cpp/src/arrow/compute/kernels/scalar_arithmetic.cc @@ -15,6 +15,8 @@ // specific language governing permissions and limitations // under the License. +#include + #include "arrow/compute/kernels/common.h" #include "arrow/util/int_util_internal.h" #include "arrow/util/macros.h" @@ -233,6 +235,70 @@ struct DivideChecked { } }; +struct Power { + ARROW_NOINLINE + static uint64_t IntegerPower(uint64_t base, uint64_t exp) { + // right to left O(logn) power + uint64_t pow = 1; + while (exp) { + pow *= (exp & 1) ? base : 1; + base *= base; + exp >>= 1; + } + return pow; + } + + template + static enable_if_integer Call(KernelContext* ctx, T base, T exp) { + if (exp < 0) { + ctx->SetStatus( + Status::Invalid("integers to negative integer powers are not allowed")); + return 0; + } + return static_cast(IntegerPower(base, exp)); + } + + template + static enable_if_floating_point Call(KernelContext* ctx, T base, T exp) { + return std::pow(base, exp); + } +}; + +struct PowerChecked { + template + static enable_if_integer Call(KernelContext* ctx, Arg0 base, Arg1 exp) { + if (exp < 0) { + ctx->SetStatus( + Status::Invalid("integers to negative integer powers are not allowed")); + return 0; + } else if (exp == 0) { + return 1; + } + // left to right O(logn) power with overflow checks + bool overflow = false; + uint64_t bitmask = + 1ULL << (63 - BitUtil::CountLeadingZeros(static_cast(exp))); + T pow = 1; + while (bitmask) { + overflow |= MultiplyWithOverflow(pow, pow, &pow); + if (exp & bitmask) { + overflow |= MultiplyWithOverflow(pow, base, &pow); + } + bitmask >>= 1; + } + if (overflow) { + ctx->SetStatus(Status::Invalid("overflow")); + } + return pow; + } + + template + static enable_if_floating_point Call(KernelContext* ctx, Arg0 base, Arg1 exp) { + static_assert(std::is_same::value && std::is_same::value, ""); + return std::pow(base, exp); + } +}; + // Generate a kernel given an arithmetic functor template