diff --git a/cpp/src/arrow/compute/kernels/chunked_internal.h b/cpp/src/arrow/compute/kernels/chunked_internal.h index 69f439fccf0..89ddcbcab01 100644 --- a/cpp/src/arrow/compute/kernels/chunked_internal.h +++ b/cpp/src/arrow/compute/kernels/chunked_internal.h @@ -61,6 +61,19 @@ struct ResolvedChunk { bool IsNull() const { return array->IsNull(index); } }; +// ResolvedChunk specialization for StructArray +template <> +struct ResolvedChunk { + // The target structarray in chunked array. + const StructArray* array; + // The index in the target array. + const int64_t index; + + ResolvedChunk(const StructArray* array, int64_t index) : array(array), index(index) {} + + bool IsNull() const { return array->field(static_cast(index)) == nullptr; } +}; + struct ChunkedArrayResolver : protected ::arrow::internal::ChunkResolver { ChunkedArrayResolver(const ChunkedArrayResolver& other) : ::arrow::internal::ChunkResolver(other.chunks_), chunks_(other.chunks_) {} diff --git a/cpp/src/arrow/compute/kernels/vector_array_sort.cc b/cpp/src/arrow/compute/kernels/vector_array_sort.cc index 324a435441f..3d37b8f4942 100644 --- a/cpp/src/arrow/compute/kernels/vector_array_sort.cc +++ b/cpp/src/arrow/compute/kernels/vector_array_sort.cc @@ -173,6 +173,49 @@ class ArrayCompareSorter { } }; +template +class StructArrayCompareSorter { + using ArrayType = typename TypeTraits::ArrayType; + + public: + // `offset` is used when this is called on a chunk of a chunked array + NullPartitionResult operator()(uint64_t* indices_begin, uint64_t* indices_end, + const Array& array, int64_t offset, + const ArraySortOptions& options) { + const auto& values = checked_cast(array); + nested_value_comparator_ = std::make_shared(); + + if (nested_value_comparator_->Prepare(values) != Status::OK()) { + // TODO: Improve error handling + return NullPartitionResult(); + } + + const auto p = PartitionNulls( + indices_begin, indices_end, values, offset, options.null_placement); + + bool asc_order = options.order == SortOrder::Ascending; + std::stable_sort(p.non_nulls_begin, p.non_nulls_end, + [&offset, &values, asc_order, this](uint64_t left, uint64_t right) { + // is better to do values.fields.size() or + // values.schema().num_fields() ? + for (ArrayVector::size_type fieldidx = 0; + fieldidx < values.fields().size(); ++fieldidx) { + int result = nested_value_comparator_->Compare( + values, fieldidx, offset, asc_order ? left : right, + asc_order ? right : left); + if (result == -1) + return true; + else if (result == 1) + return false; + } + return false; + }); + return p; + } + + std::shared_ptr nested_value_comparator_; +}; + template class ArrayCountSorter { using ArrayType = typename TypeTraits::ArrayType; @@ -409,6 +452,11 @@ struct ArraySorter< ArrayCompareSorter impl; }; +template +struct ArraySorter::value>> { + StructArrayCompareSorter impl; +}; + struct ArraySorterFactory { ArraySortFunc sorter; @@ -507,6 +555,13 @@ void AddArraySortingKernels(VectorKernel base, VectorFunction* func) { DCHECK_OK(func->AddKernel(base)); } +template