diff --git a/cpp/src/arrow/compute/kernel.cc b/cpp/src/arrow/compute/kernel.cc index 4747e63c159..d52df9a669c 100644 --- a/cpp/src/arrow/compute/kernel.cc +++ b/cpp/src/arrow/compute/kernel.cc @@ -112,24 +112,27 @@ std::shared_ptr SameTypeId(Type::type type_id) { return std::make_shared(type_id); } -class TimestampUnitMatcher : public TypeMatcher { +template +class TimeUnitMatcher : public TypeMatcher { + using ThisType = TimeUnitMatcher; + public: - explicit TimestampUnitMatcher(TimeUnit::type accepted_unit) + explicit TimeUnitMatcher(TimeUnit::type accepted_unit) : accepted_unit_(accepted_unit) {} bool Matches(const DataType& type) const override { - if (type.id() != Type::TIMESTAMP) { + if (type.id() != ArrowType::type_id) { return false; } - const auto& ts_type = checked_cast(type); - return ts_type.unit() == accepted_unit_; + const auto& time_type = checked_cast(type); + return time_type.unit() == accepted_unit_; } bool Equals(const TypeMatcher& other) const override { if (this == &other) { return true; } - auto casted = dynamic_cast(&other); + auto casted = dynamic_cast(&other); if (casted == nullptr) { return false; } @@ -138,7 +141,8 @@ class TimestampUnitMatcher : public TypeMatcher { std::string ToString() const override { std::stringstream ss; - ss << "timestamp(" << ::arrow::internal::ToString(accepted_unit_) << ")"; + ss << ArrowType::type_name() << "(" << ::arrow::internal::ToString(accepted_unit_) + << ")"; return ss.str(); } @@ -146,8 +150,25 @@ class TimestampUnitMatcher : public TypeMatcher { TimeUnit::type accepted_unit_; }; -std::shared_ptr TimestampUnit(TimeUnit::type unit) { - return std::make_shared(unit); +using DurationTypeUnitMatcher = TimeUnitMatcher; +using Time32TypeUnitMatcher = TimeUnitMatcher; +using Time64TypeUnitMatcher = TimeUnitMatcher; +using TimestampTypeUnitMatcher = TimeUnitMatcher; + +std::shared_ptr TimestampTypeUnit(TimeUnit::type unit) { + return std::make_shared(unit); +} + +std::shared_ptr Time32TypeUnit(TimeUnit::type unit) { + return std::make_shared(unit); +} + +std::shared_ptr Time64TypeUnit(TimeUnit::type unit) { + return std::make_shared(unit); +} + +std::shared_ptr DurationTypeUnit(TimeUnit::type unit) { + return std::make_shared(unit); } class IntegerMatcher : public TypeMatcher { diff --git a/cpp/src/arrow/compute/kernel.h b/cpp/src/arrow/compute/kernel.h index b51efe5a953..68875ccfd66 100644 --- a/cpp/src/arrow/compute/kernel.h +++ b/cpp/src/arrow/compute/kernel.h @@ -148,7 +148,10 @@ ARROW_EXPORT std::shared_ptr SameTypeId(Type::type type_id); /// \brief Match any TimestampType instance having the same unit, but the time /// zones can be different. -ARROW_EXPORT std::shared_ptr TimestampUnit(TimeUnit::type unit); +ARROW_EXPORT std::shared_ptr TimestampTypeUnit(TimeUnit::type unit); +ARROW_EXPORT std::shared_ptr Time32TypeUnit(TimeUnit::type unit); +ARROW_EXPORT std::shared_ptr Time64TypeUnit(TimeUnit::type unit); +ARROW_EXPORT std::shared_ptr DurationTypeUnit(TimeUnit::type unit); // \brief Match any integer type ARROW_EXPORT std::shared_ptr Integer(); diff --git a/cpp/src/arrow/compute/kernel_test.cc b/cpp/src/arrow/compute/kernel_test.cc index fbdaf124d81..2eb7fd11449 100644 --- a/cpp/src/arrow/compute/kernel_test.cc +++ b/cpp/src/arrow/compute/kernel_test.cc @@ -45,23 +45,27 @@ TEST(TypeMatcher, SameTypeId) { ASSERT_FALSE(matcher->Equals(*match::SameTypeId(Type::TIMESTAMP))); } -TEST(TypeMatcher, TimestampUnit) { - std::shared_ptr matcher = match::TimestampUnit(TimeUnit::MILLI); +TEST(TypeMatcher, TimestampTypeUnit) { + auto matcher = match::TimestampTypeUnit(TimeUnit::MILLI); + auto matcher2 = match::Time32TypeUnit(TimeUnit::MILLI); ASSERT_TRUE(matcher->Matches(*timestamp(TimeUnit::MILLI))); ASSERT_TRUE(matcher->Matches(*timestamp(TimeUnit::MILLI, "utc"))); ASSERT_FALSE(matcher->Matches(*timestamp(TimeUnit::SECOND))); + ASSERT_FALSE(matcher->Matches(*time32(TimeUnit::MILLI))); + ASSERT_TRUE(matcher2->Matches(*time32(TimeUnit::MILLI))); // Check ToString representation - ASSERT_EQ("timestamp(s)", match::TimestampUnit(TimeUnit::SECOND)->ToString()); - ASSERT_EQ("timestamp(ms)", match::TimestampUnit(TimeUnit::MILLI)->ToString()); - ASSERT_EQ("timestamp(us)", match::TimestampUnit(TimeUnit::MICRO)->ToString()); - ASSERT_EQ("timestamp(ns)", match::TimestampUnit(TimeUnit::NANO)->ToString()); + ASSERT_EQ("timestamp(s)", match::TimestampTypeUnit(TimeUnit::SECOND)->ToString()); + ASSERT_EQ("timestamp(ms)", match::TimestampTypeUnit(TimeUnit::MILLI)->ToString()); + ASSERT_EQ("timestamp(us)", match::TimestampTypeUnit(TimeUnit::MICRO)->ToString()); + ASSERT_EQ("timestamp(ns)", match::TimestampTypeUnit(TimeUnit::NANO)->ToString()); // Equals implementation ASSERT_TRUE(matcher->Equals(*matcher)); - ASSERT_TRUE(matcher->Equals(*match::TimestampUnit(TimeUnit::MILLI))); - ASSERT_FALSE(matcher->Equals(*match::TimestampUnit(TimeUnit::MICRO))); + ASSERT_TRUE(matcher->Equals(*match::TimestampTypeUnit(TimeUnit::MILLI))); + ASSERT_FALSE(matcher->Equals(*match::TimestampTypeUnit(TimeUnit::MICRO))); + ASSERT_FALSE(matcher->Equals(*match::Time32TypeUnit(TimeUnit::MILLI))); } // ---------------------------------------------------------------------- @@ -135,7 +139,7 @@ TEST(InputType, Constructors) { ASSERT_EQ("array[Type::DECIMAL]", ty2_array.ToString()); ASSERT_EQ("scalar[Type::DECIMAL]", ty2_scalar.ToString()); - InputType ty7(match::TimestampUnit(TimeUnit::MICRO)); + InputType ty7(match::TimestampTypeUnit(TimeUnit::MICRO)); ASSERT_EQ("any[timestamp(us)]", ty7.ToString()); } diff --git a/cpp/src/arrow/compute/kernels/codegen_internal.cc b/cpp/src/arrow/compute/kernels/codegen_internal.cc index 427ad2b0c54..2606c978e39 100644 --- a/cpp/src/arrow/compute/kernels/codegen_internal.cc +++ b/cpp/src/arrow/compute/kernels/codegen_internal.cc @@ -31,11 +31,12 @@ void ExecFail(KernelContext* ctx, const ExecBatch& batch, Datum* out) { ctx->SetStatus(Status::NotImplemented("This kernel is malformed")); } -void BinaryExecFlipped(KernelContext* ctx, ArrayKernelExec exec, const ExecBatch& batch, - Datum* out) { - ExecBatch flipped_batch = batch; - std::swap(flipped_batch.values[0], flipped_batch.values[1]); - exec(ctx, flipped_batch, out); +ArrayKernelExec MakeFlippedBinaryExec(ArrayKernelExec exec) { + return [exec](KernelContext* ctx, const ExecBatch& batch, Datum* out) { + ExecBatch flipped_batch = batch; + std::swap(flipped_batch.values[0], flipped_batch.values[1]); + exec(ctx, flipped_batch, out); + }; } std::vector> g_signed_int_types; diff --git a/cpp/src/arrow/compute/kernels/codegen_internal.h b/cpp/src/arrow/compute/kernels/codegen_internal.h index c547c807757..3941b67fba9 100644 --- a/cpp/src/arrow/compute/kernels/codegen_internal.h +++ b/cpp/src/arrow/compute/kernels/codegen_internal.h @@ -23,8 +23,10 @@ #include #include +#include "arrow/array/builder_binary.h" #include "arrow/array/data.h" #include "arrow/buffer.h" +#include "arrow/buffer_builder.h" #include "arrow/compute/exec.h" #include "arrow/compute/kernel.h" #include "arrow/datum.h" @@ -121,10 +123,26 @@ struct ArrayIterator> { template struct ArrayIterator> { - int64_t position = 0; - typename TypeTraits::ArrayType arr; - explicit ArrayIterator(const ArrayData& data) : arr(data.Copy()) {} - util::string_view operator()() { return arr.GetView(position++); } + using offset_type = typename Type::offset_type; + const ArrayData& arr; + const offset_type* offsets; + offset_type cur_offset; + const char* data; + int64_t position; + explicit ArrayIterator(const ArrayData& arr) + : arr(arr), + offsets(reinterpret_cast(arr.buffers[1]->data()) + + arr.offset), + cur_offset(offsets[0]), + data(reinterpret_cast(arr.buffers[2]->data())), + position(0) {} + + util::string_view operator()() { + offset_type next_offset = offsets[position++ + 1]; + auto result = util::string_view(data + cur_offset, next_offset - cur_offset); + cur_offset = next_offset; + return result; + } }; template @@ -132,7 +150,7 @@ struct UnboxScalar; template struct UnboxScalar> { - using ScalarType = typename TypeTraits::ScalarType; + using ScalarType = ::arrow::internal::PrimitiveScalar; static typename Type::c_type Unbox(const Datum& datum) { return datum.scalar_as().value; } @@ -229,8 +247,7 @@ Result FirstType(KernelContext*, const std::vector& desc void ExecFail(KernelContext* ctx, const ExecBatch& batch, Datum* out); -void BinaryExecFlipped(KernelContext* ctx, ArrayKernelExec exec, const ExecBatch& batch, - Datum* out); +ArrayKernelExec MakeFlippedBinaryExec(ArrayKernelExec exec); // ---------------------------------------------------------------------- // Helpers for iterating over common DataType instances for adding kernels to @@ -308,34 +325,6 @@ void SimpleBinary(KernelContext* ctx, const ExecBatch& batch, Datum* out) { } } -// A ArrayKernelExec-creation template that iterates over primitive non-boolean -// inputs and writes into non-boolean primitive outputs. -// -// It may be possible to create a more generic template that can deal with any -// input writing to any output, but we will need to write benchmarks to -// investigate that on all compiler targets to ensure that the additional -// template abstractions do not incur performance overhead. This template -// provides a reference point for performance when there are no templates -// dealing with value iteration. -// -// TODO: Run benchmarks to determine if OutputAdapter is a zero-cost abstraction -template -void ScalarPrimitiveExecUnary(KernelContext* ctx, const ExecBatch& batch, Datum* out) { - using OUT = typename OutType::c_type; - using ARG0 = typename Arg0Type::c_type; - - if (batch[0].kind() == Datum::SCALAR) { - ctx->SetStatus(Status::NotImplemented("NYI")); - } else { - ArrayData* out_arr = out->mutable_array(); - auto out_data = out_arr->GetMutableValues(1); - auto arg0_data = batch[0].array()->GetValues(1); - for (int64_t i = 0; i < batch.length; ++i) { - *out_data++ = Op::template Call(ctx, *arg0_data++); - } - } -} - // OutputAdapter allows passing an inlineable lambda that provides a sequence // of output values to write into output memory. Boolean and primitive outputs // are currently implemented, and the validity bitmap is presumed to be handled @@ -485,6 +474,10 @@ struct ScalarUnaryNotNullStateful { struct ArrayExec> { static void Exec(const ThisType& functor, KernelContext* ctx, const ExecBatch& batch, Datum* out) { + // NOTE: This code is not currently used by any kernels and has + // suboptimal performance because it's recomputing the validity bitmap + // that is already computed by the kernel execution layer. Consider + // writing a lower-level "output adapter" for base binary types. typename TypeTraits::BuilderType builder; VisitArrayValuesInline(*batch[0].array(), [&](util::optional v) { if (v.has_value()) { @@ -644,15 +637,41 @@ struct ScalarBinary { template using ScalarBinaryEqualTypes = ScalarBinary; +} // namespace codegen + // ---------------------------------------------------------------------- -// Dynamic kernel selectors. These functors allow a kernel implementation to be -// selected given a arrow::DataType instance. Using these functors triggers the -// corresponding template that generate's the kernel's Exec function to be -// instantiated +// BEGIN of kernel generator-dispatchers ("GD") +// +// These GD functions instantiate kernel functor templates and select one of +// the instantiated kernels dynamically based on the data type or Type::type id +// that is passed. This enables functions to be populated with kernels by +// looping over vectors of data types rather than using macros or other +// approaches. +// +// The kernel functor must be of the form: +// +// template +// struct FUNCTOR { +// static void Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { +// // IMPLEMENTATION +// } +// }; +// +// When you pass FUNCTOR to a GD function, you must pass at least one static +// type along with the functor -- this is often the fixed return type of the +// functor. This Type0 argument is passed as the first argument to the functor +// during instantiation. The 2nd type passed to the functor is the DataType +// subclass corresponding to the type passed as argument (not template type) to +// the function. +// +// For example, GenerateNumeric(int32()) will select a kernel +// instantiated like FUNCTOR. Any additional variadic +// template arguments will be passed as additional template arguments to the +// kernel template. namespace detail { -// Convenience so we can pass DataType or Type::type into these kernel selectors +// Convenience so we can pass DataType or Type::type for the GD's struct GetTypeId { Type::type id; GetTypeId(const std::shared_ptr& type) // NOLINT implicit construction @@ -665,57 +684,9 @@ struct GetTypeId { } // namespace detail -// Generate a kernel given a functor of type -// -// struct OPERATOR_NAME { -// template -// static OUT Call(KernelContext*, ARG0 val) { -// // IMPLEMENTATION -// } -// }; -template -ArrayKernelExec NumericEqualTypesUnary(detail::GetTypeId get_id) { - switch (get_id.id) { - case Type::INT8: - return ScalarPrimitiveExecUnary; - case Type::UINT8: - return ScalarPrimitiveExecUnary; - case Type::INT16: - return ScalarPrimitiveExecUnary; - case Type::UINT16: - return ScalarPrimitiveExecUnary; - case Type::INT32: - return ScalarPrimitiveExecUnary; - case Type::UINT32: - return ScalarPrimitiveExecUnary; - case Type::INT64: - return ScalarPrimitiveExecUnary; - case Type::UINT64: - return ScalarPrimitiveExecUnary; - case Type::FLOAT: - return ScalarPrimitiveExecUnary; - case Type::DOUBLE: - return ScalarPrimitiveExecUnary; - default: - DCHECK(false); - return ExecFail; - } -} - -// Generate a kernel given a templated functor. This template effectively -// "curries" the first type argument. The functor must be of the form: -// -// template -// struct FUNCTOR { -// static void Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { -// // IMPLEMENTATION -// } -// }; -// -// This function will generate exec functions where Type1 is one of the numeric -// types +// GD for numeric types (integer and floating point) template