From f25bf77abc832f347062296bdd645594303771b8 Mon Sep 17 00:00:00 2001 From: Antoine Pitrou Date: Tue, 28 Sep 2021 19:21:24 +0200 Subject: [PATCH] ARROW-10898: [C++] Improve table sort performance Use the same strategy as for chunked array sorting: - first sort each RecordBatch individually, taking advantage of data contiguity for fast indexing - then merge sorted batches recursively, using slower chunked indexing Benchmarks show up to 200% speedups on some benchmark parameters, along with very minor regressions. --- cpp/src/arrow/compute/kernels/vector_sort.cc | 709 ++++++++++-------- .../arrow/compute/kernels/vector_sort_test.cc | 39 +- 2 files changed, 437 insertions(+), 311 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/vector_sort.cc b/cpp/src/arrow/compute/kernels/vector_sort.cc index 1077cf42032..83aa40a23a0 100644 --- a/cpp/src/arrow/compute/kernels/vector_sort.cc +++ b/cpp/src/arrow/compute/kernels/vector_sort.cc @@ -17,6 +17,7 @@ #include #include +#include #include #include #include @@ -286,6 +287,15 @@ struct NullPartitionResult { } } + static NullPartitionResult NullsOnly(uint64_t* indices_begin, uint64_t* indices_end, + NullPlacement null_placement) { + if (null_placement == NullPlacement::AtStart) { + return {indices_end, indices_end, indices_begin, indices_end}; + } else { + return {indices_begin, indices_begin, indices_begin, indices_end}; + } + } + static NullPartitionResult NullsAtEnd(uint64_t* indices_begin, uint64_t* indices_end, uint64_t* midpoint) { DCHECK_GE(midpoint, indices_begin); @@ -440,6 +450,116 @@ NullPartitionResult PartitionNulls(uint64_t* indices_begin, uint64_t* indices_en std::max(q.nulls_end, p.nulls_end)}; } +struct MergeImpl { + using MergeNullsFunc = std::function; + + using MergeNonNullsFunc = + std::function; + + MergeImpl(NullPlacement null_placement, MergeNullsFunc&& merge_nulls, + MergeNonNullsFunc&& merge_non_nulls) + : null_placement_(null_placement), + merge_nulls_(std::move(merge_nulls)), + merge_non_nulls_(std::move(merge_non_nulls)) {} + + Status Init(ExecContext* ctx, int64_t temp_indices_length) { + ARROW_ASSIGN_OR_RAISE( + temp_buffer_, + AllocateBuffer(sizeof(int64_t) * temp_indices_length, ctx->memory_pool())); + temp_indices_ = reinterpret_cast(temp_buffer_->mutable_data()); + return Status::OK(); + } + + NullPartitionResult Merge(const NullPartitionResult& left, + const NullPartitionResult& right, int64_t null_count) const { + if (null_placement_ == NullPlacement::AtStart) { + return MergeNullsAtStart(left, right, null_count); + } else { + return MergeNullsAtEnd(left, right, null_count); + } + } + + NullPartitionResult MergeNullsAtStart(const NullPartitionResult& left, + const NullPartitionResult& right, + int64_t null_count) const { + // Input layout: + // [left nulls .... left non-nulls .... right nulls .... right non-nulls] + DCHECK_EQ(left.nulls_end, left.non_nulls_begin); + DCHECK_EQ(left.non_nulls_end, right.nulls_begin); + DCHECK_EQ(right.nulls_end, right.non_nulls_begin); + + // Mutate the input, stably, to obtain the following layout: + // [left nulls .... right nulls .... left non-nulls .... right non-nulls] + std::rotate(left.non_nulls_begin, right.nulls_begin, right.nulls_end); + + const auto p = NullPartitionResult::NullsAtStart( + left.nulls_begin, right.non_nulls_end, + left.nulls_begin + left.null_count() + right.null_count()); + + // If the type has null-like values (such as NaN), ensure those plus regular + // nulls are partitioned in the right order. Note this assumes that all + // null-like values (e.g. NaN) are ordered equally. + if (p.null_count()) { + merge_nulls_(p.nulls_begin, p.nulls_begin + left.null_count(), p.nulls_end, + temp_indices_, null_count); + } + + // Merge the non-null values into temp area + DCHECK_EQ(right.non_nulls_begin - p.non_nulls_begin, left.non_null_count()); + DCHECK_EQ(p.non_nulls_end - right.non_nulls_begin, right.non_null_count()); + if (p.non_null_count()) { + merge_non_nulls_(p.non_nulls_begin, right.non_nulls_begin, p.non_nulls_end, + temp_indices_); + } + return p; + } + + NullPartitionResult MergeNullsAtEnd(const NullPartitionResult& left, + const NullPartitionResult& right, + int64_t null_count) const { + // Input layout: + // [left non-nulls .... left nulls .... right non-nulls .... right nulls] + DCHECK_EQ(left.non_nulls_end, left.nulls_begin); + DCHECK_EQ(left.nulls_end, right.non_nulls_begin); + DCHECK_EQ(right.non_nulls_end, right.nulls_begin); + + // Mutate the input, stably, to obtain the following layout: + // [left non-nulls .... right non-nulls .... left nulls .... right nulls] + std::rotate(left.nulls_begin, right.non_nulls_begin, right.non_nulls_end); + + const auto p = NullPartitionResult::NullsAtEnd( + left.non_nulls_begin, right.nulls_end, + left.non_nulls_begin + left.non_null_count() + right.non_null_count()); + + // If the type has null-like values (such as NaN), ensure those plus regular + // nulls are partitioned in the right order. Note this assumes that all + // null-like values (e.g. NaN) are ordered equally. + if (p.null_count()) { + merge_nulls_(p.nulls_begin, p.nulls_begin + left.null_count(), p.nulls_end, + temp_indices_, null_count); + } + + // Merge the non-null values into temp area + DCHECK_EQ(left.non_nulls_end - p.non_nulls_begin, left.non_null_count()); + DCHECK_EQ(p.non_nulls_end - left.non_nulls_end, right.non_null_count()); + if (p.non_null_count()) { + merge_non_nulls_(p.non_nulls_begin, left.non_nulls_end, p.non_nulls_end, + temp_indices_); + } + return p; + } + + private: + NullPlacement null_placement_; + MergeNullsFunc merge_nulls_; + MergeNonNullsFunc merge_non_nulls_; + std::unique_ptr temp_buffer_; + uint64_t* temp_indices_ = nullptr; +}; + // ---------------------------------------------------------------------- // partition_nth_indices implementation @@ -860,56 +980,15 @@ void AddSortingKernels(VectorKernel base, VectorFunction* func) { } // ---------------------------------------------------------------------- -// ChunkedArray sorting implementations - -// Sort a chunked array directly without sorting each array in the -// chunked array. This is used for processing the second and following -// sort keys in TableRadixSorter. -// -// This uses the same algorithm as ArrayCompareSorter. -template -class ChunkedArrayCompareSorter { - using ArrayType = typename TypeTraits::ArrayType; +// ChunkedArray sorting implementation - public: - NullPartitionResult Sort(uint64_t* indices_begin, uint64_t* indices_end, - const std::vector& arrays, int64_t null_count, - const ArraySortOptions& options) { - const auto p = PartitionNulls( - indices_begin, indices_end, ChunkedArrayResolver(arrays), null_count, - options.null_placement); - ChunkedArrayResolver resolver(arrays); - if (options.order == SortOrder::Ascending) { - std::stable_sort(p.non_nulls_begin, p.non_nulls_end, - [&](uint64_t left, uint64_t right) { - const auto chunk_left = resolver.Resolve(left); - const auto chunk_right = resolver.Resolve(right); - return chunk_left.Value() < chunk_right.Value(); - }); - } else { - std::stable_sort(p.non_nulls_begin, p.non_nulls_end, - [&](uint64_t left, uint64_t right) { - const auto chunk_left = resolver.Resolve(left); - const auto chunk_right = resolver.Resolve(right); - // We don't use 'left > right' here to reduce required operator. - // If we use 'right < left' here, '<' is only required. - return chunk_right.Value() < chunk_left.Value(); - }); - } - return p; - } -}; - -// Sort a chunked array by sorting each array in the chunked array. -// -// TODO: This is a naive implementation. We'll be able to improve -// performance of this. For example, we'll be able to use threads for -// sorting each array. +// Sort a chunked array by sorting each array in the chunked array, +// then merging the sorted chunks recursively. class ChunkedArraySorter : public TypeVisitor { public: ChunkedArraySorter(ExecContext* ctx, uint64_t* indices_begin, uint64_t* indices_end, const ChunkedArray& chunked_array, const SortOrder order, - const NullPlacement null_placement, bool can_use_array_sorter = true) + const NullPlacement null_placement) : TypeVisitor(), indices_begin_(indices_begin), indices_end_(indices_end), @@ -918,7 +997,6 @@ class ChunkedArraySorter : public TypeVisitor { physical_chunks_(GetPhysicalChunks(chunked_array_, physical_type_)), order_(order), null_placement_(null_placement), - can_use_array_sorter_(can_use_array_sorter), ctx_(ctx) {} Status Sort() { return physical_type_->Accept(this); } @@ -946,38 +1024,48 @@ class ChunkedArraySorter : public TypeVisitor { } const auto arrays = GetArrayPointers(physical_chunks_); - if (can_use_array_sorter_) { - // Sort each chunk independently and merge to sorted indices. - // This is a serial implementation. - ArraySorter sorter; - std::vector sorted(num_chunks); - - // First sort all individual chunks - int64_t begin_offset = 0; - int64_t end_offset = 0; - int64_t null_count = 0; - for (int i = 0; i < num_chunks; ++i) { - const auto array = checked_cast(arrays[i]); - end_offset += array->length(); - null_count += array->null_count(); - sorted[i] = - sorter.impl.Sort(indices_begin_ + begin_offset, indices_begin_ + end_offset, - *array, begin_offset, options); - begin_offset = end_offset; - } - DCHECK_EQ(end_offset, indices_end_ - indices_begin_); - - std::unique_ptr temp_buffer; - uint64_t* temp_indices = nullptr; - if (sorted.size() > 1) { - ARROW_ASSIGN_OR_RAISE( - temp_buffer, - AllocateBuffer(sizeof(int64_t) * (indices_end_ - indices_begin_ - null_count), - ctx_->memory_pool())); - temp_indices = reinterpret_cast(temp_buffer->mutable_data()); - } + // Sort each chunk independently and merge to sorted indices. + // This is a serial implementation. + ArraySorter sorter; + std::vector sorted(num_chunks); + + // First sort all individual chunks + int64_t begin_offset = 0; + int64_t end_offset = 0; + int64_t null_count = 0; + for (int i = 0; i < num_chunks; ++i) { + const auto array = checked_cast(arrays[i]); + end_offset += array->length(); + null_count += array->null_count(); + sorted[i] = + sorter.impl.Sort(indices_begin_ + begin_offset, indices_begin_ + end_offset, + *array, begin_offset, options); + begin_offset = end_offset; + } + DCHECK_EQ(end_offset, indices_end_ - indices_begin_); + + // Then merge them by pairs, recursively + if (sorted.size() > 1) { + auto merge_nulls = [&](uint64_t* nulls_begin, uint64_t* nulls_middle, + uint64_t* nulls_end, uint64_t* temp_indices, + int64_t null_count) { + if (NullTraits::has_null_like_values) { + PartitionNullsOnly(nulls_begin, nulls_end, + ChunkedArrayResolver(arrays), null_count, + null_placement_); + } + }; + auto merge_non_nulls = [&](uint64_t* range_begin, uint64_t* range_middle, + uint64_t* range_end, uint64_t* temp_indices) { + MergeNonNulls(range_begin, range_middle, range_end, arrays, + temp_indices); + }; + + MergeImpl merge_impl{null_placement_, std::move(merge_nulls), + std::move(merge_non_nulls)}; + // std::merge is only called on non-null values, so size temp indices accordingly + RETURN_NOT_OK(merge_impl.Init(ctx_, indices_end_ - indices_begin_ - null_count)); - // Then merge them by pairs, recursively while (sorted.size() > 1) { auto out_it = sorted.begin(); auto it = sorted.begin(); @@ -985,8 +1073,7 @@ class ChunkedArraySorter : public TypeVisitor { const auto& left = *it++; const auto& right = *it++; DCHECK_EQ(left.overall_end(), right.overall_begin()); - const auto merged = - Merge(left, right, arrays, null_count, temp_indices); + const auto merged = merge_impl.Merge(left, right, null_count); *out_it++ = merged; } if (it < sorted.end()) { @@ -994,103 +1081,15 @@ class ChunkedArraySorter : public TypeVisitor { } sorted.erase(out_it, sorted.end()); } - DCHECK_EQ(sorted.size(), 1); - DCHECK_EQ(sorted[0].overall_begin(), indices_begin_); - DCHECK_EQ(sorted[0].overall_end(), indices_end_); - // Note that "nulls" can also include NaNs, hence the >= check - DCHECK_GE(sorted[0].null_count(), null_count); - } else { - // Sort the chunked array directory. - ChunkedArrayCompareSorter sorter; - sorter.Sort(indices_begin_, indices_end_, arrays, chunked_array_.null_count(), - options); } - return Status::OK(); - } - - // Merge two adjacent sorted indices arrays - template - NullPartitionResult Merge(const NullPartitionResult& left, - const NullPartitionResult& right, - const std::vector& arrays, int64_t null_count, - uint64_t* temp_indices) { - if (null_placement_ == NullPlacement::AtStart) { - return MergeNullsAtStart(left, right, arrays, null_count, temp_indices); - } else { - return MergeNullsAtEnd(left, right, arrays, null_count, temp_indices); - } - } - - template - NullPartitionResult MergeNullsAtStart(const NullPartitionResult& left, - const NullPartitionResult& right, - const std::vector& arrays, - int64_t null_count, uint64_t* temp_indices) { - // Input layout: - // [left nulls .... left non-nulls .... right nulls .... right non-nulls] - DCHECK_EQ(left.nulls_end, left.non_nulls_begin); - DCHECK_EQ(left.non_nulls_end, right.nulls_begin); - DCHECK_EQ(right.nulls_end, right.non_nulls_begin); - - // Mutate the input, stably, to obtain the following layout: - // [left nulls .... right nulls .... left non-nulls .... right non-nulls] - std::rotate(left.non_nulls_begin, right.nulls_begin, right.nulls_end); - - const auto p = NullPartitionResult::NullsAtStart( - left.nulls_begin, right.non_nulls_end, - left.nulls_begin + left.null_count() + right.null_count()); - - // If the type has null-like values (such as NaN), ensure those plus regular - // nulls are partitioned in the right order. Note this assumes that all - // null-like values (e.g. NaN) are ordered equally. - if (NullTraits::has_null_like_values) { - PartitionNullsOnly(p.nulls_begin, p.nulls_end, - ChunkedArrayResolver(arrays), null_count, - null_placement_); - } - - // Merge the non-null values into temp area - DCHECK_EQ(right.non_nulls_begin - p.non_nulls_begin, left.non_null_count()); - DCHECK_EQ(p.non_nulls_end - right.non_nulls_begin, right.non_null_count()); - MergeNonNulls(p.non_nulls_begin, right.non_nulls_begin, p.non_nulls_end, - arrays, temp_indices); - return p; - } - template - NullPartitionResult MergeNullsAtEnd(const NullPartitionResult& left, - const NullPartitionResult& right, - const std::vector& arrays, - int64_t null_count, uint64_t* temp_indices) { - // Input layout: - // [left non-nulls .... left nulls .... right non-nulls .... right nulls] - DCHECK_EQ(left.non_nulls_end, left.nulls_begin); - DCHECK_EQ(left.nulls_end, right.non_nulls_begin); - DCHECK_EQ(right.non_nulls_end, right.nulls_begin); - - // Mutate the input, stably, to obtain the following layout: - // [left non-nulls .... right non-nulls .... left nulls .... right nulls] - std::rotate(left.nulls_begin, right.non_nulls_begin, right.non_nulls_end); + DCHECK_EQ(sorted.size(), 1); + DCHECK_EQ(sorted[0].overall_begin(), indices_begin_); + DCHECK_EQ(sorted[0].overall_end(), indices_end_); + // Note that "nulls" can also include NaNs, hence the >= check + DCHECK_GE(sorted[0].null_count(), null_count); - const auto p = NullPartitionResult::NullsAtEnd( - left.non_nulls_begin, right.nulls_end, - left.non_nulls_begin + left.non_null_count() + right.non_null_count()); - - // If the type has null-like values (such as NaN), ensure those plus regular - // nulls are partitioned in the right order. Note this assumes that all - // null-like values (e.g. NaN) are ordered equally. - if (NullTraits::has_null_like_values) { - PartitionNullsOnly(p.nulls_begin, p.nulls_end, - ChunkedArrayResolver(arrays), null_count, - null_placement_); - } - - // Merge the non-null values into temp area - DCHECK_EQ(left.non_nulls_end - p.non_nulls_begin, left.non_null_count()); - DCHECK_EQ(p.non_nulls_end - left.non_nulls_end, right.non_null_count()); - MergeNonNulls(p.non_nulls_begin, left.non_nulls_end, p.non_nulls_end, - arrays, temp_indices); - return p; + return Status::OK(); } template @@ -1128,7 +1127,6 @@ class ChunkedArraySorter : public TypeVisitor { const ArrayVector physical_chunks_; const SortOrder order_; const NullPlacement null_placement_; - const bool can_use_array_sorter_; ExecContext* ctx_; }; @@ -1139,7 +1137,7 @@ class ChunkedArraySorter : public TypeVisitor { // to be non-null. template void VisitConstantRanges(const ArrayType& array, uint64_t* indices_begin, - uint64_t* indices_end, Visitor&& visit) { + uint64_t* indices_end, int64_t offset, Visitor&& visit) { using GetView = GetViewType; if (indices_begin == indices_end) { @@ -1147,9 +1145,9 @@ void VisitConstantRanges(const ArrayType& array, uint64_t* indices_begin, } auto range_start = indices_begin; auto range_cur = range_start; - auto last_value = GetView::LogicalValue(array.GetView(*range_cur)); + auto last_value = GetView::LogicalValue(array.GetView(*range_cur - offset)); while (++range_cur != indices_end) { - auto v = GetView::LogicalValue(array.GetView(*range_cur)); + auto v = GetView::LogicalValue(array.GetView(*range_cur - offset)); if (v != last_value) { visit(range_start, range_cur); range_start = range_cur; @@ -1169,7 +1167,8 @@ class RecordBatchColumnSorter { : next_column_(next_column) {} virtual ~RecordBatchColumnSorter() {} - virtual void SortRange(uint64_t* indices_begin, uint64_t* indices_end) = 0; + virtual NullPartitionResult SortRange(uint64_t* indices_begin, uint64_t* indices_end, + int64_t offset) = 0; protected: RecordBatchColumnSorter* next_column_; @@ -1190,10 +1189,10 @@ class ConcreteRecordBatchColumnSorter : public RecordBatchColumnSorter { null_placement_(null_placement), null_count_(array_.null_count()) {} - void SortRange(uint64_t* indices_begin, uint64_t* indices_end) { + NullPartitionResult SortRange(uint64_t* indices_begin, uint64_t* indices_end, + int64_t offset) override { using GetView = GetViewType; - constexpr int64_t offset = 0; NullPartitionResult p; if (null_count_ == 0) { p = NullPartitionResult::NoNulls(indices_begin, indices_end, null_placement_); @@ -1231,19 +1230,22 @@ class ConcreteRecordBatchColumnSorter : public RecordBatchColumnSorter { if (next_column_ != nullptr) { // Visit all ranges of equal values in this column and sort them on // the next column. - SortNextColumn(q.nulls_begin, q.nulls_end); - SortNextColumn(p.nulls_begin, p.nulls_end); - VisitConstantRanges(array_, q.non_nulls_begin, q.non_nulls_end, + SortNextColumn(q.nulls_begin, q.nulls_end, offset); + SortNextColumn(p.nulls_begin, p.nulls_end, offset); + VisitConstantRanges(array_, q.non_nulls_begin, q.non_nulls_end, offset, [&](uint64_t* range_start, uint64_t* range_end) { - SortNextColumn(range_start, range_end); + SortNextColumn(range_start, range_end, offset); }); } + return NullPartitionResult{q.non_nulls_begin, q.non_nulls_end, + std::min(q.nulls_begin, p.nulls_begin), + std::max(q.nulls_end, p.nulls_end)}; } - void SortNextColumn(uint64_t* indices_begin, uint64_t* indices_end) { + void SortNextColumn(uint64_t* indices_begin, uint64_t* indices_end, int64_t offset) { // Avoid the cost of a virtual method call in trivial cases if (indices_end - indices_begin > 1) { - next_column_->SortRange(indices_begin, indices_end); + next_column_->SortRange(indices_begin, indices_end, offset); } } @@ -1261,16 +1263,18 @@ class ConcreteRecordBatchColumnSorter : public RecordBatchColumnSorter ConcreteRecordBatchColumnSorter(std::shared_ptr array, SortOrder order, NullPlacement null_placement, RecordBatchColumnSorter* next_column = nullptr) - : RecordBatchColumnSorter(next_column), owned_array_(std::move(array)) {} + : RecordBatchColumnSorter(next_column), null_placement_(null_placement) {} - void SortRange(uint64_t* indices_begin, uint64_t* indices_end) { + NullPartitionResult SortRange(uint64_t* indices_begin, uint64_t* indices_end, + int64_t offset) { if (next_column_ != nullptr) { - next_column_->SortRange(indices_begin, indices_end); + next_column_->SortRange(indices_begin, indices_end, offset); } + return NullPartitionResult::NullsOnly(indices_begin, indices_end, null_placement_); } protected: - const std::shared_ptr owned_array_; + const NullPlacement null_placement_; }; // Sort a batch using a single-pass left-to-right radix sort. @@ -1283,7 +1287,8 @@ class RadixRecordBatchSorter { indices_begin_(indices_begin), indices_end_(indices_end) {} - Status Sort() { + // Offset is for table sorting + Result Sort(int64_t offset = 0) { ARROW_ASSIGN_OR_RAISE(const auto sort_keys, ResolveSortKeys(batch_, options_.sort_keys)); @@ -1297,8 +1302,7 @@ class RadixRecordBatchSorter { } // Sort from left to right - column_sorts.front()->SortRange(indices_begin_, indices_end_); - return Status::OK(); + return column_sorts.front()->SortRange(indices_begin_, indices_end_, offset); } protected: @@ -1419,6 +1423,8 @@ class MultipleKeyComparator { #undef VISIT + Status Visit(const NullType& type) { return Status::OK(); } + Status Visit(const DataType& type) { return Status::TypeError("Unsupported type for RecordBatch sorting: ", type.ToString()); @@ -1663,40 +1669,17 @@ class MultipleKeyRecordBatchSorter : public TypeVisitor { }; // ---------------------------------------------------------------------- -// Table sorting implementations - -// Sort a table using a radix sort-like algorithm. -// A distinct stable sort is called for each sort key, from the last key to the first. -class TableRadixSorter { - public: - Status Sort(ExecContext* ctx, uint64_t* indices_begin, uint64_t* indices_end, - const Table& table, const SortOptions& options) { - for (auto i = options.sort_keys.size(); i > 0; --i) { - const auto& sort_key = options.sort_keys[i - 1]; - const auto& chunked_array = table.GetColumnByName(sort_key.name); - if (!chunked_array) { - return Status::Invalid("Nonexistent sort key column: ", sort_key.name); - } - // We can use ArraySorter only for the sort key that is - // processed first because ArraySorter doesn't care about - // existing indices. - const auto can_use_array_sorter = (i == 0); - ChunkedArraySorter sorter(ctx, indices_begin, indices_end, *chunked_array.get(), - sort_key.order, options.null_placement, - can_use_array_sorter); - ARROW_RETURN_NOT_OK(sorter.Sort()); - } - return Status::OK(); - } -}; +// Table sorting implementation(s) -// Sort a table using a single sort and multiple-key comparisons. -class MultipleKeyTableSorter : public TypeVisitor { - public: +// Sort a table using an explicit merge sort. +// Each batch is first sorted individually (taking advantage of the fact +// that batch columns are contiguous and therefore have less indexing +// overhead), then sorted batches are merged recursively. +class TableSorter { // TODO instead of resolving chunks for each column independently, we could - // split the table into RecordBatches and pay the cost of chunked indexing - // at the first column only. + // rely on RecordBatches and resolve the first column only. + public: // Preprocessed sort key. struct ResolvedSortKey { ResolvedSortKey(const std::shared_ptr& chunked_array, @@ -1709,7 +1692,7 @@ class MultipleKeyTableSorter : public TypeVisitor { num_chunks(chunked_array->num_chunks()), resolver(chunk_pointers) {} - // Finds the target chunk and index in the target chunk from an + // Find the target chunk and index in the target chunk from an // index in chunked array. template ResolvedChunk GetChunk(int64_t index) const { @@ -1725,34 +1708,27 @@ class MultipleKeyTableSorter : public TypeVisitor { const ChunkedArrayResolver resolver; }; - using Comparator = MultipleKeyComparator; - - public: - MultipleKeyTableSorter(uint64_t* indices_begin, uint64_t* indices_end, - const Table& table, const SortOptions& options) - : indices_begin_(indices_begin), + TableSorter(ExecContext* ctx, uint64_t* indices_begin, uint64_t* indices_end, + const Table& table, const SortOptions& options) + : ctx_(ctx), + table_(table), + options_(options), + indices_begin_(indices_begin), indices_end_(indices_end), sort_keys_(ResolveSortKeys(table, options.sort_keys, &status_)), null_placement_(options.null_placement), comparator_(sort_keys_, null_placement_) {} - // This is optimized for the first sort key. The first sort key sort - // is processed in this class. The second and following sort keys - // are processed in Comparator. + // This is optimized for null partitioning and merging along the first sort key. + // Other sort keys are delegated to the Comparator class. Status Sort() { ARROW_RETURN_NOT_OK(status_); - return sort_keys_[0].type->Accept(this); + return SortInternal(); } -#define VISIT(TYPE) \ - Status Visit(const TYPE& type) override { return SortInternal(); } - - VISIT_PHYSICAL_TYPES(VISIT) - VISIT(NullType) - -#undef VISIT - private: + using Comparator = MultipleKeyComparator; + static std::vector ResolveSortKeys( const Table& table, const std::vector& sort_keys, Status* status) { const auto maybe_resolved = @@ -1764,78 +1740,216 @@ class MultipleKeyTableSorter : public TypeVisitor { return *std::move(maybe_resolved); } + Status SortInternal() { + // Sort each batch independently and merge to sorted indices. + RecordBatchVector batches; + { + TableBatchReader reader(table_); + RETURN_NOT_OK(reader.ReadAll(&batches)); + } + const int64_t num_batches = static_cast(batches.size()); + if (num_batches == 0) { + return Status::OK(); + } + std::vector sorted(num_batches); + + // First sort all individual batches + int64_t begin_offset = 0; + int64_t end_offset = 0; + int64_t null_count = 0; + for (int64_t i = 0; i < num_batches; ++i) { + const auto& batch = *batches[i]; + end_offset += batch.num_rows(); + RadixRecordBatchSorter sorter(indices_begin_ + begin_offset, + indices_begin_ + end_offset, batch, options_); + ARROW_ASSIGN_OR_RAISE(sorted[i], sorter.Sort(begin_offset)); + DCHECK_EQ(sorted[i].overall_begin(), indices_begin_ + begin_offset); + DCHECK_EQ(sorted[i].overall_end(), indices_begin_ + end_offset); + DCHECK_EQ(sorted[i].non_null_count() + sorted[i].null_count(), batch.num_rows()); + begin_offset = end_offset; + // XXX this is an upper bound on the true null count + null_count += sorted[i].null_count(); + } + DCHECK_EQ(end_offset, indices_end_ - indices_begin_); + + // Then merge them by pairs, recursively + if (sorted.size() > 1) { + struct Visitor { + TableSorter* sorter; + std::vector* sorted; + int64_t null_count; + +#define VISIT(TYPE) \ + Status Visit(const TYPE& type) { \ + return sorter->MergeInternal(std::move(*sorted), null_count); \ + } + + VISIT_PHYSICAL_TYPES(VISIT) + VISIT(NullType) +#undef VISIT + + Status Visit(const DataType& type) { + return Status::NotImplemented("Unsupported type for sorting: ", + type.ToString()); + } + }; + Visitor visitor{this, &sorted, null_count}; + RETURN_NOT_OK(VisitTypeInline(*sort_keys_[0].type, &visitor)); + } + return Status::OK(); + } + + // Recursive merge routine, typed on the first sort key template - enable_if_t::value, Status> SortInternal() { + Status MergeInternal(std::vector sorted, int64_t null_count) { + auto merge_nulls = [&](uint64_t* nulls_begin, uint64_t* nulls_middle, + uint64_t* nulls_end, uint64_t* temp_indices, + int64_t null_count) { + MergeNulls(nulls_begin, nulls_middle, nulls_end, temp_indices, null_count); + }; + auto merge_non_nulls = [&](uint64_t* range_begin, uint64_t* range_middle, + uint64_t* range_end, uint64_t* temp_indices) { + MergeNonNulls(range_begin, range_middle, range_end, temp_indices); + }; + + MergeImpl merge_impl(options_.null_placement, std::move(merge_nulls), + std::move(merge_non_nulls)); + RETURN_NOT_OK(merge_impl.Init(ctx_, table_.num_rows())); + + while (sorted.size() > 1) { + auto out_it = sorted.begin(); + auto it = sorted.begin(); + while (it < sorted.end() - 1) { + const auto& left = *it++; + const auto& right = *it++; + DCHECK_EQ(left.overall_end(), right.overall_begin()); + *out_it++ = merge_impl.Merge(left, right, null_count); + } + if (it < sorted.end()) { + *out_it++ = *it++; + } + sorted.erase(out_it, sorted.end()); + } + DCHECK_EQ(sorted.size(), 1); + DCHECK_EQ(sorted[0].overall_begin(), indices_begin_); + DCHECK_EQ(sorted[0].overall_end(), indices_end_); + return comparator_.status(); + } + + // Merge rows with a null or a null-like in the first sort key + template + enable_if_t::has_null_like_values> MergeNulls(uint64_t* nulls_begin, + uint64_t* nulls_middle, + uint64_t* nulls_end, + uint64_t* temp_indices, + int64_t null_count) { using ArrayType = typename TypeTraits::ArrayType; auto& comparator = comparator_; const auto& first_sort_key = sort_keys_[0]; - const auto p = PartitionNullsInternal(first_sort_key); - std::stable_sort(p.non_nulls_begin, p.non_nulls_end, - [&](uint64_t left, uint64_t right) { - // Both values are never null nor NaN. - auto chunk_left = first_sort_key.GetChunk(left); - auto chunk_right = first_sort_key.GetChunk(right); - auto value_left = chunk_left.Value(); - auto value_right = chunk_right.Value(); - if (value_left == value_right) { - // If the left value equals to the right value, - // we need to compare the second and following - // sort keys. - return comparator.Compare(left, right, 1); - } else { - auto compared = value_left < value_right; - if (first_sort_key.order == SortOrder::Ascending) { - return compared; - } else { - return !compared; - } - } - }); - return comparator_.status(); + std::merge(nulls_begin, nulls_middle, nulls_middle, nulls_end, temp_indices, + [&](uint64_t left, uint64_t right) { + // First column is either null or nan + auto chunk_left = first_sort_key.GetChunk(left); + auto chunk_right = first_sort_key.GetChunk(right); + const auto left_is_null = chunk_left.IsNull(); + const auto right_is_null = chunk_right.IsNull(); + if (left_is_null == right_is_null) { + return comparator.Compare(left, right, 1); + } else if (options_.null_placement == NullPlacement::AtEnd) { + return right_is_null; + } else { + return left_is_null; + } + }); + // Copy back temp area into main buffer + std::copy(temp_indices, temp_indices + (nulls_end - nulls_begin), nulls_begin); } template - enable_if_null SortInternal() { - std::stable_sort(indices_begin_, indices_end_, [&](uint64_t left, uint64_t right) { - return comparator_.Compare(left, right, 1); - }); - return comparator_.status(); + enable_if_t::has_null_like_values> MergeNulls(uint64_t* nulls_begin, + uint64_t* nulls_middle, + uint64_t* nulls_end, + uint64_t* temp_indices, + int64_t null_count) { + MergeNullsOnly(nulls_begin, nulls_middle, nulls_end, temp_indices, null_count); + } + + void MergeNullsOnly(uint64_t* nulls_begin, uint64_t* nulls_middle, uint64_t* nulls_end, + uint64_t* temp_indices, int64_t null_count) { + // Untyped implementation + auto& comparator = comparator_; + + std::merge(nulls_begin, nulls_middle, nulls_middle, nulls_end, temp_indices, + [&](uint64_t left, uint64_t right) { + // First column is always null + return comparator.Compare(left, right, 1); + }); + // Copy back temp area into main buffer + std::copy(temp_indices, temp_indices + (nulls_end - nulls_begin), nulls_begin); } - // Behaves like PatitionNulls() but this supports multiple sort keys. + // + // Merge rows with a non-null in the first sort key // template - NullPartitionResult PartitionNullsInternal(const ResolvedSortKey& first_sort_key) { + enable_if_t::value> MergeNonNulls(uint64_t* range_begin, + uint64_t* range_middle, + uint64_t* range_end, + uint64_t* temp_indices) { using ArrayType = typename TypeTraits::ArrayType; - const auto p = PartitionNullsOnly( - indices_begin_, indices_end_, first_sort_key.resolver, first_sort_key.null_count, - null_placement_); - DCHECK_EQ(p.nulls_end - p.nulls_begin, first_sort_key.null_count); + auto& comparator = comparator_; + const auto& first_sort_key = sort_keys_[0]; - const auto q = PartitionNullLikes( - p.non_nulls_begin, p.non_nulls_end, first_sort_key.resolver, null_placement_); + std::merge(range_begin, range_middle, range_middle, range_end, temp_indices, + [&](uint64_t left, uint64_t right) { + // Both values are never null nor NaN. + auto chunk_left = first_sort_key.GetChunk(left); + auto chunk_right = first_sort_key.GetChunk(right); + DCHECK(!chunk_left.IsNull()); + DCHECK(!chunk_right.IsNull()); + auto value_left = chunk_left.Value(); + auto value_right = chunk_right.Value(); + if (value_left == value_right) { + // If the left value equals to the right value, + // we need to compare the second and following + // sort keys. + return comparator.Compare(left, right, 1); + } else { + auto compared = value_left < value_right; + if (first_sort_key.order == SortOrder::Ascending) { + return compared; + } else { + return !compared; + } + } + }); + // Copy back temp area into main buffer + std::copy(temp_indices, temp_indices + (range_end - range_begin), range_begin); + } + template + enable_if_null MergeNonNulls(uint64_t* range_begin, uint64_t* range_middle, + uint64_t* range_end, uint64_t* temp_indices) { auto& comparator = comparator_; - // Sort all NaNs by the second and following sort keys. - std::stable_sort(q.nulls_begin, q.nulls_end, [&](uint64_t left, uint64_t right) { - return comparator.Compare(left, right, 1); - }); - // Sort all nulls by the second and following sort keys. - std::stable_sort(p.nulls_begin, p.nulls_end, [&](uint64_t left, uint64_t right) { - return comparator.Compare(left, right, 1); - }); - return q; + std::merge(range_begin, range_middle, range_middle, range_end, temp_indices, + [&](uint64_t left, uint64_t right) { + return comparator.Compare(left, right, 1); + }); + std::copy(temp_indices, temp_indices + (range_end - range_begin), range_begin); } + ExecContext* ctx_; + const Table& table_; + const SortOptions& options_; uint64_t* indices_begin_; uint64_t* indices_end_; Status status_; std::vector sort_keys_; - NullPlacement null_placement_; + const NullPlacement null_placement_; Comparator comparator_; }; @@ -1990,16 +2104,9 @@ class SortIndicesMetaFunction : public MetaFunction { auto out_end = out_begin + length; std::iota(out_begin, out_end, 0); - // TODO: We should choose suitable sort implementation - // automatically. The current TableRadixSorter implementation is - // faster than MultipleKeyTableSorter only when the number of - // sort keys is 2 and counting sort is used. So we always - // MultipleKeyTableSorter for now. - // - // TableRadixSorter sorter; - // ARROW_RETURN_NOT_OK(sorter.Sort(ctx, out_begin, out_end, table, options)); - MultipleKeyTableSorter sorter(out_begin, out_end, table, options); - ARROW_RETURN_NOT_OK(sorter.Sort()); + TableSorter sorter(ctx, out_begin, out_end, table, options); + RETURN_NOT_OK(sorter.Sort()); + return Datum(out); } }; @@ -2365,7 +2472,7 @@ class RecordBatchSelecter : public TypeVisitor { class TableSelecter : public TypeVisitor { private: - using ResolvedSortKey = MultipleKeyTableSorter::ResolvedSortKey; + using ResolvedSortKey = TableSorter::ResolvedSortKey; using Comparator = MultipleKeyComparator; public: diff --git a/cpp/src/arrow/compute/kernels/vector_sort_test.cc b/cpp/src/arrow/compute/kernels/vector_sort_test.cc index acdf894aa53..811060b660f 100644 --- a/cpp/src/arrow/compute/kernels/vector_sort_test.cc +++ b/cpp/src/arrow/compute/kernels/vector_sort_test.cc @@ -1182,22 +1182,31 @@ TEST_F(TestRecordBatchSortIndices, NullType) { field("f", int32()), field("g", int32()), field("h", int32()), - field("i", int32()), + field("i", null()), }); auto batch = RecordBatchFromJSON(schema, R"([ - {"a": null, "b": 5, "c": 0, "d": 0, "e": 1, "f": 2, "g": 3, "h": 4, "i": 5}, - {"a": null, "b": 5, "c": 1, "d": 0, "e": 1, "f": 2, "g": 3, "h": 4, "i": 5}, - {"a": null, "b": 2, "c": 2, "d": 0, "e": 1, "f": 2, "g": 3, "h": 4, "i": 5}, - {"a": null, "b": 4, "c": 3, "d": 0, "e": 1, "f": 2, "g": 3, "h": 4, "i": 5} + {"a": null, "b": 5, "c": 0, "d": 0, "e": 1, "f": 2, "g": 3, "h": 4, "i": null}, + {"a": null, "b": 5, "c": 1, "d": 0, "e": 1, "f": 2, "g": 3, "h": 4, "i": null}, + {"a": null, "b": 2, "c": 2, "d": 0, "e": 1, "f": 2, "g": 3, "h": 4, "i": null}, + {"a": null, "b": 4, "c": 3, "d": 0, "e": 1, "f": 2, "g": 3, "h": 4, "i": null} ])"); for (const auto null_placement : AllNullPlacements()) { for (const auto order : AllOrders()) { // Uses radix sorter + AssertSortIndices(batch, + SortOptions( + { + SortKey("a", order), + SortKey("i", order), + }, + null_placement), + "[0, 1, 2, 3]"); AssertSortIndices(batch, SortOptions( { SortKey("a", order), SortKey("b", SortOrder::Ascending), + SortKey("i", order), }, null_placement), "[2, 3, 0, 1]"); @@ -1213,7 +1222,7 @@ TEST_F(TestRecordBatchSortIndices, NullType) { SortKey("f", SortOrder::Ascending), SortKey("g", SortOrder::Ascending), SortKey("h", SortOrder::Ascending), - SortKey("i", SortOrder::Ascending), + SortKey("i", order), }, null_placement), "[2, 3, 0, 1]"); @@ -1444,23 +1453,33 @@ TEST_F(TestTableSortIndices, NullType) { field("a", null()), field("b", int32()), field("c", int32()), + field("d", null()), }); auto table = TableFromJSON(schema, { R"([ - {"a": null, "b": 5, "c": 0}, - {"a": null, "b": 5, "c": 1}, - {"a": null, "b": 2, "c": 2} + {"a": null, "b": 5, "c": 0, "d": null}, + {"a": null, "b": 5, "c": 1, "d": null}, + {"a": null, "b": 2, "c": 2, "d": null} ])", R"([])", - R"([{"a": null, "b": 4, "c": 3}])", + R"([{"a": null, "b": 4, "c": 3, "d": null}])", }); for (const auto null_placement : AllNullPlacements()) { for (const auto order : AllOrders()) { + AssertSortIndices(table, + SortOptions( + { + SortKey("a", order), + SortKey("d", order), + }, + null_placement), + "[0, 1, 2, 3]"); AssertSortIndices(table, SortOptions( { SortKey("a", order), SortKey("b", SortOrder::Ascending), + SortKey("d", order), }, null_placement), "[2, 3, 0, 1]");