From 3be07c339ceac90cd0301e419c979887ce70585a Mon Sep 17 00:00:00 2001 From: Alexander Date: Thu, 19 Aug 2021 15:06:57 -0500 Subject: [PATCH 01/19] Improve NthToIndices - Add optional parameter to choose order (ascending or descending) - Add optional parameter to choose algorithm strategy (IntroSelect[default] or HeapBased ) - Implement HeapBased partition_nth - Run benchmarks comparing IntroSelect and HeapBased - Prepare APIs to support NthToIndices for DATUM update minor update topk update --- cpp/src/arrow/compute/api_vector.cc | 30 ++ cpp/src/arrow/compute/api_vector.h | 30 ++ cpp/src/arrow/compute/kernels/CMakeLists.txt | 1 + .../arrow/compute/kernels/select_k_test.cc | 364 ++++++++++++++++++ cpp/src/arrow/compute/kernels/vector_sort.cc | 270 +++++++++++++ cpp/src/arrow/util/heap.h | 72 ++++ 6 files changed, 767 insertions(+) create mode 100644 cpp/src/arrow/compute/kernels/select_k_test.cc create mode 100644 cpp/src/arrow/util/heap.h diff --git a/cpp/src/arrow/compute/api_vector.cc b/cpp/src/arrow/compute/api_vector.cc index d4c4a915999..e4787fb7047 100644 --- a/cpp/src/arrow/compute/api_vector.cc +++ b/cpp/src/arrow/compute/api_vector.cc @@ -111,6 +111,10 @@ static auto kSortOptionsType = GetFunctionOptionsType(DataMember("sort_keys", &SortOptions::sort_keys)); static auto kPartitionNthOptionsType = GetFunctionOptionsType( DataMember("pivot", &PartitionNthOptions::pivot)); +static auto kSelectKOptionsType = GetFunctionOptionsType( + DataMember("k", &SelectKOptions::k), DataMember("keys", &SelectKOptions::keys), + DataMember("order", &SelectKOptions::order), + DataMember("keep", &SelectKOptions::keep)); } // namespace } // namespace internal @@ -140,6 +144,15 @@ PartitionNthOptions::PartitionNthOptions(int64_t pivot) : FunctionOptions(internal::kPartitionNthOptionsType), pivot(pivot) {} constexpr char PartitionNthOptions::kTypeName[]; +SelectKOptions::SelectKOptions(int64_t k, std::vector keys, std::string keep, + SortOrder order) + : FunctionOptions(internal::kSelectKOptionsType), + k(k), + keys(std::move(keys)), + keep(keep), + order(order) {} +constexpr char SelectKOptions::kTypeName[]; + namespace internal { void RegisterVectorOptions(FunctionRegistry* registry) { DCHECK_OK(registry->AddFunctionOptionsType(kFilterOptionsType)); @@ -148,6 +161,7 @@ void RegisterVectorOptions(FunctionRegistry* registry) { DCHECK_OK(registry->AddFunctionOptionsType(kArraySortOptionsType)); DCHECK_OK(registry->AddFunctionOptionsType(kSortOptionsType)); DCHECK_OK(registry->AddFunctionOptionsType(kPartitionNthOptionsType)); + DCHECK_OK(registry->AddFunctionOptionsType(kSelectKOptionsType)); } } // namespace internal @@ -162,6 +176,22 @@ Result> NthToIndices(const Array& values, int64_t n, return result.make_array(); } +Result> TopK(const Array& values, int64_t k, + const std::string& keep, ExecContext* ctx) { + SelectKOptions options(k, {}, keep); + ARROW_ASSIGN_OR_RAISE(Datum result, + CallFunction("array_top_k", {Datum(values)}, &options, ctx)); + return result.make_array(); +} + +Result TopK(const Datum& datum, int64_t k, SelectKOptions options, + ExecContext* ctx) { + options.k = k; + ARROW_ASSIGN_OR_RAISE(Datum result, + CallFunction("top_k", {Datum(datum)}, &options, ctx)); + return result; +} + Result ReplaceWithMask(const Datum& values, const Datum& mask, const Datum& replacements, ExecContext* ctx) { return CallFunction("replace_with_mask", {values, mask, replacements}, ctx); diff --git a/cpp/src/arrow/compute/api_vector.h b/cpp/src/arrow/compute/api_vector.h index 5dc68fc5c83..848facd8b81 100644 --- a/cpp/src/arrow/compute/api_vector.h +++ b/cpp/src/arrow/compute/api_vector.h @@ -120,6 +120,25 @@ class ARROW_EXPORT SortOptions : public FunctionOptions { std::vector sort_keys; }; +/// \brief Partitioning options for NthToIndices +class ARROW_EXPORT SelectKOptions : public FunctionOptions { + public: + explicit SelectKOptions(int64_t pivot = 0, std::vector keys = {}, + std::string keep = "first", + SortOrder order = SortOrder::Ascending); + constexpr static char const kTypeName[] = "SelectKOptions"; + static SelectKOptions TopKDefault() { + return SelectKOptions{0, {}, "first", SortOrder::Ascending}; + } + static SelectKOptions BottomKDefault() { + return SelectKOptions{0, {}, "first", SortOrder::Descending}; + } + int64_t k; + std::vector keys; + std::string keep; + SortOrder order; +}; + /// \brief Partitioning options for NthToIndices class ARROW_EXPORT PartitionNthOptions : public FunctionOptions { public: @@ -252,6 +271,17 @@ ARROW_EXPORT Result> NthToIndices(const Array& values, int64_t n, ExecContext* ctx = NULLPTR); +/// @TODO +ARROW_EXPORT +Result> TopK(const Array& values, int64_t k, + const std::string& keep = "first", + ExecContext* ctx = NULLPTR); + +/// @TODO +ARROW_EXPORT +Result TopK(const Datum& datum, int64_t k, SelectKOptions options, + ExecContext* ctx = NULLPTR); + /// \brief Returns the indices that would sort an array in the /// specified order. /// diff --git a/cpp/src/arrow/compute/kernels/CMakeLists.txt b/cpp/src/arrow/compute/kernels/CMakeLists.txt index 4096e497c0a..dc347b008c0 100644 --- a/cpp/src/arrow/compute/kernels/CMakeLists.txt +++ b/cpp/src/arrow/compute/kernels/CMakeLists.txt @@ -50,6 +50,7 @@ add_arrow_compute_test(vector_test vector_replace_test.cc vector_selection_test.cc vector_sort_test.cc + select_k_test.cc test_util.cc) add_arrow_benchmark(vector_hash_benchmark PREFIX "arrow-compute") diff --git a/cpp/src/arrow/compute/kernels/select_k_test.cc b/cpp/src/arrow/compute/kernels/select_k_test.cc new file mode 100644 index 00000000000..0def147d29f --- /dev/null +++ b/cpp/src/arrow/compute/kernels/select_k_test.cc @@ -0,0 +1,364 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include +#include +#include +#include +#include + +#include "arrow/array/array_decimal.h" +#include "arrow/array/concatenate.h" +#include "arrow/compute/api_vector.h" +#include "arrow/compute/kernels/test_util.h" +#include "arrow/table.h" +#include "arrow/testing/gtest_common.h" +#include "arrow/testing/gtest_util.h" +#include "arrow/testing/random.h" +#include "arrow/testing/util.h" +#include "arrow/type_traits.h" + +namespace arrow { + +using internal::checked_cast; +using internal::checked_pointer_cast; + +namespace compute { + +namespace { + +// Convert arrow::Type to arrow::DataType. If arrow::Type isn't +// parameter free, this returns an arrow::DataType with the default +// parameter. +template +enable_if_t::is_parameter_free, std::shared_ptr> +TypeToDataType() { + return TypeTraits::type_singleton(); +} + +template +enable_if_t::value, std::shared_ptr> +TypeToDataType() { + return timestamp(TimeUnit::MILLI); +} + +template +enable_if_t::value, std::shared_ptr> +TypeToDataType() { + return time32(TimeUnit::MILLI); +} + +template +enable_if_t::value, std::shared_ptr> +TypeToDataType() { + return time64(TimeUnit::NANO); +} + +// ---------------------------------------------------------------------- +// Tests for SelectK + +template +auto GetLogicalValue(const ArrayType& array, uint64_t index) + -> decltype(array.GetView(index)) { + return array.GetView(index); +} + +Decimal128 GetLogicalValue(const Decimal128Array& array, uint64_t index) { + return Decimal128(array.Value(index)); +} + +Decimal256 GetLogicalValue(const Decimal256Array& array, uint64_t index) { + return Decimal256(array.Value(index)); +} + +} // namespace +template +class NthComparator { + public: + bool operator()(const ArrayType& array, uint64_t lhs, uint64_t rhs) { + if (array.IsNull(rhs)) return true; + if (array.IsNull(lhs)) return false; + const auto lval = GetLogicalValue(array, lhs); + const auto rval = GetLogicalValue(array, rhs); + if (is_floating_type::value) { + // NaNs ordered after non-NaNs + if (rval != rval) return true; + if (lval != lval) return false; + } + return lval <= rval; + } +}; + +template +class SortComparator { + public: + bool operator()(const ArrayType& array, SortOrder order, uint64_t lhs, uint64_t rhs) { + if (array.IsNull(rhs) && array.IsNull(lhs)) return lhs < rhs; + if (array.IsNull(rhs)) return true; + if (array.IsNull(lhs)) return false; + const auto lval = GetLogicalValue(array, lhs); + const auto rval = GetLogicalValue(array, rhs); + if (is_floating_type::value) { + const bool lhs_isnan = lval != lval; + const bool rhs_isnan = rval != rval; + if (lhs_isnan && rhs_isnan) return lhs < rhs; + if (rhs_isnan) return true; + if (lhs_isnan) return false; + } + if (lval == rval) return lhs < rhs; + if (order == SortOrder::Ascending) { + return lval < rval; + } else { + return lval > rval; + } + } +}; + +template +class TestSelectKBase : public TestBase { + using ArrayType = typename TypeTraits::ArrayType; + + protected: + void Validate(const ArrayType& array, int n, UInt64Array& offsets) { + if (n >= array.length()) { + for (int i = 0; i < array.length(); ++i) { + ASSERT_TRUE(offsets.Value(i) == (uint64_t)i); + } + } else { + NthComparator compare; + uint64_t nth = offsets.Value(n); + + for (int i = 0; i < n; ++i) { + uint64_t lhs = offsets.Value(i); + ASSERT_TRUE(compare(array, lhs, nth)); + } + for (int i = n + 1; i < array.length(); ++i) { + uint64_t rhs = offsets.Value(i); + ASSERT_TRUE(compare(array, nth, rhs)); + } + } + } + + void AssertSelectKArray(const std::shared_ptr values, int n) { + ASSERT_OK_AND_ASSIGN(std::shared_ptr offsets, TopK(*values, n)); + // null_count field should have been initialized to 0, for convenience + ASSERT_EQ(offsets->data()->null_count, 0); + ValidateOutput(*offsets); + Validate(*checked_pointer_cast(values), n, + *checked_pointer_cast(offsets)); + } + + void AssertSelectKJson(const std::string& values, int n) { + AssertSelectKArray(ArrayFromJSON(GetType(), values), n); + } + + virtual std::shared_ptr GetType() = 0; +}; + +template +class TestSelectK : public TestSelectKBase { + protected: + std::shared_ptr GetType() override { return TypeToDataType(); } +}; + +template +class TestSelectKForReal : public TestSelectK {}; +TYPED_TEST_SUITE(TestSelectKForReal, RealArrowTypes); + +template +class TestSelectKForIntegral : public TestSelectK {}; +TYPED_TEST_SUITE(TestSelectKForIntegral, IntegralArrowTypes); + +template +class TestSelectKForBool : public TestSelectK {}; +TYPED_TEST_SUITE(TestSelectKForBool, ::testing::Types); + +template +class TestSelectKForTemporal : public TestSelectK {}; +TYPED_TEST_SUITE(TestSelectKForTemporal, TemporalArrowTypes); + +template +class TestSelectKForDecimal : public TestSelectKBase { + std::shared_ptr GetType() override { + return std::make_shared(5, 2); + } +}; +TYPED_TEST_SUITE(TestSelectKForDecimal, DecimalArrowTypes); + +template +class TestSelectKForStrings : public TestSelectK {}; +TYPED_TEST_SUITE(TestSelectKForStrings, testing::Types); + +TYPED_TEST(TestSelectKForReal, SelectKDoesNotProvideDefaultOptions) { + auto input = ArrayFromJSON(this->GetType(), "[null, 1, 3.3, null, 2, 5.3]"); + ASSERT_RAISES(Invalid, CallFunction("top_k", {input})); +} + +TYPED_TEST(TestSelectKForReal, Real) { + this->AssertSelectKJson("[null, 1, 3.3, null, 2, 5.3]", 0); + this->AssertSelectKJson("[null, 1, 3.3, null, 2, 5.3]", 2); + this->AssertSelectKJson("[null, 1, 3.3, null, 2, 5.3]", 5); + this->AssertSelectKJson("[null, 1, 3.3, null, 2, 5.3]", 6); + + this->AssertSelectKJson("[null, 2, NaN, 3, 1]", 0); + this->AssertSelectKJson("[null, 2, NaN, 3, 1]", 1); + this->AssertSelectKJson("[null, 2, NaN, 3, 1]", 2); + this->AssertSelectKJson("[null, 2, NaN, 3, 1]", 3); + this->AssertSelectKJson("[null, 2, NaN, 3, 1]", 4); + this->AssertSelectKJson("[NaN, 2, null, 3, 1]", 3); + this->AssertSelectKJson("[NaN, 2, null, 3, 1]", 4); +} + +TYPED_TEST(TestSelectKForIntegral, Integral) { + this->AssertSelectKJson("[null, 1, 3, null, 2, 5]", 0); + this->AssertSelectKJson("[null, 1, 3, null, 2, 5]", 2); + this->AssertSelectKJson("[null, 1, 3, null, 2, 5]", 5); + this->AssertSelectKJson("[null, 1, 3, null, 2, 5]", 6); +} + +TYPED_TEST(TestSelectKForBool, Bool) { + this->AssertSelectKJson("[null, false, true, null, false, true]", 0); + this->AssertSelectKJson("[null, false, true, null, false, true]", 2); + this->AssertSelectKJson("[null, false, true, null, false, true]", 5); + this->AssertSelectKJson("[null, false, true, null, false, true]", 6); +} + +TYPED_TEST(TestSelectKForTemporal, Temporal) { + this->AssertSelectKJson("[null, 1, 3, null, 2, 5]", 0); + this->AssertSelectKJson("[null, 1, 3, null, 2, 5]", 2); + this->AssertSelectKJson("[null, 1, 3, null, 2, 5]", 5); + this->AssertSelectKJson("[null, 1, 3, null, 2, 5]", 6); +} + +TYPED_TEST(TestSelectKForDecimal, Decimal) { + const std::string values = R"(["123.45", null, "-123.45", "456.78", "-456.78"])"; + this->AssertSelectKJson(values, 0); + this->AssertSelectKJson(values, 2); + this->AssertSelectKJson(values, 4); + this->AssertSelectKJson(values, 5); +} + +TYPED_TEST(TestSelectKForStrings, Strings) { + this->AssertSelectKJson(R"(["testing", null, "nth", "for", null, "strings"])", 0); + this->AssertSelectKJson(R"(["testing", null, "nth", "for", null, "strings"])", 2); + this->AssertSelectKJson(R"(["testing", null, "nth", "for", null, "strings"])", 5); + this->AssertSelectKJson(R"(["testing", null, "nth", "for", null, "strings"])", 6); +} + +template +class TestSelectKRandom : public TestSelectKBase { + public: + std::shared_ptr GetType() override { + EXPECT_TRUE(0) << "shouldn't be used"; + return nullptr; + } +}; + +using SelectKableTypes = + ::testing::Types; + +class RandomImpl { + protected: + random::RandomArrayGenerator generator_; + std::shared_ptr type_; + + explicit RandomImpl(random::SeedType seed, std::shared_ptr type) + : generator_(seed), type_(std::move(type)) {} + + public: + std::shared_ptr Generate(uint64_t count, double null_prob) { + return generator_.ArrayOf(type_, count, null_prob); + } +}; + +template +class Random : public RandomImpl { + public: + explicit Random(random::SeedType seed) + : RandomImpl(seed, TypeTraits::type_singleton()) {} +}; + +template <> +class Random : public RandomImpl { + using CType = float; + + public: + explicit Random(random::SeedType seed) : RandomImpl(seed, float32()) {} + + std::shared_ptr Generate(uint64_t count, double null_prob, double nan_prob = 0) { + return generator_.Float32(count, std::numeric_limits::min(), + std::numeric_limits::max(), null_prob, nan_prob); + } +}; + +template <> +class Random : public RandomImpl { + using CType = double; + + public: + explicit Random(random::SeedType seed) : RandomImpl(seed, float64()) {} + + std::shared_ptr Generate(uint64_t count, double null_prob, double nan_prob = 0) { + return generator_.Float64(count, std::numeric_limits::min(), + std::numeric_limits::max(), null_prob, nan_prob); + } +}; + +template <> +class Random : public RandomImpl { + public: + explicit Random(random::SeedType seed, + std::shared_ptr type = decimal128(18, 5)) + : RandomImpl(seed, std::move(type)) {} +}; + +template +class RandomRange : public RandomImpl { + using CType = typename TypeTraits::CType; + + public: + explicit RandomRange(random::SeedType seed) + : RandomImpl(seed, TypeTraits::type_singleton()) {} + + std::shared_ptr Generate(uint64_t count, int range, double null_prob) { + CType min = std::numeric_limits::min(); + CType max = min + range; + if (sizeof(CType) < 4 && (range + min) > std::numeric_limits::max()) { + max = std::numeric_limits::max(); + } + return generator_.Numeric(count, min, max, null_prob); + } +}; + +TYPED_TEST_SUITE(TestSelectKRandom, SelectKableTypes); + +TYPED_TEST(TestSelectKRandom, RandomValues) { + Random rand(0x61549225); + int length = 100; + for (auto null_probability : {0.0, 0.1, 0.5, 1.0}) { + // Try n from 0 to out of bound + for (int n = 0; n <= length; ++n) { + auto array = rand.Generate(length, null_probability); + this->AssertSelectKArray(array, n); + } + } +} + +} // namespace compute +} // namespace arrow diff --git a/cpp/src/arrow/compute/kernels/vector_sort.cc b/cpp/src/arrow/compute/kernels/vector_sort.cc index 7fa43e715d8..838255b6956 100644 --- a/cpp/src/arrow/compute/kernels/vector_sort.cc +++ b/cpp/src/arrow/compute/kernels/vector_sort.cc @@ -32,6 +32,7 @@ #include "arrow/util/bitmap.h" #include "arrow/util/bitmap_ops.h" #include "arrow/util/checked_cast.h" +#include "arrow/util/heap.h" #include "arrow/util/optional.h" #include "arrow/visitor_inline.h" @@ -360,6 +361,254 @@ struct PartitionNthToIndices { } }; +// ---------------------------------------------------------------------- +// TopK/BottomK implementations + +using SelectKOptionsState = internal::OptionsWrapper; +const auto kDefaultTopKOptions = SelectKOptions::TopKDefault(); +const auto kDefaultBottomKOptions = SelectKOptions::BottomKDefault(); + +template +struct ArraySelectNth { + using ArrayType = typename TypeTraits::ArrayType; + + static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { + using GetView = GetViewType; + + if (ctx->state() == nullptr) { + return Status::Invalid("NthToIndices requires PartitionNthOptions"); + } + + ArrayType arr(batch[0].array()); + + int64_t pivot = SelectKOptionsState::Get(ctx).k; + SortOrder order = SelectKOptionsState::Get(ctx).order; + std::string keep = SelectKOptionsState::Get(ctx).keep; + if (pivot > arr.length()) { + return Status::IndexError("NthToIndices index out of bound"); + } + ArrayData* out_arr = out->mutable_array(); + uint64_t* out_begin = out_arr->GetMutableValues(1); + uint64_t* out_end = out_begin + arr.length(); + std::iota(out_begin, out_end, 0); + if (pivot == arr.length()) { + return Status::OK(); + } + auto nulls_begin = + PartitionNulls(out_begin, out_end, arr, 0); + auto nth_begin = out_begin + pivot; + if (nth_begin < nulls_begin) { + std::function cmp; + if (order == SortOrder::Ascending) { + cmp = [&arr](uint64_t left, uint64_t right) -> bool { + const auto lval = GetView::LogicalValue(arr.GetView(left)); + const auto rval = GetView::LogicalValue(arr.GetView(right)); + return lval < rval; + }; + } else { + cmp = [&arr](uint64_t left, uint64_t right) -> bool { + const auto lval = GetView::LogicalValue(arr.GetView(left)); + const auto rval = GetView::LogicalValue(arr.GetView(right)); + return rval < lval; + }; + } + arrow::internal::Heap heap(cmp); + + for (uint64_t* iter = out_begin; iter != nth_begin; ++iter) { + heap.Push(*iter); + } + for (uint64_t* iter = out_begin; iter != nulls_begin; ++iter) { + uint64_t x_index = *iter; + const auto lval = GetView::LogicalValue(arr.GetView(x_index)); + const auto rval = GetView::LogicalValue(arr.GetView(heap.Top())); + if (order == SortOrder::Ascending) { + if (keep == "first") { + if (lval < rval) { + heap.ReplaceTop(x_index); + } + } + } else { + if (rval < lval) { + heap.ReplaceTop(x_index); + } + } + } + std::copy(heap.Data(), heap.Data() + pivot, out_begin); + } + return Status::OK(); + } +}; + +const FunctionDoc top_k_doc( + "Return the indices that would partition an array array, record batch or table\n" + "around a pivot", + ("@TODO"), {"input"}, "PartitionNthOptions"); + +const FunctionDoc bottom_k_doc( + "Return the indices that would partition an array array, record batch or table\n" + "around a pivot", + ("@TODO"), {"input"}, "PartitionNthOptions"); + +class ChunkedArraySelecter : public TypeVisitor { + public: + ChunkedArraySelecter(ExecContext* ctx, uint64_t* indices_begin, uint64_t* indices_end, + int64_t pivot, const ChunkedArray& chunked_array, + const SortOrder order) + : TypeVisitor(), + indices_begin_(indices_begin), + indices_end_(indices_end), + pivot_(pivot), + chunked_array_(chunked_array), + physical_type_(GetPhysicalType(chunked_array.type())), + physical_chunks_(GetPhysicalChunks(chunked_array_, physical_type_)), + order_(order), + ctx_(ctx) {} + + Status Run() { return physical_type_->Accept(this); } + +#define VISIT(TYPE) \ + Status Visit(const TYPE& type) { return SelectNthInternal(); } + + VISIT_PHYSICAL_TYPES(VISIT) + +#undef VISIT + + template + Status SelectNthInternal() { + using GetView = GetViewType; + using ArrayType = typename TypeTraits::ArrayType; + + const auto num_chunks = chunked_array_.num_chunks(); + if (num_chunks == 0) { + return Status::OK(); + } + + for (const auto& chunk : physical_chunks_) { + ArrayType arr(chunk->data()); + + auto nulls_begin = PartitionNulls( + indices_begin_, indices_end_, arr, 0); + + auto nth_begin = indices_begin_ + pivot_; + if (nth_begin < nulls_begin) { + std::function cmp; + if (order_ == SortOrder::Ascending) { + cmp = [&arr](uint64_t left, uint64_t right) -> bool { + const auto lval = GetView::LogicalValue(arr.GetView(left)); + const auto rval = GetView::LogicalValue(arr.GetView(right)); + return lval < rval; + }; + } else { + cmp = [&arr](uint64_t left, uint64_t right) -> bool { + const auto lval = GetView::LogicalValue(arr.GetView(left)); + const auto rval = GetView::LogicalValue(arr.GetView(right)); + return rval < lval; + }; + } + arrow::internal::Heap heap(cmp); + + for (uint64_t* iter = indices_begin_; iter != nth_begin; ++iter) { + heap.Push(*iter); + } + for (uint64_t* iter = indices_begin_; iter != nulls_begin; ++iter) { + auto x = *iter; + if (x < heap.Top()) { + heap.ReplaceTop(x); + } + } + std::copy(heap.Data(), heap.Data() + pivot_, indices_begin_); + } + } + return Status::OK(); + } + + uint64_t* indices_begin_; + uint64_t* indices_end_; + int64_t pivot_; + const ChunkedArray& chunked_array_; + const std::shared_ptr physical_type_; + const ArrayVector physical_chunks_; + const SortOrder order_; + ExecContext* ctx_; +}; +template +class SelectNthMetaFunction { + public: + Result ExecuteImpl(const std::vector& args, + const FunctionOptions* options, ExecContext* ctx) const { + const SelectKOptions& partition_options = + static_cast(*options); + + switch (args[0].kind()) { + case Datum::ARRAY: + return SelectNth(*args[0].make_array(), partition_options, ctx); + break; + case Datum::CHUNKED_ARRAY: + return SelectNth(*args[0].chunked_array(), partition_options, ctx); + break; + default: + break; + } + return Status::NotImplemented( + "Unsupported types for sort_indices operation: " + "values=", + args[0].ToString()); + } + + private: + Result SelectNth(const Array& values, const SelectKOptions& options, + ExecContext* ctx) const { + return CallFunction("partition_nth_indices_doc", {values}, &options, ctx); + } + + Result SelectNth(const ChunkedArray& chunked_array, + const SelectKOptions& options, ExecContext* ctx) const { + int64_t pivot = options.k; + SortOrder order = sort_order; + auto out_type = uint64(); + auto length = chunked_array.length(); + auto buffer_size = BitUtil::BytesForBits( + length * std::static_pointer_cast(out_type)->bit_width()); + std::vector> buffers(2); + ARROW_ASSIGN_OR_RAISE(buffers[1], + AllocateResizableBuffer(buffer_size, ctx->memory_pool())); + auto out = std::make_shared(out_type, length, buffers, 0); + auto out_begin = out->GetMutableValues(1); + auto out_end = out_begin + length; + std::iota(out_begin, out_end, 0); + + ChunkedArraySelecter partitioner(ctx, out_begin, out_end, pivot, chunked_array, + order); + ARROW_RETURN_NOT_OK(partitioner.Run()); + return Datum(out); + } +}; + +class TopKMetaFunction : public MetaFunction { + public: + TopKMetaFunction() + : MetaFunction("top_k", Arity::Unary(), &top_k_doc, &kDefaultTopKOptions) {} + + Result ExecuteImpl(const std::vector& args, + const FunctionOptions* options, + ExecContext* ctx) const override { + SelectNthMetaFunction impl; + return impl.ExecuteImpl(args, options, ctx); + } +}; + +class BottomKMetaFunction : public MetaFunction { + public: + BottomKMetaFunction() + : MetaFunction("bottom_k", Arity::Unary(), &top_k_doc, &kDefaultBottomKOptions) {} + Result ExecuteImpl(const std::vector& args, + const FunctionOptions* options, + ExecContext* ctx) const override { + SelectNthMetaFunction impl; + return impl.ExecuteImpl(args, options, ctx); + } +}; + // ---------------------------------------------------------------------- // Array sorting implementations @@ -1806,6 +2055,11 @@ const FunctionDoc partition_nth_indices_doc( "The pivot index `N` must be given in PartitionNthOptions."), {"array"}, "PartitionNthOptions"); +const FunctionDoc array_top_k_doc("Return the indices that would sort an array", + ("@TODO."), {"array"}, "ArraySortOptions"); + +const FunctionDoc array_bottom_k_doc("Return the indices that would sort an array", + ("@TODO"), {"array"}, "ArraySortOptions"); } // namespace void RegisterVectorSort(FunctionRegistry* registry) { @@ -1829,6 +2083,22 @@ void RegisterVectorSort(FunctionRegistry* registry) { base.init = PartitionNthToIndicesState::Init; AddSortingKernels(base, part_indices.get()); DCHECK_OK(registry->AddFunction(std::move(part_indices))); + + // top_k + auto part_topk = std::make_shared("array_top_k", Arity::Unary(), + &array_bottom_k_doc); + base.init = SelectKOptionsState::Init; + AddSortingKernels(base, part_topk.get()); + DCHECK_OK(registry->AddFunction(std::move(part_topk))); + DCHECK_OK(registry->AddFunction(std::make_shared())); + + // bottom_k + auto part_bottomk = + std::make_shared("array_bottom_k", Arity::Unary(), &bottom_k_doc); + base.init = SelectKOptionsState::Init; + AddSortingKernels(base, part_topk.get()); + DCHECK_OK(registry->AddFunction(std::move(part_topk))); + DCHECK_OK(registry->AddFunction(std::make_shared())); } #undef VISIT_PHYSICAL_TYPES diff --git a/cpp/src/arrow/util/heap.h b/cpp/src/arrow/util/heap.h new file mode 100644 index 00000000000..afcb62ea3f6 --- /dev/null +++ b/cpp/src/arrow/util/heap.h @@ -0,0 +1,72 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include +#include +#include +#include "arrow/util/macros.h" + +namespace arrow { +namespace internal { + +// A Heap class, is a simple wrapper to make heap operation simpler. +// This class is immutable by design +template > +class ARROW_EXPORT Heap { + public: + explicit Heap() : values_(), comp_() {} + explicit Heap(const Compare& compare) : values_(), comp_(compare) {} + + Heap(Heap&&) = default; + Heap& operator=(Heap&&) = default; + + T* Data() { return values_.data(); } + + const T& Top() const { return values_.front(); } + + bool Empty() const { return values_.empty(); } + + size_t Size() const { return values_.size(); } + + void Push(const T& value) { + values_.push_back(value); + std::push_heap(values_.begin(), values_.end(), comp_); + } + + void Pop() { + std::pop_heap(values_.begin(), values_.end(), comp_); + values_.pop_back(); + } + + void ReplaceTop(const T& value) { + std::pop_heap(values_.begin(), values_.end(), comp_); + values_.back() = value; + std::push_heap(values_.begin(), values_.end(), comp_); + } + + protected: + ARROW_DISALLOW_COPY_AND_ASSIGN(Heap); + + std::vector values_; + + Compare comp_; +}; + +} // namespace internal +} // namespace arrow \ No newline at end of file From 5af5512e6f76c830af21a8bb52d73c374d106487 Mon Sep 17 00:00:00 2001 From: Alexander Date: Thu, 19 Aug 2021 15:25:41 -0500 Subject: [PATCH 02/19] minor fix --- cpp/src/arrow/compute/kernels/vector_sort.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/vector_sort.cc b/cpp/src/arrow/compute/kernels/vector_sort.cc index 838255b6956..e2bad24ec07 100644 --- a/cpp/src/arrow/compute/kernels/vector_sort.cc +++ b/cpp/src/arrow/compute/kernels/vector_sort.cc @@ -2096,8 +2096,8 @@ void RegisterVectorSort(FunctionRegistry* registry) { auto part_bottomk = std::make_shared("array_bottom_k", Arity::Unary(), &bottom_k_doc); base.init = SelectKOptionsState::Init; - AddSortingKernels(base, part_topk.get()); - DCHECK_OK(registry->AddFunction(std::move(part_topk))); + AddSortingKernels(base, part_bottomk.get()); + DCHECK_OK(registry->AddFunction(std::move(part_bottomk))); DCHECK_OK(registry->AddFunction(std::make_shared())); } From 1084be3dcf6f05cbed685355bd377c3ae671dc6a Mon Sep 17 00:00:00 2001 From: Alexander Date: Thu, 26 Aug 2021 18:50:09 -0500 Subject: [PATCH 03/19] second version --- cpp/src/arrow/compute/api_vector.cc | 46 ++- cpp/src/arrow/compute/api_vector.h | 28 +- .../arrow/compute/kernels/select_k_test.cc | 309 ++++++++++++++--- cpp/src/arrow/compute/kernels/vector_sort.cc | 325 ++++++++++-------- cpp/src/arrow/testing/gtest_util.h | 3 +- cpp/src/arrow/util/heap.h | 8 +- 6 files changed, 504 insertions(+), 215 deletions(-) diff --git a/cpp/src/arrow/compute/api_vector.cc b/cpp/src/arrow/compute/api_vector.cc index e4787fb7047..3b3a92a24a6 100644 --- a/cpp/src/arrow/compute/api_vector.cc +++ b/cpp/src/arrow/compute/api_vector.cc @@ -178,18 +178,52 @@ Result> NthToIndices(const Array& values, int64_t n, Result> TopK(const Array& values, int64_t k, const std::string& keep, ExecContext* ctx) { - SelectKOptions options(k, {}, keep); + SelectKOptions options(k, {}, keep, SortOrder::Ascending); ARROW_ASSIGN_OR_RAISE(Datum result, - CallFunction("array_top_k", {Datum(values)}, &options, ctx)); + CallFunction("top_k", {values, Datum(k)}, &options, ctx)); return result.make_array(); } -Result TopK(const Datum& datum, int64_t k, SelectKOptions options, - ExecContext* ctx) { +Result> TopK(const ChunkedArray& values, int64_t k, + const std::string& keep, ExecContext* ctx) { + SelectKOptions options(k, {}, keep, SortOrder::Ascending); + ARROW_ASSIGN_OR_RAISE(Datum result, + CallFunction("top_k", {Datum(values), Datum(k)}, &options, ctx)); + return result.make_array(); +} + +Result> TopK(const Datum& datum, int64_t k, SelectKOptions options, + ExecContext* ctx) { + options.k = k; + options.order = SortOrder::Ascending; + ARROW_ASSIGN_OR_RAISE(Datum result, + CallFunction("top_k", {datum, Datum(k)}, &options, ctx)); + return result.make_array(); +} + +Result> BottomK(const Array& values, int64_t k, + const std::string& keep, ExecContext* ctx) { + SelectKOptions options(k, {}, keep, SortOrder::Ascending); + ARROW_ASSIGN_OR_RAISE(Datum result, + CallFunction("bottom_k", {values, Datum(k)}, &options, ctx)); + return result.make_array(); +} + +Result> BottomK(const ChunkedArray& values, int64_t k, + const std::string& keep, ExecContext* ctx) { + SelectKOptions options(k, {}, keep, SortOrder::Ascending); + ARROW_ASSIGN_OR_RAISE( + Datum result, CallFunction("bottom_k", {Datum(values), Datum(k)}, &options, ctx)); + return result.make_array(); +} + +Result> BottomK(const Datum& datum, int64_t k, + SelectKOptions options, ExecContext* ctx) { options.k = k; + options.order = SortOrder::Ascending; ARROW_ASSIGN_OR_RAISE(Datum result, - CallFunction("top_k", {Datum(datum)}, &options, ctx)); - return result; + CallFunction("bottom_k", {datum, Datum(k)}, &options, ctx)); + return result.make_array(); } Result ReplaceWithMask(const Datum& values, const Datum& mask, diff --git a/cpp/src/arrow/compute/api_vector.h b/cpp/src/arrow/compute/api_vector.h index 848facd8b81..62381e14fe1 100644 --- a/cpp/src/arrow/compute/api_vector.h +++ b/cpp/src/arrow/compute/api_vector.h @@ -279,8 +279,32 @@ Result> TopK(const Array& values, int64_t k, /// @TODO ARROW_EXPORT -Result TopK(const Datum& datum, int64_t k, SelectKOptions options, - ExecContext* ctx = NULLPTR); +Result> TopK(const ChunkedArray& values, int64_t k, + const std::string& keep = "first", + ExecContext* ctx = NULLPTR); + +/// @TODO +ARROW_EXPORT +Result> TopK(const Datum& datum, int64_t k, SelectKOptions options, + ExecContext* ctx = NULLPTR); + +/// @TODO +ARROW_EXPORT +Result> BottomK(const Array& values, int64_t k, + const std::string& keep = "first", + ExecContext* ctx = NULLPTR); + +/// @TODO +ARROW_EXPORT +Result> BottomK(const ChunkedArray& values, int64_t k, + const std::string& keep = "first", + ExecContext* ctx = NULLPTR); + +/// @TODO +ARROW_EXPORT +Result> BottomK(const Datum& datum, int64_t k, + SelectKOptions options, + ExecContext* ctx = NULLPTR); /// \brief Returns the indices that would sort an array in the /// specified order. diff --git a/cpp/src/arrow/compute/kernels/select_k_test.cc b/cpp/src/arrow/compute/kernels/select_k_test.cc index 0def147d29f..3fd1392451c 100644 --- a/cpp/src/arrow/compute/kernels/select_k_test.cc +++ b/cpp/src/arrow/compute/kernels/select_k_test.cc @@ -86,84 +86,83 @@ Decimal256 GetLogicalValue(const Decimal256Array& array, uint64_t index) { } } // namespace -template -class NthComparator { + +template +class SelectKComparator { public: - bool operator()(const ArrayType& array, uint64_t lhs, uint64_t rhs) { - if (array.IsNull(rhs)) return true; - if (array.IsNull(lhs)) return false; - const auto lval = GetLogicalValue(array, lhs); - const auto rval = GetLogicalValue(array, rhs); + template + bool operator()(const Type& lval, const Type& rval) { if (is_floating_type::value) { // NaNs ordered after non-NaNs if (rval != rval) return true; if (lval != lval) return false; } - return lval <= rval; - } -}; - -template -class SortComparator { - public: - bool operator()(const ArrayType& array, SortOrder order, uint64_t lhs, uint64_t rhs) { - if (array.IsNull(rhs) && array.IsNull(lhs)) return lhs < rhs; - if (array.IsNull(rhs)) return true; - if (array.IsNull(lhs)) return false; - const auto lval = GetLogicalValue(array, lhs); - const auto rval = GetLogicalValue(array, rhs); - if (is_floating_type::value) { - const bool lhs_isnan = lval != lval; - const bool rhs_isnan = rval != rval; - if (lhs_isnan && rhs_isnan) return lhs < rhs; - if (rhs_isnan) return true; - if (lhs_isnan) return false; - } - if (lval == rval) return lhs < rhs; if (order == SortOrder::Ascending) { - return lval < rval; + return lval <= rval; } else { - return lval > rval; + return rval <= lval; } } }; +template +Result> SelectK(const ChunkedArray& values, int64_t k) { + if (order == SortOrder::Descending) { + return TopK(values, k); + } else { + return BottomK(values, k); + } +} + +template +Result> SelectK(const Array& values, int64_t k) { + if (order == SortOrder::Descending) { + return TopK(values, k); + } else { + return BottomK(values, k); + } +} template class TestSelectKBase : public TestBase { using ArrayType = typename TypeTraits::ArrayType; protected: - void Validate(const ArrayType& array, int n, UInt64Array& offsets) { - if (n >= array.length()) { - for (int i = 0; i < array.length(); ++i) { - ASSERT_TRUE(offsets.Value(i) == (uint64_t)i); - } - } else { - NthComparator compare; - uint64_t nth = offsets.Value(n); - - for (int i = 0; i < n; ++i) { - uint64_t lhs = offsets.Value(i); - ASSERT_TRUE(compare(array, lhs, nth)); - } - for (int i = n + 1; i < array.length(); ++i) { - uint64_t rhs = offsets.Value(i); - ASSERT_TRUE(compare(array, nth, rhs)); + void Validate(const ArrayType& array, int k, ArrayType& select_k, SortOrder order) { + ASSERT_OK_AND_ASSIGN(auto sorted_indices, SortIndices(array, order)); + ASSERT_OK_AND_ASSIGN(Datum sorted_datum, + Take(array, sorted_indices, TakeOptions::NoBoundsCheck())); + std::shared_ptr sorted_array_out = sorted_datum.make_array(); + + const ArrayType& sorted_array = *checked_pointer_cast(sorted_array_out); + + if (k < array.length()) { + for (uint64_t i = 0; i < (uint64_t)select_k.length(); ++i) { + const auto lval = GetLogicalValue(select_k, i); + const auto rval = GetLogicalValue(sorted_array, i); + ASSERT_TRUE(lval == rval); } } } - + template void AssertSelectKArray(const std::shared_ptr values, int n) { - ASSERT_OK_AND_ASSIGN(std::shared_ptr offsets, TopK(*values, n)); - // null_count field should have been initialized to 0, for convenience - ASSERT_EQ(offsets->data()->null_count, 0); - ValidateOutput(*offsets); + std::shared_ptr select_k; + ASSERT_OK_AND_ASSIGN(select_k, SelectK(*values, n)); + ASSERT_EQ(select_k->data()->null_count, 0); + ValidateOutput(*select_k); Validate(*checked_pointer_cast(values), n, - *checked_pointer_cast(offsets)); + *checked_pointer_cast(select_k), order); + } + + void AssertTopKArray(const std::shared_ptr values, int n) { + AssertSelectKArray(values, n); + } + void AssertBottomKArray(const std::shared_ptr values, int n) { + AssertSelectKArray(values, n); } void AssertSelectKJson(const std::string& values, int n) { - AssertSelectKArray(ArrayFromJSON(GetType(), values), n); + AssertTopKArray(ArrayFromJSON(GetType(), values), n); + AssertBottomKArray(ArrayFromJSON(GetType(), values), n); } virtual std::shared_ptr GetType() = 0; @@ -228,6 +227,8 @@ TYPED_TEST(TestSelectKForIntegral, Integral) { this->AssertSelectKJson("[null, 1, 3, null, 2, 5]", 2); this->AssertSelectKJson("[null, 1, 3, null, 2, 5]", 5); this->AssertSelectKJson("[null, 1, 3, null, 2, 5]", 6); + + this->AssertSelectKJson("[2, 4, 5, 7, 8, 0, 9, 1, 3]", 5); } TYPED_TEST(TestSelectKForBool, Bool) { @@ -355,10 +356,210 @@ TYPED_TEST(TestSelectKRandom, RandomValues) { // Try n from 0 to out of bound for (int n = 0; n <= length; ++n) { auto array = rand.Generate(length, null_probability); - this->AssertSelectKArray(array, n); + this->AssertTopKArray(array, n); + this->AssertBottomKArray(array, n); } } } +template +struct SelectKWithChunkedArray : public ::testing::Test { + SelectKWithChunkedArray() + : sizes_({0, 1, 2, 4, 16, 31, 1234}), + null_probabilities_({0.0, 0.1, 0.5, 0.9, 1.0}) {} + + void Check(const std::shared_ptr& type, + const std::vector& values, int64_t k, + const std::string& expected) { + std::shared_ptr actual; + ASSERT_OK(this->DoSelectK(type, values, k, &actual)); + ValidateOutput(actual); + + ASSERT_ARRAYS_EQUAL(*ArrayFromJSON(type, expected), *actual); + } + + void Check(const std::shared_ptr& type, + const std::shared_ptr& values, int64_t k, + const std::string& expected) { + ASSERT_OK_AND_ASSIGN(auto actual, SelectK(*values, k)); + ValidateOutput(actual); + ASSERT_ARRAYS_EQUAL(*ArrayFromJSON(type, expected), *actual); + } + + Status DoSelectK(const std::shared_ptr& type, + const std::vector& values, int64_t k, + std::shared_ptr* out) { + ARROW_ASSIGN_OR_RAISE(*out, SelectK(*(ChunkedArrayFromJSON(type, values)), k)); + return Status::OK(); + } + std::vector sizes_; + std::vector null_probabilities_; +}; + +struct TopKWithChunkedArray : public SelectKWithChunkedArray {}; + +TEST_F(TopKWithChunkedArray, Int8) { + this->Check(int8(), {"[0, 1, 9]", "[3, 7, 2, 4, 10]"}, 3, "[10, 9, 7]"); + this->Check(int8(), {"[]", "[]"}, 0, "[]"); + this->Check(float32(), {"[]"}, 0, "[]"); +} + +TEST_F(TopKWithChunkedArray, Null) { + this->Check(int8(), {"[null]", "[8, null]"}, 1, "[8]"); + + this->Check(int8(), {"[null]", "[null, null]"}, 0, "[]"); + this->Check(int32(), {"[0, null, 9]", "[3, null, 2, null, 10]"}, 3, "[10, 9, 3]"); + this->Check(int8(), {"[null]", "[]"}, 0, "[]"); +} + +TEST_F(TopKWithChunkedArray, NaN) { + this->Check(float32(), {"[NaN]", "[8, NaN]"}, 1, "[8]"); + + this->Check(float32(), {"[NaN]", "[NaN, NaN]"}, 0, "[]"); + this->Check(float32(), {"[0, NaN, 9]", "[3, NaN, 2, NaN, 10]"}, 3, "[10, 9, 3]"); + this->Check(float32(), {"[NaN]", "[]"}, 0, "[]"); +} + +struct BottomKWithChunkedArray : public SelectKWithChunkedArray {}; + +TEST_F(BottomKWithChunkedArray, Int8) { + this->Check(int8(), {"[0, 1, 9]", "[3, 7, 2, 4, 10]"}, 3, "[0, 1, 2]"); + this->Check(int8(), {"[]", "[]"}, 0, "[]"); + this->Check(float32(), {"[]"}, 0, "[]"); +} + +TEST_F(BottomKWithChunkedArray, Null) { + this->Check(int8(), {"[null]", "[8, null]"}, 1, "[8]"); + + this->Check(int8(), {"[null]", "[null, null]"}, 0, "[]"); + this->Check(int32(), {"[0, null, 9]", "[3, null, 2, null, 10]"}, 3, "[0, 2, 3]"); + this->Check(int8(), {"[null]", "[]"}, 0, "[]"); +} + +TEST_F(BottomKWithChunkedArray, NaN) { + this->Check(float32(), {"[NaN]", "[8, NaN]"}, 1, "[8]"); + + this->Check(float32(), {"[NaN]", "[NaN, NaN]"}, 0, "[]"); + this->Check(float32(), {"[0, NaN, 9]", "[3, NaN, 2, NaN, 10]"}, 3, "[0, 2, 3]"); + this->Check(float32(), {"[NaN]", "[]"}, 0, "[]"); +} + +template +class TopKWithChunkedArrayForTemporal : public TopKWithChunkedArray { + protected: + std::shared_ptr GetType() { return TypeToDataType(); } +}; +TYPED_TEST_SUITE(TopKWithChunkedArrayForTemporal, TemporalArrowTypes); + +TYPED_TEST(TopKWithChunkedArrayForTemporal, NoNull) { + auto type = this->GetType(); + auto chunked_array = ChunkedArrayFromJSON(type, { + "[0, 1]", + "[3, 2, 1]", + "[5, 0]", + }); + this->Check(type, chunked_array, 3, "[5, 3, 2]"); +} + +template +class BottomKWithChunkedArrayForTemporal : public TopKWithChunkedArray { + protected: + std::shared_ptr GetType() { return TypeToDataType(); } +}; +TYPED_TEST_SUITE(BottomKWithChunkedArrayForTemporal, TemporalArrowTypes); + +TYPED_TEST(BottomKWithChunkedArrayForTemporal, NoNull) { + auto type = this->GetType(); + auto chunked_array = ChunkedArrayFromJSON(type, { + "[0, 1]", + "[3, 2, 1]", + "[5, 0]", + }); + this->Check(type, chunked_array, 3, "[0, 1, 1]"); +} + +// Tests for decimal types +template +class TopKWithChunkedArrayForDecimal : public TopKWithChunkedArray { + protected: + std::shared_ptr GetType() { return std::make_shared(5, 2); } +}; +TYPED_TEST_SUITE(TopKWithChunkedArrayForDecimal, DecimalArrowTypes); + +TYPED_TEST(TopKWithChunkedArrayForDecimal, Basics) { + auto type = this->GetType(); + auto chunked_array = ChunkedArrayFromJSON( + type, {R"(["123.45", "-123.45"])", R"([null, "456.78"])", R"(["-456.78", null])"}); + this->Check(type, chunked_array, 3, R"(["456.78", "123.45", "-123.45"])"); +} + +using SortIndicesableTypes = + ::testing::Types; + +template +void ValidateSelectK(const ArrayType& array) { + ValidateOutput(array); + SelectKComparator compare; + for (int i = 1; i < array.length(); i++) { + const auto lval = GetLogicalValue(array, i - 1); + const auto rval = GetLogicalValue(array, i); + ASSERT_TRUE(compare(lval, rval)); + } +} +// Base class for testing against random chunked array. +template +struct SelectKWithChunkedArrayRandomBase : public ::testing::Test { + void TestChunkedArraySelectK(int length) { + using ArrayType = typename TypeTraits::ArrayType; + // We can use INSTANTIATE_TEST_SUITE_P() instead of using fors in a test. + for (auto null_probability : {0.0, 0.1, 0.5, 0.9, 1.0}) { + for (auto num_chunks : {1, 2, 5, 10, 40}) { + std::vector> arrays; + for (int i = 0; i < num_chunks; ++i) { + auto array = this->GenerateArray(length / num_chunks, null_probability); + arrays.push_back(array); + } + ASSERT_OK_AND_ASSIGN(auto chunked_array, ChunkedArray::Make(arrays)); + ASSERT_OK_AND_ASSIGN(auto top_k, SelectK(*chunked_array, 5)); + // Concatenates chunks to use existing ValidateSorted() for array. + ValidateSelectK(*checked_pointer_cast(top_k)); + } + } + } + + void SetUp() override { rand_ = new Random(0x5487655); } + + void TearDown() override { delete rand_; } + + protected: + std::shared_ptr GenerateArray(int length, double null_probability) { + return rand_->Generate(length, null_probability); + } + + private: + Random* rand_; +}; + +// Long array with big value range +template +class TestTopKChunkedArrayRandom + : public SelectKWithChunkedArrayRandomBase {}; + +TYPED_TEST_SUITE(TestTopKChunkedArrayRandom, SortIndicesableTypes); + +TYPED_TEST(TestTopKChunkedArrayRandom, TopK) { this->TestChunkedArraySelectK(1000); } + +template +class TestBottomKChunkedArrayRandom + : public SelectKWithChunkedArrayRandomBase {}; + +TYPED_TEST_SUITE(TestBottomKChunkedArrayRandom, SortIndicesableTypes); + +TYPED_TEST(TestBottomKChunkedArrayRandom, BottomK) { + this->TestChunkedArraySelectK(1000); +} + } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/compute/kernels/vector_sort.cc b/cpp/src/arrow/compute/kernels/vector_sort.cc index e2bad24ec07..289b8d52c91 100644 --- a/cpp/src/arrow/compute/kernels/vector_sort.cc +++ b/cpp/src/arrow/compute/kernels/vector_sort.cc @@ -17,8 +17,10 @@ #include #include +#include #include #include +#include #include #include @@ -26,6 +28,7 @@ #include "arrow/compute/api_vector.h" #include "arrow/compute/kernels/common.h" #include "arrow/compute/kernels/util_internal.h" +#include "arrow/pretty_print.h" #include "arrow/table.h" #include "arrow/type_traits.h" #include "arrow/util/bit_block_counter.h" @@ -368,183 +371,233 @@ using SelectKOptionsState = internal::OptionsWrapper; const auto kDefaultTopKOptions = SelectKOptions::TopKDefault(); const auto kDefaultBottomKOptions = SelectKOptions::BottomKDefault(); -template -struct ArraySelectNth { - using ArrayType = typename TypeTraits::ArrayType; +const FunctionDoc top_k_doc( + "Return the indices that would partition an array array, record batch or table\n" + "around a pivot", + ("@TODO"), {"input", "k"}, "PartitionNthOptions"); - static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { - using GetView = GetViewType; +const FunctionDoc bottom_k_doc( + "Return the indices that would partition an array array, record batch or table\n" + "around a pivot", + ("@TODO"), {"input", "k"}, "PartitionNthOptions"); + +Result> MakeMutableArray(std::shared_ptr out_type, + int64_t length, + MemoryPool* memory_pool) { + auto buffer_size = BitUtil::BytesForBits( + length * std::static_pointer_cast(out_type)->bit_width()); + std::vector> buffers(2); + ARROW_ASSIGN_OR_RAISE(buffers[1], AllocateResizableBuffer(buffer_size, memory_pool)); + auto out = std::make_shared(out_type, length, buffers, 0); + return out; +} - if (ctx->state() == nullptr) { - return Status::Invalid("NthToIndices requires PartitionNthOptions"); - } +class ArraySelecter : public TypeVisitor { + public: + ArraySelecter(ExecContext* ctx, const Array& array, int64_t k, const SortOrder order, + Datum* output) + : TypeVisitor(), + ctx_(ctx), + array_(array), + k_(k), + physical_type_(GetPhysicalType(array.type())), + order_(order), + output_(output) {} - ArrayType arr(batch[0].array()); + Status Run() { return physical_type_->Accept(this); } - int64_t pivot = SelectKOptionsState::Get(ctx).k; - SortOrder order = SelectKOptionsState::Get(ctx).order; - std::string keep = SelectKOptionsState::Get(ctx).keep; - if (pivot > arr.length()) { - return Status::IndexError("NthToIndices index out of bound"); +#define VISIT(TYPE) \ + Status Visit(const TYPE& type) { return SelectKthInternal(); } + + VISIT_PHYSICAL_TYPES(VISIT) + +#undef VISIT + + template + Status SelectKthInternal() { + using GetView = GetViewType; + using ArrayType = typename TypeTraits::ArrayType; + + ArrayType arr(array_.data()); + std::vector indices(arr.length()); + + uint64_t* indices_begin = indices.data(); + uint64_t* indices_end = indices_begin + indices.size(); + std::iota(indices_begin, indices_end, 0); + if (k_ > arr.length()) { + k_ = arr.length(); } - ArrayData* out_arr = out->mutable_array(); - uint64_t* out_begin = out_arr->GetMutableValues(1); - uint64_t* out_end = out_begin + arr.length(); - std::iota(out_begin, out_end, 0); - if (pivot == arr.length()) { - return Status::OK(); + auto end_iter = PartitionNulls(indices_begin, + indices_end, arr, 0); + auto kth_begin = indices_begin + k_; + if (kth_begin > end_iter) { + kth_begin = end_iter; } - auto nulls_begin = - PartitionNulls(out_begin, out_end, arr, 0); - auto nth_begin = out_begin + pivot; - if (nth_begin < nulls_begin) { - std::function cmp; - if (order == SortOrder::Ascending) { - cmp = [&arr](uint64_t left, uint64_t right) -> bool { - const auto lval = GetView::LogicalValue(arr.GetView(left)); - const auto rval = GetView::LogicalValue(arr.GetView(right)); - return lval < rval; - }; + std::function cmp; + if (order_ == SortOrder::Ascending) { + cmp = [&arr](uint64_t left, uint64_t right) -> bool { + const auto lval = GetView::LogicalValue(arr.GetView(left)); + const auto rval = GetView::LogicalValue(arr.GetView(right)); + return lval < rval; + }; + } else { + cmp = [&arr](uint64_t left, uint64_t right) -> bool { + const auto lval = GetView::LogicalValue(arr.GetView(left)); + const auto rval = GetView::LogicalValue(arr.GetView(right)); + return rval < lval; + }; + } + arrow::internal::Heap heap(cmp); + uint64_t* iter = indices_begin; + for (; iter != kth_begin && heap.Size() < static_cast(k_); ++iter) { + heap.Push(*iter); + } + for (; iter != end_iter && heap.Size() > 0; ++iter) { + uint64_t x_index = *iter; + const auto lval = GetView::LogicalValue(arr.GetView(x_index)); + const auto rval = GetView::LogicalValue(arr.GetView(heap.Top())); + if (order_ == SortOrder::Ascending) { + if (lval < rval) { + heap.ReplaceTop(x_index); + } } else { - cmp = [&arr](uint64_t left, uint64_t right) -> bool { - const auto lval = GetView::LogicalValue(arr.GetView(left)); - const auto rval = GetView::LogicalValue(arr.GetView(right)); - return rval < lval; - }; - } - arrow::internal::Heap heap(cmp); - - for (uint64_t* iter = out_begin; iter != nth_begin; ++iter) { - heap.Push(*iter); - } - for (uint64_t* iter = out_begin; iter != nulls_begin; ++iter) { - uint64_t x_index = *iter; - const auto lval = GetView::LogicalValue(arr.GetView(x_index)); - const auto rval = GetView::LogicalValue(arr.GetView(heap.Top())); - if (order == SortOrder::Ascending) { - if (keep == "first") { - if (lval < rval) { - heap.ReplaceTop(x_index); - } - } - } else { - if (rval < lval) { - heap.ReplaceTop(x_index); - } + if (rval < lval) { + heap.ReplaceTop(x_index); } } - std::copy(heap.Data(), heap.Data() + pivot, out_begin); } + + int64_t out_size = static_cast(heap.Size()); + ARROW_ASSIGN_OR_RAISE(auto take_indices, + MakeMutableArray(uint64(), out_size, ctx_->memory_pool())); + + auto* out_cbegin = take_indices->GetMutableValues(1) + out_size - 1; + while (heap.Size() > 0) { + *out_cbegin = heap.Top(); + heap.Pop(); + --out_cbegin; + } + ARROW_ASSIGN_OR_RAISE(*output_, Take(array_, Datum(std::move(take_indices)), + TakeOptions::NoBoundsCheck(), ctx_)); return Status::OK(); } -}; -const FunctionDoc top_k_doc( - "Return the indices that would partition an array array, record batch or table\n" - "around a pivot", - ("@TODO"), {"input"}, "PartitionNthOptions"); - -const FunctionDoc bottom_k_doc( - "Return the indices that would partition an array array, record batch or table\n" - "around a pivot", - ("@TODO"), {"input"}, "PartitionNthOptions"); + ExecContext* ctx_; + const Array& array_; + int64_t k_; + const std::shared_ptr physical_type_; + SortOrder order_; + Datum* output_; +}; class ChunkedArraySelecter : public TypeVisitor { public: - ChunkedArraySelecter(ExecContext* ctx, uint64_t* indices_begin, uint64_t* indices_end, - int64_t pivot, const ChunkedArray& chunked_array, - const SortOrder order) + ChunkedArraySelecter(ExecContext* ctx, int64_t k, const ChunkedArray& chunked_array, + const SortOrder order, Datum* output) : TypeVisitor(), - indices_begin_(indices_begin), - indices_end_(indices_end), - pivot_(pivot), + k_(k), chunked_array_(chunked_array), physical_type_(GetPhysicalType(chunked_array.type())), physical_chunks_(GetPhysicalChunks(chunked_array_, physical_type_)), order_(order), - ctx_(ctx) {} + ctx_(ctx), + output_(output) {} Status Run() { return physical_type_->Accept(this); } #define VISIT(TYPE) \ - Status Visit(const TYPE& type) { return SelectNthInternal(); } + Status Visit(const TYPE& type) { return SelectKthInternal(); } VISIT_PHYSICAL_TYPES(VISIT) #undef VISIT template - Status SelectNthInternal() { + Status SelectKthInternal() { using GetView = GetViewType; + using T = typename GetView::T; using ArrayType = typename TypeTraits::ArrayType; const auto num_chunks = chunked_array_.num_chunks(); if (num_chunks == 0) { return Status::OK(); } + if (k_ > chunked_array_.length()) { + k_ = chunked_array_.length(); + } + arrow::internal::Heap> heap; for (const auto& chunk : physical_chunks_) { + if (chunk->length() == 0) continue; ArrayType arr(chunk->data()); + std::vector indices(arr.length()); + + uint64_t* indices_begin = indices.data(); + uint64_t* indices_end = indices_begin + indices.size(); + std::iota(indices_begin, indices_end, 0); - auto nulls_begin = PartitionNulls( - indices_begin_, indices_end_, arr, 0); + auto end_iter = PartitionNulls( + indices_begin, indices_end, arr, 0); + auto kth_begin = indices_begin + k_; - auto nth_begin = indices_begin_ + pivot_; - if (nth_begin < nulls_begin) { - std::function cmp; + if (kth_begin > end_iter) { + kth_begin = end_iter; + } + uint64_t* iter = indices_begin; + for (; iter != kth_begin && heap.Size() < static_cast(k_); ++iter) { + const T xval = GetView::LogicalValue(arr.GetView(*iter)); + heap.Push(xval); + } + for (; iter != end_iter && heap.Size() > 0; ++iter) { + uint64_t x_index = *iter; + const T xval = GetView::LogicalValue(arr.GetView(x_index)); + const T& top_value = heap.Top(); if (order_ == SortOrder::Ascending) { - cmp = [&arr](uint64_t left, uint64_t right) -> bool { - const auto lval = GetView::LogicalValue(arr.GetView(left)); - const auto rval = GetView::LogicalValue(arr.GetView(right)); - return lval < rval; - }; + if (xval < top_value) { + heap.ReplaceTop(xval); + } } else { - cmp = [&arr](uint64_t left, uint64_t right) -> bool { - const auto lval = GetView::LogicalValue(arr.GetView(left)); - const auto rval = GetView::LogicalValue(arr.GetView(right)); - return rval < lval; - }; - } - arrow::internal::Heap heap(cmp); - - for (uint64_t* iter = indices_begin_; iter != nth_begin; ++iter) { - heap.Push(*iter); - } - for (uint64_t* iter = indices_begin_; iter != nulls_begin; ++iter) { - auto x = *iter; - if (x < heap.Top()) { - heap.ReplaceTop(x); + if (top_value < xval) { + heap.ReplaceTop(xval); } } - std::copy(heap.Data(), heap.Data() + pivot_, indices_begin_); } } + int64_t out_size = static_cast(heap.Size()); + ARROW_ASSIGN_OR_RAISE( + auto out_array, + MakeMutableArray(chunked_array_.type(), out_size, ctx_->memory_pool())); + auto* out_cbegin = out_array->GetMutableValues(1) + out_size - 1; + while (heap.Size() > 0) { + *out_cbegin = heap.Top(); + heap.Pop(); + --out_cbegin; + } + *output_ = Datum(out_array); return Status::OK(); } - uint64_t* indices_begin_; - uint64_t* indices_end_; - int64_t pivot_; + int64_t k_; const ChunkedArray& chunked_array_; const std::shared_ptr physical_type_; const ArrayVector physical_chunks_; const SortOrder order_; ExecContext* ctx_; + Datum* output_; }; template -class SelectNthMetaFunction { +class SelectKthMetaFunction { public: Result ExecuteImpl(const std::vector& args, const FunctionOptions* options, ExecContext* ctx) const { - const SelectKOptions& partition_options = - static_cast(*options); + const SelectKOptions& select_k_options = static_cast(*options); switch (args[0].kind()) { case Datum::ARRAY: - return SelectNth(*args[0].make_array(), partition_options, ctx); + return SelectKth(*args[0].make_array(), select_k_options, ctx); break; case Datum::CHUNKED_ARRAY: - return SelectNth(*args[0].chunked_array(), partition_options, ctx); + return SelectKth(*args[0].chunked_array(), select_k_options, ctx); break; default: break; @@ -556,43 +609,32 @@ class SelectNthMetaFunction { } private: - Result SelectNth(const Array& values, const SelectKOptions& options, + Result SelectKth(const Array& array, const SelectKOptions& options, ExecContext* ctx) const { - return CallFunction("partition_nth_indices_doc", {values}, &options, ctx); + Datum output; + ArraySelecter selecter(ctx, array, options.k, sort_order, &output); + ARROW_RETURN_NOT_OK(selecter.Run()); + return output; } - Result SelectNth(const ChunkedArray& chunked_array, + Result SelectKth(const ChunkedArray& chunked_array, const SelectKOptions& options, ExecContext* ctx) const { - int64_t pivot = options.k; - SortOrder order = sort_order; - auto out_type = uint64(); - auto length = chunked_array.length(); - auto buffer_size = BitUtil::BytesForBits( - length * std::static_pointer_cast(out_type)->bit_width()); - std::vector> buffers(2); - ARROW_ASSIGN_OR_RAISE(buffers[1], - AllocateResizableBuffer(buffer_size, ctx->memory_pool())); - auto out = std::make_shared(out_type, length, buffers, 0); - auto out_begin = out->GetMutableValues(1); - auto out_end = out_begin + length; - std::iota(out_begin, out_end, 0); - - ChunkedArraySelecter partitioner(ctx, out_begin, out_end, pivot, chunked_array, - order); + Datum output; + ChunkedArraySelecter partitioner(ctx, options.k, chunked_array, sort_order, &output); ARROW_RETURN_NOT_OK(partitioner.Run()); - return Datum(out); + return output; } }; class TopKMetaFunction : public MetaFunction { public: TopKMetaFunction() - : MetaFunction("top_k", Arity::Unary(), &top_k_doc, &kDefaultTopKOptions) {} + : MetaFunction("top_k", Arity::Binary(), &top_k_doc, &kDefaultTopKOptions) {} Result ExecuteImpl(const std::vector& args, const FunctionOptions* options, ExecContext* ctx) const override { - SelectNthMetaFunction impl; + SelectKthMetaFunction impl; return impl.ExecuteImpl(args, options, ctx); } }; @@ -600,11 +642,11 @@ class TopKMetaFunction : public MetaFunction { class BottomKMetaFunction : public MetaFunction { public: BottomKMetaFunction() - : MetaFunction("bottom_k", Arity::Unary(), &top_k_doc, &kDefaultBottomKOptions) {} + : MetaFunction("bottom_k", Arity::Binary(), &top_k_doc, &kDefaultBottomKOptions) {} Result ExecuteImpl(const std::vector& args, const FunctionOptions* options, ExecContext* ctx) const override { - SelectNthMetaFunction impl; + SelectKthMetaFunction impl; return impl.ExecuteImpl(args, options, ctx); } }; @@ -2055,11 +2097,6 @@ const FunctionDoc partition_nth_indices_doc( "The pivot index `N` must be given in PartitionNthOptions."), {"array"}, "PartitionNthOptions"); -const FunctionDoc array_top_k_doc("Return the indices that would sort an array", - ("@TODO."), {"array"}, "ArraySortOptions"); - -const FunctionDoc array_bottom_k_doc("Return the indices that would sort an array", - ("@TODO"), {"array"}, "ArraySortOptions"); } // namespace void RegisterVectorSort(FunctionRegistry* registry) { @@ -2085,19 +2122,9 @@ void RegisterVectorSort(FunctionRegistry* registry) { DCHECK_OK(registry->AddFunction(std::move(part_indices))); // top_k - auto part_topk = std::make_shared("array_top_k", Arity::Unary(), - &array_bottom_k_doc); - base.init = SelectKOptionsState::Init; - AddSortingKernels(base, part_topk.get()); - DCHECK_OK(registry->AddFunction(std::move(part_topk))); DCHECK_OK(registry->AddFunction(std::make_shared())); // bottom_k - auto part_bottomk = - std::make_shared("array_bottom_k", Arity::Unary(), &bottom_k_doc); - base.init = SelectKOptionsState::Init; - AddSortingKernels(base, part_bottomk.get()); - DCHECK_OK(registry->AddFunction(std::move(part_bottomk))); DCHECK_OK(registry->AddFunction(std::make_shared())); } diff --git a/cpp/src/arrow/testing/gtest_util.h b/cpp/src/arrow/testing/gtest_util.h index 13bab35a66d..5b8127f7c20 100644 --- a/cpp/src/arrow/testing/gtest_util.h +++ b/cpp/src/arrow/testing/gtest_util.h @@ -158,8 +158,7 @@ using NumericArrowTypes = using RealArrowTypes = ::testing::Types; -using IntegralArrowTypes = ::testing::Types; +using IntegralArrowTypes = ::testing::Types; using TemporalArrowTypes = ::testing::Types; diff --git a/cpp/src/arrow/util/heap.h b/cpp/src/arrow/util/heap.h index afcb62ea3f6..e1f16682eca 100644 --- a/cpp/src/arrow/util/heap.h +++ b/cpp/src/arrow/util/heap.h @@ -38,7 +38,9 @@ class ARROW_EXPORT Heap { T* Data() { return values_.data(); } - const T& Top() const { return values_.front(); } + // const T& Top() const { return values_.front(); } + + T Top() const { return values_.front(); } bool Empty() const { return values_.empty(); } @@ -60,7 +62,9 @@ class ARROW_EXPORT Heap { std::push_heap(values_.begin(), values_.end(), comp_); } - protected: + void SetComparator(const Compare& comp) { comp_ = comp; } + + public: ARROW_DISALLOW_COPY_AND_ASSIGN(Heap); std::vector values_; From ed28db2290519116911df3a93babb8821c6ce4b8 Mon Sep 17 00:00:00 2001 From: Alexander Date: Fri, 27 Aug 2021 10:16:35 -0500 Subject: [PATCH 04/19] next version with HeapItem and Take --- cpp/src/arrow/compute/api_vector.cc | 12 +- cpp/src/arrow/compute/api_vector.h | 9 +- .../arrow/compute/kernels/select_k_test.cc | 103 ++++++--- cpp/src/arrow/compute/kernels/vector_sort.cc | 210 ++++++++++++++++-- 4 files changed, 270 insertions(+), 64 deletions(-) diff --git a/cpp/src/arrow/compute/api_vector.cc b/cpp/src/arrow/compute/api_vector.cc index 3b3a92a24a6..43acfc1a54b 100644 --- a/cpp/src/arrow/compute/api_vector.cc +++ b/cpp/src/arrow/compute/api_vector.cc @@ -192,13 +192,13 @@ Result> TopK(const ChunkedArray& values, int64_t k, return result.make_array(); } -Result> TopK(const Datum& datum, int64_t k, SelectKOptions options, - ExecContext* ctx) { +Result TopK(const Datum& datum, int64_t k, SelectKOptions options, + ExecContext* ctx) { options.k = k; options.order = SortOrder::Ascending; ARROW_ASSIGN_OR_RAISE(Datum result, CallFunction("top_k", {datum, Datum(k)}, &options, ctx)); - return result.make_array(); + return result; } Result> BottomK(const Array& values, int64_t k, @@ -217,13 +217,13 @@ Result> BottomK(const ChunkedArray& values, int64_t k, return result.make_array(); } -Result> BottomK(const Datum& datum, int64_t k, - SelectKOptions options, ExecContext* ctx) { +Result BottomK(const Datum& datum, int64_t k, SelectKOptions options, + ExecContext* ctx) { options.k = k; options.order = SortOrder::Ascending; ARROW_ASSIGN_OR_RAISE(Datum result, CallFunction("bottom_k", {datum, Datum(k)}, &options, ctx)); - return result.make_array(); + return result; } Result ReplaceWithMask(const Datum& values, const Datum& mask, diff --git a/cpp/src/arrow/compute/api_vector.h b/cpp/src/arrow/compute/api_vector.h index 62381e14fe1..942df8ebde6 100644 --- a/cpp/src/arrow/compute/api_vector.h +++ b/cpp/src/arrow/compute/api_vector.h @@ -285,8 +285,8 @@ Result> TopK(const ChunkedArray& values, int64_t k, /// @TODO ARROW_EXPORT -Result> TopK(const Datum& datum, int64_t k, SelectKOptions options, - ExecContext* ctx = NULLPTR); +Result TopK(const Datum& datum, int64_t k, SelectKOptions options, + ExecContext* ctx = NULLPTR); /// @TODO ARROW_EXPORT @@ -302,9 +302,8 @@ Result> BottomK(const ChunkedArray& values, int64_t k, /// @TODO ARROW_EXPORT -Result> BottomK(const Datum& datum, int64_t k, - SelectKOptions options, - ExecContext* ctx = NULLPTR); +Result BottomK(const Datum& datum, int64_t k, SelectKOptions options, + ExecContext* ctx = NULLPTR); /// \brief Returns the indices that would sort an array in the /// specified order. diff --git a/cpp/src/arrow/compute/kernels/select_k_test.cc b/cpp/src/arrow/compute/kernels/select_k_test.cc index 3fd1392451c..cd5024fa96c 100644 --- a/cpp/src/arrow/compute/kernels/select_k_test.cc +++ b/cpp/src/arrow/compute/kernels/select_k_test.cc @@ -16,6 +16,7 @@ // under the License. #include +#include #include #include #include @@ -363,8 +364,8 @@ TYPED_TEST(TestSelectKRandom, RandomValues) { } template -struct SelectKWithChunkedArray : public ::testing::Test { - SelectKWithChunkedArray() +struct TestSelectKWithChunkedArray : public ::testing::Test { + TestSelectKWithChunkedArray() : sizes_({0, 1, 2, 4, 16, 31, 1234}), null_probabilities_({0.0, 0.1, 0.5, 0.9, 1.0}) {} @@ -390,21 +391,23 @@ struct SelectKWithChunkedArray : public ::testing::Test { const std::vector& values, int64_t k, std::shared_ptr* out) { ARROW_ASSIGN_OR_RAISE(*out, SelectK(*(ChunkedArrayFromJSON(type, values)), k)); + PrettyPrint(**out, {}, &std::cerr); return Status::OK(); } std::vector sizes_; std::vector null_probabilities_; }; -struct TopKWithChunkedArray : public SelectKWithChunkedArray {}; +struct TestTopKWithChunkedArray + : public TestSelectKWithChunkedArray {}; -TEST_F(TopKWithChunkedArray, Int8) { - this->Check(int8(), {"[0, 1, 9]", "[3, 7, 2, 4, 10]"}, 3, "[10, 9, 7]"); - this->Check(int8(), {"[]", "[]"}, 0, "[]"); - this->Check(float32(), {"[]"}, 0, "[]"); +TEST_F(TestTopKWithChunkedArray, Int32) { + this->Check(int32(), {"[0, 1, 9]", "[3, 7, 2, 4, 10]"}, 3, "[10, 9, 7]"); + this->Check(int32(), {"[]", "[]"}, 0, "[]"); + this->Check(int32(), {"[]"}, 0, "[]"); } -TEST_F(TopKWithChunkedArray, Null) { +TEST_F(TestTopKWithChunkedArray, Null) { this->Check(int8(), {"[null]", "[8, null]"}, 1, "[8]"); this->Check(int8(), {"[null]", "[null, null]"}, 0, "[]"); @@ -412,7 +415,7 @@ TEST_F(TopKWithChunkedArray, Null) { this->Check(int8(), {"[null]", "[]"}, 0, "[]"); } -TEST_F(TopKWithChunkedArray, NaN) { +TEST_F(TestTopKWithChunkedArray, NaN) { this->Check(float32(), {"[NaN]", "[8, NaN]"}, 1, "[8]"); this->Check(float32(), {"[NaN]", "[NaN, NaN]"}, 0, "[]"); @@ -420,15 +423,16 @@ TEST_F(TopKWithChunkedArray, NaN) { this->Check(float32(), {"[NaN]", "[]"}, 0, "[]"); } -struct BottomKWithChunkedArray : public SelectKWithChunkedArray {}; +struct TestBottomKWithChunkedArray + : public TestSelectKWithChunkedArray {}; -TEST_F(BottomKWithChunkedArray, Int8) { +TEST_F(TestBottomKWithChunkedArray, Int8) { this->Check(int8(), {"[0, 1, 9]", "[3, 7, 2, 4, 10]"}, 3, "[0, 1, 2]"); this->Check(int8(), {"[]", "[]"}, 0, "[]"); this->Check(float32(), {"[]"}, 0, "[]"); } -TEST_F(BottomKWithChunkedArray, Null) { +TEST_F(TestBottomKWithChunkedArray, Null) { this->Check(int8(), {"[null]", "[8, null]"}, 1, "[8]"); this->Check(int8(), {"[null]", "[null, null]"}, 0, "[]"); @@ -436,7 +440,7 @@ TEST_F(BottomKWithChunkedArray, Null) { this->Check(int8(), {"[null]", "[]"}, 0, "[]"); } -TEST_F(BottomKWithChunkedArray, NaN) { +TEST_F(TestBottomKWithChunkedArray, NaN) { this->Check(float32(), {"[NaN]", "[8, NaN]"}, 1, "[8]"); this->Check(float32(), {"[NaN]", "[NaN, NaN]"}, 0, "[]"); @@ -445,13 +449,13 @@ TEST_F(BottomKWithChunkedArray, NaN) { } template -class TopKWithChunkedArrayForTemporal : public TopKWithChunkedArray { +class TestTopKWithChunkedArrayForTemporal : public TestTopKWithChunkedArray { protected: std::shared_ptr GetType() { return TypeToDataType(); } }; -TYPED_TEST_SUITE(TopKWithChunkedArrayForTemporal, TemporalArrowTypes); +TYPED_TEST_SUITE(TestTopKWithChunkedArrayForTemporal, TemporalArrowTypes); -TYPED_TEST(TopKWithChunkedArrayForTemporal, NoNull) { +TYPED_TEST(TestTopKWithChunkedArrayForTemporal, NoNull) { auto type = this->GetType(); auto chunked_array = ChunkedArrayFromJSON(type, { "[0, 1]", @@ -462,37 +466,53 @@ TYPED_TEST(TopKWithChunkedArrayForTemporal, NoNull) { } template -class BottomKWithChunkedArrayForTemporal : public TopKWithChunkedArray { +class TestBottomKWithChunkedArrayForTemporal : public TestBottomKWithChunkedArray { protected: std::shared_ptr GetType() { return TypeToDataType(); } }; -TYPED_TEST_SUITE(BottomKWithChunkedArrayForTemporal, TemporalArrowTypes); +TYPED_TEST_SUITE(TestBottomKWithChunkedArrayForTemporal, TemporalArrowTypes); -TYPED_TEST(BottomKWithChunkedArrayForTemporal, NoNull) { +TYPED_TEST(TestBottomKWithChunkedArrayForTemporal, NoNull) { auto type = this->GetType(); auto chunked_array = ChunkedArrayFromJSON(type, { "[0, 1]", "[3, 2, 1]", "[5, 0]", }); - this->Check(type, chunked_array, 3, "[0, 1, 1]"); + this->Check(type, chunked_array, 3, "[0, 0, 1]"); } // Tests for decimal types template -class TopKWithChunkedArrayForDecimal : public TopKWithChunkedArray { +class TestTopKWithChunkedArrayForDecimal : public TestTopKWithChunkedArray { protected: std::shared_ptr GetType() { return std::make_shared(5, 2); } }; -TYPED_TEST_SUITE(TopKWithChunkedArrayForDecimal, DecimalArrowTypes); +TYPED_TEST_SUITE(TestTopKWithChunkedArrayForDecimal, DecimalArrowTypes); -TYPED_TEST(TopKWithChunkedArrayForDecimal, Basics) { +TYPED_TEST(TestTopKWithChunkedArrayForDecimal, Basics) { auto type = this->GetType(); auto chunked_array = ChunkedArrayFromJSON( - type, {R"(["123.45", "-123.45"])", R"([null, "456.78"])", R"(["-456.78", null])"}); + type, {R"(["123.45", "-123.45"])", R"([null, "456.78"])", R"(["-456.78", + null])"}); this->Check(type, chunked_array, 3, R"(["456.78", "123.45", "-123.45"])"); } +template +class TestBottomKWithChunkedArrayForDecimal : public TestBottomKWithChunkedArray { + protected: + std::shared_ptr GetType() { return std::make_shared(5, 2); } +}; +TYPED_TEST_SUITE(TestBottomKWithChunkedArrayForDecimal, DecimalArrowTypes); + +TYPED_TEST(TestBottomKWithChunkedArrayForDecimal, Basics) { + auto type = this->GetType(); + auto chunked_array = ChunkedArrayFromJSON( + type, {R"(["123.45", "-123.45"])", R"([null, "456.78"])", R"(["-456.78", + null])"}); + this->Check(type, chunked_array, 3, R"(["-456.78", "-123.45", "123.45"])"); +} + using SortIndicesableTypes = ::testing::Types -struct SelectKWithChunkedArrayRandomBase : public ::testing::Test { - void TestChunkedArraySelectK(int length) { +struct TestSelectKWithChunkedArrayRandomBase : public ::testing::Test { + void TestSelectK(int length) { using ArrayType = typename TypeTraits::ArrayType; // We can use INSTANTIATE_TEST_SUITE_P() instead of using fors in a test. for (auto null_probability : {0.0, 0.1, 0.5, 0.9, 1.0}) { @@ -545,20 +565,41 @@ struct SelectKWithChunkedArrayRandomBase : public ::testing::Test { // Long array with big value range template class TestTopKChunkedArrayRandom - : public SelectKWithChunkedArrayRandomBase {}; + : public TestSelectKWithChunkedArrayRandomBase {}; TYPED_TEST_SUITE(TestTopKChunkedArrayRandom, SortIndicesableTypes); -TYPED_TEST(TestTopKChunkedArrayRandom, TopK) { this->TestChunkedArraySelectK(1000); } +TYPED_TEST(TestTopKChunkedArrayRandom, TopK) { this->TestSelectK(1000); } template class TestBottomKChunkedArrayRandom - : public SelectKWithChunkedArrayRandomBase {}; + : public TestSelectKWithChunkedArrayRandomBase {}; TYPED_TEST_SUITE(TestBottomKChunkedArrayRandom, SortIndicesableTypes); -TYPED_TEST(TestBottomKChunkedArrayRandom, BottomK) { - this->TestChunkedArraySelectK(1000); +TYPED_TEST(TestBottomKChunkedArrayRandom, BottomK) { this->TestSelectK(1000); } + +// Test basic cases for record batch. +class TestTopKWithRecordBatch : public ::testing::Test {}; + +TEST_F(TestTopKWithRecordBatch, NoNull) { + auto schema = ::arrow::schema({ + {field("a", uint8())}, + {field("b", uint32())}, + }); + SortOptions options( + {SortKey("a", SortOrder::Ascending), SortKey("b", SortOrder::Descending)}); + + auto batch = RecordBatchFromJSON(schema, + R"([{"a": 3, "b": 5}, + {"a": 1, "b": 3}, + {"a": 3, "b": 4}, + {"a": 0, "b": 6}, + {"a": 2, "b": 5}, + {"a": 1, "b": 5}, + {"a": 1, "b": 3} + ])"); + // AssertSortIndices(batch, options, "[3, 5, 1, 6, 4, 0, 2]"); } } // namespace compute diff --git a/cpp/src/arrow/compute/kernels/vector_sort.cc b/cpp/src/arrow/compute/kernels/vector_sort.cc index 289b8d52c91..989aa53c090 100644 --- a/cpp/src/arrow/compute/kernels/vector_sort.cc +++ b/cpp/src/arrow/compute/kernels/vector_sort.cc @@ -24,6 +24,7 @@ #include #include +#include "arrow/array/concatenate.h" #include "arrow/array/data.h" #include "arrow/compute/api_vector.h" #include "arrow/compute/kernels/common.h" @@ -381,9 +382,8 @@ const FunctionDoc bottom_k_doc( "around a pivot", ("@TODO"), {"input", "k"}, "PartitionNthOptions"); -Result> MakeMutableArray(std::shared_ptr out_type, - int64_t length, - MemoryPool* memory_pool) { +Result> MakeMutableArrayForFixedSizedType( + std::shared_ptr out_type, int64_t length, MemoryPool* memory_pool) { auto buffer_size = BitUtil::BytesForBits( length * std::static_pointer_cast(out_type)->bit_width()); std::vector> buffers(2); @@ -468,8 +468,9 @@ class ArraySelecter : public TypeVisitor { } int64_t out_size = static_cast(heap.Size()); - ARROW_ASSIGN_OR_RAISE(auto take_indices, - MakeMutableArray(uint64(), out_size, ctx_->memory_pool())); + ARROW_ASSIGN_OR_RAISE( + auto take_indices, + MakeMutableArrayForFixedSizedType(uint64(), out_size, ctx_->memory_pool())); auto* out_cbegin = take_indices->GetMutableValues(1) + out_size - 1; while (heap.Size() > 0) { @@ -490,6 +491,13 @@ class ArraySelecter : public TypeVisitor { Datum* output_; }; +template +struct TypedHeapItem { + uint64_t index; + uint64_t offset; + ArrayType* array; +}; + class ChunkedArraySelecter : public TypeVisitor { public: ChunkedArraySelecter(ExecContext* ctx, int64_t k, const ChunkedArray& chunked_array, @@ -515,8 +523,8 @@ class ChunkedArraySelecter : public TypeVisitor { template Status SelectKthInternal() { using GetView = GetViewType; - using T = typename GetView::T; using ArrayType = typename TypeTraits::ArrayType; + using HeapItem = TypedHeapItem; const auto num_chunks = chunked_array_.num_chunks(); if (num_chunks == 0) { @@ -525,11 +533,28 @@ class ChunkedArraySelecter : public TypeVisitor { if (k_ > chunked_array_.length()) { k_ = chunked_array_.length(); } - arrow::internal::Heap> heap; - + std::function cmp; + if (order_ == SortOrder::Ascending) { + cmp = [](const HeapItem& left, const HeapItem& right) -> bool { + const auto lval = GetView::LogicalValue(left.array->GetView(left.index)); + const auto rval = GetView::LogicalValue(right.array->GetView(right.index)); + return lval < rval; + }; + } else { + cmp = [](const HeapItem& left, const HeapItem& right) -> bool { + const auto lval = GetView::LogicalValue(left.array->GetView(left.index)); + const auto rval = GetView::LogicalValue(right.array->GetView(right.index)); + return rval < lval; + }; + } + arrow::internal::Heap heap(cmp); + std::vector> chunks_holder; + uint64_t offset = 0; for (const auto& chunk : physical_chunks_) { if (chunk->length() == 0) continue; - ArrayType arr(chunk->data()); + chunks_holder.emplace_back(std::make_shared(chunk->data())); + ArrayType& arr = *chunks_holder[chunks_holder.size() - 1]; + std::vector indices(arr.length()); uint64_t* indices_begin = indices.data(); @@ -545,35 +570,44 @@ class ChunkedArraySelecter : public TypeVisitor { } uint64_t* iter = indices_begin; for (; iter != kth_begin && heap.Size() < static_cast(k_); ++iter) { - const T xval = GetView::LogicalValue(arr.GetView(*iter)); - heap.Push(xval); + heap.Push(HeapItem{*iter, offset, &arr}); } for (; iter != end_iter && heap.Size() > 0; ++iter) { uint64_t x_index = *iter; - const T xval = GetView::LogicalValue(arr.GetView(x_index)); - const T& top_value = heap.Top(); + const auto& xval = GetView::LogicalValue(arr.GetView(x_index)); + auto top_item = heap.Top(); + const auto& top_value = + GetView::LogicalValue(top_item.array->GetView(top_item.index)); if (order_ == SortOrder::Ascending) { if (xval < top_value) { - heap.ReplaceTop(xval); + heap.ReplaceTop(HeapItem{x_index, offset, &arr}); } } else { if (top_value < xval) { - heap.ReplaceTop(xval); + heap.ReplaceTop(HeapItem{x_index, offset, &arr}); } } } + offset += chunk->length(); } int64_t out_size = static_cast(heap.Size()); ARROW_ASSIGN_OR_RAISE( - auto out_array, - MakeMutableArray(chunked_array_.type(), out_size, ctx_->memory_pool())); - auto* out_cbegin = out_array->GetMutableValues(1) + out_size - 1; + auto take_indices, + MakeMutableArrayForFixedSizedType(uint64(), out_size, ctx_->memory_pool())); + auto* out_cbegin = take_indices->GetMutableValues(1) + out_size - 1; while (heap.Size() > 0) { - *out_cbegin = heap.Top(); + auto top_item = heap.Top(); + *out_cbegin = top_item.index + top_item.offset; heap.Pop(); --out_cbegin; } - *output_ = Datum(out_array); + ARROW_ASSIGN_OR_RAISE(auto chunked_select_k, + Take(Datum(chunked_array_), Datum(std::move(take_indices)), + TakeOptions::NoBoundsCheck(), ctx_)); + ARROW_ASSIGN_OR_RAISE( + auto select_k, + Concatenate(chunked_select_k.chunked_array()->chunks(), ctx_->memory_pool())); + *output_ = Datum(select_k); return Status::OK(); } @@ -585,6 +619,128 @@ class ChunkedArraySelecter : public TypeVisitor { ExecContext* ctx_; Datum* output_; }; + +class RecordBatchSelecter : public TypeVisitor { + public: + RecordBatchSelecter(ExecContext* ctx, int64_t k, const RecordBatch& record_batch, + const SortOrder order, Datum* output) + : TypeVisitor(), + k_(k), + record_batch_(record_batch), + order_(order), + ctx_(ctx), + output_(output) {} + + Status Run() { + // return physical_type_->Accept(this); + return Status::OK(); + } + +#define VISIT(TYPE) \ + Status Visit(const TYPE& type) { return SelectKthInternal(); } + + VISIT_PHYSICAL_TYPES(VISIT) + +#undef VISIT + + template + Status SelectKthInternal() { + // using GetView = GetViewType; + // using ArrayType = typename TypeTraits::ArrayType; + // using HeapItem = TypedHeapItem; + + // const auto num_chunks = chunked_array_.num_chunks(); + // if (num_chunks == 0) { + // return Status::OK(); + // } + // if (k_ > chunked_array_.length()) { + // k_ = chunked_array_.length(); + // } + // std::function cmp; + // if (order_ == SortOrder::Ascending) { + // cmp = [](const HeapItem& left, const HeapItem& right) -> bool { + // const auto lval = GetView::LogicalValue(left.array->GetView(left.index)); + // const auto rval = GetView::LogicalValue(right.array->GetView(right.index)); + // return lval < rval; + // }; + // } else { + // cmp = [](const HeapItem& left, const HeapItem& right) -> bool { + // const auto lval = GetView::LogicalValue(left.array->GetView(left.index)); + // const auto rval = GetView::LogicalValue(right.array->GetView(right.index)); + // return rval < lval; + // }; + // } + // arrow::internal::Heap heap(cmp); + // std::vector> chunks_holder; + // uint64_t offset = 0; + // for (const auto& chunk : physical_chunks_) { + // if (chunk->length() == 0) continue; + // chunks_holder.emplace_back(std::make_shared(chunk->data())); + // ArrayType& arr = *chunks_holder[chunks_holder.size() - 1]; + + // std::vector indices(arr.length()); + + // uint64_t* indices_begin = indices.data(); + // uint64_t* indices_end = indices_begin + indices.size(); + // std::iota(indices_begin, indices_end, 0); + + // auto end_iter = PartitionNulls( + // indices_begin, indices_end, arr, 0); + // auto kth_begin = indices_begin + k_; + + // if (kth_begin > end_iter) { + // kth_begin = end_iter; + // } + // uint64_t* iter = indices_begin; + // for (; iter != kth_begin && heap.Size() < static_cast(k_); ++iter) { + // heap.Push(HeapItem{*iter, offset, &arr}); + // } + // for (; iter != end_iter && heap.Size() > 0; ++iter) { + // uint64_t x_index = *iter; + // const auto& xval = GetView::LogicalValue(arr.GetView(x_index)); + // auto top_item = heap.Top(); + // const auto& top_value = + // GetView::LogicalValue(top_item.array->GetView(top_item.index)); + // if (order_ == SortOrder::Ascending) { + // if (xval < top_value) { + // heap.ReplaceTop(HeapItem{x_index, offset, &arr}); + // } + // } else { + // if (top_value < xval) { + // heap.ReplaceTop(HeapItem{x_index, offset, &arr}); + // } + // } + // } + // offset += chunk->length(); + // } + // int64_t out_size = static_cast(heap.Size()); + // ARROW_ASSIGN_OR_RAISE( + // auto take_indices, + // MakeMutableArrayForFixedSizedType(uint64(), out_size, ctx_->memory_pool())); + // auto* out_cbegin = take_indices->GetMutableValues(1) + out_size - 1; + // while (heap.Size() > 0) { + // auto top_item = heap.Top(); + // *out_cbegin = top_item.index + top_item.offset; + // heap.Pop(); + // --out_cbegin; + // } + // ARROW_ASSIGN_OR_RAISE(auto chunked_select_k, + // Take(Datum(chunked_array_), Datum(std::move(take_indices)), + // TakeOptions::NoBoundsCheck(), ctx_)); + // ARROW_ASSIGN_OR_RAISE( + // auto select_k, + // Concatenate(chunked_select_k.chunked_array()->chunks(), ctx_->memory_pool())); + // *output_ = Datum(select_k); + return Status::OK(); + } + + int64_t k_; + const RecordBatch& record_batch_; + const SortOrder order_; + ExecContext* ctx_; + Datum* output_; +}; + template class SelectKthMetaFunction { public: @@ -599,6 +755,9 @@ class SelectKthMetaFunction { case Datum::CHUNKED_ARRAY: return SelectKth(*args[0].chunked_array(), select_k_options, ctx); break; + case Datum::RECORD_BATCH: + return SelectKth(*args[0].record_batch(), select_k_options, ctx); + break; default: break; } @@ -620,8 +779,15 @@ class SelectKthMetaFunction { Result SelectKth(const ChunkedArray& chunked_array, const SelectKOptions& options, ExecContext* ctx) const { Datum output; - ChunkedArraySelecter partitioner(ctx, options.k, chunked_array, sort_order, &output); - ARROW_RETURN_NOT_OK(partitioner.Run()); + ChunkedArraySelecter selecter(ctx, options.k, chunked_array, sort_order, &output); + ARROW_RETURN_NOT_OK(selecter.Run()); + return output; + } + Result SelectKth(const RecordBatch& record_batch, const SelectKOptions& options, + ExecContext* ctx) const { + Datum output; + RecordBatchSelecter selecter(ctx, options.k, record_batch, sort_order, &output); + ARROW_RETURN_NOT_OK(selecter.Run()); return output; } }; From 613495216fc616fb466a04b6dcf217504f0b189f Mon Sep 17 00:00:00 2001 From: Alexander Date: Mon, 30 Aug 2021 21:53:48 -0500 Subject: [PATCH 05/19] TopK for RecordBatch and Table minor fix --- cpp/src/arrow/compute/api_vector.cc | 6 +- cpp/src/arrow/compute/api_vector.h | 4 +- .../arrow/compute/kernels/select_k_test.cc | 59 +- cpp/src/arrow/compute/kernels/vector_sort.cc | 3330 +++++++++-------- 4 files changed, 1850 insertions(+), 1549 deletions(-) diff --git a/cpp/src/arrow/compute/api_vector.cc b/cpp/src/arrow/compute/api_vector.cc index 43acfc1a54b..81477de3ecf 100644 --- a/cpp/src/arrow/compute/api_vector.cc +++ b/cpp/src/arrow/compute/api_vector.cc @@ -178,7 +178,7 @@ Result> NthToIndices(const Array& values, int64_t n, Result> TopK(const Array& values, int64_t k, const std::string& keep, ExecContext* ctx) { - SelectKOptions options(k, {}, keep, SortOrder::Ascending); + SelectKOptions options(k, {}, keep, SortOrder::Descending); ARROW_ASSIGN_OR_RAISE(Datum result, CallFunction("top_k", {values, Datum(k)}, &options, ctx)); return result.make_array(); @@ -186,7 +186,7 @@ Result> TopK(const Array& values, int64_t k, Result> TopK(const ChunkedArray& values, int64_t k, const std::string& keep, ExecContext* ctx) { - SelectKOptions options(k, {}, keep, SortOrder::Ascending); + SelectKOptions options(k, {}, keep, SortOrder::Descending); ARROW_ASSIGN_OR_RAISE(Datum result, CallFunction("top_k", {Datum(values), Datum(k)}, &options, ctx)); return result.make_array(); @@ -195,7 +195,7 @@ Result> TopK(const ChunkedArray& values, int64_t k, Result TopK(const Datum& datum, int64_t k, SelectKOptions options, ExecContext* ctx) { options.k = k; - options.order = SortOrder::Ascending; + options.order = SortOrder::Descending; ARROW_ASSIGN_OR_RAISE(Datum result, CallFunction("top_k", {datum, Datum(k)}, &options, ctx)); return result; diff --git a/cpp/src/arrow/compute/api_vector.h b/cpp/src/arrow/compute/api_vector.h index 942df8ebde6..aace18147fa 100644 --- a/cpp/src/arrow/compute/api_vector.h +++ b/cpp/src/arrow/compute/api_vector.h @@ -128,10 +128,10 @@ class ARROW_EXPORT SelectKOptions : public FunctionOptions { SortOrder order = SortOrder::Ascending); constexpr static char const kTypeName[] = "SelectKOptions"; static SelectKOptions TopKDefault() { - return SelectKOptions{0, {}, "first", SortOrder::Ascending}; + return SelectKOptions{0, {}, "first", SortOrder::Descending}; } static SelectKOptions BottomKDefault() { - return SelectKOptions{0, {}, "first", SortOrder::Descending}; + return SelectKOptions{0, {}, "first", SortOrder::Ascending}; } int64_t k; std::vector keys; diff --git a/cpp/src/arrow/compute/kernels/select_k_test.cc b/cpp/src/arrow/compute/kernels/select_k_test.cc index cd5024fa96c..f910470e354 100644 --- a/cpp/src/arrow/compute/kernels/select_k_test.cc +++ b/cpp/src/arrow/compute/kernels/select_k_test.cc @@ -587,20 +587,67 @@ TEST_F(TestTopKWithRecordBatch, NoNull) { {field("a", uint8())}, {field("b", uint32())}, }); - SortOptions options( - {SortKey("a", SortOrder::Ascending), SortKey("b", SortOrder::Descending)}); auto batch = RecordBatchFromJSON(schema, R"([{"a": 3, "b": 5}, - {"a": 1, "b": 3}, + {"a": 30, "b": 3}, {"a": 3, "b": 4}, {"a": 0, "b": 6}, - {"a": 2, "b": 5}, - {"a": 1, "b": 5}, - {"a": 1, "b": 3} + {"a": 20, "b": 5}, + {"a": 10, "b": 5}, + {"a": 10, "b": 3} ])"); + PrettyPrint(*batch, {}, &std::cerr); + auto options = SelectKOptions::TopKDefault(); + options.keys = {"a"}; + ASSERT_OK_AND_ASSIGN(auto top_k, TopK(Datum(batch), 3, options)); + PrettyPrint(*top_k.record_batch(), {}, &std::cerr); + // AssertSortIndices(batch, options, "[3, 5, 1, 6, 4, 0, 2]"); } +// Test basic cases for table. +class TestTopKWithTable : public ::testing::Test {}; + +TEST_F(TestTopKWithTable, Null) { + auto schema = ::arrow::schema({ + {field("a", uint8())}, + {field("b", uint32())}, + }); + std::shared_ptr table; + + table = TableFromJSON(schema, {R"([{"a": null, "b": 5}, + {"a": 1, "b": 3}, + {"a": 3, "b": null}, + {"a": null, "b": null}, + {"a": 2, "b": 5}, + {"a": 1, "b": 5} + ])"}); + // AssertSortIndices(table, options, "[5, 1, 4, 2, 0, 3]"); + + PrettyPrint(*table, {}, &std::cerr); + auto options = SelectKOptions::TopKDefault(); + options.keys = {"a"}; + ASSERT_OK_AND_ASSIGN(auto top_k, TopK(Datum(table), 3, options)); + PrettyPrint(*top_k.table(), {}, &std::cerr); + + // Same data, several chunks + table = TableFromJSON(schema, {R"([{"a": null, "b": 5}, + {"a": 1, "b": 3}, + {"a": 3, "b": null} + ])", + R"([{"a": null, "b": null}, + {"a": 2, "b": 5}, + {"a": 1, "b": 5} + ])"}); + // AssertSortIndices(table, options, "[5, 1, 4, 2, 0, 3]"); + + PrettyPrint(*table, {}, &std::cerr); + options = SelectKOptions::TopKDefault(); + options.keys = {"a", "b"}; + ASSERT_OK_AND_ASSIGN(top_k, TopK(Datum(table), 3, options)); + PrettyPrint(*top_k.table(), {}, &std::cerr); +} + } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/compute/kernels/vector_sort.cc b/cpp/src/arrow/compute/kernels/vector_sort.cc index 989aa53c090..b00b69e3c99 100644 --- a/cpp/src/arrow/compute/kernels/vector_sort.cc +++ b/cpp/src/arrow/compute/kernels/vector_sort.cc @@ -103,7 +103,7 @@ struct ResolvedChunk { struct ChunkedArrayResolver { explicit ChunkedArrayResolver(const std::vector& chunks) : num_chunks_(static_cast(chunks.size())), - chunks_(chunks.data()), + chunks_(chunks), offsets_(MakeEndOffsets(chunks)), cached_chunk_(0) {} @@ -161,7 +161,7 @@ struct ChunkedArrayResolver { } int64_t num_chunks_; - const Array* const* chunks_; + const std::vector chunks_; std::vector offsets_; mutable int64_t cached_chunk_; @@ -366,835 +366,1142 @@ struct PartitionNthToIndices { }; // ---------------------------------------------------------------------- -// TopK/BottomK implementations +// Array sorting implementations -using SelectKOptionsState = internal::OptionsWrapper; -const auto kDefaultTopKOptions = SelectKOptions::TopKDefault(); -const auto kDefaultBottomKOptions = SelectKOptions::BottomKDefault(); +template +inline void VisitRawValuesInline(const ArrayType& values, + VisitorNotNull&& visitor_not_null, + VisitorNull&& visitor_null) { + const auto data = values.raw_values(); + VisitBitBlocksVoid( + values.null_bitmap(), values.offset(), values.length(), + [&](int64_t i) { visitor_not_null(data[i]); }, [&]() { visitor_null(); }); +} -const FunctionDoc top_k_doc( - "Return the indices that would partition an array array, record batch or table\n" - "around a pivot", - ("@TODO"), {"input", "k"}, "PartitionNthOptions"); +template +inline void VisitRawValuesInline(const BooleanArray& values, + VisitorNotNull&& visitor_not_null, + VisitorNull&& visitor_null) { + if (values.null_count() != 0) { + const uint8_t* data = values.data()->GetValues(1, 0); + VisitBitBlocksVoid( + values.null_bitmap(), values.offset(), values.length(), + [&](int64_t i) { visitor_not_null(BitUtil::GetBit(data, values.offset() + i)); }, + [&]() { visitor_null(); }); + } else { + // Can avoid GetBit() overhead in the no-nulls case + VisitBitBlocksVoid( + values.data()->buffers[1], values.offset(), values.length(), + [&](int64_t i) { visitor_not_null(true); }, [&]() { visitor_not_null(false); }); + } +} -const FunctionDoc bottom_k_doc( - "Return the indices that would partition an array array, record batch or table\n" - "around a pivot", - ("@TODO"), {"input", "k"}, "PartitionNthOptions"); +template +class ArrayCompareSorter { + using ArrayType = typename TypeTraits::ArrayType; + using GetView = GetViewType; -Result> MakeMutableArrayForFixedSizedType( - std::shared_ptr out_type, int64_t length, MemoryPool* memory_pool) { - auto buffer_size = BitUtil::BytesForBits( - length * std::static_pointer_cast(out_type)->bit_width()); - std::vector> buffers(2); - ARROW_ASSIGN_OR_RAISE(buffers[1], AllocateResizableBuffer(buffer_size, memory_pool)); - auto out = std::make_shared(out_type, length, buffers, 0); - return out; -} + public: + // Returns where null starts. + // + // `offset` is used when this is called on a chunk of a chunked array + uint64_t* Sort(uint64_t* indices_begin, uint64_t* indices_end, const ArrayType& values, + int64_t offset, const ArraySortOptions& options) { + auto nulls_begin = PartitionNulls( + indices_begin, indices_end, values, offset); + if (options.order == SortOrder::Ascending) { + std::stable_sort( + indices_begin, nulls_begin, [&values, &offset](uint64_t left, uint64_t right) { + const auto lhs = GetView::LogicalValue(values.GetView(left - offset)); + const auto rhs = GetView::LogicalValue(values.GetView(right - offset)); + return lhs < rhs; + }); + } else { + std::stable_sort( + indices_begin, nulls_begin, [&values, &offset](uint64_t left, uint64_t right) { + const auto lhs = GetView::LogicalValue(values.GetView(left - offset)); + const auto rhs = GetView::LogicalValue(values.GetView(right - offset)); + // We don't use 'left > right' here to reduce required operator. + // If we use 'right < left' here, '<' is only required. + return rhs < lhs; + }); + } + return nulls_begin; + } +}; + +template +class ArrayCountSorter { + using ArrayType = typename TypeTraits::ArrayType; + using c_type = typename ArrowType::c_type; -class ArraySelecter : public TypeVisitor { public: - ArraySelecter(ExecContext* ctx, const Array& array, int64_t k, const SortOrder order, - Datum* output) - : TypeVisitor(), - ctx_(ctx), - array_(array), - k_(k), - physical_type_(GetPhysicalType(array.type())), - order_(order), - output_(output) {} + ArrayCountSorter() = default; - Status Run() { return physical_type_->Accept(this); } + explicit ArrayCountSorter(c_type min, c_type max) { SetMinMax(min, max); } -#define VISIT(TYPE) \ - Status Visit(const TYPE& type) { return SelectKthInternal(); } + // Assume: max >= min && (max - min) < 4Gi + void SetMinMax(c_type min, c_type max) { + min_ = min; + value_range_ = static_cast(max - min) + 1; + } - VISIT_PHYSICAL_TYPES(VISIT) + // Returns where null starts. + uint64_t* Sort(uint64_t* indices_begin, uint64_t* indices_end, const ArrayType& values, + int64_t offset, const ArraySortOptions& options) { + // 32bit counter performs much better than 64bit one + if (values.length() < (1LL << 32)) { + return SortInternal(indices_begin, indices_end, values, offset, options); + } else { + return SortInternal(indices_begin, indices_end, values, offset, options); + } + } -#undef VISIT + private: + c_type min_{0}; + uint32_t value_range_{0}; - template - Status SelectKthInternal() { - using GetView = GetViewType; - using ArrayType = typename TypeTraits::ArrayType; + // Returns where null starts. + // + // `offset` is used when this is called on a chunk of a chunked array + template + uint64_t* SortInternal(uint64_t* indices_begin, uint64_t* indices_end, + const ArrayType& values, int64_t offset, + const ArraySortOptions& options) { + const uint32_t value_range = value_range_; - ArrayType arr(array_.data()); - std::vector indices(arr.length()); + // first slot reserved for prefix sum + std::vector counts(1 + value_range); - uint64_t* indices_begin = indices.data(); - uint64_t* indices_end = indices_begin + indices.size(); - std::iota(indices_begin, indices_end, 0); - if (k_ > arr.length()) { - k_ = arr.length(); - } - auto end_iter = PartitionNulls(indices_begin, - indices_end, arr, 0); - auto kth_begin = indices_begin + k_; - if (kth_begin > end_iter) { - kth_begin = end_iter; - } - std::function cmp; - if (order_ == SortOrder::Ascending) { - cmp = [&arr](uint64_t left, uint64_t right) -> bool { - const auto lval = GetView::LogicalValue(arr.GetView(left)); - const auto rval = GetView::LogicalValue(arr.GetView(right)); - return lval < rval; - }; + if (options.order == SortOrder::Ascending) { + VisitRawValuesInline( + values, [&](c_type v) { ++counts[v - min_ + 1]; }, []() {}); + for (uint32_t i = 1; i <= value_range; ++i) { + counts[i] += counts[i - 1]; + } + auto null_position = counts[value_range]; + auto nulls_begin = indices_begin + null_position; + int64_t index = offset; + VisitRawValuesInline( + values, [&](c_type v) { indices_begin[counts[v - min_]++] = index++; }, + [&]() { indices_begin[null_position++] = index++; }); + return nulls_begin; } else { - cmp = [&arr](uint64_t left, uint64_t right) -> bool { - const auto lval = GetView::LogicalValue(arr.GetView(left)); - const auto rval = GetView::LogicalValue(arr.GetView(right)); - return rval < lval; - }; - } - arrow::internal::Heap heap(cmp); - uint64_t* iter = indices_begin; - for (; iter != kth_begin && heap.Size() < static_cast(k_); ++iter) { - heap.Push(*iter); - } - for (; iter != end_iter && heap.Size() > 0; ++iter) { - uint64_t x_index = *iter; - const auto lval = GetView::LogicalValue(arr.GetView(x_index)); - const auto rval = GetView::LogicalValue(arr.GetView(heap.Top())); - if (order_ == SortOrder::Ascending) { - if (lval < rval) { - heap.ReplaceTop(x_index); - } - } else { - if (rval < lval) { - heap.ReplaceTop(x_index); - } + VisitRawValuesInline( + values, [&](c_type v) { ++counts[v - min_]; }, []() {}); + for (uint32_t i = value_range; i >= 1; --i) { + counts[i - 1] += counts[i]; } + auto null_position = counts[0]; + auto nulls_begin = indices_begin + null_position; + int64_t index = offset; + VisitRawValuesInline( + values, [&](c_type v) { indices_begin[counts[v - min_ + 1]++] = index++; }, + [&]() { indices_begin[null_position++] = index++; }); + return nulls_begin; } - - int64_t out_size = static_cast(heap.Size()); - ARROW_ASSIGN_OR_RAISE( - auto take_indices, - MakeMutableArrayForFixedSizedType(uint64(), out_size, ctx_->memory_pool())); - - auto* out_cbegin = take_indices->GetMutableValues(1) + out_size - 1; - while (heap.Size() > 0) { - *out_cbegin = heap.Top(); - heap.Pop(); - --out_cbegin; - } - ARROW_ASSIGN_OR_RAISE(*output_, Take(array_, Datum(std::move(take_indices)), - TakeOptions::NoBoundsCheck(), ctx_)); - return Status::OK(); } - - ExecContext* ctx_; - const Array& array_; - int64_t k_; - const std::shared_ptr physical_type_; - SortOrder order_; - Datum* output_; }; -template -struct TypedHeapItem { - uint64_t index; - uint64_t offset; - ArrayType* array; -}; +using ::arrow::internal::Bitmap; -class ChunkedArraySelecter : public TypeVisitor { +template <> +class ArrayCountSorter { public: - ChunkedArraySelecter(ExecContext* ctx, int64_t k, const ChunkedArray& chunked_array, - const SortOrder order, Datum* output) - : TypeVisitor(), - k_(k), - chunked_array_(chunked_array), - physical_type_(GetPhysicalType(chunked_array.type())), - physical_chunks_(GetPhysicalChunks(chunked_array_, physical_type_)), - order_(order), - ctx_(ctx), - output_(output) {} - - Status Run() { return physical_type_->Accept(this); } + ArrayCountSorter() = default; -#define VISIT(TYPE) \ - Status Visit(const TYPE& type) { return SelectKthInternal(); } + // Returns where null starts. + // `offset` is used when this is called on a chunk of a chunked array + uint64_t* Sort(uint64_t* indices_begin, uint64_t* indices_end, + const BooleanArray& values, int64_t offset, + const ArraySortOptions& options) { + std::array counts{0, 0}; - VISIT_PHYSICAL_TYPES(VISIT) - -#undef VISIT + const int64_t nulls = values.null_count(); + const int64_t ones = values.true_count(); + const int64_t zeros = values.length() - ones - nulls; - template - Status SelectKthInternal() { - using GetView = GetViewType; - using ArrayType = typename TypeTraits::ArrayType; - using HeapItem = TypedHeapItem; + int64_t null_position = values.length() - nulls; + int64_t index = offset; + const auto nulls_begin = indices_begin + null_position; - const auto num_chunks = chunked_array_.num_chunks(); - if (num_chunks == 0) { - return Status::OK(); - } - if (k_ > chunked_array_.length()) { - k_ = chunked_array_.length(); - } - std::function cmp; - if (order_ == SortOrder::Ascending) { - cmp = [](const HeapItem& left, const HeapItem& right) -> bool { - const auto lval = GetView::LogicalValue(left.array->GetView(left.index)); - const auto rval = GetView::LogicalValue(right.array->GetView(right.index)); - return lval < rval; - }; + if (options.order == SortOrder::Ascending) { + // ones start after zeros + counts[1] = zeros; } else { - cmp = [](const HeapItem& left, const HeapItem& right) -> bool { - const auto lval = GetView::LogicalValue(left.array->GetView(left.index)); - const auto rval = GetView::LogicalValue(right.array->GetView(right.index)); - return rval < lval; - }; + // zeros start after ones + counts[0] = ones; } - arrow::internal::Heap heap(cmp); - std::vector> chunks_holder; - uint64_t offset = 0; - for (const auto& chunk : physical_chunks_) { - if (chunk->length() == 0) continue; - chunks_holder.emplace_back(std::make_shared(chunk->data())); - ArrayType& arr = *chunks_holder[chunks_holder.size() - 1]; - - std::vector indices(arr.length()); + VisitRawValuesInline( + values, [&](bool v) { indices_begin[counts[v]++] = index++; }, + [&]() { indices_begin[null_position++] = index++; }); + return nulls_begin; + } +}; - uint64_t* indices_begin = indices.data(); - uint64_t* indices_end = indices_begin + indices.size(); - std::iota(indices_begin, indices_end, 0); +// Sort integers with counting sort or comparison based sorting algorithm +// - Use O(n) counting sort if values are in a small range +// - Use O(nlogn) std::stable_sort otherwise +template +class ArrayCountOrCompareSorter { + using ArrayType = typename TypeTraits::ArrayType; + using c_type = typename ArrowType::c_type; - auto end_iter = PartitionNulls( - indices_begin, indices_end, arr, 0); - auto kth_begin = indices_begin + k_; + public: + // Returns where null starts. + // + // `offset` is used when this is called on a chunk of a chunked array + uint64_t* Sort(uint64_t* indices_begin, uint64_t* indices_end, const ArrayType& values, + int64_t offset, const ArraySortOptions& options) { + if (values.length() >= countsort_min_len_ && values.length() > values.null_count()) { + c_type min, max; + std::tie(min, max) = GetMinMax(*values.data()); - if (kth_begin > end_iter) { - kth_begin = end_iter; - } - uint64_t* iter = indices_begin; - for (; iter != kth_begin && heap.Size() < static_cast(k_); ++iter) { - heap.Push(HeapItem{*iter, offset, &arr}); - } - for (; iter != end_iter && heap.Size() > 0; ++iter) { - uint64_t x_index = *iter; - const auto& xval = GetView::LogicalValue(arr.GetView(x_index)); - auto top_item = heap.Top(); - const auto& top_value = - GetView::LogicalValue(top_item.array->GetView(top_item.index)); - if (order_ == SortOrder::Ascending) { - if (xval < top_value) { - heap.ReplaceTop(HeapItem{x_index, offset, &arr}); - } - } else { - if (top_value < xval) { - heap.ReplaceTop(HeapItem{x_index, offset, &arr}); - } - } + // For signed int32/64, (max - min) may overflow and trigger UBSAN. + // Cast to largest unsigned type(uint64_t) before subtraction. + if (static_cast(max) - static_cast(min) <= + countsort_max_range_) { + count_sorter_.SetMinMax(min, max); + return count_sorter_.Sort(indices_begin, indices_end, values, offset, options); } - offset += chunk->length(); - } - int64_t out_size = static_cast(heap.Size()); - ARROW_ASSIGN_OR_RAISE( - auto take_indices, - MakeMutableArrayForFixedSizedType(uint64(), out_size, ctx_->memory_pool())); - auto* out_cbegin = take_indices->GetMutableValues(1) + out_size - 1; - while (heap.Size() > 0) { - auto top_item = heap.Top(); - *out_cbegin = top_item.index + top_item.offset; - heap.Pop(); - --out_cbegin; } - ARROW_ASSIGN_OR_RAISE(auto chunked_select_k, - Take(Datum(chunked_array_), Datum(std::move(take_indices)), - TakeOptions::NoBoundsCheck(), ctx_)); - ARROW_ASSIGN_OR_RAISE( - auto select_k, - Concatenate(chunked_select_k.chunked_array()->chunks(), ctx_->memory_pool())); - *output_ = Datum(select_k); - return Status::OK(); + + return compare_sorter_.Sort(indices_begin, indices_end, values, offset, options); } - int64_t k_; - const ChunkedArray& chunked_array_; - const std::shared_ptr physical_type_; - const ArrayVector physical_chunks_; - const SortOrder order_; - ExecContext* ctx_; - Datum* output_; -}; + private: + ArrayCompareSorter compare_sorter_; + ArrayCountSorter count_sorter_; -class RecordBatchSelecter : public TypeVisitor { - public: - RecordBatchSelecter(ExecContext* ctx, int64_t k, const RecordBatch& record_batch, - const SortOrder order, Datum* output) - : TypeVisitor(), - k_(k), - record_batch_(record_batch), - order_(order), - ctx_(ctx), - output_(output) {} + // Cross point to prefer counting sort than stl::stable_sort(merge sort) + // - array to be sorted is longer than "count_min_len_" + // - value range (max-min) is within "count_max_range_" + // + // The optimal setting depends heavily on running CPU. Below setting is + // conservative to adapt to various hardware and keep code simple. + // It's possible to decrease array-len and/or increase value-range to cover + // more cases, or setup a table for best array-len/value-range combinations. + // See https://issues.apache.org/jira/browse/ARROW-1571 for detailed analysis. + static const uint32_t countsort_min_len_ = 1024; + static const uint32_t countsort_max_range_ = 4096; +}; - Status Run() { - // return physical_type_->Accept(this); - return Status::OK(); - } +template +struct ArraySorter; -#define VISIT(TYPE) \ - Status Visit(const TYPE& type) { return SelectKthInternal(); } +template <> +struct ArraySorter { + ArrayCountSorter impl; +}; - VISIT_PHYSICAL_TYPES(VISIT) +template <> +struct ArraySorter { + ArrayCountSorter impl; + ArraySorter() : impl(0, 255) {} +}; -#undef VISIT +template <> +struct ArraySorter { + ArrayCountSorter impl; + ArraySorter() : impl(-128, 127) {} +}; - template - Status SelectKthInternal() { - // using GetView = GetViewType; - // using ArrayType = typename TypeTraits::ArrayType; - // using HeapItem = TypedHeapItem; - - // const auto num_chunks = chunked_array_.num_chunks(); - // if (num_chunks == 0) { - // return Status::OK(); - // } - // if (k_ > chunked_array_.length()) { - // k_ = chunked_array_.length(); - // } - // std::function cmp; - // if (order_ == SortOrder::Ascending) { - // cmp = [](const HeapItem& left, const HeapItem& right) -> bool { - // const auto lval = GetView::LogicalValue(left.array->GetView(left.index)); - // const auto rval = GetView::LogicalValue(right.array->GetView(right.index)); - // return lval < rval; - // }; - // } else { - // cmp = [](const HeapItem& left, const HeapItem& right) -> bool { - // const auto lval = GetView::LogicalValue(left.array->GetView(left.index)); - // const auto rval = GetView::LogicalValue(right.array->GetView(right.index)); - // return rval < lval; - // }; - // } - // arrow::internal::Heap heap(cmp); - // std::vector> chunks_holder; - // uint64_t offset = 0; - // for (const auto& chunk : physical_chunks_) { - // if (chunk->length() == 0) continue; - // chunks_holder.emplace_back(std::make_shared(chunk->data())); - // ArrayType& arr = *chunks_holder[chunks_holder.size() - 1]; - - // std::vector indices(arr.length()); - - // uint64_t* indices_begin = indices.data(); - // uint64_t* indices_end = indices_begin + indices.size(); - // std::iota(indices_begin, indices_end, 0); - - // auto end_iter = PartitionNulls( - // indices_begin, indices_end, arr, 0); - // auto kth_begin = indices_begin + k_; - - // if (kth_begin > end_iter) { - // kth_begin = end_iter; - // } - // uint64_t* iter = indices_begin; - // for (; iter != kth_begin && heap.Size() < static_cast(k_); ++iter) { - // heap.Push(HeapItem{*iter, offset, &arr}); - // } - // for (; iter != end_iter && heap.Size() > 0; ++iter) { - // uint64_t x_index = *iter; - // const auto& xval = GetView::LogicalValue(arr.GetView(x_index)); - // auto top_item = heap.Top(); - // const auto& top_value = - // GetView::LogicalValue(top_item.array->GetView(top_item.index)); - // if (order_ == SortOrder::Ascending) { - // if (xval < top_value) { - // heap.ReplaceTop(HeapItem{x_index, offset, &arr}); - // } - // } else { - // if (top_value < xval) { - // heap.ReplaceTop(HeapItem{x_index, offset, &arr}); - // } - // } - // } - // offset += chunk->length(); - // } - // int64_t out_size = static_cast(heap.Size()); - // ARROW_ASSIGN_OR_RAISE( - // auto take_indices, - // MakeMutableArrayForFixedSizedType(uint64(), out_size, ctx_->memory_pool())); - // auto* out_cbegin = take_indices->GetMutableValues(1) + out_size - 1; - // while (heap.Size() > 0) { - // auto top_item = heap.Top(); - // *out_cbegin = top_item.index + top_item.offset; - // heap.Pop(); - // --out_cbegin; - // } - // ARROW_ASSIGN_OR_RAISE(auto chunked_select_k, - // Take(Datum(chunked_array_), Datum(std::move(take_indices)), - // TakeOptions::NoBoundsCheck(), ctx_)); - // ARROW_ASSIGN_OR_RAISE( - // auto select_k, - // Concatenate(chunked_select_k.chunked_array()->chunks(), ctx_->memory_pool())); - // *output_ = Datum(select_k); - return Status::OK(); - } +template +struct ArraySorter::value && + (sizeof(typename Type::c_type) > 1)) || + is_temporal_type::value>> { + ArrayCountOrCompareSorter impl; +}; - int64_t k_; - const RecordBatch& record_batch_; - const SortOrder order_; - ExecContext* ctx_; - Datum* output_; +template +struct ArraySorter< + Type, enable_if_t::value || is_base_binary_type::value || + is_fixed_size_binary_type::value>> { + ArrayCompareSorter impl; }; -template -class SelectKthMetaFunction { - public: - Result ExecuteImpl(const std::vector& args, - const FunctionOptions* options, ExecContext* ctx) const { - const SelectKOptions& select_k_options = static_cast(*options); +using ArraySortIndicesState = internal::OptionsWrapper; - switch (args[0].kind()) { - case Datum::ARRAY: - return SelectKth(*args[0].make_array(), select_k_options, ctx); - break; - case Datum::CHUNKED_ARRAY: - return SelectKth(*args[0].chunked_array(), select_k_options, ctx); - break; - case Datum::RECORD_BATCH: - return SelectKth(*args[0].record_batch(), select_k_options, ctx); - break; - default: - break; - } - return Status::NotImplemented( - "Unsupported types for sort_indices operation: " - "values=", - args[0].ToString()); - } +template +struct ArraySortIndices { + using ArrayType = typename TypeTraits::ArrayType; + static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { + const auto& options = ArraySortIndicesState::Get(ctx); - private: - Result SelectKth(const Array& array, const SelectKOptions& options, - ExecContext* ctx) const { - Datum output; - ArraySelecter selecter(ctx, array, options.k, sort_order, &output); - ARROW_RETURN_NOT_OK(selecter.Run()); - return output; - } + ArrayType arr(batch[0].array()); + ArrayData* out_arr = out->mutable_array(); + uint64_t* out_begin = out_arr->GetMutableValues(1); + uint64_t* out_end = out_begin + arr.length(); + std::iota(out_begin, out_end, 0); - Result SelectKth(const ChunkedArray& chunked_array, - const SelectKOptions& options, ExecContext* ctx) const { - Datum output; - ChunkedArraySelecter selecter(ctx, options.k, chunked_array, sort_order, &output); - ARROW_RETURN_NOT_OK(selecter.Run()); - return output; - } - Result SelectKth(const RecordBatch& record_batch, const SelectKOptions& options, - ExecContext* ctx) const { - Datum output; - RecordBatchSelecter selecter(ctx, options.k, record_batch, sort_order, &output); - ARROW_RETURN_NOT_OK(selecter.Run()); - return output; + ArraySorter sorter; + sorter.impl.Sort(out_begin, out_end, arr, 0, options); + + return Status::OK(); } }; -class TopKMetaFunction : public MetaFunction { - public: - TopKMetaFunction() - : MetaFunction("top_k", Arity::Binary(), &top_k_doc, &kDefaultTopKOptions) {} +// Sort indices kernels implemented for +// +// * Boolean type +// * Number types +// * Base binary types - Result ExecuteImpl(const std::vector& args, - const FunctionOptions* options, - ExecContext* ctx) const override { - SelectKthMetaFunction impl; - return impl.ExecuteImpl(args, options, ctx); - } -}; +template