diff --git a/cpp/src/arrow/compute/api_vector.cc b/cpp/src/arrow/compute/api_vector.cc index d4c4a915999..3b5561e4423 100644 --- a/cpp/src/arrow/compute/api_vector.cc +++ b/cpp/src/arrow/compute/api_vector.cc @@ -111,6 +111,9 @@ static auto kSortOptionsType = GetFunctionOptionsType(DataMember("sort_keys", &SortOptions::sort_keys)); static auto kPartitionNthOptionsType = GetFunctionOptionsType( DataMember("pivot", &PartitionNthOptions::pivot)); +static auto kSelectKOptionsType = GetFunctionOptionsType( + DataMember("k", &SelectKOptions::k), + DataMember("sort_keys", &SelectKOptions::sort_keys)); } // namespace } // namespace internal @@ -140,6 +143,29 @@ PartitionNthOptions::PartitionNthOptions(int64_t pivot) : FunctionOptions(internal::kPartitionNthOptionsType), pivot(pivot) {} constexpr char PartitionNthOptions::kTypeName[]; +SelectKOptions::SelectKOptions(int64_t k, std::vector sort_keys) + : FunctionOptions(internal::kSelectKOptionsType), + k(k), + sort_keys(std::move(sort_keys)) {} + +bool SelectKOptions::is_top_k() const { + for (const auto& k : sort_keys) { + if (k.order != SortOrder::Descending) { + return false; + } + } + return true; +} +bool SelectKOptions::is_bottom_k() const { + for (const auto& k : sort_keys) { + if (k.order != SortOrder::Ascending) { + return false; + } + } + return true; +} +constexpr char SelectKOptions::kTypeName[]; + namespace internal { void RegisterVectorOptions(FunctionRegistry* registry) { DCHECK_OK(registry->AddFunctionOptionsType(kFilterOptionsType)); @@ -148,6 +174,7 @@ void RegisterVectorOptions(FunctionRegistry* registry) { DCHECK_OK(registry->AddFunctionOptionsType(kArraySortOptionsType)); DCHECK_OK(registry->AddFunctionOptionsType(kSortOptionsType)); DCHECK_OK(registry->AddFunctionOptionsType(kPartitionNthOptionsType)); + DCHECK_OK(registry->AddFunctionOptionsType(kSelectKOptionsType)); } } // namespace internal @@ -162,6 +189,13 @@ Result> NthToIndices(const Array& values, int64_t n, return result.make_array(); } +Result> SelectKUnstable(const Datum& datum, SelectKOptions options, + ExecContext* ctx) { + ARROW_ASSIGN_OR_RAISE(Datum result, + CallFunction("select_k_unstable", {datum}, &options, ctx)); + return result.make_array(); +} + Result ReplaceWithMask(const Datum& values, const Datum& mask, const Datum& replacements, ExecContext* ctx) { return CallFunction("replace_with_mask", {values, mask, replacements}, ctx); diff --git a/cpp/src/arrow/compute/api_vector.h b/cpp/src/arrow/compute/api_vector.h index 5dc68fc5c83..95796b8026d 100644 --- a/cpp/src/arrow/compute/api_vector.h +++ b/cpp/src/arrow/compute/api_vector.h @@ -120,6 +120,46 @@ class ARROW_EXPORT SortOptions : public FunctionOptions { std::vector sort_keys; }; +/// \brief SelectK options +class ARROW_EXPORT SelectKOptions : public FunctionOptions { + public: + explicit SelectKOptions(int64_t k = -1, std::vector sort_keys = {}); + + constexpr static char const kTypeName[] = "SelectKOptions"; + + static SelectKOptions Defaults() { return SelectKOptions{-1, {}}; } + + static SelectKOptions TopKDefault(int64_t k, std::vector key_names = {}) { + std::vector keys; + for (const auto& name : key_names) { + keys.emplace_back(SortKey(name, SortOrder::Descending)); + } + if (key_names.empty()) { + keys.emplace_back(SortKey("not-used", SortOrder::Descending)); + } + return SelectKOptions{k, keys}; + } + static SelectKOptions BottomKDefault(int64_t k, + std::vector key_names = {}) { + std::vector keys; + for (const auto& name : key_names) { + keys.emplace_back(SortKey(name, SortOrder::Ascending)); + } + if (key_names.empty()) { + keys.emplace_back(SortKey("not-used", SortOrder::Ascending)); + } + return SelectKOptions{k, keys}; + } + bool is_top_k() const; + + bool is_bottom_k() const; + + /// The number of `k` elements to keep. + int64_t k; + /// Column key(s) to order by and how to order by these sort keys. + std::vector sort_keys; +}; + /// \brief Partitioning options for NthToIndices class ARROW_EXPORT PartitionNthOptions : public FunctionOptions { public: @@ -252,6 +292,21 @@ ARROW_EXPORT Result> NthToIndices(const Array& values, int64_t n, ExecContext* ctx = NULLPTR); +/// \brief Returns the first k elements ordered by `options.keys`. +/// +/// Return a sorted array with its elements rearranged in such +/// a way that the value of the element in k-th position (options.k) is in the position it +/// would be in a sorted datum ordered by `options.keys`. Null like values will be not +/// part of the output. Output is not guaranteed to be stable. +/// +/// \param[in] datum datum to be partitioned +/// \param[in] options options +/// \param[in] ctx the function execution context, optional +/// \return a datum with the same schema as the input +ARROW_EXPORT +Result> SelectKUnstable(const Datum& datum, SelectKOptions options, + ExecContext* ctx = NULLPTR); + /// \brief Returns the indices that would sort an array in the /// specified order. /// diff --git a/cpp/src/arrow/compute/function_test.cc b/cpp/src/arrow/compute/function_test.cc index ab41887ca35..c08fdaca627 100644 --- a/cpp/src/arrow/compute/function_test.cc +++ b/cpp/src/arrow/compute/function_test.cc @@ -116,6 +116,8 @@ TEST(FunctionOptions, Equality) { {SortKey("key", SortOrder::Descending), SortKey("value", SortOrder::Descending)})); options.emplace_back(new PartitionNthOptions(/*pivot=*/0)); options.emplace_back(new PartitionNthOptions(/*pivot=*/42)); + options.emplace_back(new SelectKOptions(0, {})); + options.emplace_back(new SelectKOptions(5, {{SortKey("key", SortOrder::Ascending)}})); for (size_t i = 0; i < options.size(); i++) { const size_t prev_i = i == 0 ? options.size() - 1 : i - 1; diff --git a/cpp/src/arrow/compute/kernels/CMakeLists.txt b/cpp/src/arrow/compute/kernels/CMakeLists.txt index 4096e497c0a..ce7a85f1557 100644 --- a/cpp/src/arrow/compute/kernels/CMakeLists.txt +++ b/cpp/src/arrow/compute/kernels/CMakeLists.txt @@ -50,11 +50,13 @@ add_arrow_compute_test(vector_test vector_replace_test.cc vector_selection_test.cc vector_sort_test.cc + select_k_test.cc test_util.cc) add_arrow_benchmark(vector_hash_benchmark PREFIX "arrow-compute") add_arrow_benchmark(vector_sort_benchmark PREFIX "arrow-compute") add_arrow_benchmark(vector_partition_benchmark PREFIX "arrow-compute") +add_arrow_benchmark(vector_topk_benchmark PREFIX "arrow-compute") add_arrow_benchmark(vector_replace_benchmark PREFIX "arrow-compute") add_arrow_benchmark(vector_selection_benchmark PREFIX "arrow-compute") diff --git a/cpp/src/arrow/compute/kernels/select_k_test.cc b/cpp/src/arrow/compute/kernels/select_k_test.cc new file mode 100644 index 00000000000..779a4f1fa3d --- /dev/null +++ b/cpp/src/arrow/compute/kernels/select_k_test.cc @@ -0,0 +1,736 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include +#include +#include +#include +#include +#include + +#include "arrow/array/array_decimal.h" +#include "arrow/array/concatenate.h" +#include "arrow/compute/api_vector.h" +#include "arrow/compute/kernels/test_util.h" +#include "arrow/compute/kernels/util_internal.h" +#include "arrow/table.h" +#include "arrow/testing/gtest_common.h" +#include "arrow/testing/gtest_util.h" +#include "arrow/testing/random.h" +#include "arrow/testing/util.h" +#include "arrow/type_traits.h" + +namespace arrow { + +using internal::checked_cast; +using internal::checked_pointer_cast; + +namespace compute { + +template +class SelectKComparator { + public: + template + bool operator()(const Type& lval, const Type& rval) { + if (order == SortOrder::Ascending) { + return lval <= rval; + } else { + return rval <= lval; + } + } +}; + +template +Result> SelectK(const Datum& values, int64_t k) { + if (order == SortOrder::Descending) { + return SelectKUnstable(values, SelectKOptions::TopKDefault(k)); + } else { + return SelectKUnstable(values, SelectKOptions::BottomKDefault(k)); + } +} + +template +Result> SelectK(const Datum& values, + const SelectKOptions& options) { + if (order == SortOrder::Descending) { + return SelectKUnstable(Datum(values), options); + } else { + return SelectKUnstable(Datum(values), options); + } +} + +void ValidateSelectK(const Datum& datum, Array& select_k_indices, SortOrder order, + bool stable_sort = false) { + ASSERT_TRUE(datum.is_arraylike()); + ASSERT_OK_AND_ASSIGN(auto sorted_indices, + SortIndices(datum, SortOptions({SortKey("unused", order)}))); + + int64_t k = select_k_indices.length(); + // head(k) + auto head_k_indices = sorted_indices->Slice(0, k); + if (stable_sort) { + AssertDatumsEqual(*head_k_indices, select_k_indices); + } else { + ASSERT_OK_AND_ASSIGN(auto expected, + Take(datum, *head_k_indices, TakeOptions::NoBoundsCheck())); + ASSERT_OK_AND_ASSIGN(auto actual, + Take(datum, select_k_indices, TakeOptions::NoBoundsCheck())); + AssertDatumsEqual(Datum(expected), Datum(actual)); + } +} + +template +class TestSelectKBase : public TestBase { + using ArrayType = typename TypeTraits::ArrayType; + + protected: + template + void AssertSelectKArray(const std::shared_ptr values, int k) { + std::shared_ptr select_k; + ASSERT_OK_AND_ASSIGN(select_k, SelectK(Datum(*values), k)); + ASSERT_EQ(select_k->data()->null_count, 0); + ValidateOutput(*select_k); + ValidateSelectK(Datum(*values), *select_k, order); + } + + void AssertTopKArray(const std::shared_ptr values, int n) { + AssertSelectKArray(values, n); + } + void AssertBottomKArray(const std::shared_ptr values, int n) { + AssertSelectKArray(values, n); + } + + void AssertSelectKJson(const std::string& values, int n) { + AssertTopKArray(ArrayFromJSON(type_singleton(), values), n); + AssertBottomKArray(ArrayFromJSON(type_singleton(), values), n); + } + + virtual std::shared_ptr type_singleton() = 0; +}; + +template +class TestSelectK : public TestSelectKBase { + protected: + std::shared_ptr type_singleton() override { + return default_type_instance(); + } +}; + +template +class TestSelectKForReal : public TestSelectK {}; +TYPED_TEST_SUITE(TestSelectKForReal, RealArrowTypes); + +template +class TestSelectKForIntegral : public TestSelectK {}; +TYPED_TEST_SUITE(TestSelectKForIntegral, IntegralArrowTypes); + +template +class TestSelectKForBool : public TestSelectK {}; +TYPED_TEST_SUITE(TestSelectKForBool, ::testing::Types); + +template +class TestSelectKForTemporal : public TestSelectK {}; +TYPED_TEST_SUITE(TestSelectKForTemporal, TemporalArrowTypes); + +template +class TestSelectKForDecimal : public TestSelectKBase { + std::shared_ptr type_singleton() override { + return std::make_shared(5, 2); + } +}; +TYPED_TEST_SUITE(TestSelectKForDecimal, DecimalArrowTypes); + +template +class TestSelectKForStrings : public TestSelectK {}; +TYPED_TEST_SUITE(TestSelectKForStrings, testing::Types); + +TYPED_TEST(TestSelectKForReal, SelectKDoesNotProvideDefaultOptions) { + auto input = ArrayFromJSON(this->type_singleton(), "[null, 1, 3.3, null, 2, 5.3]"); + ASSERT_RAISES(Invalid, CallFunction("select_k_unstable", {input})); +} + +TYPED_TEST(TestSelectKForReal, Real) { + this->AssertSelectKJson("[null, 1, 3.3, null, 2, 5.3]", 0); + this->AssertSelectKJson("[null, 1, 3.3, null, 2, 5.3]", 2); + this->AssertSelectKJson("[null, 1, 3.3, null, 2, 5.3]", 5); + this->AssertSelectKJson("[null, 1, 3.3, null, 2, 5.3]", 6); + + this->AssertSelectKJson("[null, 2, NaN, 3, 1]", 0); + this->AssertSelectKJson("[null, 2, NaN, 3, 1]", 1); + this->AssertSelectKJson("[null, 2, NaN, 3, 1]", 2); + this->AssertSelectKJson("[null, 2, NaN, 3, 1]", 3); + this->AssertSelectKJson("[null, 2, NaN, 3, 1]", 4); + this->AssertSelectKJson("[NaN, 2, null, 3, 1]", 3); + this->AssertSelectKJson("[NaN, 2, null, 3, 1]", 4); + this->AssertSelectKJson("[100, 4, 2, 7, 8, 3, NaN, 3, 1]", 4); +} + +TYPED_TEST(TestSelectKForIntegral, Integral) { + this->AssertSelectKJson("[null, 1, 3, null, 2, 5]", 0); + this->AssertSelectKJson("[null, 1, 3, null, 2, 5]", 2); + this->AssertSelectKJson("[null, 1, 3, null, 2, 5]", 5); + this->AssertSelectKJson("[null, 1, 3, null, 2, 5]", 6); + + this->AssertSelectKJson("[2, 4, 5, 7, 8, 0, 9, 1, 3]", 5); +} + +TYPED_TEST(TestSelectKForBool, Bool) { + this->AssertSelectKJson("[null, false, true, null, false, true]", 0); + this->AssertSelectKJson("[null, false, true, null, false, true]", 2); + this->AssertSelectKJson("[null, false, true, null, false, true]", 5); + this->AssertSelectKJson("[null, false, true, null, false, true]", 6); +} + +TYPED_TEST(TestSelectKForTemporal, Temporal) { + this->AssertSelectKJson("[null, 1, 3, null, 2, 5]", 0); + this->AssertSelectKJson("[null, 1, 3, null, 2, 5]", 2); + this->AssertSelectKJson("[null, 1, 3, null, 2, 5]", 5); + this->AssertSelectKJson("[null, 1, 3, null, 2, 5]", 6); +} + +TYPED_TEST(TestSelectKForDecimal, Decimal) { + const std::string values = R"(["123.45", null, "-123.45", "456.78", "-456.78"])"; + this->AssertSelectKJson(values, 0); + this->AssertSelectKJson(values, 2); + this->AssertSelectKJson(values, 4); + this->AssertSelectKJson(values, 5); +} + +TYPED_TEST(TestSelectKForStrings, Strings) { + this->AssertSelectKJson(R"(["testing", null, "nth", "for", null, "strings"])", 0); + this->AssertSelectKJson(R"(["testing", null, "nth", "for", null, "strings"])", 2); + this->AssertSelectKJson(R"(["testing", null, "nth", "for", null, "strings"])", 5); + this->AssertSelectKJson(R"(["testing", null, "nth", "for", null, "strings"])", 6); +} + +template +class TestSelectKRandom : public TestSelectKBase { + public: + std::shared_ptr type_singleton() override { + EXPECT_TRUE(0) << "shouldn't be used"; + return nullptr; + } +}; + +using SelectKableTypes = + ::testing::Types; + +TYPED_TEST_SUITE(TestSelectKRandom, SelectKableTypes); + +TYPED_TEST(TestSelectKRandom, RandomValues) { + Random rand(0x61549225); + int length = 100; + for (auto null_probability : {0.0, 0.1, 0.5, 1.0}) { + auto array = rand.Generate(length, null_probability); + // Try n from 0 to out of bound + for (int n = 0; n <= length; ++n) { + this->AssertTopKArray(array, n); + this->AssertBottomKArray(array, n); + } + } +} + +// Test basic cases for chunked array + +template +struct TestSelectKWithChunkedArray : public ::testing::Test { + TestSelectKWithChunkedArray() {} + + // Slice `array` into multiple chunks along `offsets` + ArrayVector Slices(const std::shared_ptr& array, + const std::shared_ptr& offsets) { + ArrayVector slices(offsets->length() - 1); + for (int64_t i = 0; i != static_cast(slices.size()); ++i) { + slices[i] = + array->Slice(offsets->Value(i), offsets->Value(i + 1) - offsets->Value(i)); + } + return slices; + } + + template + void AssertSelectK(const std::shared_ptr& chunked_array, int64_t k) { + ASSERT_OK_AND_ASSIGN(auto select_k_array, SelectK(Datum(*chunked_array), k)); + ValidateSelectK(Datum(*chunked_array), *select_k_array, order); + } + + void AssertTopK(const std::shared_ptr& chunked_array, int64_t k) { + AssertSelectK(chunked_array, k); + } + void AssertBottomK(const std::shared_ptr& chunked_array, int64_t k) { + AssertSelectK(chunked_array, k); + } +}; + +TYPED_TEST_SUITE(TestSelectKWithChunkedArray, SelectKableTypes); + +TYPED_TEST(TestSelectKWithChunkedArray, RandomValuesWithSlices) { + Random rand(0x61549225); + int length = 100; + for (auto null_probability : {0.0, 0.1, 0.5, 1.0}) { + // Try n from 0 to out of bound + auto array = rand.Generate(length, null_probability); + auto offsets = rand.Offsets(length, 3); + auto slices = this->Slices(array, offsets); + ASSERT_OK_AND_ASSIGN(auto chunked_array, ChunkedArray::Make(slices)); + for (int k = 0; k <= length; k += 10) { + this->AssertTopK(chunked_array, k); + this->AssertBottomK(chunked_array, k); + } + } +} + +template +void ValidateSelectKIndices(const ArrayType& array) { + ValidateOutput(array); + + SelectKComparator compare; + for (uint64_t i = 1; i < static_cast(array.length()); i++) { + using ArrowType = typename ArrayType::TypeClass; + using GetView = internal::GetViewType; + + const auto lval = GetView::LogicalValue(array.GetView(i - 1)); + const auto rval = GetView::LogicalValue(array.GetView(i)); + ASSERT_TRUE(compare(lval, rval)); + } +} +// Base class for testing against random chunked array. +template +struct TestSelectKWithChunkedArrayRandomBase : public ::testing::Test { + void TestSelectK(int length) { + using ArrayType = typename TypeTraits::ArrayType; + // We can use INSTANTIATE_TEST_SUITE_P() instead of using fors in a test. + for (auto null_probability : {0.0, 0.1, 0.5, 0.9, 1.0}) { + for (auto num_chunks : {1, 2, 5, 10, 40}) { + std::vector> arrays; + for (int i = 0; i < num_chunks; ++i) { + auto array = this->GenerateArray(length / num_chunks, null_probability); + arrays.push_back(array); + } + ASSERT_OK_AND_ASSIGN(auto chunked_array, ChunkedArray::Make(arrays)); + ASSERT_OK_AND_ASSIGN(auto indices, SelectK(Datum(*chunked_array), 5)); + ASSERT_OK_AND_ASSIGN(auto actual, Take(Datum(chunked_array), Datum(indices), + TakeOptions::NoBoundsCheck())); + ASSERT_OK_AND_ASSIGN(auto sorted_k, + Concatenate(actual.chunked_array()->chunks())); + + ValidateSelectKIndices( + *checked_pointer_cast(sorted_k)); + } + } + } + + void SetUp() override { rand_ = new Random(0x5487655); } + + void TearDown() override { delete rand_; } + + protected: + std::shared_ptr GenerateArray(int length, double null_probability) { + return rand_->Generate(length, null_probability); + } + + private: + Random* rand_; +}; + +// Long array with big value range +template +class TestTopKChunkedArrayRandom + : public TestSelectKWithChunkedArrayRandomBase {}; + +TYPED_TEST_SUITE(TestTopKChunkedArrayRandom, SelectKableTypes); + +TYPED_TEST(TestTopKChunkedArrayRandom, TopK) { this->TestSelectK(1000); } + +template +class TestBottomKChunkedArrayRandom + : public TestSelectKWithChunkedArrayRandomBase {}; + +TYPED_TEST_SUITE(TestBottomKChunkedArrayRandom, SelectKableTypes); + +TYPED_TEST(TestBottomKChunkedArrayRandom, BottomK) { this->TestSelectK(1000); } + +// // Test basic cases for record batch. +template +class TestSelectKWithRecordBatch : public ::testing::Test { + public: + void Check(const std::shared_ptr& schm, const std::string& batch_json, + const SelectKOptions& options, const std::string& expected_batch) { + std::shared_ptr actual; + ASSERT_OK(this->DoSelectK(schm, batch_json, options, &actual)); + ASSERT_BATCHES_EQUAL(*RecordBatchFromJSON(schm, expected_batch), *actual); + } + + Status DoSelectK(const std::shared_ptr& schm, const std::string& batch_json, + const SelectKOptions& options, std::shared_ptr* out) { + auto batch = RecordBatchFromJSON(schm, batch_json); + ARROW_ASSIGN_OR_RAISE(auto indices, SelectK(Datum(*batch), options)); + + ValidateOutput(*indices); + ARROW_ASSIGN_OR_RAISE( + auto select_k, Take(Datum(batch), Datum(indices), TakeOptions::NoBoundsCheck())); + *out = select_k.record_batch(); + return Status::OK(); + } +}; + +struct TestTopKWithRecordBatch : TestSelectKWithRecordBatch {}; + +TEST_F(TestTopKWithRecordBatch, NoNull) { + auto schema = ::arrow::schema({ + {field("a", uint8())}, + {field("b", uint32())}, + }); + + auto batch_input = R"([ + {"a": 3, "b": 5}, + {"a": 30, "b": 3}, + {"a": 3, "b": 4}, + {"a": 0, "b": 6}, + {"a": 20, "b": 5}, + {"a": 10, "b": 5}, + {"a": 10, "b": 3} + ])"; + + auto options = SelectKOptions::TopKDefault(3, {"a"}); + + auto expected_batch = R"([ + {"a": 30, "b": 3}, + {"a": 20, "b": 5}, + {"a": 10, "b": 5} + ])"; + + Check(schema, batch_input, options, expected_batch); +} + +TEST_F(TestTopKWithRecordBatch, Null) { + auto schema = ::arrow::schema({ + {field("a", uint8())}, + {field("b", uint32())}, + }); + + auto batch_input = R"([ + {"a": null, "b": 5}, + {"a": 30, "b": 3}, + {"a": null, "b": 4}, + {"a": null, "b": 6}, + {"a": 20, "b": 5}, + {"a": null, "b": 5}, + {"a": 10, "b": 3} + ])"; + + auto options = SelectKOptions::TopKDefault(3, {"a"}); + + auto expected_batch = R"([ + {"a": 30, "b": 3}, + {"a": 20, "b": 5}, + {"a": 10, "b": 3} + ])"; + + Check(schema, batch_input, options, expected_batch); +} + +TEST_F(TestTopKWithRecordBatch, OneColumnKey) { + auto schema = ::arrow::schema({ + {field("country", utf8())}, + {field("population", uint64())}, + }); + + auto batch_input = + R"([{"country": "Italy", "population": 59000000}, + {"country": "France", "population": 65000000}, + {"country": "Malta", "population": 434000}, + {"country": "Maldives", "population": 434000}, + {"country": "Brunei", "population": 434000}, + {"country": "Iceland", "population": 337000}, + {"country": "Nauru", "population": 11300}, + {"country": "Tuvalu", "population": 11300}, + {"country": "Anguilla", "population": 11300}, + {"country": "Montserrat", "population": 5200} + ])"; + + auto options = SelectKOptions::TopKDefault(3, {"population"}); + + auto expected_batch = + R"([{"country": "France", "population": 65000000}, + {"country": "Italy", "population": 59000000}, + {"country": "Malta", "population": 434000} + ])"; + this->Check(schema, batch_input, options, expected_batch); +} + +TEST_F(TestTopKWithRecordBatch, MultipleColumnKeys) { + auto schema = ::arrow::schema({{field("country", utf8())}, + {field("population", uint64())}, + {field("GDP", uint64())}}); + + auto batch_input = + R"([{"country": "Italy", "population": 59000000, "GDP": 1937894}, + {"country": "France", "population": 65000000, "GDP": 2583560}, + {"country": "Malta", "population": 434000, "GDP": 12011}, + {"country": "Maldives", "population": 434000, "GDP": 4520}, + {"country": "Brunei", "population": 434000, "GDP": 12128}, + {"country": "Iceland", "population": 337000, "GDP": 17036}, + {"country": "Nauru", "population": 337000, "GDP": 182}, + {"country": "Tuvalu", "population": 11300, "GDP": 38}, + {"country": "Anguilla", "population": 11300, "GDP": 311} + ])"; + auto options = SelectKOptions::TopKDefault(3, {"population", "GDP"}); + + auto expected_batch = + R"([{"country": "France", "population": 65000000, "GDP": 2583560}, + {"country": "Italy", "population": 59000000, "GDP": 1937894}, + {"country": "Brunei", "population": 434000, "GDP": 12128} + ])"; + this->Check(schema, batch_input, options, expected_batch); +} + +struct TestBottomKWithRecordBatch : TestSelectKWithRecordBatch {}; + +TEST_F(TestBottomKWithRecordBatch, NoNull) { + auto schema = ::arrow::schema({ + {field("a", uint8())}, + {field("b", uint32())}, + }); + + auto batch_input = R"([ + {"a": 3, "b": 5}, + {"a": 30, "b": 3}, + {"a": 3, "b": 4}, + {"a": 0, "b": 6}, + {"a": 20, "b": 5}, + {"a": 10, "b": 5}, + {"a": 10, "b": 3} + ])"; + + auto options = SelectKOptions::BottomKDefault(3, {"a"}); + + auto expected_batch = R"([ + {"a": 0, "b": 6}, + {"a": 3, "b": 4}, + {"a": 3, "b": 5} + ])"; + + Check(schema, batch_input, options, expected_batch); +} + +TEST_F(TestBottomKWithRecordBatch, Null) { + auto schema = ::arrow::schema({ + {field("a", uint8())}, + {field("b", uint32())}, + }); + + auto batch_input = R"([ + {"a": null, "b": 5}, + {"a": 30, "b": 3}, + {"a": null, "b": 4}, + {"a": null, "b": 6}, + {"a": 20, "b": 5}, + {"a": null, "b": 5}, + {"a": 10, "b": 3} + ])"; + + auto options = SelectKOptions::BottomKDefault(3, {"a"}); + + auto expected_batch = R"([ + {"a": 10, "b": 3}, + {"a": 20, "b": 5}, + {"a": 30, "b": 3} + ])"; + + Check(schema, batch_input, options, expected_batch); +} + +TEST_F(TestBottomKWithRecordBatch, OneColumnKey) { + auto schema = ::arrow::schema({ + {field("country", utf8())}, + {field("population", uint64())}, + }); + + auto batch_input = + R"([{"country": "Italy", "population": 59000000}, + {"country": "France", "population": 65000000}, + {"country": "Malta", "population": 434000}, + {"country": "Maldives", "population": 434000}, + {"country": "Brunei", "population": 434000}, + {"country": "Iceland", "population": 337000}, + {"country": "Nauru", "population": 11300}, + {"country": "Tuvalu", "population": 11300}, + {"country": "Anguilla", "population": 11300}, + {"country": "Montserrat", "population": 5200} + ])"; + + auto options = SelectKOptions::BottomKDefault(3, {"population"}); + + auto expected_batch = + R"([{"country": "Montserrat", "population": 5200}, + {"country": "Anguilla", "population": 11300}, + {"country": "Tuvalu", "population": 11300} + ])"; + this->Check(schema, batch_input, options, expected_batch); +} + +TEST_F(TestBottomKWithRecordBatch, MultipleColumnKeys) { + auto schema = ::arrow::schema({{field("country", utf8())}, + {field("population", uint64())}, + {field("GDP", uint64())}}); + + auto batch_input = + R"([{"country": "Italy", "population": 59000000, "GDP": 1937894}, + {"country": "France", "population": 65000000, "GDP": 2583560}, + {"country": "Malta", "population": 434000, "GDP": 12011}, + {"country": "Maldives", "population": 434000, "GDP": 4520}, + {"country": "Brunei", "population": 434000, "GDP": 12128}, + {"country": "Iceland", "population": 337000, "GDP": 17036}, + {"country": "Nauru", "population": 337000, "GDP": 182}, + {"country": "Tuvalu", "population": 11300, "GDP": 38}, + {"country": "Anguilla", "population": 11300, "GDP": 311} + ])"; + + auto options = SelectKOptions::BottomKDefault(3, {"population", "GDP"}); + + auto expected_batch = + R"([{"country": "Tuvalu", "population": 11300, "GDP": 38}, + {"country": "Anguilla", "population": 11300, "GDP": 311}, + {"country": "Nauru", "population": 337000, "GDP": 182} + ])"; + this->Check(schema, batch_input, options, expected_batch); +} + +// Test basic cases for table. +template +struct TestSelectKWithTable : public ::testing::Test { + void Check(const std::shared_ptr& schm, + const std::vector& input_json, const SelectKOptions& options, + const std::vector& expected) { + std::shared_ptr actual; + ASSERT_OK(this->DoSelectK(schm, input_json, options, &actual)); + ASSERT_TABLES_EQUAL(*TableFromJSON(schm, expected), *actual); + } + + Status DoSelectK(const std::shared_ptr& schm, + const std::vector& input_json, + const SelectKOptions& options, std::shared_ptr
* out) { + auto table = TableFromJSON(schm, input_json); + ARROW_ASSIGN_OR_RAISE(auto indices, SelectK(Datum(*table), options)); + ValidateOutput(*indices); + + ARROW_ASSIGN_OR_RAISE( + auto select_k, Take(Datum(table), Datum(indices), TakeOptions::NoBoundsCheck())); + *out = select_k.table(); + return Status::OK(); + } +}; + +struct TestTopKWithTable : TestSelectKWithTable {}; + +TEST_F(TestTopKWithTable, OneColumnKey) { + auto schema = ::arrow::schema({ + {field("a", uint8())}, + {field("b", uint32())}, + }); + + std::vector input = {R"([{"a": null, "b": 5}, + {"a": 1, "b": 3}, + {"a": 3, "b": null}, + {"a": null, "b": null}, + {"a": 2, "b": 5}, + {"a": 1, "b": 5} + ])"}; + + auto options = SelectKOptions::TopKDefault(3, {"a"}); + + std::vector expected = {R"([{"a": 3, "b": null}, + {"a": 2, "b": 5}, + {"a": 1, "b": 3} + ])"}; + Check(schema, input, options, expected); +} + +TEST_F(TestTopKWithTable, MultipleColumnKeys) { + auto schema = ::arrow::schema({ + {field("a", uint8())}, + {field("b", uint32())}, + }); + std::vector input = {R"([{"a": null, "b": 5}, + {"a": 1, "b": 3}, + {"a": 3, "b": null} + ])", + R"([{"a": null, "b": null}, + {"a": 2, "b": 5}, + {"a": 1, "b": 5} + ])"}; + + auto options = SelectKOptions::TopKDefault(3, {"a", "b"}); + + std::vector expected = {R"([{"a": 3, "b": null}, + {"a": 2, "b": 5}, + {"a": 1, "b": 5} + ])"}; + Check(schema, input, options, expected); +} + +struct TestBottomKWithTable : TestSelectKWithTable {}; + +TEST_F(TestBottomKWithTable, OneColumnKey) { + auto schema = ::arrow::schema({ + {field("a", uint8())}, + {field("b", uint32())}, + }); + + std::vector input = {R"([{"a": null, "b": 5}, + {"a": 0, "b": 3}, + {"a": 3, "b": null}, + {"a": null, "b": null}, + {"a": 2, "b": 5}, + {"a": 1, "b": 5} + ])"}; + + auto options = SelectKOptions::BottomKDefault(3, {"a"}); + + std::vector expected = {R"([{"a": 0, "b": 3}, + {"a": 1, "b": 5}, + {"a": 2, "b": 5} + ])"}; + Check(schema, input, options, expected); +} + +TEST_F(TestBottomKWithTable, MultipleColumnKeys) { + auto schema = ::arrow::schema({ + {field("a", uint8())}, + {field("b", uint32())}, + }); + std::vector input = {R"([{"a": null, "b": 5}, + {"a": 1, "b": 3}, + {"a": 3, "b": null} + ])", + R"([{"a": null, "b": null}, + {"a": 2, "b": 5}, + {"a": 1, "b": 5} + ])"}; + + auto options = SelectKOptions::BottomKDefault(3, {"a", "b"}); + + std::vector expected = {R"([{"a": 1, "b": 3}, + {"a": 1, "b": 5}, + {"a": 2, "b": 5} + ])"}; + Check(schema, input, options, expected); +} + +} // namespace compute +} // namespace arrow diff --git a/cpp/src/arrow/compute/kernels/test_util.h b/cpp/src/arrow/compute/kernels/test_util.h index b685599f952..79745b05552 100644 --- a/cpp/src/arrow/compute/kernels/test_util.h +++ b/cpp/src/arrow/compute/kernels/test_util.h @@ -35,6 +35,7 @@ #include "arrow/testing/random.h" #include "arrow/testing/util.h" #include "arrow/type.h" +#include "arrow/util/checked_cast.h" // IWYU pragma: end_exports @@ -142,5 +143,84 @@ enable_if_decimal> default_type_instance() { return std::make_shared(5, 2); } +// Random Generator Helpers +class RandomImpl { + protected: + random::RandomArrayGenerator generator_; + std::shared_ptr type_; + + explicit RandomImpl(random::SeedType seed, std::shared_ptr type) + : generator_(seed), type_(std::move(type)) {} + + public: + std::shared_ptr Generate(uint64_t count, double null_prob) { + return generator_.ArrayOf(type_, count, null_prob); + } + + std::shared_ptr Offsets(int32_t length, int32_t slice_count) { + return arrow::internal::checked_pointer_cast( + generator_.Offsets(slice_count, 0, length)); + } +}; + +template +class Random : public RandomImpl { + public: + explicit Random(random::SeedType seed) + : RandomImpl(seed, TypeTraits::type_singleton()) {} +}; + +template <> +class Random : public RandomImpl { + using CType = float; + + public: + explicit Random(random::SeedType seed) : RandomImpl(seed, float32()) {} + + std::shared_ptr Generate(uint64_t count, double null_prob, double nan_prob = 0) { + return generator_.Float32(count, std::numeric_limits::min(), + std::numeric_limits::max(), null_prob, nan_prob); + } +}; + +template <> +class Random : public RandomImpl { + using CType = double; + + public: + explicit Random(random::SeedType seed) : RandomImpl(seed, float64()) {} + + std::shared_ptr Generate(uint64_t count, double null_prob, double nan_prob = 0) { + return generator_.Float64(count, std::numeric_limits::min(), + std::numeric_limits::max(), null_prob, nan_prob); + } +}; + +template <> +class Random : public RandomImpl { + public: + explicit Random(random::SeedType seed, + std::shared_ptr type = decimal128(18, 5)) + : RandomImpl(seed, std::move(type)) {} +}; + +template +class RandomRange : public RandomImpl { + using CType = typename TypeTraits::CType; + + public: + explicit RandomRange(random::SeedType seed) + : RandomImpl(seed, TypeTraits::type_singleton()) {} + + std::shared_ptr Generate(uint64_t count, int range, double null_prob) { + CType min = std::numeric_limits::min(); + CType max = min + range; + if (sizeof(CType) < 4 && (range + min) > std::numeric_limits::max()) { + max = std::numeric_limits::max(); + } + return generator_.Numeric(count, min, max, null_prob); + } +}; + } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/compute/kernels/vector_sort.cc b/cpp/src/arrow/compute/kernels/vector_sort.cc index 7fa43e715d8..58a48aa9056 100644 --- a/cpp/src/arrow/compute/kernels/vector_sort.cc +++ b/cpp/src/arrow/compute/kernels/vector_sort.cc @@ -19,9 +19,11 @@ #include #include #include +#include #include #include +#include "arrow/array/concatenate.h" #include "arrow/array/data.h" #include "arrow/compute/api_vector.h" #include "arrow/compute/kernels/common.h" @@ -98,7 +100,7 @@ struct ResolvedChunk { struct ChunkedArrayResolver { explicit ChunkedArrayResolver(const std::vector& chunks) : num_chunks_(static_cast(chunks.size())), - chunks_(chunks.data()), + chunks_(chunks), offsets_(MakeEndOffsets(chunks)), cached_chunk_(0) {} @@ -156,7 +158,7 @@ struct ChunkedArrayResolver { } int64_t num_chunks_; - const Array* const* chunks_; + const std::vector chunks_; std::vector offsets_; mutable int64_t cached_chunk_; @@ -1142,6 +1144,23 @@ class MultipleKeyComparator { return current_compared_ < 0; } + bool Equals(uint64_t left, uint64_t right, size_t start_sort_key_index) { + current_left_ = left; + current_right_ = right; + current_compared_ = 0; + auto num_sort_keys = sort_keys_.size(); + for (size_t i = start_sort_key_index; i < num_sort_keys; ++i) { + current_sort_key_index_ = i; + status_ = VisitTypeInline(*sort_keys_[i].type, this); + // If the left value equals to the right value, we need to + // continue to sort. + if (current_compared_ != 0) { + break; + } + } + return current_compared_ == 0; + } + #define VISIT(TYPE) \ Status Visit(const TYPE& type) { \ current_compared_ = CompareType(); \ @@ -1250,7 +1269,7 @@ class MultipleKeyComparator { // Sort a batch using a single sort and multiple-key comparisons. class MultipleKeyRecordBatchSorter : public TypeVisitor { - private: + public: // Preprocessed sort key. struct ResolvedSortKey { ResolvedSortKey(const std::shared_ptr& array, const SortOrder order) @@ -1272,6 +1291,7 @@ class MultipleKeyRecordBatchSorter : public TypeVisitor { int64_t null_count; }; + private: using Comparator = MultipleKeyComparator; public: @@ -1447,7 +1467,7 @@ class TableRadixSorter { // Sort a table using a single sort and multiple-key comparisons. class MultipleKeyTableSorter : public TypeVisitor { - private: + public: // TODO instead of resolving chunks for each column independently, we could // split the table into RecordBatches and pay the cost of chunked indexing // at the first column only. @@ -1778,6 +1798,621 @@ class SortIndicesMetaFunction : public MetaFunction { } }; +// ---------------------------------------------------------------------- +// TopK/BottomK implementations + +const auto kDefaultSelectKOptions = SelectKOptions::Defaults(); + +const FunctionDoc select_k_doc( + "Returns the first k elements ordered by `options.keys`", + ("This function computes the k elements of the input\n" + "array, record batch or table specified in the column names (`options.sort_keys`).\n" + "The columns that are not specified are returned as well, but not used for\n" + "ordering. Null values are considered greater than any other value and are\n" + "therefore sorted at the end of the array.\n" + "For floating-point types, NaNs are considered greater than any\n" + "other non-null value, but smaller than null values."), + {"input"}, "SelectKOptions"); + +Result> MakeMutableUInt64Array( + std::shared_ptr out_type, int64_t length, MemoryPool* memory_pool) { + auto buffer_size = length * sizeof(uint64_t); + ARROW_ASSIGN_OR_RAISE(auto data, AllocateBuffer(buffer_size, memory_pool)); + return ArrayData::Make(uint64(), length, {nullptr, std::move(data)}, /*null_count=*/0); +} + +template +class SelectKComparator { + public: + template + bool operator()(const Type& lval, const Type& rval); +}; + +template <> +class SelectKComparator { + public: + template + bool operator()(const Type& lval, const Type& rval) { + return lval < rval; + } +}; + +template <> +class SelectKComparator { + public: + template + bool operator()(const Type& lval, const Type& rval) { + return rval < lval; + } +}; + +template +class ArraySelecter : public TypeVisitor { + public: + ArraySelecter(ExecContext* ctx, const Array& array, const SelectKOptions& options, + Datum* output) + : TypeVisitor(), + ctx_(ctx), + array_(array), + k_(options.k), + physical_type_(GetPhysicalType(array.type())), + output_(output) {} + + Status Run() { return physical_type_->Accept(this); } + +#define VISIT(TYPE) \ + Status Visit(const TYPE& type) { return SelectKthInternal(); } + + VISIT_PHYSICAL_TYPES(VISIT) + +#undef VISIT + + template + Status SelectKthInternal() { + using GetView = GetViewType; + using ArrayType = typename TypeTraits::ArrayType; + + ArrayType arr(array_.data()); + std::vector indices(arr.length()); + + uint64_t* indices_begin = indices.data(); + uint64_t* indices_end = indices_begin + indices.size(); + std::iota(indices_begin, indices_end, 0); + if (k_ > arr.length()) { + k_ = arr.length(); + } + auto end_iter = PartitionNulls(indices_begin, + indices_end, arr, 0); + auto kth_begin = indices_begin + k_; + if (kth_begin > end_iter) { + kth_begin = end_iter; + } + SelectKComparator comparator; + auto cmp = [&arr, &comparator](uint64_t left, uint64_t right) { + const auto lval = GetView::LogicalValue(arr.GetView(left)); + const auto rval = GetView::LogicalValue(arr.GetView(right)); + return comparator(lval, rval); + }; + using HeapContainer = + std::priority_queue, decltype(cmp)>; + HeapContainer heap(indices_begin, kth_begin, cmp); + for (auto iter = kth_begin; iter != end_iter && !heap.empty(); ++iter) { + uint64_t x_index = *iter; + if (cmp(x_index, heap.top())) { + heap.pop(); + heap.push(x_index); + } + } + int64_t out_size = static_cast(heap.size()); + ARROW_ASSIGN_OR_RAISE(auto take_indices, MakeMutableUInt64Array(uint64(), out_size, + ctx_->memory_pool())); + + auto* out_cbegin = take_indices->GetMutableValues(1) + out_size - 1; + while (heap.size() > 0) { + *out_cbegin = heap.top(); + heap.pop(); + --out_cbegin; + } + *output_ = Datum(take_indices); + return Status::OK(); + } + + ExecContext* ctx_; + const Array& array_; + int64_t k_; + const std::shared_ptr physical_type_; + Datum* output_; +}; + +template +struct TypedHeapItem { + uint64_t index; + uint64_t offset; + ArrayType* array; +}; + +template +class ChunkedArraySelecter : public TypeVisitor { + public: + ChunkedArraySelecter(ExecContext* ctx, const ChunkedArray& chunked_array, + const SelectKOptions& options, Datum* output) + : TypeVisitor(), + chunked_array_(chunked_array), + physical_type_(GetPhysicalType(chunked_array.type())), + physical_chunks_(GetPhysicalChunks(chunked_array_, physical_type_)), + k_(options.k), + ctx_(ctx), + output_(output) {} + + Status Run() { return physical_type_->Accept(this); } + +#define VISIT(TYPE) \ + Status Visit(const TYPE& type) { return SelectKthInternal(); } + + VISIT_PHYSICAL_TYPES(VISIT) + +#undef VISIT + + template + Status SelectKthInternal() { + using GetView = GetViewType; + using ArrayType = typename TypeTraits::ArrayType; + using HeapItem = TypedHeapItem; + + const auto num_chunks = chunked_array_.num_chunks(); + if (num_chunks == 0) { + return Status::OK(); + } + if (k_ > chunked_array_.length()) { + k_ = chunked_array_.length(); + } + std::function cmp; + SelectKComparator comparator; + + cmp = [&comparator](const HeapItem& left, const HeapItem& right) -> bool { + const auto lval = GetView::LogicalValue(left.array->GetView(left.index)); + const auto rval = GetView::LogicalValue(right.array->GetView(right.index)); + return comparator(lval, rval); + }; + using HeapContainer = + std::priority_queue, decltype(cmp)>; + + HeapContainer heap(cmp); + std::vector> chunks_holder; + uint64_t offset = 0; + for (const auto& chunk : physical_chunks_) { + if (chunk->length() == 0) continue; + chunks_holder.emplace_back(std::make_shared(chunk->data())); + ArrayType& arr = *chunks_holder[chunks_holder.size() - 1]; + + std::vector indices(arr.length()); + uint64_t* indices_begin = indices.data(); + uint64_t* indices_end = indices_begin + indices.size(); + std::iota(indices_begin, indices_end, 0); + + auto end_iter = PartitionNulls( + indices_begin, indices_end, arr, 0); + auto kth_begin = indices_begin + k_; + + if (kth_begin > end_iter) { + kth_begin = end_iter; + } + uint64_t* iter = indices_begin; + for (; iter != kth_begin && heap.size() < static_cast(k_); ++iter) { + heap.push(HeapItem{*iter, offset, &arr}); + } + for (; iter != end_iter && !heap.empty(); ++iter) { + uint64_t x_index = *iter; + const auto& xval = GetView::LogicalValue(arr.GetView(x_index)); + auto top_item = heap.top(); + const auto& top_value = + GetView::LogicalValue(top_item.array->GetView(top_item.index)); + if (comparator(xval, top_value)) { + heap.pop(); + heap.push(HeapItem{x_index, offset, &arr}); + } + } + offset += chunk->length(); + } + + int64_t out_size = static_cast(heap.size()); + ARROW_ASSIGN_OR_RAISE(auto take_indices, MakeMutableUInt64Array(uint64(), out_size, + ctx_->memory_pool())); + auto* out_cbegin = take_indices->GetMutableValues(1) + out_size - 1; + while (heap.size() > 0) { + auto top_item = heap.top(); + *out_cbegin = top_item.index + top_item.offset; + heap.pop(); + --out_cbegin; + } + *output_ = Datum(take_indices); + return Status::OK(); + } + + const ChunkedArray& chunked_array_; + const std::shared_ptr physical_type_; + const ArrayVector physical_chunks_; + int64_t k_; + ExecContext* ctx_; + Datum* output_; +}; + +class RecordBatchSelecter : public TypeVisitor { + private: + using ResolvedSortKey = MultipleKeyRecordBatchSorter::ResolvedSortKey; + using Comparator = MultipleKeyComparator; + + public: + RecordBatchSelecter(ExecContext* ctx, const RecordBatch& record_batch, + const SelectKOptions& options, Datum* output) + : TypeVisitor(), + ctx_(ctx), + record_batch_(record_batch), + k_(options.k), + output_(output), + sort_keys_(ResolveSortKeys(record_batch, options.sort_keys)), + comparator_(sort_keys_) {} + + Status Run() { return sort_keys_[0].type->Accept(this); } + + protected: +#define VISIT(TYPE) \ + Status Visit(const TYPE& type) { \ + if (sort_keys_[0].order == SortOrder::Descending) \ + return SelectKthInternal(); \ + return SelectKthInternal(); \ + } + VISIT_PHYSICAL_TYPES(VISIT) +#undef VISIT + + static std::vector ResolveSortKeys( + const RecordBatch& batch, const std::vector& sort_keys) { + std::vector resolved; + for (const auto& key : sort_keys) { + auto array = batch.GetColumnByName(key.name); + resolved.emplace_back(array, key.order); + } + return resolved; + } + + template + Status SelectKthInternal() { + using GetView = GetViewType; + using ArrayType = typename TypeTraits::ArrayType; + auto& comparator = comparator_; + const auto& first_sort_key = sort_keys_[0]; + const ArrayType& arr = checked_cast(first_sort_key.array); + + const auto num_rows = record_batch_.num_rows(); + if (num_rows == 0) { + return Status::OK(); + } + if (k_ > record_batch_.num_rows()) { + k_ = record_batch_.num_rows(); + } + std::function cmp; + SelectKComparator select_k_comparator; + cmp = [&](const uint64_t& left, const uint64_t& right) -> bool { + const auto lval = GetView::LogicalValue(arr.GetView(left)); + const auto rval = GetView::LogicalValue(arr.GetView(right)); + if (lval == rval) { + // If the left value equals to the right value, + // we need to compare the second and following + // sort keys. + return comparator.Compare(left, right, 1); + } + return select_k_comparator(lval, rval); + }; + using HeapContainer = + std::priority_queue, decltype(cmp)>; + + std::vector indices(arr.length()); + uint64_t* indices_begin = indices.data(); + uint64_t* indices_end = indices_begin + indices.size(); + std::iota(indices_begin, indices_end, 0); + + auto end_iter = PartitionNulls(indices_begin, + indices_end, arr, 0); + auto kth_begin = indices_begin + k_; + + if (kth_begin > end_iter) { + kth_begin = end_iter; + } + HeapContainer heap(indices_begin, kth_begin, cmp); + for (auto iter = kth_begin; iter != end_iter && !heap.empty(); ++iter) { + uint64_t x_index = *iter; + auto top_item = heap.top(); + if (cmp(x_index, top_item)) { + heap.pop(); + heap.push(x_index); + } + } + int64_t out_size = static_cast(heap.size()); + ARROW_ASSIGN_OR_RAISE(auto take_indices, MakeMutableUInt64Array(uint64(), out_size, + ctx_->memory_pool())); + auto* out_cbegin = take_indices->GetMutableValues(1) + out_size - 1; + while (heap.size() > 0) { + *out_cbegin = heap.top(); + heap.pop(); + --out_cbegin; + } + *output_ = Datum(take_indices); + return Status::OK(); + } + + ExecContext* ctx_; + const RecordBatch& record_batch_; + int64_t k_; + Datum* output_; + std::vector sort_keys_; + Comparator comparator_; +}; + +class TableSelecter : public TypeVisitor { + private: + using ResolvedSortKey = MultipleKeyTableSorter::ResolvedSortKey; + using Comparator = MultipleKeyComparator; + + public: + TableSelecter(ExecContext* ctx, const Table& table, const SelectKOptions& options, + Datum* output) + : TypeVisitor(), + ctx_(ctx), + table_(table), + k_(options.k), + output_(output), + sort_keys_(ResolveSortKeys(table, options.sort_keys)), + comparator_(sort_keys_) {} + + Status Run() { return sort_keys_[0].type->Accept(this); } + + protected: +#define VISIT(TYPE) \ + Status Visit(const TYPE& type) { \ + if (sort_keys_[0].order == SortOrder::Descending) \ + return SelectKthInternal(); \ + return SelectKthInternal(); \ + } + VISIT_PHYSICAL_TYPES(VISIT) + +#undef VISIT + + static std::vector ResolveSortKeys( + const Table& table, const std::vector& sort_keys) { + std::vector resolved; + for (const auto& key : sort_keys) { + auto chunked_array = table.GetColumnByName(key.name); + resolved.emplace_back(*chunked_array, key.order); + } + return resolved; + } + + // Behaves like PatitionNulls() but this supports multiple sort keys. + // + // For non-float types. + template + enable_if_t::value, uint64_t*> PartitionNullsInternal( + uint64_t* indices_begin, uint64_t* indices_end, + const ResolvedSortKey& first_sort_key) { + using ArrayType = typename TypeTraits::ArrayType; + if (first_sort_key.null_count == 0) { + return indices_end; + } + StablePartitioner partitioner; + auto nulls_begin = + partitioner(indices_begin, indices_end, [&first_sort_key](uint64_t index) { + const auto chunk = + first_sort_key.GetChunk(static_cast(index)); + return !chunk.IsNull(); + }); + DCHECK_EQ(indices_end - nulls_begin, first_sort_key.null_count); + auto& comparator = comparator_; + std::stable_sort(nulls_begin, indices_end, [&](uint64_t left, uint64_t right) { + return comparator.Compare(left, right, 1); + }); + return nulls_begin; + } + + // Behaves like PatitionNulls() but this supports multiple sort keys. + // + // For float types. + template + enable_if_t::value, uint64_t*> PartitionNullsInternal( + uint64_t* indices_begin, uint64_t* indices_end, + const ResolvedSortKey& first_sort_key) { + using ArrayType = typename TypeTraits::ArrayType; + StablePartitioner partitioner; + uint64_t* nulls_begin; + if (first_sort_key.null_count == 0) { + nulls_begin = indices_end; + } else { + nulls_begin = partitioner(indices_begin, indices_end, [&](uint64_t index) { + const auto chunk = first_sort_key.GetChunk(index); + return !chunk.IsNull(); + }); + } + DCHECK_EQ(indices_end - nulls_begin, first_sort_key.null_count); + uint64_t* nans_begin = partitioner(indices_begin, nulls_begin, [&](uint64_t index) { + const auto chunk = first_sort_key.GetChunk(index); + return !std::isnan(chunk.Value()); + }); + auto& comparator = comparator_; + // Sort all NaNs by the second and following sort keys. + std::stable_sort(nans_begin, nulls_begin, [&](uint64_t left, uint64_t right) { + return comparator.Compare(left, right, 1); + }); + // Sort all nulls by the second and following sort keys. + std::stable_sort(nulls_begin, indices_end, [&](uint64_t left, uint64_t right) { + return comparator.Compare(left, right, 1); + }); + return nans_begin; + } + + template + Status SelectKthInternal() { + using ArrayType = typename TypeTraits::ArrayType; + auto& comparator = comparator_; + const auto& first_sort_key = sort_keys_[0]; + + const auto num_rows = table_.num_rows(); + if (num_rows == 0) { + return Status::OK(); + } + if (k_ > table_.num_rows()) { + k_ = table_.num_rows(); + } + std::function cmp; + SelectKComparator select_k_comparator; + cmp = [&](const uint64_t& left, const uint64_t& right) -> bool { + auto chunk_left = first_sort_key.template GetChunk(left); + auto chunk_right = first_sort_key.template GetChunk(right); + auto value_left = chunk_left.Value(); + auto value_right = chunk_right.Value(); + if (value_left == value_right) { + return comparator.Compare(left, right, 1); + } + return select_k_comparator(value_left, value_right); + }; + using HeapContainer = + std::priority_queue, decltype(cmp)>; + + std::vector indices(num_rows); + uint64_t* indices_begin = indices.data(); + uint64_t* indices_end = indices_begin + indices.size(); + std::iota(indices_begin, indices_end, 0); + + auto end_iter = + this->PartitionNullsInternal(indices_begin, indices_end, first_sort_key); + auto kth_begin = indices_begin + k_; + + if (kth_begin > end_iter) { + kth_begin = end_iter; + } + HeapContainer heap(indices_begin, kth_begin, cmp); + for (auto iter = kth_begin; iter != end_iter && !heap.empty(); ++iter) { + uint64_t x_index = *iter; + uint64_t top_item = heap.top(); + if (cmp(x_index, top_item)) { + heap.pop(); + heap.push(x_index); + } + } + int64_t out_size = static_cast(heap.size()); + ARROW_ASSIGN_OR_RAISE(auto take_indices, MakeMutableUInt64Array(uint64(), out_size, + ctx_->memory_pool())); + auto* out_cbegin = take_indices->GetMutableValues(1) + out_size - 1; + while (heap.size() > 0) { + *out_cbegin = heap.top(); + heap.pop(); + --out_cbegin; + } + *output_ = Datum(take_indices); + return Status::OK(); + } + + ExecContext* ctx_; + const Table& table_; + int64_t k_; + Datum* output_; + std::vector sort_keys_; + Comparator comparator_; +}; + +static Status CheckConsistency(const Schema& schema, + const std::vector& sort_keys) { + for (const auto& key : sort_keys) { + auto field = schema.GetFieldByName(key.name); + if (!field) { + return Status::Invalid("Nonexistent sort key column: ", key.name); + } + } + return Status::OK(); +} + +class SelectKUnstableMetaFunction : public MetaFunction { + public: + SelectKUnstableMetaFunction() + : MetaFunction("select_k_unstable", Arity::Unary(), &select_k_doc, + &kDefaultSelectKOptions) {} + + Result ExecuteImpl(const std::vector& args, + const FunctionOptions* options, ExecContext* ctx) const { + const SelectKOptions& select_k_options = static_cast(*options); + if (select_k_options.k < 0) { + return Status::Invalid("SelectK requires a nonnegative `k`, got ", + select_k_options.k); + } + switch (args[0].kind()) { + case Datum::ARRAY: { + if (select_k_options.is_top_k()) { + return SelectKth(*args[0].make_array(), select_k_options, + ctx); + } else { + return SelectKth(*args[0].make_array(), select_k_options, + ctx); + } + } break; + case Datum::CHUNKED_ARRAY: { + if (select_k_options.is_top_k()) { + return SelectKth(*args[0].chunked_array(), + select_k_options, ctx); + } else { + return SelectKth(*args[0].chunked_array(), + select_k_options, ctx); + } + } break; + case Datum::RECORD_BATCH: + return SelectKth(*args[0].record_batch(), select_k_options, ctx); + break; + case Datum::TABLE: + return SelectKth(*args[0].table(), select_k_options, ctx); + break; + default: + break; + } + return Status::NotImplemented( + "Unsupported types for select_k operation: " + "values=", + args[0].ToString()); + } + + private: + template + Result SelectKth(const Array& array, const SelectKOptions& options, + ExecContext* ctx) const { + Datum output; + ArraySelecter selecter(ctx, array, options, &output); + ARROW_RETURN_NOT_OK(selecter.Run()); + return output; + } + + template + Result SelectKth(const ChunkedArray& chunked_array, + const SelectKOptions& options, ExecContext* ctx) const { + Datum output; + ChunkedArraySelecter selecter(ctx, chunked_array, options, &output); + ARROW_RETURN_NOT_OK(selecter.Run()); + return output; + } + Result SelectKth(const RecordBatch& record_batch, const SelectKOptions& options, + ExecContext* ctx) const { + ARROW_RETURN_NOT_OK(CheckConsistency(*record_batch.schema(), options.sort_keys)); + Datum output; + RecordBatchSelecter selecter(ctx, record_batch, options, &output); + ARROW_RETURN_NOT_OK(selecter.Run()); + return output; + } + Result SelectKth(const Table& table, const SelectKOptions& options, + ExecContext* ctx) const { + ARROW_RETURN_NOT_OK(CheckConsistency(*table.schema(), options.sort_keys)); + Datum output; + TableSelecter selecter(ctx, table, options, &output); + ARROW_RETURN_NOT_OK(selecter.Run()); + return output; + } +}; + +// array documentation const auto kDefaultArraySortOptions = ArraySortOptions::Defaults(); const FunctionDoc array_sort_indices_doc( @@ -1829,6 +2464,9 @@ void RegisterVectorSort(FunctionRegistry* registry) { base.init = PartitionNthToIndicesState::Init; AddSortingKernels(base, part_indices.get()); DCHECK_OK(registry->AddFunction(std::move(part_indices))); + + // select_k_unstable + DCHECK_OK(registry->AddFunction(std::make_shared())); } #undef VISIT_PHYSICAL_TYPES diff --git a/cpp/src/arrow/compute/kernels/vector_sort_test.cc b/cpp/src/arrow/compute/kernels/vector_sort_test.cc index 478f6ccac3a..131eeed5098 100644 --- a/cpp/src/arrow/compute/kernels/vector_sort_test.cc +++ b/cpp/src/arrow/compute/kernels/vector_sort_test.cc @@ -38,34 +38,6 @@ using internal::checked_cast; using internal::checked_pointer_cast; namespace compute { - -// Convert arrow::Type to arrow::DataType. If arrow::Type isn't -// parameter free, this returns an arrow::DataType with the default -// parameter. -template -enable_if_t::is_parameter_free, std::shared_ptr> -TypeToDataType() { - return TypeTraits::type_singleton(); -} - -template -enable_if_t::value, std::shared_ptr> -TypeToDataType() { - return timestamp(TimeUnit::MILLI); -} - -template -enable_if_t::value, std::shared_ptr> -TypeToDataType() { - return time32(TimeUnit::MILLI); -} - -template -enable_if_t::value, std::shared_ptr> -TypeToDataType() { - return time64(TimeUnit::NANO); -} - // ---------------------------------------------------------------------- // Tests for NthToIndices @@ -169,7 +141,9 @@ class TestNthToIndicesBase : public TestBase { template class TestNthToIndices : public TestNthToIndicesBase { protected: - std::shared_ptr GetType() override { return TypeToDataType(); } + std::shared_ptr GetType() override { + return default_type_instance(); + } }; template @@ -270,79 +244,6 @@ using NthToIndicesableTypes = Int32Type, Int64Type, FloatType, DoubleType, Decimal128Type, StringType>; -class RandomImpl { - protected: - random::RandomArrayGenerator generator_; - std::shared_ptr type_; - - explicit RandomImpl(random::SeedType seed, std::shared_ptr type) - : generator_(seed), type_(std::move(type)) {} - - public: - std::shared_ptr Generate(uint64_t count, double null_prob) { - return generator_.ArrayOf(type_, count, null_prob); - } -}; - -template -class Random : public RandomImpl { - public: - explicit Random(random::SeedType seed) - : RandomImpl(seed, TypeTraits::type_singleton()) {} -}; - -template <> -class Random : public RandomImpl { - using CType = float; - - public: - explicit Random(random::SeedType seed) : RandomImpl(seed, float32()) {} - - std::shared_ptr Generate(uint64_t count, double null_prob, double nan_prob = 0) { - return generator_.Float32(count, std::numeric_limits::min(), - std::numeric_limits::max(), null_prob, nan_prob); - } -}; - -template <> -class Random : public RandomImpl { - using CType = double; - - public: - explicit Random(random::SeedType seed) : RandomImpl(seed, float64()) {} - - std::shared_ptr Generate(uint64_t count, double null_prob, double nan_prob = 0) { - return generator_.Float64(count, std::numeric_limits::min(), - std::numeric_limits::max(), null_prob, nan_prob); - } -}; - -template <> -class Random : public RandomImpl { - public: - explicit Random(random::SeedType seed, - std::shared_ptr type = decimal128(18, 5)) - : RandomImpl(seed, std::move(type)) {} -}; - -template -class RandomRange : public RandomImpl { - using CType = typename TypeTraits::CType; - - public: - explicit RandomRange(random::SeedType seed) - : RandomImpl(seed, TypeTraits::type_singleton()) {} - - std::shared_ptr Generate(uint64_t count, int range, double null_prob) { - CType min = std::numeric_limits::min(); - CType max = min + range; - if (sizeof(CType) < 4 && (range + min) > std::numeric_limits::max()) { - max = std::numeric_limits::max(); - } - return generator_.Numeric(count, min, max, null_prob); - } -}; - TYPED_TEST_SUITE(TestNthToIndicesRandom, NthToIndicesableTypes); TYPED_TEST(TestNthToIndicesRandom, RandomValues) { @@ -686,7 +587,7 @@ TEST_F(TestChunkedArraySortIndices, NaN) { template class TestChunkedArraySortIndicesForTemporal : public TestChunkedArraySortIndices { protected: - std::shared_ptr GetType() { return TypeToDataType(); } + std::shared_ptr GetType() { return default_type_instance(); } }; TYPED_TEST_SUITE(TestChunkedArraySortIndicesForTemporal, TemporalArrowTypes); @@ -1097,7 +998,7 @@ TEST_F(TestTableSortIndices, Decimal) { template class TestTableSortIndicesForTemporal : public TestTableSortIndices { protected: - std::shared_ptr GetType() { return TypeToDataType(); } + std::shared_ptr GetType() { return default_type_instance(); } }; TYPED_TEST_SUITE(TestTableSortIndicesForTemporal, TemporalArrowTypes); diff --git a/cpp/src/arrow/compute/kernels/vector_topk_benchmark.cc b/cpp/src/arrow/compute/kernels/vector_topk_benchmark.cc new file mode 100644 index 00000000000..3f89eb6bea9 --- /dev/null +++ b/cpp/src/arrow/compute/kernels/vector_topk_benchmark.cc @@ -0,0 +1,59 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "benchmark/benchmark.h" + +#include "arrow/compute/api_vector.h" +#include "arrow/compute/kernels/test_util.h" +#include "arrow/testing/gtest_util.h" +#include "arrow/testing/random.h" +#include "arrow/util/benchmark_util.h" + +namespace arrow { +namespace compute { +constexpr auto kSeed = 0x0ff1ce; + +static void SelectKBenchmark(benchmark::State& state, + const std::shared_ptr& values, int64_t k) { + for (auto _ : state) { + ABORT_NOT_OK(SelectKUnstable(*values, SelectKOptions::TopKDefault(k)).status()); + } + state.SetItemsProcessed(state.iterations() * values->length()); +} + +static void SelectKInt64(benchmark::State& state) { + RegressionArgs args(state); + + const int64_t array_size = args.size / sizeof(int64_t); + auto rand = random::RandomArrayGenerator(kSeed); + + auto min = std::numeric_limits::min(); + auto max = std::numeric_limits::max(); + auto values = rand.Int64(array_size, min, max, args.null_proportion); + + SelectKBenchmark(state, values, array_size / 8); +} + +BENCHMARK(SelectKInt64) + ->Apply(RegressionSetArgs) + ->Args({1 << 20, 100}) + ->Args({1 << 23, 100}) + ->MinTime(1.0) + ->Unit(benchmark::TimeUnit::kNanosecond); + +} // namespace compute +} // namespace arrow diff --git a/docs/source/cpp/compute.rst b/docs/source/cpp/compute.rst index 953ad22da05..21955575d61 100644 --- a/docs/source/cpp/compute.rst +++ b/docs/source/cpp/compute.rst @@ -1352,21 +1352,17 @@ In these functions, nulls are considered greater than any other value Floating-point NaN values are considered greater than any other non-null value, but smaller than nulls. -+-----------------------+------------+-----------------------------+-------------------+--------------------------------+----------------+ -| Function name | Arity | Input types | Output type | Options class | Notes | -+=======================+============+=============================+===================+================================+================+ -| partition_nth_indices | Unary | Binary- and String-like | UInt64 | :struct:`PartitionNthOptions` | \(1) \(3) | -+-----------------------+------------+-----------------------------+-------------------+--------------------------------+----------------+ -| partition_nth_indices | Unary | Boolean, Numeric, Temporal | UInt64 | :struct:`PartitionNthOptions` | \(1) | -+-----------------------+------------+-----------------------------+-------------------+--------------------------------+----------------+ -| array_sort_indices | Unary | Binary- and String-like | UInt64 | :struct:`ArraySortOptions` | \(2) \(3) \(4) | -+-----------------------+------------+-----------------------------+-------------------+--------------------------------+----------------+ -| array_sort_indices | Unary | Boolean, Numeric, Temporal | UInt64 | :struct:`ArraySortOptions` | \(2) \(4) | -+-----------------------+------------+-----------------------------+-------------------+--------------------------------+----------------+ -| sort_indices | Unary | Binary- and String-like | UInt64 | :struct:`SortOptions` | \(2) \(3) \(5) | -+-----------------------+------------+-----------------------------+-------------------+--------------------------------+----------------+ -| sort_indices | Unary | Boolean, Numeric, Temporal | UInt64 | :struct:`SortOptions` | \(2) \(5) | -+-----------------------+------------+-----------------------------+-------------------+--------------------------------+----------------+ ++-----------------------+------------+---------------------------------------------------------+-------------------+--------------------------------+----------------+ +| Function name | Arity | Input types | Output type | Options class | Notes | ++=======================+============+=========================================================+===================+================================+================+ +| partition_nth_indices | Unary | Boolean, Numeric, Temporal, Binary- and String-like | UInt64 | :struct:`PartitionNthOptions` | \(1) \(3) | ++-----------------------+------------+---------------------------------------------------------+-------------------+--------------------------------+----------------+ +| array_sort_indices | Unary | Boolean, Numeric, Temporal, Binary- and String-like | UInt64 | :struct:`ArraySortOptions` | \(2) \(4) \(3) | ++-----------------------+------------+---------------------------------------------------------+-------------------+--------------------------------+----------------+ +| select_k_unstable | Unary | Boolean, Numeric, Temporal, Binary- and String-like | UInt64 | :struct:`SelectKOptions` | \(5) \(6) \(3) | ++-----------------------+------------+---------------------------------------------------------+-------------------+--------------------------------+----------------+ +| sort_indices | Unary | Boolean, Numeric, Temporal, Binary- and String-like | UInt64 | :struct:`SortOptions` | \(2) \(5) \(3) | ++-----------------------+------------+---------------------------------------------------------+-------------------+--------------------------------+----------------+ * \(1) The output is an array of indices into the input array, that define a partial non-stable sort such that the *N*'th index points to the *N*'th @@ -1387,6 +1383,9 @@ value, but smaller than nulls. table. If the input is a record batch or table, one or more sort keys must be specified. +* \(6) The output is an array of indices into the input, that define a + non-stable sort of the input. + .. _cpp-compute-vector-structural-transforms: Structural transforms diff --git a/docs/source/python/api/compute.rst b/docs/source/python/api/compute.rst index f7f740c24e5..ec3f4101447 100644 --- a/docs/source/python/api/compute.rst +++ b/docs/source/python/api/compute.rst @@ -349,6 +349,7 @@ Sorts and partitions :toctree: ../generated/ partition_nth_indices + select_k_unstable sort_indices Structural Transforms diff --git a/python/pyarrow/_compute.pyx b/python/pyarrow/_compute.pyx index e0eb41f4eaa..f69a6d6c73b 100644 --- a/python/pyarrow/_compute.pyx +++ b/python/pyarrow/_compute.pyx @@ -1130,6 +1130,35 @@ class SortOptions(_SortOptions): self._set_options(sort_keys) +cdef class _SelectKOptions(FunctionOptions): + def _set_options(self, k, sort_keys): + cdef: + c_string c_name + vector[CSortKey] c_sort_keys + CSortOrder c_order + + for name, order in sort_keys: + if order == "ascending": + c_order = CSortOrder_Ascending + elif order == "descending": + c_order = CSortOrder_Descending + else: + raise ValueError( + "{!r} is not a valid order".format(order) + ) + c_name = tobytes(name) + c_sort_keys.push_back(CSortKey(c_name, c_order)) + + self.wrapped.reset(new CSelectKOptions(k, c_sort_keys)) + + +class SelectKOptions(_SelectKOptions): + def __init__(self, k, sort_keys=None): + if sort_keys is None: + sort_keys = [] + self._set_options(k, sort_keys) + + cdef class _QuantileOptions(FunctionOptions): def _set_options(self, quantiles, interp, skip_nulls, min_count): interp_dict = { diff --git a/python/pyarrow/compute.py b/python/pyarrow/compute.py index 5ee9bb0a434..2f8882658fe 100644 --- a/python/pyarrow/compute.py +++ b/python/pyarrow/compute.py @@ -53,6 +53,7 @@ SetLookupOptions, SliceOptions, SortOptions, + SelectKOptions, SplitOptions, SplitPatternOptions, StrftimeOptions, diff --git a/python/pyarrow/includes/libarrow.pxd b/python/pyarrow/includes/libarrow.pxd index b1e963c24e5..3a9137482e6 100644 --- a/python/pyarrow/includes/libarrow.pxd +++ b/python/pyarrow/includes/libarrow.pxd @@ -2058,6 +2058,12 @@ cdef extern from "arrow/compute/api.h" namespace "arrow::compute" nogil: CSortOptions(vector[CSortKey] sort_keys) vector[CSortKey] sort_keys + cdef cppclass CSelectKOptions \ + "arrow::compute::SelectKOptions"(CFunctionOptions): + CSelectKOptions(int64_t k, vector[CSortKey] sort_keys) + int64_t k + vector[CSortKey] sort_keys + enum CQuantileInterp \ "arrow::compute::QuantileOptions::Interpolation": CQuantileInterp_LINEAR "arrow::compute::QuantileOptions::LINEAR" diff --git a/python/pyarrow/tests/test_compute.py b/python/pyarrow/tests/test_compute.py index 2349f84b55e..d4f147c73c1 100644 --- a/python/pyarrow/tests/test_compute.py +++ b/python/pyarrow/tests/test_compute.py @@ -123,6 +123,7 @@ def test_option_class_equality(): pc.SetLookupOptions(value_set=pa.array([1])), pc.SliceOptions(start=0, stop=1, step=1), pc.SplitPatternOptions(pattern="pattern"), + pc.SelectKOptions(k=0, sort_keys=[("b", "ascending")]), pc.StrptimeOptions("%Y", "s"), pc.TrimOptions(" "), pc.StrftimeOptions(), @@ -1780,6 +1781,69 @@ def test_partition_nth(): for i in range(pivot, len(data))) +def test_select_k_array(): + def validate_select_k(select_k_indices, arr, order, stable_sort=False): + sorted_indices = pc.sort_indices(arr, sort_keys=[("dummy", order)]) + head_k_indices = sorted_indices.slice(0, len(select_k_indices)) + if stable_sort: + assert select_k_indices == head_k_indices + else: + expected = pc.take(arr, head_k_indices) + actual = pc.take(arr, select_k_indices) + assert actual == expected + + arr = pa.array([1, 2, None, 0]) + for order in ["descending", "ascending"]: + for k in [0, 2, 4]: + result = pc.select_k_unstable( + arr, k=k, sort_keys=[("dummy", order)]) + validate_select_k(result, arr, order) + + result = pc.select_k_unstable(arr, options=pc.SelectKOptions( + k=2, sort_keys=[("dummy", "descending")])) + validate_select_k(result, arr, "descending") + + result = pc.select_k_unstable(arr, options=pc.SelectKOptions( + k=2, sort_keys=[("dummy", "ascending")])) + validate_select_k(result, arr, "ascending") + + +def test_select_k_table(): + table = pa.table({"a": [1, 2, 0], "b": [1, 0, 1]}) + + def validate_select_k(select_k_indices, table, sort_keys, + stable_sort=False): + sorted_indices = pc.sort_indices(table, sort_keys=sort_keys) + head_k_indices = sorted_indices.slice(0, len(select_k_indices)) + if stable_sort: + assert select_k_indices == head_k_indices + else: + expected = pc.take(table, head_k_indices) + actual = pc.take(table, select_k_indices) + assert actual == expected + + for k in [0, 2, 4]: + result = pc.select_k_unstable( + table, k=k, sort_keys=[("a", "ascending")]) + validate_select_k(result, table, sort_keys=[("a", "ascending")]) + + result = pc.select_k_unstable( + table, k=k, sort_keys=[("a", "ascending"), ("b", "ascending")] + ) + validate_select_k(result, table, sort_keys=[ + ("a", "ascending"), ("b", "ascending")]) + + with pytest.raises(ValueError, + match="SelectK requires a nonnegative `k`"): + pc.select_k_unstable(table) + + with pytest.raises(ValueError, match="not a valid order"): + pc.select_k_unstable(table, k=k, sort_keys=[("a", "nonscending")]) + + with pytest.raises(ValueError, match="Nonexistent sort key column"): + pc.select_k_unstable(table, k=k, sort_keys=[("unknown", "ascending")]) + + def test_array_sort_indices(): arr = pa.array([1, 2, None, 0]) result = pc.array_sort_indices(arr)