Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions cpp/src/arrow/compute/kernels/codegen_internal_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -247,8 +247,27 @@ TEST(TestDispatchBest, ReplaceTemporalTypes) {

args = {timestamp(TimeUnit::SECOND, "UTC"), timestamp(TimeUnit::SECOND, tz)};
ty = CommonTemporalResolution(args.data(), args.size());
ReplaceTemporalTypes(ty, &args);
AssertTypeEqual(args[0].type, timestamp(TimeUnit::SECOND, "UTC"));
AssertTypeEqual(args[1].type, timestamp(TimeUnit::SECOND, tz));

args = {time32(TimeUnit::SECOND), duration(TimeUnit::SECOND)};
ty = CommonTemporalResolution(args.data(), args.size());
ReplaceTemporalTypes(ty, &args);
AssertTypeEqual(args[0].type, time32(TimeUnit::SECOND));
AssertTypeEqual(args[1].type, duration(TimeUnit::SECOND));

args = {time64(TimeUnit::MICRO), duration(TimeUnit::SECOND)};
ty = CommonTemporalResolution(args.data(), args.size());
ReplaceTemporalTypes(ty, &args);
AssertTypeEqual(args[0].type, time64(TimeUnit::MICRO));
AssertTypeEqual(args[1].type, duration(TimeUnit::MICRO));

args = {time32(TimeUnit::SECOND), duration(TimeUnit::NANO)};
ty = CommonTemporalResolution(args.data(), args.size());
ReplaceTemporalTypes(ty, &args);
AssertTypeEqual(args[0].type, time64(TimeUnit::NANO));
AssertTypeEqual(args[1].type, duration(TimeUnit::NANO));
}

} // namespace internal
Expand Down
88 changes: 88 additions & 0 deletions cpp/src/arrow/compute/kernels/scalar_arithmetic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,62 @@ struct SubtractCheckedDate32 {
}
};

template <bool is_32bit, int64_t multiple>
struct SubtractTimeDuration {
template <typename T, typename Arg0, typename Arg1>
static enable_if_t<is_32bit, T> Call(KernelContext*, Arg0 left, Arg1 right,
Status* st) {
T result = arrow::internal::SafeSignedSubtract(left, static_cast<T>(right));
if (result < 0 || multiple <= result) {
*st = Status::Invalid(result, " is not within the acceptable range of ", "[0, ",
multiple, ") s");
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we could subtract a negative duration, we should also check if the output is larger than the maximum acceptable value.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alternatively - maybe this arithmetic should wrap around? Does that make any sense?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Huh. I wonder how expected this would be. But you're proposing something like this:

if (result < 0) {
  result = result % multiple;
}

Correct?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right. Though, I'm not sure if that really makes any sense to have.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It feels semantically correct but I'm going to check what Pandas and others do to see what conventions are.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I poked around stdlib datetime but it doesn't seem to implement this. However, I still think we need to check whether the result is in range on both ends.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I found this reference: https://bugs.python.org/issue1487389#msg54803

I think raising an error is OK. If we need wrap-around, we can add that as a separate kernel perhaps.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should still raise an error if the result is >24:00 though.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for looking into this! Pandas appears delegate time to Python and rather work with timedelta which is like our duration.
I made this raise on time < 0 and time >= 24:00. Lets revisit if/when desired.

return result;
}

template <typename T, typename Arg0, typename Arg1>
static enable_if_t<!is_32bit, T> Call(KernelContext*, Arg0 left, Arg1 right,
Status* st) {
T result = arrow::internal::SafeSignedSubtract(left, right);
if (result < 0 || multiple <= result) {
*st = Status::Invalid(result, " is not within the acceptable range of ", "[0, ",
multiple, ") s");
}
return result;
}
};

template <bool is_32bit, int64_t multiple>
struct SubtractTimeDurationChecked {
template <typename T, typename Arg0, typename Arg1>
static enable_if_t<is_32bit, T> Call(KernelContext*, Arg0 left, Arg1 right,
Status* st) {
T result = 0;
if (ARROW_PREDICT_FALSE(SubtractWithOverflow(left, static_cast<T>(right), &result))) {
*st = Status::Invalid("overflow");
}
if (result < 0 || multiple <= result) {
*st = Status::Invalid(result, " is not within the acceptable range of ", "[0, ",
multiple, ") s");
}
return result;
}

template <typename T, typename Arg0, typename Arg1>
static enable_if_t<!is_32bit, T> Call(KernelContext*, Arg0 left, Arg1 right,
Status* st) {
T result = 0;
if (ARROW_PREDICT_FALSE(SubtractWithOverflow(left, static_cast<T>(right), &result))) {
*st = Status::Invalid("overflow");
}
if (result < 0 || multiple <= result) {
*st = Status::Invalid(result, " is not within the acceptable range of ", "[0, ",
multiple, ") s");
}
return result;
}
};

struct Multiply {
static_assert(std::is_same<decltype(int8_t() * int8_t()), int32_t>::value, "");
static_assert(std::is_same<decltype(uint8_t() * uint8_t()), int32_t>::value, "");
Expand Down Expand Up @@ -2132,6 +2188,34 @@ std::shared_ptr<ScalarFunction> MakeArithmeticFunctionFloatingPointNotNull(
return func;
}

template <template <bool, int64_t> class Op>
void AddArithmeticFunctionTimeDurations(std::shared_ptr<ScalarFunction> func) {
// Add Op(time32, duration) -> time32
TimeUnit::type unit = TimeUnit::SECOND;
auto exec_1 = ScalarBinary<Time32Type, Time32Type, DurationType, Op<true, 86400>>::Exec;
DCHECK_OK(func->AddKernel({time32(unit), duration(unit)}, OutputType(FirstType),
std::move(exec_1)));

unit = TimeUnit::MILLI;
auto exec_2 =
ScalarBinary<Time32Type, Time32Type, DurationType, Op<true, 86400000>>::Exec;
DCHECK_OK(func->AddKernel({time32(unit), duration(unit)}, OutputType(FirstType),
std::move(exec_2)));

// Add Op(time64, duration) -> time64
unit = TimeUnit::MICRO;
auto exec_3 =
ScalarBinary<Time64Type, Time64Type, DurationType, Op<false, 86400000000>>::Exec;
DCHECK_OK(func->AddKernel({time64(unit), duration(unit)}, OutputType(FirstType),
std::move(exec_3)));

unit = TimeUnit::NANO;
auto exec_4 =
ScalarBinary<Time64Type, Time64Type, DurationType, Op<false, 86400000000000>>::Exec;
DCHECK_OK(func->AddKernel({time64(unit), duration(unit)}, OutputType(FirstType),
std::move(exec_4)));
}

const FunctionDoc absolute_value_doc{
"Calculate the absolute value of the argument element-wise",
("Results will wrap around on integer overflow.\n"
Expand Down Expand Up @@ -2507,6 +2591,8 @@ void RegisterScalarArithmetic(FunctionRegistry* registry) {
DCHECK_OK(subtract->AddKernel({in_type_date_64, in_type_date_64},
duration(TimeUnit::MILLI), std::move(exec_date_64)));

AddArithmeticFunctionTimeDurations<SubtractTimeDuration>(subtract);

DCHECK_OK(registry->AddFunction(std::move(subtract)));

// ----------------------------------------------------------------------
Expand Down Expand Up @@ -2563,6 +2649,8 @@ void RegisterScalarArithmetic(FunctionRegistry* registry) {
subtract_checked->AddKernel({in_type, in_type}, duration(unit), std::move(exec)));
}

AddArithmeticFunctionTimeDurations<SubtractTimeDurationChecked>(subtract_checked);

DCHECK_OK(registry->AddFunction(std::move(subtract_checked)));

// ----------------------------------------------------------------------
Expand Down
58 changes: 58 additions & 0 deletions cpp/src/arrow/compute/kernels/scalar_temporal_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1156,6 +1156,64 @@ TEST_F(ScalarTemporalTest, TestTemporalSubtractTime) {
}
}

TEST_F(ScalarTemporalTest, TestTemporalSubtractTimeAndDuration) {
for (auto op : {"subtract", "subtract_checked"}) {
auto arr_s = ArrayFromJSON(time32(TimeUnit::SECOND), times_s);
auto arr_s2 = ArrayFromJSON(time32(TimeUnit::SECOND), times_s2);
auto arr_ms = ArrayFromJSON(time32(TimeUnit::MILLI), times_ms);
auto arr_ms2 = ArrayFromJSON(time32(TimeUnit::MILLI), times_ms2);
auto arr_us = ArrayFromJSON(time64(TimeUnit::MICRO), times_us);
auto arr_us2 = ArrayFromJSON(time64(TimeUnit::MICRO), times_us2);
auto arr_ns = ArrayFromJSON(time64(TimeUnit::NANO), times_ns);
auto arr_ns2 = ArrayFromJSON(time64(TimeUnit::NANO), times_ns2);

CheckScalarBinary(op, arr_s2,
ArrayFromJSON(duration(TimeUnit::SECOND), seconds_between_time),
arr_s);
CheckScalarBinary(op, arr_ms2,
ArrayFromJSON(duration(TimeUnit::MILLI), milliseconds_between_time),
arr_ms);
CheckScalarBinary(op, arr_us2,
ArrayFromJSON(duration(TimeUnit::MICRO), microseconds_between_time),
arr_us);
CheckScalarBinary(op, arr_ns2,
ArrayFromJSON(duration(TimeUnit::NANO), nanoseconds_between_time),
arr_ns);

auto seconds_3 = ArrayFromJSON(time32(TimeUnit::SECOND), R"([3, null])");
auto milliseconds_2k = ArrayFromJSON(duration(TimeUnit::MILLI), R"([2000, null])");
auto milliseconds_1k = ArrayFromJSON(time32(TimeUnit::MILLI), R"([1000, null])");
auto nanoseconds_3G = ArrayFromJSON(time64(TimeUnit::NANO), R"([3000000000, null])");
auto microseconds_2M = ArrayFromJSON(duration(TimeUnit::MICRO), R"([2000000, null])");
auto nanoseconds_1M = ArrayFromJSON(time64(TimeUnit::NANO), R"([1000000000, null])");
auto microseconds_1M = ArrayFromJSON(time64(TimeUnit::MICRO), R"([1000000, null])");
CheckScalarBinary(op, seconds_3, milliseconds_2k, milliseconds_1k);
CheckScalarBinary(op, nanoseconds_3G, microseconds_2M, nanoseconds_1M);
CheckScalarBinary(op, seconds_3, microseconds_2M, microseconds_1M);

EXPECT_RAISES_WITH_MESSAGE_THAT(
Invalid,
::testing::HasSubstr("-1 is not within the acceptable range of [0, 86400)"),
CallFunction(op, {ArrayFromJSON(time32(TimeUnit::SECOND), R"([1, null])"),
ArrayFromJSON(duration(TimeUnit::SECOND), R"([2, null])")}));

EXPECT_RAISES_WITH_MESSAGE_THAT(
Invalid,
::testing::HasSubstr(
"-1000 is not within the acceptable range of [0, 86400000000000)"),
CallFunction(op, {ArrayFromJSON(time64(TimeUnit::MICRO), R"([1, null])"),
ArrayFromJSON(duration(TimeUnit::NANO), R"([2000, null])")}));

EXPECT_RAISES_WITH_MESSAGE_THAT(
Invalid,
::testing::HasSubstr(
"86400000001000 is not within the acceptable range of [0, 86400000000000)"),
CallFunction(op,
{ArrayFromJSON(time64(TimeUnit::MICRO), R"([86400000000, null])"),
ArrayFromJSON(duration(TimeUnit::NANO), R"([-1000, null])")}));
}
}

TEST_F(ScalarTemporalTest, TestTemporalDifferenceWeeks) {
auto raw_days = ArrayFromJSON(timestamp(TimeUnit::SECOND), R"([
"2021-08-09", "2021-08-10", "2021-08-11", "2021-08-12", "2021-08-13", "2021-08-14", "2021-08-15",
Expand Down