From 6955cbff9882bacbf836bd51ff5f95579a999907 Mon Sep 17 00:00:00 2001 From: Rok Date: Wed, 12 Jan 2022 18:38:08 +0100 Subject: [PATCH 1/3] SubtractDateAndDuration --- .../compute/kernels/scalar_arithmetic.cc | 31 +++++++++++++++++++ .../compute/kernels/scalar_temporal_test.cc | 17 ++++++++++ 2 files changed, 48 insertions(+) diff --git a/cpp/src/arrow/compute/kernels/scalar_arithmetic.cc b/cpp/src/arrow/compute/kernels/scalar_arithmetic.cc index 2f36c946f0a..d8762f715ba 100644 --- a/cpp/src/arrow/compute/kernels/scalar_arithmetic.cc +++ b/cpp/src/arrow/compute/kernels/scalar_arithmetic.cc @@ -42,6 +42,7 @@ using internal::SubtractWithOverflow; namespace compute { namespace internal { +using applicator::ScalarBinary; using applicator::ScalarBinaryEqualTypes; using applicator::ScalarBinaryNotNullEqualTypes; using applicator::ScalarUnary; @@ -215,6 +216,22 @@ struct SubtractChecked { } }; +struct SubtractDate32AndDuration { + static constexpr int64_t kMillisecondsInDay = 86400000; + + template + static constexpr T Call(KernelContext*, Arg0 left, Arg1 right, Status*) { + return arrow::internal::SafeSignedSubtract(left * kMillisecondsInDay, right); + } +}; + +struct SubtractDate64AndDuration { + template + static constexpr T Call(KernelContext*, Arg0 left, Arg1 right, Status*) { + return arrow::internal::SafeSignedSubtract(left, right); + } +}; + struct Multiply { static_assert(std::is_same::value, ""); static_assert(std::is_same::value, ""); @@ -2428,6 +2445,20 @@ void RegisterScalarArithmetic(FunctionRegistry* registry) { DCHECK_OK(subtract->AddKernel({in_type, in_type}, duration(unit), std::move(exec))); } + // Add subtract(date32, date32) -> duration(TimeUnit::MILLI) + InputType in_type_date_32(date32()); + auto exec_date_32 = + ScalarBinary::Exec; + DCHECK_OK(subtract->AddKernel({in_type_date_32, duration(TimeUnit::MILLI)}, + timestamp(TimeUnit::MILLI), std::move(exec_date_32))); + + // Add subtract(date64, date64) -> duration(TimeUnit::MILLI) + InputType in_type_date_64(date64()); + auto exec_date_64 = + ScalarBinary::Exec; + DCHECK_OK(subtract->AddKernel({in_type_date_64, duration(TimeUnit::MILLI)}, + timestamp(TimeUnit::MILLI), std::move(exec_date_64))); + DCHECK_OK(registry->AddFunction(std::move(subtract))); // ---------------------------------------------------------------------- diff --git a/cpp/src/arrow/compute/kernels/scalar_temporal_test.cc b/cpp/src/arrow/compute/kernels/scalar_temporal_test.cc index 1960c0eaed7..33ecc2d08f5 100644 --- a/cpp/src/arrow/compute/kernels/scalar_temporal_test.cc +++ b/cpp/src/arrow/compute/kernels/scalar_temporal_test.cc @@ -976,6 +976,23 @@ TEST_F(ScalarTemporalTest, TestTemporalDifference) { } } +TEST_F(ScalarTemporalTest, TestTemporalSubtractDateAndDuration) { + std::string op = "subtract"; + std::string milliseconds_between_time_and_date = + "[31535941000, -31706603000, 2674840000, -2604800000, 82495000," + "-180610000, -11715000, -15620000, -19525000, -23430000, -27335000," + "-31240000, -35145000, -86400000, -26352000000, 5180277000, null]"; + + auto durations = + ArrayFromJSON(duration(TimeUnit::MILLI), milliseconds_between_time_and_date); + auto timestamps = ArrayFromJSON(timestamp(TimeUnit::MILLI), times_seconds_precision); + auto dates32 = ArrayFromJSON(date32(), date32s2); + auto dates64 = ArrayFromJSON(date64(), date64s2); + + CheckScalarBinary(op, dates32, durations, timestamps); + CheckScalarBinary(op, dates64, durations, timestamps); +} + TEST_F(ScalarTemporalTest, TestTemporalDifferenceWeeks) { auto raw_days = ArrayFromJSON(timestamp(TimeUnit::SECOND), R"([ "2021-08-09", "2021-08-10", "2021-08-11", "2021-08-12", "2021-08-13", "2021-08-14", "2021-08-15", From f638b7b62dcd72fedbd4218626b0286354fdc138 Mon Sep 17 00:00:00 2001 From: Rok Date: Thu, 20 Jan 2022 00:55:15 +0100 Subject: [PATCH 2/3] Add ReplaceTemporalTypes --- .../arrow/compute/kernels/codegen_internal.cc | 61 ++++++++++++++++ .../arrow/compute/kernels/codegen_internal.h | 4 ++ .../compute/kernels/codegen_internal_test.cc | 18 +++++ .../compute/kernels/scalar_arithmetic.cc | 52 ++++++-------- .../compute/kernels/scalar_temporal_test.cc | 45 ++++++++++-- docs/source/cpp/compute.rst | 69 ++++++++++--------- 6 files changed, 179 insertions(+), 70 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/codegen_internal.cc b/cpp/src/arrow/compute/kernels/codegen_internal.cc index eef5c4081bc..c71af2b6629 100644 --- a/cpp/src/arrow/compute/kernels/codegen_internal.cc +++ b/cpp/src/arrow/compute/kernels/codegen_internal.cc @@ -109,6 +109,58 @@ void ReplaceNullWithOtherType(ValueDescr* first, size_t count) { } } +void ReplaceTemporalTypes(const std::shared_ptr& type, + std::vector* descrs) { + auto* end = descrs->data() + descrs->size(); + TimeUnit::type finest_unit = TimeUnit::SECOND; + + switch (type->id()) { + case Type::TIMESTAMP: { + const auto& ty = checked_cast(*type); + finest_unit = ty.unit(); + break; + } + case Type::DURATION: { + const auto& ty = checked_cast(*type); + finest_unit = ty.unit(); + break; + } + case Type::DATE32: { + // Date32's unit is days, but the coarsest we have is seconds + break; + } + case Type::DATE64: { + finest_unit = std::max(finest_unit, TimeUnit::MILLI); + break; + } + default: + break; + } + + for (auto* it = descrs->data(); it != end; it++) { + switch (it->type->id()) { + case Type::TIMESTAMP: { + it->type = type; + continue; + } + case Type::DURATION: { + it->type = duration(finest_unit); + continue; + } + case Type::DATE32: { + it->type = timestamp(finest_unit); + continue; + } + case Type::DATE64: { + it->type = timestamp(finest_unit); + continue; + } + default: + continue; + } + } +} + void ReplaceTypes(const std::shared_ptr& type, std::vector* descrs) { ReplaceTypes(type, descrs->data(), descrs->size()); @@ -185,6 +237,7 @@ std::shared_ptr CommonTemporal(const ValueDescr* begin, size_t count) const std::string* timezone = nullptr; bool saw_date32 = false; bool saw_date64 = false; + bool saw_duration = false; const ValueDescr* end = begin + count; for (auto it = begin; it != end; it++) { @@ -206,6 +259,12 @@ std::shared_ptr CommonTemporal(const ValueDescr* begin, size_t count) finest_unit = std::max(finest_unit, ty.unit()); continue; } + case Type::DURATION: { + const auto& ty = checked_cast(*it->type); + finest_unit = std::max(finest_unit, ty.unit()); + saw_duration = true; + continue; + } default: return nullptr; } @@ -214,6 +273,8 @@ std::shared_ptr CommonTemporal(const ValueDescr* begin, size_t count) if (timezone) { // At least one timestamp seen return timestamp(finest_unit, *timezone); + } else if (saw_duration) { + return duration(finest_unit); } else if (saw_date64) { return date64(); } else if (saw_date32) { diff --git a/cpp/src/arrow/compute/kernels/codegen_internal.h b/cpp/src/arrow/compute/kernels/codegen_internal.h index 1efd3e22f93..5e9e204a8e6 100644 --- a/cpp/src/arrow/compute/kernels/codegen_internal.h +++ b/cpp/src/arrow/compute/kernels/codegen_internal.h @@ -1381,6 +1381,10 @@ void ReplaceTypes(const std::shared_ptr&, std::vector* des ARROW_EXPORT void ReplaceTypes(const std::shared_ptr&, ValueDescr* descrs, size_t count); +ARROW_EXPORT +void ReplaceTemporalTypes(const std::shared_ptr&, + std::vector* descrs); + ARROW_EXPORT std::shared_ptr CommonNumeric(const std::vector& descrs); diff --git a/cpp/src/arrow/compute/kernels/codegen_internal_test.cc b/cpp/src/arrow/compute/kernels/codegen_internal_test.cc index d64143dea31..090368b051c 100644 --- a/cpp/src/arrow/compute/kernels/codegen_internal_test.cc +++ b/cpp/src/arrow/compute/kernels/codegen_internal_test.cc @@ -156,6 +156,24 @@ TEST(TestDispatchBest, CommonTemporal) { args = {timestamp(TimeUnit::SECOND, "America/Phoenix"), timestamp(TimeUnit::SECOND, "UTC")}; ASSERT_EQ(nullptr, CommonTemporal(args.data(), args.size())); + args = {date32(), duration(TimeUnit::MILLI)}; + AssertTypeEqual(duration(TimeUnit::MILLI), CommonTemporal(args.data(), args.size())); + args = {date64(), duration(TimeUnit::MICRO)}; + AssertTypeEqual(duration(TimeUnit::MICRO), CommonTemporal(args.data(), args.size())); +} + +TEST(TestDispatchBest, ReplaceTemporalTypes) { + std::vector args; + + args = {date32(), duration(TimeUnit::MILLI)}; + ReplaceTemporalTypes(CommonTemporal(args.data(), args.size()), &args); + AssertTypeEqual(args[0].type, timestamp(TimeUnit::MILLI)); + AssertTypeEqual(args[1].type, duration(TimeUnit::MILLI)); + + args = {date64(), duration(TimeUnit::MICRO)}; + ReplaceTemporalTypes(CommonTemporal(args.data(), args.size()), &args); + AssertTypeEqual(args[0].type, timestamp(TimeUnit::MICRO)); + AssertTypeEqual(args[1].type, duration(TimeUnit::MICRO)); } } // namespace internal diff --git a/cpp/src/arrow/compute/kernels/scalar_arithmetic.cc b/cpp/src/arrow/compute/kernels/scalar_arithmetic.cc index d8762f715ba..c2afb5ffa6e 100644 --- a/cpp/src/arrow/compute/kernels/scalar_arithmetic.cc +++ b/cpp/src/arrow/compute/kernels/scalar_arithmetic.cc @@ -181,7 +181,6 @@ struct Subtract { template static constexpr enable_if_signed_integer_value Call(KernelContext*, Arg0 left, Arg1 right, Status*) { - static_assert(std::is_same::value && std::is_same::value, ""); return arrow::internal::SafeSignedSubtract(left, right); } @@ -195,7 +194,6 @@ struct SubtractChecked { template static enable_if_integer_value Call(KernelContext*, Arg0 left, Arg1 right, Status* st) { - static_assert(std::is_same::value && std::is_same::value, ""); T result = 0; if (ARROW_PREDICT_FALSE(SubtractWithOverflow(left, right, &result))) { *st = Status::Invalid("overflow"); @@ -216,22 +214,6 @@ struct SubtractChecked { } }; -struct SubtractDate32AndDuration { - static constexpr int64_t kMillisecondsInDay = 86400000; - - template - static constexpr T Call(KernelContext*, Arg0 left, Arg1 right, Status*) { - return arrow::internal::SafeSignedSubtract(left * kMillisecondsInDay, right); - } -}; - -struct SubtractDate64AndDuration { - template - static constexpr T Call(KernelContext*, Arg0 left, Arg1 right, Status*) { - return arrow::internal::SafeSignedSubtract(left, right); - } -}; - struct Multiply { static_assert(std::is_same::value, ""); static_assert(std::is_same::value, ""); @@ -1710,7 +1692,9 @@ struct ArithmeticFunction : ScalarFunction { if (values->size() == 2) { ReplaceNullWithOtherType(values); - if (auto type = CommonNumeric(*values)) { + if (auto type = CommonTemporal(values->data(), values->size())) { + ReplaceTemporalTypes(type, values); + } else if (auto type = CommonNumeric(*values)) { ReplaceTypes(type, values); } } @@ -2445,19 +2429,13 @@ void RegisterScalarArithmetic(FunctionRegistry* registry) { DCHECK_OK(subtract->AddKernel({in_type, in_type}, duration(unit), std::move(exec))); } - // Add subtract(date32, date32) -> duration(TimeUnit::MILLI) - InputType in_type_date_32(date32()); - auto exec_date_32 = - ScalarBinary::Exec; - DCHECK_OK(subtract->AddKernel({in_type_date_32, duration(TimeUnit::MILLI)}, - timestamp(TimeUnit::MILLI), std::move(exec_date_32))); - - // Add subtract(date64, date64) -> duration(TimeUnit::MILLI) - InputType in_type_date_64(date64()); - auto exec_date_64 = - ScalarBinary::Exec; - DCHECK_OK(subtract->AddKernel({in_type_date_64, duration(TimeUnit::MILLI)}, - timestamp(TimeUnit::MILLI), std::move(exec_date_64))); + // Add subtract(timestamp, duration) -> timestamp + for (auto unit : TimeUnit::values()) { + InputType in_type(match::TimestampTypeUnit(unit)); + auto exec = ScalarBinary::Exec; + DCHECK_OK(subtract->AddKernel({in_type, duration(unit)}, OutputType(FirstType), + std::move(exec))); + } DCHECK_OK(registry->AddFunction(std::move(subtract))); @@ -2465,6 +2443,16 @@ void RegisterScalarArithmetic(FunctionRegistry* registry) { auto subtract_checked = MakeArithmeticFunctionNotNull( "subtract_checked", &sub_checked_doc); AddDecimalBinaryKernels("subtract_checked", subtract_checked.get()); + + // Add subtract(timestamp, duration) -> timestamp + for (auto unit : TimeUnit::values()) { + InputType in_type(match::TimestampTypeUnit(unit)); + auto exec = + ScalarBinary::Exec; + DCHECK_OK(subtract_checked->AddKernel({in_type, duration(unit)}, + OutputType(FirstType), std::move(exec))); + } + DCHECK_OK(registry->AddFunction(std::move(subtract_checked))); // ---------------------------------------------------------------------- diff --git a/cpp/src/arrow/compute/kernels/scalar_temporal_test.cc b/cpp/src/arrow/compute/kernels/scalar_temporal_test.cc index 33ecc2d08f5..410ea6132fe 100644 --- a/cpp/src/arrow/compute/kernels/scalar_temporal_test.cc +++ b/cpp/src/arrow/compute/kernels/scalar_temporal_test.cc @@ -982,15 +982,52 @@ TEST_F(ScalarTemporalTest, TestTemporalSubtractDateAndDuration) { "[31535941000, -31706603000, 2674840000, -2604800000, 82495000," "-180610000, -11715000, -15620000, -19525000, -23430000, -27335000," "-31240000, -35145000, -86400000, -26352000000, 5180277000, null]"; + std::string microseconds_between_time_and_date = + "[31535941000000, -31706603000000, 2674840000000, -2604800000000, 82495000000," + "-180610000000, -11715000000, -15620000000, -19525000000, -23430000000, " + "-27335000000, -31240000000, -35145000000, -86400000000, -26352000000000, " + "5180277000000, null]"; + auto dates32 = ArrayFromJSON(date32(), date32s2); + auto dates64 = ArrayFromJSON(date64(), date64s2); - auto durations = + auto durations_ms = ArrayFromJSON(duration(TimeUnit::MILLI), milliseconds_between_time_and_date); - auto timestamps = ArrayFromJSON(timestamp(TimeUnit::MILLI), times_seconds_precision); + auto timestamps_ms = ArrayFromJSON(timestamp(TimeUnit::MILLI), times_seconds_precision); + CheckScalarBinary(op, dates32, durations_ms, timestamps_ms); + CheckScalarBinary(op, dates64, durations_ms, timestamps_ms); + + auto durations_us = + ArrayFromJSON(duration(TimeUnit::MICRO), microseconds_between_time_and_date); + auto timestamps_us = ArrayFromJSON(timestamp(TimeUnit::MICRO), times_seconds_precision); + CheckScalarBinary(op, dates32, durations_us, timestamps_us); + CheckScalarBinary(op, dates64, durations_us, timestamps_us); +} + +TEST_F(ScalarTemporalTest, TestTemporalSubtractDateAndDurationChecked) { + std::string op = "subtract_checked"; + std::string milliseconds_between_time_and_date = + "[31535941000, -31706603000, 2674840000, -2604800000, 82495000," + "-180610000, -11715000, -15620000, -19525000, -23430000, -27335000," + "-31240000, -35145000, -86400000, -26352000000, 5180277000, null]"; + std::string microseconds_between_time_and_date = + "[31535941000000, -31706603000000, 2674840000000, -2604800000000, 82495000000," + "-180610000000, -11715000000, -15620000000, -19525000000, -23430000000, " + "-27335000000, -31240000000, -35145000000, -86400000000, -26352000000000, " + "5180277000000, null]"; auto dates32 = ArrayFromJSON(date32(), date32s2); auto dates64 = ArrayFromJSON(date64(), date64s2); - CheckScalarBinary(op, dates32, durations, timestamps); - CheckScalarBinary(op, dates64, durations, timestamps); + auto durations_ms = + ArrayFromJSON(duration(TimeUnit::MILLI), milliseconds_between_time_and_date); + auto timestamps_ms = ArrayFromJSON(timestamp(TimeUnit::MILLI), times_seconds_precision); + CheckScalarBinary(op, dates32, durations_ms, timestamps_ms); + CheckScalarBinary(op, dates64, durations_ms, timestamps_ms); + + auto durations_us = + ArrayFromJSON(duration(TimeUnit::MICRO), microseconds_between_time_and_date); + auto timestamps_us = ArrayFromJSON(timestamp(TimeUnit::MICRO), times_seconds_precision); + CheckScalarBinary(op, dates32, durations_us, timestamps_us); + CheckScalarBinary(op, dates64, durations_us, timestamps_us); } TEST_F(ScalarTemporalTest, TestTemporalDifferenceWeeks) { diff --git a/docs/source/cpp/compute.rst b/docs/source/cpp/compute.rst index f958bd8d398..4173fbbfdee 100644 --- a/docs/source/cpp/compute.rst +++ b/docs/source/cpp/compute.rst @@ -419,40 +419,41 @@ For functions which support decimal inputs (currently ``add``, ``subtract``, precisions/scales will be promoted appropriately. Mixed decimal and floating-point arguments will cast all arguments to floating-point, while mixed decimal and integer arguments will cast all arguments to decimals. - -+------------------+--------+----------------+----------------------+-------+ -| Function name | Arity | Input types | Output type | Notes | -+==================+========+================+======================+=======+ -| abs | Unary | Numeric | Numeric | | -+------------------+--------+----------------+----------------------+-------+ -| abs_checked | Unary | Numeric | Numeric | | -+------------------+--------+----------------+----------------------+-------+ -| add | Binary | Numeric | Numeric | \(1) | -+------------------+--------+----------------+----------------------+-------+ -| add_checked | Binary | Numeric | Numeric | \(1) | -+------------------+--------+----------------+----------------------+-------+ -| divide | Binary | Numeric | Numeric | \(1) | -+------------------+--------+----------------+----------------------+-------+ -| divide_checked | Binary | Numeric | Numeric | \(1) | -+------------------+--------+----------------+----------------------+-------+ -| multiply | Binary | Numeric | Numeric | \(1) | -+------------------+--------+----------------+----------------------+-------+ -| multiply_checked | Binary | Numeric | Numeric | \(1) | -+------------------+--------+----------------+----------------------+-------+ -| negate | Unary | Numeric | Numeric | | -+------------------+--------+----------------+----------------------+-------+ -| negate_checked | Unary | Signed Numeric | Signed Numeric | | -+------------------+--------+----------------+----------------------+-------+ -| power | Binary | Numeric | Numeric | | -+------------------+--------+----------------+----------------------+-------+ -| power_checked | Binary | Numeric | Numeric | | -+------------------+--------+----------------+----------------------+-------+ -| sign | Unary | Numeric | Int8/Float32/Float64 | \(2) | -+------------------+--------+----------------+----------------------+-------+ -| subtract | Binary | Numeric | Numeric | \(1) | -+------------------+--------+----------------+----------------------+-------+ -| subtract_checked | Binary | Numeric | Numeric | \(1) | -+------------------+--------+----------------+----------------------+-------+ +Mixed time resolution temporal inputs will be cast to finest input resolution. + ++------------------+--------+----------------------------+----------------------------+-------+ +| Function name | Arity | Input types | Output type | Notes | ++==================+========+============================+============================+=======+ +| abs | Unary | Numeric | Numeric | | ++------------------+--------+----------------------------+----------------------------+-------+ +| abs_checked | Unary | Numeric | Numeric | | ++------------------+--------+----------------------------+----------------------------+-------+ +| add | Binary | Numeric | Numeric | \(1) | ++------------------+--------+----------------------------+----------------------------+-------+ +| add_checked | Binary | Numeric | Numeric | \(1) | ++------------------+--------+----------------------------+----------------------------+-------+ +| divide | Binary | Numeric | Numeric | \(1) | ++------------------+--------+----------------------------+----------------------------+-------+ +| divide_checked | Binary | Numeric | Numeric | \(1) | ++------------------+--------+----------------------------+----------------------------+-------+ +| multiply | Binary | Numeric | Numeric | \(1) | ++------------------+--------+----------------------------+----------------------------+-------+ +| multiply_checked | Binary | Numeric | Numeric | \(1) | ++------------------+--------+----------------------------+----------------------------+-------+ +| negate | Unary | Numeric | Numeric | | ++------------------+--------+----------------------------+----------------------------+-------+ +| negate_checked | Unary | Signed Numeric | Signed Numeric | | ++------------------+--------+----------------------------+----------------------------+-------+ +| power | Binary | Numeric | Numeric | | ++------------------+--------+----------------------------+----------------------------+-------+ +| power_checked | Binary | Numeric | Numeric | | ++------------------+--------+----------------------------+----------------------------+-------+ +| sign | Unary | Numeric | Int8/Float32/Float64 | \(2) | ++------------------+--------+----------------------------+----------------------------+-------+ +| subtract | Binary | Numeric/Date/Duration | Numeric/Date/Duration | \(1) | ++------------------+--------+----------------------------+----------------------------+-------+ +| subtract_checked | Binary | Numeric/Date/Duration | Numeric/Date/Duration | \(1) | ++------------------+--------+----------------------------+----------------------------+-------+ * \(1) Precision and scale of computed DECIMAL results From 12f68edb4ef17180b3d08fecc5ecd94f17cad60d Mon Sep 17 00:00:00 2001 From: Rok Date: Tue, 25 Jan 2022 01:50:40 +0100 Subject: [PATCH 3/3] Switch to CommonTemporalResolution --- .../arrow/compute/kernels/codegen_internal.cc | 91 ++++++++++--------- .../arrow/compute/kernels/codegen_internal.h | 6 +- .../compute/kernels/codegen_internal_test.cc | 91 +++++++++++++++++-- .../compute/kernels/scalar_arithmetic.cc | 2 +- 4 files changed, 135 insertions(+), 55 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/codegen_internal.cc b/cpp/src/arrow/compute/kernels/codegen_internal.cc index c71af2b6629..1e06a364509 100644 --- a/cpp/src/arrow/compute/kernels/codegen_internal.cc +++ b/cpp/src/arrow/compute/kernels/codegen_internal.cc @@ -109,50 +109,25 @@ void ReplaceNullWithOtherType(ValueDescr* first, size_t count) { } } -void ReplaceTemporalTypes(const std::shared_ptr& type, - std::vector* descrs) { +void ReplaceTemporalTypes(const TimeUnit::type unit, std::vector* descrs) { auto* end = descrs->data() + descrs->size(); - TimeUnit::type finest_unit = TimeUnit::SECOND; - - switch (type->id()) { - case Type::TIMESTAMP: { - const auto& ty = checked_cast(*type); - finest_unit = ty.unit(); - break; - } - case Type::DURATION: { - const auto& ty = checked_cast(*type); - finest_unit = ty.unit(); - break; - } - case Type::DATE32: { - // Date32's unit is days, but the coarsest we have is seconds - break; - } - case Type::DATE64: { - finest_unit = std::max(finest_unit, TimeUnit::MILLI); - break; - } - default: - break; - } for (auto* it = descrs->data(); it != end; it++) { switch (it->type->id()) { case Type::TIMESTAMP: { - it->type = type; + const auto& ty = checked_cast(*it->type); + it->type = timestamp(unit, ty.timezone()); continue; } + case Type::TIME32: + case Type::TIME64: case Type::DURATION: { - it->type = duration(finest_unit); - continue; - } - case Type::DATE32: { - it->type = timestamp(finest_unit); + it->type = duration(unit); continue; } + case Type::DATE32: case Type::DATE64: { - it->type = timestamp(finest_unit); + it->type = timestamp(unit); continue; } default: @@ -232,12 +207,52 @@ std::shared_ptr CommonNumeric(const ValueDescr* begin, size_t count) { return int8(); } +TimeUnit::type CommonTemporalResolution(const ValueDescr* begin, size_t count) { + TimeUnit::type finest_unit = TimeUnit::SECOND; + const ValueDescr* end = begin + count; + for (auto it = begin; it != end; it++) { + auto id = it->type->id(); + switch (id) { + case Type::DATE32: { + // Date32's unit is days, but the coarsest we have is seconds + continue; + } + case Type::DATE64: { + finest_unit = std::max(finest_unit, TimeUnit::MILLI); + continue; + } + case Type::TIMESTAMP: { + const auto& ty = checked_cast(*it->type); + finest_unit = std::max(finest_unit, ty.unit()); + continue; + } + case Type::DURATION: { + const auto& ty = checked_cast(*it->type); + finest_unit = std::max(finest_unit, ty.unit()); + continue; + } + case Type::TIME32: { + const auto& ty = checked_cast(*it->type); + finest_unit = std::max(finest_unit, ty.unit()); + continue; + } + case Type::TIME64: { + const auto& ty = checked_cast(*it->type); + finest_unit = std::max(finest_unit, ty.unit()); + continue; + } + default: + continue; + } + } + return finest_unit; +} + std::shared_ptr CommonTemporal(const ValueDescr* begin, size_t count) { TimeUnit::type finest_unit = TimeUnit::SECOND; const std::string* timezone = nullptr; bool saw_date32 = false; bool saw_date64 = false; - bool saw_duration = false; const ValueDescr* end = begin + count; for (auto it = begin; it != end; it++) { @@ -259,12 +274,6 @@ std::shared_ptr CommonTemporal(const ValueDescr* begin, size_t count) finest_unit = std::max(finest_unit, ty.unit()); continue; } - case Type::DURATION: { - const auto& ty = checked_cast(*it->type); - finest_unit = std::max(finest_unit, ty.unit()); - saw_duration = true; - continue; - } default: return nullptr; } @@ -273,8 +282,6 @@ std::shared_ptr CommonTemporal(const ValueDescr* begin, size_t count) if (timezone) { // At least one timestamp seen return timestamp(finest_unit, *timezone); - } else if (saw_duration) { - return duration(finest_unit); } else if (saw_date64) { return date64(); } else if (saw_date32) { diff --git a/cpp/src/arrow/compute/kernels/codegen_internal.h b/cpp/src/arrow/compute/kernels/codegen_internal.h index 5e9e204a8e6..ff7b9161fe3 100644 --- a/cpp/src/arrow/compute/kernels/codegen_internal.h +++ b/cpp/src/arrow/compute/kernels/codegen_internal.h @@ -1382,8 +1382,7 @@ ARROW_EXPORT void ReplaceTypes(const std::shared_ptr&, ValueDescr* descrs, size_t count); ARROW_EXPORT -void ReplaceTemporalTypes(const std::shared_ptr&, - std::vector* descrs); +void ReplaceTemporalTypes(TimeUnit::type unit, std::vector* descrs); ARROW_EXPORT std::shared_ptr CommonNumeric(const std::vector& descrs); @@ -1394,6 +1393,9 @@ std::shared_ptr CommonNumeric(const ValueDescr* begin, size_t count); ARROW_EXPORT std::shared_ptr CommonTemporal(const ValueDescr* begin, size_t count); +ARROW_EXPORT +TimeUnit::type CommonTemporalResolution(const ValueDescr* begin, size_t count); + ARROW_EXPORT std::shared_ptr CommonBinary(const ValueDescr* begin, size_t count); diff --git a/cpp/src/arrow/compute/kernels/codegen_internal_test.cc b/cpp/src/arrow/compute/kernels/codegen_internal_test.cc index 090368b051c..6b68632ceb3 100644 --- a/cpp/src/arrow/compute/kernels/codegen_internal_test.cc +++ b/cpp/src/arrow/compute/kernels/codegen_internal_test.cc @@ -156,24 +156,95 @@ TEST(TestDispatchBest, CommonTemporal) { args = {timestamp(TimeUnit::SECOND, "America/Phoenix"), timestamp(TimeUnit::SECOND, "UTC")}; ASSERT_EQ(nullptr, CommonTemporal(args.data(), args.size())); - args = {date32(), duration(TimeUnit::MILLI)}; - AssertTypeEqual(duration(TimeUnit::MILLI), CommonTemporal(args.data(), args.size())); - args = {date64(), duration(TimeUnit::MICRO)}; - AssertTypeEqual(duration(TimeUnit::MICRO), CommonTemporal(args.data(), args.size())); +} + +TEST(TestDispatchBest, CommonTemporalResolution) { + std::vector args; + std::string tz = "Pacific/Marquesas"; + + args = {date32(), date32()}; + ASSERT_EQ(TimeUnit::SECOND, CommonTemporalResolution(args.data(), args.size())); + args = {date32(), date64()}; + ASSERT_EQ(TimeUnit::MILLI, CommonTemporalResolution(args.data(), args.size())); + args = {time32(TimeUnit::MILLI), date32()}; + ASSERT_EQ(TimeUnit::MILLI, CommonTemporalResolution(args.data(), args.size())); + args = {time32(TimeUnit::MILLI), time32(TimeUnit::SECOND)}; + ASSERT_EQ(TimeUnit::MILLI, CommonTemporalResolution(args.data(), args.size())); + args = {time32(TimeUnit::MILLI), time64(TimeUnit::MICRO)}; + ASSERT_EQ(TimeUnit::MICRO, CommonTemporalResolution(args.data(), args.size())); + args = {time64(TimeUnit::NANO), time64(TimeUnit::MICRO)}; + ASSERT_EQ(TimeUnit::NANO, CommonTemporalResolution(args.data(), args.size())); + args = {duration(TimeUnit::MILLI), duration(TimeUnit::MICRO)}; + ASSERT_EQ(TimeUnit::MICRO, CommonTemporalResolution(args.data(), args.size())); + args = {duration(TimeUnit::MILLI), date32()}; + ASSERT_EQ(TimeUnit::MILLI, CommonTemporalResolution(args.data(), args.size())); + args = {date64(), duration(TimeUnit::SECOND)}; + ASSERT_EQ(TimeUnit::MILLI, CommonTemporalResolution(args.data(), args.size())); + args = {duration(TimeUnit::SECOND), time32(TimeUnit::SECOND)}; + ASSERT_EQ(TimeUnit::SECOND, CommonTemporalResolution(args.data(), args.size())); + args = {time64(TimeUnit::MICRO), duration(TimeUnit::NANO)}; + ASSERT_EQ(TimeUnit::NANO, CommonTemporalResolution(args.data(), args.size())); + args = {timestamp(TimeUnit::SECOND, tz), timestamp(TimeUnit::MICRO)}; + ASSERT_EQ(TimeUnit::MICRO, CommonTemporalResolution(args.data(), args.size())); + args = {date32(), timestamp(TimeUnit::MICRO, tz)}; + ASSERT_EQ(TimeUnit::MICRO, CommonTemporalResolution(args.data(), args.size())); + args = {timestamp(TimeUnit::MICRO, tz), date64()}; + ASSERT_EQ(TimeUnit::MICRO, CommonTemporalResolution(args.data(), args.size())); + args = {time32(TimeUnit::MILLI), timestamp(TimeUnit::MICRO, tz)}; + ASSERT_EQ(TimeUnit::MICRO, CommonTemporalResolution(args.data(), args.size())); + args = {timestamp(TimeUnit::MICRO, tz), time64(TimeUnit::NANO)}; + ASSERT_EQ(TimeUnit::NANO, CommonTemporalResolution(args.data(), args.size())); + args = {timestamp(TimeUnit::SECOND, tz), duration(TimeUnit::MILLI)}; + ASSERT_EQ(TimeUnit::MILLI, CommonTemporalResolution(args.data(), args.size())); + args = {timestamp(TimeUnit::SECOND, "UTC"), timestamp(TimeUnit::SECOND, tz)}; + ASSERT_EQ(TimeUnit::SECOND, CommonTemporalResolution(args.data(), args.size())); } TEST(TestDispatchBest, ReplaceTemporalTypes) { std::vector args; + std::string tz = "Pacific/Marquesas"; + TimeUnit::type ty; - args = {date32(), duration(TimeUnit::MILLI)}; - ReplaceTemporalTypes(CommonTemporal(args.data(), args.size()), &args); + args = {date32(), date32()}; + ty = CommonTemporalResolution(args.data(), args.size()); + ReplaceTemporalTypes(ty, &args); + AssertTypeEqual(args[0].type, timestamp(TimeUnit::SECOND)); + AssertTypeEqual(args[1].type, timestamp(TimeUnit::SECOND)); + + args = {date64(), time32(TimeUnit::SECOND)}; + ty = CommonTemporalResolution(args.data(), args.size()); + ReplaceTemporalTypes(ty, &args); AssertTypeEqual(args[0].type, timestamp(TimeUnit::MILLI)); AssertTypeEqual(args[1].type, duration(TimeUnit::MILLI)); - args = {date64(), duration(TimeUnit::MICRO)}; - ReplaceTemporalTypes(CommonTemporal(args.data(), args.size()), &args); - AssertTypeEqual(args[0].type, timestamp(TimeUnit::MICRO)); - AssertTypeEqual(args[1].type, duration(TimeUnit::MICRO)); + args = {duration(TimeUnit::SECOND), date64()}; + ty = CommonTemporalResolution(args.data(), args.size()); + ReplaceTemporalTypes(ty, &args); + AssertTypeEqual(args[0].type, duration(TimeUnit::MILLI)); + AssertTypeEqual(args[1].type, timestamp(TimeUnit::MILLI)); + + args = {timestamp(TimeUnit::MICRO, tz), timestamp(TimeUnit::NANO)}; + ty = CommonTemporalResolution(args.data(), args.size()); + ReplaceTemporalTypes(ty, &args); + AssertTypeEqual(args[0].type, timestamp(TimeUnit::NANO, tz)); + AssertTypeEqual(args[1].type, timestamp(TimeUnit::NANO)); + + args = {timestamp(TimeUnit::MICRO, tz), time64(TimeUnit::NANO)}; + ty = CommonTemporalResolution(args.data(), args.size()); + ReplaceTemporalTypes(ty, &args); + AssertTypeEqual(args[0].type, timestamp(TimeUnit::NANO, tz)); + AssertTypeEqual(args[1].type, duration(TimeUnit::NANO)); + + args = {timestamp(TimeUnit::SECOND, tz), date64()}; + ty = CommonTemporalResolution(args.data(), args.size()); + ReplaceTemporalTypes(ty, &args); + AssertTypeEqual(args[0].type, timestamp(TimeUnit::MILLI, tz)); + AssertTypeEqual(args[1].type, timestamp(TimeUnit::MILLI)); + + args = {timestamp(TimeUnit::SECOND, "UTC"), timestamp(TimeUnit::SECOND, tz)}; + ty = CommonTemporalResolution(args.data(), args.size()); + AssertTypeEqual(args[0].type, timestamp(TimeUnit::SECOND, "UTC")); + AssertTypeEqual(args[1].type, timestamp(TimeUnit::SECOND, tz)); } } // namespace internal diff --git a/cpp/src/arrow/compute/kernels/scalar_arithmetic.cc b/cpp/src/arrow/compute/kernels/scalar_arithmetic.cc index c2afb5ffa6e..b27d494d864 100644 --- a/cpp/src/arrow/compute/kernels/scalar_arithmetic.cc +++ b/cpp/src/arrow/compute/kernels/scalar_arithmetic.cc @@ -1692,7 +1692,7 @@ struct ArithmeticFunction : ScalarFunction { if (values->size() == 2) { ReplaceNullWithOtherType(values); - if (auto type = CommonTemporal(values->data(), values->size())) { + if (auto type = CommonTemporalResolution(values->data(), values->size())) { ReplaceTemporalTypes(type, values); } else if (auto type = CommonNumeric(*values)) { ReplaceTypes(type, values);