diff --git a/cpp/src/arrow/compute/kernels/vector_sort.cc b/cpp/src/arrow/compute/kernels/vector_sort.cc index 6c425d65550..7fa43e715d8 100644 --- a/cpp/src/arrow/compute/kernels/vector_sort.cc +++ b/cpp/src/arrow/compute/kernels/vector_sort.cc @@ -29,6 +29,8 @@ #include "arrow/table.h" #include "arrow/type_traits.h" #include "arrow/util/bit_block_counter.h" +#include "arrow/util/bitmap.h" +#include "arrow/util/bitmap_ops.h" #include "arrow/util/checked_cast.h" #include "arrow/util/optional.h" #include "arrow/visitor_inline.h" @@ -42,6 +44,7 @@ namespace internal { // Visit all physical types for which sorting is implemented. #define VISIT_PHYSICAL_TYPES(VISIT) \ + VISIT(BooleanType) \ VISIT(Int8Type) \ VISIT(Int16Type) \ VISIT(Int32Type) \ @@ -370,6 +373,24 @@ inline void VisitRawValuesInline(const ArrayType& values, [&](int64_t i) { visitor_not_null(data[i]); }, [&]() { visitor_null(); }); } +template +inline void VisitRawValuesInline(const BooleanArray& values, + VisitorNotNull&& visitor_not_null, + VisitorNull&& visitor_null) { + if (values.null_count() != 0) { + const uint8_t* data = values.data()->GetValues(1, 0); + VisitBitBlocksVoid( + values.null_bitmap(), values.offset(), values.length(), + [&](int64_t i) { visitor_not_null(BitUtil::GetBit(data, values.offset() + i)); }, + [&]() { visitor_null(); }); + } else { + // Can avoid GetBit() overhead in the no-nulls case + VisitBitBlocksVoid( + values.data()->buffers[1], values.offset(), values.length(), + [&](int64_t i) { visitor_not_null(true); }, [&]() { visitor_not_null(false); }); + } +} + template class ArrayCompareSorter { using ArrayType = typename TypeTraits::ArrayType; @@ -477,6 +498,42 @@ class ArrayCountSorter { } }; +using ::arrow::internal::Bitmap; + +template <> +class ArrayCountSorter { + public: + ArrayCountSorter() = default; + + // Returns where null starts. + // `offset` is used when this is called on a chunk of a chunked array + uint64_t* Sort(uint64_t* indices_begin, uint64_t* indices_end, + const BooleanArray& values, int64_t offset, + const ArraySortOptions& options) { + std::array counts{0, 0}; + + const int64_t nulls = values.null_count(); + const int64_t ones = values.true_count(); + const int64_t zeros = values.length() - ones - nulls; + + int64_t null_position = values.length() - nulls; + int64_t index = offset; + const auto nulls_begin = indices_begin + null_position; + + if (options.order == SortOrder::Ascending) { + // ones start after zeros + counts[1] = zeros; + } else { + // zeros start after ones + counts[0] = ones; + } + VisitRawValuesInline( + values, [&](bool v) { indices_begin[counts[v]++] = index++; }, + [&]() { indices_begin[null_position++] = index++; }); + return nulls_begin; + } +}; + // Sort integers with counting sort or comparison based sorting algorithm // - Use O(n) counting sort if values are in a small range // - Use O(nlogn) std::stable_sort otherwise @@ -527,6 +584,11 @@ class ArrayCountOrCompareSorter { template struct ArraySorter; +template <> +struct ArraySorter { + ArrayCountSorter impl; +}; + template <> struct ArraySorter { ArrayCountSorter impl; @@ -576,11 +638,17 @@ struct ArraySortIndices { // Sort indices kernels implemented for // +// * Boolean type // * Number types // * Base binary types template