diff --git a/cpp/src/arrow/compute/kernels/vector_sort.cc b/cpp/src/arrow/compute/kernels/vector_sort.cc index 83aa40a23a0..6276f96aef4 100644 --- a/cpp/src/arrow/compute/kernels/vector_sort.cc +++ b/cpp/src/arrow/compute/kernels/vector_sort.cc @@ -93,20 +93,26 @@ Result> FindSortKeys(const Schema& schema, return fields; } -template +template Result> ResolveSortKeys( - const TableOrBatch& table_or_batch, const std::vector& sort_keys) { - ARROW_ASSIGN_OR_RAISE(const auto fields, - FindSortKeys(*table_or_batch.schema(), sort_keys)); + const Schema& schema, const std::vector& sort_keys, + ResolvedSortKeyFactory&& factory) { + ARROW_ASSIGN_OR_RAISE(const auto fields, FindSortKeys(schema, sort_keys)); std::vector resolved; resolved.reserve(fields.size()); - std::transform(fields.begin(), fields.end(), std::back_inserter(resolved), - [&](const SortField& f) { - return ResolvedSortKey{table_or_batch.column(f.field_index), f.order}; - }); + std::transform(fields.begin(), fields.end(), std::back_inserter(resolved), factory); return resolved; } +template +Result> ResolveSortKeys( + const TableOrBatch& table_or_batch, const std::vector& sort_keys) { + return ResolveSortKeys( + *table_or_batch.schema(), sort_keys, [&](const SortField& f) { + return ResolvedSortKey{table_or_batch.column(f.field_index), f.order}; + }); +} + // The target chunk in a chunked array. template struct ResolvedChunk { @@ -138,16 +144,18 @@ struct ResolvedChunk { bool IsNull() const { return array->IsNull(index); } }; +struct ChunkLocation { + int64_t chunk_index, index_in_chunk; +}; + // An object that resolves an array chunk depending on the index. -struct ChunkedArrayResolver { - explicit ChunkedArrayResolver(const std::vector& chunks) - : num_chunks_(static_cast(chunks.size())), - chunks_(chunks), - offsets_(MakeEndOffsets(chunks)), +struct ChunkResolver { + explicit ChunkResolver(std::vector lengths) + : num_chunks_(static_cast(lengths.size())), + offsets_(MakeEndOffsets(std::move(lengths))), cached_chunk_(0) {} - template - ResolvedChunk Resolve(int64_t index) const { + ChunkLocation Resolve(int64_t index) const { // It is common for the algorithms below to make consecutive accesses at // a relatively small distance from each other, hence often falling in // the same chunk. @@ -157,17 +165,22 @@ struct ChunkedArrayResolver { const bool cache_hit = (index >= offsets_[cached_chunk_] && index < offsets_[cached_chunk_ + 1]); if (ARROW_PREDICT_TRUE(cache_hit)) { - return ResolvedChunk( - checked_cast(chunks_[cached_chunk_]), - index - offsets_[cached_chunk_]); + return {cached_chunk_, index - offsets_[cached_chunk_]}; } else { - return ResolveMissBisect(index); + return ResolveMissBisect(index); } } - private: - template - ResolvedChunk ResolveMissBisect(int64_t index) const { + static ChunkResolver FromBatches(const RecordBatchVector& batches) { + std::vector lengths(batches.size()); + std::transform( + batches.begin(), batches.end(), lengths.begin(), + [](const std::shared_ptr& batch) { return batch->num_rows(); }); + return ChunkResolver(std::move(lengths)); + } + + protected: + ChunkLocation ResolveMissBisect(int64_t index) const { // Like std::upper_bound(), but hand-written as it can help the compiler. const int64_t* raw_offsets = offsets_.data(); // Search [lo, lo + n) @@ -183,29 +196,48 @@ struct ChunkedArrayResolver { } } cached_chunk_ = lo; - return ResolvedChunk(checked_cast(chunks_[lo]), - index - offsets_[lo]); + return {lo, index - offsets_[lo]}; } - static std::vector MakeEndOffsets(const std::vector& chunks) { - std::vector end_offsets(chunks.size() + 1); + static std::vector MakeEndOffsets(std::vector lengths) { int64_t offset = 0; - end_offsets[0] = 0; - std::transform(chunks.begin(), chunks.end(), end_offsets.begin() + 1, - [&](const Array* chunk) { - offset += chunk->length(); - return offset; - }); - return end_offsets; + for (auto& v : lengths) { + const auto this_length = v; + v = offset; + offset += this_length; + } + lengths.push_back(offset); + return lengths; } int64_t num_chunks_; - const std::vector chunks_; std::vector offsets_; mutable int64_t cached_chunk_; }; +struct ChunkedArrayResolver : protected ChunkResolver { + explicit ChunkedArrayResolver(const std::vector& chunks) + : ChunkResolver(MakeLengths(chunks)), chunks_(chunks) {} + + template + ResolvedChunk Resolve(int64_t index) const { + const auto loc = ChunkResolver::Resolve(index); + return ResolvedChunk( + checked_cast(chunks_[loc.chunk_index]), loc.index_in_chunk); + } + + protected: + static std::vector MakeLengths(const std::vector& chunks) { + std::vector lengths(chunks.size()); + std::transform(chunks.begin(), chunks.end(), lengths.begin(), + [](const Array* arr) { return arr->length(); }); + return lengths; + } + + const std::vector chunks_; +}; + // We could try to reproduce the concrete Array classes' facilities // (such as cached raw values pointer) in a separate hierarchy of // physical accessors, but doing so ends up too cumbersome. @@ -217,9 +249,8 @@ std::shared_ptr GetPhysicalArray(const Array& array, return MakeArray(std::move(new_data)); } -ArrayVector GetPhysicalChunks(const ChunkedArray& chunked_array, +ArrayVector GetPhysicalChunks(const ArrayVector& chunks, const std::shared_ptr& physical_type) { - const auto& chunks = chunked_array.chunks(); ArrayVector physical(chunks.size()); std::transform(chunks.begin(), chunks.end(), physical.begin(), [&](const std::shared_ptr& array) { @@ -228,6 +259,11 @@ ArrayVector GetPhysicalChunks(const ChunkedArray& chunked_array, return physical; } +ArrayVector GetPhysicalChunks(const ChunkedArray& chunked_array, + const std::shared_ptr& physical_type) { + return GetPhysicalChunks(chunked_array.chunks(), physical_type); +} + std::vector GetArrayPointers(const ArrayVector& arrays) { std::vector pointers(arrays.size()); std::transform(arrays.begin(), arrays.end(), pointers.begin(), @@ -235,6 +271,13 @@ std::vector GetArrayPointers(const ArrayVector& arrays) { return pointers; } +Result BatchesFromTable(const Table& table) { + RecordBatchVector batches; + TableBatchReader reader(table); + RETURN_NOT_OK(reader.ReadAll(&batches)); + return batches; +} + // NOTE: std::partition is usually faster than std::stable_partition. struct NonStablePartitioner { @@ -252,7 +295,65 @@ struct StablePartitioner { } }; -// TODO factor out value comparison and NaN checking? +// Compare two values, taking NaNs into account + +template +struct ValueComparator; + +template +struct ValueComparator::value>> { + template + static int Compare(const Value& left, const Value& right, SortOrder order, + NullPlacement null_placement) { + int compared; + if (left == right) { + compared = 0; + } else if (left > right) { + compared = 1; + } else { + compared = -1; + } + if (order == SortOrder::Descending) { + compared = -compared; + } + return compared; + } +}; + +template +struct ValueComparator::value>> { + template + static int Compare(const Value& left, const Value& right, SortOrder order, + NullPlacement null_placement) { + const bool is_nan_left = std::isnan(left); + const bool is_nan_right = std::isnan(right); + if (is_nan_left && is_nan_right) { + return 0; + } else if (is_nan_left) { + return null_placement == NullPlacement::AtStart ? -1 : 1; + } else if (is_nan_right) { + return null_placement == NullPlacement::AtStart ? 1 : -1; + } + int compared; + if (left == right) { + compared = 0; + } else if (left > right) { + compared = 1; + } else { + compared = -1; + } + if (order == SortOrder::Descending) { + compared = -compared; + } + return compared; + } +}; + +template +int CompareTypeValues(const Value& left, const Value& right, SortOrder order, + NullPlacement null_placement) { + return ValueComparator::Compare(left, right, order, null_placement); +} template struct NullTraits { @@ -1365,73 +1466,125 @@ class RadixRecordBatchSorter { uint64_t* indices_end_; }; +// Compare two records in a single column (either from a batch or table) +template +struct ColumnComparator { + using Location = typename ResolvedSortKey::LocationType; + + ColumnComparator(const ResolvedSortKey& sort_key, NullPlacement null_placement) + : sort_key_(sort_key), null_placement_(null_placement) {} + + virtual ~ColumnComparator() = default; + + virtual int Compare(const Location& left, const Location& right) const = 0; + + ResolvedSortKey sort_key_; + NullPlacement null_placement_; +}; + +template +struct ConcreteColumnComparator : public ColumnComparator { + using ArrayType = typename TypeTraits::ArrayType; + using Location = typename ResolvedSortKey::LocationType; + + using ColumnComparator::ColumnComparator; + + int Compare(const Location& left, const Location& right) const override { + const auto& sort_key = this->sort_key_; + + const auto chunk_left = sort_key.template GetChunk(left); + const auto chunk_right = sort_key.template GetChunk(right); + if (sort_key.null_count > 0) { + const bool is_null_left = chunk_left.IsNull(); + const bool is_null_right = chunk_right.IsNull(); + if (is_null_left && is_null_right) { + return 0; + } else if (is_null_left) { + return this->null_placement_ == NullPlacement::AtStart ? -1 : 1; + } else if (is_null_right) { + return this->null_placement_ == NullPlacement::AtStart ? 1 : -1; + } + } + return CompareTypeValues(chunk_left.Value(), chunk_right.Value(), + sort_key.order, this->null_placement_); + } +}; + +template +struct ConcreteColumnComparator + : public ColumnComparator { + using Location = typename ResolvedSortKey::LocationType; + + using ColumnComparator::ColumnComparator; + + int Compare(const Location& left, const Location& right) const override { return 0; } +}; + // Compare two records in the same RecordBatch or Table // (indexing is handled through ResolvedSortKey) template class MultipleKeyComparator { public: + using Location = typename ResolvedSortKey::LocationType; + MultipleKeyComparator(const std::vector& sort_keys, NullPlacement null_placement) - : sort_keys_(sort_keys), null_placement_(null_placement) {} + : sort_keys_(sort_keys), null_placement_(null_placement) { + status_ &= MakeComparators(); + } Status status() const { return status_; } // Returns true if the left-th value should be ordered before the // right-th value, false otherwise. The start_sort_key_index-th // sort key and subsequent sort keys are used for comparison. - bool Compare(uint64_t left, uint64_t right, size_t start_sort_key_index) { - current_left_ = left; - current_right_ = right; - current_compared_ = 0; - auto num_sort_keys = sort_keys_.size(); - for (size_t i = start_sort_key_index; i < num_sort_keys; ++i) { - current_sort_key_index_ = i; - status_ = VisitTypeInline(*sort_keys_[i].type, this); - // If the left value equals to the right value, we need to - // continue to sort. - if (current_compared_ != 0) { - break; - } - } - return current_compared_ < 0; + bool Compare(const Location& left, const Location& right, size_t start_sort_key_index) { + return CompareInternal(left, right, start_sort_key_index) < 0; } - bool Equals(uint64_t left, uint64_t right, size_t start_sort_key_index) { - current_left_ = left; - current_right_ = right; - current_compared_ = 0; - auto num_sort_keys = sort_keys_.size(); - for (size_t i = start_sort_key_index; i < num_sort_keys; ++i) { - current_sort_key_index_ = i; - status_ = VisitTypeInline(*sort_keys_[i].type, this); - // If the left value equals to the right value, we need to - // continue to sort. - if (current_compared_ != 0) { - return false; - } - } - return current_compared_ == 0; + bool Equals(const Location& left, const Location& right, size_t start_sort_key_index) { + return CompareInternal(left, right, start_sort_key_index) == 0; } -#define VISIT(TYPE) \ - Status Visit(const TYPE& type) { \ - current_compared_ = CompareType(); \ - return Status::OK(); \ - } + private: + struct ColumnComparatorFactory { +#define VISIT(TYPE) \ + Status Visit(const TYPE& type) { return VisitGeneric(type); } - VISIT_PHYSICAL_TYPES(VISIT) + VISIT_PHYSICAL_TYPES(VISIT) + VISIT(NullType) #undef VISIT - Status Visit(const NullType& type) { return Status::OK(); } + Status Visit(const DataType& type) { + return Status::TypeError("Unsupported type for batch or table sorting: ", + type.ToString()); + } + + template + Status VisitGeneric(const Type& type) { + res.reset( + new ConcreteColumnComparator{sort_key, null_placement}); + return Status::OK(); + } + + const ResolvedSortKey& sort_key; + NullPlacement null_placement; + std::unique_ptr> res; + }; + + Status MakeComparators() { + column_comparators_.reserve(sort_keys_.size()); - Status Visit(const DataType& type) { - return Status::TypeError("Unsupported type for RecordBatch sorting: ", - type.ToString()); + for (const auto& sort_key : sort_keys_) { + ColumnComparatorFactory factory{sort_key, null_placement_, nullptr}; + RETURN_NOT_OK(VisitTypeInline(*sort_key.type, &factory)); + column_comparators_.push_back(std::move(factory.res)); + } + return Status::OK(); } - private: - // Compares two records in the same table and returns -1, 0 or 1. + // Compare two records in the same table and return -1, 0 or 1. // // -1: The left is less than the right. // 0: The left equals to the right. @@ -1439,87 +1592,22 @@ class MultipleKeyComparator { // // This supports null and NaN. Null is processed in this and NaN // is processed in CompareTypeValue(). - template - int32_t CompareType() { - using ArrayType = typename TypeTraits::ArrayType; - const auto& sort_key = sort_keys_[current_sort_key_index_]; - auto order = sort_key.order; - const auto chunk_left = sort_key.template GetChunk(current_left_); - const auto chunk_right = sort_key.template GetChunk(current_right_); - if (sort_key.null_count > 0) { - const bool is_null_left = chunk_left.IsNull(); - const bool is_null_right = chunk_right.IsNull(); - if (is_null_left && is_null_right) { - return 0; - } else if (is_null_left) { - return null_placement_ == NullPlacement::AtStart ? -1 : 1; - } else if (is_null_right) { - return null_placement_ == NullPlacement::AtStart ? 1 : -1; + int CompareInternal(const Location& left, const Location& right, + size_t start_sort_key_index) { + const auto num_sort_keys = sort_keys_.size(); + for (size_t i = start_sort_key_index; i < num_sort_keys; ++i) { + const int r = column_comparators_[i]->Compare(left, right); + if (r != 0) { + return r; } } - return CompareTypeValue(chunk_left, chunk_right, order); - } - - // For non-float types. Value is never NaN. - template - enable_if_t::value, int32_t> CompareTypeValue( - const ResolvedChunk::ArrayType>& chunk_left, - const ResolvedChunk::ArrayType>& chunk_right, - const SortOrder order) { - const auto left = chunk_left.Value(); - const auto right = chunk_right.Value(); - int32_t compared; - if (left == right) { - compared = 0; - } else if (left > right) { - compared = 1; - } else { - compared = -1; - } - if (order == SortOrder::Descending) { - compared = -compared; - } - return compared; - } - - // For float types. Value may be NaN. - template - enable_if_t::value, int32_t> CompareTypeValue( - const ResolvedChunk::ArrayType>& chunk_left, - const ResolvedChunk::ArrayType>& chunk_right, - const SortOrder order) { - const auto left = chunk_left.Value(); - const auto right = chunk_right.Value(); - const bool is_nan_left = std::isnan(left); - const bool is_nan_right = std::isnan(right); - if (is_nan_left && is_nan_right) { - return 0; - } else if (is_nan_left) { - return null_placement_ == NullPlacement::AtStart ? -1 : 1; - } else if (is_nan_right) { - return null_placement_ == NullPlacement::AtStart ? 1 : -1; - } - int32_t compared; - if (left == right) { - compared = 0; - } else if (left > right) { - compared = 1; - } else { - compared = -1; - } - if (order == SortOrder::Descending) { - compared = -compared; - } - return compared; + return 0; } const std::vector& sort_keys_; const NullPlacement null_placement_; + std::vector>> column_comparators_; Status status_; - int64_t current_left_; - int64_t current_right_; - size_t current_sort_key_index_; - int32_t current_compared_; }; // Sort a batch using a single sort and multiple-key comparisons. @@ -1534,6 +1622,8 @@ class MultipleKeyRecordBatchSorter : public TypeVisitor { order(order), null_count(array->null_count()) {} + using LocationType = int64_t; + template ResolvedChunk GetChunk(int64_t index) const { return {&checked_cast(array), index}; @@ -1676,47 +1766,68 @@ class MultipleKeyRecordBatchSorter : public TypeVisitor { // that batch columns are contiguous and therefore have less indexing // overhead), then sorted batches are merged recursively. class TableSorter { - // TODO instead of resolving chunks for each column independently, we could - // rely on RecordBatches and resolve the first column only. - public: // Preprocessed sort key. struct ResolvedSortKey { - ResolvedSortKey(const std::shared_ptr& chunked_array, - const SortOrder order) - : order(order), - type(GetPhysicalType(chunked_array->type())), - chunks(GetPhysicalChunks(*chunked_array, type)), - chunk_pointers(GetArrayPointers(chunks)), - null_count(chunked_array->null_count()), - num_chunks(chunked_array->num_chunks()), - resolver(chunk_pointers) {} + ResolvedSortKey(const std::shared_ptr& type, ArrayVector chunks, + SortOrder order, int64_t null_count) + : type(GetPhysicalType(type)), + owned_chunks(std::move(chunks)), + chunks(GetArrayPointers(owned_chunks)), + order(order), + null_count(null_count) {} + + using LocationType = ChunkLocation; - // Find the target chunk and index in the target chunk from an - // index in chunked array. template - ResolvedChunk GetChunk(int64_t index) const { - return resolver.Resolve(index); + ResolvedChunk GetChunk(ChunkLocation loc) const { + return {checked_cast(chunks[loc.chunk_index]), + loc.index_in_chunk}; + } + + // Make a vector of ResolvedSortKeys for the sort keys and the given table. + // `batches` must be a chunking of `table`. + static Result> Make( + const Table& table, const RecordBatchVector& batches, + const std::vector& sort_keys) { + auto factory = [&](const SortField& f) { + const auto& type = table.schema()->field(f.field_index)->type(); + // We must expose a homogenous chunking for all ResolvedSortKey, + // so we can't simply pass `table.column(f.field_index)` + ArrayVector chunks(batches.size()); + std::transform(batches.begin(), batches.end(), chunks.begin(), + [&](const std::shared_ptr& batch) { + return batch->column(f.field_index); + }); + return ResolvedSortKey(type, std::move(chunks), f.order, + table.column(f.field_index)->null_count()); + }; + + return ::arrow::compute::internal::ResolveSortKeys( + *table.schema(), sort_keys, factory); } - const SortOrder order; - const std::shared_ptr type; - const ArrayVector chunks; - const std::vector chunk_pointers; - const int64_t null_count; - const int num_chunks; - const ChunkedArrayResolver resolver; + std::shared_ptr type; + ArrayVector owned_chunks; + std::vector chunks; + SortOrder order; + int64_t null_count; }; + // TODO make all methods const and defer initialization into a Init() method? + TableSorter(ExecContext* ctx, uint64_t* indices_begin, uint64_t* indices_end, const Table& table, const SortOptions& options) : ctx_(ctx), table_(table), + batches_(MakeBatches(table, &status_)), options_(options), + null_placement_(options.null_placement), + left_resolver_(ChunkResolver::FromBatches(batches_)), + right_resolver_(ChunkResolver::FromBatches(batches_)), + sort_keys_(ResolveSortKeys(table, batches_, options.sort_keys, &status_)), indices_begin_(indices_begin), indices_end_(indices_end), - sort_keys_(ResolveSortKeys(table, options.sort_keys, &status_)), - null_placement_(options.null_placement), comparator_(sort_keys_, null_placement_) {} // This is optimized for null partitioning and merging along the first sort key. @@ -1729,10 +1840,19 @@ class TableSorter { private: using Comparator = MultipleKeyComparator; + static RecordBatchVector MakeBatches(const Table& table, Status* status) { + const auto maybe_batches = BatchesFromTable(table); + if (!maybe_batches.ok()) { + *status = maybe_batches.status(); + return {}; + } + return *std::move(maybe_batches); + } + static std::vector ResolveSortKeys( - const Table& table, const std::vector& sort_keys, Status* status) { - const auto maybe_resolved = - ::arrow::compute::internal::ResolveSortKeys(table, sort_keys); + const Table& table, const RecordBatchVector& batches, + const std::vector& sort_keys, Status* status) { + const auto maybe_resolved = ResolvedSortKey::Make(table, batches, sort_keys); if (!maybe_resolved.ok()) { *status = maybe_resolved.status(); return {}; @@ -1851,12 +1971,14 @@ class TableSorter { std::merge(nulls_begin, nulls_middle, nulls_middle, nulls_end, temp_indices, [&](uint64_t left, uint64_t right) { // First column is either null or nan - auto chunk_left = first_sort_key.GetChunk(left); - auto chunk_right = first_sort_key.GetChunk(right); + const auto left_loc = left_resolver_.Resolve(left); + const auto right_loc = right_resolver_.Resolve(right); + auto chunk_left = first_sort_key.GetChunk(left_loc); + auto chunk_right = first_sort_key.GetChunk(right_loc); const auto left_is_null = chunk_left.IsNull(); const auto right_is_null = chunk_right.IsNull(); if (left_is_null == right_is_null) { - return comparator.Compare(left, right, 1); + return comparator.Compare(left_loc, right_loc, 1); } else if (options_.null_placement == NullPlacement::AtEnd) { return right_is_null; } else { @@ -1884,7 +2006,9 @@ class TableSorter { std::merge(nulls_begin, nulls_middle, nulls_middle, nulls_end, temp_indices, [&](uint64_t left, uint64_t right) { // First column is always null - return comparator.Compare(left, right, 1); + const auto left_loc = left_resolver_.Resolve(left); + const auto right_loc = right_resolver_.Resolve(right); + return comparator.Compare(left_loc, right_loc, 1); }); // Copy back temp area into main buffer std::copy(temp_indices, temp_indices + (nulls_end - nulls_begin), nulls_begin); @@ -1906,8 +2030,10 @@ class TableSorter { std::merge(range_begin, range_middle, range_middle, range_end, temp_indices, [&](uint64_t left, uint64_t right) { // Both values are never null nor NaN. - auto chunk_left = first_sort_key.GetChunk(left); - auto chunk_right = first_sort_key.GetChunk(right); + const auto left_loc = left_resolver_.Resolve(left); + const auto right_loc = right_resolver_.Resolve(right); + auto chunk_left = first_sort_key.GetChunk(left_loc); + auto chunk_right = first_sort_key.GetChunk(right_loc); DCHECK(!chunk_left.IsNull()); DCHECK(!chunk_right.IsNull()); auto value_left = chunk_left.Value(); @@ -1916,7 +2042,7 @@ class TableSorter { // If the left value equals to the right value, // we need to compare the second and following // sort keys. - return comparator.Compare(left, right, 1); + return comparator.Compare(left_loc, right_loc, 1); } else { auto compared = value_left < value_right; if (first_sort_key.order == SortOrder::Ascending) { @@ -1933,24 +2059,21 @@ class TableSorter { template enable_if_null MergeNonNulls(uint64_t* range_begin, uint64_t* range_middle, uint64_t* range_end, uint64_t* temp_indices) { - auto& comparator = comparator_; - - std::merge(range_begin, range_middle, range_middle, range_end, temp_indices, - [&](uint64_t left, uint64_t right) { - return comparator.Compare(left, right, 1); - }); - std::copy(temp_indices, temp_indices + (range_end - range_begin), range_begin); + const int64_t null_count = range_end - range_begin; + MergeNullsOnly(range_begin, range_middle, range_end, temp_indices, null_count); } ExecContext* ctx_; const Table& table_; + const RecordBatchVector batches_; const SortOptions& options_; + const NullPlacement null_placement_; + const ChunkResolver left_resolver_, right_resolver_; + const std::vector sort_keys_; uint64_t* indices_begin_; uint64_t* indices_end_; - Status status_; - std::vector sort_keys_; - const NullPlacement null_placement_; Comparator comparator_; + Status status_; }; // ---------------------------------------------------------------------- @@ -2472,7 +2595,32 @@ class RecordBatchSelecter : public TypeVisitor { class TableSelecter : public TypeVisitor { private: - using ResolvedSortKey = TableSorter::ResolvedSortKey; + struct ResolvedSortKey { + ResolvedSortKey(const std::shared_ptr& chunked_array, + const SortOrder order) + : order(order), + type(GetPhysicalType(chunked_array->type())), + chunks(GetPhysicalChunks(*chunked_array, type)), + chunk_pointers(GetArrayPointers(chunks)), + null_count(chunked_array->null_count()), + resolver(chunk_pointers) {} + + using LocationType = int64_t; + + // Find the target chunk and index in the target chunk from an + // index in chunked array. + template + ResolvedChunk GetChunk(int64_t index) const { + return resolver.Resolve(index); + } + + const SortOrder order; + const std::shared_ptr type; + const ArrayVector chunks; + const std::vector chunk_pointers; + const int64_t null_count; + const ChunkedArrayResolver resolver; + }; using Comparator = MultipleKeyComparator; public: @@ -2538,6 +2686,10 @@ class TableSelecter : public TypeVisitor { return q; } + // XXX this implementation is rather inefficient as it computes chunk indices + // at every comparison. Instead we should iterate over individual batches + // and remember ChunkLocation entries in the max-heap. + template Status SelectKthInternal() { using ArrayType = typename TypeTraits::ArrayType; diff --git a/cpp/src/arrow/compute/kernels/vector_sort_benchmark.cc b/cpp/src/arrow/compute/kernels/vector_sort_benchmark.cc index d8e3b9b8081..6ab0bcfde97 100644 --- a/cpp/src/arrow/compute/kernels/vector_sort_benchmark.cc +++ b/cpp/src/arrow/compute/kernels/vector_sort_benchmark.cc @@ -127,6 +127,8 @@ struct RecordBatchSortIndicesArgs { state_(state) {} ~RecordBatchSortIndicesArgs() { + state_.counters["columns"] = static_cast(num_columns); + state_.counters["null_percent"] = null_proportion * 100; state_.SetItemsProcessed(state_.iterations() * num_records); } @@ -149,6 +151,8 @@ struct TableSortIndicesArgs : public RecordBatchSortIndicesArgs { // Extract args explicit TableSortIndicesArgs(benchmark::State& state) : RecordBatchSortIndicesArgs(state), num_chunks(state.range(3)) {} + + ~TableSortIndicesArgs() { state_.counters["chunks"] = static_cast(num_chunks); } }; struct BatchOrTableBenchmarkData { @@ -266,7 +270,7 @@ BENCHMARK(ChunkedArraySortIndicesInt64Wide) BENCHMARK(RecordBatchSortIndicesInt64Narrow) ->ArgsProduct({ {1 << 20}, // the number of records - {100, 0}, // inverse null proportion + {100, 4, 0}, // inverse null proportion {16, 8, 2, 1}, // the number of columns }) ->Unit(benchmark::TimeUnit::kNanosecond); @@ -274,7 +278,7 @@ BENCHMARK(RecordBatchSortIndicesInt64Narrow) BENCHMARK(RecordBatchSortIndicesInt64Wide) ->ArgsProduct({ {1 << 20}, // the number of records - {100, 0}, // inverse null proportion + {100, 4, 0}, // inverse null proportion {16, 8, 2, 1}, // the number of columns }) ->Unit(benchmark::TimeUnit::kNanosecond); @@ -282,7 +286,7 @@ BENCHMARK(RecordBatchSortIndicesInt64Wide) BENCHMARK(TableSortIndicesInt64Narrow) ->ArgsProduct({ {1 << 20}, // the number of records - {100, 0}, // inverse null proportion + {100, 4, 0}, // inverse null proportion {16, 8, 2, 1}, // the number of columns {32, 4, 1}, // the number of chunks }) @@ -291,7 +295,7 @@ BENCHMARK(TableSortIndicesInt64Narrow) BENCHMARK(TableSortIndicesInt64Wide) ->ArgsProduct({ {1 << 20}, // the number of records - {100, 0}, // inverse null proportion + {100, 4, 0}, // inverse null proportion {16, 8, 2, 1}, // the number of columns {32, 4, 1}, // the number of chunks }) diff --git a/cpp/src/arrow/compute/kernels/vector_sort_test.cc b/cpp/src/arrow/compute/kernels/vector_sort_test.cc index 811060b660f..9e41e966eb3 100644 --- a/cpp/src/arrow/compute/kernels/vector_sort_test.cc +++ b/cpp/src/arrow/compute/kernels/vector_sort_test.cc @@ -24,6 +24,8 @@ #include #include +#include + #include "arrow/array/array_decimal.h" #include "arrow/array/concatenate.h" #include "arrow/compute/api_vector.h" @@ -1261,6 +1263,45 @@ TEST_F(TestRecordBatchSortIndices, DuplicateSortKeys) { // Test basic cases for table. class TestTableSortIndices : public ::testing::Test {}; +TEST_F(TestTableSortIndices, EmptyTable) { + auto schema = ::arrow::schema({ + {field("a", uint8())}, + {field("b", uint32())}, + }); + const std::vector sort_keys{SortKey("a", SortOrder::Ascending), + SortKey("b", SortOrder::Descending)}; + + auto table = TableFromJSON(schema, {"[]"}); + auto chunked_table = TableFromJSON(schema, {"[]", "[]"}); + + SortOptions options(sort_keys, NullPlacement::AtEnd); + AssertSortIndices(table, options, "[]"); + AssertSortIndices(chunked_table, options, "[]"); + options.null_placement = NullPlacement::AtStart; + AssertSortIndices(table, options, "[]"); + AssertSortIndices(chunked_table, options, "[]"); +} + +TEST_F(TestTableSortIndices, EmptySortKeys) { + auto schema = ::arrow::schema({ + {field("a", uint8())}, + {field("b", uint32())}, + }); + const std::vector sort_keys{}; + const SortOptions options(sort_keys, NullPlacement::AtEnd); + + auto table = TableFromJSON(schema, {R"([{"a": null, "b": 5}])"}); + EXPECT_RAISES_WITH_MESSAGE_THAT( + Invalid, testing::HasSubstr("Must specify one or more sort keys"), + CallFunction("sort_indices", {table}, &options)); + + // Several chunks + table = TableFromJSON(schema, {R"([{"a": null, "b": 5}])", R"([{"a": 0, "b": 6}])"}); + EXPECT_RAISES_WITH_MESSAGE_THAT( + Invalid, testing::HasSubstr("Must specify one or more sort keys"), + CallFunction("sort_indices", {table}, &options)); +} + TEST_F(TestTableSortIndices, Null) { auto schema = ::arrow::schema({ {field("a", uint8())}, @@ -1516,6 +1557,32 @@ TEST_F(TestTableSortIndices, DuplicateSortKeys) { AssertSortIndices(table, options, "[3, 0, 4, 5, 6, 7, 1, 2]"); } +TEST_F(TestTableSortIndices, HeterogenousChunking) { + auto schema = ::arrow::schema({ + {field("a", float32())}, + {field("b", float64())}, + }); + + // Same logical data as in "NaNAndNull" test above + auto col_a = + ChunkedArrayFromJSON(float32(), {"[null, 1]", "[]", "[3, null, NaN, NaN, NaN, 1]"}); + auto col_b = ChunkedArrayFromJSON(float64(), + {"[5]", "[3, null, null]", "[null, NaN, 5]", "[5]"}); + auto table = Table::Make(schema, {col_a, col_b}); + + SortOptions options( + {SortKey("a", SortOrder::Ascending), SortKey("b", SortOrder::Descending)}); + AssertSortIndices(table, options, "[7, 1, 2, 6, 5, 4, 0, 3]"); + options.null_placement = NullPlacement::AtStart; + AssertSortIndices(table, options, "[3, 0, 4, 5, 6, 7, 1, 2]"); + + options = SortOptions( + {SortKey("b", SortOrder::Ascending), SortKey("a", SortOrder::Descending)}); + AssertSortIndices(table, options, "[1, 7, 6, 0, 5, 2, 4, 3]"); + options.null_placement = NullPlacement::AtStart; + AssertSortIndices(table, options, "[3, 4, 2, 5, 1, 0, 6, 7]"); +} + // Tests for temporal types template class TestTableSortIndicesForTemporal : public TestTableSortIndices { @@ -1712,22 +1779,48 @@ TEST_P(TestTableSortIndicesRandom, Sort) { {field("decimal128", decimal128(25, 3))}, {field("decimal256", decimal256(42, 6))}, }; - const auto length = 200; - ArrayVector columns = { - rng.UInt8(length, 0, 10, null_probability), - rng.Int16(length, -1000, 12000, /*null_probability=*/0.0), - rng.Int32(length, -123456789, 987654321, null_probability), - rng.UInt64(length, 1, 1234567890123456789ULL, /*null_probability=*/0.0), - rng.Float32(length, -1.0f, 1.0f, null_probability, nan_probability), - rng.Boolean(length, /*true_probability=*/0.3, null_probability), - rng.StringWithRepeats(length, /*unique=*/length / 10, /*min_length=*/5, - /*max_length=*/15, null_probability), - rng.LargeString(length, /*min_length=*/5, /*max_length=*/15, - /*null_probability=*/0.0), - rng.Decimal128(fields[8]->type(), length, null_probability), - rng.Decimal256(fields[9]->type(), length, /*null_probability=*/0.0), + const auto schema = ::arrow::schema(fields); + const int64_t length = 80; + + using ArrayFactory = std::function(int64_t length)>; + + std::vector column_factories{ + [&](int64_t length) { return rng.UInt8(length, 0, 10, null_probability); }, + [&](int64_t length) { + return rng.Int16(length, -1000, 12000, /*null_probability=*/0.0); + }, + [&](int64_t length) { + return rng.Int32(length, -123456789, 987654321, null_probability); + }, + [&](int64_t length) { + return rng.UInt64(length, 1, 1234567890123456789ULL, /*null_probability=*/0.0); + }, + [&](int64_t length) { + return rng.Float32(length, -1.0f, 1.0f, null_probability, nan_probability); + }, + [&](int64_t length) { + return rng.Boolean(length, /*true_probability=*/0.3, null_probability); + }, + [&](int64_t length) { + if (length > 0) { + return rng.StringWithRepeats(length, /*unique=*/1 + length / 10, + /*min_length=*/5, + /*max_length=*/15, null_probability); + } else { + return *MakeArrayOfNull(utf8(), 0); + } + }, + [&](int64_t length) { + return rng.LargeString(length, /*min_length=*/5, /*max_length=*/15, + /*null_probability=*/0.0); + }, + [&](int64_t length) { + return rng.Decimal128(fields[8]->type(), length, null_probability); + }, + [&](int64_t length) { + return rng.Decimal256(fields[9]->type(), length, /*null_probability=*/0.0); + }, }; - const auto table = Table::Make(schema(fields), columns, length); // Generate random sort keys, making sure no column is included twice std::default_random_engine engine(seed); @@ -1758,30 +1851,53 @@ TEST_P(TestTableSortIndicesRandom, Sort) { SortOptions options(sort_keys); - // Test with different table chunkings - for (const int64_t num_chunks : {1, 2, 20}) { - ARROW_SCOPED_TRACE("Table sorting: num_chunks = ", num_chunks); - TableBatchReader reader(*table); - reader.set_chunksize((length + num_chunks - 1) / num_chunks); - ASSERT_OK_AND_ASSIGN(auto chunked_table, Table::FromRecordBatchReader(&reader)); + // Test with different, heterogenous table chunkings + for (const int64_t max_num_chunks : {1, 3, 15}) { + ARROW_SCOPED_TRACE("Table sorting: max chunks per column = ", max_num_chunks); + std::uniform_int_distribution num_chunk_dist(1 + max_num_chunks / 2, + max_num_chunks); + ChunkedArrayVector columns; + columns.reserve(fields.size()); + + // Chunk each column independently, and make sure they consist of + // physically non-contiguous chunks. + for (const auto& factory : column_factories) { + const int64_t num_chunks = num_chunk_dist(engine); + ArrayVector chunks(num_chunks); + const auto offsets = + checked_pointer_cast(rng.Offsets(num_chunks + 1, 0, length)); + for (int64_t i = 0; i < num_chunks; ++i) { + const auto chunk_len = offsets->Value(i + 1) - offsets->Value(i); + chunks[i] = factory(chunk_len); + } + columns.push_back(std::make_shared(std::move(chunks))); + ASSERT_EQ(columns.back()->length(), length); + } + + auto table = Table::Make(schema, std::move(columns)); for (auto null_placement : AllNullPlacements()) { ARROW_SCOPED_TRACE("null_placement = ", null_placement); options.null_placement = null_placement; - ASSERT_OK_AND_ASSIGN(auto offsets, SortIndices(Datum(*chunked_table), options)); + ASSERT_OK_AND_ASSIGN(auto offsets, SortIndices(Datum(*table), options)); Validate(*table, options, *checked_pointer_cast(offsets)); } } // Also validate RecordBatch sorting - TableBatchReader reader(*table); - RecordBatchVector batches; - ASSERT_OK(reader.ReadAll(&batches)); - ASSERT_EQ(batches.size(), 1); ARROW_SCOPED_TRACE("Record batch sorting"); + ArrayVector columns; + columns.reserve(fields.size()); + for (const auto& factory : column_factories) { + columns.push_back(factory(length)); + } + auto batch = RecordBatch::Make(schema, length, std::move(columns)); + ASSERT_OK(batch->ValidateFull()); + ASSERT_OK_AND_ASSIGN(auto table, Table::FromRecordBatches(schema, {batch})); + for (auto null_placement : AllNullPlacements()) { ARROW_SCOPED_TRACE("null_placement = ", null_placement); options.null_placement = null_placement; - ASSERT_OK_AND_ASSIGN(auto offsets, SortIndices(Datum(*batches[0]), options)); + ASSERT_OK_AND_ASSIGN(auto offsets, SortIndices(Datum(batch), options)); Validate(*table, options, *checked_pointer_cast(offsets)); } }