diff --git a/cpp/src/arrow/compute/kernels/codegen_internal_test.cc b/cpp/src/arrow/compute/kernels/codegen_internal_test.cc index ba9493ea875..46d31c8ae4c 100644 --- a/cpp/src/arrow/compute/kernels/codegen_internal_test.cc +++ b/cpp/src/arrow/compute/kernels/codegen_internal_test.cc @@ -182,6 +182,8 @@ TEST(TestDispatchBest, CommonTemporalResolution) { ASSERT_EQ(TimeUnit::MILLI, CommonTemporalResolution(args.data(), args.size())); args = {duration(TimeUnit::SECOND), time32(TimeUnit::SECOND)}; ASSERT_EQ(TimeUnit::SECOND, CommonTemporalResolution(args.data(), args.size())); + args = {duration(TimeUnit::SECOND), time64(TimeUnit::NANO)}; + ASSERT_EQ(TimeUnit::NANO, CommonTemporalResolution(args.data(), args.size())); args = {time64(TimeUnit::MICRO), duration(TimeUnit::NANO)}; ASSERT_EQ(TimeUnit::NANO, CommonTemporalResolution(args.data(), args.size())); args = {timestamp(TimeUnit::SECOND, tz), timestamp(TimeUnit::MICRO)}; @@ -204,6 +206,8 @@ TEST(TestDispatchBest, CommonTemporalResolution) { ASSERT_EQ(TimeUnit::NANO, CommonTemporalResolution(args.data(), args.size())); args = {duration(TimeUnit::SECOND), int64()}; ASSERT_EQ(TimeUnit::SECOND, CommonTemporalResolution(args.data(), args.size())); + args = {duration(TimeUnit::MILLI), timestamp(TimeUnit::SECOND, tz)}; + ASSERT_EQ(TimeUnit::MILLI, CommonTemporalResolution(args.data(), args.size())); } TEST(TestDispatchBest, ReplaceTemporalTypes) { diff --git a/cpp/src/arrow/compute/kernels/scalar_arithmetic.cc b/cpp/src/arrow/compute/kernels/scalar_arithmetic.cc index 2d543e23266..bfafb6fcc19 100644 --- a/cpp/src/arrow/compute/kernels/scalar_arithmetic.cc +++ b/cpp/src/arrow/compute/kernels/scalar_arithmetic.cc @@ -240,23 +240,12 @@ struct SubtractCheckedDate32 { } }; -template +template struct AddTimeDuration { template - static enable_if_t Call(KernelContext*, Arg0 left, Arg1 right, - Status* st) { - T result = arrow::internal::SafeSignedAdd(left, static_cast(right)); - if (result < 0 || multiple <= result) { - *st = Status::Invalid(result, " is not within the acceptable range of ", "[0, ", - multiple, ") s"); - } - return result; - } - - template - static enable_if_t Call(KernelContext*, Arg0 left, Arg1 right, - Status* st) { - T result = arrow::internal::SafeSignedAdd(left, right); + static T Call(KernelContext*, Arg0 left, Arg1 right, Status* st) { + T result = + arrow::internal::SafeSignedAdd(static_cast(left), static_cast(right)); if (result < 0 || multiple <= result) { *st = Status::Invalid(result, " is not within the acceptable range of ", "[0, ", multiple, ") s"); @@ -265,27 +254,13 @@ struct AddTimeDuration { } }; -template +template struct AddTimeDurationChecked { template - static enable_if_t Call(KernelContext*, Arg0 left, Arg1 right, - Status* st) { - T result = 0; - if (ARROW_PREDICT_FALSE(AddWithOverflow(left, static_cast(right), &result))) { - *st = Status::Invalid("overflow"); - } - if (result < 0 || multiple <= result) { - *st = Status::Invalid(result, " is not within the acceptable range of ", "[0, ", - multiple, ") s"); - } - return result; - } - - template - static enable_if_t Call(KernelContext*, Arg0 left, Arg1 right, - Status* st) { + static T Call(KernelContext*, Arg0 left, Arg1 right, Status* st) { T result = 0; - if (ARROW_PREDICT_FALSE(AddWithOverflow(left, static_cast(right), &result))) { + if (ARROW_PREDICT_FALSE( + AddWithOverflow(static_cast(left), static_cast(right), &result))) { *st = Status::Invalid("overflow"); } if (result < 0 || multiple <= result) { @@ -296,11 +271,10 @@ struct AddTimeDurationChecked { } }; -template +template struct SubtractTimeDuration { template - static enable_if_t Call(KernelContext*, Arg0 left, Arg1 right, - Status* st) { + static T Call(KernelContext*, Arg0 left, Arg1 right, Status* st) { T result = arrow::internal::SafeSignedSubtract(left, static_cast(right)); if (result < 0 || multiple <= result) { *st = Status::Invalid(result, " is not within the acceptable range of ", "[0, ", @@ -308,38 +282,12 @@ struct SubtractTimeDuration { } return result; } - - template - static enable_if_t Call(KernelContext*, Arg0 left, Arg1 right, - Status* st) { - T result = arrow::internal::SafeSignedSubtract(left, right); - if (result < 0 || multiple <= result) { - *st = Status::Invalid(result, " is not within the acceptable range of ", "[0, ", - multiple, ") s"); - } - return result; - } }; -template +template struct SubtractTimeDurationChecked { template - static enable_if_t Call(KernelContext*, Arg0 left, Arg1 right, - Status* st) { - T result = 0; - if (ARROW_PREDICT_FALSE(SubtractWithOverflow(left, static_cast(right), &result))) { - *st = Status::Invalid("overflow"); - } - if (result < 0 || multiple <= result) { - *st = Status::Invalid(result, " is not within the acceptable range of ", "[0, ", - multiple, ") s"); - } - return result; - } - - template - static enable_if_t Call(KernelContext*, Arg0 left, Arg1 right, - Status* st) { + static T Call(KernelContext*, Arg0 left, Arg1 right, Status* st) { T result = 0; if (ARROW_PREDICT_FALSE(SubtractWithOverflow(left, static_cast(right), &result))) { *st = Status::Invalid("overflow"); @@ -2269,34 +2217,58 @@ std::shared_ptr MakeArithmeticFunctionFloatingPointNotNull( return func; } -template