Skip to content
Closed
39 changes: 30 additions & 9 deletions cpp/src/arrow/compute/kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -112,24 +112,27 @@ std::shared_ptr<TypeMatcher> SameTypeId(Type::type type_id) {
return std::make_shared<SameTypeIdMatcher>(type_id);
}

class TimestampUnitMatcher : public TypeMatcher {
template <typename ArrowType>
class TimeUnitMatcher : public TypeMatcher {
using ThisType = TimeUnitMatcher<ArrowType>;

public:
explicit TimestampUnitMatcher(TimeUnit::type accepted_unit)
explicit TimeUnitMatcher(TimeUnit::type accepted_unit)
: accepted_unit_(accepted_unit) {}

bool Matches(const DataType& type) const override {
if (type.id() != Type::TIMESTAMP) {
if (type.id() != ArrowType::type_id) {
return false;
}
const auto& ts_type = checked_cast<const TimestampType&>(type);
return ts_type.unit() == accepted_unit_;
const auto& time_type = checked_cast<const ArrowType&>(type);
return time_type.unit() == accepted_unit_;
}

bool Equals(const TypeMatcher& other) const override {
if (this == &other) {
return true;
}
auto casted = dynamic_cast<const TimestampUnitMatcher*>(&other);
auto casted = dynamic_cast<const ThisType*>(&other);
if (casted == nullptr) {
return false;
}
Expand All @@ -138,16 +141,34 @@ class TimestampUnitMatcher : public TypeMatcher {

std::string ToString() const override {
std::stringstream ss;
ss << "timestamp(" << ::arrow::internal::ToString(accepted_unit_) << ")";
ss << ArrowType::type_name() << "(" << ::arrow::internal::ToString(accepted_unit_)
<< ")";
return ss.str();
}

private:
TimeUnit::type accepted_unit_;
};

std::shared_ptr<TypeMatcher> TimestampUnit(TimeUnit::type unit) {
return std::make_shared<TimestampUnitMatcher>(unit);
using DurationTypeUnitMatcher = TimeUnitMatcher<DurationType>;
using Time32TypeUnitMatcher = TimeUnitMatcher<Time32Type>;
using Time64TypeUnitMatcher = TimeUnitMatcher<Time64Type>;
using TimestampTypeUnitMatcher = TimeUnitMatcher<TimestampType>;

std::shared_ptr<TypeMatcher> TimestampTypeUnit(TimeUnit::type unit) {
return std::make_shared<TimestampTypeUnitMatcher>(unit);
}

std::shared_ptr<TypeMatcher> Time32TypeUnit(TimeUnit::type unit) {
return std::make_shared<Time32TypeUnitMatcher>(unit);
}

std::shared_ptr<TypeMatcher> Time64TypeUnit(TimeUnit::type unit) {
return std::make_shared<Time64TypeUnitMatcher>(unit);
}

std::shared_ptr<TypeMatcher> DurationTypeUnit(TimeUnit::type unit) {
return std::make_shared<DurationTypeUnitMatcher>(unit);
}

class IntegerMatcher : public TypeMatcher {
Expand Down
5 changes: 4 additions & 1 deletion cpp/src/arrow/compute/kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,10 @@ ARROW_EXPORT std::shared_ptr<TypeMatcher> SameTypeId(Type::type type_id);

/// \brief Match any TimestampType instance having the same unit, but the time
/// zones can be different.
ARROW_EXPORT std::shared_ptr<TypeMatcher> TimestampUnit(TimeUnit::type unit);
ARROW_EXPORT std::shared_ptr<TypeMatcher> TimestampTypeUnit(TimeUnit::type unit);
ARROW_EXPORT std::shared_ptr<TypeMatcher> Time32TypeUnit(TimeUnit::type unit);
ARROW_EXPORT std::shared_ptr<TypeMatcher> Time64TypeUnit(TimeUnit::type unit);
ARROW_EXPORT std::shared_ptr<TypeMatcher> DurationTypeUnit(TimeUnit::type unit);

// \brief Match any integer type
ARROW_EXPORT std::shared_ptr<TypeMatcher> Integer();
Expand Down
22 changes: 13 additions & 9 deletions cpp/src/arrow/compute/kernel_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,23 +45,27 @@ TEST(TypeMatcher, SameTypeId) {
ASSERT_FALSE(matcher->Equals(*match::SameTypeId(Type::TIMESTAMP)));
}

TEST(TypeMatcher, TimestampUnit) {
std::shared_ptr<TypeMatcher> matcher = match::TimestampUnit(TimeUnit::MILLI);
TEST(TypeMatcher, TimestampTypeUnit) {
auto matcher = match::TimestampTypeUnit(TimeUnit::MILLI);
auto matcher2 = match::Time32TypeUnit(TimeUnit::MILLI);

ASSERT_TRUE(matcher->Matches(*timestamp(TimeUnit::MILLI)));
ASSERT_TRUE(matcher->Matches(*timestamp(TimeUnit::MILLI, "utc")));
ASSERT_FALSE(matcher->Matches(*timestamp(TimeUnit::SECOND)));
ASSERT_FALSE(matcher->Matches(*time32(TimeUnit::MILLI)));
ASSERT_TRUE(matcher2->Matches(*time32(TimeUnit::MILLI)));

// Check ToString representation
ASSERT_EQ("timestamp(s)", match::TimestampUnit(TimeUnit::SECOND)->ToString());
ASSERT_EQ("timestamp(ms)", match::TimestampUnit(TimeUnit::MILLI)->ToString());
ASSERT_EQ("timestamp(us)", match::TimestampUnit(TimeUnit::MICRO)->ToString());
ASSERT_EQ("timestamp(ns)", match::TimestampUnit(TimeUnit::NANO)->ToString());
ASSERT_EQ("timestamp(s)", match::TimestampTypeUnit(TimeUnit::SECOND)->ToString());
ASSERT_EQ("timestamp(ms)", match::TimestampTypeUnit(TimeUnit::MILLI)->ToString());
ASSERT_EQ("timestamp(us)", match::TimestampTypeUnit(TimeUnit::MICRO)->ToString());
ASSERT_EQ("timestamp(ns)", match::TimestampTypeUnit(TimeUnit::NANO)->ToString());

// Equals implementation
ASSERT_TRUE(matcher->Equals(*matcher));
ASSERT_TRUE(matcher->Equals(*match::TimestampUnit(TimeUnit::MILLI)));
ASSERT_FALSE(matcher->Equals(*match::TimestampUnit(TimeUnit::MICRO)));
ASSERT_TRUE(matcher->Equals(*match::TimestampTypeUnit(TimeUnit::MILLI)));
ASSERT_FALSE(matcher->Equals(*match::TimestampTypeUnit(TimeUnit::MICRO)));
ASSERT_FALSE(matcher->Equals(*match::Time32TypeUnit(TimeUnit::MILLI)));
}

// ----------------------------------------------------------------------
Expand Down Expand Up @@ -135,7 +139,7 @@ TEST(InputType, Constructors) {
ASSERT_EQ("array[Type::DECIMAL]", ty2_array.ToString());
ASSERT_EQ("scalar[Type::DECIMAL]", ty2_scalar.ToString());

InputType ty7(match::TimestampUnit(TimeUnit::MICRO));
InputType ty7(match::TimestampTypeUnit(TimeUnit::MICRO));
ASSERT_EQ("any[timestamp(us)]", ty7.ToString());
}

Expand Down
11 changes: 6 additions & 5 deletions cpp/src/arrow/compute/kernels/codegen_internal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,12 @@ void ExecFail(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
ctx->SetStatus(Status::NotImplemented("This kernel is malformed"));
}

void BinaryExecFlipped(KernelContext* ctx, ArrayKernelExec exec, const ExecBatch& batch,
Datum* out) {
ExecBatch flipped_batch = batch;
std::swap(flipped_batch.values[0], flipped_batch.values[1]);
exec(ctx, flipped_batch, out);
ArrayKernelExec MakeFlippedBinaryExec(ArrayKernelExec exec) {
return [exec](KernelContext* ctx, const ExecBatch& batch, Datum* out) {
ExecBatch flipped_batch = batch;
std::swap(flipped_batch.values[0], flipped_batch.values[1]);
exec(ctx, flipped_batch, out);
};
}

std::vector<std::shared_ptr<DataType>> g_signed_int_types;
Expand Down
Loading