diff --git a/cpp/src/arrow/compute/kernels/chunked_internal.h b/cpp/src/arrow/compute/kernels/chunked_internal.h index b007d6cbfb8..7e66e9b6403 100644 --- a/cpp/src/arrow/compute/kernels/chunked_internal.h +++ b/cpp/src/arrow/compute/kernels/chunked_internal.h @@ -48,6 +48,19 @@ struct ResolvedChunk { LogicalValueType Value() const { return V::LogicalValue(array->GetView(index)); } }; +// ResolvedChunk specialization for StructArray +template <> +struct ResolvedChunk { + // The target struct in chunked array. + const StructArray* array; + // The field index in the target struct. + const int64_t index; + + ResolvedChunk(const StructArray* array, int64_t index) : array(array), index(index) {} + + bool IsNull() const { return array->field(0)->IsNull(index); } +}; + // ResolvedChunk specialization for untyped arrays when all is needed is null lookup template <> struct ResolvedChunk { diff --git a/cpp/src/arrow/compute/kernels/vector_array_sort.cc b/cpp/src/arrow/compute/kernels/vector_array_sort.cc index 1335882a252..844b7ad0c2a 100644 --- a/cpp/src/arrow/compute/kernels/vector_array_sort.cc +++ b/cpp/src/arrow/compute/kernels/vector_array_sort.cc @@ -172,6 +172,49 @@ class ArrayCompareSorter { } }; +template +class StructArrayCompareSorter { + using ArrayType = typename TypeTraits::ArrayType; + + public: + // `offset` is used when this is called on a chunk of a chunked array + NullPartitionResult operator()(uint64_t* indices_begin, uint64_t* indices_end, + const Array& array, int64_t offset, + const ArraySortOptions& options) { + const auto& values = checked_cast(array); + nested_value_comparator_ = std::make_shared(); + + if (nested_value_comparator_->Prepare(values) != Status::OK()) { + // TODO: Improve error handling + return NullPartitionResult(); + } + + const auto p = PartitionNulls( + indices_begin, indices_end, values, offset, options.null_placement); + + bool asc_order = options.order == SortOrder::Ascending; + std::stable_sort(p.non_nulls_begin, p.non_nulls_end, + [&offset, &values, asc_order, this](uint64_t left, uint64_t right) { + // is better to do values.fields.size() or + // values.schema().num_fields() ? + for (ArrayVector::size_type fieldidx = 0; + fieldidx < values.fields().size(); ++fieldidx) { + int result = nested_value_comparator_->Compare( + values, fieldidx, offset, asc_order ? left : right, + asc_order ? right : left); + if (result == -1) + return true; + else if (result == 1) + return false; + } + return false; + }); + return p; + } + + std::shared_ptr nested_value_comparator_; +}; + template class ArrayCountSorter { using ArrayType = typename TypeTraits::ArrayType; @@ -410,6 +453,11 @@ struct ArraySorter< ArrayCompareSorter impl; }; +template +struct ArraySorter::value>> { + StructArrayCompareSorter impl; +}; + struct ArraySorterFactory { ArraySortFunc sorter; @@ -511,6 +559,13 @@ const ArraySortOptions* GetDefaultArraySortOptions() { return &kDefaultArraySortOptions; } +template