Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions cpp/src/arrow/compute/api_scalar.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,16 @@ namespace compute {
// ----------------------------------------------------------------------
// Arithmetic

SCALAR_EAGER_BINARY(Add, "add")
SCALAR_EAGER_BINARY(Subtract, "subtract")
SCALAR_EAGER_BINARY(Multiply, "multiply")
#define SCALAR_ARITHMETIC_BINARY(NAME, REGISTRY_NAME, REGISTRY_CHECKED_NAME) \
Result<Datum> NAME(const Datum& left, const Datum& right, ArithmeticOptions options, \
ExecContext* ctx) { \
auto func_name = (options.check_overflow) ? REGISTRY_CHECKED_NAME : REGISTRY_NAME; \
return CallFunction(func_name, {left, right}, ctx); \
}

SCALAR_ARITHMETIC_BINARY(Add, "add", "add_checked")
SCALAR_ARITHMETIC_BINARY(Subtract, "subtract", "subtract_checked")
SCALAR_ARITHMETIC_BINARY(Multiply, "multiply", "multiply_checked")

// ----------------------------------------------------------------------
// Set-related operations
Expand Down
17 changes: 14 additions & 3 deletions cpp/src/arrow/compute/api_scalar.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,11 @@ namespace compute {

// ----------------------------------------------------------------------

struct ArithmeticOptions : public FunctionOptions {
ArithmeticOptions() : check_overflow(false) {}
bool check_overflow;
};

/// \brief Add two values together. Array values must be the same length. If
/// either addend is null the result will be null.
///
Expand All @@ -43,7 +48,9 @@ namespace compute {
/// \param[in] ctx the function execution context, optional
/// \return the elementwise sum
ARROW_EXPORT
Result<Datum> Add(const Datum& left, const Datum& right, ExecContext* ctx = NULLPTR);
Result<Datum> Add(const Datum& left, const Datum& right,
ArithmeticOptions options = ArithmeticOptions(),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to document this.
This breaks CI: https://github.com/apache/arrow/runs/786249626

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I fixed it in #7492

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

ExecContext* ctx = NULLPTR);

/// \brief Subtract two values. Array values must be the same length. If the
/// minuend or subtrahend is null the result will be null.
Expand All @@ -53,7 +60,9 @@ Result<Datum> Add(const Datum& left, const Datum& right, ExecContext* ctx = NULL
/// \param[in] ctx the function execution context, optional
/// \return the elementwise difference
ARROW_EXPORT
Result<Datum> Subtract(const Datum& left, const Datum& right, ExecContext* ctx = NULLPTR);
Result<Datum> Subtract(const Datum& left, const Datum& right,
ArithmeticOptions options = ArithmeticOptions(),
ExecContext* ctx = NULLPTR);

/// \brief Multiply two values. Array values must be the same length. If either
/// factor is null the result will be null.
Expand All @@ -63,7 +72,9 @@ Result<Datum> Subtract(const Datum& left, const Datum& right, ExecContext* ctx =
/// \param[in] ctx the function execution context, optional
/// \return the elementwise product
ARROW_EXPORT
Result<Datum> Multiply(const Datum& left, const Datum& right, ExecContext* ctx = NULLPTR);
Result<Datum> Multiply(const Datum& left, const Datum& right,
ArithmeticOptions options = ArithmeticOptions(),
ExecContext* ctx = NULLPTR);

enum CompareOperator {
EQUAL,
Expand Down
104 changes: 104 additions & 0 deletions cpp/src/arrow/compute/kernels/scalar_arithmetic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@
#include "arrow/compute/kernels/common.h"
#include "arrow/util/int_util.h"

#ifndef __has_builtin
#define __has_builtin(x) 0
#endif

namespace arrow {
namespace compute {

Expand All @@ -35,6 +39,10 @@ using enable_if_signed_integer = enable_if_t<is_signed_integer<T>::value, T>;
template <typename T>
using enable_if_unsigned_integer = enable_if_t<is_unsigned_integer<T>::value, T>;

template <typename T>
using enable_if_integer =
enable_if_t<is_signed_integer<T>::value || is_unsigned_integer<T>::value, T>;

template <typename T>
using enable_if_floating_point = enable_if_t<std::is_floating_point<T>::value, T>;

Expand All @@ -60,6 +68,42 @@ struct Add {
}
};

struct AddChecked {
#if __has_builtin(__builtin_add_overflow)
template <typename T>
static enable_if_integer<T> Call(KernelContext* ctx, T left, T right) {
T result;
if (__builtin_add_overflow(left, right, &result)) {
ctx->SetStatus(Status::Invalid("overflow"));
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Which error should we raise? ExecutionError?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is fine for now

}
return result;
}
#else
template <typename T>
static enable_if_unsigned_integer<T> Call(KernelContext* ctx, T left, T right) {
if (arrow::internal::HasAdditionOverflow(left, right)) {
ctx->SetStatus(Status::Invalid("overflow"));
}
return left + right;
}

template <typename T>
static enable_if_signed_integer<T> Call(KernelContext* ctx, T left, T right) {
auto unsigned_left = to_unsigned(left);
auto unsigned_right = to_unsigned(right);
if (arrow::internal::HasAdditionOverflow(unsigned_left, unsigned_right)) {
ctx->SetStatus(Status::Invalid("overflow"));
}
return unsigned_left + unsigned_right;
}
#endif

template <typename T>
static constexpr enable_if_floating_point<T> Call(KernelContext*, T left, T right) {
return left + right;
}
};

struct Subtract {
template <typename T>
static constexpr enable_if_floating_point<T> Call(KernelContext*, T left, T right) {
Expand All @@ -77,6 +121,40 @@ struct Subtract {
}
};

struct SubtractChecked {
#if __has_builtin(__builtin_sub_overflow)
template <typename T>
static enable_if_integer<T> Call(KernelContext* ctx, T left, T right) {
T result;
if (__builtin_sub_overflow(left, right, &result)) {
ctx->SetStatus(Status::Invalid("overflow"));
}
return result;
}
#else
template <typename T>
static enable_if_unsigned_integer<T> Call(KernelContext* ctx, T left, T right) {
if (arrow::internal::HasSubtractionOverflow(left, right)) {
ctx->SetStatus(Status::Invalid("overflow"));
}
return left - right;
}

template <typename T>
static enable_if_signed_integer<T> Call(KernelContext* ctx, T left, T right) {
if (arrow::internal::HasSubtractionOverflow(left, right)) {
ctx->SetStatus(Status::Invalid("overflow"));
}
return to_unsigned(left) - to_unsigned(right);
}
#endif

template <typename T>
static constexpr enable_if_floating_point<T> Call(KernelContext*, T left, T right) {
return left - right;
}
};

struct Multiply {
static_assert(std::is_same<decltype(int8_t() * int8_t()), int32_t>::value, "");
static_assert(std::is_same<decltype(uint8_t() * uint8_t()), int32_t>::value, "");
Expand Down Expand Up @@ -116,6 +194,29 @@ struct Multiply {
}
};

struct MultiplyChecked {
template <typename T>
static enable_if_integer<T> Call(KernelContext* ctx, T left, T right) {
T result;
#if __has_builtin(__builtin_mul_overflow)
if (__builtin_mul_overflow(left, right, &result)) {
ctx->SetStatus(Status::Invalid("overflow"));
}
#else
result = Multiply::Call(ctx, left, right);
if (left != 0 && result / left != right) {
ctx->SetStatus(Status::Invalid("overflow"));
}
#endif
return result;
}

template <typename T>
static constexpr enable_if_floating_point<T> Call(KernelContext*, T left, T right) {
return left * right;
}
};

namespace codegen {

// Generate a kernel given an arithmetic functor
Expand Down Expand Up @@ -168,8 +269,11 @@ namespace internal {

void RegisterScalarArithmetic(FunctionRegistry* registry) {
codegen::AddBinaryFunction<Add>("add", registry);
codegen::AddBinaryFunction<AddChecked>("add_checked", registry);
codegen::AddBinaryFunction<Subtract>("subtract", registry);
codegen::AddBinaryFunction<SubtractChecked>("subtract_checked", registry);
codegen::AddBinaryFunction<Multiply>("multiply", registry);
codegen::AddBinaryFunction<MultiplyChecked>("multiply_checked", registry);
}

} // namespace internal
Expand Down
7 changes: 4 additions & 3 deletions cpp/src/arrow/compute/kernels/scalar_arithmetic_benchmark.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ namespace compute {

constexpr auto kSeed = 0x94378165;

using BinaryOp = Result<Datum>(const Datum&, const Datum&, ExecContext*);
using BinaryOp = Result<Datum>(const Datum&, const Datum&, ArithmeticOptions,
ExecContext*);

template <BinaryOp& Op, typename ArrowType, typename CType = typename ArrowType::c_type>
static void ArrayScalarKernel(benchmark::State& state) {
Expand All @@ -46,7 +47,7 @@ static void ArrayScalarKernel(benchmark::State& state) {

Datum fifteen(CType(15));
for (auto _ : state) {
ABORT_NOT_OK(Op(lhs, fifteen, nullptr).status());
ABORT_NOT_OK(Op(lhs, fifteen, ArithmeticOptions(), nullptr).status());
}
state.SetItemsProcessed(state.iterations() * array_size);
}
Expand All @@ -66,7 +67,7 @@ static void ArrayArrayKernel(benchmark::State& state) {
rand.Numeric<ArrowType>(array_size, min, max, args.null_proportion));

for (auto _ : state) {
ABORT_NOT_OK(Op(lhs, rhs, nullptr).status());
ABORT_NOT_OK(Op(lhs, rhs, ArithmeticOptions(), nullptr).status());
}
state.SetItemsProcessed(state.iterations() * array_size);
}
Expand Down
Loading