diff --git a/cpp/src/arrow/compute/kernels/codegen_internal.cc b/cpp/src/arrow/compute/kernels/codegen_internal.cc index eef5c4081bc..1e06a364509 100644 --- a/cpp/src/arrow/compute/kernels/codegen_internal.cc +++ b/cpp/src/arrow/compute/kernels/codegen_internal.cc @@ -109,6 +109,33 @@ void ReplaceNullWithOtherType(ValueDescr* first, size_t count) { } } +void ReplaceTemporalTypes(const TimeUnit::type unit, std::vector* descrs) { + auto* end = descrs->data() + descrs->size(); + + for (auto* it = descrs->data(); it != end; it++) { + switch (it->type->id()) { + case Type::TIMESTAMP: { + const auto& ty = checked_cast(*it->type); + it->type = timestamp(unit, ty.timezone()); + continue; + } + case Type::TIME32: + case Type::TIME64: + case Type::DURATION: { + it->type = duration(unit); + continue; + } + case Type::DATE32: + case Type::DATE64: { + it->type = timestamp(unit); + continue; + } + default: + continue; + } + } +} + void ReplaceTypes(const std::shared_ptr& type, std::vector* descrs) { ReplaceTypes(type, descrs->data(), descrs->size()); @@ -180,6 +207,47 @@ std::shared_ptr CommonNumeric(const ValueDescr* begin, size_t count) { return int8(); } +TimeUnit::type CommonTemporalResolution(const ValueDescr* begin, size_t count) { + TimeUnit::type finest_unit = TimeUnit::SECOND; + const ValueDescr* end = begin + count; + for (auto it = begin; it != end; it++) { + auto id = it->type->id(); + switch (id) { + case Type::DATE32: { + // Date32's unit is days, but the coarsest we have is seconds + continue; + } + case Type::DATE64: { + finest_unit = std::max(finest_unit, TimeUnit::MILLI); + continue; + } + case Type::TIMESTAMP: { + const auto& ty = checked_cast(*it->type); + finest_unit = std::max(finest_unit, ty.unit()); + continue; + } + case Type::DURATION: { + const auto& ty = checked_cast(*it->type); + finest_unit = std::max(finest_unit, ty.unit()); + continue; + } + case Type::TIME32: { + const auto& ty = checked_cast(*it->type); + finest_unit = std::max(finest_unit, ty.unit()); + continue; + } + case Type::TIME64: { + const auto& ty = checked_cast(*it->type); + finest_unit = std::max(finest_unit, ty.unit()); + continue; + } + default: + continue; + } + } + return finest_unit; +} + std::shared_ptr CommonTemporal(const ValueDescr* begin, size_t count) { TimeUnit::type finest_unit = TimeUnit::SECOND; const std::string* timezone = nullptr; diff --git a/cpp/src/arrow/compute/kernels/codegen_internal.h b/cpp/src/arrow/compute/kernels/codegen_internal.h index 1efd3e22f93..ff7b9161fe3 100644 --- a/cpp/src/arrow/compute/kernels/codegen_internal.h +++ b/cpp/src/arrow/compute/kernels/codegen_internal.h @@ -1381,6 +1381,9 @@ void ReplaceTypes(const std::shared_ptr&, std::vector* des ARROW_EXPORT void ReplaceTypes(const std::shared_ptr&, ValueDescr* descrs, size_t count); +ARROW_EXPORT +void ReplaceTemporalTypes(TimeUnit::type unit, std::vector* descrs); + ARROW_EXPORT std::shared_ptr CommonNumeric(const std::vector& descrs); @@ -1390,6 +1393,9 @@ std::shared_ptr CommonNumeric(const ValueDescr* begin, size_t count); ARROW_EXPORT std::shared_ptr CommonTemporal(const ValueDescr* begin, size_t count); +ARROW_EXPORT +TimeUnit::type CommonTemporalResolution(const ValueDescr* begin, size_t count); + ARROW_EXPORT std::shared_ptr CommonBinary(const ValueDescr* begin, size_t count); diff --git a/cpp/src/arrow/compute/kernels/codegen_internal_test.cc b/cpp/src/arrow/compute/kernels/codegen_internal_test.cc index d64143dea31..6b68632ceb3 100644 --- a/cpp/src/arrow/compute/kernels/codegen_internal_test.cc +++ b/cpp/src/arrow/compute/kernels/codegen_internal_test.cc @@ -158,6 +158,95 @@ TEST(TestDispatchBest, CommonTemporal) { ASSERT_EQ(nullptr, CommonTemporal(args.data(), args.size())); } +TEST(TestDispatchBest, CommonTemporalResolution) { + std::vector args; + std::string tz = "Pacific/Marquesas"; + + args = {date32(), date32()}; + ASSERT_EQ(TimeUnit::SECOND, CommonTemporalResolution(args.data(), args.size())); + args = {date32(), date64()}; + ASSERT_EQ(TimeUnit::MILLI, CommonTemporalResolution(args.data(), args.size())); + args = {time32(TimeUnit::MILLI), date32()}; + ASSERT_EQ(TimeUnit::MILLI, CommonTemporalResolution(args.data(), args.size())); + args = {time32(TimeUnit::MILLI), time32(TimeUnit::SECOND)}; + ASSERT_EQ(TimeUnit::MILLI, CommonTemporalResolution(args.data(), args.size())); + args = {time32(TimeUnit::MILLI), time64(TimeUnit::MICRO)}; + ASSERT_EQ(TimeUnit::MICRO, CommonTemporalResolution(args.data(), args.size())); + args = {time64(TimeUnit::NANO), time64(TimeUnit::MICRO)}; + ASSERT_EQ(TimeUnit::NANO, CommonTemporalResolution(args.data(), args.size())); + args = {duration(TimeUnit::MILLI), duration(TimeUnit::MICRO)}; + ASSERT_EQ(TimeUnit::MICRO, CommonTemporalResolution(args.data(), args.size())); + args = {duration(TimeUnit::MILLI), date32()}; + ASSERT_EQ(TimeUnit::MILLI, CommonTemporalResolution(args.data(), args.size())); + args = {date64(), duration(TimeUnit::SECOND)}; + ASSERT_EQ(TimeUnit::MILLI, CommonTemporalResolution(args.data(), args.size())); + args = {duration(TimeUnit::SECOND), time32(TimeUnit::SECOND)}; + ASSERT_EQ(TimeUnit::SECOND, CommonTemporalResolution(args.data(), args.size())); + args = {time64(TimeUnit::MICRO), duration(TimeUnit::NANO)}; + ASSERT_EQ(TimeUnit::NANO, CommonTemporalResolution(args.data(), args.size())); + args = {timestamp(TimeUnit::SECOND, tz), timestamp(TimeUnit::MICRO)}; + ASSERT_EQ(TimeUnit::MICRO, CommonTemporalResolution(args.data(), args.size())); + args = {date32(), timestamp(TimeUnit::MICRO, tz)}; + ASSERT_EQ(TimeUnit::MICRO, CommonTemporalResolution(args.data(), args.size())); + args = {timestamp(TimeUnit::MICRO, tz), date64()}; + ASSERT_EQ(TimeUnit::MICRO, CommonTemporalResolution(args.data(), args.size())); + args = {time32(TimeUnit::MILLI), timestamp(TimeUnit::MICRO, tz)}; + ASSERT_EQ(TimeUnit::MICRO, CommonTemporalResolution(args.data(), args.size())); + args = {timestamp(TimeUnit::MICRO, tz), time64(TimeUnit::NANO)}; + ASSERT_EQ(TimeUnit::NANO, CommonTemporalResolution(args.data(), args.size())); + args = {timestamp(TimeUnit::SECOND, tz), duration(TimeUnit::MILLI)}; + ASSERT_EQ(TimeUnit::MILLI, CommonTemporalResolution(args.data(), args.size())); + args = {timestamp(TimeUnit::SECOND, "UTC"), timestamp(TimeUnit::SECOND, tz)}; + ASSERT_EQ(TimeUnit::SECOND, CommonTemporalResolution(args.data(), args.size())); +} + +TEST(TestDispatchBest, ReplaceTemporalTypes) { + std::vector args; + std::string tz = "Pacific/Marquesas"; + TimeUnit::type ty; + + args = {date32(), date32()}; + ty = CommonTemporalResolution(args.data(), args.size()); + ReplaceTemporalTypes(ty, &args); + AssertTypeEqual(args[0].type, timestamp(TimeUnit::SECOND)); + AssertTypeEqual(args[1].type, timestamp(TimeUnit::SECOND)); + + args = {date64(), time32(TimeUnit::SECOND)}; + ty = CommonTemporalResolution(args.data(), args.size()); + ReplaceTemporalTypes(ty, &args); + AssertTypeEqual(args[0].type, timestamp(TimeUnit::MILLI)); + AssertTypeEqual(args[1].type, duration(TimeUnit::MILLI)); + + args = {duration(TimeUnit::SECOND), date64()}; + ty = CommonTemporalResolution(args.data(), args.size()); + ReplaceTemporalTypes(ty, &args); + AssertTypeEqual(args[0].type, duration(TimeUnit::MILLI)); + AssertTypeEqual(args[1].type, timestamp(TimeUnit::MILLI)); + + args = {timestamp(TimeUnit::MICRO, tz), timestamp(TimeUnit::NANO)}; + ty = CommonTemporalResolution(args.data(), args.size()); + ReplaceTemporalTypes(ty, &args); + AssertTypeEqual(args[0].type, timestamp(TimeUnit::NANO, tz)); + AssertTypeEqual(args[1].type, timestamp(TimeUnit::NANO)); + + args = {timestamp(TimeUnit::MICRO, tz), time64(TimeUnit::NANO)}; + ty = CommonTemporalResolution(args.data(), args.size()); + ReplaceTemporalTypes(ty, &args); + AssertTypeEqual(args[0].type, timestamp(TimeUnit::NANO, tz)); + AssertTypeEqual(args[1].type, duration(TimeUnit::NANO)); + + args = {timestamp(TimeUnit::SECOND, tz), date64()}; + ty = CommonTemporalResolution(args.data(), args.size()); + ReplaceTemporalTypes(ty, &args); + AssertTypeEqual(args[0].type, timestamp(TimeUnit::MILLI, tz)); + AssertTypeEqual(args[1].type, timestamp(TimeUnit::MILLI)); + + args = {timestamp(TimeUnit::SECOND, "UTC"), timestamp(TimeUnit::SECOND, tz)}; + ty = CommonTemporalResolution(args.data(), args.size()); + AssertTypeEqual(args[0].type, timestamp(TimeUnit::SECOND, "UTC")); + AssertTypeEqual(args[1].type, timestamp(TimeUnit::SECOND, tz)); +} + } // namespace internal } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/compute/kernels/scalar_arithmetic.cc b/cpp/src/arrow/compute/kernels/scalar_arithmetic.cc index 2f36c946f0a..b27d494d864 100644 --- a/cpp/src/arrow/compute/kernels/scalar_arithmetic.cc +++ b/cpp/src/arrow/compute/kernels/scalar_arithmetic.cc @@ -42,6 +42,7 @@ using internal::SubtractWithOverflow; namespace compute { namespace internal { +using applicator::ScalarBinary; using applicator::ScalarBinaryEqualTypes; using applicator::ScalarBinaryNotNullEqualTypes; using applicator::ScalarUnary; @@ -180,7 +181,6 @@ struct Subtract { template static constexpr enable_if_signed_integer_value Call(KernelContext*, Arg0 left, Arg1 right, Status*) { - static_assert(std::is_same::value && std::is_same::value, ""); return arrow::internal::SafeSignedSubtract(left, right); } @@ -194,7 +194,6 @@ struct SubtractChecked { template static enable_if_integer_value Call(KernelContext*, Arg0 left, Arg1 right, Status* st) { - static_assert(std::is_same::value && std::is_same::value, ""); T result = 0; if (ARROW_PREDICT_FALSE(SubtractWithOverflow(left, right, &result))) { *st = Status::Invalid("overflow"); @@ -1693,7 +1692,9 @@ struct ArithmeticFunction : ScalarFunction { if (values->size() == 2) { ReplaceNullWithOtherType(values); - if (auto type = CommonNumeric(*values)) { + if (auto type = CommonTemporalResolution(values->data(), values->size())) { + ReplaceTemporalTypes(type, values); + } else if (auto type = CommonNumeric(*values)) { ReplaceTypes(type, values); } } @@ -2428,12 +2429,30 @@ void RegisterScalarArithmetic(FunctionRegistry* registry) { DCHECK_OK(subtract->AddKernel({in_type, in_type}, duration(unit), std::move(exec))); } + // Add subtract(timestamp, duration) -> timestamp + for (auto unit : TimeUnit::values()) { + InputType in_type(match::TimestampTypeUnit(unit)); + auto exec = ScalarBinary::Exec; + DCHECK_OK(subtract->AddKernel({in_type, duration(unit)}, OutputType(FirstType), + std::move(exec))); + } + DCHECK_OK(registry->AddFunction(std::move(subtract))); // ---------------------------------------------------------------------- auto subtract_checked = MakeArithmeticFunctionNotNull( "subtract_checked", &sub_checked_doc); AddDecimalBinaryKernels("subtract_checked", subtract_checked.get()); + + // Add subtract(timestamp, duration) -> timestamp + for (auto unit : TimeUnit::values()) { + InputType in_type(match::TimestampTypeUnit(unit)); + auto exec = + ScalarBinary::Exec; + DCHECK_OK(subtract_checked->AddKernel({in_type, duration(unit)}, + OutputType(FirstType), std::move(exec))); + } + DCHECK_OK(registry->AddFunction(std::move(subtract_checked))); // ---------------------------------------------------------------------- diff --git a/cpp/src/arrow/compute/kernels/scalar_temporal_test.cc b/cpp/src/arrow/compute/kernels/scalar_temporal_test.cc index 1960c0eaed7..410ea6132fe 100644 --- a/cpp/src/arrow/compute/kernels/scalar_temporal_test.cc +++ b/cpp/src/arrow/compute/kernels/scalar_temporal_test.cc @@ -976,6 +976,60 @@ TEST_F(ScalarTemporalTest, TestTemporalDifference) { } } +TEST_F(ScalarTemporalTest, TestTemporalSubtractDateAndDuration) { + std::string op = "subtract"; + std::string milliseconds_between_time_and_date = + "[31535941000, -31706603000, 2674840000, -2604800000, 82495000," + "-180610000, -11715000, -15620000, -19525000, -23430000, -27335000," + "-31240000, -35145000, -86400000, -26352000000, 5180277000, null]"; + std::string microseconds_between_time_and_date = + "[31535941000000, -31706603000000, 2674840000000, -2604800000000, 82495000000," + "-180610000000, -11715000000, -15620000000, -19525000000, -23430000000, " + "-27335000000, -31240000000, -35145000000, -86400000000, -26352000000000, " + "5180277000000, null]"; + auto dates32 = ArrayFromJSON(date32(), date32s2); + auto dates64 = ArrayFromJSON(date64(), date64s2); + + auto durations_ms = + ArrayFromJSON(duration(TimeUnit::MILLI), milliseconds_between_time_and_date); + auto timestamps_ms = ArrayFromJSON(timestamp(TimeUnit::MILLI), times_seconds_precision); + CheckScalarBinary(op, dates32, durations_ms, timestamps_ms); + CheckScalarBinary(op, dates64, durations_ms, timestamps_ms); + + auto durations_us = + ArrayFromJSON(duration(TimeUnit::MICRO), microseconds_between_time_and_date); + auto timestamps_us = ArrayFromJSON(timestamp(TimeUnit::MICRO), times_seconds_precision); + CheckScalarBinary(op, dates32, durations_us, timestamps_us); + CheckScalarBinary(op, dates64, durations_us, timestamps_us); +} + +TEST_F(ScalarTemporalTest, TestTemporalSubtractDateAndDurationChecked) { + std::string op = "subtract_checked"; + std::string milliseconds_between_time_and_date = + "[31535941000, -31706603000, 2674840000, -2604800000, 82495000," + "-180610000, -11715000, -15620000, -19525000, -23430000, -27335000," + "-31240000, -35145000, -86400000, -26352000000, 5180277000, null]"; + std::string microseconds_between_time_and_date = + "[31535941000000, -31706603000000, 2674840000000, -2604800000000, 82495000000," + "-180610000000, -11715000000, -15620000000, -19525000000, -23430000000, " + "-27335000000, -31240000000, -35145000000, -86400000000, -26352000000000, " + "5180277000000, null]"; + auto dates32 = ArrayFromJSON(date32(), date32s2); + auto dates64 = ArrayFromJSON(date64(), date64s2); + + auto durations_ms = + ArrayFromJSON(duration(TimeUnit::MILLI), milliseconds_between_time_and_date); + auto timestamps_ms = ArrayFromJSON(timestamp(TimeUnit::MILLI), times_seconds_precision); + CheckScalarBinary(op, dates32, durations_ms, timestamps_ms); + CheckScalarBinary(op, dates64, durations_ms, timestamps_ms); + + auto durations_us = + ArrayFromJSON(duration(TimeUnit::MICRO), microseconds_between_time_and_date); + auto timestamps_us = ArrayFromJSON(timestamp(TimeUnit::MICRO), times_seconds_precision); + CheckScalarBinary(op, dates32, durations_us, timestamps_us); + CheckScalarBinary(op, dates64, durations_us, timestamps_us); +} + TEST_F(ScalarTemporalTest, TestTemporalDifferenceWeeks) { auto raw_days = ArrayFromJSON(timestamp(TimeUnit::SECOND), R"([ "2021-08-09", "2021-08-10", "2021-08-11", "2021-08-12", "2021-08-13", "2021-08-14", "2021-08-15", diff --git a/docs/source/cpp/compute.rst b/docs/source/cpp/compute.rst index f958bd8d398..4173fbbfdee 100644 --- a/docs/source/cpp/compute.rst +++ b/docs/source/cpp/compute.rst @@ -419,40 +419,41 @@ For functions which support decimal inputs (currently ``add``, ``subtract``, precisions/scales will be promoted appropriately. Mixed decimal and floating-point arguments will cast all arguments to floating-point, while mixed decimal and integer arguments will cast all arguments to decimals. - -+------------------+--------+----------------+----------------------+-------+ -| Function name | Arity | Input types | Output type | Notes | -+==================+========+================+======================+=======+ -| abs | Unary | Numeric | Numeric | | -+------------------+--------+----------------+----------------------+-------+ -| abs_checked | Unary | Numeric | Numeric | | -+------------------+--------+----------------+----------------------+-------+ -| add | Binary | Numeric | Numeric | \(1) | -+------------------+--------+----------------+----------------------+-------+ -| add_checked | Binary | Numeric | Numeric | \(1) | -+------------------+--------+----------------+----------------------+-------+ -| divide | Binary | Numeric | Numeric | \(1) | -+------------------+--------+----------------+----------------------+-------+ -| divide_checked | Binary | Numeric | Numeric | \(1) | -+------------------+--------+----------------+----------------------+-------+ -| multiply | Binary | Numeric | Numeric | \(1) | -+------------------+--------+----------------+----------------------+-------+ -| multiply_checked | Binary | Numeric | Numeric | \(1) | -+------------------+--------+----------------+----------------------+-------+ -| negate | Unary | Numeric | Numeric | | -+------------------+--------+----------------+----------------------+-------+ -| negate_checked | Unary | Signed Numeric | Signed Numeric | | -+------------------+--------+----------------+----------------------+-------+ -| power | Binary | Numeric | Numeric | | -+------------------+--------+----------------+----------------------+-------+ -| power_checked | Binary | Numeric | Numeric | | -+------------------+--------+----------------+----------------------+-------+ -| sign | Unary | Numeric | Int8/Float32/Float64 | \(2) | -+------------------+--------+----------------+----------------------+-------+ -| subtract | Binary | Numeric | Numeric | \(1) | -+------------------+--------+----------------+----------------------+-------+ -| subtract_checked | Binary | Numeric | Numeric | \(1) | -+------------------+--------+----------------+----------------------+-------+ +Mixed time resolution temporal inputs will be cast to finest input resolution. + ++------------------+--------+----------------------------+----------------------------+-------+ +| Function name | Arity | Input types | Output type | Notes | ++==================+========+============================+============================+=======+ +| abs | Unary | Numeric | Numeric | | ++------------------+--------+----------------------------+----------------------------+-------+ +| abs_checked | Unary | Numeric | Numeric | | ++------------------+--------+----------------------------+----------------------------+-------+ +| add | Binary | Numeric | Numeric | \(1) | ++------------------+--------+----------------------------+----------------------------+-------+ +| add_checked | Binary | Numeric | Numeric | \(1) | ++------------------+--------+----------------------------+----------------------------+-------+ +| divide | Binary | Numeric | Numeric | \(1) | ++------------------+--------+----------------------------+----------------------------+-------+ +| divide_checked | Binary | Numeric | Numeric | \(1) | ++------------------+--------+----------------------------+----------------------------+-------+ +| multiply | Binary | Numeric | Numeric | \(1) | ++------------------+--------+----------------------------+----------------------------+-------+ +| multiply_checked | Binary | Numeric | Numeric | \(1) | ++------------------+--------+----------------------------+----------------------------+-------+ +| negate | Unary | Numeric | Numeric | | ++------------------+--------+----------------------------+----------------------------+-------+ +| negate_checked | Unary | Signed Numeric | Signed Numeric | | ++------------------+--------+----------------------------+----------------------------+-------+ +| power | Binary | Numeric | Numeric | | ++------------------+--------+----------------------------+----------------------------+-------+ +| power_checked | Binary | Numeric | Numeric | | ++------------------+--------+----------------------------+----------------------------+-------+ +| sign | Unary | Numeric | Int8/Float32/Float64 | \(2) | ++------------------+--------+----------------------------+----------------------------+-------+ +| subtract | Binary | Numeric/Date/Duration | Numeric/Date/Duration | \(1) | ++------------------+--------+----------------------------+----------------------------+-------+ +| subtract_checked | Binary | Numeric/Date/Duration | Numeric/Date/Duration | \(1) | ++------------------+--------+----------------------------+----------------------------+-------+ * \(1) Precision and scale of computed DECIMAL results