Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions cpp/src/arrow/compute/api_scalar.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,14 @@ namespace compute {
// ----------------------------------------------------------------------
// Arithmetic

#define SCALAR_ARITHMETIC_UNARY(NAME, REGISTRY_NAME, REGISTRY_CHECKED_NAME) \
Result<Datum> NAME(const Datum& arg, ArithmeticOptions options, ExecContext* ctx) { \
auto func_name = (options.check_overflow) ? REGISTRY_CHECKED_NAME : REGISTRY_NAME; \
return CallFunction(func_name, {arg}, ctx); \
}

SCALAR_ARITHMETIC_UNARY(Negate, "negate", "negate_checked")

#define SCALAR_ARITHMETIC_BINARY(NAME, REGISTRY_NAME, REGISTRY_CHECKED_NAME) \
Result<Datum> NAME(const Datum& left, const Datum& right, ArithmeticOptions options, \
ExecContext* ctx) { \
Expand Down
11 changes: 11 additions & 0 deletions cpp/src/arrow/compute/api_scalar.h
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,17 @@ Result<Datum> Divide(const Datum& left, const Datum& right,
ArithmeticOptions options = ArithmeticOptions(),
ExecContext* ctx = NULLPTR);

/// \brief Negate a value. Array values can be of arbitrary length. If argument
/// is null the result will be null.
///
/// \param[in] arg the value negated
/// \param[in] options arithmetic options (overflow handling), optional
/// \param[in] ctx the function execution context, optional
/// \return the elementwise negation
ARROW_EXPORT
Result<Datum> Negate(const Datum& arg, ArithmeticOptions options = ArithmeticOptions(),
ExecContext* ctx = NULLPTR);

/// \brief Raise the values of base array to the power of the exponent array values.
/// Array values must be the same length. If either base or exponent is null the result
/// will be null.
Expand Down
118 changes: 107 additions & 11 deletions cpp/src/arrow/compute/kernels/scalar_arithmetic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include <cmath>

#include "arrow/compute/kernels/common.h"
#include "arrow/type_traits.h"
#include "arrow/util/int_util_internal.h"
#include "arrow/util/macros.h"

Expand All @@ -26,13 +27,16 @@ namespace arrow {
using internal::AddWithOverflow;
using internal::DivideWithOverflow;
using internal::MultiplyWithOverflow;
using internal::NegateWithOverflow;
using internal::SubtractWithOverflow;

namespace compute {
namespace internal {

using applicator::ScalarBinaryEqualTypes;
using applicator::ScalarBinaryNotNullEqualTypes;
using applicator::ScalarUnary;
using applicator::ScalarUnaryNotNull;

namespace {

Expand Down Expand Up @@ -246,6 +250,49 @@ struct DivideChecked {
}
};

struct Negate {
template <typename T, typename Arg>
static constexpr enable_if_floating_point<T> Call(KernelContext*, Arg arg, Status*) {
return -arg;
}

template <typename T, typename Arg>
static constexpr enable_if_unsigned_integer<T> Call(KernelContext*, Arg arg, Status*) {
return ~arg + 1;
}

template <typename T, typename Arg>
static constexpr enable_if_signed_integer<T> Call(KernelContext*, Arg arg, Status* st) {
return arrow::internal::SafeSignedNegate(arg);
}
};

struct NegateChecked {
template <typename T, typename Arg>
static enable_if_signed_integer<T> Call(KernelContext*, Arg arg, Status* st) {
static_assert(std::is_same<T, Arg>::value, "");
T result = 0;
if (ARROW_PREDICT_FALSE(NegateWithOverflow(arg, &result))) {
*st = Status::Invalid("overflow");
}
return result;
}

template <typename T, typename Arg>
static enable_if_unsigned_integer<T> Call(KernelContext* ctx, Arg arg, Status* st) {
static_assert(std::is_same<T, Arg>::value, "");
DCHECK(false) << "This is included only for the purposes of instantiability from the "
"arithmetic kernel generator";
return 0;
}

template <typename T, typename Arg>
static constexpr enable_if_floating_point<T> Call(KernelContext*, Arg arg, Status* st) {
static_assert(std::is_same<T, Arg>::value, "");
return -arg;
}
};

struct Power {
ARROW_NOINLINE
static uint64_t IntegerPower(uint64_t base, uint64_t exp) {
Expand Down Expand Up @@ -310,7 +357,7 @@ struct PowerChecked {

// Generate a kernel given an arithmetic functor
template <template <typename... Args> class KernelGenerator, typename Op>
ArrayKernelExec NumericEqualTypesBinary(detail::GetTypeId get_id) {
ArrayKernelExec ArithmeticExecFromOp(detail::GetTypeId get_id) {
switch (get_id.id) {
case Type::INT8:
return KernelGenerator<Int8Type, Int8Type, Op>::Exec;
Expand Down Expand Up @@ -349,10 +396,14 @@ struct ArithmeticFunction : ScalarFunction {
if (auto kernel = DispatchExactImpl(this, *values)) return kernel;

EnsureDictionaryDecoded(values);
ReplaceNullWithOtherType(values);

if (auto type = CommonNumeric(*values)) {
ReplaceTypes(type, values);
// Only promote types for binary functions
if (values->size() == 2) {
ReplaceNullWithOtherType(values);

if (auto type = CommonNumeric(*values)) {
ReplaceTypes(type, values);
}
}

if (auto kernel = DispatchExactImpl(this, *values)) return kernel;
Expand All @@ -365,7 +416,7 @@ std::shared_ptr<ScalarFunction> MakeArithmeticFunction(std::string name,
const FunctionDoc* doc) {
auto func = std::make_shared<ArithmeticFunction>(name, Arity::Binary(), doc);
for (const auto& ty : NumericTypes()) {
auto exec = NumericEqualTypesBinary<ScalarBinaryEqualTypes, Op>(ty);
auto exec = ArithmeticExecFromOp<ScalarBinaryEqualTypes, Op>(ty);
DCHECK_OK(func->AddKernel({ty, ty}, ty, exec));
}
return func;
Expand All @@ -378,12 +429,38 @@ std::shared_ptr<ScalarFunction> MakeArithmeticFunctionNotNull(std::string name,
const FunctionDoc* doc) {
auto func = std::make_shared<ArithmeticFunction>(name, Arity::Binary(), doc);
for (const auto& ty : NumericTypes()) {
auto exec = NumericEqualTypesBinary<ScalarBinaryNotNullEqualTypes, Op>(ty);
auto exec = ArithmeticExecFromOp<ScalarBinaryNotNullEqualTypes, Op>(ty);
DCHECK_OK(func->AddKernel({ty, ty}, ty, exec));
}
return func;
}

template <typename Op>
std::shared_ptr<ScalarFunction> MakeUnaryArithmeticFunction(std::string name,
const FunctionDoc* doc) {
auto func = std::make_shared<ArithmeticFunction>(name, Arity::Unary(), doc);
for (const auto& ty : NumericTypes()) {
auto exec = ArithmeticExecFromOp<ScalarUnary, Op>(ty);
DCHECK_OK(func->AddKernel({ty}, ty, exec));
}
return func;
}

// Like MakeUnaryArithmeticFunction, but for signed arithmetic ops that need to run
// only on non-null output.
template <typename Op>
std::shared_ptr<ScalarFunction> MakeUnarySignedArithmeticFunctionNotNull(
std::string name, const FunctionDoc* doc) {
auto func = std::make_shared<ArithmeticFunction>(name, Arity::Unary(), doc);
for (const auto& ty : NumericTypes()) {
if (!arrow::is_unsigned_integer(ty->id())) {
auto exec = ArithmeticExecFromOp<ScalarUnaryNotNull, Op>(ty);
DCHECK_OK(func->AddKernel({ty}, ty, exec));
}
}
return func;
}

const FunctionDoc add_doc{"Add the arguments element-wise",
("Results will wrap around on integer overflow.\n"
"Use function \"add_checked\" if you want overflow\n"
Expand All @@ -396,14 +473,14 @@ const FunctionDoc add_checked_doc{
"doesn't fail on overflow, use function \"add\"."),
{"x", "y"}};

const FunctionDoc sub_doc{"Substract the arguments element-wise",
const FunctionDoc sub_doc{"Subtract the arguments element-wise",
("Results will wrap around on integer overflow.\n"
"Use function \"subtract_checked\" if you want overflow\n"
"to return an error."),
{"x", "y"}};

const FunctionDoc sub_checked_doc{
"Substract the arguments element-wise",
"Subtract the arguments element-wise",
("This function returns an error on overflow. For a variant that\n"
"doesn't fail on overflow, use function \"subtract\"."),
{"x", "y"}};
Expand Down Expand Up @@ -434,6 +511,18 @@ const FunctionDoc div_checked_doc{
"integer overflow is encountered."),
{"dividend", "divisor"}};

const FunctionDoc negate_doc{"Negate the argument element-wise",
("Results will wrap around on integer overflow.\n"
"Use function \"negate_checked\" if you want overflow\n"
"to return an error."),
{"x"}};

const FunctionDoc negate_checked_doc{
"Negate the arguments element-wise",
("This function returns an error on overflow. For a variant that\n"
"doesn't fail on overflow, use function \"negate\"."),
{"x"}};

const FunctionDoc pow_doc{
"Raise arguments to power element-wise",
("Integer to negative integer power returns an error. However, integer overflow\n"
Expand All @@ -445,7 +534,6 @@ const FunctionDoc pow_checked_doc{
("An error is returned when integer to negative integer power is encountered,\n"
"or integer overflow is encountered."),
{"base", "exponent"}};

} // namespace

void RegisterScalarArithmetic(FunctionRegistry* registry) {
Expand All @@ -465,8 +553,7 @@ void RegisterScalarArithmetic(FunctionRegistry* registry) {
// Add subtract(timestamp, timestamp) -> duration
for (auto unit : AllTimeUnits()) {
InputType in_type(match::TimestampTypeUnit(unit));
auto exec =
NumericEqualTypesBinary<ScalarBinaryEqualTypes, Subtract>(Type::TIMESTAMP);
auto exec = ArithmeticExecFromOp<ScalarBinaryEqualTypes, Subtract>(Type::TIMESTAMP);
DCHECK_OK(subtract->AddKernel({in_type, in_type}, duration(unit), std::move(exec)));
}

Expand Down Expand Up @@ -495,6 +582,15 @@ void RegisterScalarArithmetic(FunctionRegistry* registry) {
MakeArithmeticFunctionNotNull<DivideChecked>("divide_checked", &div_checked_doc);
DCHECK_OK(registry->AddFunction(std::move(divide_checked)));

// ----------------------------------------------------------------------
auto negate = MakeUnaryArithmeticFunction<Negate>("negate", &negate_doc);
DCHECK_OK(registry->AddFunction(std::move(negate)));

// ----------------------------------------------------------------------
auto negate_checked = MakeUnarySignedArithmeticFunctionNotNull<NegateChecked>(
"negate_checked", &negate_checked_doc);
DCHECK_OK(registry->AddFunction(std::move(negate_checked)));

// ----------------------------------------------------------------------
auto power = MakeArithmeticFunction<Power>("power", &pow_doc);
DCHECK_OK(registry->AddFunction(std::move(power)));
Expand Down
Loading