diff --git a/cpp/src/arrow/compute/kernels/gather_internal.h b/cpp/src/arrow/compute/kernels/gather_internal.h index 4c161533a72..dfc893a4da2 100644 --- a/cpp/src/arrow/compute/kernels/gather_internal.h +++ b/cpp/src/arrow/compute/kernels/gather_internal.h @@ -20,8 +20,14 @@ #include #include #include +#include +#include "arrow/array/array_base.h" #include "arrow/array/data.h" +#include "arrow/chunk_resolver.h" +#include "arrow/chunked_array.h" +#include "arrow/type_fwd.h" +#include "arrow/type_traits.h" #include "arrow/util/bit_block_counter.h" #include "arrow/util/bit_run_reader.h" #include "arrow/util/bit_util.h" @@ -52,6 +58,15 @@ class GatherBaseCRTP { ARROW_DEFAULT_MOVE_AND_ASSIGN(GatherBaseCRTP); protected: + template + bool IsSrcValid(const ArraySpan& src_validity, const IndexCType* idx, + int64_t position) const { + // Translate position into index on the source + const int64_t index = idx[position]; + ARROW_COMPILER_ASSUME(src_validity.buffers[0].data != nullptr); + return src_validity.IsValid(index); + } + ARROW_FORCE_INLINE int64_t ExecuteNoNulls(int64_t idx_length) { auto* self = static_cast(this); for (int64_t position = 0; position < idx_length; position++) { @@ -76,8 +91,12 @@ class GatherBaseCRTP { // doesn't have to be called for resulting null positions. A position is // considered null if either the index or the source value is null at that // position. - template - ARROW_FORCE_INLINE int64_t ExecuteWithNulls(const ArraySpan& src_validity, + // + // ValiditySpan is any class that `GatherImpl::IsSrcValid(src_validity, idx, position)` + // can be called with. + template + ARROW_FORCE_INLINE int64_t ExecuteWithNulls(const ValiditySpan& src_validity, int64_t idx_length, const IndexCType* idx, const ArraySpan& idx_validity, uint8_t* out_is_valid) { @@ -116,12 +135,11 @@ class GatherBaseCRTP { position += block.length; } } else { - // Source values may be null, so we must do random access into src_validity + // Source values may be null, so we must do random access with IsSrcValid() if (block.popcount == block.length) { // Faster path: indices are not null but source values may be for (int64_t i = 0; i < block.length; ++i) { - ARROW_COMPILER_ASSUME(src_validity.buffers[0].data != nullptr); - if (src_validity.IsValid(idx[position])) { + if (self->IsSrcValid(src_validity, idx, position)) { // value is not null self->WriteValue(position); bit_util::SetBit(out_is_valid, position); @@ -136,9 +154,9 @@ class GatherBaseCRTP { // random access in general we have to check the value nullness one by // one. for (int64_t i = 0; i < block.length; ++i) { - ARROW_COMPILER_ASSUME(src_validity.buffers[0].data != nullptr); ARROW_COMPILER_ASSUME(idx_validity.buffers[0].data != nullptr); - if (idx_validity.IsValid(position) && src_validity.IsValid(idx[position])) { + if (idx_validity.IsValid(position) && + self->IsSrcValid(src_validity, idx, position)) { // index is not null && value is not null self->WriteValue(position); bit_util::SetBit(out_is_valid, position); @@ -303,4 +321,139 @@ class Gather } }; +template +struct ChunkedValiditySpan { + const ChunkedArray& chunks_validity; + const TypedChunkLocation* chunk_location_vec; + const bool may_have_nulls; + + ChunkedValiditySpan(const ChunkedArray& chunks_validity, + const TypedChunkLocation* chunk_location_vec) + : chunks_validity(chunks_validity), + chunk_location_vec(chunk_location_vec), + may_have_nulls(chunks_validity.null_count() > 0) {} + + bool MayHaveNulls() const { return may_have_nulls; } + + bool IsSrcValid(const IndexCType* idx, int64_t position) const { + // idx is unused because all the indices have been pre-resolved into + // `chunk_location_vec` by ChunkResolver::ResolveMany. + ARROW_UNUSED(idx); + auto loc = chunk_location_vec[position]; + return chunks_validity.chunk(static_cast(loc.chunk_index)) + ->IsValid(loc.index_in_chunk); + } +}; + +template +class GatherFromChunks + : public GatherBaseCRTP< + GatherFromChunks> { + private: + static_assert(!kWithFactor || kValueWidthInBits == 8, + "kWithFactor is only supported for kValueWidthInBits == 8"); + static_assert(kValueWidthInBits == 1 || kValueWidthInBits % 8 == 0); + // kValueWidth should not be used if kValueWidthInBits == 1. + static constexpr int kValueWidth = kValueWidthInBits / 8; + + // src_residual_bit_offsets_[i] is used to store the bit offset of the first byte (0-7) + // in src_chunks_[i] iff kValueWidthInBits == 1. + const int* src_residual_bit_offsets_ = NULLPTR; + // Pre-computed pointers to the start of the values in each chunk. + const uint8_t* const* src_chunks_; + // Number indices resolved in chunk_location_vec_. + const int64_t idx_length_; + const TypedChunkLocation* chunk_location_vec_; + + uint8_t* out_; + int64_t factor_; + + public: + void WriteValue(int64_t position) { + auto loc = chunk_location_vec_[position]; + auto* chunk = src_chunks_[loc.chunk_index]; + if constexpr (kValueWidthInBits == 1) { + auto src_offset = src_residual_bit_offsets_[loc.chunk_index]; + bit_util::SetBitTo(out_, position, + bit_util::GetBit(chunk, src_offset + loc.index_in_chunk)); + } else if constexpr (kWithFactor) { + const int64_t scaled_factor = kValueWidth * factor_; + memcpy(out_ + position * scaled_factor, chunk + loc.index_in_chunk * scaled_factor, + scaled_factor); + } else { + memcpy(out_ + position * kValueWidth, chunk + loc.index_in_chunk * kValueWidth, + kValueWidth); + } + } + + void WriteZero(int64_t position) { + if constexpr (kValueWidthInBits == 1) { + bit_util::ClearBit(out_, position); + } else if constexpr (kWithFactor) { + const int64_t scaled_factor = kValueWidth * factor_; + memset(out_ + position * scaled_factor, 0, scaled_factor); + } else { + memset(out_ + position * kValueWidth, 0, kValueWidth); + } + } + + void WriteZeroSegment(int64_t position, int64_t block_length) { + if constexpr (kValueWidthInBits == 1) { + bit_util::SetBitsTo(out_, position, block_length, false); + } else if constexpr (kWithFactor) { + const int64_t scaled_factor = kValueWidth * factor_; + memset(out_ + position * scaled_factor, 0, block_length * scaled_factor); + } else { + memset(out_ + position * kValueWidth, 0, block_length * kValueWidth); + } + } + + bool IsSrcValid(const ChunkedValiditySpan& src_validity, + const IndexCType* idx, int64_t position) const { + return src_validity.IsSrcValid(idx, position); + } + + public: + GatherFromChunks(const int* src_residual_bit_offsets, const uint8_t* const* src_chunks, + const int64_t idx_length, + const TypedChunkLocation* chunk_location_vec, uint8_t* out, + int64_t factor = 1) + : src_residual_bit_offsets_(src_residual_bit_offsets), + src_chunks_(src_chunks), + idx_length_(idx_length), + chunk_location_vec_(chunk_location_vec), + out_(out), + factor_(factor) { + assert(src_chunks && chunk_location_vec_ && out); + if constexpr (kValueWidthInBits == 1) { + assert(src_residual_bit_offsets); + } + assert((kWithFactor || factor == 1) && + "When kWithFactor is false, the factor is assumed to be 1 at compile time"); + } + + ARROW_FORCE_INLINE int64_t Execute() { return this->ExecuteNoNulls(idx_length_); } + + /// \pre If kOutputIsZeroInitialized, then this->out_ has to be zero initialized. + /// \pre Bits in out_is_valid have to always be zero initialized. + /// \post The bits for the valid elements (and only those) are set in out_is_valid. + /// \post If !kOutputIsZeroInitialized, then positions in this->_out containing null + /// elements have 0s written to them. This might be less efficient than + /// zero-initializing first and calling this->Execute() afterwards. + /// \return The number of valid elements in out. + template + ARROW_FORCE_INLINE int64_t Execute(const ChunkedArray& src_validity, + const ArraySpan& idx_validity, + uint8_t* out_is_valid) { + assert(idx_length_ == idx_validity.length); + assert(out_is_valid); + assert(idx_validity.type->byte_width() == sizeof(IndexCType)); + ChunkedValiditySpan src_validity_span{src_validity, chunk_location_vec_}; + assert(src_validity_span.MayHaveNulls() || idx_validity.MayHaveNulls()); + // idx=NULLPTR because when it's passed to IsSrcValid() defined above, it's not used. + return this->template ExecuteWithNulls( + src_validity_span, idx_length_, /*idx=*/NULLPTR, idx_validity, out_is_valid); + } +}; + } // namespace arrow::internal diff --git a/cpp/src/arrow/compute/kernels/vector_selection.cc b/cpp/src/arrow/compute/kernels/vector_selection.cc index 6c6f1b36b84..ddc96485239 100644 --- a/cpp/src/arrow/compute/kernels/vector_selection.cc +++ b/cpp/src/arrow/compute/kernels/vector_selection.cc @@ -294,6 +294,8 @@ std::shared_ptr MakeIndicesNonZeroFunction(std::string name, VectorKernel kernel; kernel.null_handling = NullHandling::OUTPUT_NOT_NULL; kernel.mem_allocation = MemAllocation::NO_PREALLOCATE; + // "array_take" ensures that the output will be be chunked when at least one + // input is chunked, so we need to set this to false. kernel.output_chunked = false; kernel.exec = IndicesNonZeroExec; kernel.exec_chunked = IndicesNonZeroExecChunked; @@ -339,6 +341,7 @@ void RegisterVectorSelection(FunctionRegistry* registry) { VectorKernel take_base; take_base.init = TakeState::Init; take_base.can_execute_chunkwise = false; + take_base.output_chunked = false; RegisterSelectionFunction("array_take", array_take_doc, take_base, std::move(take_kernels), GetDefaultTakeOptions(), registry); diff --git a/cpp/src/arrow/compute/kernels/vector_selection_filter_internal.cc b/cpp/src/arrow/compute/kernels/vector_selection_filter_internal.cc index 194c3591337..1cecbd3f2fd 100644 --- a/cpp/src/arrow/compute/kernels/vector_selection_filter_internal.cc +++ b/cpp/src/arrow/compute/kernels/vector_selection_filter_internal.cc @@ -895,18 +895,23 @@ Status ExtensionFilterExec(KernelContext* ctx, const ExecSpan& batch, ExecResult } // Transform filter to selection indices and then use Take. -Status FilterWithTakeExec(const ArrayKernelExec& take_exec, KernelContext* ctx, +Status FilterWithTakeExec(TakeKernelExec take_aaa_exec, KernelContext* ctx, const ExecSpan& batch, ExecResult* out) { - std::shared_ptr indices; + std::shared_ptr indices_data; RETURN_NOT_OK(GetTakeIndices(batch[1].array, FilterState::Get(ctx).null_selection_behavior, ctx->memory_pool()) - .Value(&indices)); + .Value(&indices_data)); + KernelContext take_ctx(*ctx); TakeState state{TakeOptions::NoBoundsCheck()}; take_ctx.SetState(&state); - ExecSpan take_batch({batch[0], ArraySpan(*indices)}, batch.length); - return take_exec(&take_ctx, take_batch, out); + + ValuesSpan values(batch[0].array); + std::shared_ptr out_data = out->array_data(); + RETURN_NOT_OK(take_aaa_exec(&take_ctx, values, *indices_data, &out_data)); + out->value = std::move(out_data); + return Status::OK(); } // Due to the special treatment with their Take kernels, we filter Struct and SparseUnion diff --git a/cpp/src/arrow/compute/kernels/vector_selection_internal.cc b/cpp/src/arrow/compute/kernels/vector_selection_internal.cc index 7fe8d9b8866..acc6c5a2fc4 100644 --- a/cpp/src/arrow/compute/kernels/vector_selection_internal.cc +++ b/cpp/src/arrow/compute/kernels/vector_selection_internal.cc @@ -19,6 +19,7 @@ #include #include #include +#include #include #include "arrow/array/array_binary.h" @@ -60,6 +61,7 @@ void RegisterSelectionFunction(const std::string& name, FunctionDoc doc, {std::move(kernel_data.value_type), std::move(kernel_data.selection_type)}, OutputType(FirstType)); base_kernel.exec = kernel_data.exec; + base_kernel.exec_chunked = kernel_data.chunked_exec; DCHECK_OK(func->AddKernel(base_kernel)); } kernels.clear(); @@ -192,20 +194,34 @@ struct Selection { }; KernelContext* ctx; - const ArraySpan& values; - const ArraySpan& selection; + const ArraySpan values; + const ArraySpan selection; int64_t output_length; ArrayData* out; TypedBufferBuilder validity_builder; - Selection(KernelContext* ctx, const ExecSpan& batch, int64_t output_length, - ExecResult* out) + Selection(KernelContext* ctx, ArraySpan values, ArraySpan selection, + int64_t output_length, ArrayData* out) : ctx(ctx), - values(batch[0].array), - selection(batch[1].array), + values(std::move(values)), + selection(std::move(selection)), output_length(output_length), - out(out->array_data().get()), - validity_builder(ctx->memory_pool()) {} + out(out), + validity_builder(ctx->memory_pool()) { + // If the selection is an array of indices, the output length should + // match the number of indices in the selection array. + DCHECK(!is_integer(selection.type->id()) || output_length == selection.length); + } + + Selection(KernelContext* ctx, const ExecSpan& batch, int64_t output_length, + ExecResult* out) + : Selection(ctx, batch[0].array, batch[1].array, output_length, + out->array_data().get()) {} + + Selection(KernelContext* ctx, const ValuesSpan& values, const ArraySpan& indices, + std::shared_ptr* out) + : Selection(ctx, values.array(), indices, /*output_length=*/indices.length, + out->get()) {} virtual ~Selection() = default; @@ -483,9 +499,9 @@ struct VarBinarySelectionImpl : public Selection, T static constexpr int64_t kOffsetLimit = std::numeric_limits::max() - 1; - VarBinarySelectionImpl(KernelContext* ctx, const ExecSpan& batch, int64_t output_length, - ExecResult* out) - : Base(ctx, batch, output_length, out), + template + explicit VarBinarySelectionImpl(KernelContext* ctx, Args... args) + : Base(ctx, std::forward(args)...), offset_builder(ctx->memory_pool()), data_builder(ctx->memory_pool()) {} @@ -557,9 +573,9 @@ struct ListSelectionImpl : public Selection, Type> { TypedBufferBuilder offset_builder; typename TypeTraits::OffsetBuilderType child_index_builder; - ListSelectionImpl(KernelContext* ctx, const ExecSpan& batch, int64_t output_length, - ExecResult* out) - : Base(ctx, batch, output_length, out), + template + explicit ListSelectionImpl(KernelContext* ctx, Args... args) + : Base(ctx, std::forward(args)...), offset_builder(ctx->memory_pool()), child_index_builder(ctx->memory_pool()) {} @@ -622,9 +638,9 @@ struct ListViewSelectionImpl : public Selection, Typ TypedBufferBuilder offsets_builder; TypedBufferBuilder sizes_builder; - ListViewSelectionImpl(KernelContext* ctx, const ExecSpan& batch, int64_t output_length, - ExecResult* out) - : Base(ctx, batch, output_length, out), + template + explicit ListViewSelectionImpl(KernelContext* ctx, Args... args) + : Base(ctx, std::forward(args)...), offsets_builder(ctx->memory_pool()), sizes_builder(ctx->memory_pool()) {} @@ -679,9 +695,9 @@ struct DenseUnionSelectionImpl std::vector type_codes_; std::vector child_indices_builders_; - DenseUnionSelectionImpl(KernelContext* ctx, const ExecSpan& batch, - int64_t output_length, ExecResult* out) - : Base(ctx, batch, output_length, out), + template + explicit DenseUnionSelectionImpl(KernelContext* ctx, Args... args) + : Base(ctx, std::forward(args)...), value_offset_buffer_builder_(ctx->memory_pool()), child_id_buffer_builder_(ctx->memory_pool()), type_codes_(checked_cast(*this->values.type).type_codes()), @@ -760,9 +776,9 @@ struct SparseUnionSelectionImpl TypedBufferBuilder child_id_buffer_builder_; const int8_t type_code_for_null_; - SparseUnionSelectionImpl(KernelContext* ctx, const ExecSpan& batch, - int64_t output_length, ExecResult* out) - : Base(ctx, batch, output_length, out), + template + explicit SparseUnionSelectionImpl(KernelContext* ctx, Args... args) + : Base(ctx, std::forward(args)...), child_id_buffer_builder_(ctx->memory_pool()), type_code_for_null_( checked_cast(*this->values.type).type_codes()[0]) {} @@ -811,9 +827,9 @@ struct FSLSelectionImpl : public Selection using Base = Selection; LIFT_BASE_MEMBERS(); - FSLSelectionImpl(KernelContext* ctx, const ExecSpan& batch, int64_t output_length, - ExecResult* out) - : Base(ctx, batch, output_length, out), child_index_builder(ctx->memory_pool()) {} + template + explicit FSLSelectionImpl(KernelContext* ctx, Args... args) + : Base(ctx, std::forward(args)...), child_index_builder(ctx->memory_pool()) {} template Status GenerateOutput() { @@ -952,69 +968,80 @@ Status MapFilterExec(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) namespace { -template -Status TakeExec(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) { +template +Status TakeAAAExec(KernelContext* ctx, const ValuesSpan& values, const ArraySpan& indices, + std::shared_ptr* out) { + DCHECK(!values.is_chunked()) + << "TakeAAAExec kernels can't be called with chunked array values"; if (TakeState::Get(ctx).boundscheck) { - RETURN_NOT_OK(CheckIndexBounds(batch[1].array, batch[0].length())); + RETURN_NOT_OK(CheckIndexBounds(indices, values.length())); } - Impl kernel(ctx, batch, /*output_length=*/batch[1].length(), out); + SelectionImpl kernel(ctx, values, indices, out); return kernel.ExecTake(); } } // namespace -Status VarBinaryTakeExec(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) { - return TakeExec>(ctx, batch, out); +Status VarBinaryTakeExec(KernelContext* ctx, const ValuesSpan& values, + const ArraySpan& indices, std::shared_ptr* out) { + return TakeAAAExec>(ctx, values, indices, out); } -Status LargeVarBinaryTakeExec(KernelContext* ctx, const ExecSpan& batch, - ExecResult* out) { - return TakeExec>(ctx, batch, out); +Status LargeVarBinaryTakeExec(KernelContext* ctx, const ValuesSpan& values, + const ArraySpan& indices, std::shared_ptr* out) { + return TakeAAAExec>(ctx, values, indices, out); } -Status ListTakeExec(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) { - return TakeExec>(ctx, batch, out); +Status ListTakeExec(KernelContext* ctx, const ValuesSpan& values, + const ArraySpan& indices, std::shared_ptr* out) { + return TakeAAAExec>(ctx, values, indices, out); } -Status LargeListTakeExec(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) { - return TakeExec>(ctx, batch, out); +Status LargeListTakeExec(KernelContext* ctx, const ValuesSpan& values, + const ArraySpan& indices, std::shared_ptr* out) { + return TakeAAAExec>(ctx, values, indices, out); } -Status ListViewTakeExec(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) { - return TakeExec>(ctx, batch, out); +Status ListViewTakeExec(KernelContext* ctx, const ValuesSpan& values, + const ArraySpan& indices, std::shared_ptr* out) { + return TakeAAAExec>(ctx, values, indices, out); } -Status LargeListViewTakeExec(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) { - return TakeExec>(ctx, batch, out); +Status LargeListViewTakeExec(KernelContext* ctx, const ValuesSpan& values, + const ArraySpan& indices, std::shared_ptr* out) { + return TakeAAAExec>(ctx, values, indices, out); } -Status FSLTakeExec(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) { - const ArraySpan& values = batch[0].array; - +Status FSLTakeExec(KernelContext* ctx, const ValuesSpan& values, const ArraySpan& indices, + std::shared_ptr* out) { // If a FixedSizeList wraps a fixed-width type we can, in some cases, use // FixedWidthTakeExec for a fixed-size list array. - if (util::IsFixedWidthLike(values, + if (util::IsFixedWidthLike(values.array(), /*force_null_count=*/true, /*exclude_bool_and_dictionary=*/true)) { - return FixedWidthTakeExec(ctx, batch, out); + return FixedWidthTakeExec(ctx, values, indices, out); } - return TakeExec(ctx, batch, out); + return TakeAAAExec(ctx, values, indices, out); } -Status DenseUnionTakeExec(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) { - return TakeExec(ctx, batch, out); +Status DenseUnionTakeExec(KernelContext* ctx, const ValuesSpan& values, + const ArraySpan& indices, std::shared_ptr* out) { + return TakeAAAExec(ctx, values, indices, out); } -Status SparseUnionTakeExec(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) { - return TakeExec(ctx, batch, out); +Status SparseUnionTakeExec(KernelContext* ctx, const ValuesSpan& values, + const ArraySpan& indices, std::shared_ptr* out) { + return TakeAAAExec(ctx, values, indices, out); } -Status StructTakeExec(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) { - return TakeExec(ctx, batch, out); +Status StructTakeExec(KernelContext* ctx, const ValuesSpan& values, + const ArraySpan& indices, std::shared_ptr* out) { + return TakeAAAExec(ctx, values, indices, out); } -Status MapTakeExec(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) { - return TakeExec>(ctx, batch, out); +Status MapTakeExec(KernelContext* ctx, const ValuesSpan& values, const ArraySpan& indices, + std::shared_ptr* out) { + return TakeAAAExec>(ctx, values, indices, out); } } // namespace compute::internal diff --git a/cpp/src/arrow/compute/kernels/vector_selection_internal.h b/cpp/src/arrow/compute/kernels/vector_selection_internal.h index 887bf083541..049a0e9abc7 100644 --- a/cpp/src/arrow/compute/kernels/vector_selection_internal.h +++ b/cpp/src/arrow/compute/kernels/vector_selection_internal.h @@ -17,11 +17,16 @@ #pragma once +#include #include +#include +#include #include +#include #include #include "arrow/array/data.h" +#include "arrow/chunk_resolver.h" #include "arrow/compute/api_vector.h" #include "arrow/compute/exec.h" #include "arrow/compute/function.h" @@ -33,10 +38,91 @@ namespace arrow::compute::internal { using FilterState = OptionsWrapper; using TakeState = OptionsWrapper; +/// \brief A class used to represent the values argument in take kernels. +/// +/// It can represent either a chunked array or a single array. When the values +/// are chunked, the class provides a ChunkResolver to resolve the target array +/// and index in the chunked array. +class ValuesSpan { + private: + const std::shared_ptr chunked_ = nullptr; + const ArraySpan chunk0_; // first chunk or the whole array + mutable std::optional chunk_resolver_; + + public: + explicit ValuesSpan(const std::shared_ptr values) + : chunked_(std::move(values)), chunk0_{*chunked_->chunk(0)->data()} { + assert(chunked_); + assert(chunked_->num_chunks() > 0); + } + + explicit ValuesSpan(const ArraySpan& values) // NOLINT(modernize-pass-by-value) + : chunk0_(values) {} + + explicit ValuesSpan(const ArrayData& values) : chunk0_{ArraySpan{values}} {} + + bool is_chunked() const { return chunked_ != nullptr; } + + const ChunkedArray& chunked_array() const { + assert(is_chunked()); + return *chunked_; + } + + /// \brief Lazily builds a ChunkResolver from the underlying chunked array. + /// + /// \note This method is not thread-safe. + /// \pre is_chunked() + const ChunkResolver& chunk_resolver() const { + assert(is_chunked()); + if (!chunk_resolver_.has_value()) { + chunk_resolver_.emplace(chunked_->chunks()); + } + return *chunk_resolver_; + } + + const ArraySpan& chunk0() const { return chunk0_; } + + const ArraySpan& array() const { + assert(!is_chunked()); + return chunk0_; + } + + const DataType* type() const { return chunk0_.type; } + + int64_t length() const { return is_chunked() ? chunked_->length() : array().length; } + + bool MayHaveNulls() const { + return is_chunked() ? chunked_->null_count() != 0 : array().MayHaveNulls(); + } +}; + +/// \brief Type for a single "array_take" kernel function. +/// +/// Instead of implementing both `ArrayKernelExec` and `ChunkedExec` typed +/// functions for each configurations of `array_take` parameters, we use +/// templates wrapping `TakeKernelExec` functions to expose exec functions +/// that can be registered in the kernel registry. +/// +/// A `TakeKernelExec` always returns a single array, which is the result of +/// taking values from a single array (AA->A) or multiple arrays (CA->A). The +/// wrappers take care of converting the output of a CA call to C or calling +/// the kernel multiple times to process a CC call. +using TakeKernelExec = Status (*)(KernelContext*, const ValuesSpan&, const ArraySpan&, + std::shared_ptr*); + struct SelectionKernelData { + SelectionKernelData(InputType value_type, InputType selection_type, + ArrayKernelExec exec, + VectorKernel::ChunkedExec chunked_exec = NULLPTR) + : value_type(std::move(value_type)), + selection_type(std::move(selection_type)), + exec(exec), + chunked_exec(chunked_exec) {} + InputType value_type; InputType selection_type; ArrayKernelExec exec; + VectorKernel::ChunkedExec chunked_exec; }; void RegisterSelectionFunction(const std::string& name, FunctionDoc doc, @@ -73,17 +159,32 @@ Status FSLFilterExec(KernelContext*, const ExecSpan&, ExecResult*); Status DenseUnionFilterExec(KernelContext*, const ExecSpan&, ExecResult*); Status MapFilterExec(KernelContext*, const ExecSpan&, ExecResult*); -Status VarBinaryTakeExec(KernelContext*, const ExecSpan&, ExecResult*); -Status LargeVarBinaryTakeExec(KernelContext*, const ExecSpan&, ExecResult*); -Status FixedWidthTakeExec(KernelContext*, const ExecSpan&, ExecResult*); -Status ListTakeExec(KernelContext*, const ExecSpan&, ExecResult*); -Status LargeListTakeExec(KernelContext*, const ExecSpan&, ExecResult*); -Status ListViewTakeExec(KernelContext*, const ExecSpan&, ExecResult*); -Status LargeListViewTakeExec(KernelContext*, const ExecSpan&, ExecResult*); -Status FSLTakeExec(KernelContext*, const ExecSpan&, ExecResult*); -Status DenseUnionTakeExec(KernelContext*, const ExecSpan&, ExecResult*); -Status SparseUnionTakeExec(KernelContext*, const ExecSpan&, ExecResult*); -Status StructTakeExec(KernelContext*, const ExecSpan&, ExecResult*); -Status MapTakeExec(KernelContext*, const ExecSpan&, ExecResult*); +// Take kernels compatible with the TakeKernelExec signature +Status VarBinaryTakeExec(KernelContext*, const ValuesSpan&, const ArraySpan&, + std::shared_ptr*); +Status LargeVarBinaryTakeExec(KernelContext*, const ValuesSpan&, const ArraySpan&, + std::shared_ptr*); +Status FixedWidthTakeExec(KernelContext*, const ValuesSpan&, const ArraySpan&, + std::shared_ptr*); +Status FixedWidthTakeChunkedExec(KernelContext*, const ValuesSpan&, const ArraySpan&, + std::shared_ptr*); +Status ListTakeExec(KernelContext*, const ValuesSpan&, const ArraySpan&, + std::shared_ptr*); +Status LargeListTakeExec(KernelContext*, const ValuesSpan&, const ArraySpan&, + std::shared_ptr*); +Status ListViewTakeExec(KernelContext*, const ValuesSpan&, const ArraySpan&, + std::shared_ptr*); +Status LargeListViewTakeExec(KernelContext*, const ValuesSpan&, const ArraySpan&, + std::shared_ptr*); +Status FSLTakeExec(KernelContext*, const ValuesSpan&, const ArraySpan&, + std::shared_ptr*); +Status DenseUnionTakeExec(KernelContext*, const ValuesSpan&, const ArraySpan&, + std::shared_ptr*); +Status SparseUnionTakeExec(KernelContext*, const ValuesSpan&, const ArraySpan&, + std::shared_ptr*); +Status StructTakeExec(KernelContext*, const ValuesSpan&, const ArraySpan&, + std::shared_ptr*); +Status MapTakeExec(KernelContext*, const ValuesSpan&, const ArraySpan&, + std::shared_ptr*); } // namespace arrow::compute::internal diff --git a/cpp/src/arrow/compute/kernels/vector_selection_take_internal.cc b/cpp/src/arrow/compute/kernels/vector_selection_take_internal.cc index fedafeb5bea..d93b37839fc 100644 --- a/cpp/src/arrow/compute/kernels/vector_selection_take_internal.cc +++ b/cpp/src/arrow/compute/kernels/vector_selection_take_internal.cc @@ -327,6 +327,105 @@ namespace { using TakeState = OptionsWrapper; +struct ChunkedFixedWidthValuesSpan { + private: + // src_residual_bit_offsets_[i] is used to store the bit offset of the first byte (0-7) + // in src_chunks_[i] iff kValueWidthInBits == 1. + std::vector src_residual_bit_offsets; + // Pre-computed pointers to the start of the values in each chunk. + std::vector src_chunks; + + public: + explicit ChunkedFixedWidthValuesSpan(const ChunkedArray& values) { + const bool chunk_values_are_bit_sized = values.type()->id() == Type::BOOL; + DCHECK_EQ(chunk_values_are_bit_sized, util::FixedWidthInBytes(*values.type()) == -1); + if (chunk_values_are_bit_sized) { + src_residual_bit_offsets.resize(values.num_chunks()); + } + src_chunks.resize(values.num_chunks()); + + for (int i = 0; i < values.num_chunks(); ++i) { + const ArraySpan chunk{*values.chunk(i)->data()}; + DCHECK(util::IsFixedWidthLike(chunk)); + + auto offset_pointer = util::OffsetPointerOfFixedBitWidthValues(chunk); + if (chunk_values_are_bit_sized) { + src_residual_bit_offsets[i] = offset_pointer.first; + } else { + DCHECK_EQ(offset_pointer.first, 0); + } + src_chunks[i] = offset_pointer.second; + } + } + + const int* src_residual_bit_offsets_data() const { + return src_residual_bit_offsets.empty() ? nullptr : src_residual_bit_offsets.data(); + } + + const uint8_t* const* src_chunks_data() const { return src_chunks.data(); } +}; + +/// \brief Buffer for chunk locations resolved against a chunked array. +struct BoundedLocationBuffer { + private: + std::unique_ptr chunk_location_buffer = NULLPTR; + + Status Allocate(int64_t n_locations, int64_t sizeof_location, MemoryPool* pool) { + ARROW_ASSIGN_OR_RAISE(chunk_location_buffer, + AllocateBuffer(n_locations * sizeof_location, pool)); + return Status::OK(); + } + + public: + ~BoundedLocationBuffer() = default; + + template + Status InitWithCapacity(int64_t n_locations, MemoryPool* pool) { + RETURN_NOT_OK(Allocate(n_locations, sizeof(TypedChunkLocation), pool)); + return Status::OK(); + } + + /// \brief The capacity in terms of number of resolved chunk locations. + /// + /// One location is needed for each index. + template + int64_t Capacity() const { + return chunk_location_buffer->size() / sizeof(TypedChunkLocation); + } + + /// \pre idx_length <= Capacity() + template + Status ResolveIndices(const ChunkResolver& chunk_resolver, int64_t idx_length, + const IndexCType* idx, IndexCType chunk_hint) { + DCHECK_LE(idx_length, Capacity()); + auto* chunk_location_vec = mutable_chunk_location_vec(); + // All indices are resolved in one go without checking the validity bitmap. + // This is OK as long the output corresponding to the invalid indices is not used. + bool enough_precision = chunk_resolver.ResolveMany( + /*n_indices=*/idx_length, /*logical_index_vec=*/idx, chunk_location_vec, + chunk_hint); + if (ARROW_PREDICT_FALSE(!enough_precision)) { + return Status::IndexError("IndexCType is too small"); + } + return Status::OK(); + } + + template + TypedChunkLocation* mutable_chunk_location_vec() { + return chunk_location_buffer->mutable_data_as>(); + } + + template + const TypedChunkLocation* chunk_location_vec() const { + return chunk_location_buffer->data_as>(); + } + + template + IndexCType chunk_index(int64_t position) const { + return chunk_location_vec()[position].chunk_index; + } +}; + // ---------------------------------------------------------------------- // Implement optimized take for primitive types from boolean to // 1/2/4/8/16/32-byte C-type based types and fixed-size binary (0 or more @@ -358,15 +457,21 @@ template 0 && kValueWidthInBits == 8 && // factors are used with bytes static_cast(factor * kValueWidthInBits) == bit_width)); #endif + return values.is_chunked() ? ChunkedExec(ctx, values, indices, out_arr, factor) + : Exec(ctx, values.array(), indices, out_arr, factor); + } + + static Status Exec(KernelContext* ctx, const ArraySpan& values, + const ArraySpan& indices, ArrayData* out_arr, int64_t factor) { const bool out_has_validity = values.MayHaveNulls() || indices.MayHaveNulls(); const uint8_t* src; @@ -396,10 +501,89 @@ struct FixedWidthTakeImpl { out_arr->null_count = out_arr->length - valid_count; return Status::OK(); } + + static Status ChunkedExec(KernelContext* ctx, const ValuesSpan& values, + const ArraySpan& indices, ArrayData* out_arr, + int64_t factor) { + constexpr int64_t kIndexBlockCapacityInBytes = 16 * 1024; + // Must be a multiple of 8 so `GatherFromChunks` can always be + // constructed with byte-aligned output pointers in the loop. + constexpr int64_t kIndexBlockCapacity = + kIndexBlockCapacityInBytes / sizeof(IndexCType); + static_assert((kIndexBlockCapacity * kValueWidthInBits) % 8 == 0); + + ChunkedFixedWidthValuesSpan chunked_values{values.chunked_array()}; + BoundedLocationBuffer location_buffer; + // TODO(felipecrv): find a way to share the buffer on TakeCC kernel + RETURN_NOT_OK(location_buffer.InitWithCapacity( + /*n_locations=*/std::min(kIndexBlockCapacity, indices.length), + ctx->memory_pool())); + + return DoChunkedExec(ctx, values, chunked_values, indices, &location_buffer, out_arr, + factor); + } + + // \pre location_buffer is initialized + // \pre location_buffer->Capacity() is a multiple of 8 + static Status DoChunkedExec(KernelContext* ctx, const ValuesSpan& values, + const ChunkedFixedWidthValuesSpan& chunked_values, + const ArraySpan& indices, + BoundedLocationBuffer* location_buffer, ArrayData* out_arr, + int64_t factor) { + const bool out_has_validity = + values.chunked_array().null_count() > 0 || indices.MayHaveNulls(); + + const auto& chunk_resolver = values.chunk_resolver(); + const auto location_buffer_capacity = location_buffer->Capacity(); + const auto* idx = indices.GetValues(1); + uint8_t* out = util::MutableFixedWidthValuesPointer(out_arr); + int64_t valid_count = 0; + IndexCType chunk_hint = 0; + int64_t idx_offset = 0; + while (idx_offset < indices.length) { + const int64_t block_length = + std::min(location_buffer_capacity, indices.length - idx_offset); + + RETURN_NOT_OK(location_buffer->ResolveIndices( + chunk_resolver, /*idx_length=*/block_length, idx, chunk_hint)); + arrow::internal::GatherFromChunks + gather{chunked_values.src_residual_bit_offsets_data(), + chunked_values.src_chunks_data(), + /*idx_length=*/block_length, + location_buffer->chunk_location_vec(), + out, + factor}; + if (out_has_validity) { + DCHECK_EQ(out_arr->offset, 0); + // out_is_valid must be zero-initiliazed, because Gather::Execute + // saves time by not having to ClearBit on every null element. + auto out_is_valid = out_arr->GetMutableValues(0); + memset(out_is_valid, 0, bit_util::BytesForBits(out_arr->length)); + valid_count += gather.template Execute( + /*src_validity=*/values.chunked_array(), /*idx_validity=*/indices, + out_is_valid); + } else { + valid_count += gather.Execute(); + } + // Prepare for the next iteration + chunk_hint = location_buffer->chunk_index(block_length - 1); + idx_offset += block_length; + if constexpr (WithFactor::value) { + static_assert(kValueWidthInBits == 8); + out += block_length * factor; + } else { + out += (block_length * kValueWidthInBits) / 8; + // The last `out` produced in this loop might not be byte-aligned, + // but that is not a poblem because no value is written to it. + } + } + out_arr->null_count = out_arr->length - valid_count; + return Status::OK(); + } }; template