diff --git a/cpp/src/arrow/compute/api_vector.cc b/cpp/src/arrow/compute/api_vector.cc index a5cb61d6b55..6855e43d4d0 100644 --- a/cpp/src/arrow/compute/api_vector.cc +++ b/cpp/src/arrow/compute/api_vector.cc @@ -170,10 +170,12 @@ PartitionNthOptions::PartitionNthOptions(int64_t pivot, NullPlacement null_place null_placement(null_placement) {} constexpr char PartitionNthOptions::kTypeName[]; -SelectKOptions::SelectKOptions(int64_t k, std::vector sort_keys) +SelectKOptions::SelectKOptions(int64_t k, std::vector sort_keys, + NullPlacement null_placement) : FunctionOptions(internal::kSelectKOptionsType), k(k), - sort_keys(std::move(sort_keys)) {} + sort_keys(std::move(sort_keys)), + null_placement(null_placement) {} constexpr char SelectKOptions::kTypeName[]; namespace internal { diff --git a/cpp/src/arrow/compute/api_vector.h b/cpp/src/arrow/compute/api_vector.h index 9e53cfcf640..abbde034576 100644 --- a/cpp/src/arrow/compute/api_vector.h +++ b/cpp/src/arrow/compute/api_vector.h @@ -142,7 +142,8 @@ class ARROW_EXPORT SortOptions : public FunctionOptions { /// \brief SelectK options class ARROW_EXPORT SelectKOptions : public FunctionOptions { public: - explicit SelectKOptions(int64_t k = -1, std::vector sort_keys = {}); + explicit SelectKOptions(int64_t k = -1, std::vector sort_keys = {}, + NullPlacement null_placement = NullPlacement::AtEnd); static constexpr char const kTypeName[] = "SelectKOptions"; static SelectKOptions Defaults() { return SelectKOptions(); } @@ -172,6 +173,8 @@ class ARROW_EXPORT SelectKOptions : public FunctionOptions { int64_t k; /// Column key(s) to order by and how to order by these sort keys. std::vector sort_keys; + /// Whether nulls and NaNs are placed at the start or at the end + NullPlacement null_placement; }; /// \brief Partitioning options for NthToIndices diff --git a/cpp/src/arrow/compute/kernels/chunked_internal.h b/cpp/src/arrow/compute/kernels/chunked_internal.h index b007d6cbfb8..120685e48d9 100644 --- a/cpp/src/arrow/compute/kernels/chunked_internal.h +++ b/cpp/src/arrow/compute/kernels/chunked_internal.h @@ -144,6 +144,19 @@ struct ChunkedArrayResolver : protected ChunkResolver { checked_cast(chunks_[loc.chunk_index]), loc.index_in_chunk); } + template + ResolvedChunk Resolve(ChunkLocation loc) const { + return ResolvedChunk( + checked_cast(chunks_[loc.chunk_index]), loc.index_in_chunk); + } + + template + ChunkLocation ResolveChunkLocation(int64_t index) const { + const auto loc = ChunkResolver::Resolve(index); + return ResolvedChunk( + checked_cast(chunks_[loc.chunk_index]), loc.index_in_chunk); + } + protected: static std::vector MakeLengths(const std::vector& chunks) { std::vector lengths(chunks.size()); diff --git a/cpp/src/arrow/compute/kernels/vector_sort.cc b/cpp/src/arrow/compute/kernels/vector_sort.cc index dd5bead58aa..c505488af48 100644 --- a/cpp/src/arrow/compute/kernels/vector_sort.cc +++ b/cpp/src/arrow/compute/kernels/vector_sort.cc @@ -52,6 +52,13 @@ struct SortField { SortOrder order; }; +struct LocationContainer { + public: + LocationContainer() {} + uint64_t index_; + ChunkLocation chunk_location_; +}; + Status CheckNonNested(const FieldRef& ref) { if (ref.IsNested()) { return Status::KeyError("Nested keys not supported for SortKeys"); @@ -1682,7 +1689,7 @@ class TableSelecter : public TypeVisitor { null_count(chunked_array->null_count()), resolver(chunk_pointers) {} - using LocationType = int64_t; + using LocationType = ChunkLocation; // Find the target chunk and index in the target chunk from an // index in chunked array. @@ -1691,6 +1698,21 @@ class TableSelecter : public TypeVisitor { return resolver.Resolve(index); } + template + ResolvedChunk GetChunk(LocationContainer& location_container) const { + if (location_container.chunk_location_.chunk_index == -1) { + location_container.chunk_location_ = + resolver.ResolveChunkLocation(location_container.index_); + } + auto loc = location_container.chunk_location_; + return resolver.Resolve(loc); + } + + template + ResolvedChunk GetChunk(ChunkLocation loc) const { + return resolver.Resolve(loc); + } + const SortOrder order; const std::shared_ptr type; const ArrayVector chunks; @@ -1709,7 +1731,11 @@ class TableSelecter : public TypeVisitor { k_(options.k), output_(output), sort_keys_(ResolveSortKeys(table, options.sort_keys)), - comparator_(sort_keys_, NullPlacement::AtEnd) {} + batches_(TableSelecter::MakeBatches(table, &status_)), + left_resolver_(ChunkResolver::FromBatches(batches_)), + right_resolver_(ChunkResolver::FromBatches(batches_)), + comparator_(sort_keys_, NullPlacement::AtEnd), + options_(options) {} Status Run() { return sort_keys_[0].type->Accept(this); } @@ -1723,6 +1749,14 @@ class TableSelecter : public TypeVisitor { VISIT_SORTABLE_PHYSICAL_TYPES(VISIT) #undef VISIT + static RecordBatchVector MakeBatches(const Table& table, Status* status) { + const auto maybe_batches = BatchesFromTable(table); + if (!maybe_batches.ok()) { + *status = maybe_batches.status(); + return {}; + } + return *std::move(maybe_batches); + } static std::vector ResolveSortKeys( const Table& table, const std::vector& sort_keys) { @@ -1734,45 +1768,8 @@ class TableSelecter : public TypeVisitor { return resolved; } - // Behaves like PartitionNulls() but this supports multiple sort keys. - template - NullPartitionResult PartitionNullsInternal(uint64_t* indices_begin, - uint64_t* indices_end, - const ResolvedSortKey& first_sort_key) { - using ArrayType = typename TypeTraits::ArrayType; - - const auto p = PartitionNullsOnly( - indices_begin, indices_end, first_sort_key.resolver, first_sort_key.null_count, - NullPlacement::AtEnd); - DCHECK_EQ(p.nulls_end - p.nulls_begin, first_sort_key.null_count); - - const auto q = PartitionNullLikes( - p.non_nulls_begin, p.non_nulls_end, first_sort_key.resolver, - NullPlacement::AtEnd); - - auto& comparator = comparator_; - // Sort all NaNs by the second and following sort keys. - std::stable_sort(q.nulls_begin, q.nulls_end, [&](uint64_t left, uint64_t right) { - return comparator.Compare(left, right, 1); - }); - // Sort all nulls by the second and following sort keys. - std::stable_sort(p.nulls_begin, p.nulls_end, [&](uint64_t left, uint64_t right) { - return comparator.Compare(left, right, 1); - }); - - return q; - } - - // XXX this implementation is rather inefficient as it computes chunk indices - // at every comparison. Instead we should iterate over individual batches - // and remember ChunkLocation entries in the max-heap. - template Status SelectKthInternal() { - using ArrayType = typename TypeTraits::ArrayType; - auto& comparator = comparator_; - const auto& first_sort_key = sort_keys_[0]; - const auto num_rows = table_.num_rows(); if (num_rows == 0) { return Status::OK(); @@ -1780,59 +1777,146 @@ class TableSelecter : public TypeVisitor { if (k_ > table_.num_rows()) { k_ = table_.num_rows(); } - std::function cmp; - SelectKComparator select_k_comparator; - cmp = [&](const uint64_t& left, const uint64_t& right) -> bool { - auto chunk_left = first_sort_key.template GetChunk(left); - auto chunk_right = first_sort_key.template GetChunk(right); - auto value_left = chunk_left.Value(); - auto value_right = chunk_right.Value(); - if (value_left == value_right) { - return comparator.Compare(left, right, 1); - } - return select_k_comparator(value_left, value_right); - }; - using HeapContainer = - std::priority_queue, decltype(cmp)>; - std::vector indices(num_rows); - uint64_t* indices_begin = indices.data(); - uint64_t* indices_end = indices_begin + indices.size(); + SortOptions sort_options(options_.sort_keys, NullPlacement::AtEnd); + ARROW_ASSIGN_OR_RAISE( + auto row_indices, + MakeMutableUInt64Array(uint64(), table_.num_rows(), ctx_->memory_pool())); + auto indices_begin = row_indices->GetMutableValues(1); + auto indices_end = indices_begin + table_.num_rows(); std::iota(indices_begin, indices_end, 0); - const auto p = - this->PartitionNullsInternal(indices_begin, indices_end, first_sort_key); - const auto end_iter = p.non_nulls_end; - auto kth_begin = std::min(indices_begin + k_, end_iter); + RecordBatchVector batches; + { + TableBatchReader reader(table_); + RETURN_NOT_OK(reader.ReadAll(&batches)); + } + const int64_t num_batches = static_cast(batches.size()); + if (num_batches == 0) { + return Status::OK(); + } + std::vector sorted(num_batches); - HeapContainer heap(indices_begin, kth_begin, cmp); - for (auto iter = kth_begin; iter != end_iter && !heap.empty(); ++iter) { - uint64_t x_index = *iter; - uint64_t top_item = heap.top(); - if (cmp(x_index, top_item)) { - heap.pop(); - heap.push(x_index); - } + // First sort all individual batches + int64_t begin_offset = 0; + int64_t end_offset = 0; + int64_t null_count = 0; + for (int64_t i = 0; i < num_batches; ++i) { + const auto& batch = *batches[i]; + end_offset += batch.num_rows(); + RadixRecordBatchSorter sorter(indices_begin + begin_offset, + indices_begin + end_offset, batch, sort_options); + ARROW_ASSIGN_OR_RAISE(sorted[i], sorter.Sort(begin_offset)); + DCHECK_EQ(sorted[i].overall_begin(), indices_begin + begin_offset); + DCHECK_EQ(sorted[i].overall_end(), indices_begin + end_offset); + DCHECK_EQ(sorted[i].non_null_count() + sorted[i].null_count(), batch.num_rows()); + begin_offset = end_offset; + // XXX this is an upper bound on the true null count + null_count += sorted[i].null_count(); } - int64_t out_size = static_cast(heap.size()); - ARROW_ASSIGN_OR_RAISE(auto take_indices, MakeMutableUInt64Array(uint64(), out_size, - ctx_->memory_pool())); - auto* out_cbegin = take_indices->GetMutableValues(1) + out_size - 1; - while (heap.size() > 0) { - *out_cbegin = heap.top(); - heap.pop(); - --out_cbegin; + DCHECK_EQ(end_offset, indices_end - indices_begin); + + if (sorted.size() > 1) { + struct Visitor { + TableSelecter* sorter; + std::vector* sorted; + int64_t null_count; + + Status Visit(const InType& type) { + return sorter->MergeInternal(std::move(*sorted), null_count); + } + + Status Visit(const DataType& type) { + return Status::NotImplemented("Unsupported type for sorting: ", + type.ToString()); + } + }; + Visitor visitor{this, &sorted, null_count}; + RETURN_NOT_OK(VisitTypeInline(*sort_keys_[0].type, &visitor)); } + + ARROW_ASSIGN_OR_RAISE(auto take_indices, + MakeMutableUInt64Array(uint64(), k_, ctx_->memory_pool())); + auto* out_cbegin = take_indices->GetMutableValues(1); + std::copy(indices_begin, indices_begin + k_, out_cbegin); *output_ = Datum(take_indices); return Status::OK(); } + // Recursive merge routine, typed on the first sort key + template + Status MergeInternal(std::vector sorted, int64_t null_count) { + auto merge_general = [&](uint64_t* range_begin, uint64_t* range_middle, + uint64_t* range_end, uint64_t* temp_indices, uint64_t k) { + MergeNonNulls(range_begin, range_middle, range_end, temp_indices, k); + }; + + MergeImpl merge_impl(options_.null_placement, std::move(merge_general)); + RETURN_NOT_OK(merge_impl.Init(ctx_, table_.num_rows())); + + while (sorted.size() > 1) { + auto out_it = sorted.begin(); + auto it = sorted.begin(); + while (it < sorted.end() - 1) { + const auto& left = *it++; + const auto& right = *it++; + DCHECK_EQ(left.overall_end(), right.overall_begin()); + *out_it++ = merge_impl.MergeKElements(left, right, null_count, k_); + } + if (it < sorted.end()) { + *out_it++ = *it++; + } + sorted.erase(out_it, sorted.end()); + } + DCHECK_EQ(sorted.size(), 1); + return comparator_.status(); + } + + template + void MergeNonNulls(uint64_t* range_begin, uint64_t* range_middle, uint64_t* range_end, + uint64_t* temp_indices, uint64_t k) { + using ArrayType = typename TypeTraits::ArrayType; + auto left_end = std::max(range_begin + k, range_middle); + auto right_end = std::max(range_middle + k, range_end); + auto& comparator = comparator_; + const auto& first_sort_key = sort_keys_[0]; + std::merge(range_begin, left_end, range_middle, right_end, temp_indices, + [&](uint64_t left, uint64_t right) { + // Both values are never null nor NaN. + const auto left_loc = left_resolver_.Resolve(left); + const auto right_loc = right_resolver_.Resolve(right); + auto chunk_left = first_sort_key.GetChunk(left_loc); + auto chunk_right = first_sort_key.GetChunk(right_loc); + auto value_left = chunk_left.Value(); + auto value_right = chunk_right.Value(); + if (value_left == value_right) { + // If the left value equals to the right value, + // we need to compare the second and following + // sort keys. + return comparator.Compare(left_loc, right_loc, 1); + } else { + auto compared = value_left < value_right; + if (first_sort_key.order == SortOrder::Ascending) { + return compared; + } else { + return !compared; + } + } + }); + // Copy back temp area into main buffer + std::copy(temp_indices, temp_indices + k, range_begin); + } + ExecContext* ctx_; const Table& table_; int64_t k_; Datum* output_; std::vector sort_keys_; + const RecordBatchVector batches_; + const ChunkResolver left_resolver_, right_resolver_; + Status status_; Comparator comparator_; + const SelectKOptions& options_; }; static Status CheckConsistency(const Schema& schema, @@ -1936,3 +2020,4 @@ void RegisterVectorSort(FunctionRegistry* registry) { } // namespace internal } // namespace compute } // namespace arrow + diff --git a/cpp/src/arrow/compute/kernels/vector_sort_internal.h b/cpp/src/arrow/compute/kernels/vector_sort_internal.h index d8b024525c8..ee025216e5e 100644 --- a/cpp/src/arrow/compute/kernels/vector_sort_internal.h +++ b/cpp/src/arrow/compute/kernels/vector_sort_internal.h @@ -342,12 +342,19 @@ struct MergeImpl { std::function; + using MergeGeneralFunc = + std::function; + MergeImpl(NullPlacement null_placement, MergeNullsFunc&& merge_nulls, MergeNonNullsFunc&& merge_non_nulls) : null_placement_(null_placement), merge_nulls_(std::move(merge_nulls)), merge_non_nulls_(std::move(merge_non_nulls)) {} + MergeImpl(NullPlacement null_placement, MergeGeneralFunc&& merge_general) + : null_placement_(null_placement), merge_general_(std::move(merge_general)) {} + Status Init(ExecContext* ctx, int64_t temp_indices_length) { ARROW_ASSIGN_OR_RAISE( temp_buffer_, @@ -365,6 +372,24 @@ struct MergeImpl { } } + NullPartitionResult MergeK(const NullPartitionResult& left, + const NullPartitionResult& right, int64_t null_count, + uint64_t k) const { + return MergeKNullsAtEnd(left, right, null_count, k); + } + + NullPartitionResult MergeKElements(const NullPartitionResult& left, + const NullPartitionResult& right, int64_t null_count, + uint64_t k) const { + uint64_t total_elements = left.non_null_count() + left.null_count() + + right.non_null_count() + right.null_count(); + if (total_elements <= k) { + return Merge(left, right, null_count); + } else { + return MergeK(left, right, null_count, k); + } + } + NullPartitionResult MergeNullsAtStart(const NullPartitionResult& left, const NullPartitionResult& right, int64_t null_count) const { @@ -435,10 +460,22 @@ struct MergeImpl { return p; } + NullPartitionResult MergeKNullsAtEnd(const NullPartitionResult& left, + const NullPartitionResult& right, + int64_t null_count, uint64_t k) const { + merge_general_(left.non_nulls_begin, right.non_nulls_begin, right.nulls_end, + temp_indices_, k); + // left.nulls_end and right.non_nulls_begin are unused for mergeK. + const auto p = NullPartitionResult{left.non_nulls_begin, left.nulls_end, + right.non_nulls_begin, right.nulls_end};; + return p; + } + private: NullPlacement null_placement_; MergeNullsFunc merge_nulls_; MergeNonNullsFunc merge_non_nulls_; + MergeGeneralFunc merge_general_; std::unique_ptr temp_buffer_; uint64_t* temp_indices_ = nullptr; };