Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 68 additions & 0 deletions cpp/src/arrow/compute/kernels/codegen_internal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,33 @@ void ReplaceNullWithOtherType(ValueDescr* first, size_t count) {
}
}

void ReplaceTemporalTypes(const TimeUnit::type unit, std::vector<ValueDescr>* descrs) {
auto* end = descrs->data() + descrs->size();

for (auto* it = descrs->data(); it != end; it++) {
switch (it->type->id()) {
case Type::TIMESTAMP: {
const auto& ty = checked_cast<const TimestampType&>(*it->type);
it->type = timestamp(unit, ty.timezone());
continue;
}
case Type::TIME32:
case Type::TIME64:
case Type::DURATION: {
it->type = duration(unit);
continue;
}
case Type::DATE32:
case Type::DATE64: {
it->type = timestamp(unit);
continue;
}
default:
continue;
}
}
}

void ReplaceTypes(const std::shared_ptr<DataType>& type,
std::vector<ValueDescr>* descrs) {
ReplaceTypes(type, descrs->data(), descrs->size());
Expand Down Expand Up @@ -180,6 +207,47 @@ std::shared_ptr<DataType> CommonNumeric(const ValueDescr* begin, size_t count) {
return int8();
}

TimeUnit::type CommonTemporalResolution(const ValueDescr* begin, size_t count) {
TimeUnit::type finest_unit = TimeUnit::SECOND;
const ValueDescr* end = begin + count;
for (auto it = begin; it != end; it++) {
auto id = it->type->id();
switch (id) {
case Type::DATE32: {
// Date32's unit is days, but the coarsest we have is seconds
continue;
}
case Type::DATE64: {
finest_unit = std::max(finest_unit, TimeUnit::MILLI);
continue;
}
case Type::TIMESTAMP: {
const auto& ty = checked_cast<const TimestampType&>(*it->type);
finest_unit = std::max(finest_unit, ty.unit());
continue;
}
case Type::DURATION: {
const auto& ty = checked_cast<const DurationType&>(*it->type);
finest_unit = std::max(finest_unit, ty.unit());
continue;
}
case Type::TIME32: {
const auto& ty = checked_cast<const Time32Type&>(*it->type);
finest_unit = std::max(finest_unit, ty.unit());
continue;
}
case Type::TIME64: {
const auto& ty = checked_cast<const Time64Type&>(*it->type);
finest_unit = std::max(finest_unit, ty.unit());
continue;
}
default:
continue;
}
}
return finest_unit;
}

std::shared_ptr<DataType> CommonTemporal(const ValueDescr* begin, size_t count) {
TimeUnit::type finest_unit = TimeUnit::SECOND;
const std::string* timezone = nullptr;
Expand Down
6 changes: 6 additions & 0 deletions cpp/src/arrow/compute/kernels/codegen_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -1381,6 +1381,9 @@ void ReplaceTypes(const std::shared_ptr<DataType>&, std::vector<ValueDescr>* des
ARROW_EXPORT
void ReplaceTypes(const std::shared_ptr<DataType>&, ValueDescr* descrs, size_t count);

ARROW_EXPORT
void ReplaceTemporalTypes(TimeUnit::type unit, std::vector<ValueDescr>* descrs);

ARROW_EXPORT
std::shared_ptr<DataType> CommonNumeric(const std::vector<ValueDescr>& descrs);

Expand All @@ -1390,6 +1393,9 @@ std::shared_ptr<DataType> CommonNumeric(const ValueDescr* begin, size_t count);
ARROW_EXPORT
std::shared_ptr<DataType> CommonTemporal(const ValueDescr* begin, size_t count);

ARROW_EXPORT
TimeUnit::type CommonTemporalResolution(const ValueDescr* begin, size_t count);

ARROW_EXPORT
std::shared_ptr<DataType> CommonBinary(const ValueDescr* begin, size_t count);

Expand Down
89 changes: 89 additions & 0 deletions cpp/src/arrow/compute/kernels/codegen_internal_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,95 @@ TEST(TestDispatchBest, CommonTemporal) {
ASSERT_EQ(nullptr, CommonTemporal(args.data(), args.size()));
}

TEST(TestDispatchBest, CommonTemporalResolution) {
std::vector<ValueDescr> args;
std::string tz = "Pacific/Marquesas";

args = {date32(), date32()};
ASSERT_EQ(TimeUnit::SECOND, CommonTemporalResolution(args.data(), args.size()));
args = {date32(), date64()};
ASSERT_EQ(TimeUnit::MILLI, CommonTemporalResolution(args.data(), args.size()));
args = {time32(TimeUnit::MILLI), date32()};
ASSERT_EQ(TimeUnit::MILLI, CommonTemporalResolution(args.data(), args.size()));
args = {time32(TimeUnit::MILLI), time32(TimeUnit::SECOND)};
ASSERT_EQ(TimeUnit::MILLI, CommonTemporalResolution(args.data(), args.size()));
args = {time32(TimeUnit::MILLI), time64(TimeUnit::MICRO)};
ASSERT_EQ(TimeUnit::MICRO, CommonTemporalResolution(args.data(), args.size()));
args = {time64(TimeUnit::NANO), time64(TimeUnit::MICRO)};
ASSERT_EQ(TimeUnit::NANO, CommonTemporalResolution(args.data(), args.size()));
args = {duration(TimeUnit::MILLI), duration(TimeUnit::MICRO)};
ASSERT_EQ(TimeUnit::MICRO, CommonTemporalResolution(args.data(), args.size()));
args = {duration(TimeUnit::MILLI), date32()};
ASSERT_EQ(TimeUnit::MILLI, CommonTemporalResolution(args.data(), args.size()));
args = {date64(), duration(TimeUnit::SECOND)};
ASSERT_EQ(TimeUnit::MILLI, CommonTemporalResolution(args.data(), args.size()));
args = {duration(TimeUnit::SECOND), time32(TimeUnit::SECOND)};
ASSERT_EQ(TimeUnit::SECOND, CommonTemporalResolution(args.data(), args.size()));
args = {time64(TimeUnit::MICRO), duration(TimeUnit::NANO)};
ASSERT_EQ(TimeUnit::NANO, CommonTemporalResolution(args.data(), args.size()));
args = {timestamp(TimeUnit::SECOND, tz), timestamp(TimeUnit::MICRO)};
ASSERT_EQ(TimeUnit::MICRO, CommonTemporalResolution(args.data(), args.size()));
args = {date32(), timestamp(TimeUnit::MICRO, tz)};
ASSERT_EQ(TimeUnit::MICRO, CommonTemporalResolution(args.data(), args.size()));
args = {timestamp(TimeUnit::MICRO, tz), date64()};
ASSERT_EQ(TimeUnit::MICRO, CommonTemporalResolution(args.data(), args.size()));
args = {time32(TimeUnit::MILLI), timestamp(TimeUnit::MICRO, tz)};
ASSERT_EQ(TimeUnit::MICRO, CommonTemporalResolution(args.data(), args.size()));
args = {timestamp(TimeUnit::MICRO, tz), time64(TimeUnit::NANO)};
ASSERT_EQ(TimeUnit::NANO, CommonTemporalResolution(args.data(), args.size()));
args = {timestamp(TimeUnit::SECOND, tz), duration(TimeUnit::MILLI)};
ASSERT_EQ(TimeUnit::MILLI, CommonTemporalResolution(args.data(), args.size()));
args = {timestamp(TimeUnit::SECOND, "UTC"), timestamp(TimeUnit::SECOND, tz)};
ASSERT_EQ(TimeUnit::SECOND, CommonTemporalResolution(args.data(), args.size()));
}

TEST(TestDispatchBest, ReplaceTemporalTypes) {
std::vector<ValueDescr> args;
std::string tz = "Pacific/Marquesas";
TimeUnit::type ty;

args = {date32(), date32()};
ty = CommonTemporalResolution(args.data(), args.size());
ReplaceTemporalTypes(ty, &args);
AssertTypeEqual(args[0].type, timestamp(TimeUnit::SECOND));
AssertTypeEqual(args[1].type, timestamp(TimeUnit::SECOND));

args = {date64(), time32(TimeUnit::SECOND)};
ty = CommonTemporalResolution(args.data(), args.size());
ReplaceTemporalTypes(ty, &args);
AssertTypeEqual(args[0].type, timestamp(TimeUnit::MILLI));
AssertTypeEqual(args[1].type, duration(TimeUnit::MILLI));

args = {duration(TimeUnit::SECOND), date64()};
ty = CommonTemporalResolution(args.data(), args.size());
ReplaceTemporalTypes(ty, &args);
AssertTypeEqual(args[0].type, duration(TimeUnit::MILLI));
AssertTypeEqual(args[1].type, timestamp(TimeUnit::MILLI));

args = {timestamp(TimeUnit::MICRO, tz), timestamp(TimeUnit::NANO)};
ty = CommonTemporalResolution(args.data(), args.size());
ReplaceTemporalTypes(ty, &args);
AssertTypeEqual(args[0].type, timestamp(TimeUnit::NANO, tz));
AssertTypeEqual(args[1].type, timestamp(TimeUnit::NANO));

args = {timestamp(TimeUnit::MICRO, tz), time64(TimeUnit::NANO)};
ty = CommonTemporalResolution(args.data(), args.size());
ReplaceTemporalTypes(ty, &args);
AssertTypeEqual(args[0].type, timestamp(TimeUnit::NANO, tz));
AssertTypeEqual(args[1].type, duration(TimeUnit::NANO));

args = {timestamp(TimeUnit::SECOND, tz), date64()};
ty = CommonTemporalResolution(args.data(), args.size());
ReplaceTemporalTypes(ty, &args);
AssertTypeEqual(args[0].type, timestamp(TimeUnit::MILLI, tz));
AssertTypeEqual(args[1].type, timestamp(TimeUnit::MILLI));

args = {timestamp(TimeUnit::SECOND, "UTC"), timestamp(TimeUnit::SECOND, tz)};
ty = CommonTemporalResolution(args.data(), args.size());
AssertTypeEqual(args[0].type, timestamp(TimeUnit::SECOND, "UTC"));
AssertTypeEqual(args[1].type, timestamp(TimeUnit::SECOND, tz));
}

} // namespace internal
} // namespace compute
} // namespace arrow
25 changes: 22 additions & 3 deletions cpp/src/arrow/compute/kernels/scalar_arithmetic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ using internal::SubtractWithOverflow;
namespace compute {
namespace internal {

using applicator::ScalarBinary;
using applicator::ScalarBinaryEqualTypes;
using applicator::ScalarBinaryNotNullEqualTypes;
using applicator::ScalarUnary;
Expand Down Expand Up @@ -180,7 +181,6 @@ struct Subtract {
template <typename T, typename Arg0, typename Arg1>
static constexpr enable_if_signed_integer_value<T> Call(KernelContext*, Arg0 left,
Arg1 right, Status*) {
static_assert(std::is_same<T, Arg0>::value && std::is_same<T, Arg1>::value, "");
return arrow::internal::SafeSignedSubtract(left, right);
}

Expand All @@ -194,7 +194,6 @@ struct SubtractChecked {
template <typename T, typename Arg0, typename Arg1>
static enable_if_integer_value<T> Call(KernelContext*, Arg0 left, Arg1 right,
Status* st) {
static_assert(std::is_same<T, Arg0>::value && std::is_same<T, Arg1>::value, "");
T result = 0;
if (ARROW_PREDICT_FALSE(SubtractWithOverflow(left, right, &result))) {
*st = Status::Invalid("overflow");
Expand Down Expand Up @@ -1693,7 +1692,9 @@ struct ArithmeticFunction : ScalarFunction {
if (values->size() == 2) {
ReplaceNullWithOtherType(values);

if (auto type = CommonNumeric(*values)) {
if (auto type = CommonTemporalResolution(values->data(), values->size())) {
ReplaceTemporalTypes(type, values);
} else if (auto type = CommonNumeric(*values)) {
ReplaceTypes(type, values);
}
}
Expand Down Expand Up @@ -2428,12 +2429,30 @@ void RegisterScalarArithmetic(FunctionRegistry* registry) {
DCHECK_OK(subtract->AddKernel({in_type, in_type}, duration(unit), std::move(exec)));
}

// Add subtract(timestamp, duration) -> timestamp
for (auto unit : TimeUnit::values()) {
InputType in_type(match::TimestampTypeUnit(unit));
auto exec = ScalarBinary<TimestampType, DurationType, TimestampType, Subtract>::Exec;
DCHECK_OK(subtract->AddKernel({in_type, duration(unit)}, OutputType(FirstType),
std::move(exec)));
}

DCHECK_OK(registry->AddFunction(std::move(subtract)));

// ----------------------------------------------------------------------
auto subtract_checked = MakeArithmeticFunctionNotNull<SubtractChecked>(
"subtract_checked", &sub_checked_doc);
AddDecimalBinaryKernels<SubtractChecked>("subtract_checked", subtract_checked.get());

// Add subtract(timestamp, duration) -> timestamp
for (auto unit : TimeUnit::values()) {
InputType in_type(match::TimestampTypeUnit(unit));
auto exec =
ScalarBinary<TimestampType, DurationType, TimestampType, SubtractChecked>::Exec;
DCHECK_OK(subtract_checked->AddKernel({in_type, duration(unit)},
OutputType(FirstType), std::move(exec)));
}

DCHECK_OK(registry->AddFunction(std::move(subtract_checked)));

// ----------------------------------------------------------------------
Expand Down
54 changes: 54 additions & 0 deletions cpp/src/arrow/compute/kernels/scalar_temporal_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -976,6 +976,60 @@ TEST_F(ScalarTemporalTest, TestTemporalDifference) {
}
}

TEST_F(ScalarTemporalTest, TestTemporalSubtractDateAndDuration) {
std::string op = "subtract";
std::string milliseconds_between_time_and_date =
"[31535941000, -31706603000, 2674840000, -2604800000, 82495000,"
"-180610000, -11715000, -15620000, -19525000, -23430000, -27335000,"
"-31240000, -35145000, -86400000, -26352000000, 5180277000, null]";
std::string microseconds_between_time_and_date =
"[31535941000000, -31706603000000, 2674840000000, -2604800000000, 82495000000,"
"-180610000000, -11715000000, -15620000000, -19525000000, -23430000000, "
"-27335000000, -31240000000, -35145000000, -86400000000, -26352000000000, "
"5180277000000, null]";
auto dates32 = ArrayFromJSON(date32(), date32s2);
auto dates64 = ArrayFromJSON(date64(), date64s2);

auto durations_ms =
ArrayFromJSON(duration(TimeUnit::MILLI), milliseconds_between_time_and_date);
auto timestamps_ms = ArrayFromJSON(timestamp(TimeUnit::MILLI), times_seconds_precision);
CheckScalarBinary(op, dates32, durations_ms, timestamps_ms);
CheckScalarBinary(op, dates64, durations_ms, timestamps_ms);

auto durations_us =
ArrayFromJSON(duration(TimeUnit::MICRO), microseconds_between_time_and_date);
auto timestamps_us = ArrayFromJSON(timestamp(TimeUnit::MICRO), times_seconds_precision);
CheckScalarBinary(op, dates32, durations_us, timestamps_us);
CheckScalarBinary(op, dates64, durations_us, timestamps_us);
}

TEST_F(ScalarTemporalTest, TestTemporalSubtractDateAndDurationChecked) {
std::string op = "subtract_checked";
std::string milliseconds_between_time_and_date =
"[31535941000, -31706603000, 2674840000, -2604800000, 82495000,"
"-180610000, -11715000, -15620000, -19525000, -23430000, -27335000,"
"-31240000, -35145000, -86400000, -26352000000, 5180277000, null]";
std::string microseconds_between_time_and_date =
"[31535941000000, -31706603000000, 2674840000000, -2604800000000, 82495000000,"
"-180610000000, -11715000000, -15620000000, -19525000000, -23430000000, "
"-27335000000, -31240000000, -35145000000, -86400000000, -26352000000000, "
"5180277000000, null]";
auto dates32 = ArrayFromJSON(date32(), date32s2);
auto dates64 = ArrayFromJSON(date64(), date64s2);

auto durations_ms =
ArrayFromJSON(duration(TimeUnit::MILLI), milliseconds_between_time_and_date);
auto timestamps_ms = ArrayFromJSON(timestamp(TimeUnit::MILLI), times_seconds_precision);
CheckScalarBinary(op, dates32, durations_ms, timestamps_ms);
CheckScalarBinary(op, dates64, durations_ms, timestamps_ms);

auto durations_us =
ArrayFromJSON(duration(TimeUnit::MICRO), microseconds_between_time_and_date);
auto timestamps_us = ArrayFromJSON(timestamp(TimeUnit::MICRO), times_seconds_precision);
CheckScalarBinary(op, dates32, durations_us, timestamps_us);
CheckScalarBinary(op, dates64, durations_us, timestamps_us);
}

TEST_F(ScalarTemporalTest, TestTemporalDifferenceWeeks) {
auto raw_days = ArrayFromJSON(timestamp(TimeUnit::SECOND), R"([
"2021-08-09", "2021-08-10", "2021-08-11", "2021-08-12", "2021-08-13", "2021-08-14", "2021-08-15",
Expand Down
Loading