From 5a1f7051197db4fae4f6b774b00a3ddf9ac503c6 Mon Sep 17 00:00:00 2001 From: Alexander Date: Fri, 10 Sep 2021 11:45:26 -0500 Subject: [PATCH 1/2] addressing edponce feedback --- cpp/src/arrow/compute/api_vector.cc | 20 +--- cpp/src/arrow/compute/api_vector.h | 14 ++- .../arrow/compute/kernels/select_k_test.cc | 49 ++++------ cpp/src/arrow/compute/kernels/vector_sort.cc | 93 ++++++++----------- 4 files changed, 64 insertions(+), 112 deletions(-) diff --git a/cpp/src/arrow/compute/api_vector.cc b/cpp/src/arrow/compute/api_vector.cc index 3b5561e4423..34ee0599c3d 100644 --- a/cpp/src/arrow/compute/api_vector.cc +++ b/cpp/src/arrow/compute/api_vector.cc @@ -147,23 +147,6 @@ SelectKOptions::SelectKOptions(int64_t k, std::vector sort_keys) : FunctionOptions(internal::kSelectKOptionsType), k(k), sort_keys(std::move(sort_keys)) {} - -bool SelectKOptions::is_top_k() const { - for (const auto& k : sort_keys) { - if (k.order != SortOrder::Descending) { - return false; - } - } - return true; -} -bool SelectKOptions::is_bottom_k() const { - for (const auto& k : sort_keys) { - if (k.order != SortOrder::Ascending) { - return false; - } - } - return true; -} constexpr char SelectKOptions::kTypeName[]; namespace internal { @@ -189,7 +172,8 @@ Result> NthToIndices(const Array& values, int64_t n, return result.make_array(); } -Result> SelectKUnstable(const Datum& datum, SelectKOptions options, +Result> SelectKUnstable(const Datum& datum, + const SelectKOptions& options, ExecContext* ctx) { ARROW_ASSIGN_OR_RAISE(Datum result, CallFunction("select_k_unstable", {datum}, &options, ctx)); diff --git a/cpp/src/arrow/compute/api_vector.h b/cpp/src/arrow/compute/api_vector.h index 95796b8026d..e64934e4a7e 100644 --- a/cpp/src/arrow/compute/api_vector.h +++ b/cpp/src/arrow/compute/api_vector.h @@ -150,9 +150,6 @@ class ARROW_EXPORT SelectKOptions : public FunctionOptions { } return SelectKOptions{k, keys}; } - bool is_top_k() const; - - bool is_bottom_k() const; /// The number of `k` elements to keep. int64_t k; @@ -292,19 +289,20 @@ ARROW_EXPORT Result> NthToIndices(const Array& values, int64_t n, ExecContext* ctx = NULLPTR); -/// \brief Returns the first k elements ordered by `options.keys`. +/// \brief Returns the indices of the first k elements ordered by `options.sort_keys`. /// -/// Return a sorted array with its elements rearranged in such +/// Select an array of indices of the sorted array with its elements rearranged in such /// a way that the value of the element in k-th position (options.k) is in the position it -/// would be in a sorted datum ordered by `options.keys`. Null like values will be not -/// part of the output. Output is not guaranteed to be stable. +/// would be in a sorted datum ordered by `options.sort_keys`. Null like values will be +/// not part of the output. Output is not guaranteed to be stable. /// /// \param[in] datum datum to be partitioned /// \param[in] options options /// \param[in] ctx the function execution context, optional /// \return a datum with the same schema as the input ARROW_EXPORT -Result> SelectKUnstable(const Datum& datum, SelectKOptions options, +Result> SelectKUnstable(const Datum& datum, + const SelectKOptions& options, ExecContext* ctx = NULLPTR); /// \brief Returns the indices that would sort an array in the diff --git a/cpp/src/arrow/compute/kernels/select_k_test.cc b/cpp/src/arrow/compute/kernels/select_k_test.cc index 779a4f1fa3d..51df3184584 100644 --- a/cpp/src/arrow/compute/kernels/select_k_test.cc +++ b/cpp/src/arrow/compute/kernels/select_k_test.cc @@ -42,7 +42,7 @@ using internal::checked_pointer_cast; namespace compute { template -class SelectKComparator { +class SelectKCompareForResult { public: template bool operator()(const Type& lval, const Type& rval) { @@ -63,14 +63,9 @@ Result> SelectK(const Datum& values, int64_t k) { } } -template Result> SelectK(const Datum& values, const SelectKOptions& options) { - if (order == SortOrder::Descending) { - return SelectKUnstable(Datum(values), options); - } else { - return SelectKUnstable(Datum(values), options); - } + return SelectKUnstable(Datum(values), options); } void ValidateSelectK(const Datum& datum, Array& select_k_indices, SortOrder order, @@ -298,7 +293,7 @@ template void ValidateSelectKIndices(const ArrayType& array) { ValidateOutput(array); - SelectKComparator compare; + SelectKCompareForResult compare; for (uint64_t i = 1; i < static_cast(array.length()); i++) { using ArrowType = typename ArrayType::TypeClass; using GetView = internal::GetViewType; @@ -365,7 +360,6 @@ TYPED_TEST_SUITE(TestBottomKChunkedArrayRandom, SelectKableTypes); TYPED_TEST(TestBottomKChunkedArrayRandom, BottomK) { this->TestSelectK(1000); } // // Test basic cases for record batch. -template class TestSelectKWithRecordBatch : public ::testing::Test { public: void Check(const std::shared_ptr& schm, const std::string& batch_json, @@ -378,7 +372,7 @@ class TestSelectKWithRecordBatch : public ::testing::Test { Status DoSelectK(const std::shared_ptr& schm, const std::string& batch_json, const SelectKOptions& options, std::shared_ptr* out) { auto batch = RecordBatchFromJSON(schm, batch_json); - ARROW_ASSIGN_OR_RAISE(auto indices, SelectK(Datum(*batch), options)); + ARROW_ASSIGN_OR_RAISE(auto indices, SelectK(Datum(*batch), options)); ValidateOutput(*indices); ARROW_ASSIGN_OR_RAISE( @@ -388,9 +382,7 @@ class TestSelectKWithRecordBatch : public ::testing::Test { } }; -struct TestTopKWithRecordBatch : TestSelectKWithRecordBatch {}; - -TEST_F(TestTopKWithRecordBatch, NoNull) { +TEST_F(TestSelectKWithRecordBatch, TopKNoNull) { auto schema = ::arrow::schema({ {field("a", uint8())}, {field("b", uint32())}, @@ -417,7 +409,7 @@ TEST_F(TestTopKWithRecordBatch, NoNull) { Check(schema, batch_input, options, expected_batch); } -TEST_F(TestTopKWithRecordBatch, Null) { +TEST_F(TestSelectKWithRecordBatch, TopKNull) { auto schema = ::arrow::schema({ {field("a", uint8())}, {field("b", uint32())}, @@ -444,7 +436,7 @@ TEST_F(TestTopKWithRecordBatch, Null) { Check(schema, batch_input, options, expected_batch); } -TEST_F(TestTopKWithRecordBatch, OneColumnKey) { +TEST_F(TestSelectKWithRecordBatch, TopKOneColumnKey) { auto schema = ::arrow::schema({ {field("country", utf8())}, {field("population", uint64())}, @@ -473,7 +465,7 @@ TEST_F(TestTopKWithRecordBatch, OneColumnKey) { this->Check(schema, batch_input, options, expected_batch); } -TEST_F(TestTopKWithRecordBatch, MultipleColumnKeys) { +TEST_F(TestSelectKWithRecordBatch, TopKMultipleColumnKeys) { auto schema = ::arrow::schema({{field("country", utf8())}, {field("population", uint64())}, {field("GDP", uint64())}}); @@ -499,9 +491,7 @@ TEST_F(TestTopKWithRecordBatch, MultipleColumnKeys) { this->Check(schema, batch_input, options, expected_batch); } -struct TestBottomKWithRecordBatch : TestSelectKWithRecordBatch {}; - -TEST_F(TestBottomKWithRecordBatch, NoNull) { +TEST_F(TestSelectKWithRecordBatch, BottomKNoNull) { auto schema = ::arrow::schema({ {field("a", uint8())}, {field("b", uint32())}, @@ -528,7 +518,7 @@ TEST_F(TestBottomKWithRecordBatch, NoNull) { Check(schema, batch_input, options, expected_batch); } -TEST_F(TestBottomKWithRecordBatch, Null) { +TEST_F(TestSelectKWithRecordBatch, BottomKNull) { auto schema = ::arrow::schema({ {field("a", uint8())}, {field("b", uint32())}, @@ -555,7 +545,7 @@ TEST_F(TestBottomKWithRecordBatch, Null) { Check(schema, batch_input, options, expected_batch); } -TEST_F(TestBottomKWithRecordBatch, OneColumnKey) { +TEST_F(TestSelectKWithRecordBatch, BottomKOneColumnKey) { auto schema = ::arrow::schema({ {field("country", utf8())}, {field("population", uint64())}, @@ -584,7 +574,7 @@ TEST_F(TestBottomKWithRecordBatch, OneColumnKey) { this->Check(schema, batch_input, options, expected_batch); } -TEST_F(TestBottomKWithRecordBatch, MultipleColumnKeys) { +TEST_F(TestSelectKWithRecordBatch, BottomKMultipleColumnKeys) { auto schema = ::arrow::schema({{field("country", utf8())}, {field("population", uint64())}, {field("GDP", uint64())}}); @@ -612,7 +602,6 @@ TEST_F(TestBottomKWithRecordBatch, MultipleColumnKeys) { } // Test basic cases for table. -template struct TestSelectKWithTable : public ::testing::Test { void Check(const std::shared_ptr& schm, const std::vector& input_json, const SelectKOptions& options, @@ -626,7 +615,7 @@ struct TestSelectKWithTable : public ::testing::Test { const std::vector& input_json, const SelectKOptions& options, std::shared_ptr* out) { auto table = TableFromJSON(schm, input_json); - ARROW_ASSIGN_OR_RAISE(auto indices, SelectK(Datum(*table), options)); + ARROW_ASSIGN_OR_RAISE(auto indices, SelectK(Datum(*table), options)); ValidateOutput(*indices); ARROW_ASSIGN_OR_RAISE( @@ -636,9 +625,7 @@ struct TestSelectKWithTable : public ::testing::Test { } }; -struct TestTopKWithTable : TestSelectKWithTable {}; - -TEST_F(TestTopKWithTable, OneColumnKey) { +TEST_F(TestSelectKWithTable, TopKOneColumnKey) { auto schema = ::arrow::schema({ {field("a", uint8())}, {field("b", uint32())}, @@ -661,7 +648,7 @@ TEST_F(TestTopKWithTable, OneColumnKey) { Check(schema, input, options, expected); } -TEST_F(TestTopKWithTable, MultipleColumnKeys) { +TEST_F(TestSelectKWithTable, TopKMultipleColumnKeys) { auto schema = ::arrow::schema({ {field("a", uint8())}, {field("b", uint32())}, @@ -684,9 +671,7 @@ TEST_F(TestTopKWithTable, MultipleColumnKeys) { Check(schema, input, options, expected); } -struct TestBottomKWithTable : TestSelectKWithTable {}; - -TEST_F(TestBottomKWithTable, OneColumnKey) { +TEST_F(TestSelectKWithTable, BottomKOneColumnKey) { auto schema = ::arrow::schema({ {field("a", uint8())}, {field("b", uint32())}, @@ -709,7 +694,7 @@ TEST_F(TestBottomKWithTable, OneColumnKey) { Check(schema, input, options, expected); } -TEST_F(TestBottomKWithTable, MultipleColumnKeys) { +TEST_F(TestSelectKWithTable, BottomKMultipleColumnKeys) { auto schema = ::arrow::schema({ {field("a", uint8())}, {field("b", uint32())}, diff --git a/cpp/src/arrow/compute/kernels/vector_sort.cc b/cpp/src/arrow/compute/kernels/vector_sort.cc index 58a48aa9056..294ca4a9c2f 100644 --- a/cpp/src/arrow/compute/kernels/vector_sort.cc +++ b/cpp/src/arrow/compute/kernels/vector_sort.cc @@ -1155,7 +1155,7 @@ class MultipleKeyComparator { // If the left value equals to the right value, we need to // continue to sort. if (current_compared_ != 0) { - break; + return false; } } return current_compared_ == 0; @@ -1803,15 +1803,15 @@ class SortIndicesMetaFunction : public MetaFunction { const auto kDefaultSelectKOptions = SelectKOptions::Defaults(); -const FunctionDoc select_k_doc( - "Returns the first k elements ordered by `options.keys`", - ("This function computes the k elements of the input\n" - "array, record batch or table specified in the column names (`options.sort_keys`).\n" +const FunctionDoc select_k_unstable_doc( + "Selects the indices of the first `k` ordered elements from the input", + ("This function selects an array of indices of the first `k` ordered elements from\n" + "the input array, record batch or table specified in the column keys\n" + "(`options.sort_keys`). Output is not guaranteed to be stable.\n" "The columns that are not specified are returned as well, but not used for\n" "ordering. Null values are considered greater than any other value and are\n" - "therefore sorted at the end of the array.\n" - "For floating-point types, NaNs are considered greater than any\n" - "other non-null value, but smaller than null values."), + "therefore sorted at the end of the array. For floating-point types, ordering of\n" + "values is such that: Null > NaN > Inf > number."), {"input"}, "SelectKOptions"); Result> MakeMutableUInt64Array( @@ -1846,7 +1846,6 @@ class SelectKComparator { } }; -template class ArraySelecter : public TypeVisitor { public: ArraySelecter(ExecContext* ctx, const Array& array, const SelectKOptions& options, @@ -1855,19 +1854,25 @@ class ArraySelecter : public TypeVisitor { ctx_(ctx), array_(array), k_(options.k), + order_(options.sort_keys[0].order), physical_type_(GetPhysicalType(array.type())), output_(output) {} Status Run() { return physical_type_->Accept(this); } -#define VISIT(TYPE) \ - Status Visit(const TYPE& type) { return SelectKthInternal(); } +#define VISIT(TYPE) \ + Status Visit(const TYPE& type) { \ + if (order_ == SortOrder::Ascending) { \ + return SelectKthInternal(); \ + } \ + return SelectKthInternal(); \ + } VISIT_PHYSICAL_TYPES(VISIT) #undef VISIT - template + template Status SelectKthInternal() { using GetView = GetViewType; using ArrayType = typename TypeTraits::ArrayType; @@ -1883,10 +1888,8 @@ class ArraySelecter : public TypeVisitor { } auto end_iter = PartitionNulls(indices_begin, indices_end, arr, 0); - auto kth_begin = indices_begin + k_; - if (kth_begin > end_iter) { - kth_begin = end_iter; - } + auto kth_begin = std::min(indices_begin + k_, end_iter); + SelectKComparator comparator; auto cmp = [&arr, &comparator](uint64_t left, uint64_t right) { const auto lval = GetView::LogicalValue(arr.GetView(left)); @@ -1920,6 +1923,7 @@ class ArraySelecter : public TypeVisitor { ExecContext* ctx_; const Array& array_; int64_t k_; + SortOrder order_; const std::shared_ptr physical_type_; Datum* output_; }; @@ -1931,7 +1935,6 @@ struct TypedHeapItem { ArrayType* array; }; -template class ChunkedArraySelecter : public TypeVisitor { public: ChunkedArraySelecter(ExecContext* ctx, const ChunkedArray& chunked_array, @@ -1941,19 +1944,24 @@ class ChunkedArraySelecter : public TypeVisitor { physical_type_(GetPhysicalType(chunked_array.type())), physical_chunks_(GetPhysicalChunks(chunked_array_, physical_type_)), k_(options.k), + order_(options.sort_keys[0].order), ctx_(ctx), output_(output) {} Status Run() { return physical_type_->Accept(this); } -#define VISIT(TYPE) \ - Status Visit(const TYPE& type) { return SelectKthInternal(); } +#define VISIT(TYPE) \ + Status Visit(const TYPE& type) { \ + if (order_ == SortOrder::Ascending) { \ + return SelectKthInternal(); \ + } \ + return SelectKthInternal(); \ + } VISIT_PHYSICAL_TYPES(VISIT) - #undef VISIT - template + template Status SelectKthInternal() { using GetView = GetViewType; using ArrayType = typename TypeTraits::ArrayType; @@ -1992,11 +2000,7 @@ class ChunkedArraySelecter : public TypeVisitor { auto end_iter = PartitionNulls( indices_begin, indices_end, arr, 0); - auto kth_begin = indices_begin + k_; - - if (kth_begin > end_iter) { - kth_begin = end_iter; - } + auto kth_begin = std::min(indices_begin + k_, end_iter); uint64_t* iter = indices_begin; for (; iter != kth_begin && heap.size() < static_cast(k_); ++iter) { heap.push(HeapItem{*iter, offset, &arr}); @@ -2033,9 +2037,10 @@ class ChunkedArraySelecter : public TypeVisitor { const std::shared_ptr physical_type_; const ArrayVector physical_chunks_; int64_t k_; + SortOrder order_; ExecContext* ctx_; Datum* output_; -}; +}; // namespace class RecordBatchSelecter : public TypeVisitor { private: @@ -2113,11 +2118,8 @@ class RecordBatchSelecter : public TypeVisitor { auto end_iter = PartitionNulls(indices_begin, indices_end, arr, 0); - auto kth_begin = indices_begin + k_; + auto kth_begin = std::min(indices_begin + k_, end_iter); - if (kth_begin > end_iter) { - kth_begin = end_iter; - } HeapContainer heap(indices_begin, kth_begin, cmp); for (auto iter = kth_begin; iter != end_iter && !heap.empty(); ++iter) { uint64_t x_index = *iter; @@ -2283,11 +2285,8 @@ class TableSelecter : public TypeVisitor { auto end_iter = this->PartitionNullsInternal(indices_begin, indices_end, first_sort_key); - auto kth_begin = indices_begin + k_; + auto kth_begin = std::min(indices_begin + k_, end_iter); - if (kth_begin > end_iter) { - kth_begin = end_iter; - } HeapContainer heap(indices_begin, kth_begin, cmp); for (auto iter = kth_begin; iter != end_iter && !heap.empty(); ++iter) { uint64_t x_index = *iter; @@ -2332,7 +2331,7 @@ static Status CheckConsistency(const Schema& schema, class SelectKUnstableMetaFunction : public MetaFunction { public: SelectKUnstableMetaFunction() - : MetaFunction("select_k_unstable", Arity::Unary(), &select_k_doc, + : MetaFunction("select_k_unstable", Arity::Unary(), &select_k_unstable_doc, &kDefaultSelectKOptions) {} Result ExecuteImpl(const std::vector& args, @@ -2344,22 +2343,10 @@ class SelectKUnstableMetaFunction : public MetaFunction { } switch (args[0].kind()) { case Datum::ARRAY: { - if (select_k_options.is_top_k()) { - return SelectKth(*args[0].make_array(), select_k_options, - ctx); - } else { - return SelectKth(*args[0].make_array(), select_k_options, - ctx); - } + return SelectKth(*args[0].make_array(), select_k_options, ctx); } break; case Datum::CHUNKED_ARRAY: { - if (select_k_options.is_top_k()) { - return SelectKth(*args[0].chunked_array(), - select_k_options, ctx); - } else { - return SelectKth(*args[0].chunked_array(), - select_k_options, ctx); - } + return SelectKth(*args[0].chunked_array(), select_k_options, ctx); } break; case Datum::RECORD_BATCH: return SelectKth(*args[0].record_batch(), select_k_options, ctx); @@ -2377,20 +2364,18 @@ class SelectKUnstableMetaFunction : public MetaFunction { } private: - template Result SelectKth(const Array& array, const SelectKOptions& options, ExecContext* ctx) const { Datum output; - ArraySelecter selecter(ctx, array, options, &output); + ArraySelecter selecter(ctx, array, options, &output); ARROW_RETURN_NOT_OK(selecter.Run()); return output; } - template Result SelectKth(const ChunkedArray& chunked_array, const SelectKOptions& options, ExecContext* ctx) const { Datum output; - ChunkedArraySelecter selecter(ctx, chunked_array, options, &output); + ChunkedArraySelecter selecter(ctx, chunked_array, options, &output); ARROW_RETURN_NOT_OK(selecter.Run()); return output; } From a380d4a7a352178ad01e54c01089f649c60c8482 Mon Sep 17 00:00:00 2001 From: Alexander Date: Fri, 10 Sep 2021 15:27:20 -0500 Subject: [PATCH 2/2] minor fixes --- cpp/src/arrow/compute/api_vector.h | 14 ++++++++------ cpp/src/arrow/compute/kernels/select_k_test.cc | 9 ++------- 2 files changed, 10 insertions(+), 13 deletions(-) diff --git a/cpp/src/arrow/compute/api_vector.h b/cpp/src/arrow/compute/api_vector.h index e64934e4a7e..a1c6f7959e1 100644 --- a/cpp/src/arrow/compute/api_vector.h +++ b/cpp/src/arrow/compute/api_vector.h @@ -289,12 +289,14 @@ ARROW_EXPORT Result> NthToIndices(const Array& values, int64_t n, ExecContext* ctx = NULLPTR); -/// \brief Returns the indices of the first k elements ordered by `options.sort_keys`. -/// -/// Select an array of indices of the sorted array with its elements rearranged in such -/// a way that the value of the element in k-th position (options.k) is in the position it -/// would be in a sorted datum ordered by `options.sort_keys`. Null like values will be -/// not part of the output. Output is not guaranteed to be stable. +/// \brief Returns the indices that would select the first `k` elements of the array in +/// the specified order. +/// +// Perform an indirect sort of the datum, keeping only the first `k` elements. The output +// array will contain indices such that the item indicated by the k-th index will be in +// the position it would be if the datum were sorted by `options.sort_keys`. However, +// indices of null values will not be part of the output. The sort is not guaranteed to be +// stable. /// /// \param[in] datum datum to be partitioned /// \param[in] options options diff --git a/cpp/src/arrow/compute/kernels/select_k_test.cc b/cpp/src/arrow/compute/kernels/select_k_test.cc index 51df3184584..2d1d5cffe3d 100644 --- a/cpp/src/arrow/compute/kernels/select_k_test.cc +++ b/cpp/src/arrow/compute/kernels/select_k_test.cc @@ -63,11 +63,6 @@ Result> SelectK(const Datum& values, int64_t k) { } } -Result> SelectK(const Datum& values, - const SelectKOptions& options) { - return SelectKUnstable(Datum(values), options); -} - void ValidateSelectK(const Datum& datum, Array& select_k_indices, SortOrder order, bool stable_sort = false) { ASSERT_TRUE(datum.is_arraylike()); @@ -372,7 +367,7 @@ class TestSelectKWithRecordBatch : public ::testing::Test { Status DoSelectK(const std::shared_ptr& schm, const std::string& batch_json, const SelectKOptions& options, std::shared_ptr* out) { auto batch = RecordBatchFromJSON(schm, batch_json); - ARROW_ASSIGN_OR_RAISE(auto indices, SelectK(Datum(*batch), options)); + ARROW_ASSIGN_OR_RAISE(auto indices, SelectKUnstable(Datum(*batch), options)); ValidateOutput(*indices); ARROW_ASSIGN_OR_RAISE( @@ -615,7 +610,7 @@ struct TestSelectKWithTable : public ::testing::Test { const std::vector& input_json, const SelectKOptions& options, std::shared_ptr
* out) { auto table = TableFromJSON(schm, input_json); - ARROW_ASSIGN_OR_RAISE(auto indices, SelectK(Datum(*table), options)); + ARROW_ASSIGN_OR_RAISE(auto indices, SelectKUnstable(Datum(*table), options)); ValidateOutput(*indices); ARROW_ASSIGN_OR_RAISE(