Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions cpp/src/arrow/compute/api_scalar.cc
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,7 @@ void RegisterScalarOptions(FunctionRegistry* registry) {

SCALAR_ARITHMETIC_UNARY(AbsoluteValue, "abs", "abs_checked")
SCALAR_ARITHMETIC_UNARY(Negate, "negate", "negate_checked")
SCALAR_EAGER_UNARY(Sign, "sign")
SCALAR_ARITHMETIC_UNARY(Sin, "sin", "sin_checked")
SCALAR_ARITHMETIC_UNARY(Cos, "cos", "cos_checked")
SCALAR_ARITHMETIC_UNARY(Asin, "asin", "asin_checked")
Expand Down
9 changes: 9 additions & 0 deletions cpp/src/arrow/compute/api_scalar.h
Original file line number Diff line number Diff line change
Expand Up @@ -513,6 +513,15 @@ Result<Datum> MinElementWise(
ElementWiseAggregateOptions options = ElementWiseAggregateOptions::Defaults(),
ExecContext* ctx = NULLPTR);

/// \brief Get the sign of a value. Array values can be of arbitrary length. If argument
/// is null the result will be null.
///
/// \param[in] arg the value to extract sign from
/// \param[in] ctx the function execution context, optional
/// \return the elementwise sign function
ARROW_EXPORT
Result<Datum> Sign(const Datum& arg, ExecContext* ctx = NULLPTR);

/// \brief Compare a numeric array with a scalar.
///
/// \param[in] left datum to compare, must be an Array
Expand Down
77 changes: 77 additions & 0 deletions cpp/src/arrow/compute/kernels/scalar_arithmetic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,10 @@
#include <limits>
#include <utility>

#include "arrow/compute/kernels/codegen_internal.h"
#include "arrow/compute/kernels/common.h"
#include "arrow/compute/kernels/util_internal.h"
#include "arrow/type.h"
#include "arrow/type_traits.h"
#include "arrow/util/decimal.h"
#include "arrow/util/int_util_internal.h"
Expand Down Expand Up @@ -462,6 +464,23 @@ struct PowerChecked {
}
};

struct Sign {
template <typename T, typename Arg>
static constexpr enable_if_floating_point<T> Call(KernelContext*, Arg arg, Status*) {
return std::isnan(arg) ? arg : ((arg == 0) ? 0 : (std::signbit(arg) ? -1 : 1));
}

template <typename T, typename Arg>
static constexpr enable_if_unsigned_integer<T> Call(KernelContext*, Arg arg, Status*) {
return arg > 0;
}

template <typename T, typename Arg>
static constexpr enable_if_signed_integer<T> Call(KernelContext*, Arg arg, Status*) {
return (arg > 0) ? 1 : ((arg == 0) ? 0 : -1);
}
};

// Bitwise operations

struct BitWiseNot {
Expand Down Expand Up @@ -1033,6 +1052,37 @@ void AddDecimalBinaryKernels(const std::string& name,
DCHECK_OK((*func)->AddKernel({in_type256, in_type256}, out_type, exec256));
}

// Generate a kernel given an arithmetic functor
template <template <typename...> class KernelGenerator, typename OutType, typename Op>
ArrayKernelExec GenerateArithmeticWithFixedIntOutType(detail::GetTypeId get_id) {
switch (get_id.id) {
case Type::INT8:
return KernelGenerator<OutType, Int8Type, Op>::Exec;
case Type::UINT8:
return KernelGenerator<OutType, UInt8Type, Op>::Exec;
case Type::INT16:
return KernelGenerator<OutType, Int16Type, Op>::Exec;
case Type::UINT16:
return KernelGenerator<OutType, UInt16Type, Op>::Exec;
case Type::INT32:
return KernelGenerator<OutType, Int32Type, Op>::Exec;
case Type::UINT32:
return KernelGenerator<OutType, UInt32Type, Op>::Exec;
case Type::INT64:
case Type::TIMESTAMP:
return KernelGenerator<OutType, Int64Type, Op>::Exec;
case Type::UINT64:
return KernelGenerator<OutType, UInt64Type, Op>::Exec;
case Type::FLOAT:
return KernelGenerator<FloatType, FloatType, Op>::Exec;
case Type::DOUBLE:
return KernelGenerator<DoubleType, DoubleType, Op>::Exec;
default:
DCHECK(false);
return ExecFail;
}
}

struct ArithmeticFunction : ScalarFunction {
using ScalarFunction::ScalarFunction;

Expand Down Expand Up @@ -1142,6 +1192,21 @@ std::shared_ptr<ScalarFunction> MakeUnaryArithmeticFunction(std::string name,
return func;
}

// Like MakeUnaryArithmeticFunction, but for unary arithmetic ops with a fixed
// output type for integral inputs.
template <typename Op, typename IntOutType>
std::shared_ptr<ScalarFunction> MakeUnaryArithmeticFunctionWithFixedIntOutType(
std::string name, const FunctionDoc* doc) {
auto int_out_ty = TypeTraits<IntOutType>::type_singleton();
auto func = std::make_shared<ArithmeticFunction>(name, Arity::Unary(), doc);
for (const auto& ty : NumericTypes()) {
auto out_ty = arrow::is_floating(ty->id()) ? ty : int_out_ty;
auto exec = GenerateArithmeticWithFixedIntOutType<ScalarUnary, IntOutType, Op>(ty);
DCHECK_OK(func->AddKernel({ty}, out_ty, exec));
}
return func;
}

// Like MakeUnaryArithmeticFunction, but for arithmetic ops that need to run
// only on non-null output.
template <typename Op>
Expand Down Expand Up @@ -1318,6 +1383,13 @@ const FunctionDoc pow_checked_doc{
"or integer overflow is encountered."),
{"base", "exponent"}};

const FunctionDoc sign_doc{
"Get the signedness of the arguments element-wise",
("Output is any of (-1,1) for nonzero inputs and 0 for zero input.\n"
"NaN values return NaN. Integral values return signedness as Int8 and\n"
"floating-point values return it with the same type as the input values."),
{"x"}};

const FunctionDoc bit_wise_not_doc{
"Bit-wise negate the arguments element-wise", ("Null values return null."), {"x"}};

Expand Down Expand Up @@ -1579,6 +1651,11 @@ void RegisterScalarArithmetic(FunctionRegistry* registry) {
MakeArithmeticFunctionNotNull<PowerChecked>("power_checked", &pow_checked_doc);
DCHECK_OK(registry->AddFunction(std::move(power_checked)));

// ----------------------------------------------------------------------
auto sign =
MakeUnaryArithmeticFunctionWithFixedIntOutType<Sign, Int8Type>("sign", &sign_doc);
DCHECK_OK(registry->AddFunction(std::move(sign)));

// ----------------------------------------------------------------------
// Bitwise functions
{
Expand Down
102 changes: 89 additions & 13 deletions cpp/src/arrow/compute/kernels/scalar_arithmetic_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -66,42 +66,50 @@ class TestUnaryArithmetic : public TestBase {
return *arrow::MakeScalar(type_singleton(), value);
}

// (Scalar)
// (Scalar, Scalar)
void AssertUnaryOp(UnaryFunction func, CType argument, CType expected) {
auto arg = MakeScalar(argument);
auto exp = MakeScalar(expected);
ASSERT_OK_AND_ASSIGN(auto actual, func(arg, options_, nullptr));
AssertScalarsApproxEqual(*exp, *actual.scalar(), /*verbose=*/true);
}

// (Scalar)
// (Scalar, Scalar)
void AssertUnaryOp(UnaryFunction func, const std::shared_ptr<Scalar>& arg,
const std::shared_ptr<Scalar>& expected) {
ASSERT_OK_AND_ASSIGN(auto actual, func(arg, options_, nullptr));
AssertScalarsApproxEqual(*expected, *actual.scalar(), /*verbose=*/true);
}

// (Array)
void AssertUnaryOp(UnaryFunction func, const std::string& argument,
const std::string& expected) {
auto arg = ArrayFromJSON(type_singleton(), argument);
// (JSON, JSON)
void AssertUnaryOp(UnaryFunction func, const std::string& arg_json,
const std::string& expected_json) {
auto arg = ArrayFromJSON(type_singleton(), arg_json);
auto expected = ArrayFromJSON(type_singleton(), expected_json);
AssertUnaryOp(func, arg, expected);
}

// (Array)
// (Array, JSON)
void AssertUnaryOp(UnaryFunction func, const std::shared_ptr<Array>& arg,
const std::string& expected_json) {
const auto expected = ArrayFromJSON(type_singleton(), expected_json);
return AssertUnaryOp(func, arg, expected);
AssertUnaryOp(func, arg, expected);
}

// (Array)
// (JSON, Array)
void AssertUnaryOp(UnaryFunction func, const std::string& arg_json,
const std::shared_ptr<Array>& expected) {
auto arg = ArrayFromJSON(type_singleton(), arg_json);
AssertUnaryOp(func, arg, expected);
}

// (Array, Array)
void AssertUnaryOp(UnaryFunction func, const std::shared_ptr<Array>& arg,
const std::shared_ptr<Array>& expected) {
ASSERT_OK_AND_ASSIGN(Datum actual, func(arg, options_, nullptr));
ASSERT_OK_AND_ASSIGN(auto actual, func(arg, options_, nullptr));
ValidateAndAssertApproxEqual(actual.make_array(), expected);

// Also check (Scalar) operations
// Also check (Scalar, Scalar) operations
const int64_t length = expected->length();
for (int64_t i = 0; i < length; ++i) {
const auto expected_scalar = *expected->GetScalar(i);
Expand Down Expand Up @@ -1024,22 +1032,25 @@ TEST(TestBinaryArithmetic, AddWithImplicitCastsUint64EdgeCase) {
}

TEST(TestUnaryArithmetic, DispatchBest) {
for (std::string name : {"negate", "abs", "abs_checked"}) {
// All arithmetic
for (std::string name : {"negate", "abs", "abs_checked", "sign"}) {
for (const auto& ty : {int8(), int16(), int32(), int64(), uint8(), uint16(), uint32(),
uint64(), float32(), float64()}) {
CheckDispatchBest(name, {ty}, {ty});
CheckDispatchBest(name, {dictionary(int8(), ty)}, {ty});
}
}

// Signed arithmetic
for (std::string name : {"negate_checked"}) {
for (const auto& ty : {int8(), int16(), int32(), int64(), float32(), float64()}) {
CheckDispatchBest(name, {ty}, {ty});
CheckDispatchBest(name, {dictionary(int8(), ty)}, {ty});
}
}

for (std::string name : {"negate", "negate_checked", "abs", "abs_checked"}) {
// Null input
for (std::string name : {"negate", "negate_checked", "abs", "abs_checked", "sign"}) {
CheckDispatchFails(name, {null()});
}

Expand Down Expand Up @@ -1973,5 +1984,70 @@ TYPED_TEST(TestUnaryArithmeticSigned, Log) {
this->AssertUnaryOpRaises(Log1p, "[-2]", "logarithm of negative number");
}

TYPED_TEST(TestUnaryArithmeticSigned, Sign) {
using CType = typename TestFixture::CType;
auto min = std::numeric_limits<CType>::min();
auto max = std::numeric_limits<CType>::max();

// N.B. TestUnaryArithmetic expects a function with ArithmeticOptions as its
// second parameter
auto sign = [](const Datum& arg, ArithmeticOptions, ExecContext* ctx) {
return Sign(arg, ctx);
};

this->AssertUnaryOp(sign, "[]", ArrayFromJSON(int8(), "[]"));
this->AssertUnaryOp(sign, "[null]", ArrayFromJSON(int8(), "[null]"));
this->AssertUnaryOp(sign, "[1, null, -10]", ArrayFromJSON(int8(), "[1, null, -1]"));
this->AssertUnaryOp(sign, "[0]", ArrayFromJSON(int8(), "[0]"));
this->AssertUnaryOp(sign, "[1, 10, 127]", ArrayFromJSON(int8(), "[1, 1, 1]"));
this->AssertUnaryOp(sign, "[-1, -10, -127]", ArrayFromJSON(int8(), "[-1, -1, -1]"));
this->AssertUnaryOp(sign, this->MakeScalar(min), *arrow::MakeScalar(int8(), -1));
this->AssertUnaryOp(sign, this->MakeScalar(max), *arrow::MakeScalar(int8(), 1));
}

TYPED_TEST(TestUnaryArithmeticUnsigned, Sign) {
using CType = typename TestFixture::CType;
auto min = std::numeric_limits<CType>::min();
auto max = std::numeric_limits<CType>::max();

// N.B. TestUnaryArithmetic expects a function with ArithmeticOptions as its
// second parameter
auto sign = [](const Datum& arg, ArithmeticOptions, ExecContext* ctx) {
return Sign(arg, ctx);
};

this->AssertUnaryOp(sign, "[]", ArrayFromJSON(int8(), "[]"));
this->AssertUnaryOp(sign, "[null]", ArrayFromJSON(int8(), "[null]"));
this->AssertUnaryOp(sign, "[1, null, 10]", ArrayFromJSON(int8(), "[1, null, 1]"));
this->AssertUnaryOp(sign, "[0]", ArrayFromJSON(int8(), "[0]"));
this->AssertUnaryOp(sign, "[1, 10, 127]", ArrayFromJSON(int8(), "[1, 1, 1]"));
this->AssertUnaryOp(sign, this->MakeScalar(min), *arrow::MakeScalar(int8(), 0));
this->AssertUnaryOp(sign, this->MakeScalar(max), *arrow::MakeScalar(int8(), 1));
}

TYPED_TEST(TestUnaryArithmeticFloating, Sign) {
using CType = typename TestFixture::CType;
auto min = std::numeric_limits<CType>::lowest();
auto max = std::numeric_limits<CType>::max();

this->SetNansEqual(true);

// N.B. TestUnaryArithmetic expects a function with ArithmeticOptions as its
// second parameter
auto sign = [](const Datum& arg, ArithmeticOptions, ExecContext* ctx) {
return Sign(arg, ctx);
};

this->AssertUnaryOp(sign, "[]", "[]");
this->AssertUnaryOp(sign, "[null]", "[null]");
this->AssertUnaryOp(sign, "[1.3, null, -10.80]", "[1, null, -1]");
this->AssertUnaryOp(sign, "[0.0, -0.0]", "[0, 0]");
this->AssertUnaryOp(sign, "[1.3, 10.80, 12748.001]", "[1, 1, 1]");
this->AssertUnaryOp(sign, "[-1.3, -10.80, -12748.001]", "[-1, -1, -1]");
this->AssertUnaryOp(sign, "[Inf, -Inf]", "[1, -1]");
this->AssertUnaryOp(sign, "[NaN]", "[NaN]");
this->AssertUnaryOp(sign, this->MakeScalar(min), this->MakeScalar(-1));
this->AssertUnaryOp(sign, this->MakeScalar(max), this->MakeScalar(1));
}
} // namespace compute
} // namespace arrow
70 changes: 38 additions & 32 deletions docs/source/cpp/compute.rst
Original file line number Diff line number Diff line change
Expand Up @@ -258,41 +258,43 @@ Input(s) will be cast to the :ref:`common numeric type <common-numeric-type>`
(and dictionary decoded, if applicable) before the operation is applied.

The default variant of these functions does not detect overflow (the result
then typically wraps around). Each function is also available in an
then typically wraps around). Most functions are also available in an
overflow-checking variant, suffixed ``_checked``, which returns
an ``Invalid`` :class:`Status` when overflow is detected.

+------------------+--------+----------------+----------------+-------+
| Function name | Arity | Input types | Output type | Notes |
+==================+========+================+================+=======+
| abs | Unary | Numeric | Numeric | |
+------------------+--------+----------------+----------------+-------+
| abs_checked | Unary | Numeric | Numeric | |
+------------------+--------+----------------+----------------+-------+
| add | Binary | Numeric | Numeric | \(1) |
+------------------+--------+----------------+----------------+-------+
| add_checked | Binary | Numeric | Numeric | \(1) |
+------------------+--------+----------------+----------------+-------+
| divide | Binary | Numeric | Numeric | \(1) |
+------------------+--------+----------------+----------------+-------+
| divide_checked | Binary | Numeric | Numeric | \(1) |
+------------------+--------+----------------+----------------+-------+
| multiply | Binary | Numeric | Numeric | \(1) |
+------------------+--------+----------------+----------------+-------+
| multiply_checked | Binary | Numeric | Numeric | \(1) |
+------------------+--------+----------------+----------------+-------+
| negate | Unary | Numeric | Numeric | |
+------------------+--------+----------------+----------------+-------+
| negate_checked | Unary | Signed Numeric | Signed Numeric | |
+------------------+--------+----------------+----------------+-------+
| power | Binary | Numeric | Numeric | |
+------------------+--------+----------------+----------------+-------+
| power_checked | Binary | Numeric | Numeric | |
+------------------+--------+----------------+----------------+-------+
| subtract | Binary | Numeric | Numeric | \(1) |
+------------------+--------+----------------+----------------+-------+
| subtract_checked | Binary | Numeric | Numeric | \(1) |
+------------------+--------+----------------+----------------+-------+
+------------------+--------+----------------+----------------------+-------+
| Function name | Arity | Input types | Output type | Notes |
+==================+========+================+======================+=======+
| abs | Unary | Numeric | Numeric | |
+------------------+--------+----------------+----------------------+-------+
| abs_checked | Unary | Numeric | Numeric | |
+------------------+--------+----------------+----------------------+-------+
| add | Binary | Numeric | Numeric | \(1) |
+------------------+--------+----------------+----------------------+-------+
| add_checked | Binary | Numeric | Numeric | \(1) |
+------------------+--------+----------------+----------------------+-------+
| divide | Binary | Numeric | Numeric | \(1) |
+------------------+--------+----------------+----------------------+-------+
| divide_checked | Binary | Numeric | Numeric | \(1) |
+------------------+--------+----------------+----------------------+-------+
| multiply | Binary | Numeric | Numeric | \(1) |
+------------------+--------+----------------+----------------------+-------+
| multiply_checked | Binary | Numeric | Numeric | \(1) |
+------------------+--------+----------------+----------------------+-------+
| negate | Unary | Numeric | Numeric | |
+------------------+--------+----------------+----------------------+-------+
| negate_checked | Unary | Signed Numeric | Signed Numeric | |
+------------------+--------+----------------+----------------------+-------+
| power | Binary | Numeric | Numeric | |
+------------------+--------+----------------+----------------------+-------+
| power_checked | Binary | Numeric | Numeric | |
+------------------+--------+----------------+----------------------+-------+
| sign | Unary | Numeric | Int8/Float32/Float64 | \(2) |
+------------------+--------+----------------+----------------------+-------+
| subtract | Binary | Numeric | Numeric | \(1) |
+------------------+--------+----------------+----------------------+-------+
| subtract_checked | Binary | Numeric | Numeric | \(1) |
+------------------+--------+----------------+----------------------+-------+

* \(1) Precision and scale of computed DECIMAL results

Expand All @@ -315,6 +317,10 @@ an ``Invalid`` :class:`Status` when overflow is detected.
enough scale kept. Error is returned if the result precision is beyond the
decimal value range.

* \(2) Output is any of (-1,1) for nonzero inputs and 0 for zero input.
NaN values return NaN. Integral values return signedness as Int8 and
floating-point values return it with the same type as the input values.

Bit-wise functions
~~~~~~~~~~~~~~~~~~

Expand Down
Loading