diff --git a/cpp/src/arrow/compute/api_scalar.cc b/cpp/src/arrow/compute/api_scalar.cc index 9cdce7c1f16..381545fff8b 100644 --- a/cpp/src/arrow/compute/api_scalar.cc +++ b/cpp/src/arrow/compute/api_scalar.cc @@ -41,9 +41,16 @@ namespace compute { // ---------------------------------------------------------------------- // Arithmetic -SCALAR_EAGER_BINARY(Add, "add") -SCALAR_EAGER_BINARY(Subtract, "subtract") -SCALAR_EAGER_BINARY(Multiply, "multiply") +#define SCALAR_ARITHMETIC_BINARY(NAME, REGISTRY_NAME, REGISTRY_CHECKED_NAME) \ + Result NAME(const Datum& left, const Datum& right, ArithmeticOptions options, \ + ExecContext* ctx) { \ + auto func_name = (options.check_overflow) ? REGISTRY_CHECKED_NAME : REGISTRY_NAME; \ + return CallFunction(func_name, {left, right}, ctx); \ + } + +SCALAR_ARITHMETIC_BINARY(Add, "add", "add_checked") +SCALAR_ARITHMETIC_BINARY(Subtract, "subtract", "subtract_checked") +SCALAR_ARITHMETIC_BINARY(Multiply, "multiply", "multiply_checked") // ---------------------------------------------------------------------- // Set-related operations diff --git a/cpp/src/arrow/compute/api_scalar.h b/cpp/src/arrow/compute/api_scalar.h index bc502f7bcb9..28609618729 100644 --- a/cpp/src/arrow/compute/api_scalar.h +++ b/cpp/src/arrow/compute/api_scalar.h @@ -35,6 +35,11 @@ namespace compute { // ---------------------------------------------------------------------- +struct ArithmeticOptions : public FunctionOptions { + ArithmeticOptions() : check_overflow(false) {} + bool check_overflow; +}; + /// \brief Add two values together. Array values must be the same length. If /// either addend is null the result will be null. /// @@ -43,7 +48,9 @@ namespace compute { /// \param[in] ctx the function execution context, optional /// \return the elementwise sum ARROW_EXPORT -Result Add(const Datum& left, const Datum& right, ExecContext* ctx = NULLPTR); +Result Add(const Datum& left, const Datum& right, + ArithmeticOptions options = ArithmeticOptions(), + ExecContext* ctx = NULLPTR); /// \brief Subtract two values. Array values must be the same length. If the /// minuend or subtrahend is null the result will be null. @@ -53,7 +60,9 @@ Result Add(const Datum& left, const Datum& right, ExecContext* ctx = NULL /// \param[in] ctx the function execution context, optional /// \return the elementwise difference ARROW_EXPORT -Result Subtract(const Datum& left, const Datum& right, ExecContext* ctx = NULLPTR); +Result Subtract(const Datum& left, const Datum& right, + ArithmeticOptions options = ArithmeticOptions(), + ExecContext* ctx = NULLPTR); /// \brief Multiply two values. Array values must be the same length. If either /// factor is null the result will be null. @@ -63,7 +72,9 @@ Result Subtract(const Datum& left, const Datum& right, ExecContext* ctx = /// \param[in] ctx the function execution context, optional /// \return the elementwise product ARROW_EXPORT -Result Multiply(const Datum& left, const Datum& right, ExecContext* ctx = NULLPTR); +Result Multiply(const Datum& left, const Datum& right, + ArithmeticOptions options = ArithmeticOptions(), + ExecContext* ctx = NULLPTR); enum CompareOperator { EQUAL, diff --git a/cpp/src/arrow/compute/kernels/scalar_arithmetic.cc b/cpp/src/arrow/compute/kernels/scalar_arithmetic.cc index 4459e4ffa35..36f52a78bc4 100644 --- a/cpp/src/arrow/compute/kernels/scalar_arithmetic.cc +++ b/cpp/src/arrow/compute/kernels/scalar_arithmetic.cc @@ -18,6 +18,10 @@ #include "arrow/compute/kernels/common.h" #include "arrow/util/int_util.h" +#ifndef __has_builtin +#define __has_builtin(x) 0 +#endif + namespace arrow { namespace compute { @@ -35,6 +39,10 @@ using enable_if_signed_integer = enable_if_t::value, T>; template using enable_if_unsigned_integer = enable_if_t::value, T>; +template +using enable_if_integer = + enable_if_t::value || is_unsigned_integer::value, T>; + template using enable_if_floating_point = enable_if_t::value, T>; @@ -60,6 +68,42 @@ struct Add { } }; +struct AddChecked { +#if __has_builtin(__builtin_add_overflow) + template + static enable_if_integer Call(KernelContext* ctx, T left, T right) { + T result; + if (__builtin_add_overflow(left, right, &result)) { + ctx->SetStatus(Status::Invalid("overflow")); + } + return result; + } +#else + template + static enable_if_unsigned_integer Call(KernelContext* ctx, T left, T right) { + if (arrow::internal::HasAdditionOverflow(left, right)) { + ctx->SetStatus(Status::Invalid("overflow")); + } + return left + right; + } + + template + static enable_if_signed_integer Call(KernelContext* ctx, T left, T right) { + auto unsigned_left = to_unsigned(left); + auto unsigned_right = to_unsigned(right); + if (arrow::internal::HasAdditionOverflow(unsigned_left, unsigned_right)) { + ctx->SetStatus(Status::Invalid("overflow")); + } + return unsigned_left + unsigned_right; + } +#endif + + template + static constexpr enable_if_floating_point Call(KernelContext*, T left, T right) { + return left + right; + } +}; + struct Subtract { template static constexpr enable_if_floating_point Call(KernelContext*, T left, T right) { @@ -77,6 +121,40 @@ struct Subtract { } }; +struct SubtractChecked { +#if __has_builtin(__builtin_sub_overflow) + template + static enable_if_integer Call(KernelContext* ctx, T left, T right) { + T result; + if (__builtin_sub_overflow(left, right, &result)) { + ctx->SetStatus(Status::Invalid("overflow")); + } + return result; + } +#else + template + static enable_if_unsigned_integer Call(KernelContext* ctx, T left, T right) { + if (arrow::internal::HasSubtractionOverflow(left, right)) { + ctx->SetStatus(Status::Invalid("overflow")); + } + return left - right; + } + + template + static enable_if_signed_integer Call(KernelContext* ctx, T left, T right) { + if (arrow::internal::HasSubtractionOverflow(left, right)) { + ctx->SetStatus(Status::Invalid("overflow")); + } + return to_unsigned(left) - to_unsigned(right); + } +#endif + + template + static constexpr enable_if_floating_point Call(KernelContext*, T left, T right) { + return left - right; + } +}; + struct Multiply { static_assert(std::is_same::value, ""); static_assert(std::is_same::value, ""); @@ -116,6 +194,29 @@ struct Multiply { } }; +struct MultiplyChecked { + template + static enable_if_integer Call(KernelContext* ctx, T left, T right) { + T result; +#if __has_builtin(__builtin_mul_overflow) + if (__builtin_mul_overflow(left, right, &result)) { + ctx->SetStatus(Status::Invalid("overflow")); + } +#else + result = Multiply::Call(ctx, left, right); + if (left != 0 && result / left != right) { + ctx->SetStatus(Status::Invalid("overflow")); + } +#endif + return result; + } + + template + static constexpr enable_if_floating_point Call(KernelContext*, T left, T right) { + return left * right; + } +}; + namespace codegen { // Generate a kernel given an arithmetic functor @@ -168,8 +269,11 @@ namespace internal { void RegisterScalarArithmetic(FunctionRegistry* registry) { codegen::AddBinaryFunction("add", registry); + codegen::AddBinaryFunction("add_checked", registry); codegen::AddBinaryFunction("subtract", registry); + codegen::AddBinaryFunction("subtract_checked", registry); codegen::AddBinaryFunction("multiply", registry); + codegen::AddBinaryFunction("multiply_checked", registry); } } // namespace internal diff --git a/cpp/src/arrow/compute/kernels/scalar_arithmetic_benchmark.cc b/cpp/src/arrow/compute/kernels/scalar_arithmetic_benchmark.cc index 34991b99819..b301c95c680 100644 --- a/cpp/src/arrow/compute/kernels/scalar_arithmetic_benchmark.cc +++ b/cpp/src/arrow/compute/kernels/scalar_arithmetic_benchmark.cc @@ -30,7 +30,8 @@ namespace compute { constexpr auto kSeed = 0x94378165; -using BinaryOp = Result(const Datum&, const Datum&, ExecContext*); +using BinaryOp = Result(const Datum&, const Datum&, ArithmeticOptions, + ExecContext*); template static void ArrayScalarKernel(benchmark::State& state) { @@ -46,7 +47,7 @@ static void ArrayScalarKernel(benchmark::State& state) { Datum fifteen(CType(15)); for (auto _ : state) { - ABORT_NOT_OK(Op(lhs, fifteen, nullptr).status()); + ABORT_NOT_OK(Op(lhs, fifteen, ArithmeticOptions(), nullptr).status()); } state.SetItemsProcessed(state.iterations() * array_size); } @@ -66,7 +67,7 @@ static void ArrayArrayKernel(benchmark::State& state) { rand.Numeric(array_size, min, max, args.null_proportion)); for (auto _ : state) { - ABORT_NOT_OK(Op(lhs, rhs, nullptr).status()); + ABORT_NOT_OK(Op(lhs, rhs, ArithmeticOptions(), nullptr).status()); } state.SetItemsProcessed(state.iterations() * array_size); } diff --git a/cpp/src/arrow/compute/kernels/scalar_arithmetic_test.cc b/cpp/src/arrow/compute/kernels/scalar_arithmetic_test.cc index 2f2159ef642..87c3750f123 100644 --- a/cpp/src/arrow/compute/kernels/scalar_arithmetic_test.cc +++ b/cpp/src/arrow/compute/kernels/scalar_arithmetic_test.cc @@ -47,8 +47,10 @@ class TestBinaryArithmetics : public TestBase { return TypeTraits::type_singleton(); } - using BinaryFunction = - std::function(const Datum&, const Datum&, ExecContext*)>; + using BinaryFunction = std::function(const Datum&, const Datum&, + ArithmeticOptions, ExecContext*)>; + + void SetUp() { options_.check_overflow = false; } // (Scalar, Scalar) void AssertBinop(BinaryFunction func, CType lhs, CType rhs, CType expected) { @@ -56,7 +58,7 @@ class TestBinaryArithmetics : public TestBase { ASSERT_OK_AND_ASSIGN(auto right, MakeScalar(type_singleton(), rhs)); ASSERT_OK_AND_ASSIGN(auto exp, MakeScalar(type_singleton(), expected)); - ASSERT_OK_AND_ASSIGN(auto actual, func(left, right, nullptr)); + ASSERT_OK_AND_ASSIGN(auto actual, func(left, right, options_, nullptr)); AssertScalarsEqual(*exp, *actual.scalar(), true); } @@ -67,7 +69,7 @@ class TestBinaryArithmetics : public TestBase { auto right = ArrayFromJSON(type_singleton(), rhs); auto exp = ArrayFromJSON(type_singleton(), expected); - ASSERT_OK_AND_ASSIGN(auto actual, func(left, right, nullptr)); + ASSERT_OK_AND_ASSIGN(auto actual, func(left, right, options_, nullptr)); ValidateAndAssertApproxEqual(actual.make_array(), expected); } @@ -77,16 +79,29 @@ class TestBinaryArithmetics : public TestBase { auto left = ArrayFromJSON(type_singleton(), lhs); auto right = ArrayFromJSON(type_singleton(), rhs); - ASSERT_OK_AND_ASSIGN(Datum actual, func(left, right, nullptr)); + ASSERT_OK_AND_ASSIGN(Datum actual, func(left, right, options_, nullptr)); ValidateAndAssertApproxEqual(actual.make_array(), expected); } + void AssertBinopRaises(BinaryFunction func, const std::string& lhs, + const std::string& rhs, const std::string& expected_msg) { + auto left = ArrayFromJSON(type_singleton(), lhs); + auto right = ArrayFromJSON(type_singleton(), rhs); + + EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid, testing::HasSubstr(expected_msg), + func(left, right, options_, nullptr)); + } + void ValidateAndAssertApproxEqual(std::shared_ptr actual, const std::string& expected) { auto exp = ArrayFromJSON(type_singleton(), expected); ASSERT_OK(actual->ValidateFull()); AssertArraysApproxEqual(*exp, *actual); } + + void SetOverflowCheck(bool value = true) { options_.check_overflow = value; } + + ArithmeticOptions options_ = ArithmeticOptions(); }; template @@ -130,57 +145,58 @@ TYPED_TEST_SUITE(TestBinaryArithmeticsUnsigned, UnsignedIntegerTypes); TYPED_TEST_SUITE(TestBinaryArithmeticsFloating, FloatingTypes); TYPED_TEST(TestBinaryArithmeticsIntegral, Add) { - this->AssertBinop(arrow::compute::Add, "[]", "[]", "[]"); - this->AssertBinop(arrow::compute::Add, "[null]", "[null]", "[null]"); - this->AssertBinop(arrow::compute::Add, "[3, 2, 6]", "[1, 0, 2]", "[4, 2, 8]"); - - this->AssertBinop(arrow::compute::Add, "[1, 2, 3, 4, 5, 6, 7]", "[0, 1, 2, 3, 4, 5, 6]", - "[1, 3, 5, 7, 9, 11, 13]"); - - this->AssertBinop(arrow::compute::Add, "[10, 12, 4, 50, 50, 32, 11]", - "[2, 0, 6, 1, 5, 3, 4]", "[12, 12, 10, 51, 55, 35, 15]"); - - this->AssertBinop(arrow::compute::Add, "[null, 1, 3, null, 2, 5]", "[1, 4, 2, 5, 0, 3]", - "[null, 5, 5, null, 2, 8]"); - - this->AssertBinop(arrow::compute::Add, 10, "[null, 1, 3, null, 2, 5]", - "[null, 11, 13, null, 12, 15]"); - - this->AssertBinop(arrow::compute::Add, 17, 42, 59); + for (auto check_overflow : {false, true}) { + this->SetOverflowCheck(check_overflow); + + this->AssertBinop(arrow::compute::Add, "[]", "[]", "[]"); + this->AssertBinop(arrow::compute::Add, "[null]", "[null]", "[null]"); + this->AssertBinop(arrow::compute::Add, "[3, 2, 6]", "[1, 0, 2]", "[4, 2, 8]"); + + this->AssertBinop(arrow::compute::Add, "[1, 2, 3, 4, 5, 6, 7]", + "[0, 1, 2, 3, 4, 5, 6]", "[1, 3, 5, 7, 9, 11, 13]"); + + this->AssertBinop(arrow::compute::Add, "[10, 12, 4, 50, 50, 32, 11]", + "[2, 0, 6, 1, 5, 3, 4]", "[12, 12, 10, 51, 55, 35, 15]"); + this->AssertBinop(arrow::compute::Add, "[null, 1, 3, null, 2, 5]", + "[1, 4, 2, 5, 0, 3]", "[null, 5, 5, null, 2, 8]"); + this->AssertBinop(arrow::compute::Add, 10, "[null, 1, 3, null, 2, 5]", + "[null, 11, 13, null, 12, 15]"); + this->AssertBinop(arrow::compute::Add, 17, 42, 59); + } } TYPED_TEST(TestBinaryArithmeticsIntegral, Sub) { - this->AssertBinop(arrow::compute::Subtract, "[]", "[]", "[]"); - this->AssertBinop(arrow::compute::Subtract, "[null]", "[null]", "[null]"); - this->AssertBinop(arrow::compute::Subtract, "[3, 2, 6]", "[1, 0, 2]", "[2, 2, 4]"); - - this->AssertBinop(arrow::compute::Subtract, "[1, 2, 3, 4, 5, 6, 7]", - "[0, 1, 2, 3, 4, 5, 6]", "[1, 1, 1, 1, 1, 1, 1]"); - - this->AssertBinop(arrow::compute::Subtract, 10, "[null, 1, 3, null, 2, 5]", - "[null, 9, 7, null, 8, 5]"); - - this->AssertBinop(arrow::compute::Subtract, 20, 9, 11); + for (auto check_overflow : {false, true}) { + this->SetOverflowCheck(check_overflow); + + this->AssertBinop(arrow::compute::Subtract, "[]", "[]", "[]"); + this->AssertBinop(arrow::compute::Subtract, "[null]", "[null]", "[null]"); + this->AssertBinop(arrow::compute::Subtract, "[3, 2, 6]", "[1, 0, 2]", "[2, 2, 4]"); + this->AssertBinop(arrow::compute::Subtract, "[1, 2, 3, 4, 5, 6, 7]", + "[0, 1, 2, 3, 4, 5, 6]", "[1, 1, 1, 1, 1, 1, 1]"); + this->AssertBinop(arrow::compute::Subtract, 10, "[null, 1, 3, null, 2, 5]", + "[null, 9, 7, null, 8, 5]"); + this->AssertBinop(arrow::compute::Subtract, 20, 9, 11); + } } TYPED_TEST(TestBinaryArithmeticsIntegral, Mul) { - this->AssertBinop(arrow::compute::Multiply, "[]", "[]", "[]"); - this->AssertBinop(arrow::compute::Multiply, "[null]", "[null]", "[null]"); - this->AssertBinop(arrow::compute::Multiply, "[3, 2, 6]", "[1, 0, 2]", "[3, 0, 12]"); - - this->AssertBinop(arrow::compute::Multiply, "[1, 2, 3, 4, 5, 6, 7]", - "[0, 1, 2, 3, 4, 5, 6]", "[0, 2, 6, 12, 20, 30, 42]"); - - this->AssertBinop(arrow::compute::Multiply, "[7, 6, 5, 4, 3, 2, 1]", - "[6, 5, 4, 3, 2, 1, 0]", "[42, 30, 20, 12, 6, 2, 0]"); - - this->AssertBinop(arrow::compute::Multiply, "[null, 1, 3, null, 2, 5]", - "[1, 4, 2, 5, 0, 3]", "[null, 4, 6, null, 0, 15]"); - - this->AssertBinop(arrow::compute::Multiply, 3, "[null, 1, 3, null, 2, 5]", - "[null, 3, 9, null, 6, 15]"); - - this->AssertBinop(arrow::compute::Multiply, 6, 7, 42); + for (auto check_overflow : {false, true}) { + this->SetOverflowCheck(check_overflow); + + this->AssertBinop(arrow::compute::Multiply, "[]", "[]", "[]"); + this->AssertBinop(arrow::compute::Multiply, "[null]", "[null]", "[null]"); + this->AssertBinop(arrow::compute::Multiply, "[3, 2, 6]", "[1, 0, 2]", "[3, 0, 12]"); + this->AssertBinop(arrow::compute::Multiply, "[1, 2, 3, 4, 5, 6, 7]", + "[0, 1, 2, 3, 4, 5, 6]", "[0, 2, 6, 12, 20, 30, 42]"); + this->AssertBinop(arrow::compute::Multiply, "[7, 6, 5, 4, 3, 2, 1]", + "[6, 5, 4, 3, 2, 1, 0]", "[42, 30, 20, 12, 6, 2, 0]"); + this->AssertBinop(arrow::compute::Multiply, "[null, 1, 3, null, 2, 5]", + "[1, 4, 2, 5, 0, 3]", "[null, 4, 6, null, 0, 15]"); + this->AssertBinop(arrow::compute::Multiply, 3, "[null, 1, 3, null, 2, 5]", + "[null, 3, 9, null, 6, 15]"); + this->AssertBinop(arrow::compute::Multiply, 6, 7, 42); + } } TYPED_TEST(TestBinaryArithmeticsSigned, Add) { @@ -202,17 +218,52 @@ TYPED_TEST(TestBinaryArithmeticsSigned, OverflowWraps) { this->AssertBinop(arrow::compute::Subtract, MakeArray(min, max, min), MakeArray(1, max, max), MakeArray(max, 0, 1)); - this->AssertBinop(arrow::compute::Multiply, MakeArray(min, max, max), MakeArray(max, 2, max), MakeArray(min, CType(-2), 1)); } +TYPED_TEST(TestBinaryArithmeticsIntegral, OverflowRaises) { + using CType = typename TestFixture::CType; + + auto min = std::numeric_limits::lowest(); + auto max = std::numeric_limits::max(); + + this->SetOverflowCheck(true); + + this->AssertBinopRaises(arrow::compute::Add, MakeArray(min, max, max), + MakeArray(CType(-1), 1, max), "overflow"); + this->AssertBinopRaises(arrow::compute::Subtract, MakeArray(min, max), + MakeArray(1, max), "overflow"); + this->AssertBinopRaises(arrow::compute::Subtract, MakeArray(min), MakeArray(max), + "overflow"); + + this->AssertBinopRaises(arrow::compute::Multiply, MakeArray(min, max, max), + MakeArray(max, 2, max), "overflow"); +} + +TYPED_TEST(TestBinaryArithmeticsSigned, OverflowRaises) { + using CType = typename TestFixture::CType; + + auto min = std::numeric_limits::lowest(); + auto max = std::numeric_limits::max(); + + this->SetOverflowCheck(true); + + this->AssertBinop(arrow::compute::Multiply, MakeArray(max), MakeArray(-1), + MakeArray(min + 1)); + this->AssertBinopRaises(arrow::compute::Multiply, MakeArray(max), MakeArray(2), + "overflow"); + this->AssertBinopRaises(arrow::compute::Multiply, MakeArray(min), MakeArray(-1), + "overflow"); +} + TYPED_TEST(TestBinaryArithmeticsUnsigned, OverflowWraps) { using CType = typename TestFixture::CType; auto min = std::numeric_limits::lowest(); auto max = std::numeric_limits::max(); + this->SetOverflowCheck(false); this->AssertBinop(arrow::compute::Add, MakeArray(min, max, max), MakeArray(CType(-1), 1, max), MakeArray(max, min, CType(-2))); diff --git a/cpp/src/arrow/util/int_util.h b/cpp/src/arrow/util/int_util.h index 0478c55e623..3131476bfec 100644 --- a/cpp/src/arrow/util/int_util.h +++ b/cpp/src/arrow/util/int_util.h @@ -100,6 +100,12 @@ bool HasAdditionOverflow(Integer value, Integer addend) { return (value > std::numeric_limits::max() - addend); } +/// Detect addition overflow between integers +template +bool HasSubtractionOverflow(Integer value, Integer minuend) { + return (value < minuend); +} + /// Upcast an integer to the largest possible width (currently 64 bits) template