Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 82 additions & 2 deletions cpp/src/arrow/compute/kernels/scalar_arithmetic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,31 @@ struct SubtractChecked {
}
};

struct SubtractDate32 {
static constexpr int64_t kSecondsInDay = 86400;

template <typename T, typename Arg0, typename Arg1>
static constexpr T Call(KernelContext*, Arg0 left, Arg1 right, Status*) {
return arrow::internal::SafeSignedSubtract(left, right) * kSecondsInDay;
}
};

struct SubtractCheckedDate32 {
static constexpr int64_t kSecondsInDay = 86400;

template <typename T, typename Arg0, typename Arg1>
static T Call(KernelContext*, Arg0 left, Arg1 right, Status* st) {
T result = 0;
if (ARROW_PREDICT_FALSE(SubtractWithOverflow(left, right, &result))) {
*st = Status::Invalid("overflow");
}
if (ARROW_PREDICT_FALSE(MultiplyWithOverflow(result, kSecondsInDay, &result))) {
*st = Status::Invalid("overflow");
}
return result;
}
};

struct Multiply {
static_assert(std::is_same<decltype(int8_t() * int8_t()), int32_t>::value, "");
static_assert(std::is_same<decltype(uint8_t() * uint8_t()), int32_t>::value, "");
Expand Down Expand Up @@ -1611,6 +1636,23 @@ Result<ValueDescr> ResolveDecimalDivisionOutput(KernelContext*,
});
}

Result<ValueDescr> ResolveTemporalOutput(KernelContext*,
const std::vector<ValueDescr>& args) {
DCHECK_EQ(args[0].type->id(), args[1].type->id());
auto left_type = checked_cast<const TimestampType*>(args[0].type.get());
auto right_type = checked_cast<const TimestampType*>(args[1].type.get());
DCHECK_EQ(left_type->unit(), left_type->unit());

if ((left_type->timezone() == "" || right_type->timezone() == "") &&
left_type->timezone() != right_type->timezone()) {
return Status::Invalid("Subtraction of zoned and non-zoned times is ambiguous. (",
left_type->timezone(), right_type->timezone(), ").");
}

auto type = duration(right_type->unit());
return ValueDescr(std::move(type), GetBroadcastShape(args));
}

template <typename Op>
void AddDecimalUnaryKernels(ScalarFunction* func) {
OutputType out_type(FirstType);
Expand Down Expand Up @@ -2426,7 +2468,9 @@ void RegisterScalarArithmetic(FunctionRegistry* registry) {
for (auto unit : TimeUnit::values()) {
InputType in_type(match::TimestampTypeUnit(unit));
auto exec = ArithmeticExecFromOp<ScalarBinaryEqualTypes, Subtract>(Type::TIMESTAMP);
DCHECK_OK(subtract->AddKernel({in_type, in_type}, duration(unit), std::move(exec)));
DCHECK_OK(subtract->AddKernel({in_type, in_type},
OutputType::Resolver(ResolveTemporalOutput),
std::move(exec)));
}

// Add subtract(timestamp, duration) -> timestamp
Expand All @@ -2437,14 +2481,36 @@ void RegisterScalarArithmetic(FunctionRegistry* registry) {
std::move(exec)));
}

// Add subtract(date32, date32) -> duration(TimeUnit::SECOND)
InputType in_type_date_32(date32());
auto exec_date_32 = ScalarBinaryEqualTypes<Int64Type, Date32Type, SubtractDate32>::Exec;
DCHECK_OK(subtract->AddKernel({in_type_date_32, in_type_date_32},
duration(TimeUnit::SECOND), std::move(exec_date_32)));

// Add subtract(date64, date64) -> duration(TimeUnit::MILLI)
InputType in_type_date_64(date64());
auto exec_date_64 = ScalarBinaryEqualTypes<Int64Type, Date64Type, Subtract>::Exec;
DCHECK_OK(subtract->AddKernel({in_type_date_64, in_type_date_64},
duration(TimeUnit::MILLI), std::move(exec_date_64)));

DCHECK_OK(registry->AddFunction(std::move(subtract)));

// ----------------------------------------------------------------------
auto subtract_checked = MakeArithmeticFunctionNotNull<SubtractChecked>(
"subtract_checked", &sub_checked_doc);
AddDecimalBinaryKernels<SubtractChecked>("subtract_checked", subtract_checked.get());

// Add subtract(timestamp, duration) -> timestamp
// Add subtract_checked(timestamp, timestamp) -> duration
for (auto unit : TimeUnit::values()) {
InputType in_type(match::TimestampTypeUnit(unit));
auto exec =
ArithmeticExecFromOp<ScalarBinaryEqualTypes, SubtractChecked>(Type::TIMESTAMP);
DCHECK_OK(subtract_checked->AddKernel({in_type, in_type},
OutputType::Resolver(ResolveTemporalOutput),
std::move(exec)));
}

// Add subtract_checked(timestamp, duration) -> timestamp
for (auto unit : TimeUnit::values()) {
InputType in_type(match::TimestampTypeUnit(unit));
auto exec =
Expand All @@ -2453,6 +2519,20 @@ void RegisterScalarArithmetic(FunctionRegistry* registry) {
OutputType(FirstType), std::move(exec)));
}

// Add subtract_checked(date32, date32) -> duration(TimeUnit::SECOND)
auto exec_date_32_checked =
ScalarBinaryEqualTypes<Int64Type, Date32Type, SubtractCheckedDate32>::Exec;
DCHECK_OK(subtract_checked->AddKernel({in_type_date_32, in_type_date_32},
duration(TimeUnit::SECOND),
std::move(exec_date_32_checked)));

// Add subtract_checked(date64, date64) -> duration(TimeUnit::MILLI)
auto exec_date_64_checked =
ScalarBinaryEqualTypes<Int64Type, Date64Type, SubtractChecked>::Exec;
DCHECK_OK(subtract_checked->AddKernel({in_type_date_64, in_type_date_64},
duration(TimeUnit::MILLI),
std::move(exec_date_64_checked)));

DCHECK_OK(registry->AddFunction(std::move(subtract_checked)));

// ----------------------------------------------------------------------
Expand Down
Loading