Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions cpp/src/arrow/compute/kernels/codegen_internal_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,8 @@ TEST(TestDispatchBest, CommonTemporalResolution) {
ASSERT_EQ(TimeUnit::MILLI, CommonTemporalResolution(args.data(), args.size()));
args = {duration(TimeUnit::SECOND), time32(TimeUnit::SECOND)};
ASSERT_EQ(TimeUnit::SECOND, CommonTemporalResolution(args.data(), args.size()));
args = {duration(TimeUnit::SECOND), time64(TimeUnit::NANO)};
ASSERT_EQ(TimeUnit::NANO, CommonTemporalResolution(args.data(), args.size()));
args = {time64(TimeUnit::MICRO), duration(TimeUnit::NANO)};
ASSERT_EQ(TimeUnit::NANO, CommonTemporalResolution(args.data(), args.size()));
args = {timestamp(TimeUnit::SECOND, tz), timestamp(TimeUnit::MICRO)};
Expand All @@ -204,6 +206,8 @@ TEST(TestDispatchBest, CommonTemporalResolution) {
ASSERT_EQ(TimeUnit::NANO, CommonTemporalResolution(args.data(), args.size()));
args = {duration(TimeUnit::SECOND), int64()};
ASSERT_EQ(TimeUnit::SECOND, CommonTemporalResolution(args.data(), args.size()));
args = {duration(TimeUnit::MILLI), timestamp(TimeUnit::SECOND, tz)};
ASSERT_EQ(TimeUnit::MILLI, CommonTemporalResolution(args.data(), args.size()));
}

TEST(TestDispatchBest, ReplaceTemporalTypes) {
Expand Down
157 changes: 64 additions & 93 deletions cpp/src/arrow/compute/kernels/scalar_arithmetic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -240,23 +240,12 @@ struct SubtractCheckedDate32 {
}
};

template <bool is_32bit, int64_t multiple>
template <int64_t multiple>
struct AddTimeDuration {
template <typename T, typename Arg0, typename Arg1>
static enable_if_t<is_32bit, T> Call(KernelContext*, Arg0 left, Arg1 right,
Status* st) {
T result = arrow::internal::SafeSignedAdd(left, static_cast<T>(right));
if (result < 0 || multiple <= result) {
*st = Status::Invalid(result, " is not within the acceptable range of ", "[0, ",
multiple, ") s");
}
return result;
}

template <typename T, typename Arg0, typename Arg1>
static enable_if_t<!is_32bit, T> Call(KernelContext*, Arg0 left, Arg1 right,
Status* st) {
T result = arrow::internal::SafeSignedAdd(left, right);
static T Call(KernelContext*, Arg0 left, Arg1 right, Status* st) {
T result =
arrow::internal::SafeSignedAdd(static_cast<T>(left), static_cast<T>(right));
if (result < 0 || multiple <= result) {
*st = Status::Invalid(result, " is not within the acceptable range of ", "[0, ",
multiple, ") s");
Expand All @@ -265,27 +254,13 @@ struct AddTimeDuration {
}
};

template <bool is_32bit, int64_t multiple>
template <int64_t multiple>
struct AddTimeDurationChecked {
template <typename T, typename Arg0, typename Arg1>
static enable_if_t<is_32bit, T> Call(KernelContext*, Arg0 left, Arg1 right,
Status* st) {
T result = 0;
if (ARROW_PREDICT_FALSE(AddWithOverflow(left, static_cast<T>(right), &result))) {
*st = Status::Invalid("overflow");
}
if (result < 0 || multiple <= result) {
*st = Status::Invalid(result, " is not within the acceptable range of ", "[0, ",
multiple, ") s");
}
return result;
}

template <typename T, typename Arg0, typename Arg1>
static enable_if_t<!is_32bit, T> Call(KernelContext*, Arg0 left, Arg1 right,
Status* st) {
static T Call(KernelContext*, Arg0 left, Arg1 right, Status* st) {
T result = 0;
if (ARROW_PREDICT_FALSE(AddWithOverflow(left, static_cast<T>(right), &result))) {
if (ARROW_PREDICT_FALSE(
AddWithOverflow(static_cast<T>(left), static_cast<T>(right), &result))) {
*st = Status::Invalid("overflow");
}
if (result < 0 || multiple <= result) {
Expand All @@ -296,50 +271,23 @@ struct AddTimeDurationChecked {
}
};

template <bool is_32bit, int64_t multiple>
template <int64_t multiple>
struct SubtractTimeDuration {
template <typename T, typename Arg0, typename Arg1>
static enable_if_t<is_32bit, T> Call(KernelContext*, Arg0 left, Arg1 right,
Status* st) {
static T Call(KernelContext*, Arg0 left, Arg1 right, Status* st) {
T result = arrow::internal::SafeSignedSubtract(left, static_cast<T>(right));
if (result < 0 || multiple <= result) {
*st = Status::Invalid(result, " is not within the acceptable range of ", "[0, ",
multiple, ") s");
}
return result;
}

template <typename T, typename Arg0, typename Arg1>
static enable_if_t<!is_32bit, T> Call(KernelContext*, Arg0 left, Arg1 right,
Status* st) {
T result = arrow::internal::SafeSignedSubtract(left, right);
if (result < 0 || multiple <= result) {
*st = Status::Invalid(result, " is not within the acceptable range of ", "[0, ",
multiple, ") s");
}
return result;
}
};

template <bool is_32bit, int64_t multiple>
template <int64_t multiple>
struct SubtractTimeDurationChecked {
template <typename T, typename Arg0, typename Arg1>
static enable_if_t<is_32bit, T> Call(KernelContext*, Arg0 left, Arg1 right,
Status* st) {
T result = 0;
if (ARROW_PREDICT_FALSE(SubtractWithOverflow(left, static_cast<T>(right), &result))) {
*st = Status::Invalid("overflow");
}
if (result < 0 || multiple <= result) {
*st = Status::Invalid(result, " is not within the acceptable range of ", "[0, ",
multiple, ") s");
}
return result;
}

template <typename T, typename Arg0, typename Arg1>
static enable_if_t<!is_32bit, T> Call(KernelContext*, Arg0 left, Arg1 right,
Status* st) {
static T Call(KernelContext*, Arg0 left, Arg1 right, Status* st) {
T result = 0;
if (ARROW_PREDICT_FALSE(SubtractWithOverflow(left, static_cast<T>(right), &result))) {
*st = Status::Invalid("overflow");
Expand Down Expand Up @@ -2269,34 +2217,58 @@ std::shared_ptr<ScalarFunction> MakeArithmeticFunctionFloatingPointNotNull(
return func;
}

template <template <bool, int64_t> class Op>
void AddArithmeticFunctionTimeDurations(std::shared_ptr<ScalarFunction> func) {
template <template <int64_t> class Op>
void AddArithmeticFunctionTimeDuration(std::shared_ptr<ScalarFunction> func) {
// Add Op(time32, duration) -> time32
TimeUnit::type unit = TimeUnit::SECOND;
auto exec_1 = ScalarBinary<Time32Type, Time32Type, DurationType, Op<true, 86400>>::Exec;
auto exec_1 = ScalarBinary<Time32Type, Time32Type, DurationType, Op<86400>>::Exec;
DCHECK_OK(func->AddKernel({time32(unit), duration(unit)}, OutputType(FirstType),
std::move(exec_1)));

unit = TimeUnit::MILLI;
auto exec_2 =
ScalarBinary<Time32Type, Time32Type, DurationType, Op<true, 86400000>>::Exec;
auto exec_2 = ScalarBinary<Time32Type, Time32Type, DurationType, Op<86400000>>::Exec;
DCHECK_OK(func->AddKernel({time32(unit), duration(unit)}, OutputType(FirstType),
std::move(exec_2)));

// Add Op(time64, duration) -> time64
unit = TimeUnit::MICRO;
auto exec_3 =
ScalarBinary<Time64Type, Time64Type, DurationType, Op<false, 86400000000>>::Exec;
auto exec_3 = ScalarBinary<Time64Type, Time64Type, DurationType, Op<86400000000>>::Exec;
DCHECK_OK(func->AddKernel({time64(unit), duration(unit)}, OutputType(FirstType),
std::move(exec_3)));

unit = TimeUnit::NANO;
auto exec_4 =
ScalarBinary<Time64Type, Time64Type, DurationType, Op<false, 86400000000000>>::Exec;
ScalarBinary<Time64Type, Time64Type, DurationType, Op<86400000000000>>::Exec;
DCHECK_OK(func->AddKernel({time64(unit), duration(unit)}, OutputType(FirstType),
std::move(exec_4)));
}

template <template <int64_t> class Op>
void AddArithmeticFunctionDurationTime(std::shared_ptr<ScalarFunction> func) {
// Add Op(duration, time32) -> time32
TimeUnit::type unit = TimeUnit::SECOND;
auto exec_1 = ScalarBinary<Time32Type, DurationType, Time32Type, Op<86400>>::Exec;
DCHECK_OK(func->AddKernel({duration(unit), time32(unit)}, OutputType(LastType),
std::move(exec_1)));

unit = TimeUnit::MILLI;
auto exec_2 = ScalarBinary<Time32Type, DurationType, Time32Type, Op<86400000>>::Exec;
DCHECK_OK(func->AddKernel({duration(unit), time32(unit)}, OutputType(LastType),
std::move(exec_2)));

// Add Op(duration, time64) -> time64
unit = TimeUnit::MICRO;
auto exec_3 = ScalarBinary<Time64Type, DurationType, Time64Type, Op<86400000000>>::Exec;
DCHECK_OK(func->AddKernel({duration(unit), time64(unit)}, OutputType(LastType),
std::move(exec_3)));

unit = TimeUnit::NANO;
auto exec_4 =
ScalarBinary<Time64Type, DurationType, Time64Type, Op<86400000000000>>::Exec;
DCHECK_OK(func->AddKernel({duration(unit), time64(unit)}, OutputType(LastType),
std::move(exec_4)));
}

const FunctionDoc absolute_value_doc{
"Calculate the absolute value of the argument element-wise",
("Results will wrap around on integer overflow.\n"
Expand Down Expand Up @@ -2638,8 +2610,8 @@ void RegisterScalarArithmetic(FunctionRegistry* registry) {
for (auto unit : TimeUnit::values()) {
InputType in_type(match::TimestampTypeUnit(unit));
auto exec = ScalarBinary<Int64Type, Int64Type, Int64Type, Add>::Exec;
DCHECK_OK(add->AddKernel({in_type, duration(unit)}, OutputType(FirstType),
std::move(exec)));
DCHECK_OK(add->AddKernel({in_type, duration(unit)}, OutputType(FirstType), exec));
DCHECK_OK(add->AddKernel({duration(unit), in_type}, OutputType(LastType), exec));
}

// Add add(duration, duration) -> duration
Expand All @@ -2649,7 +2621,8 @@ void RegisterScalarArithmetic(FunctionRegistry* registry) {
DCHECK_OK(add->AddKernel({in_type, in_type}, duration(unit), std::move(exec)));
}

AddArithmeticFunctionTimeDurations<AddTimeDuration>(add);
AddArithmeticFunctionTimeDuration<AddTimeDuration>(add);
AddArithmeticFunctionDurationTime<AddTimeDuration>(add);

DCHECK_OK(registry->AddFunction(std::move(add)));

Expand All @@ -2662,8 +2635,10 @@ void RegisterScalarArithmetic(FunctionRegistry* registry) {
for (auto unit : TimeUnit::values()) {
InputType in_type(match::TimestampTypeUnit(unit));
auto exec = ScalarBinary<Int64Type, Int64Type, Int64Type, AddChecked>::Exec;
DCHECK_OK(add_checked->AddKernel({in_type, duration(unit)}, OutputType(FirstType),
std::move(exec)));
DCHECK_OK(
add_checked->AddKernel({in_type, duration(unit)}, OutputType(FirstType), exec));
DCHECK_OK(
add_checked->AddKernel({duration(unit), in_type}, OutputType(LastType), exec));
}

// Add add(duration, duration) -> duration
Expand All @@ -2674,7 +2649,8 @@ void RegisterScalarArithmetic(FunctionRegistry* registry) {
add_checked->AddKernel({in_type, in_type}, duration(unit), std::move(exec)));
}

AddArithmeticFunctionTimeDurations<AddTimeDurationChecked>(add_checked);
AddArithmeticFunctionTimeDuration<AddTimeDurationChecked>(add_checked);
AddArithmeticFunctionDurationTime<AddTimeDurationChecked>(add_checked);

DCHECK_OK(registry->AddFunction(std::move(add_checked)));

Expand Down Expand Up @@ -2732,7 +2708,7 @@ void RegisterScalarArithmetic(FunctionRegistry* registry) {
DCHECK_OK(subtract->AddKernel({in_type_date_64, in_type_date_64},
duration(TimeUnit::MILLI), std::move(exec_date_64)));

AddArithmeticFunctionTimeDurations<SubtractTimeDuration>(subtract);
AddArithmeticFunctionTimeDuration<SubtractTimeDuration>(subtract);

DCHECK_OK(registry->AddFunction(std::move(subtract)));

Expand Down Expand Up @@ -2798,7 +2774,7 @@ void RegisterScalarArithmetic(FunctionRegistry* registry) {
subtract_checked->AddKernel({in_type, in_type}, duration(unit), std::move(exec)));
}

AddArithmeticFunctionTimeDurations<SubtractTimeDurationChecked>(subtract_checked);
AddArithmeticFunctionTimeDuration<SubtractTimeDurationChecked>(subtract_checked);

DCHECK_OK(registry->AddFunction(std::move(subtract_checked)));

Expand All @@ -2808,12 +2784,9 @@ void RegisterScalarArithmetic(FunctionRegistry* registry) {

// Add multiply(duration, int64) -> duration
for (auto unit : TimeUnit::values()) {
auto exec1 = ArithmeticExecFromOp<ScalarBinaryEqualTypes, Multiply>(Type::DURATION);
DCHECK_OK(
multiply->AddKernel({duration(unit), int64()}, duration(unit), std::move(exec1)));
auto exec2 = ArithmeticExecFromOp<ScalarBinaryEqualTypes, Multiply>(Type::DURATION);
DCHECK_OK(
multiply->AddKernel({int64(), duration(unit)}, duration(unit), std::move(exec2)));
auto exec = ArithmeticExecFromOp<ScalarBinaryEqualTypes, Multiply>(Type::DURATION);
DCHECK_OK(multiply->AddKernel({duration(unit), int64()}, duration(unit), exec));
DCHECK_OK(multiply->AddKernel({int64(), duration(unit)}, duration(unit), exec));
}

DCHECK_OK(registry->AddFunction(std::move(multiply)));
Expand All @@ -2825,14 +2798,12 @@ void RegisterScalarArithmetic(FunctionRegistry* registry) {

// Add multiply_checked(duration, int64) -> duration
for (auto unit : TimeUnit::values()) {
auto exec1 =
ArithmeticExecFromOp<ScalarBinaryEqualTypes, MultiplyChecked>(Type::DURATION);
DCHECK_OK(multiply_checked->AddKernel({duration(unit), int64()}, duration(unit),
std::move(exec1)));
auto exec2 =
auto exec =
ArithmeticExecFromOp<ScalarBinaryEqualTypes, MultiplyChecked>(Type::DURATION);
DCHECK_OK(multiply_checked->AddKernel({int64(), duration(unit)}, duration(unit),
std::move(exec2)));
DCHECK_OK(
multiply_checked->AddKernel({duration(unit), int64()}, duration(unit), exec));
DCHECK_OK(
multiply_checked->AddKernel({int64(), duration(unit)}, duration(unit), exec));
}

DCHECK_OK(registry->AddFunction(std::move(multiply_checked)));
Expand Down
Loading