Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 37 additions & 3 deletions cpp/src/arrow/compute/kernels/scalar_arithmetic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1076,6 +1076,37 @@ struct ArithmeticFunction : ScalarFunction {
}
};

/// An ArithmeticFunction that promotes integer arguments to double.
struct ArithmeticFloatingPointFunction : public ArithmeticFunction {
using ArithmeticFunction::ArithmeticFunction;

Result<const Kernel*> DispatchBest(std::vector<ValueDescr>* values) const override {
RETURN_NOT_OK(CheckArity(*values));
RETURN_NOT_OK(CheckDecimals(values));

using arrow::compute::detail::DispatchExactImpl;
if (auto kernel = DispatchExactImpl(this, *values)) return kernel;

EnsureDictionaryDecoded(values);

if (values->size() == 2) {
ReplaceNullWithOtherType(values);
}

for (auto& descr : *values) {
if (is_integer(descr.type->id())) {
descr.type = float64();
}
}
if (auto type = CommonNumeric(*values)) {
ReplaceTypes(type, values);
}

if (auto kernel = DispatchExactImpl(this, *values)) return kernel;
return arrow::compute::detail::NoMatchingKernel(this, *values);
}
};

template <typename Op>
std::shared_ptr<ScalarFunction> MakeArithmeticFunction(std::string name,
const FunctionDoc* doc) {
Expand Down Expand Up @@ -1164,7 +1195,8 @@ std::shared_ptr<ScalarFunction> MakeShiftFunctionNotNull(std::string name,
template <typename Op>
std::shared_ptr<ScalarFunction> MakeUnaryArithmeticFunctionFloatingPoint(
std::string name, const FunctionDoc* doc) {
auto func = std::make_shared<ArithmeticFunction>(name, Arity::Unary(), doc);
auto func =
std::make_shared<ArithmeticFloatingPointFunction>(name, Arity::Unary(), doc);
for (const auto& ty : FloatingPointTypes()) {
auto output = is_integer(ty->id()) ? float64() : ty;
auto exec = GenerateArithmeticFloatingPoint<ScalarUnary, Op>(ty);
Expand All @@ -1176,7 +1208,8 @@ std::shared_ptr<ScalarFunction> MakeUnaryArithmeticFunctionFloatingPoint(
template <typename Op>
std::shared_ptr<ScalarFunction> MakeUnaryArithmeticFunctionFloatingPointNotNull(
std::string name, const FunctionDoc* doc) {
auto func = std::make_shared<ArithmeticFunction>(name, Arity::Unary(), doc);
auto func =
std::make_shared<ArithmeticFloatingPointFunction>(name, Arity::Unary(), doc);
for (const auto& ty : FloatingPointTypes()) {
auto output = is_integer(ty->id()) ? float64() : ty;
auto exec = GenerateArithmeticFloatingPoint<ScalarUnaryNotNull, Op>(ty);
Expand All @@ -1188,7 +1221,8 @@ std::shared_ptr<ScalarFunction> MakeUnaryArithmeticFunctionFloatingPointNotNull(
template <typename Op>
std::shared_ptr<ScalarFunction> MakeArithmeticFunctionFloatingPoint(
std::string name, const FunctionDoc* doc) {
auto func = std::make_shared<ArithmeticFunction>(name, Arity::Binary(), doc);
auto func =
std::make_shared<ArithmeticFloatingPointFunction>(name, Arity::Binary(), doc);
for (const auto& ty : FloatingPointTypes()) {
auto output = is_integer(ty->id()) ? float64() : ty;
auto exec = GenerateArithmeticFloatingPoint<ScalarBinaryEqualTypes, Op>(ty);
Expand Down
94 changes: 93 additions & 1 deletion cpp/src/arrow/compute/kernels/scalar_arithmetic_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1042,6 +1042,26 @@ TEST(TestUnaryArithmetic, DispatchBest) {
for (std::string name : {"negate", "negate_checked", "abs", "abs_checked"}) {
CheckDispatchFails(name, {null()});
}

for (std::string name :
{"ln", "log2", "log10", "log1p", "sin", "cos", "tan", "asin", "acos"}) {
for (std::string suffix : {"", "_checked"}) {
name += suffix;

CheckDispatchBest(name, {int32()}, {float64()});
CheckDispatchBest(name, {uint8()}, {float64()});

CheckDispatchBest(name, {dictionary(int8(), int64())}, {float64()});
}
}

CheckDispatchBest("atan", {int32()}, {float64()});
CheckDispatchBest("atan2", {int32(), float64()}, {float64(), float64()});
CheckDispatchBest("atan2", {int32(), uint8()}, {float64(), float64()});
CheckDispatchBest("atan2", {int32(), null()}, {float64(), float64()});
CheckDispatchBest("atan2", {float32(), float64()}, {float64(), float64()});
// Integer always promotes to double
CheckDispatchBest("atan2", {float32(), int8()}, {float64(), float64()});
}

TYPED_TEST(TestUnaryArithmeticSigned, Negate) {
Expand Down Expand Up @@ -1821,9 +1841,41 @@ TYPED_TEST(TestBinaryArithmeticFloating, TrigAtan2) {
-M_PI_2, 0, M_PI));
}

TYPED_TEST(TestUnaryArithmeticIntegral, Trig) {
// Integer arguments promoted to double, sanity check here
auto ty = this->type_singleton();
auto atan = [](const Datum& arg, ArithmeticOptions, ExecContext* ctx) {
return Atan(arg, ctx);
};
for (auto check_overflow : {false, true}) {
this->SetOverflowCheck(check_overflow);
this->AssertUnaryOp(Sin, ArrayFromJSON(ty, "[0, 1]"),
ArrayFromJSON(float64(), "[0, 0.8414709848078965]"));
this->AssertUnaryOp(Cos, ArrayFromJSON(ty, "[0, 1]"),
ArrayFromJSON(float64(), "[1, 0.5403023058681398]"));
this->AssertUnaryOp(Tan, ArrayFromJSON(ty, "[0, 1]"),
ArrayFromJSON(float64(), "[0, 1.5574077246549023]"));
this->AssertUnaryOp(Asin, ArrayFromJSON(ty, "[0, 1]"),
ArrayFromJSON(float64(), MakeArray(0, M_PI_2)));
this->AssertUnaryOp(Acos, ArrayFromJSON(ty, "[0, 1]"),
ArrayFromJSON(float64(), MakeArray(M_PI_2, 0)));
this->AssertUnaryOp(atan, ArrayFromJSON(ty, "[0, 1]"),
ArrayFromJSON(float64(), MakeArray(0, M_PI_4)));
}
}

TYPED_TEST(TestBinaryArithmeticIntegral, Trig) {
// Integer arguments promoted to double, sanity check here
auto ty = this->type_singleton();
auto atan2 = [](const Datum& y, const Datum& x, ArithmeticOptions, ExecContext* ctx) {
return Atan2(y, x, ctx);
};
this->AssertBinop(atan2, ArrayFromJSON(ty, "[0, 1]"), ArrayFromJSON(ty, "[1, 0]"),
ArrayFromJSON(float64(), MakeArray(0, M_PI_2)));
}

TYPED_TEST(TestUnaryArithmeticFloating, Log) {
using CType = typename TestFixture::CType;
auto ty = this->type_singleton();
this->SetNansEqual(true);
auto min_val = std::numeric_limits<CType>::min();
auto max_val = std::numeric_limits<CType>::max();
Expand Down Expand Up @@ -1881,5 +1933,45 @@ TYPED_TEST(TestUnaryArithmeticFloating, Log) {
Log1p(lowest_val, this->options_));
}

TYPED_TEST(TestUnaryArithmeticIntegral, Log) {
// Integer arguments promoted to double, sanity check here
auto ty = this->type_singleton();
for (auto check_overflow : {false, true}) {
this->SetOverflowCheck(check_overflow);
this->AssertUnaryOp(Ln, ArrayFromJSON(ty, "[1, null]"),
ArrayFromJSON(float64(), "[0, null]"));
this->AssertUnaryOp(Log10, ArrayFromJSON(ty, "[1, 10, null]"),
ArrayFromJSON(float64(), "[0, 1, null]"));
this->AssertUnaryOp(Log2, ArrayFromJSON(ty, "[1, 2, null]"),
ArrayFromJSON(float64(), "[0, 1, null]"));
this->AssertUnaryOp(Log1p, ArrayFromJSON(ty, "[0, null]"),
ArrayFromJSON(float64(), "[0, null]"));
}
}

TYPED_TEST(TestUnaryArithmeticSigned, Log) {
// Integer arguments promoted to double, sanity check here
auto ty = this->type_singleton();
this->SetNansEqual(true);
this->SetOverflowCheck(false);
this->AssertUnaryOp(Ln, ArrayFromJSON(ty, "[-1, 0]"),
ArrayFromJSON(float64(), "[NaN, -Inf]"));
this->AssertUnaryOp(Log10, ArrayFromJSON(ty, "[-1, 0]"),
ArrayFromJSON(float64(), "[NaN, -Inf]"));
this->AssertUnaryOp(Log2, ArrayFromJSON(ty, "[-1, 0]"),
ArrayFromJSON(float64(), "[NaN, -Inf]"));
this->AssertUnaryOp(Log1p, ArrayFromJSON(ty, "[-2, -1]"),
ArrayFromJSON(float64(), "[NaN, -Inf]"));
this->SetOverflowCheck(true);
this->AssertUnaryOpRaises(Ln, "[0]", "logarithm of zero");
this->AssertUnaryOpRaises(Ln, "[-1]", "logarithm of negative number");
this->AssertUnaryOpRaises(Log10, "[0]", "logarithm of zero");
this->AssertUnaryOpRaises(Log10, "[-1]", "logarithm of negative number");
this->AssertUnaryOpRaises(Log2, "[0]", "logarithm of zero");
this->AssertUnaryOpRaises(Log2, "[-1]", "logarithm of negative number");
this->AssertUnaryOpRaises(Log1p, "[-1]", "logarithm of zero");
this->AssertUnaryOpRaises(Log1p, "[-2]", "logarithm of negative number");
}

} // namespace compute
} // namespace arrow