Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 22 additions & 1 deletion cpp/src/arrow/compute/kernels/vector_array_sort.cc
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,19 @@ class ArrayCompareSorter<DictionaryType> {
}
};

template <>
class ArrayCompareSorter<StructType> {
public:
Result<NullPartitionResult> operator()(uint64_t* indices_begin, uint64_t* indices_end,
const Array& array, int64_t offset,
const ArraySortOptions& options,
ExecContext* ctx) {
const auto& struct_array = checked_cast<const StructArray&>(array);
return SortStructArray(ctx, indices_begin, indices_end, struct_array, options.order,
options.null_placement);
}
};

template <typename ArrowType>
class ArrayCountSorter {
using ArrayType = typename TypeTraits<ArrowType>::ArrayType;
Expand Down Expand Up @@ -497,7 +510,7 @@ template <typename Type>
struct ArraySorter<
Type, enable_if_t<is_floating_type<Type>::value || is_base_binary_type<Type>::value ||
is_fixed_size_binary_type<Type>::value ||
is_dictionary_type<Type>::value>> {
is_dictionary_type<Type>::value || is_struct_type<Type>::value>> {
ArrayCompareSorter<Type> impl;
};

Expand Down Expand Up @@ -606,6 +619,13 @@ void AddDictArraySortingKernels(VectorKernel base, VectorFunction* func) {
DCHECK_OK(func->AddKernel(base));
}

template <template <typename...> class ExecTemplate>
void AddStructArraySortingKernels(VectorKernel base, VectorFunction* func) {
base.signature = KernelSignature::Make({Type::STRUCT}, uint64());
base.exec = ExecTemplate<UInt64Type, StructType>::Exec;
DCHECK_OK(func->AddKernel(base));
}

const ArraySortOptions* GetDefaultArraySortOptions() {
static const auto kDefaultArraySortOptions = ArraySortOptions::Defaults();
return &kDefaultArraySortOptions;
Expand Down Expand Up @@ -661,6 +681,7 @@ void RegisterVectorArraySort(FunctionRegistry* registry) {
base.exec_chunked = ArraySortIndicesChunked;
AddArraySortingKernels<ArraySortIndices>(base, array_sort_indices.get());
AddDictArraySortingKernels<ArraySortIndices>(base, array_sort_indices.get());
AddStructArraySortingKernels<ArraySortIndices>(base, array_sort_indices.get());
DCHECK_OK(registry->AddFunction(std::move(array_sort_indices)));

// partition_nth_indices has a parameter so needs its init function
Expand Down
Loading