From 0e7fabaf369187996f9c87648ddce7a4acbd3068 Mon Sep 17 00:00:00 2001 From: Rok Date: Fri, 11 Mar 2022 21:02:13 +0100 Subject: [PATCH 1/4] Temporal add/add_checked should be commutative. --- .../compute/kernels/codegen_internal_test.cc | 4 + .../compute/kernels/scalar_arithmetic.cc | 104 ++++++++++++++---- .../compute/kernels/scalar_temporal_test.cc | 14 +++ 3 files changed, 101 insertions(+), 21 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/codegen_internal_test.cc b/cpp/src/arrow/compute/kernels/codegen_internal_test.cc index ba9493ea875..46d31c8ae4c 100644 --- a/cpp/src/arrow/compute/kernels/codegen_internal_test.cc +++ b/cpp/src/arrow/compute/kernels/codegen_internal_test.cc @@ -182,6 +182,8 @@ TEST(TestDispatchBest, CommonTemporalResolution) { ASSERT_EQ(TimeUnit::MILLI, CommonTemporalResolution(args.data(), args.size())); args = {duration(TimeUnit::SECOND), time32(TimeUnit::SECOND)}; ASSERT_EQ(TimeUnit::SECOND, CommonTemporalResolution(args.data(), args.size())); + args = {duration(TimeUnit::SECOND), time64(TimeUnit::NANO)}; + ASSERT_EQ(TimeUnit::NANO, CommonTemporalResolution(args.data(), args.size())); args = {time64(TimeUnit::MICRO), duration(TimeUnit::NANO)}; ASSERT_EQ(TimeUnit::NANO, CommonTemporalResolution(args.data(), args.size())); args = {timestamp(TimeUnit::SECOND, tz), timestamp(TimeUnit::MICRO)}; @@ -204,6 +206,8 @@ TEST(TestDispatchBest, CommonTemporalResolution) { ASSERT_EQ(TimeUnit::NANO, CommonTemporalResolution(args.data(), args.size())); args = {duration(TimeUnit::SECOND), int64()}; ASSERT_EQ(TimeUnit::SECOND, CommonTemporalResolution(args.data(), args.size())); + args = {duration(TimeUnit::MILLI), timestamp(TimeUnit::SECOND, tz)}; + ASSERT_EQ(TimeUnit::MILLI, CommonTemporalResolution(args.data(), args.size())); } TEST(TestDispatchBest, ReplaceTemporalTypes) { diff --git a/cpp/src/arrow/compute/kernels/scalar_arithmetic.cc b/cpp/src/arrow/compute/kernels/scalar_arithmetic.cc index 2d543e23266..8b8b0046f4a 100644 --- a/cpp/src/arrow/compute/kernels/scalar_arithmetic.cc +++ b/cpp/src/arrow/compute/kernels/scalar_arithmetic.cc @@ -240,11 +240,11 @@ struct SubtractCheckedDate32 { } }; -template +template struct AddTimeDuration { template - static enable_if_t Call(KernelContext*, Arg0 left, Arg1 right, - Status* st) { + static enable_if_t<(!std::is_same::value && left_is_32bit), T> Call( + KernelContext*, Arg0 left, Arg1 right, Status* st) { T result = arrow::internal::SafeSignedAdd(left, static_cast(right)); if (result < 0 || multiple <= result) { *st = Status::Invalid(result, " is not within the acceptable range of ", "[0, ", @@ -254,8 +254,19 @@ struct AddTimeDuration { } template - static enable_if_t Call(KernelContext*, Arg0 left, Arg1 right, - Status* st) { + static enable_if_t<(!std::is_same::value && !left_is_32bit), T> Call( + KernelContext*, Arg0 left, Arg1 right, Status* st) { + T result = arrow::internal::SafeSignedAdd(static_cast(left), right); + if (result < 0 || multiple <= result) { + *st = Status::Invalid(result, " is not within the acceptable range of ", "[0, ", + multiple, ") s"); + } + return result; + } + + template + static enable_if_t<(std::is_same::value), T> Call(KernelContext*, Arg0 left, + Arg1 right, Status* st) { T result = arrow::internal::SafeSignedAdd(left, right); if (result < 0 || multiple <= result) { *st = Status::Invalid(result, " is not within the acceptable range of ", "[0, ", @@ -265,11 +276,11 @@ struct AddTimeDuration { } }; -template +template struct AddTimeDurationChecked { template - static enable_if_t Call(KernelContext*, Arg0 left, Arg1 right, - Status* st) { + static enable_if_t<(!std::is_same::value && left_is_32bit), T> Call( + KernelContext*, Arg0 left, Arg1 right, Status* st) { T result = 0; if (ARROW_PREDICT_FALSE(AddWithOverflow(left, static_cast(right), &result))) { *st = Status::Invalid("overflow"); @@ -282,10 +293,24 @@ struct AddTimeDurationChecked { } template - static enable_if_t Call(KernelContext*, Arg0 left, Arg1 right, - Status* st) { + static enable_if_t<(!std::is_same::value && !left_is_32bit), T> Call( + KernelContext*, Arg0 left, Arg1 right, Status* st) { T result = 0; - if (ARROW_PREDICT_FALSE(AddWithOverflow(left, static_cast(right), &result))) { + if (ARROW_PREDICT_FALSE(AddWithOverflow(static_cast(left), right, &result))) { + *st = Status::Invalid("overflow"); + } + if (result < 0 || multiple <= result) { + *st = Status::Invalid(result, " is not within the acceptable range of ", "[0, ", + multiple, ") s"); + } + return result; + } + + template + static enable_if_t<(std::is_same::value), T> Call(KernelContext*, Arg0 left, + Arg1 right, Status* st) { + T result = 0; + if (ARROW_PREDICT_FALSE(AddWithOverflow(left, right, &result))) { *st = Status::Invalid("overflow"); } if (result < 0 || multiple <= result) { @@ -341,7 +366,7 @@ struct SubtractTimeDurationChecked { static enable_if_t Call(KernelContext*, Arg0 left, Arg1 right, Status* st) { T result = 0; - if (ARROW_PREDICT_FALSE(SubtractWithOverflow(left, static_cast(right), &result))) { + if (ARROW_PREDICT_FALSE(SubtractWithOverflow(left, right, &result))) { *st = Status::Invalid("overflow"); } if (result < 0 || multiple <= result) { @@ -2270,7 +2295,7 @@ std::shared_ptr MakeArithmeticFunctionFloatingPointNotNull( } template