diff --git a/cpp/src/arrow/compute/api_scalar.cc b/cpp/src/arrow/compute/api_scalar.cc index 86ef74946df..9cdce7c1f16 100644 --- a/cpp/src/arrow/compute/api_scalar.cc +++ b/cpp/src/arrow/compute/api_scalar.cc @@ -42,6 +42,8 @@ namespace compute { // Arithmetic SCALAR_EAGER_BINARY(Add, "add") +SCALAR_EAGER_BINARY(Subtract, "subtract") +SCALAR_EAGER_BINARY(Multiply, "multiply") // ---------------------------------------------------------------------- // Set-related operations diff --git a/cpp/src/arrow/compute/api_scalar.h b/cpp/src/arrow/compute/api_scalar.h index e3210d3ac91..bc502f7bcb9 100644 --- a/cpp/src/arrow/compute/api_scalar.h +++ b/cpp/src/arrow/compute/api_scalar.h @@ -35,16 +35,36 @@ namespace compute { // ---------------------------------------------------------------------- -/// \brief Add two values together. Array values must be the same length. If a -/// value is null in either addend, the result is null +/// \brief Add two values together. Array values must be the same length. If +/// either addend is null the result will be null. /// -/// \param[in] left the first value -/// \param[in] right the second value +/// \param[in] left the first addend +/// \param[in] right the second addend /// \param[in] ctx the function execution context, optional -/// \return the elementwise addition of the values +/// \return the elementwise sum ARROW_EXPORT Result Add(const Datum& left, const Datum& right, ExecContext* ctx = NULLPTR); +/// \brief Subtract two values. Array values must be the same length. If the +/// minuend or subtrahend is null the result will be null. +/// +/// \param[in] left the value subtracted from (minuend) +/// \param[in] right the value by which the minuend is reduced (subtrahend) +/// \param[in] ctx the function execution context, optional +/// \return the elementwise difference +ARROW_EXPORT +Result Subtract(const Datum& left, const Datum& right, ExecContext* ctx = NULLPTR); + +/// \brief Multiply two values. Array values must be the same length. If either +/// factor is null the result will be null. +/// +/// \param[in] left the first factor +/// \param[in] right the second factor +/// \param[in] ctx the function execution context, optional +/// \return the elementwise product +ARROW_EXPORT +Result Multiply(const Datum& left, const Datum& right, ExecContext* ctx = NULLPTR); + enum CompareOperator { EQUAL, NOT_EQUAL, diff --git a/cpp/src/arrow/compute/kernels/codegen_internal.h b/cpp/src/arrow/compute/kernels/codegen_internal.h index de420017ec5..f4e2a8f9ca8 100644 --- a/cpp/src/arrow/compute/kernels/codegen_internal.h +++ b/cpp/src/arrow/compute/kernels/codegen_internal.h @@ -337,25 +337,6 @@ void ScalarPrimitiveExecUnary(KernelContext* ctx, const ExecBatch& batch, Datum* } } -template -void ScalarPrimitiveExecBinary(KernelContext* ctx, const ExecBatch& batch, Datum* out) { - using OUT = typename OutType::c_type; - using ARG0 = typename Arg0Type::c_type; - using ARG1 = typename Arg1Type::c_type; - - if (batch[0].kind() == Datum::SCALAR || batch[1].kind() == Datum::SCALAR) { - ctx->SetStatus(Status::NotImplemented("NYI")); - } else { - ArrayData* out_arr = out->mutable_array(); - auto out_data = out_arr->GetMutableValues(1); - auto arg0_data = batch[0].array()->GetValues(1); - auto arg1_data = batch[1].array()->GetValues(1); - for (int64_t i = 0; i < batch.length; ++i) { - *out_data++ = Op::template Call(ctx, *arg0_data++, *arg1_data++); - } - } -} - // OutputAdapter allows passing an inlineable lambda that provides a sequence // of output values to write into output memory. Boolean and primitive outputs // are currently implemented, and the validity bitmap is presumed to be handled @@ -610,53 +591,56 @@ struct ScalarUnaryNotNull { // // implementation // } // }; -template +template struct ScalarBinary { using OUT = typename GetOutputType::T; using ARG0 = typename GetViewType::T; using ARG1 = typename GetViewType::T; - template static void ArrayArray(KernelContext* ctx, const ExecBatch& batch, Datum* out) { ArrayIterator arg0(*batch[0].array()); ArrayIterator arg1(*batch[1].array()); OutputAdapter::Write(ctx, out, [&]() -> OUT { - return ChosenOp::template Call(ctx, arg0(), arg1()); + return Op::template Call(ctx, arg0(), arg1()); }); } - template static void ArrayScalar(KernelContext* ctx, const ExecBatch& batch, Datum* out) { ArrayIterator arg0(*batch[0].array()); auto arg1 = UnboxScalar::Unbox(batch[1]); OutputAdapter::Write(ctx, out, [&]() -> OUT { - return ChosenOp::template Call(ctx, arg0(), arg1); + return Op::template Call(ctx, arg0(), arg1); + }); + } + + static void ScalarArray(KernelContext* ctx, const ExecBatch& batch, Datum* out) { + auto arg0 = UnboxScalar::Unbox(batch[0]); + ArrayIterator arg1(*batch[1].array()); + OutputAdapter::Write(ctx, out, [&]() -> OUT { + return Op::template Call(ctx, arg0, arg1()); }); } - template static void ScalarScalar(KernelContext* ctx, const ExecBatch& batch, Datum* out) { auto arg0 = UnboxScalar::Unbox(batch[0]); auto arg1 = UnboxScalar::Unbox(batch[1]); - out->value = BoxScalar::Box(ChosenOp::template Call(ctx, arg0, arg1), + out->value = BoxScalar::Box(Op::template Call(ctx, arg0, arg1), out->type()); } static void Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { - if (batch[0].kind() == Datum::ARRAY) { if (batch[1].kind() == Datum::ARRAY) { - return ArrayArray(ctx, batch, out); + return ArrayArray(ctx, batch, out); } else { - return ArrayScalar(ctx, batch, out); + return ArrayScalar(ctx, batch, out); } } else { if (batch[1].kind() == Datum::ARRAY) { // e.g. if we were doing scalar < array, we flip and do array >= scalar - return BinaryExecFlipped(ctx, ArrayScalar, batch, out); + return ScalarArray(ctx, batch, out); } else { - return ScalarScalar(ctx, batch, out); + return ScalarScalar(ctx, batch, out); } } } @@ -664,9 +648,8 @@ struct ScalarBinary { // A kernel exec generator for binary kernels where both input types are the // same -template -using ScalarBinaryEqualTypes = ScalarBinary; +template +using ScalarBinaryEqualTypes = ScalarBinary; // ---------------------------------------------------------------------- // Dynamic kernel selectors. These functors allow a kernel implementation to be @@ -726,43 +709,6 @@ ArrayKernelExec NumericEqualTypesUnary(detail::GetTypeId get_id) { } } -// Generate a kernel given a functor of type -// -// struct OPERATOR_NAME { -// template -// static OUT Call(KernelContext*, ARG0 left, ARG1 right) { -// // IMPLEMENTATION -// } -// }; -template -ArrayKernelExec NumericEqualTypesBinary(detail::GetTypeId get_id) { - switch (get_id.id) { - case Type::INT8: - return ScalarPrimitiveExecBinary; - case Type::UINT8: - return ScalarPrimitiveExecBinary; - case Type::INT16: - return ScalarPrimitiveExecBinary; - case Type::UINT16: - return ScalarPrimitiveExecBinary; - case Type::INT32: - return ScalarPrimitiveExecBinary; - case Type::UINT32: - return ScalarPrimitiveExecBinary; - case Type::INT64: - return ScalarPrimitiveExecBinary; - case Type::UINT64: - return ScalarPrimitiveExecBinary; - case Type::FLOAT: - return ScalarPrimitiveExecBinary; - case Type::DOUBLE: - return ScalarPrimitiveExecBinary; - default: - DCHECK(false); - return ExecFail; - } -} - // Generate a kernel given a templated functor. This template effectively // "curries" the first type argument. The functor must be of the form: // diff --git a/cpp/src/arrow/compute/kernels/scalar_arithmetic.cc b/cpp/src/arrow/compute/kernels/scalar_arithmetic.cc index 57129bcb6a4..de05900f7c1 100644 --- a/cpp/src/arrow/compute/kernels/scalar_arithmetic.cc +++ b/cpp/src/arrow/compute/kernels/scalar_arithmetic.cc @@ -16,25 +16,150 @@ // under the License. #include "arrow/compute/kernels/common.h" +#include "arrow/util/int_util.h" namespace arrow { namespace compute { +template +using is_unsigned_integer = std::integral_constant::value && + std::is_unsigned::value>; + +template +using is_signed_integer = + std::integral_constant::value && std::is_signed::value>; + +template +using enable_if_signed_integer = enable_if_t::value, T>; + +template +using enable_if_unsigned_integer = enable_if_t::value, T>; + +template +using enable_if_floating_point = enable_if_t::value, T>; + +template ::type> +constexpr Unsigned to_unsigned(T signed_) { + return static_cast(signed_); +} + struct Add { - template - static constexpr OUT Call(KernelContext*, ARG0 left, ARG1 right) { + template + static constexpr enable_if_floating_point Call(KernelContext*, T left, T right) { + return left + right; + } + + template + static constexpr enable_if_unsigned_integer Call(KernelContext*, T left, T right) { return left + right; } + + template + static constexpr enable_if_signed_integer Call(KernelContext*, T left, T right) { + return to_unsigned(left) + to_unsigned(right); + } +}; + +struct Subtract { + template + static constexpr enable_if_floating_point Call(KernelContext*, T left, T right) { + return left - right; + } + + template + static constexpr enable_if_unsigned_integer Call(KernelContext*, T left, T right) { + return left - right; + } + + template + static constexpr enable_if_signed_integer Call(KernelContext*, T left, T right) { + return to_unsigned(left) - to_unsigned(right); + } +}; + +struct Multiply { + static_assert(std::is_same::value, ""); + static_assert(std::is_same::value, ""); + static_assert(std::is_same::value, ""); + static_assert(std::is_same::value, ""); + static_assert(std::is_same::value, ""); + + static_assert(std::is_same::value, ""); + + static_assert(std::is_same::value, ""); + static_assert(std::is_same::value, ""); + + template + static constexpr enable_if_floating_point Call(KernelContext*, T left, T right) { + return left * right; + } + + template + static constexpr enable_if_unsigned_integer Call(KernelContext*, T left, T right) { + return left * right; + } + + template + static constexpr enable_if_signed_integer Call(KernelContext*, T left, T right) { + return to_unsigned(left) * to_unsigned(right); + } + + // Multiplication of 16 bit integer types implicitly promotes to signed 32 bit + // integer. However, some inputs may nevertheless overflow (which triggers undefined + // behaviour). Therefore we first cast to 32 bit unsigned integers where overflow is + // well defined. + template + static constexpr int16_t Call(KernelContext*, int16_t left, int16_t right) { + return static_cast(left) * static_cast(right); + } + template + static constexpr uint16_t Call(KernelContext*, uint16_t left, uint16_t right) { + return static_cast(left) * static_cast(right); + } }; namespace codegen { +// Generate a kernel given an arithmetic functor +// +// To avoid undefined behaviour of signed integer overflow treat the signed +// input argument values as unsigned then cast them to signed making them wrap +// around. +template +ArrayKernelExec NumericEqualTypesBinary(detail::GetTypeId get_id) { + switch (get_id.id) { + case Type::INT8: + return ScalarBinaryEqualTypes::Exec; + case Type::UINT8: + return ScalarBinaryEqualTypes::Exec; + case Type::INT16: + return ScalarBinaryEqualTypes::Exec; + case Type::UINT16: + return ScalarBinaryEqualTypes::Exec; + case Type::INT32: + return ScalarBinaryEqualTypes::Exec; + case Type::UINT32: + return ScalarBinaryEqualTypes::Exec; + case Type::INT64: + return ScalarBinaryEqualTypes::Exec; + case Type::UINT64: + return ScalarBinaryEqualTypes::Exec; + case Type::FLOAT: + return ScalarBinaryEqualTypes::Exec; + case Type::DOUBLE: + return ScalarBinaryEqualTypes::Exec; + default: + DCHECK(false); + return ExecFail; + } +} + template -void MakeBinaryFunction(std::string name, FunctionRegistry* registry) { +void AddBinaryFunction(std::string name, FunctionRegistry* registry) { auto func = std::make_shared(name, Arity::Binary()); - for (const std::shared_ptr& ty : NumericTypes()) { - DCHECK_OK(func->AddKernel({InputType::Array(ty), InputType::Array(ty)}, ty, - NumericEqualTypesBinary(*ty))); + for (const auto& ty : NumericTypes()) { + auto exec = codegen::NumericEqualTypesBinary(ty); + DCHECK_OK(func->AddKernel({ty, ty}, ty, exec)); } DCHECK_OK(registry->AddFunction(std::move(func))); } @@ -44,7 +169,9 @@ void MakeBinaryFunction(std::string name, FunctionRegistry* registry) { namespace internal { void RegisterScalarArithmetic(FunctionRegistry* registry) { - codegen::MakeBinaryFunction("add", registry); + codegen::AddBinaryFunction("add", registry); + codegen::AddBinaryFunction("subtract", registry); + codegen::AddBinaryFunction("multiply", registry); } } // namespace internal diff --git a/cpp/src/arrow/compute/kernels/scalar_arithmetic_test.cc b/cpp/src/arrow/compute/kernels/scalar_arithmetic_test.cc index 017c9f5f034..2f2159ef642 100644 --- a/cpp/src/arrow/compute/kernels/scalar_arithmetic_test.cc +++ b/cpp/src/arrow/compute/kernels/scalar_arithmetic_test.cc @@ -29,6 +29,7 @@ #include "arrow/type.h" #include "arrow/type_traits.h" #include "arrow/util/checked_cast.h" +#include "arrow/util/string.h" #include "arrow/testing/gtest_common.h" #include "arrow/testing/gtest_util.h" @@ -38,67 +39,255 @@ namespace arrow { namespace compute { template -class TestArithmeticKernel : public TestBase { - private: - void AssertAddArrays(const std::shared_ptr lhs, const std::shared_ptr rhs, - const std::shared_ptr expected) { - ASSERT_OK_AND_ASSIGN(Datum out, arrow::compute::Add(lhs, rhs)); - std::shared_ptr actual = out.make_array(); - ASSERT_OK(actual->ValidateFull()); - AssertArraysEqual(*expected, *actual); +class TestBinaryArithmetics : public TestBase { + protected: + using CType = typename ArrowType::c_type; + + static std::shared_ptr type_singleton() { + return TypeTraits::type_singleton(); } - protected: - virtual void AssertAdd(const std::string& lhs, const std::string& rhs, - const std::string& expected) { - auto type = TypeTraits::type_singleton(); - AssertAddArrays(ArrayFromJSON(type, lhs), ArrayFromJSON(type, rhs), - ArrayFromJSON(type, expected)); + using BinaryFunction = + std::function(const Datum&, const Datum&, ExecContext*)>; + + // (Scalar, Scalar) + void AssertBinop(BinaryFunction func, CType lhs, CType rhs, CType expected) { + ASSERT_OK_AND_ASSIGN(auto left, MakeScalar(type_singleton(), lhs)); + ASSERT_OK_AND_ASSIGN(auto right, MakeScalar(type_singleton(), rhs)); + ASSERT_OK_AND_ASSIGN(auto exp, MakeScalar(type_singleton(), expected)); + + ASSERT_OK_AND_ASSIGN(auto actual, func(left, right, nullptr)); + AssertScalarsEqual(*exp, *actual.scalar(), true); + } + + // (Scalar, Array) + void AssertBinop(BinaryFunction func, CType lhs, const std::string& rhs, + const std::string& expected) { + ASSERT_OK_AND_ASSIGN(auto left, MakeScalar(type_singleton(), lhs)); + auto right = ArrayFromJSON(type_singleton(), rhs); + auto exp = ArrayFromJSON(type_singleton(), expected); + + ASSERT_OK_AND_ASSIGN(auto actual, func(left, right, nullptr)); + ValidateAndAssertApproxEqual(actual.make_array(), expected); + } + + // (Array, Array) + void AssertBinop(BinaryFunction func, const std::string& lhs, const std::string& rhs, + const std::string& expected) { + auto left = ArrayFromJSON(type_singleton(), lhs); + auto right = ArrayFromJSON(type_singleton(), rhs); + + ASSERT_OK_AND_ASSIGN(Datum actual, func(left, right, nullptr)); + ValidateAndAssertApproxEqual(actual.make_array(), expected); + } + + void ValidateAndAssertApproxEqual(std::shared_ptr actual, + const std::string& expected) { + auto exp = ArrayFromJSON(type_singleton(), expected); + ASSERT_OK(actual->ValidateFull()); + AssertArraysApproxEqual(*exp, *actual); } }; -template -class TestArithmeticKernelFloating : public TestArithmeticKernel {}; -TYPED_TEST_SUITE(TestArithmeticKernelFloating, RealArrowTypes); +template +std::string MakeArray(Elements... elements) { + std::vector elements_as_strings = {std::to_string(elements)...}; -template -class TestArithmeticKernelIntegral : public TestArithmeticKernel {}; -TYPED_TEST_SUITE(TestArithmeticKernelIntegral, IntegralArrowTypes); + std::vector elements_as_views(sizeof...(Elements)); + std::copy(elements_as_strings.begin(), elements_as_strings.end(), + elements_as_views.begin()); + + return "[" + internal::JoinStrings(elements_as_views, ",") + "]"; +} + +template +class TestBinaryArithmeticsIntegral : public TestBinaryArithmetics {}; + +template +class TestBinaryArithmeticsSigned : public TestBinaryArithmeticsIntegral {}; + +template +class TestBinaryArithmeticsUnsigned : public TestBinaryArithmeticsIntegral {}; + +template +class TestBinaryArithmeticsFloating : public TestBinaryArithmetics {}; + +// InputType - OutputType pairs +using IntegralTypes = testing::Types; + +using SignedIntegerTypes = testing::Types; + +using UnsignedIntegerTypes = + testing::Types; + +// TODO(kszucs): add half-float +using FloatingTypes = testing::Types; + +TYPED_TEST_SUITE(TestBinaryArithmeticsIntegral, IntegralTypes); +TYPED_TEST_SUITE(TestBinaryArithmeticsSigned, SignedIntegerTypes); +TYPED_TEST_SUITE(TestBinaryArithmeticsUnsigned, UnsignedIntegerTypes); +TYPED_TEST_SUITE(TestBinaryArithmeticsFloating, FloatingTypes); + +TYPED_TEST(TestBinaryArithmeticsIntegral, Add) { + this->AssertBinop(arrow::compute::Add, "[]", "[]", "[]"); + this->AssertBinop(arrow::compute::Add, "[null]", "[null]", "[null]"); + this->AssertBinop(arrow::compute::Add, "[3, 2, 6]", "[1, 0, 2]", "[4, 2, 8]"); + + this->AssertBinop(arrow::compute::Add, "[1, 2, 3, 4, 5, 6, 7]", "[0, 1, 2, 3, 4, 5, 6]", + "[1, 3, 5, 7, 9, 11, 13]"); + + this->AssertBinop(arrow::compute::Add, "[10, 12, 4, 50, 50, 32, 11]", + "[2, 0, 6, 1, 5, 3, 4]", "[12, 12, 10, 51, 55, 35, 15]"); + + this->AssertBinop(arrow::compute::Add, "[null, 1, 3, null, 2, 5]", "[1, 4, 2, 5, 0, 3]", + "[null, 5, 5, null, 2, 8]"); + + this->AssertBinop(arrow::compute::Add, 10, "[null, 1, 3, null, 2, 5]", + "[null, 11, 13, null, 12, 15]"); + + this->AssertBinop(arrow::compute::Add, 17, 42, 59); +} -TYPED_TEST(TestArithmeticKernelFloating, Add) { - this->AssertAdd("[]", "[]", "[]"); +TYPED_TEST(TestBinaryArithmeticsIntegral, Sub) { + this->AssertBinop(arrow::compute::Subtract, "[]", "[]", "[]"); + this->AssertBinop(arrow::compute::Subtract, "[null]", "[null]", "[null]"); + this->AssertBinop(arrow::compute::Subtract, "[3, 2, 6]", "[1, 0, 2]", "[2, 2, 4]"); - this->AssertAdd("[3.4, 2.6, 6.3]", "[1, 0, 2]", "[4.4, 2.6, 8.3]"); + this->AssertBinop(arrow::compute::Subtract, "[1, 2, 3, 4, 5, 6, 7]", + "[0, 1, 2, 3, 4, 5, 6]", "[1, 1, 1, 1, 1, 1, 1]"); - this->AssertAdd("[1.1, 2.4, 3.5, 4.3, 5.1, 6.8, 7.3]", "[0, 1, 2, 3, 4, 5, 6]", - "[1.1, 3.4, 5.5, 7.3, 9.1, 11.8, 13.3]"); + this->AssertBinop(arrow::compute::Subtract, 10, "[null, 1, 3, null, 2, 5]", + "[null, 9, 7, null, 8, 5]"); - this->AssertAdd("[7, 6, 5, 4, 3, 2, 1]", "[6, 5, 4, 3, 2, 1, 0]", - "[13, 11, 9, 7, 5, 3, 1]"); + this->AssertBinop(arrow::compute::Subtract, 20, 9, 11); +} + +TYPED_TEST(TestBinaryArithmeticsIntegral, Mul) { + this->AssertBinop(arrow::compute::Multiply, "[]", "[]", "[]"); + this->AssertBinop(arrow::compute::Multiply, "[null]", "[null]", "[null]"); + this->AssertBinop(arrow::compute::Multiply, "[3, 2, 6]", "[1, 0, 2]", "[3, 0, 12]"); + + this->AssertBinop(arrow::compute::Multiply, "[1, 2, 3, 4, 5, 6, 7]", + "[0, 1, 2, 3, 4, 5, 6]", "[0, 2, 6, 12, 20, 30, 42]"); + + this->AssertBinop(arrow::compute::Multiply, "[7, 6, 5, 4, 3, 2, 1]", + "[6, 5, 4, 3, 2, 1, 0]", "[42, 30, 20, 12, 6, 2, 0]"); + + this->AssertBinop(arrow::compute::Multiply, "[null, 1, 3, null, 2, 5]", + "[1, 4, 2, 5, 0, 3]", "[null, 4, 6, null, 0, 15]"); + + this->AssertBinop(arrow::compute::Multiply, 3, "[null, 1, 3, null, 2, 5]", + "[null, 3, 9, null, 6, 15]"); + + this->AssertBinop(arrow::compute::Multiply, 6, 7, 42); +} + +TYPED_TEST(TestBinaryArithmeticsSigned, Add) { + this->AssertBinop(arrow::compute::Add, "[-7, 6, 5, 4, 3, 2, 1]", + "[-6, 5, -4, 3, -2, 1, 0]", "[-13, 11, 1, 7, 1, 3, 1]"); + this->AssertBinop(arrow::compute::Add, -1, "[-6, 5, -4, 3, -2, 1, 0]", + "[-7, 4, -5, 2, -3, 0, -1]"); + this->AssertBinop(arrow::compute::Add, -10, 5, -5); +} + +TYPED_TEST(TestBinaryArithmeticsSigned, OverflowWraps) { + using CType = typename TestFixture::CType; + + auto min = std::numeric_limits::lowest(); + auto max = std::numeric_limits::max(); - this->AssertAdd("[10.4, 12, 4.2, 50, 50.3, 32, 11]", "[2, 0, 6, 1, 5, 3, 4]", - "[12.4, 12, 10.2, 51, 55.3, 35, 15]"); + this->AssertBinop(arrow::compute::Add, MakeArray(min, max, max), + MakeArray(CType(-1), 1, max), MakeArray(max, min, CType(-2))); - this->AssertAdd("[null, 1, 3.3, null, 2, 5.3]", "[1, 4, 2, 5, 0, 3]", - "[null, 5, 5.3, null, 2, 8.3]"); + this->AssertBinop(arrow::compute::Subtract, MakeArray(min, max, min), + MakeArray(1, max, max), MakeArray(max, 0, 1)); + + this->AssertBinop(arrow::compute::Multiply, MakeArray(min, max, max), + MakeArray(max, 2, max), MakeArray(min, CType(-2), 1)); } -TYPED_TEST(TestArithmeticKernelIntegral, Add) { - this->AssertAdd("[]", "[]", "[]"); +TYPED_TEST(TestBinaryArithmeticsUnsigned, OverflowWraps) { + using CType = typename TestFixture::CType; + + auto min = std::numeric_limits::lowest(); + auto max = std::numeric_limits::max(); + + this->AssertBinop(arrow::compute::Add, MakeArray(min, max, max), + MakeArray(CType(-1), 1, max), MakeArray(max, min, CType(-2))); + + this->AssertBinop(arrow::compute::Subtract, MakeArray(min, max, min), + MakeArray(1, max, max), MakeArray(max, 0, 1)); + + this->AssertBinop(arrow::compute::Multiply, MakeArray(min, max, max), + MakeArray(max, 2, max), MakeArray(min, CType(-2), 1)); +} + +TYPED_TEST(TestBinaryArithmeticsSigned, Sub) { + this->AssertBinop(arrow::compute::Subtract, "[0, 1, 2, 3, 4, 5, 6]", + "[1, 2, 3, 4, 5, 6, 7]", "[-1, -1, -1, -1, -1, -1, -1]"); + + this->AssertBinop(arrow::compute::Subtract, "[0, 0, 0, 0, 0, 0, 0]", + "[6, 5, 4, 3, 2, 1, 0]", "[-6, -5, -4, -3, -2, -1, 0]"); + + this->AssertBinop(arrow::compute::Subtract, "[10, 12, 4, 50, 50, 32, 11]", + "[2, 0, 6, 1, 5, 3, 4]", "[8, 12, -2, 49, 45, 29, 7]"); + + this->AssertBinop(arrow::compute::Subtract, "[null, 1, 3, null, 2, 5]", + "[1, 4, 2, 5, 0, 3]", "[null, -3, 1, null, 2, 2]"); +} + +TYPED_TEST(TestBinaryArithmeticsSigned, Mul) { + this->AssertBinop(arrow::compute::Multiply, "[-10, 12, 4, 50, -5, 32, 11]", + "[-2, 0, -6, 1, 5, 3, 4]", "[20, 0, -24, 50, -25, 96, 44]"); + this->AssertBinop(arrow::compute::Multiply, -2, "[-10, 12, 4, 50, -5, 32, 11]", + "[20, -24, -8, -100, 10, -64, -22]"); + this->AssertBinop(arrow::compute::Multiply, -5, -5, 25); +} + +TYPED_TEST(TestBinaryArithmeticsFloating, Add) { + this->AssertBinop(arrow::compute::Add, "[]", "[]", "[]"); + + this->AssertBinop(arrow::compute::Add, "[3.4, 2.6, 6.3]", "[1, 0, 2]", + "[4.4, 2.6, 8.3]"); + + this->AssertBinop(arrow::compute::Add, "[1.1, 2.4, 3.5, 4.3, 5.1, 6.8, 7.3]", + "[0, 1, 2, 3, 4, 5, 6]", "[1.1, 3.4, 5.5, 7.3, 9.1, 11.8, 13.3]"); + + this->AssertBinop(arrow::compute::Add, "[7, 6, 5, 4, 3, 2, 1]", "[6, 5, 4, 3, 2, 1, 0]", + "[13, 11, 9, 7, 5, 3, 1]"); + + this->AssertBinop(arrow::compute::Add, "[10.4, 12, 4.2, 50, 50.3, 32, 11]", + "[2, 0, 6, 1, 5, 3, 4]", "[12.4, 12, 10.2, 51, 55.3, 35, 15]"); + + this->AssertBinop(arrow::compute::Add, "[null, 1, 3.3, null, 2, 5.3]", + "[1, 4, 2, 5, 0, 3]", "[null, 5, 5.3, null, 2, 8.3]"); + + this->AssertBinop(arrow::compute::Add, 1.1F, "[null, 1, 3.3, null, 2, 5.3]", + "[null, 2.1, 4.4, null, 3.1, 6.4]"); +} + +TYPED_TEST(TestBinaryArithmeticsFloating, Sub) { + this->AssertBinop(arrow::compute::Subtract, "[]", "[]", "[]"); + + this->AssertBinop(arrow::compute::Subtract, "[3.4, 2.6, 6.3]", "[1, 0, 2]", + "[2.4, 2.6, 4.3]"); - this->AssertAdd("[3, 2, 6]", "[1, 0, 2]", "[4, 2, 8]"); + this->AssertBinop(arrow::compute::Subtract, "[1.1, 2.4, 3.5, 4.3, 5.1, 6.8, 7.3]", + "[0.1, 1.2, 2.3, 3.4, 4.5, 5.6, 6.7]", + "[1.0, 1.2, 1.2, 0.9, 0.6, 1.2, 0.6]"); - this->AssertAdd("[1, 2, 3, 4, 5, 6, 7]", "[0, 1, 2, 3, 4, 5, 6]", - "[1, 3, 5, 7, 9, 11, 13]"); + this->AssertBinop(arrow::compute::Subtract, "[7, 6, 5, 4, 3, 2, 1]", + "[6, 5, 4, 3, 2, 1, 0]", "[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]"); - this->AssertAdd("[7, 6, 5, 4, 3, 2, 1]", "[6, 5, 4, 3, 2, 1, 0]", - "[13, 11, 9, 7, 5, 3, 1]"); + this->AssertBinop(arrow::compute::Subtract, "[10.4, 12, 4.2, 50, 50.3, 32, 11]", + "[2, 0, 6, 1, 5, 3, 4]", "[8.4, 12, -1.8, 49, 45.3, 29, 7]"); - this->AssertAdd("[10, 12, 4, 50, 50, 32, 11]", "[2, 0, 6, 1, 5, 3, 4]", - "[12, 12, 10, 51, 55, 35, 15]"); + this->AssertBinop(arrow::compute::Subtract, "[null, 1, 3.3, null, 2, 5.3]", + "[1, 4, 2, 5, 0, 3]", "[null, -3, 1.3, null, 2, 2.3]"); - this->AssertAdd("[null, 1, 3, null, 2, 5]", "[1, 4, 2, 5, 0, 3]", - "[null, 5, 5, null, 2, 8]"); + this->AssertBinop(arrow::compute::Subtract, 0.1F, "[null, 1, 3.3, null, 2, 5.3]", + "[null, -0.9, -3.2, null, -1.9, -5.2]"); } } // namespace compute diff --git a/cpp/src/arrow/compute/kernels/scalar_compare.cc b/cpp/src/arrow/compute/kernels/scalar_compare.cc index 42a911a6e09..d782a3d41e7 100644 --- a/cpp/src/arrow/compute/kernels/scalar_compare.cc +++ b/cpp/src/arrow/compute/kernels/scalar_compare.cc @@ -68,49 +68,45 @@ struct LessEqual { } }; -template +template void AddCompare(const std::shared_ptr& ty, ScalarFunction* func) { - ArrayKernelExec exec = - codegen::ScalarBinaryEqualTypes::Exec; + ArrayKernelExec exec = codegen::ScalarBinaryEqualTypes::Exec; DCHECK_OK(func->AddKernel({ty, ty}, boolean(), exec)); } -template +template void AddTimestampComparisons(ScalarFunction* func) { ArrayKernelExec exec = - codegen::ScalarBinaryEqualTypes::Exec; + codegen::ScalarBinaryEqualTypes::Exec; for (auto unit : {TimeUnit::SECOND, TimeUnit::MILLI, TimeUnit::MICRO, TimeUnit::NANO}) { InputType in_type(match::TimestampUnit(unit)); DCHECK_OK(func->AddKernel({in_type, in_type}, boolean(), exec)); } } -template +template void MakeCompareFunction(std::string name, FunctionRegistry* registry) { auto func = std::make_shared(name, Arity::Binary()); DCHECK_OK(func->AddKernel( {boolean(), boolean()}, boolean(), - codegen::ScalarBinary::Exec)); + codegen::ScalarBinary::Exec)); for (const std::shared_ptr& ty : NumericTypes()) { - auto exec = - codegen::Numeric( - *ty); + auto exec = codegen::Numeric(*ty); DCHECK_OK(func->AddKernel({ty, ty}, boolean(), exec)); } for (const std::shared_ptr& ty : BaseBinaryTypes()) { auto exec = - codegen::BaseBinary( - *ty); + codegen::BaseBinary(*ty); DCHECK_OK(func->AddKernel({ty, ty}, boolean(), exec)); } // Temporal types requires some care because cross-unit comparisons with // everything but DATE32 and DATE64 are not implemented yet - AddCompare(date32(), func.get()); - AddCompare(date64(), func.get()); - AddTimestampComparisons(func.get()); + AddCompare(date32(), func.get()); + AddCompare(date64(), func.get()); + AddTimestampComparisons(func.get()); // TODO: Leave time32, time64, and duration for follow up work @@ -120,10 +116,10 @@ void MakeCompareFunction(std::string name, FunctionRegistry* registry) { void RegisterScalarComparison(FunctionRegistry* registry) { MakeCompareFunction("equal", registry); MakeCompareFunction("not_equal", registry); - MakeCompareFunction("less", registry); - MakeCompareFunction("less_equal", registry); - MakeCompareFunction("greater", registry); - MakeCompareFunction("greater_equal", registry); + MakeCompareFunction("less", registry); + MakeCompareFunction("less_equal", registry); + MakeCompareFunction("greater", registry); + MakeCompareFunction("greater_equal", registry); } } // namespace internal diff --git a/cpp/src/arrow/testing/gtest_util.cc b/cpp/src/arrow/testing/gtest_util.cc index a91fc834a64..280a6dd56a2 100644 --- a/cpp/src/arrow/testing/gtest_util.cc +++ b/cpp/src/arrow/testing/gtest_util.cc @@ -82,6 +82,21 @@ void AssertArraysEqual(const Array& expected, const Array& actual, bool verbose) } } +void AssertArraysApproxEqual(const Array& expected, const Array& actual, bool verbose) { + std::stringstream diff; + if (!expected.ApproxEquals(actual, EqualOptions().diff_sink(&diff))) { + if (verbose) { + ::arrow::PrettyPrintOptions options(/*indent=*/2); + options.window = 50; + diff << "Expected:\n"; + ARROW_EXPECT_OK(PrettyPrint(expected, options, &diff)); + diff << "\nActual:\n"; + ARROW_EXPECT_OK(PrettyPrint(actual, options, &diff)); + } + FAIL() << diff.str(); + } +} + void AssertScalarsEqual(const Scalar& expected, const Scalar& actual, bool verbose) { std::stringstream diff; // ARROW-8956, ScalarEquals returns false when both are null diff --git a/cpp/src/arrow/testing/gtest_util.h b/cpp/src/arrow/testing/gtest_util.h index d84db73fdd4..32c338ab538 100644 --- a/cpp/src/arrow/testing/gtest_util.h +++ b/cpp/src/arrow/testing/gtest_util.h @@ -41,13 +41,6 @@ #include "arrow/util/macros.h" #include "arrow/util/visibility.h" -namespace arrow { - -template -class Result; - -} // namespace arrow - // NOTE: failing must be inline in the macros below, to get correct file / line number // reporting on test failures. @@ -167,6 +160,8 @@ struct Datum; // If verbose is true, then the arrays will be pretty printed ARROW_EXPORT void AssertArraysEqual(const Array& expected, const Array& actual, bool verbose = false); +ARROW_EXPORT void AssertArraysApproxEqual(const Array& expected, const Array& actual, + bool verbose = false); // Returns true when values are both null ARROW_EXPORT void AssertScalarsEqual(const Scalar& expected, const Scalar& actual, bool verbose = false);