diff --git a/cpp/src/arrow/compute/api_vector.h b/cpp/src/arrow/compute/api_vector.h index 9e53cfcf640..ca3afc61285 100644 --- a/cpp/src/arrow/compute/api_vector.h +++ b/cpp/src/arrow/compute/api_vector.h @@ -172,6 +172,7 @@ class ARROW_EXPORT SelectKOptions : public FunctionOptions { int64_t k; /// Column key(s) to order by and how to order by these sort keys. std::vector sort_keys; + NullPlacement null_placement = NullPlacement::AtEnd; }; /// \brief Partitioning options for NthToIndices diff --git a/cpp/src/arrow/compute/kernels/chunked_internal.h b/cpp/src/arrow/compute/kernels/chunked_internal.h index b007d6cbfb8..01bc9a2696b 100644 --- a/cpp/src/arrow/compute/kernels/chunked_internal.h +++ b/cpp/src/arrow/compute/kernels/chunked_internal.h @@ -41,6 +41,8 @@ struct ResolvedChunk { // The index in the target array. const int64_t index; + ResolvedChunk() : index(-1) {} + ResolvedChunk(const ArrayType* array, int64_t index) : array(array), index(index) {} bool IsNull() const { return array->IsNull(index); } @@ -139,11 +141,20 @@ struct ChunkedArrayResolver : protected ChunkResolver { template ResolvedChunk Resolve(int64_t index) const { - const auto loc = ChunkResolver::Resolve(index); + const auto loc = ResolveChunkLocation(index); + return Resolve(loc); + } + + template + ResolvedChunk Resolve(ChunkLocation loc) const { return ResolvedChunk( checked_cast(chunks_[loc.chunk_index]), loc.index_in_chunk); } + ChunkLocation ResolveChunkLocation(int64_t index) const { + return ChunkResolver::Resolve(index); + } + protected: static std::vector MakeLengths(const std::vector& chunks) { std::vector lengths(chunks.size()); diff --git a/cpp/src/arrow/compute/kernels/vector_sort.cc b/cpp/src/arrow/compute/kernels/vector_sort.cc index dd5bead58aa..39184d7a8f3 100644 --- a/cpp/src/arrow/compute/kernels/vector_sort.cc +++ b/cpp/src/arrow/compute/kernels/vector_sort.cc @@ -52,6 +52,14 @@ struct SortField { SortOrder order; }; +struct LocationContainer { + public: + LocationContainer() {} + LocationContainer(uint64_t index) : index_(index), chunk_location_({-1, -1}) {} + uint64_t index_; + ChunkLocation chunk_location_; +}; + Status CheckNonNested(const FieldRef& ref) { if (ref.IsNested()) { return Status::KeyError("Nested keys not supported for SortKeys"); @@ -1682,7 +1690,7 @@ class TableSelecter : public TypeVisitor { null_count(chunked_array->null_count()), resolver(chunk_pointers) {} - using LocationType = int64_t; + using LocationType = ChunkLocation; // Find the target chunk and index in the target chunk from an // index in chunked array. @@ -1691,6 +1699,20 @@ class TableSelecter : public TypeVisitor { return resolver.Resolve(index); } + template + ResolvedChunk GetChunk(LocationContainer& location_container) const { + if (location_container.chunk_location_.chunk_index == -1){ + location_container.chunk_location_ = resolver.ResolveChunkLocation(location_container.index_); + } + auto loc = location_container.chunk_location_; + return resolver.Resolve(loc); + } + + template + ResolvedChunk GetChunk(ChunkLocation loc) const { + return resolver.Resolve(loc); + } + const SortOrder order; const std::shared_ptr type; const ArrayVector chunks; @@ -1709,11 +1731,16 @@ class TableSelecter : public TypeVisitor { k_(options.k), output_(output), sort_keys_(ResolveSortKeys(table, options.sort_keys)), - comparator_(sort_keys_, NullPlacement::AtEnd) {} + batches_(TableSelecter::MakeBatches(table, &status_)), + left_resolver_(ChunkResolver::FromBatches(batches_)), + right_resolver_(ChunkResolver::FromBatches(batches_)), + comparator_(sort_keys_, NullPlacement::AtEnd), + options_(options) {} Status Run() { return sort_keys_[0].type->Accept(this); } protected: + #define VISIT(TYPE) \ Status Visit(const TYPE& type) { \ if (sort_keys_[0].order == SortOrder::Descending) \ @@ -1723,6 +1750,14 @@ class TableSelecter : public TypeVisitor { VISIT_SORTABLE_PHYSICAL_TYPES(VISIT) #undef VISIT + static RecordBatchVector MakeBatches(const Table& table, Status* status) { + const auto maybe_batches = BatchesFromTable(table); + if (!maybe_batches.ok()) { + *status = maybe_batches.status(); + return {}; + } + return *std::move(maybe_batches); + } static std::vector ResolveSortKeys( const Table& table, const std::vector& sort_keys) { @@ -1734,45 +1769,8 @@ class TableSelecter : public TypeVisitor { return resolved; } - // Behaves like PartitionNulls() but this supports multiple sort keys. - template - NullPartitionResult PartitionNullsInternal(uint64_t* indices_begin, - uint64_t* indices_end, - const ResolvedSortKey& first_sort_key) { - using ArrayType = typename TypeTraits::ArrayType; - - const auto p = PartitionNullsOnly( - indices_begin, indices_end, first_sort_key.resolver, first_sort_key.null_count, - NullPlacement::AtEnd); - DCHECK_EQ(p.nulls_end - p.nulls_begin, first_sort_key.null_count); - - const auto q = PartitionNullLikes( - p.non_nulls_begin, p.non_nulls_end, first_sort_key.resolver, - NullPlacement::AtEnd); - - auto& comparator = comparator_; - // Sort all NaNs by the second and following sort keys. - std::stable_sort(q.nulls_begin, q.nulls_end, [&](uint64_t left, uint64_t right) { - return comparator.Compare(left, right, 1); - }); - // Sort all nulls by the second and following sort keys. - std::stable_sort(p.nulls_begin, p.nulls_end, [&](uint64_t left, uint64_t right) { - return comparator.Compare(left, right, 1); - }); - - return q; - } - - // XXX this implementation is rather inefficient as it computes chunk indices - // at every comparison. Instead we should iterate over individual batches - // and remember ChunkLocation entries in the max-heap. - template Status SelectKthInternal() { - using ArrayType = typename TypeTraits::ArrayType; - auto& comparator = comparator_; - const auto& first_sort_key = sort_keys_[0]; - const auto num_rows = table_.num_rows(); if (num_rows == 0) { return Status::OK(); @@ -1780,59 +1778,222 @@ class TableSelecter : public TypeVisitor { if (k_ > table_.num_rows()) { k_ = table_.num_rows(); } - std::function cmp; - SelectKComparator select_k_comparator; - cmp = [&](const uint64_t& left, const uint64_t& right) -> bool { - auto chunk_left = first_sort_key.template GetChunk(left); - auto chunk_right = first_sort_key.template GetChunk(right); - auto value_left = chunk_left.Value(); - auto value_right = chunk_right.Value(); - if (value_left == value_right) { - return comparator.Compare(left, right, 1); - } - return select_k_comparator(value_left, value_right); - }; - using HeapContainer = - std::priority_queue, decltype(cmp)>; - - std::vector indices(num_rows); - uint64_t* indices_begin = indices.data(); - uint64_t* indices_end = indices_begin + indices.size(); + // ALVIN NEW CODE + SortOptions sort_options(options_.sort_keys, NullPlacement::AtEnd); + ARROW_ASSIGN_OR_RAISE(auto row_indices, + MakeMutableUInt64Array(uint64(), table_.num_rows(), ctx_->memory_pool())); + auto indices_begin = row_indices->GetMutableValues(1); + auto indices_end = indices_begin + table_.num_rows(); std::iota(indices_begin, indices_end, 0); - const auto p = - this->PartitionNullsInternal(indices_begin, indices_end, first_sort_key); - const auto end_iter = p.non_nulls_end; - auto kth_begin = std::min(indices_begin + k_, end_iter); + RecordBatchVector batches; + { + TableBatchReader reader(table_); + RETURN_NOT_OK(reader.ReadAll(&batches)); + } + const int64_t num_batches = static_cast(batches.size()); + if (num_batches == 0) { + return Status::OK(); + } + std::vector sorted(num_batches); - HeapContainer heap(indices_begin, kth_begin, cmp); - for (auto iter = kth_begin; iter != end_iter && !heap.empty(); ++iter) { - uint64_t x_index = *iter; - uint64_t top_item = heap.top(); - if (cmp(x_index, top_item)) { - heap.pop(); - heap.push(x_index); - } + // First sort all individual batches + int64_t begin_offset = 0; + int64_t end_offset = 0; + int64_t null_count = 0; + for (int64_t i = 0; i < num_batches; ++i) { + const auto& batch = *batches[i]; + end_offset += batch.num_rows(); + RadixRecordBatchSorter sorter(indices_begin + begin_offset, + indices_begin + end_offset, batch, sort_options); + ARROW_ASSIGN_OR_RAISE(sorted[i], sorter.Sort(begin_offset)); + DCHECK_EQ(sorted[i].overall_begin(), indices_begin + begin_offset); + DCHECK_EQ(sorted[i].overall_end(), indices_begin + end_offset); + DCHECK_EQ(sorted[i].non_null_count() + sorted[i].null_count(), batch.num_rows()); + begin_offset = end_offset; + // XXX this is an upper bound on the true null count + null_count += sorted[i].null_count(); } - int64_t out_size = static_cast(heap.size()); - ARROW_ASSIGN_OR_RAISE(auto take_indices, MakeMutableUInt64Array(uint64(), out_size, - ctx_->memory_pool())); - auto* out_cbegin = take_indices->GetMutableValues(1) + out_size - 1; - while (heap.size() > 0) { - *out_cbegin = heap.top(); - heap.pop(); - --out_cbegin; + DCHECK_EQ(end_offset, indices_end - indices_begin); + + if (sorted.size() > 1) { + struct Visitor { + TableSelecter* sorter; + std::vector* sorted; + int64_t null_count; + + Status Visit(const InType& type) { + return sorter->MergeInternal(std::move(*sorted), null_count); + } + + Status Visit(const DataType& type) { + return Status::NotImplemented("Unsupported type for sorting: ", + type.ToString()); + } + }; + Visitor visitor{this, &sorted, null_count}; + RETURN_NOT_OK(VisitTypeInline(*sort_keys_[0].type, &visitor)); } + + ARROW_ASSIGN_OR_RAISE(auto take_indices, MakeMutableUInt64Array(uint64(), k_, + ctx_->memory_pool())); + auto* out_cbegin = take_indices->GetMutableValues(1); + std::copy(indices_begin, indices_begin + k_ , out_cbegin); *output_ = Datum(take_indices); return Status::OK(); } + // Recursive merge routine, typed on the first sort key + template + Status MergeInternal(std::vector sorted, int64_t null_count) { + auto merge_nulls = [&](uint64_t* nulls_begin, uint64_t* nulls_middle, + uint64_t* nulls_end, uint64_t* temp_indices, + int64_t null_count) { + MergeNulls(nulls_begin, nulls_middle, nulls_end, temp_indices, null_count); + }; + auto merge_non_nulls = [&](uint64_t* range_begin, uint64_t* range_middle, + uint64_t* range_end, uint64_t* temp_indices) { + MergeNonNulls(range_begin, range_middle, range_end, temp_indices); + }; + + MergeImpl merge_impl(options_.null_placement, std::move(merge_nulls), + std::move(merge_non_nulls)); + RETURN_NOT_OK(merge_impl.Init(ctx_, table_.num_rows())); + + while (sorted.size() > 1) { + auto out_it = sorted.begin(); + auto it = sorted.begin(); + while (it < sorted.end() - 1) { + const auto& left = *it++; + const auto& right = *it++; + DCHECK_EQ(left.overall_end(), right.overall_begin()); + *out_it++ = merge_impl.Merge(left, right, null_count); + } + if (it < sorted.end()) { + *out_it++ = *it++; + } + sorted.erase(out_it, sorted.end()); + } + DCHECK_EQ(sorted.size(), 1); + return comparator_.status(); + } + // Merge rows with a null or a null-like in the first sort key + template + enable_if_t::value> MergeNulls(uint64_t* nulls_begin, + uint64_t* nulls_middle, + uint64_t* nulls_end, + uint64_t* temp_indices, + int64_t null_count) { + using ArrayType = typename TypeTraits::ArrayType; + + auto& comparator = comparator_; + const auto& first_sort_key = sort_keys_[0]; + + std::merge(nulls_begin, nulls_middle, nulls_middle, nulls_end, temp_indices, + [&](uint64_t left, uint64_t right) { + // First column is either null or nan + const auto left_loc = left_resolver_.Resolve(left); + const auto right_loc = right_resolver_.Resolve(right); + auto chunk_left = first_sort_key.GetChunk(left_loc); + auto chunk_right = first_sort_key.GetChunk(right_loc); + const auto left_is_null = chunk_left.IsNull(); + const auto right_is_null = chunk_right.IsNull(); + if (left_is_null == right_is_null) { + return comparator.Compare(left_loc, right_loc, 1); + } else if (options_.null_placement == NullPlacement::AtEnd) { + return right_is_null; + } else { + return left_is_null; + } + }); + // Copy back temp area into main buffer + std::copy(temp_indices, temp_indices + (nulls_end - nulls_begin), nulls_begin); + } + + template + enable_if_t::value> MergeNulls(uint64_t* nulls_begin, + uint64_t* nulls_middle, + uint64_t* nulls_end, + uint64_t* temp_indices, + int64_t null_count) { + MergeNullsOnly(nulls_begin, nulls_middle, nulls_end, temp_indices, null_count); + } + + void MergeNullsOnly(uint64_t* nulls_begin, uint64_t* nulls_middle, uint64_t* nulls_end, + uint64_t* temp_indices, int64_t null_count) { + // Untyped implementation + auto& comparator = comparator_; + + std::merge(nulls_begin, nulls_middle, nulls_middle, nulls_end, temp_indices, + [&](uint64_t left, uint64_t right) { + // First column is always null + const auto left_loc = left_resolver_.Resolve(left); + const auto right_loc = right_resolver_.Resolve(right); + return comparator.Compare(left_loc, right_loc, 1); + }); + // Copy back temp area into main buffer + std::copy(temp_indices, temp_indices + (nulls_end - nulls_begin), nulls_begin); + } + + // + // Merge rows with a non-null in the first sort key + // + template + enable_if_t::value> MergeNonNulls(uint64_t* range_begin, + uint64_t* range_middle, + uint64_t* range_end, + uint64_t* temp_indices) { + using ArrayType = typename TypeTraits::ArrayType; + + auto& comparator = comparator_; + const auto& first_sort_key = sort_keys_[0]; + + std::merge(range_begin, range_middle, range_middle, range_end, temp_indices, + [&](uint64_t left, uint64_t right) { + // Both values are never null nor NaN. + const auto left_loc = left_resolver_.Resolve(left); + const auto right_loc = right_resolver_.Resolve(right); + auto chunk_left = first_sort_key.GetChunk(left_loc); + auto chunk_right = first_sort_key.GetChunk(right_loc); + DCHECK(!chunk_left.IsNull()); + DCHECK(!chunk_right.IsNull()); + auto value_left = chunk_left.Value(); + auto value_right = chunk_right.Value(); + if (value_left == value_right) { + // If the left value equals to the right value, + // we need to compare the second and following + // sort keys. + return comparator.Compare(left_loc, right_loc, 1); + } else { + auto compared = value_left < value_right; + if (first_sort_key.order == SortOrder::Ascending) { + return compared; + } else { + return !compared; + } + } + }); + // Copy back temp area into main buffer + std::copy(temp_indices, temp_indices + (range_end - range_begin), range_begin); + } + + template + enable_if_null MergeNonNulls(uint64_t* range_begin, uint64_t* range_middle, + uint64_t* range_end, uint64_t* temp_indices) { + const int64_t null_count = range_end - range_begin; + MergeNullsOnly(range_begin, range_middle, range_end, temp_indices, null_count); + } + ExecContext* ctx_; const Table& table_; int64_t k_; Datum* output_; std::vector sort_keys_; + const RecordBatchVector batches_; + const ChunkResolver left_resolver_, right_resolver_; + Status status_; Comparator comparator_; + const SelectKOptions& options_; }; static Status CheckConsistency(const Schema& schema, diff --git a/cpp/src/arrow/compute/kernels/vector_sort_internal.h b/cpp/src/arrow/compute/kernels/vector_sort_internal.h index d8b024525c8..7742d366329 100644 --- a/cpp/src/arrow/compute/kernels/vector_sort_internal.h +++ b/cpp/src/arrow/compute/kernels/vector_sort_internal.h @@ -342,6 +342,11 @@ struct MergeImpl { std::function; + using MergeGeneralFunc = + std::function; + + MergeImpl(NullPlacement null_placement, MergeNullsFunc&& merge_nulls, MergeNonNullsFunc&& merge_non_nulls) : null_placement_(null_placement),