diff --git a/cpp/src/arrow/CMakeLists.txt b/cpp/src/arrow/CMakeLists.txt index 690c51a4a62..c4bb771b3c8 100644 --- a/cpp/src/arrow/CMakeLists.txt +++ b/cpp/src/arrow/CMakeLists.txt @@ -438,6 +438,7 @@ if(ARROW_COMPUTE) compute/kernels/scalar_validity.cc compute/kernels/util_internal.cc compute/kernels/vector_array_sort.cc + compute/kernels/vector_cumulative_ops.cc compute/kernels/vector_hash.cc compute/kernels/vector_nested.cc compute/kernels/vector_replace.cc diff --git a/cpp/src/arrow/compute/api_vector.cc b/cpp/src/arrow/compute/api_vector.cc index a5cb61d6b55..e3db022536c 100644 --- a/cpp/src/arrow/compute/api_vector.cc +++ b/cpp/src/arrow/compute/api_vector.cc @@ -135,6 +135,10 @@ static auto kPartitionNthOptionsType = GetFunctionOptionsType( DataMember("k", &SelectKOptions::k), DataMember("sort_keys", &SelectKOptions::sort_keys)); +static auto kCumulativeSumOptionsType = GetFunctionOptionsType( + DataMember("start", &CumulativeSumOptions::start), + DataMember("skip_nulls", &CumulativeSumOptions::skip_nulls), + DataMember("check_overflow", &CumulativeSumOptions::check_overflow)); } // namespace } // namespace internal @@ -176,6 +180,18 @@ SelectKOptions::SelectKOptions(int64_t k, std::vector sort_keys) sort_keys(std::move(sort_keys)) {} constexpr char SelectKOptions::kTypeName[]; +CumulativeSumOptions::CumulativeSumOptions(double start, bool skip_nulls, + bool check_overflow) + : CumulativeSumOptions(std::make_shared(start), skip_nulls, + check_overflow) {} +CumulativeSumOptions::CumulativeSumOptions(std::shared_ptr start, bool skip_nulls, + bool check_overflow) + : FunctionOptions(internal::kCumulativeSumOptionsType), + start(std::move(start)), + skip_nulls(skip_nulls), + check_overflow(check_overflow) {} +constexpr char CumulativeSumOptions::kTypeName[]; + namespace internal { void RegisterVectorOptions(FunctionRegistry* registry) { DCHECK_OK(registry->AddFunctionOptionsType(kFilterOptionsType)); @@ -185,6 +201,7 @@ void RegisterVectorOptions(FunctionRegistry* registry) { DCHECK_OK(registry->AddFunctionOptionsType(kSortOptionsType)); DCHECK_OK(registry->AddFunctionOptionsType(kPartitionNthOptionsType)); DCHECK_OK(registry->AddFunctionOptionsType(kSelectKOptionsType)); + DCHECK_OK(registry->AddFunctionOptionsType(kCumulativeSumOptionsType)); } } // namespace internal @@ -325,6 +342,15 @@ Result> DropNull(const Array& values, ExecContext* ctx) { return out.make_array(); } +// ---------------------------------------------------------------------- +// Cumulative functions + +Result CumulativeSum(const Datum& values, const CumulativeSumOptions& options, + ExecContext* ctx) { + auto func_name = (options.check_overflow) ? "cumulative_sum_checked" : "cumulative_sum"; + return CallFunction(func_name, {Datum(values)}, &options, ctx); +} + // ---------------------------------------------------------------------- // Deprecated functions diff --git a/cpp/src/arrow/compute/api_vector.h b/cpp/src/arrow/compute/api_vector.h index 9e53cfcf640..b5daddb17b9 100644 --- a/cpp/src/arrow/compute/api_vector.h +++ b/cpp/src/arrow/compute/api_vector.h @@ -188,6 +188,27 @@ class ARROW_EXPORT PartitionNthOptions : public FunctionOptions { NullPlacement null_placement; }; +/// \brief Options for cumulative sum function +class ARROW_EXPORT CumulativeSumOptions : public FunctionOptions { + public: + explicit CumulativeSumOptions(double start = 0, bool skip_nulls = false, + bool check_overflow = false); + explicit CumulativeSumOptions(std::shared_ptr start, bool skip_nulls = false, + bool check_overflow = false); + static constexpr char const kTypeName[] = "CumulativeSumOptions"; + static CumulativeSumOptions Defaults() { return CumulativeSumOptions(); } + + /// Optional starting value for cumulative operation computation + std::shared_ptr start; + + /// If true, nulls in the input are ignored and produce a corresponding null output. + /// When false, the first null encountered is propagated through the remaining output. + bool skip_nulls = false; + + /// When true, returns an Invalid Status when overflow is detected + bool check_overflow = false; +}; + /// @} /// \brief Filter with a boolean selection filter @@ -522,6 +543,12 @@ Result DictionaryEncode( const DictionaryEncodeOptions& options = DictionaryEncodeOptions::Defaults(), ExecContext* ctx = NULLPTR); +ARROW_EXPORT +Result CumulativeSum( + const Datum& values, + const CumulativeSumOptions& options = CumulativeSumOptions::Defaults(), + ExecContext* ctx = NULLPTR); + // ---------------------------------------------------------------------- // Deprecated functions diff --git a/cpp/src/arrow/compute/kernels/CMakeLists.txt b/cpp/src/arrow/compute/kernels/CMakeLists.txt index 0a7f6191120..780699886d2 100644 --- a/cpp/src/arrow/compute/kernels/CMakeLists.txt +++ b/cpp/src/arrow/compute/kernels/CMakeLists.txt @@ -48,6 +48,7 @@ add_arrow_benchmark(scalar_temporal_benchmark PREFIX "arrow-compute") add_arrow_compute_test(vector_test SOURCES + vector_cumulative_ops_test.cc vector_hash_test.cc vector_nested_test.cc vector_replace_test.cc diff --git a/cpp/src/arrow/compute/kernels/base_arithmetic_internal.h b/cpp/src/arrow/compute/kernels/base_arithmetic_internal.h new file mode 100644 index 00000000000..1707ed7c137 --- /dev/null +++ b/cpp/src/arrow/compute/kernels/base_arithmetic_internal.h @@ -0,0 +1,602 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include "arrow/compute/api_scalar.h" +#include "arrow/compute/kernels/common.h" +#include "arrow/compute/kernels/util_internal.h" +#include "arrow/type.h" +#include "arrow/type_traits.h" +#include "arrow/util/decimal.h" +#include "arrow/util/int_util_internal.h" +#include "arrow/util/macros.h" + +namespace arrow { + +using internal::AddWithOverflow; +using internal::DivideWithOverflow; +using internal::MultiplyWithOverflow; +using internal::NegateWithOverflow; +using internal::SubtractWithOverflow; + +namespace compute { +namespace internal { + +struct Add { + template + static constexpr enable_if_floating_value Call(KernelContext*, Arg0 left, Arg1 right, + Status*) { + return left + right; + } + + template + static constexpr enable_if_unsigned_integer_value Call(KernelContext*, Arg0 left, + Arg1 right, Status*) { + return left + right; + } + + template + static constexpr enable_if_signed_integer_value Call(KernelContext*, Arg0 left, + Arg1 right, Status*) { + return arrow::internal::SafeSignedAdd(left, right); + } + + template + static enable_if_decimal_value Call(KernelContext*, Arg0 left, Arg1 right, Status*) { + return left + right; + } +}; + +struct AddChecked { + template + static enable_if_integer_value Call(KernelContext*, Arg0 left, Arg1 right, + Status* st) { + static_assert(std::is_same::value && std::is_same::value, ""); + T result = 0; + if (ARROW_PREDICT_FALSE(AddWithOverflow(left, right, &result))) { + *st = Status::Invalid("overflow"); + } + return result; + } + + template + static enable_if_floating_value Call(KernelContext*, Arg0 left, Arg1 right, + Status*) { + static_assert(std::is_same::value && std::is_same::value, ""); + return left + right; + } + + template + static enable_if_decimal_value Call(KernelContext*, Arg0 left, Arg1 right, Status*) { + return left + right; + } +}; + +template +struct AddTimeDuration { + template + static T Call(KernelContext*, Arg0 left, Arg1 right, Status* st) { + T result = + arrow::internal::SafeSignedAdd(static_cast(left), static_cast(right)); + if (result < 0 || multiple <= result) { + *st = Status::Invalid(result, " is not within the acceptable range of ", "[0, ", + multiple, ") s"); + } + return result; + } +}; + +template +struct AddTimeDurationChecked { + template + static T Call(KernelContext*, Arg0 left, Arg1 right, Status* st) { + T result = 0; + if (ARROW_PREDICT_FALSE( + AddWithOverflow(static_cast(left), static_cast(right), &result))) { + *st = Status::Invalid("overflow"); + } + if (result < 0 || multiple <= result) { + *st = Status::Invalid(result, " is not within the acceptable range of ", "[0, ", + multiple, ") s"); + } + return result; + } +}; + +struct AbsoluteValue { + template + static constexpr enable_if_floating_value Call(KernelContext*, Arg arg, + Status*) { + return std::fabs(arg); + } + + template + static constexpr enable_if_unsigned_integer_value Call(KernelContext*, Arg arg, + Status*) { + return arg; + } + + template + static constexpr enable_if_signed_integer_value Call(KernelContext*, Arg arg, + Status* st) { + return (arg < 0) ? arrow::internal::SafeSignedNegate(arg) : arg; + } + + template + static constexpr enable_if_decimal_value Call(KernelContext*, Arg arg, + Status*) { + return arg.Abs(); + } +}; + +struct AbsoluteValueChecked { + template + static enable_if_signed_integer_value Call(KernelContext*, Arg arg, + Status* st) { + static_assert(std::is_same::value, ""); + if (arg == std::numeric_limits::min()) { + *st = Status::Invalid("overflow"); + return arg; + } + return std::abs(arg); + } + + template + static enable_if_unsigned_integer_value Call(KernelContext* ctx, Arg arg, + Status* st) { + static_assert(std::is_same::value, ""); + return arg; + } + + template + static constexpr enable_if_floating_value Call(KernelContext*, Arg arg, + Status* st) { + static_assert(std::is_same::value, ""); + return std::fabs(arg); + } + + template + static constexpr enable_if_decimal_value Call(KernelContext*, Arg arg, + Status*) { + return arg.Abs(); + } +}; + +struct Subtract { + template + static constexpr enable_if_floating_value Call(KernelContext*, Arg0 left, Arg1 right, + Status*) { + static_assert(std::is_same::value && std::is_same::value, ""); + return left - right; + } + + template + static constexpr enable_if_unsigned_integer_value Call(KernelContext*, Arg0 left, + Arg1 right, Status*) { + static_assert(std::is_same::value && std::is_same::value, ""); + return left - right; + } + + template + static constexpr enable_if_signed_integer_value Call(KernelContext*, Arg0 left, + Arg1 right, Status*) { + return arrow::internal::SafeSignedSubtract(left, right); + } + + template + static enable_if_decimal_value Call(KernelContext*, Arg0 left, Arg1 right, Status*) { + return left + (-right); + } +}; + +struct SubtractChecked { + template + static enable_if_integer_value Call(KernelContext*, Arg0 left, Arg1 right, + Status* st) { + T result = 0; + if (ARROW_PREDICT_FALSE(SubtractWithOverflow(left, right, &result))) { + *st = Status::Invalid("overflow"); + } + return result; + } + + template + static enable_if_floating_value Call(KernelContext*, Arg0 left, Arg1 right, + Status*) { + static_assert(std::is_same::value && std::is_same::value, ""); + return left - right; + } + + template + static enable_if_decimal_value Call(KernelContext*, Arg0 left, Arg1 right, Status*) { + return left + (-right); + } +}; + +struct SubtractDate32 { + static constexpr int64_t kSecondsInDay = 86400; + + template + static constexpr T Call(KernelContext*, Arg0 left, Arg1 right, Status*) { + return arrow::internal::SafeSignedSubtract(left, right) * kSecondsInDay; + } +}; + +struct SubtractCheckedDate32 { + static constexpr int64_t kSecondsInDay = 86400; + + template + static T Call(KernelContext*, Arg0 left, Arg1 right, Status* st) { + T result = 0; + if (ARROW_PREDICT_FALSE(SubtractWithOverflow(left, right, &result))) { + *st = Status::Invalid("overflow"); + } + if (ARROW_PREDICT_FALSE(MultiplyWithOverflow(result, kSecondsInDay, &result))) { + *st = Status::Invalid("overflow"); + } + return result; + } +}; + +template +struct SubtractTimeDuration { + template + static T Call(KernelContext*, Arg0 left, Arg1 right, Status* st) { + T result = arrow::internal::SafeSignedSubtract(left, static_cast(right)); + if (result < 0 || multiple <= result) { + *st = Status::Invalid(result, " is not within the acceptable range of ", "[0, ", + multiple, ") s"); + } + return result; + } +}; + +template +struct SubtractTimeDurationChecked { + template + static T Call(KernelContext*, Arg0 left, Arg1 right, Status* st) { + T result = 0; + if (ARROW_PREDICT_FALSE(SubtractWithOverflow(left, static_cast(right), &result))) { + *st = Status::Invalid("overflow"); + } + if (result < 0 || multiple <= result) { + *st = Status::Invalid(result, " is not within the acceptable range of ", "[0, ", + multiple, ") s"); + } + return result; + } +}; + +struct Multiply { + static_assert(std::is_same::value, ""); + static_assert(std::is_same::value, ""); + static_assert(std::is_same::value, ""); + static_assert(std::is_same::value, ""); + static_assert(std::is_same::value, ""); + static_assert(std::is_same::value, ""); + static_assert(std::is_same::value, ""); + static_assert(std::is_same::value, ""); + + template + static constexpr enable_if_floating_value Call(KernelContext*, T left, T right, + Status*) { + return left * right; + } + + template + static constexpr enable_if_t< + is_unsigned_integer_value::value && !std::is_same::value, T> + Call(KernelContext*, T left, T right, Status*) { + return left * right; + } + + template + static constexpr enable_if_t< + is_signed_integer_value::value && !std::is_same::value, T> + Call(KernelContext*, T left, T right, Status*) { + return to_unsigned(left) * to_unsigned(right); + } + + // Multiplication of 16 bit integer types implicitly promotes to signed 32 bit + // integer. However, some inputs may nevertheless overflow (which triggers undefined + // behaviour). Therefore we first cast to 32 bit unsigned integers where overflow is + // well defined. + template + static constexpr enable_if_same Call(KernelContext*, int16_t left, + int16_t right, Status*) { + return static_cast(left) * static_cast(right); + } + template + static constexpr enable_if_same Call(KernelContext*, uint16_t left, + uint16_t right, Status*) { + return static_cast(left) * static_cast(right); + } + + template + static enable_if_decimal_value Call(KernelContext*, Arg0 left, Arg1 right, Status*) { + return left * right; + } +}; + +struct MultiplyChecked { + template + static enable_if_integer_value Call(KernelContext*, Arg0 left, Arg1 right, + Status* st) { + static_assert(std::is_same::value && std::is_same::value, ""); + T result = 0; + if (ARROW_PREDICT_FALSE(MultiplyWithOverflow(left, right, &result))) { + *st = Status::Invalid("overflow"); + } + return result; + } + + template + static enable_if_floating_value Call(KernelContext*, Arg0 left, Arg1 right, + Status*) { + static_assert(std::is_same::value && std::is_same::value, ""); + return left * right; + } + + template + static enable_if_decimal_value Call(KernelContext*, Arg0 left, Arg1 right, Status*) { + return left * right; + } +}; + +struct Divide { + template + static enable_if_floating_value Call(KernelContext*, Arg0 left, Arg1 right, + Status*) { + return left / right; + } + + template + static enable_if_integer_value Call(KernelContext*, Arg0 left, Arg1 right, + Status* st) { + T result; + if (ARROW_PREDICT_FALSE(DivideWithOverflow(left, right, &result))) { + if (right == 0) { + *st = Status::Invalid("divide by zero"); + } else { + result = 0; + } + } + return result; + } + + template + static enable_if_decimal_value Call(KernelContext*, Arg0 left, Arg1 right, + Status* st) { + if (right == Arg1()) { + *st = Status::Invalid("Divide by zero"); + return T(); + } else { + return left / right; + } + } +}; + +struct DivideChecked { + template + static enable_if_integer_value Call(KernelContext*, Arg0 left, Arg1 right, + Status* st) { + static_assert(std::is_same::value && std::is_same::value, ""); + T result; + if (ARROW_PREDICT_FALSE(DivideWithOverflow(left, right, &result))) { + if (right == 0) { + *st = Status::Invalid("divide by zero"); + } else { + *st = Status::Invalid("overflow"); + } + } + return result; + } + + template + static enable_if_floating_value Call(KernelContext*, Arg0 left, Arg1 right, + Status* st) { + static_assert(std::is_same::value && std::is_same::value, ""); + if (ARROW_PREDICT_FALSE(right == 0)) { + *st = Status::Invalid("divide by zero"); + return 0; + } + return left / right; + } + + template + static enable_if_decimal_value Call(KernelContext* ctx, Arg0 left, Arg1 right, + Status* st) { + return Divide::Call(ctx, left, right, st); + } +}; + +struct Negate { + template + static constexpr enable_if_floating_value Call(KernelContext*, Arg arg, Status*) { + return -arg; + } + + template + static constexpr enable_if_unsigned_integer_value Call(KernelContext*, Arg arg, + Status*) { + return ~arg + 1; + } + + template + static constexpr enable_if_signed_integer_value Call(KernelContext*, Arg arg, + Status*) { + return arrow::internal::SafeSignedNegate(arg); + } + + template + static constexpr enable_if_decimal_value Call(KernelContext*, Arg arg, + Status*) { + return arg.Negate(); + } +}; + +struct NegateChecked { + template + static enable_if_signed_integer_value Call(KernelContext*, Arg arg, + Status* st) { + static_assert(std::is_same::value, ""); + T result = 0; + if (ARROW_PREDICT_FALSE(NegateWithOverflow(arg, &result))) { + *st = Status::Invalid("overflow"); + } + return result; + } + + template + static enable_if_unsigned_integer_value Call(KernelContext* ctx, Arg arg, + Status* st) { + static_assert(std::is_same::value, ""); + DCHECK(false) << "This is included only for the purposes of instantiability from the " + "arithmetic kernel generator"; + return 0; + } + + template + static constexpr enable_if_floating_value Call(KernelContext*, Arg arg, + Status* st) { + static_assert(std::is_same::value, ""); + return -arg; + } + + template + static constexpr enable_if_decimal_value Call(KernelContext*, Arg arg, + Status*) { + return arg.Negate(); + } +}; + +struct Power { + ARROW_NOINLINE + static uint64_t IntegerPower(uint64_t base, uint64_t exp) { + // right to left O(logn) power + uint64_t pow = 1; + while (exp) { + pow *= (exp & 1) ? base : 1; + base *= base; + exp >>= 1; + } + return pow; + } + + template + static enable_if_integer_value Call(KernelContext*, T base, T exp, Status* st) { + if (exp < 0) { + *st = Status::Invalid("integers to negative integer powers are not allowed"); + return 0; + } + return static_cast(IntegerPower(base, exp)); + } + + template + static enable_if_floating_value Call(KernelContext*, T base, T exp, Status*) { + return std::pow(base, exp); + } +}; + +struct PowerChecked { + template + static enable_if_integer_value Call(KernelContext*, Arg0 base, Arg1 exp, + Status* st) { + if (exp < 0) { + *st = Status::Invalid("integers to negative integer powers are not allowed"); + return 0; + } else if (exp == 0) { + return 1; + } + // left to right O(logn) power with overflow checks + bool overflow = false; + uint64_t bitmask = + 1ULL << (63 - bit_util::CountLeadingZeros(static_cast(exp))); + T pow = 1; + while (bitmask) { + overflow |= MultiplyWithOverflow(pow, pow, &pow); + if (exp & bitmask) { + overflow |= MultiplyWithOverflow(pow, base, &pow); + } + bitmask >>= 1; + } + if (overflow) { + *st = Status::Invalid("overflow"); + } + return pow; + } + + template + static enable_if_floating_value Call(KernelContext*, Arg0 base, Arg1 exp, Status*) { + static_assert(std::is_same::value && std::is_same::value, ""); + return std::pow(base, exp); + } +}; + +struct SquareRoot { + template + static enable_if_floating_value Call(KernelContext*, Arg arg, Status*) { + static_assert(std::is_same::value, ""); + if (arg < 0.0) { + return std::numeric_limits::quiet_NaN(); + } + return std::sqrt(arg); + } +}; + +struct SquareRootChecked { + template + static enable_if_floating_value Call(KernelContext*, Arg arg, Status* st) { + static_assert(std::is_same::value, ""); + if (arg < 0.0) { + *st = Status::Invalid("square root of negative number"); + return arg; + } + return std::sqrt(arg); + } +}; + +struct Sign { + template + static constexpr enable_if_floating_value Call(KernelContext*, Arg arg, + Status*) { + return std::isnan(arg) ? arg : ((arg == 0) ? 0 : (std::signbit(arg) ? -1 : 1)); + } + + template + static constexpr enable_if_unsigned_integer_value Call(KernelContext*, Arg arg, + Status*) { + return (arg > 0) ? 1 : 0; + } + + template + static constexpr enable_if_signed_integer_value Call(KernelContext*, Arg arg, + Status*) { + return (arg > 0) ? 1 : ((arg == 0) ? 0 : -1); + } + + template + static constexpr enable_if_decimal_value Call(KernelContext*, Arg arg, + Status*) { + return (arg == 0) ? 0 : arg.Sign(); + } +}; + +} // namespace internal +} // namespace compute +} // namespace arrow diff --git a/cpp/src/arrow/compute/kernels/codegen_internal.h b/cpp/src/arrow/compute/kernels/codegen_internal.h index fa50427bc3e..6d31c1fe246 100644 --- a/cpp/src/arrow/compute/kernels/codegen_internal.h +++ b/cpp/src/arrow/compute/kernels/codegen_internal.h @@ -1134,6 +1134,37 @@ ArrayKernelExec GeneratePhysicalInteger(detail::GetTypeId get_id) { } } +template