diff --git a/cpp/src/arrow/compute/kernels/vector_array_sort.cc b/cpp/src/arrow/compute/kernels/vector_array_sort.cc index ccf691669b8..1499554a960 100644 --- a/cpp/src/arrow/compute/kernels/vector_array_sort.cc +++ b/cpp/src/arrow/compute/kernels/vector_array_sort.cc @@ -143,9 +143,9 @@ class ArrayCompareSorter { using GetView = GetViewType; public: - NullPartitionResult operator()(uint64_t* indices_begin, uint64_t* indices_end, - const Array& array, int64_t offset, - const ArraySortOptions& options) { + Result operator()(uint64_t* indices_begin, uint64_t* indices_end, + const Array& array, int64_t offset, + const ArraySortOptions& options, ExecContext*) { const auto& values = checked_cast(array); const auto p = PartitionNulls( @@ -173,6 +173,95 @@ class ArrayCompareSorter { } }; +template <> +class ArrayCompareSorter { + public: + Result operator()(uint64_t* indices_begin, uint64_t* indices_end, + const Array& array, int64_t offset, + const ArraySortOptions& options, + ExecContext* ctx) { + const auto& dict_array = checked_cast(array); + // TODO: These methods should probably return a const&? They seem capable. + // https://github.com/apache/arrow/issues/35437 + auto dict_values = dict_array.dictionary(); + auto dict_indices = dict_array.indices(); + + // Algorithm: + // 1) Use the Rank function to get an exactly-equivalent-order array + // of the dictionary values, but with a datatype that's friendlier to + // sorting (uint64). + // 2) Act as if we were sorting a dictionary array with the same indices, + // but with the ranks as dictionary values. + // 2a) Dictionary-decode the ranks by calling Take. + // 2b) Sort the decoded ranks. Not only those are uint64, they are dense + // in a [0, k) range where k is the number of unique dictionary values. + // Therefore, unless the dictionary is very large, a fast counting sort + // will be used. + // + // The bottom line is that performance will usually be much better + // (potentially an order of magnitude faster) than by naively decoding + // the original dictionary and sorting the decoded version. + + std::shared_ptr decoded_ranks; + // Skip the rank/take steps for cases with only nulls or no indices + if (dict_indices->length() == 0 || + dict_indices->null_count() == dict_indices->length() || + dict_values->null_count() == dict_values->length()) { + ARROW_ASSIGN_OR_RAISE(decoded_ranks, MakeArrayOfNull(uint64(), dict_array.length(), + ctx->memory_pool())); + } else { + ARROW_ASSIGN_OR_RAISE(auto ranks, RanksWithNulls(dict_values, ctx)); + + ARROW_ASSIGN_OR_RAISE(decoded_ranks, + Take(*ranks, *dict_indices, TakeOptions::Defaults(), ctx)); + } + + DCHECK_EQ(decoded_ranks->type_id(), Type::UINT64); + DCHECK_EQ(decoded_ranks->length(), dict_array.length()); + ARROW_ASSIGN_OR_RAISE(auto rank_sorter, GetArraySorter(*decoded_ranks->type())); + + return rank_sorter(indices_begin, indices_end, *decoded_ranks, offset, options, ctx); + } + + private: + static Result> RanksWithNulls( + const std::shared_ptr& array, ExecContext* ctx) { + // Notes: + // * The order is always ascending here, since the goal is to produce + // an exactly-equivalent-order of the dictionary values. + // * We're going to re-emit nulls in the output, so we can just always consider + // them "at the end". Note that choosing AtStart would merely shift other + // ranks by 1 if there are any nulls... + RankOptions rank_options(SortOrder::Ascending, NullPlacement::AtEnd, + RankOptions::Dense); + + // XXX Should this support Type::NA? + auto data = array->data(); + std::shared_ptr null_bitmap; + if (array->null_count() > 0) { + null_bitmap = array->null_bitmap(); + data = array->data()->Copy(); + if (data->offset > 0) { + ARROW_ASSIGN_OR_RAISE(null_bitmap, arrow::internal::CopyBitmap( + ctx->memory_pool(), null_bitmap->data(), + data->offset, data->length)); + } + data->buffers[0] = nullptr; + data->null_count = 0; + } + ARROW_ASSIGN_OR_RAISE(auto rank_datum, + CallFunction("rank", {std::move(data)}, &rank_options, ctx)); + auto rank_data = rank_datum.array(); + DCHECK_EQ(rank_data->GetNullCount(), 0); + // If there were nulls in the input, paste them in the output + if (null_bitmap) { + rank_data->buffers[0] = std::move(null_bitmap); + rank_data->null_count = array->null_count(); + } + return MakeArray(rank_data); + } +}; + template class ArrayCountSorter { using ArrayType = typename TypeTraits::ArrayType; @@ -189,9 +278,10 @@ class ArrayCountSorter { value_range_ = static_cast(max - min) + 1; } - NullPartitionResult operator()(uint64_t* indices_begin, uint64_t* indices_end, - const Array& array, int64_t offset, - const ArraySortOptions& options) const { + Result operator()(uint64_t* indices_begin, uint64_t* indices_end, + const Array& array, int64_t offset, + const ArraySortOptions& options, + ExecContext*) const { const auto& values = checked_cast(array); // 32bit counter performs much better than 64bit one @@ -273,9 +363,9 @@ class ArrayCountSorter { public: ArrayCountSorter() = default; - NullPartitionResult operator()(uint64_t* indices_begin, uint64_t* indices_end, - const Array& array, int64_t offset, - const ArraySortOptions& options) { + Result operator()(uint64_t* indices_begin, uint64_t* indices_end, + const Array& array, int64_t offset, + const ArraySortOptions& options, ExecContext*) { const auto& values = checked_cast(array); std::array counts{0, 0, 0}; // false, true, null @@ -318,9 +408,10 @@ class ArrayCountOrCompareSorter { using c_type = typename ArrowType::c_type; public: - NullPartitionResult operator()(uint64_t* indices_begin, uint64_t* indices_end, - const Array& array, int64_t offset, - const ArraySortOptions& options) { + Result operator()(uint64_t* indices_begin, uint64_t* indices_end, + const Array& array, int64_t offset, + const ArraySortOptions& options, + ExecContext* ctx) { const auto& values = checked_cast(array); if (values.length() >= countsort_min_len_ && values.length() > values.null_count()) { @@ -332,11 +423,11 @@ class ArrayCountOrCompareSorter { if (static_cast(max) - static_cast(min) <= countsort_max_range_) { count_sorter_.SetMinMax(min, max); - return count_sorter_(indices_begin, indices_end, values, offset, options); + return count_sorter_(indices_begin, indices_end, values, offset, options, ctx); } } - return compare_sorter_(indices_begin, indices_end, values, offset, options); + return compare_sorter_(indices_begin, indices_end, values, offset, options, ctx); } private: @@ -358,9 +449,9 @@ class ArrayCountOrCompareSorter { class ArrayNullSorter { public: - NullPartitionResult operator()(uint64_t* indices_begin, uint64_t* indices_end, - const Array& values, int64_t offset, - const ArraySortOptions& options) { + Result operator()(uint64_t* indices_begin, uint64_t* indices_end, + const Array& values, int64_t offset, + const ArraySortOptions& options, ExecContext*) { return NullPartitionResult::NullsOnly(indices_begin, indices_end, options.null_placement); } @@ -405,7 +496,8 @@ struct ArraySorter::value && template struct ArraySorter< Type, enable_if_t::value || is_base_binary_type::value || - is_fixed_size_binary_type::value>> { + is_fixed_size_binary_type::value || + is_dictionary_type::value>> { ArrayCompareSorter impl; }; @@ -445,8 +537,7 @@ struct ArraySortIndices { ArrayType arr(batch[0].array.ToArrayData()); ARROW_ASSIGN_OR_RAISE(auto sorter, GetArraySorter(*GetPhysicalType(arr.type()))); - sorter(out_begin, out_end, arr, 0, options); - return Status::OK(); + return sorter(out_begin, out_end, arr, 0, options, ctx->exec_context()).status(); } }; @@ -508,6 +599,13 @@ void AddArraySortingKernels(VectorKernel base, VectorFunction* func) { DCHECK_OK(func->AddKernel(base)); } +template