Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 60 additions & 0 deletions cpp/src/arrow/compute/kernels/scalar_arithmetic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,62 @@ struct SubtractCheckedDate32 {
}
};

template <bool is_32bit, int64_t multiple>
struct AddTimeDuration {
template <typename T, typename Arg0, typename Arg1>
static enable_if_t<is_32bit, T> Call(KernelContext*, Arg0 left, Arg1 right,
Status* st) {
T result = arrow::internal::SafeSignedAdd(left, static_cast<T>(right));
if (result < 0 || multiple <= result) {
*st = Status::Invalid(result, " is not within the acceptable range of ", "[0, ",
multiple, ") s");
}
return result;
}

template <typename T, typename Arg0, typename Arg1>
static enable_if_t<!is_32bit, T> Call(KernelContext*, Arg0 left, Arg1 right,
Status* st) {
T result = arrow::internal::SafeSignedAdd(left, right);
if (result < 0 || multiple <= result) {
*st = Status::Invalid(result, " is not within the acceptable range of ", "[0, ",
multiple, ") s");
}
return result;
}
};

template <bool is_32bit, int64_t multiple>
struct AddTimeDurationChecked {
template <typename T, typename Arg0, typename Arg1>
static enable_if_t<is_32bit, T> Call(KernelContext*, Arg0 left, Arg1 right,
Status* st) {
T result = 0;
if (ARROW_PREDICT_FALSE(AddWithOverflow(left, static_cast<T>(right), &result))) {
*st = Status::Invalid("overflow");
}
if (result < 0 || multiple <= result) {
*st = Status::Invalid(result, " is not within the acceptable range of ", "[0, ",
multiple, ") s");
}
return result;
}

template <typename T, typename Arg0, typename Arg1>
static enable_if_t<!is_32bit, T> Call(KernelContext*, Arg0 left, Arg1 right,
Status* st) {
T result = 0;
if (ARROW_PREDICT_FALSE(AddWithOverflow(left, static_cast<T>(right), &result))) {
*st = Status::Invalid("overflow");
}
if (result < 0 || multiple <= result) {
*st = Status::Invalid(result, " is not within the acceptable range of ", "[0, ",
multiple, ") s");
}
return result;
}
};

template <bool is_32bit, int64_t multiple>
struct SubtractTimeDuration {
template <typename T, typename Arg0, typename Arg1>
Expand Down Expand Up @@ -2553,6 +2609,8 @@ void RegisterScalarArithmetic(FunctionRegistry* registry) {
DCHECK_OK(add->AddKernel({in_type, in_type}, duration(unit), std::move(exec)));
}

AddArithmeticFunctionTimeDurations<AddTimeDuration>(add);

DCHECK_OK(registry->AddFunction(std::move(add)));

// ----------------------------------------------------------------------
Expand All @@ -2577,6 +2635,8 @@ void RegisterScalarArithmetic(FunctionRegistry* registry) {
add_checked->AddKernel({in_type, in_type}, duration(unit), std::move(exec)));
}

AddArithmeticFunctionTimeDurations<AddTimeDurationChecked>(add_checked);

DCHECK_OK(registry->AddFunction(std::move(add_checked)));

// ----------------------------------------------------------------------
Expand Down
59 changes: 59 additions & 0 deletions cpp/src/arrow/compute/kernels/scalar_temporal_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1238,6 +1238,65 @@ TEST_F(ScalarTemporalTest, TestTemporalSubtractTime) {
}
}

TEST_F(ScalarTemporalTest, TestTemporalAddTimeAndDuration) {
for (auto op : {"add", "add_checked"}) {
auto arr_s = ArrayFromJSON(time32(TimeUnit::SECOND), times_s);
auto arr_s2 = ArrayFromJSON(time32(TimeUnit::SECOND), times_s2);
auto arr_ms = ArrayFromJSON(time32(TimeUnit::MILLI), times_ms);
auto arr_ms2 = ArrayFromJSON(time32(TimeUnit::MILLI), times_ms2);
auto arr_us = ArrayFromJSON(time64(TimeUnit::MICRO), times_us);
auto arr_us2 = ArrayFromJSON(time64(TimeUnit::MICRO), times_us2);
auto arr_ns = ArrayFromJSON(time64(TimeUnit::NANO), times_ns);
auto arr_ns2 = ArrayFromJSON(time64(TimeUnit::NANO), times_ns2);

CheckScalarBinary(op, arr_s,
ArrayFromJSON(duration(TimeUnit::SECOND), seconds_between_time),
arr_s2);
CheckScalarBinary(op, arr_ms,
ArrayFromJSON(duration(TimeUnit::MILLI), milliseconds_between_time),
arr_ms2);
CheckScalarBinary(op, arr_us,
ArrayFromJSON(duration(TimeUnit::MICRO), microseconds_between_time),
arr_us2);
CheckScalarBinary(op, arr_ns,
ArrayFromJSON(duration(TimeUnit::NANO), nanoseconds_between_time),
arr_ns2);

auto seconds_1 = ArrayFromJSON(time32(TimeUnit::SECOND), R"([1, null])");
auto milliseconds_2k = ArrayFromJSON(duration(TimeUnit::MILLI), R"([2000, null])");
auto milliseconds_3k = ArrayFromJSON(time32(TimeUnit::MILLI), R"([3000, null])");
auto nanoseconds_1G = ArrayFromJSON(time64(TimeUnit::NANO), R"([1000000000, null])");
auto microseconds_2M = ArrayFromJSON(duration(TimeUnit::MICRO), R"([2000000, null])");
auto nanoseconds_3M = ArrayFromJSON(time64(TimeUnit::NANO), R"([3000000000, null])");
auto microseconds_3M = ArrayFromJSON(time64(TimeUnit::MICRO), R"([3000000, null])");
CheckScalarBinary(op, seconds_1, milliseconds_2k, milliseconds_3k);
CheckScalarBinary(op, nanoseconds_1G, microseconds_2M, nanoseconds_3M);
CheckScalarBinary(op, seconds_1, microseconds_2M, microseconds_3M);

EXPECT_RAISES_WITH_MESSAGE_THAT(
Invalid,
::testing::HasSubstr("-1 is not within the acceptable range of [0, 86400)"),
CallFunction(op, {ArrayFromJSON(time32(TimeUnit::SECOND), R"([0, null])"),
ArrayFromJSON(duration(TimeUnit::SECOND), R"([-1, null])")}));

EXPECT_RAISES_WITH_MESSAGE_THAT(
Invalid,
::testing::HasSubstr(
"86400000000001 is not within the acceptable range of [0, 86400000000000)"),
CallFunction(op,
{ArrayFromJSON(time64(TimeUnit::MICRO), R"([86400000000, null])"),
ArrayFromJSON(duration(TimeUnit::NANO), R"([1, null])")}));

EXPECT_RAISES_WITH_MESSAGE_THAT(
Invalid,
::testing::HasSubstr(
"86400000001 is not within the acceptable range of [0, 86400000000)"),
CallFunction(op,
{ArrayFromJSON(time64(TimeUnit::MICRO), R"([86400000000, null])"),
ArrayFromJSON(duration(TimeUnit::MICRO), R"([1, null])")}));
}
}

TEST_F(ScalarTemporalTest, TestTemporalSubtractTimeAndDuration) {
for (auto op : {"subtract", "subtract_checked"}) {
auto arr_s = ArrayFromJSON(time32(TimeUnit::SECOND), times_s);
Expand Down