From 3c4f86fcf75eee63fbdd8175e4c7ac204165ffd1 Mon Sep 17 00:00:00 2001 From: Rossi Sun Date: Fri, 17 Jan 2025 18:00:20 +0800 Subject: [PATCH 1/3] Refine the structure WIP --- cpp/src/arrow/compute/kernels/vector_rank.cc | 200 ++++++++++++++++++- 1 file changed, 194 insertions(+), 6 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/vector_rank.cc b/cpp/src/arrow/compute/kernels/vector_rank.cc index 50af9c6d599..6e79da6423b 100644 --- a/cpp/src/arrow/compute/kernels/vector_rank.cc +++ b/cpp/src/arrow/compute/kernels/vector_rank.cc @@ -299,6 +299,161 @@ class Ranker : public RankerMixin +Result SortAndMarkDup(const Array& input, uint64_t* indices_begin, + uint64_t* indices_end, SortOrder order, + NullPlacement null_placement, + bool needs_duplicates, ExecContext* ctx) { + using GetView = GetViewType; + using ArrayType = typename TypeTraits::ArrayType; + + ARROW_ASSIGN_OR_RAISE(auto array_sorter, GetArraySorter(*input.type())); + + ArrayType array(input.data()); + ARROW_ASSIGN_OR_RAISE(auto sorted, + array_sorter(indices_begin, indices_end, array, 0, + ArraySortOptions(order, null_placement), ctx)); + + if (needs_duplicates) { + auto value_selector = [&array](int64_t index) { + return GetView::LogicalValue(array.GetView(index)); + }; + MarkDuplicates(sorted, value_selector); + } + return sorted; +} + +template +Result SortAndMarkDup(const ChunkedArray& input, + uint64_t* indices_begin, uint64_t* indices_end, + SortOrder order, NullPlacement null_placement, + bool needs_duplicates, ExecContext* ctx) { + auto physical_type = GetPhysicalType(input.type()); + auto physical_chunks = GetPhysicalChunks(input, physical_type); + if (physical_chunks.empty()) { + return NullPartitionResult{}; + } + ARROW_ASSIGN_OR_RAISE(auto sorted, + SortChunkedArray(ctx, indices_begin, indices_end, physical_type, + physical_chunks, order, null_placement)); + if (needs_duplicates) { + const auto arrays = GetArrayPointers(physical_chunks); + auto value_selector = [resolver = ChunkedArrayResolver(span(arrays))](int64_t index) { + return resolver.Resolve(index).Value(); + }; + MarkDuplicates(sorted, value_selector); + } + return sorted; +} + +struct PercentileRanker { + explicit PercentileRanker(double factor) : factor_(factor) {} + + Result CreateRankings(ExecContext* ctx, const NullPartitionResult& sorted) { + const int64_t length = sorted.overall_end() - sorted.overall_begin(); + ARROW_ASSIGN_OR_RAISE(auto rankings, + MakeMutableFloat64Array(length, ctx->memory_pool())); + auto out_begin = rankings->GetMutableValues(1); + + auto is_duplicate = [](uint64_t index) { return (index & kDuplicateMask) != 0; }; + auto original_index = [](uint64_t index) { return index & ~kDuplicateMask; }; + + // The count of values strictly less than the value being considered + int64_t cum_freq = 0; + auto it = sorted.overall_begin(); + + while (it < sorted.overall_end()) { + // Look for a run of duplicate values + DCHECK(!is_duplicate(*it)); + auto run_end = it; + while (++run_end < sorted.overall_end() && is_duplicate(*run_end)) { + } + // The run length, i.e. the frequency of the current value + int64_t freq = run_end - it; + double percentile = (cum_freq + 0.5 * freq) * factor_ / static_cast(length); + // Output percentile rank values + for (; it < run_end; ++it) { + out_begin[original_index(*it)] = percentile; + } + cum_freq += freq; + } + DCHECK_EQ(cum_freq, length); + return Datum(rankings); + } + + private: + const double factor_; +}; + +// A helper class that emits rankings for the "rank" function +struct OrdinalRanker { + explicit OrdinalRanker(RankOptions::Tiebreaker tiebreaker) : tiebreaker_(tiebreaker) {} + + Result CreateRankings(ExecContext* ctx, const NullPartitionResult& sorted) { + const int64_t length = sorted.overall_end() - sorted.overall_begin(); + ARROW_ASSIGN_OR_RAISE(auto rankings, + MakeMutableUInt64Array(length, ctx->memory_pool())); + auto out_begin = rankings->GetMutableValues(1); + uint64_t rank; + + auto is_duplicate = [](uint64_t index) { return (index & kDuplicateMask) != 0; }; + auto original_index = [](uint64_t index) { return index & ~kDuplicateMask; }; + + switch (tiebreaker_) { + case RankOptions::Dense: { + rank = 0; + for (auto it = sorted.overall_begin(); it < sorted.overall_end(); ++it) { + if (!is_duplicate(*it)) { + ++rank; + } + out_begin[original_index(*it)] = rank; + } + break; + } + + case RankOptions::First: { + rank = 0; + for (auto it = sorted.overall_begin(); it < sorted.overall_end(); it++) { + // No duplicate marks expected for RankOptions::First + DCHECK(!is_duplicate(*it)); + out_begin[*it] = ++rank; + } + break; + } + + case RankOptions::Min: { + rank = 0; + for (auto it = sorted.overall_begin(); it < sorted.overall_end(); ++it) { + if (!is_duplicate(*it)) { + rank = (it - sorted.overall_begin()) + 1; + } + out_begin[original_index(*it)] = rank; + } + break; + } + + case RankOptions::Max: { + rank = length; + for (auto it = sorted.overall_end() - 1; it >= sorted.overall_begin(); --it) { + out_begin[original_index(*it)] = rank; + // If the current index isn't marked as duplicate, then it's the last + // tie in a row (since we iterate in reverse order), so update rank + // for the next row of ties. + if (!is_duplicate(*it)) { + rank = it - sorted.overall_begin(); + } + } + break; + } + } + + return Datum(rankings); + } + + private: + const RankOptions::Tiebreaker tiebreaker_; +}; + const FunctionDoc rank_doc( "Compute ordinal ranks of an array (1-based)", ("This function computes a rank of the input array.\n" @@ -324,6 +479,7 @@ const FunctionDoc rank_percentile_doc( "in RankPercentileOptions."), {"input"}, "RankPercentileOptions"); +template class RankMetaFunctionBase : public MetaFunction { public: using MetaFunction::MetaFunction; @@ -359,7 +515,13 @@ class RankMetaFunctionBase : public MetaFunction { template Result Rank(const T& input, const FunctionOptions& function_options, ExecContext* ctx) const { - auto options = UnpackOptions(function_options); + const auto& options = + checked_cast(function_options); + + // SortOrder order = SortOrder::Ascending; + // if (!options.sort_keys.empty()) { + // order = options.sort_keys[0].order; + // } int64_t length = input.length(); ARROW_ASSIGN_OR_RAISE(auto indices, @@ -368,17 +530,34 @@ class RankMetaFunctionBase : public MetaFunction { auto* indices_end = indices_begin + length; std::iota(indices_begin, indices_end, 0); - Ranker ranker(ctx, indices_begin, indices_end, input, options.order, - options.null_placement, options.emitter.get()); - return ranker.Run(); + // auto needs_duplicates = static_cast(this)->NeedsDuplicates(options); + // ARROW_ASSIGN_OR_RAISE(auto sorted, + // SortAndMarkDup(input, indices_begin, indices_end, order, + // options.null_placement, needs_duplicates, + // ctx)); + NullPartitionResult sorted; + auto ranker = static_cast(this)->GetRanker(options); + + return ranker.CreateRankings(ctx, sorted); } }; -class RankMetaFunction : public RankMetaFunctionBase { +class RankMetaFunction : public RankMetaFunctionBase { public: + using FunctionOptionsType = RankOptions; + using RankerType = OrdinalRanker; + RankMetaFunction() : RankMetaFunctionBase("rank", Arity::Unary(), rank_doc, GetDefaultRankOptions()) {} + bool NeedsDuplicates(const RankOptions& options) const { + return options.tiebreaker != RankOptions::First; + } + + RankerType GetRanker(const RankOptions& options) const { + return RankerType(options.tiebreaker); + } + protected: UnpackedOptions UnpackOptions(const FunctionOptions& function_options) const override { const auto& options = checked_cast(function_options); @@ -392,12 +571,21 @@ class RankMetaFunction : public RankMetaFunctionBase { } }; -class RankPercentileMetaFunction : public RankMetaFunctionBase { +class RankPercentileMetaFunction : public RankMetaFunctionBase { public: + using FunctionOptionsType = RankPercentileOptions; + using RankerType = PercentileRanker; + RankPercentileMetaFunction() : RankMetaFunctionBase("rank_percentile", Arity::Unary(), rank_percentile_doc, GetDefaultPercentileRankOptions()) {} + bool NeedsDuplicates(const RankPercentileOptions&) const { return true; } + + RankerType GetRanker(const RankPercentileOptions& options) const { + return RankerType(options.factor); + } + protected: UnpackedOptions UnpackOptions(const FunctionOptions& function_options) const override { const auto& options = checked_cast(function_options); From 9f6ad8dd49f982d298f5186321846146056d6609 Mon Sep 17 00:00:00 2001 From: Rossi Sun Date: Fri, 17 Jan 2025 21:36:50 +0800 Subject: [PATCH 2/3] Refinement done --- cpp/src/arrow/compute/kernels/vector_rank.cc | 375 ++++--------------- 1 file changed, 81 insertions(+), 294 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/vector_rank.cc b/cpp/src/arrow/compute/kernels/vector_rank.cc index 6e79da6423b..6af7420626c 100644 --- a/cpp/src/arrow/compute/kernels/vector_rank.cc +++ b/cpp/src/arrow/compute/kernels/vector_rank.cc @@ -63,129 +63,6 @@ void MarkDuplicates(const NullPartitionResult& sorted, ValueSelector&& value_sel } } -struct RankingsEmitter { - virtual ~RankingsEmitter() = default; - virtual bool NeedsDuplicates() = 0; - virtual Result CreateRankings(ExecContext* ctx, - const NullPartitionResult& sorted) = 0; -}; - -// A helper class that emits rankings for the "rank_percentile" function -struct PercentileRankingsEmitter : public RankingsEmitter { - explicit PercentileRankingsEmitter(double factor) : factor_(factor) {} - - bool NeedsDuplicates() override { return true; } - - Result CreateRankings(ExecContext* ctx, - const NullPartitionResult& sorted) override { - const int64_t length = sorted.overall_end() - sorted.overall_begin(); - ARROW_ASSIGN_OR_RAISE(auto rankings, - MakeMutableFloat64Array(length, ctx->memory_pool())); - auto out_begin = rankings->GetMutableValues(1); - - auto is_duplicate = [](uint64_t index) { return (index & kDuplicateMask) != 0; }; - auto original_index = [](uint64_t index) { return index & ~kDuplicateMask; }; - - // The count of values strictly less than the value being considered - int64_t cum_freq = 0; - auto it = sorted.overall_begin(); - - while (it < sorted.overall_end()) { - // Look for a run of duplicate values - DCHECK(!is_duplicate(*it)); - auto run_end = it; - while (++run_end < sorted.overall_end() && is_duplicate(*run_end)) { - } - // The run length, i.e. the frequency of the current value - int64_t freq = run_end - it; - double percentile = (cum_freq + 0.5 * freq) * factor_ / static_cast(length); - // Output percentile rank values - for (; it < run_end; ++it) { - out_begin[original_index(*it)] = percentile; - } - cum_freq += freq; - } - DCHECK_EQ(cum_freq, length); - return Datum(rankings); - } - - private: - const double factor_; -}; - -// A helper class that emits rankings for the "rank" function -struct OrdinalRankingsEmitter : public RankingsEmitter { - explicit OrdinalRankingsEmitter(RankOptions::Tiebreaker tiebreaker) - : tiebreaker_(tiebreaker) {} - - bool NeedsDuplicates() override { return tiebreaker_ != RankOptions::First; } - - Result CreateRankings(ExecContext* ctx, - const NullPartitionResult& sorted) override { - const int64_t length = sorted.overall_end() - sorted.overall_begin(); - ARROW_ASSIGN_OR_RAISE(auto rankings, - MakeMutableUInt64Array(length, ctx->memory_pool())); - auto out_begin = rankings->GetMutableValues(1); - uint64_t rank; - - auto is_duplicate = [](uint64_t index) { return (index & kDuplicateMask) != 0; }; - auto original_index = [](uint64_t index) { return index & ~kDuplicateMask; }; - - switch (tiebreaker_) { - case RankOptions::Dense: { - rank = 0; - for (auto it = sorted.overall_begin(); it < sorted.overall_end(); ++it) { - if (!is_duplicate(*it)) { - ++rank; - } - out_begin[original_index(*it)] = rank; - } - break; - } - - case RankOptions::First: { - rank = 0; - for (auto it = sorted.overall_begin(); it < sorted.overall_end(); it++) { - // No duplicate marks expected for RankOptions::First - DCHECK(!is_duplicate(*it)); - out_begin[*it] = ++rank; - } - break; - } - - case RankOptions::Min: { - rank = 0; - for (auto it = sorted.overall_begin(); it < sorted.overall_end(); ++it) { - if (!is_duplicate(*it)) { - rank = (it - sorted.overall_begin()) + 1; - } - out_begin[original_index(*it)] = rank; - } - break; - } - - case RankOptions::Max: { - rank = length; - for (auto it = sorted.overall_end() - 1; it >= sorted.overall_begin(); --it) { - out_begin[original_index(*it)] = rank; - // If the current index isn't marked as duplicate, then it's the last - // tie in a row (since we iterate in reverse order), so update rank - // for the next row of ties. - if (!is_duplicate(*it)) { - rank = it - sorted.overall_begin(); - } - } - break; - } - } - - return Datum(rankings); - } - - private: - const RankOptions::Tiebreaker tiebreaker_; -}; - const RankOptions* GetDefaultRankOptions() { static const auto kDefaultRankOptions = RankOptions::Defaults(); return &kDefaultRankOptions; @@ -196,118 +73,15 @@ const RankPercentileOptions* GetDefaultPercentileRankOptions() { return &kDefaultPercentileRankOptions; } -template -class RankerMixin : public TypeVisitor { - public: - RankerMixin(ExecContext* ctx, uint64_t* indices_begin, uint64_t* indices_end, - const InputType& input, const SortOrder order, - const NullPlacement null_placement, RankingsEmitter* emitter) - : TypeVisitor(), - ctx_(ctx), - indices_begin_(indices_begin), - indices_end_(indices_end), - input_(input), - order_(order), - null_placement_(null_placement), - physical_type_(GetPhysicalType(input.type())), - emitter_(emitter) {} - - Result Run() { - RETURN_NOT_OK(physical_type_->Accept(this)); - return emitter_->CreateRankings(ctx_, sorted_); - } - -#define VISIT(TYPE) \ - Status Visit(const TYPE& type) { \ - return static_cast(this)->template SortAndMarkDuplicates(); \ - } - - VISIT_SORTABLE_PHYSICAL_TYPES(VISIT) - -#undef VISIT - - protected: - ExecContext* ctx_; - uint64_t* indices_begin_; - uint64_t* indices_end_; - const InputType& input_; - const SortOrder order_; - const NullPlacement null_placement_; - const std::shared_ptr physical_type_; - RankingsEmitter* emitter_; - NullPartitionResult sorted_{}; -}; - -template -class Ranker; - -template <> -class Ranker : public RankerMixin> { - public: - using RankerMixin::RankerMixin; - - template - Status SortAndMarkDuplicates() { - using GetView = GetViewType; - using ArrayType = typename TypeTraits::ArrayType; - - ARROW_ASSIGN_OR_RAISE(auto array_sorter, GetArraySorter(*physical_type_)); - - ArrayType array(input_.data()); - ARROW_ASSIGN_OR_RAISE(sorted_, - array_sorter(indices_begin_, indices_end_, array, 0, - ArraySortOptions(order_, null_placement_), ctx_)); - - if (emitter_->NeedsDuplicates()) { - auto value_selector = [&array](int64_t index) { - return GetView::LogicalValue(array.GetView(index)); - }; - MarkDuplicates(sorted_, value_selector); - } - return Status::OK(); - } -}; - -template <> -class Ranker : public RankerMixin> { - public: - template - explicit Ranker(Args&&... args) - : RankerMixin(std::forward(args)...), - physical_chunks_(GetPhysicalChunks(input_, physical_type_)) {} - - template - Status SortAndMarkDuplicates() { - if (physical_chunks_.empty()) { - return Status::OK(); - } - ARROW_ASSIGN_OR_RAISE( - sorted_, SortChunkedArray(ctx_, indices_begin_, indices_end_, physical_type_, - physical_chunks_, order_, null_placement_)); - if (emitter_->NeedsDuplicates()) { - const auto arrays = GetArrayPointers(physical_chunks_); - auto value_selector = [resolver = - ChunkedArrayResolver(span(arrays))](int64_t index) { - return resolver.Resolve(index).Value(); - }; - MarkDuplicates(sorted_, value_selector); - } - return Status::OK(); - } +template +Result DoSortAndMarkDuplicate( + ExecContext* ctx, uint64_t* indices_begin, uint64_t* indices_end, const Array& input, + const std::shared_ptr& physical_type, const SortOrder order, + const NullPlacement null_placement, bool needs_duplicates) { + using GetView = GetViewType; + using ArrayType = typename TypeTraits::ArrayType; - private: - const ArrayVector physical_chunks_; -}; - -template -Result SortAndMarkDup(const Array& input, uint64_t* indices_begin, - uint64_t* indices_end, SortOrder order, - NullPlacement null_placement, - bool needs_duplicates, ExecContext* ctx) { - using GetView = GetViewType; - using ArrayType = typename TypeTraits::ArrayType; - - ARROW_ASSIGN_OR_RAISE(auto array_sorter, GetArraySorter(*input.type())); + ARROW_ASSIGN_OR_RAISE(auto array_sorter, GetArraySorter(*physical_type)); ArrayType array(input.data()); ARROW_ASSIGN_OR_RAISE(auto sorted, @@ -323,12 +97,11 @@ Result SortAndMarkDup(const Array& input, uint64_t* indices return sorted; } -template -Result SortAndMarkDup(const ChunkedArray& input, - uint64_t* indices_begin, uint64_t* indices_end, - SortOrder order, NullPlacement null_placement, - bool needs_duplicates, ExecContext* ctx) { - auto physical_type = GetPhysicalType(input.type()); +template +Result DoSortAndMarkDuplicate( + ExecContext* ctx, uint64_t* indices_begin, uint64_t* indices_end, + const ChunkedArray& input, const std::shared_ptr& physical_type, + const SortOrder order, const NullPlacement null_placement, bool needs_duplicates) { auto physical_chunks = GetPhysicalChunks(input, physical_type); if (physical_chunks.empty()) { return NullPartitionResult{}; @@ -339,13 +112,59 @@ Result SortAndMarkDup(const ChunkedArray& input, if (needs_duplicates) { const auto arrays = GetArrayPointers(physical_chunks); auto value_selector = [resolver = ChunkedArrayResolver(span(arrays))](int64_t index) { - return resolver.Resolve(index).Value(); + return resolver.Resolve(index).Value(); }; MarkDuplicates(sorted, value_selector); } return sorted; } +template +class SortAndMarkDuplicate : public TypeVisitor { + public: + SortAndMarkDuplicate(ExecContext* ctx, uint64_t* indices_begin, uint64_t* indices_end, + const InputType& input, const SortOrder order, + const NullPlacement null_placement, const bool needs_duplicate) + : TypeVisitor(), + ctx_(ctx), + indices_begin_(indices_begin), + indices_end_(indices_end), + input_(input), + order_(order), + null_placement_(null_placement), + needs_duplicates_(needs_duplicate), + physical_type_(GetPhysicalType(input.type())) {} + + Result Run() { + RETURN_NOT_OK(physical_type_->Accept(this)); + return std::move(sorted_); + } + +#define VISIT(TYPE) \ + Status Visit(const TYPE& type) { \ + ARROW_ASSIGN_OR_RAISE( \ + sorted_, DoSortAndMarkDuplicate(ctx_, indices_begin_, indices_end_, \ + input_, physical_type_, order_, \ + null_placement_, needs_duplicates_)); \ + return Status::OK(); \ + } + + VISIT_SORTABLE_PHYSICAL_TYPES(VISIT) + +#undef VISIT + + private: + ExecContext* ctx_; + uint64_t* indices_begin_; + uint64_t* indices_end_; + const InputType& input_; + const SortOrder order_; + const NullPlacement null_placement_; + const bool needs_duplicates_; + const std::shared_ptr physical_type_; + NullPartitionResult sorted_{}; +}; + struct PercentileRanker { explicit PercentileRanker(double factor) : factor_(factor) {} @@ -504,24 +323,16 @@ class RankMetaFunctionBase : public MetaFunction { } protected: - struct UnpackedOptions { - SortOrder order{SortOrder::Ascending}; - NullPlacement null_placement; - std::unique_ptr emitter; - }; - - virtual UnpackedOptions UnpackOptions(const FunctionOptions&) const = 0; - template Result Rank(const T& input, const FunctionOptions& function_options, ExecContext* ctx) const { const auto& options = checked_cast(function_options); - // SortOrder order = SortOrder::Ascending; - // if (!options.sort_keys.empty()) { - // order = options.sort_keys[0].order; - // } + SortOrder order = SortOrder::Ascending; + if (!options.sort_keys.empty()) { + order = options.sort_keys[0].order; + } int64_t length = input.length(); ARROW_ASSIGN_OR_RAISE(auto indices, @@ -529,15 +340,13 @@ class RankMetaFunctionBase : public MetaFunction { auto* indices_begin = indices->GetMutableValues(1); auto* indices_end = indices_begin + length; std::iota(indices_begin, indices_end, 0); + auto needs_duplicates = Impl::NeedsDuplicates(options); + ARROW_ASSIGN_OR_RAISE( + auto sorted, SortAndMarkDuplicate(ctx, indices_begin, indices_end, input, order, + options.null_placement, needs_duplicates) + .Run()); - // auto needs_duplicates = static_cast(this)->NeedsDuplicates(options); - // ARROW_ASSIGN_OR_RAISE(auto sorted, - // SortAndMarkDup(input, indices_begin, indices_end, order, - // options.null_placement, needs_duplicates, - // ctx)); - NullPartitionResult sorted; - auto ranker = static_cast(this)->GetRanker(options); - + auto ranker = Impl::GetRanker(options); return ranker.CreateRankings(ctx, sorted); } }; @@ -547,55 +356,33 @@ class RankMetaFunction : public RankMetaFunctionBase { using FunctionOptionsType = RankOptions; using RankerType = OrdinalRanker; - RankMetaFunction() - : RankMetaFunctionBase("rank", Arity::Unary(), rank_doc, GetDefaultRankOptions()) {} - - bool NeedsDuplicates(const RankOptions& options) const { + static bool NeedsDuplicates(const RankOptions& options) { return options.tiebreaker != RankOptions::First; } - RankerType GetRanker(const RankOptions& options) const { + static RankerType GetRanker(const RankOptions& options) { return RankerType(options.tiebreaker); } - protected: - UnpackedOptions UnpackOptions(const FunctionOptions& function_options) const override { - const auto& options = checked_cast(function_options); - UnpackedOptions unpacked{ - SortOrder::Ascending, options.null_placement, - std::make_unique(options.tiebreaker)}; - if (!options.sort_keys.empty()) { - unpacked.order = options.sort_keys[0].order; - } - return unpacked; - } + RankMetaFunction() + : RankMetaFunctionBase("rank", Arity::Unary(), rank_doc, GetDefaultRankOptions()) {} }; -class RankPercentileMetaFunction : public RankMetaFunctionBase { +class RankPercentileMetaFunction + : public RankMetaFunctionBase { public: using FunctionOptionsType = RankPercentileOptions; using RankerType = PercentileRanker; - RankPercentileMetaFunction() - : RankMetaFunctionBase("rank_percentile", Arity::Unary(), rank_percentile_doc, - GetDefaultPercentileRankOptions()) {} - - bool NeedsDuplicates(const RankPercentileOptions&) const { return true; } + static bool NeedsDuplicates(const RankPercentileOptions&) { return true; } - RankerType GetRanker(const RankPercentileOptions& options) const { + static RankerType GetRanker(const RankPercentileOptions& options) { return RankerType(options.factor); } - protected: - UnpackedOptions UnpackOptions(const FunctionOptions& function_options) const override { - const auto& options = checked_cast(function_options); - UnpackedOptions unpacked{SortOrder::Ascending, options.null_placement, - std::make_unique(options.factor)}; - if (!options.sort_keys.empty()) { - unpacked.order = options.sort_keys[0].order; - } - return unpacked; - } + RankPercentileMetaFunction() + : RankMetaFunctionBase("rank_percentile", Arity::Unary(), rank_percentile_doc, + GetDefaultPercentileRankOptions()) {} }; } // namespace From e675eb2bf71304d4b43df9cb771954ce653565af Mon Sep 17 00:00:00 2001 From: Rossi Sun Date: Mon, 20 Jan 2025 17:30:03 +0800 Subject: [PATCH 3/3] Address review comments --- cpp/src/arrow/compute/kernels/vector_rank.cc | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/vector_rank.cc b/cpp/src/arrow/compute/kernels/vector_rank.cc index 6af7420626c..e0069a1f2c4 100644 --- a/cpp/src/arrow/compute/kernels/vector_rank.cc +++ b/cpp/src/arrow/compute/kernels/vector_rank.cc @@ -137,7 +137,7 @@ class SortAndMarkDuplicate : public TypeVisitor { Result Run() { RETURN_NOT_OK(physical_type_->Accept(this)); - return std::move(sorted_); + return sorted_; } #define VISIT(TYPE) \ @@ -165,6 +165,7 @@ class SortAndMarkDuplicate : public TypeVisitor { NullPartitionResult sorted_{}; }; +// A helper class that emits rankings for the "rank_percentile" function struct PercentileRanker { explicit PercentileRanker(double factor) : factor_(factor) {} @@ -298,7 +299,7 @@ const FunctionDoc rank_percentile_doc( "in RankPercentileOptions."), {"input"}, "RankPercentileOptions"); -template +template class RankMetaFunctionBase : public MetaFunction { public: using MetaFunction::MetaFunction; @@ -327,7 +328,7 @@ class RankMetaFunctionBase : public MetaFunction { Result Rank(const T& input, const FunctionOptions& function_options, ExecContext* ctx) const { const auto& options = - checked_cast(function_options); + checked_cast(function_options); SortOrder order = SortOrder::Ascending; if (!options.sort_keys.empty()) { @@ -340,13 +341,13 @@ class RankMetaFunctionBase : public MetaFunction { auto* indices_begin = indices->GetMutableValues(1); auto* indices_end = indices_begin + length; std::iota(indices_begin, indices_end, 0); - auto needs_duplicates = Impl::NeedsDuplicates(options); + auto needs_duplicates = Derived::NeedsDuplicates(options); ARROW_ASSIGN_OR_RAISE( auto sorted, SortAndMarkDuplicate(ctx, indices_begin, indices_end, input, order, options.null_placement, needs_duplicates) .Run()); - auto ranker = Impl::GetRanker(options); + auto ranker = Derived::GetRanker(options); return ranker.CreateRankings(ctx, sorted); } };