diff --git a/cpp/src/arrow/compute/kernels/codegen_internal_test.cc b/cpp/src/arrow/compute/kernels/codegen_internal_test.cc index c2e17e377f2..bf1dff53adb 100644 --- a/cpp/src/arrow/compute/kernels/codegen_internal_test.cc +++ b/cpp/src/arrow/compute/kernels/codegen_internal_test.cc @@ -247,8 +247,27 @@ TEST(TestDispatchBest, ReplaceTemporalTypes) { args = {timestamp(TimeUnit::SECOND, "UTC"), timestamp(TimeUnit::SECOND, tz)}; ty = CommonTemporalResolution(args.data(), args.size()); + ReplaceTemporalTypes(ty, &args); AssertTypeEqual(args[0].type, timestamp(TimeUnit::SECOND, "UTC")); AssertTypeEqual(args[1].type, timestamp(TimeUnit::SECOND, tz)); + + args = {time32(TimeUnit::SECOND), duration(TimeUnit::SECOND)}; + ty = CommonTemporalResolution(args.data(), args.size()); + ReplaceTemporalTypes(ty, &args); + AssertTypeEqual(args[0].type, time32(TimeUnit::SECOND)); + AssertTypeEqual(args[1].type, duration(TimeUnit::SECOND)); + + args = {time64(TimeUnit::MICRO), duration(TimeUnit::SECOND)}; + ty = CommonTemporalResolution(args.data(), args.size()); + ReplaceTemporalTypes(ty, &args); + AssertTypeEqual(args[0].type, time64(TimeUnit::MICRO)); + AssertTypeEqual(args[1].type, duration(TimeUnit::MICRO)); + + args = {time32(TimeUnit::SECOND), duration(TimeUnit::NANO)}; + ty = CommonTemporalResolution(args.data(), args.size()); + ReplaceTemporalTypes(ty, &args); + AssertTypeEqual(args[0].type, time64(TimeUnit::NANO)); + AssertTypeEqual(args[1].type, duration(TimeUnit::NANO)); } } // namespace internal diff --git a/cpp/src/arrow/compute/kernels/scalar_arithmetic.cc b/cpp/src/arrow/compute/kernels/scalar_arithmetic.cc index dd0476584fc..3c3fd92d53b 100644 --- a/cpp/src/arrow/compute/kernels/scalar_arithmetic.cc +++ b/cpp/src/arrow/compute/kernels/scalar_arithmetic.cc @@ -239,6 +239,62 @@ struct SubtractCheckedDate32 { } }; +template +struct SubtractTimeDuration { + template + static enable_if_t Call(KernelContext*, Arg0 left, Arg1 right, + Status* st) { + T result = arrow::internal::SafeSignedSubtract(left, static_cast(right)); + if (result < 0 || multiple <= result) { + *st = Status::Invalid(result, " is not within the acceptable range of ", "[0, ", + multiple, ") s"); + } + return result; + } + + template + static enable_if_t Call(KernelContext*, Arg0 left, Arg1 right, + Status* st) { + T result = arrow::internal::SafeSignedSubtract(left, right); + if (result < 0 || multiple <= result) { + *st = Status::Invalid(result, " is not within the acceptable range of ", "[0, ", + multiple, ") s"); + } + return result; + } +}; + +template +struct SubtractTimeDurationChecked { + template + static enable_if_t Call(KernelContext*, Arg0 left, Arg1 right, + Status* st) { + T result = 0; + if (ARROW_PREDICT_FALSE(SubtractWithOverflow(left, static_cast(right), &result))) { + *st = Status::Invalid("overflow"); + } + if (result < 0 || multiple <= result) { + *st = Status::Invalid(result, " is not within the acceptable range of ", "[0, ", + multiple, ") s"); + } + return result; + } + + template + static enable_if_t Call(KernelContext*, Arg0 left, Arg1 right, + Status* st) { + T result = 0; + if (ARROW_PREDICT_FALSE(SubtractWithOverflow(left, static_cast(right), &result))) { + *st = Status::Invalid("overflow"); + } + if (result < 0 || multiple <= result) { + *st = Status::Invalid(result, " is not within the acceptable range of ", "[0, ", + multiple, ") s"); + } + return result; + } +}; + struct Multiply { static_assert(std::is_same::value, ""); static_assert(std::is_same::value, ""); @@ -2132,6 +2188,34 @@ std::shared_ptr MakeArithmeticFunctionFloatingPointNotNull( return func; } +template