From d49de5c4efc68532a47d736ba3f2d93e6650eafb Mon Sep 17 00:00:00 2001 From: benibus Date: Tue, 23 May 2023 17:20:37 -0400 Subject: [PATCH 1/8] Initial sorting implementations --- cpp/src/arrow/compute/kernels/vector_sort.cc | 55 ++++++++++----- .../compute/kernels/vector_sort_internal.h | 57 ++++++++++++---- .../arrow/compute/kernels/vector_sort_test.cc | 67 +++++++++++++++++++ 3 files changed, 150 insertions(+), 29 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/vector_sort.cc b/cpp/src/arrow/compute/kernels/vector_sort.cc index 1de90cac35b..0103bf7d3ad 100644 --- a/cpp/src/arrow/compute/kernels/vector_sort.cc +++ b/cpp/src/arrow/compute/kernels/vector_sort.cc @@ -20,6 +20,9 @@ #include "arrow/compute/kernels/vector_sort_internal.h" #include "arrow/compute/registry.h" +template <> +struct std::hash : public arrow::FieldPath::Hash {}; + namespace arrow { using internal::checked_cast; @@ -850,12 +853,22 @@ class SortIndicesMetaFunction : public MetaFunction { ExecContext* ctx) const override { const SortOptions& sort_options = static_cast(*options); switch (args[0].kind()) { - case Datum::ARRAY: - return SortIndices(*args[0].make_array(), sort_options, ctx); - break; - case Datum::CHUNKED_ARRAY: - return SortIndices(*args[0].chunked_array(), sort_options, ctx); - break; + case Datum::ARRAY: { + auto array = args[0].make_array(); + if (array->type_id() == Type::STRUCT) { + ARROW_ASSIGN_OR_RAISE(auto batch, RecordBatch::FromStructArray(array)) + return SortIndices(*batch, sort_options, ctx); + } + return SortIndices(*array, sort_options, ctx); + } break; + case Datum::CHUNKED_ARRAY: { + const auto& chunked_array = args[0].chunked_array(); + if (chunked_array->type()->id() == Type::STRUCT) { + ARROW_ASSIGN_OR_RAISE(auto table, ToTable(chunked_array)) + return SortIndices(*table, sort_options, ctx); + } + return SortIndices(*chunked_array, sort_options, ctx); + } break; case Datum::RECORD_BATCH: { return SortIndices(*args[0].record_batch(), sort_options, ctx); } break; @@ -872,6 +885,18 @@ class SortIndicesMetaFunction : public MetaFunction { } private: + static Result> ToTable( + const std::shared_ptr& chunked_array) { + if (chunked_array->null_count() == 0) { + return Table::FromChunkedStructArray(chunked_array); + } + // We avoid using `Table::FromChunkedStructArray` here since it doesn't take top-level + // validity into account for the columns. + ARROW_ASSIGN_OR_RAISE(auto columns, chunked_array->Flatten()); + return Table::Make(schema(chunked_array->type()->fields()), std::move(columns), + chunked_array->length()); + } + Result SortIndices(const Array& values, const SortOptions& options, ExecContext* ctx) const { SortOrder order = SortOrder::Ascending; @@ -913,8 +938,9 @@ class SortIndicesMetaFunction : public MetaFunction { return Status::Invalid("Must specify one or more sort keys"); } if (n_sort_keys == 1) { - ARROW_ASSIGN_OR_RAISE(auto array, PrependInvalidColumn(GetColumn( - batch, options.sort_keys[0].target))); + ARROW_ASSIGN_OR_RAISE( + auto array, + PrependInvalidColumn(options.sort_keys[0].target.GetOneFlattened(batch))); return SortIndices(*array, options, ctx); } @@ -950,8 +976,9 @@ class SortIndicesMetaFunction : public MetaFunction { return Status::Invalid("Must specify one or more sort keys"); } if (n_sort_keys == 1) { - ARROW_ASSIGN_OR_RAISE(auto chunked_array, PrependInvalidColumn(GetColumn( - table, options.sort_keys[0].target))); + ARROW_ASSIGN_OR_RAISE( + auto chunked_array, + PrependInvalidColumn(options.sort_keys[0].target.GetOneFlattened(table))); return SortIndices(*chunked_array, options, ctx); } @@ -979,17 +1006,15 @@ class SortIndicesMetaFunction : public MetaFunction { Result> FindSortKeys(const Schema& schema, const std::vector& sort_keys) { std::vector fields; - std::unordered_set seen; + std::unordered_set seen; fields.reserve(sort_keys.size()); seen.reserve(sort_keys.size()); for (const auto& sort_key : sort_keys) { - RETURN_NOT_OK(CheckNonNested(sort_key.target)); - ARROW_ASSIGN_OR_RAISE(auto match, PrependInvalidColumn(sort_key.target.FindOne(schema))); - if (seen.insert(match[0]).second) { - fields.push_back({match[0], sort_key.order}); + if (seen.insert(match).second) { + fields.push_back(SortField(std::move(match), sort_key.order)); } } return fields; diff --git a/cpp/src/arrow/compute/kernels/vector_sort_internal.h b/cpp/src/arrow/compute/kernels/vector_sort_internal.h index d78e5130617..b77226730dc 100644 --- a/cpp/src/arrow/compute/kernels/vector_sort_internal.h +++ b/cpp/src/arrow/compute/kernels/vector_sort_internal.h @@ -468,7 +468,13 @@ Result SortChunkedArray( // Helpers for Sort/SelectK/Rank implementations struct SortField { - int field_index; + SortField() = default; + SortField(FieldPath path, SortOrder order) : path(std::move(path)), order(order) {} + SortField(int index, SortOrder order) : SortField(FieldPath({index}), order) {} + + bool is_nested() const { return path.indices().size() > 1; } + + FieldPath path; SortOrder order; }; @@ -496,7 +502,10 @@ Result> ResolveSortKeys( ARROW_ASSIGN_OR_RAISE(const auto fields, FindSortKeys(schema, sort_keys)); std::vector resolved; resolved.reserve(fields.size()); - std::transform(fields.begin(), fields.end(), std::back_inserter(resolved), factory); + for (const auto& f : fields) { + ARROW_ASSIGN_OR_RAISE(auto resolved_key, factory(f)); + resolved.push_back(std::move(resolved_key)); + } return resolved; } @@ -504,8 +513,13 @@ template Result> ResolveSortKeys( const TableOrBatch& table_or_batch, const std::vector& sort_keys) { return ResolveSortKeys( - *table_or_batch.schema(), sort_keys, [&](const SortField& f) { - return ResolvedSortKey{table_or_batch.column(f.field_index), f.order}; + *table_or_batch.schema(), sort_keys, + [&](const SortField& f) -> Result { + if (f.is_nested()) { + ARROW_ASSIGN_OR_RAISE(auto child, f.path.GetFlattened(table_or_batch)); + return ResolvedSortKey{std::move(child), f.order}; + } + return ResolvedSortKey{table_or_batch.column(f.path[0]), f.order}; }); } @@ -737,17 +751,32 @@ struct ResolvedTableSortKey { static Result> Make( const Table& table, const RecordBatchVector& batches, const std::vector& sort_keys) { - auto factory = [&](const SortField& f) { - const auto& type = table.schema()->field(f.field_index)->type(); + auto factory = [&](const SortField& f) -> Result { + std::shared_ptr type; + int64_t null_count = 0; + ArrayVector chunks; + chunks.reserve(batches.size()); + // We must expose a homogenous chunking for all ResolvedSortKey, - // so we can't simply pass `table.column(f.field_index)` - ArrayVector chunks(batches.size()); - std::transform(batches.begin(), batches.end(), chunks.begin(), - [&](const std::shared_ptr& batch) { - return batch->column(f.field_index); - }); - return ResolvedTableSortKey(type, std::move(chunks), f.order, - table.column(f.field_index)->null_count()); + // so we can't simply access the column from the table directly. + if (f.is_nested()) { + ARROW_ASSIGN_OR_RAISE(auto schema_field, f.path.Get(*table.schema())); + type = schema_field->type(); + for (const auto& batch : batches) { + ARROW_ASSIGN_OR_RAISE(auto child, f.path.GetFlattened(*batch)); + null_count += child->null_count(); + chunks.push_back(std::move(child)); + } + } else { + null_count = table.column(f.path[0])->null_count(); + type = table.schema()->field(f.path[0])->type(); + std::transform(batches.begin(), batches.end(), std::back_inserter(chunks), + [&](const std::shared_ptr& batch) { + return batch->column(f.path[0]); + }); + } + + return ResolvedTableSortKey(type, std::move(chunks), f.order, null_count); }; return ::arrow::compute::internal::ResolveSortKeys( diff --git a/cpp/src/arrow/compute/kernels/vector_sort_test.cc b/cpp/src/arrow/compute/kernels/vector_sort_test.cc index 3429a5a8785..c79409c4683 100644 --- a/cpp/src/arrow/compute/kernels/vector_sort_test.cc +++ b/cpp/src/arrow/compute/kernels/vector_sort_test.cc @@ -2115,6 +2115,73 @@ INSTANTIATE_TEST_SUITE_P(AllNull, TestTableSortIndicesRandom, testing::Combine(first_sort_keys, num_sort_keys, testing::Values(1.0))); +class TestNestedSortIndices : public ::testing::Test { + protected: + static std::shared_ptr GetArray() { + auto child_type = struct_({field("a", uint8()), field("b", uint32())}); + auto child_array = ArrayFromJSON(child_type, + R"([{"a": 5, "b": null}, + {"a": null, "b": 7 }, + {"a": null, "b": 9 }, + {"a": 2, "b": 4 }, + {"a": 5, "b": 1 }, + {"a": 3, "b": null}, + {"a": 2, "b": 3 } + ])"); + + // The top-level validity bitmap is created independently to test null inheritance for + // child fields. + std::shared_ptr parent_bitmap; + ARROW_CHECK_OK(GetBitmapFromVector({1, 1, 1, 1, 1, 0, 1}, &parent_bitmap)); + + auto array = + *StructArray::Make({child_array}, {field("a", child_type)}, parent_bitmap); + ARROW_CHECK_OK(array->ValidateFull()); + return array; + } + + static std::shared_ptr GetRecordBatch() { + auto batch = *RecordBatch::FromStructArray(GetArray()); + ARROW_CHECK_OK(batch->ValidateFull()); + return batch; + } + + static std::shared_ptr GetChunkedArray() { + auto array = GetArray(); + ArrayVector chunks(2); + chunks[0] = *array->SliceSafe(0, 3); + chunks[1] = *array->SliceSafe(3); + auto chunked = *ChunkedArray::Make(std::move(chunks)); + ARROW_CHECK_OK(chunked->ValidateFull()); + return chunked; + } + + static std::shared_ptr GetTable() { + auto chunked = GetChunkedArray(); + auto columns = *chunked->Flatten(); + auto table = + Table::Make(arrow::schema(chunked->type()->fields()), std::move(columns)); + ARROW_CHECK_OK(table->ValidateFull()); + return table; + } + + template + void DoTest(const std::shared_ptr& input) const { + std::vector sort_keys = {SortKey(FieldRef("a", "a"), SortOrder::Ascending), + SortKey(FieldRef("a", "b"), SortOrder::Descending)}; + + SortOptions options(sort_keys, NullPlacement::AtEnd); + AssertSortIndices(input, options, "[3, 6, 4, 0, 2, 1, 5]"); + options.null_placement = NullPlacement::AtStart; + AssertSortIndices(input, options, "[5, 2, 1, 3, 6, 0, 4]"); + } +}; + +TEST_F(TestNestedSortIndices, Array) { DoTest(GetArray()); } +TEST_F(TestNestedSortIndices, ChunkedArray) { DoTest(GetChunkedArray()); } +TEST_F(TestNestedSortIndices, RecordBatch) { DoTest(GetRecordBatch()); } +TEST_F(TestNestedSortIndices, Table) { DoTest(GetTable()); } + // ---------------------------------------------------------------------- // Tests for Rank From 29ebb4096db22ed0e050975f6675afb1248513d5 Mon Sep 17 00:00:00 2001 From: benibus Date: Tue, 30 May 2023 16:33:28 -0400 Subject: [PATCH 2/8] Tweak tests --- .../arrow/compute/kernels/vector_sort_test.cc | 72 +++++++++++-------- 1 file changed, 42 insertions(+), 30 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/vector_sort_test.cc b/cpp/src/arrow/compute/kernels/vector_sort_test.cc index c79409c4683..08d0a5581ce 100644 --- a/cpp/src/arrow/compute/kernels/vector_sort_test.cc +++ b/cpp/src/arrow/compute/kernels/vector_sort_test.cc @@ -515,30 +515,28 @@ TEST(ArraySortIndicesFunction, ChunkedArray) { // ---------------------------------------------------------------------- // Tests for SortToIndices -template -void AssertSortIndices(const std::shared_ptr& input, SortOrder order, - NullPlacement null_placement, +void AssertSortIndices(const Datum& datum, const SortOptions& options, const std::shared_ptr& expected) { - ArraySortOptions options(order, null_placement); - ASSERT_OK_AND_ASSIGN(auto actual, SortIndices(*input, options)); + ASSERT_OK_AND_ASSIGN(auto actual, SortIndices(datum, options)); ValidateOutput(*actual); AssertArraysEqual(*expected, *actual, /*verbose=*/true); } +void AssertSortIndices(const Datum& datum, const SortOptions& options, + const std::string& expected) { + AssertSortIndices(datum, options, ArrayFromJSON(uint64(), expected)); +} + template -void AssertSortIndices(const std::shared_ptr& input, const SortOptions& options, +void AssertSortIndices(const std::shared_ptr& input, SortOrder order, + NullPlacement null_placement, const std::shared_ptr& expected) { - ASSERT_OK_AND_ASSIGN(auto actual, SortIndices(Datum(*input), options)); + ArraySortOptions options(order, null_placement); + ASSERT_OK_AND_ASSIGN(auto actual, SortIndices(*input, options)); ValidateOutput(*actual); AssertArraysEqual(*expected, *actual, /*verbose=*/true); } -template -void AssertSortIndices(const std::shared_ptr& input, const SortOptions& options, - const std::string& expected) { - AssertSortIndices(input, options, ArrayFromJSON(uint64(), expected)); -} - template void AssertSortIndices(const std::shared_ptr& input, SortOrder order, NullPlacement null_placement, const std::string& expected) { @@ -2118,24 +2116,29 @@ INSTANTIATE_TEST_SUITE_P(AllNull, TestTableSortIndicesRandom, class TestNestedSortIndices : public ::testing::Test { protected: static std::shared_ptr GetArray() { - auto child_type = struct_({field("a", uint8()), field("b", uint32())}); - auto child_array = ArrayFromJSON(child_type, - R"([{"a": 5, "b": null}, - {"a": null, "b": 7 }, - {"a": null, "b": 9 }, - {"a": 2, "b": 4 }, - {"a": 5, "b": 1 }, - {"a": 3, "b": null}, - {"a": 2, "b": 3 } - ])"); + auto struct_type = + struct_({field("a", struct_({field("a", uint8()), field("b", uint32())})), + field("b", int32())}); + auto struct_array = checked_pointer_cast( + ArrayFromJSON(struct_type, + R"([{"a": {"a": 5, "b": null}, "b": 8 }, + {"a": {"a": null, "b": 7 }, "b": 3 }, + {"a": {"a": null, "b": 9 }, "b": 3 }, + {"a": {"a": 2, "b": 4 }, "b": 6 }, + {"a": {"a": 5, "b": 1 }, "b": null}, + {"a": {"a": 3, "b": null}, "b": 2 }, + {"a": {"a": 2, "b": 3 }, "b": 0 }, + {"a": {"a": 2, "b": 4 }, "b": 1 }, + {"a": {"a": null, "b": 7 }, "b": null}])")); // The top-level validity bitmap is created independently to test null inheritance for // child fields. std::shared_ptr parent_bitmap; - ARROW_CHECK_OK(GetBitmapFromVector({1, 1, 1, 1, 1, 0, 1}, &parent_bitmap)); + ARROW_CHECK_OK( + GetBitmapFromVector({1, 1, 1, 1, 1, 0, 1, 1, 1}, &parent_bitmap)); auto array = - *StructArray::Make({child_array}, {field("a", child_type)}, parent_bitmap); + *StructArray::Make(struct_array->fields(), struct_type->fields(), parent_bitmap); ARROW_CHECK_OK(array->ValidateFull()); return array; } @@ -2165,15 +2168,24 @@ class TestNestedSortIndices : public ::testing::Test { return table; } - template - void DoTest(const std::shared_ptr& input) const { + void DoTest(const Datum& datum) const { std::vector sort_keys = {SortKey(FieldRef("a", "a"), SortOrder::Ascending), - SortKey(FieldRef("a", "b"), SortOrder::Descending)}; + SortKey(FieldRef("a", "b"), SortOrder::Descending), + SortKey(FieldRef("b"), SortOrder::Ascending)}; SortOptions options(sort_keys, NullPlacement::AtEnd); - AssertSortIndices(input, options, "[3, 6, 4, 0, 2, 1, 5]"); + AssertSortIndices(datum, options, "[7, 3, 6, 4, 0, 2, 1, 8, 5]"); options.null_placement = NullPlacement::AtStart; - AssertSortIndices(input, options, "[5, 2, 1, 3, 6, 0, 4]"); + AssertSortIndices(datum, options, "[5, 2, 8, 1, 7, 3, 6, 0, 4]"); + + // Sort keys referencing a struct array are invalid + options.sort_keys = {SortKey(FieldRef("a", "a"), SortOrder::Descending), + SortKey(FieldRef("a"), SortOrder::Ascending)}; + EXPECT_RAISES_WITH_MESSAGE_THAT( + TypeError, + ::testing::HasSubstr( + "Unsupported type for RecordBatch sorting: struct"), + SortIndices(datum, options)); } }; From f062177a32399a9ffa1028229c6780f40c51cf82 Mon Sep 17 00:00:00 2001 From: benibus Date: Tue, 30 May 2023 16:34:46 -0400 Subject: [PATCH 3/8] Add note regarding column flattening --- cpp/src/arrow/compute/kernels/vector_sort.cc | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/cpp/src/arrow/compute/kernels/vector_sort.cc b/cpp/src/arrow/compute/kernels/vector_sort.cc index 0103bf7d3ad..714cda935a2 100644 --- a/cpp/src/arrow/compute/kernels/vector_sort.cc +++ b/cpp/src/arrow/compute/kernels/vector_sort.cc @@ -892,6 +892,10 @@ class SortIndicesMetaFunction : public MetaFunction { } // We avoid using `Table::FromChunkedStructArray` here since it doesn't take top-level // validity into account for the columns. + // + // TODO: We could instead use the provided sort keys to only flatten the selected + // columns (via `GetFlattenedField`). Same for the Array -> RecordBatch conversion, + // since `RecordBatch::FromStructArray` flattens all columns as well. ARROW_ASSIGN_OR_RAISE(auto columns, chunked_array->Flatten()); return Table::Make(schema(chunked_array->type()->fields()), std::move(columns), chunked_array->length()); From 8ec36934f793c8f7607006c707e40a8db6724154 Mon Sep 17 00:00:00 2001 From: benibus Date: Wed, 7 Jun 2023 10:20:09 -0400 Subject: [PATCH 4/8] Implement sorting full struct fields --- .../compute/kernels/vector_array_sort.cc | 23 ++- cpp/src/arrow/compute/kernels/vector_sort.cc | 147 ++++++++++++++---- .../compute/kernels/vector_sort_internal.h | 6 + .../arrow/compute/kernels/vector_sort_test.cc | 39 +++-- 4 files changed, 170 insertions(+), 45 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/vector_array_sort.cc b/cpp/src/arrow/compute/kernels/vector_array_sort.cc index 1499554a960..b9cfd38a02a 100644 --- a/cpp/src/arrow/compute/kernels/vector_array_sort.cc +++ b/cpp/src/arrow/compute/kernels/vector_array_sort.cc @@ -262,6 +262,19 @@ class ArrayCompareSorter { } }; +template <> +class ArrayCompareSorter { + public: + Result operator()(uint64_t* indices_begin, uint64_t* indices_end, + const Array& array, int64_t offset, + const ArraySortOptions& options, + ExecContext* ctx) { + const auto& struct_array = checked_cast(array); + return SortStructArray(ctx, indices_begin, indices_end, struct_array, options.order, + options.null_placement); + } +}; + template class ArrayCountSorter { using ArrayType = typename TypeTraits::ArrayType; @@ -497,7 +510,7 @@ template struct ArraySorter< Type, enable_if_t::value || is_base_binary_type::value || is_fixed_size_binary_type::value || - is_dictionary_type::value>> { + is_dictionary_type::value || is_struct_type::value>> { ArrayCompareSorter impl; }; @@ -606,6 +619,13 @@ void AddDictArraySortingKernels(VectorKernel base, VectorFunction* func) { DCHECK_OK(func->AddKernel(base)); } +template