diff --git a/cpp/src/arrow/compute/kernels/codegen_internal_test.cc b/cpp/src/arrow/compute/kernels/codegen_internal_test.cc index bf1dff53adb..ba9493ea875 100644 --- a/cpp/src/arrow/compute/kernels/codegen_internal_test.cc +++ b/cpp/src/arrow/compute/kernels/codegen_internal_test.cc @@ -202,6 +202,8 @@ TEST(TestDispatchBest, CommonTemporalResolution) { ASSERT_EQ(TimeUnit::MILLI, CommonTemporalResolution(args.data(), args.size())); args = {time64(TimeUnit::MICRO), duration(TimeUnit::NANO)}; ASSERT_EQ(TimeUnit::NANO, CommonTemporalResolution(args.data(), args.size())); + args = {duration(TimeUnit::SECOND), int64()}; + ASSERT_EQ(TimeUnit::SECOND, CommonTemporalResolution(args.data(), args.size())); } TEST(TestDispatchBest, ReplaceTemporalTypes) { @@ -268,6 +270,11 @@ TEST(TestDispatchBest, ReplaceTemporalTypes) { ReplaceTemporalTypes(ty, &args); AssertTypeEqual(args[0].type, time64(TimeUnit::NANO)); AssertTypeEqual(args[1].type, duration(TimeUnit::NANO)); + + args = {duration(TimeUnit::SECOND), int64()}; + ReplaceTemporalTypes(CommonTemporalResolution(args.data(), args.size()), &args); + AssertTypeEqual(args[0].type, duration(TimeUnit::SECOND)); + AssertTypeEqual(args[1].type, int64()); } } // namespace internal diff --git a/cpp/src/arrow/compute/kernels/scalar_arithmetic.cc b/cpp/src/arrow/compute/kernels/scalar_arithmetic.cc index 2c30c87d2cd..dfbcba9ec3f 100644 --- a/cpp/src/arrow/compute/kernels/scalar_arithmetic.cc +++ b/cpp/src/arrow/compute/kernels/scalar_arithmetic.cc @@ -44,6 +44,7 @@ namespace internal { using applicator::ScalarBinary; using applicator::ScalarBinaryEqualTypes; +using applicator::ScalarBinaryNotNull; using applicator::ScalarBinaryNotNullEqualTypes; using applicator::ScalarUnary; using applicator::ScalarUnaryNotNull; @@ -1647,8 +1648,8 @@ ArrayKernelExec ArithmeticExecFromOp(detail::GetTypeId get_id) { return KernelGenerator::Exec; case Type::UINT32: return KernelGenerator::Exec; - case Type::DURATION: case Type::INT64: + case Type::DURATION: case Type::TIMESTAMP: return KernelGenerator::Exec; case Type::UINT64: @@ -2802,23 +2803,63 @@ void RegisterScalarArithmetic(FunctionRegistry* registry) { // ---------------------------------------------------------------------- auto multiply = MakeArithmeticFunction("multiply", &mul_doc); AddDecimalBinaryKernels("multiply", multiply.get()); + + // Add multiply(duration, int64) -> duration + for (auto unit : TimeUnit::values()) { + auto exec1 = ArithmeticExecFromOp(Type::DURATION); + DCHECK_OK( + multiply->AddKernel({duration(unit), int64()}, duration(unit), std::move(exec1))); + auto exec2 = ArithmeticExecFromOp(Type::DURATION); + DCHECK_OK( + multiply->AddKernel({int64(), duration(unit)}, duration(unit), std::move(exec2))); + } + DCHECK_OK(registry->AddFunction(std::move(multiply))); // ---------------------------------------------------------------------- auto multiply_checked = MakeArithmeticFunctionNotNull( "multiply_checked", &mul_checked_doc); AddDecimalBinaryKernels("multiply_checked", multiply_checked.get()); + + // Add multiply_checked(duration, int64) -> duration + for (auto unit : TimeUnit::values()) { + auto exec1 = + ArithmeticExecFromOp(Type::DURATION); + DCHECK_OK(multiply_checked->AddKernel({duration(unit), int64()}, duration(unit), + std::move(exec1))); + auto exec2 = + ArithmeticExecFromOp(Type::DURATION); + DCHECK_OK(multiply_checked->AddKernel({int64(), duration(unit)}, duration(unit), + std::move(exec2))); + } + DCHECK_OK(registry->AddFunction(std::move(multiply_checked))); // ---------------------------------------------------------------------- auto divide = MakeArithmeticFunctionNotNull("divide", &div_doc); AddDecimalBinaryKernels("divide", divide.get()); + + // Add divide(duration, int64) -> duration + for (auto unit : TimeUnit::values()) { + auto exec = ScalarBinaryNotNull::Exec; + DCHECK_OK( + divide->AddKernel({duration(unit), int64()}, duration(unit), std::move(exec))); + } DCHECK_OK(registry->AddFunction(std::move(divide))); // ---------------------------------------------------------------------- auto divide_checked = MakeArithmeticFunctionNotNull("divide_checked", &div_checked_doc); AddDecimalBinaryKernels("divide_checked", divide_checked.get()); + + // Add divide_checked(duration, int64) -> duration + for (auto unit : TimeUnit::values()) { + auto exec = + ScalarBinaryNotNull::Exec; + DCHECK_OK(divide_checked->AddKernel({duration(unit), int64()}, duration(unit), + std::move(exec))); + } + DCHECK_OK(registry->AddFunction(std::move(divide_checked))); // ---------------------------------------------------------------------- diff --git a/cpp/src/arrow/compute/kernels/scalar_temporal_test.cc b/cpp/src/arrow/compute/kernels/scalar_temporal_test.cc index 560699c09ed..ff7a169094f 100644 --- a/cpp/src/arrow/compute/kernels/scalar_temporal_test.cc +++ b/cpp/src/arrow/compute/kernels/scalar_temporal_test.cc @@ -1455,6 +1455,50 @@ TEST_F(ScalarTemporalTest, TestTemporalSubtractDuration) { } } +TEST_F(ScalarTemporalTest, TestTemporalMultiplyDuration) { + std::shared_ptr max_array; + auto max = std::numeric_limits::max(); + ArrayFromVector({max, max, max, max, max}, &max_array); + + for (auto u : TimeUnit::values()) { + auto unit = duration(u); + auto durations = ArrayFromJSON(unit, R"([0, -1, 2, 6, null])"); + auto multipliers = ArrayFromJSON(int64(), R"([0, 3, 2, 7, null])"); + auto durations_multiplied = ArrayFromJSON(unit, R"([0, -3, 4, 42, null])"); + + CheckScalarBinary("multiply", durations, multipliers, durations_multiplied); + CheckScalarBinary("multiply", multipliers, durations, durations_multiplied); + CheckScalarBinary("multiply_checked", durations, multipliers, durations_multiplied); + CheckScalarBinary("multiply_checked", multipliers, durations, durations_multiplied); + + EXPECT_RAISES_WITH_MESSAGE_THAT( + Invalid, ::testing::HasSubstr("Invalid: overflow"), + CallFunction("multiply_checked", {durations, max_array})); + EXPECT_RAISES_WITH_MESSAGE_THAT( + Invalid, ::testing::HasSubstr("Invalid: overflow"), + CallFunction("multiply_checked", {max_array, durations})); + } +} + +TEST_F(ScalarTemporalTest, TestTemporalDivideDuration) { + for (auto u : TimeUnit::values()) { + auto unit = duration(u); + auto divided_durations = ArrayFromJSON(unit, R"([0, -1, -2, 6, null])"); + auto divisors = ArrayFromJSON(int64(), R"([3, 3, -2, 7, null])"); + auto durations = ArrayFromJSON(unit, R"([1, -3, 4, 42, null])"); + auto zeros = ArrayFromJSON(int64(), R"([0, 0, 0, 0, null])"); + CheckScalarBinary("divide", durations, divisors, divided_durations); + CheckScalarBinary("divide_checked", durations, divisors, divided_durations); + + EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid, + ::testing::HasSubstr("Invalid: divide by zero"), + CallFunction("divide", {durations, zeros})); + EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid, + ::testing::HasSubstr("Invalid: divide by zero"), + CallFunction("divide_checked", {durations, zeros})); + } +} + TEST_F(ScalarTemporalTest, TestTemporalDifferenceWeeks) { auto raw_days = ArrayFromJSON(timestamp(TimeUnit::SECOND), R"([ "2021-08-09", "2021-08-10", "2021-08-11", "2021-08-12", "2021-08-13", "2021-08-14", "2021-08-15", diff --git a/docs/source/cpp/compute.rst b/docs/source/cpp/compute.rst index 56582c08a38..9c6dc321568 100644 --- a/docs/source/cpp/compute.rst +++ b/docs/source/cpp/compute.rst @@ -441,13 +441,13 @@ Mixed time resolution temporal inputs will be cast to finest input resolution. +------------------+--------+------------------+----------------------+-------+ | add_checked | Binary | Numeric/Temporal | Numeric/Temporal | \(1) | +------------------+--------+------------------+----------------------+-------+ -| divide | Binary | Numeric | Numeric | \(1) | +| divide | Binary | Numeric/Temporal | Numeric/Temporal | \(1) | +------------------+--------+------------------+----------------------+-------+ -| divide_checked | Binary | Numeric | Numeric | \(1) | +| divide_checked | Binary | Numeric/Temporal | Numeric/Temporal | \(1) | +------------------+--------+------------------+----------------------+-------+ -| multiply | Binary | Numeric | Numeric | \(1) | +| multiply | Binary | Numeric/Temporal | Numeric/Temporal | \(1) | +------------------+--------+------------------+----------------------+-------+ -| multiply_checked | Binary | Numeric | Numeric | \(1) | +| multiply_checked | Binary | Numeric/Temporal | Numeric/Temporal | \(1) | +------------------+--------+------------------+----------------------+-------+ | negate | Unary | Numeric | Numeric | | +------------------+--------+------------------+----------------------+-------+