From 339ba4ddc5e0a71b36bb22bda1e7d1313745d01c Mon Sep 17 00:00:00 2001 From: ZhangHuiGui Date: Tue, 9 Apr 2024 14:57:43 +0800 Subject: [PATCH 1/6] support flatten for combine nested list array --- cpp/src/arrow/array/array_list_test.cc | 67 +++++++++++++++++++++++++- cpp/src/arrow/array/array_nested.cc | 55 +++++++++++++++++++++ 2 files changed, 121 insertions(+), 1 deletion(-) diff --git a/cpp/src/arrow/array/array_list_test.cc b/cpp/src/arrow/array/array_list_test.cc index b08fa991686..ad6fda1870d 100644 --- a/cpp/src/arrow/array/array_list_test.cc +++ b/cpp/src/arrow/array/array_list_test.cc @@ -735,7 +735,7 @@ class TestListArray : public ::testing::Test { ArrayFromJSON(type, "[[1, 2], [3], [4], null, [5], [], [6]]")); auto sliced_list_array = std::dynamic_pointer_cast(list_array->Slice(3, 4)); - ASSERT_OK_AND_ASSIGN(auto flattened, list_array->Flatten()); + ASSERT_OK_AND_ASSIGN(auto flattened, sliced_list_array->Flatten()); ASSERT_OK(flattened->ValidateFull()); // Note the difference between values() and Flatten(). EXPECT_TRUE(flattened->Equals(ArrayFromJSON(int32(), "[5, 6]"))); @@ -763,6 +763,50 @@ class TestListArray : public ::testing::Test { << flattened->ToString(); } + void TestFlattenNested() { + auto inner_type = std::make_shared(int32()); + auto type = std::make_shared(inner_type); + + // List type with two nested level: list(list(int32)) + auto nested_list_array = std::dynamic_pointer_cast(ArrayFromJSON(type, R"([ + [[0, 1, 2], null, [3]], + [null], + [[2, 9], [4], [], [6, 5]] + ])")); + ASSERT_OK_AND_ASSIGN(auto flattened, nested_list_array->Flatten()); + ASSERT_OK(flattened->ValidateFull()); + ASSERT_EQ(9, flattened->length()); + ASSERT_TRUE(flattened->Equals(ArrayFromJSON(int32(), "[0, 1, 2, 3, 2, 9, 4, 6, 5]"))); + + // Empty nested list should flatten until reach it's non-list type + nested_list_array = + std::dynamic_pointer_cast(ArrayFromJSON(type, R"([null])")); + ASSERT_OK_AND_ASSIGN(flattened, nested_list_array->Flatten()); + ASSERT_TRUE(flattened->type()->Equals(int32())); + + // List type with three nested level: list(list(list(int32))) + type = std::make_shared(std::make_shared(std::make_shared(int32()))); + nested_list_array = std::dynamic_pointer_cast(ArrayFromJSON(type, R"([ + [ + [[0],[null]], + [[2,3], null] + ], + [ + [[null], [5]], + [[8]], + null + ], + [ + null + ] + ])")); + ASSERT_OK_AND_ASSIGN(flattened, nested_list_array->Flatten()); + ASSERT_OK(flattened->ValidateFull()); + ASSERT_EQ(7, flattened->length()); + ASSERT_EQ(2, flattened->null_count()); + ASSERT_TRUE(flattened->Equals(ArrayFromJSON(int32(), "[0, null, 2, 3, null, 5, 8]"))); + } + Status ValidateOffsetsAndSizes(int64_t length, std::vector offsets, std::vector sizes, std::shared_ptr values, int64_t offset = 0) { @@ -925,10 +969,12 @@ TYPED_TEST(TestListArray, BuilderPreserveFieldName) { TYPED_TEST(TestListArray, FlattenSimple) { this->TestFlattenSimple(); } TYPED_TEST(TestListArray, FlattenNulls) { this->TestFlattenNulls(); } TYPED_TEST(TestListArray, FlattenAllEmpty) { this->TestFlattenAllEmpty(); } +TYPED_TEST(TestListArray, FlattenSliced) { this->TestFlattenSliced(); } TYPED_TEST(TestListArray, FlattenZeroLength) { this->TestFlattenZeroLength(); } TYPED_TEST(TestListArray, TestFlattenNonEmptyBackingNulls) { this->TestFlattenNonEmptyBackingNulls(); } +TYPED_TEST(TestListArray, FlattenNested) { this->TestFlattenNested(); } TYPED_TEST(TestListArray, ValidateDimensions) { this->TestValidateDimensions(); } @@ -1714,4 +1760,23 @@ TEST_F(TestFixedSizeListArray, Flatten) { } } +TEST_F(TestFixedSizeListArray, FlattenNested) { + // Nested fixed-size list-array: fixed_size_list(fixed_size_list(int32, 2), 2) + auto inner_type = fixed_size_list(value_type_, 2); + type_ = fixed_size_list(inner_type, 2); + + auto values = std::dynamic_pointer_cast(ArrayFromJSON(type_, R"([ + [[0, 1], [null, 3]], + [[7, null], [2, 5]], + [null, null] + ])")); + ASSERT_OK(values->ValidateFull()); + ASSERT_OK_AND_ASSIGN(auto flattened, values->Flatten()); + ASSERT_OK(flattened->ValidateFull()); + ASSERT_EQ(8, flattened->length()); + ASSERT_EQ(2, flattened->null_count()); + AssertArraysEqual(*flattened, + *ArrayFromJSON(value_type_, "[0, 1, null, 3, 7, null, 2, 5]")); +} + } // namespace arrow diff --git a/cpp/src/arrow/array/array_nested.cc b/cpp/src/arrow/array/array_nested.cc index 958c2e25380..ddb789b6fc3 100644 --- a/cpp/src/arrow/array/array_nested.cc +++ b/cpp/src/arrow/array/array_nested.cc @@ -222,6 +222,28 @@ Result> FlattenListArray(const ListArrayT& list_array, const int64_t list_array_length = list_array.length(); std::shared_ptr value_array = list_array.values(); + // If the list array is nested list-array like 'list(list(int32))', then just + // flatten recursively. + if (is_list_like(value_array->type_id())) { + auto flatten_nested_list = + [&](const std::shared_ptr& varr) -> Result> { + switch (varr->type_id()) { + case Type::LIST: + return FlattenListArray(checked_cast(*varr), memory_pool); + case Type::LARGE_LIST: + return FlattenListArray(checked_cast(*varr), + memory_pool); + case Type::FIXED_SIZE_LIST: + return FlattenListArray(checked_cast(*varr), + memory_pool); + default: + return Status::Invalid("Unknown or unsupported arrow nested type: ", + varr->type()->ToString()); + } + }; + return flatten_nested_list(value_array); + } + // Shortcut: if a ListArray does not contain nulls, then simply slice its // value array with the first and the last offsets. if (list_array.null_count() == 0) { @@ -271,6 +293,39 @@ Result> FlattenListViewArray(const ListViewArrayT& list_v const int64_t list_view_array_length = list_view_array.length(); std::shared_ptr value_array = list_view_array.values(); + // If it's a nested list-view, flatten recursively. + if (is_list_view(value_array->type()->id())) { + auto flatten_nested_list_view = + [&](const std::shared_ptr& varr) -> Result> { + const bool has_nulls = varr->null_count() > 0; + + switch (varr->type_id()) { + case Type::LIST_VIEW: { + if (has_nulls) { + return FlattenListViewArray( + checked_cast(*varr), memory_pool); + } else { + return FlattenListViewArray( + checked_cast(*varr), memory_pool); + } + } + case Type::LARGE_LIST_VIEW: { + if (has_nulls) { + return FlattenListViewArray( + checked_cast(*varr), memory_pool); + } else { + return FlattenListViewArray( + checked_cast(*varr), memory_pool); + } + } + default: + return Status::Invalid("Unknown or unsupported arrow nested type: ", + varr->type()->ToString()); + } + }; + return flatten_nested_list_view(value_array); + } + if (list_view_array_length == 0) { return SliceArrayWithOffsets(*value_array, 0, 0); } From 26b7a6f9f8500d74f8c017e569ee9ca4b9c4a0e4 Mon Sep 17 00:00:00 2001 From: ZhangHuiGui Date: Tue, 9 Apr 2024 17:43:19 +0800 Subject: [PATCH 2/6] refactor underlying impl with a new interface and support option for function --- cpp/src/arrow/array/array_list_test.cc | 9 +- cpp/src/arrow/array/array_nested.cc | 132 +++++++++++++------------ cpp/src/arrow/array/array_nested.h | 20 +++- 3 files changed, 90 insertions(+), 71 deletions(-) diff --git a/cpp/src/arrow/array/array_list_test.cc b/cpp/src/arrow/array/array_list_test.cc index ad6fda1870d..65832ac6d88 100644 --- a/cpp/src/arrow/array/array_list_test.cc +++ b/cpp/src/arrow/array/array_list_test.cc @@ -773,7 +773,8 @@ class TestListArray : public ::testing::Test { [null], [[2, 9], [4], [], [6, 5]] ])")); - ASSERT_OK_AND_ASSIGN(auto flattened, nested_list_array->Flatten()); + ASSERT_OK_AND_ASSIGN(auto flattened, + nested_list_array->Flatten(/*with_recursion=*/true)); ASSERT_OK(flattened->ValidateFull()); ASSERT_EQ(9, flattened->length()); ASSERT_TRUE(flattened->Equals(ArrayFromJSON(int32(), "[0, 1, 2, 3, 2, 9, 4, 6, 5]"))); @@ -781,7 +782,7 @@ class TestListArray : public ::testing::Test { // Empty nested list should flatten until reach it's non-list type nested_list_array = std::dynamic_pointer_cast(ArrayFromJSON(type, R"([null])")); - ASSERT_OK_AND_ASSIGN(flattened, nested_list_array->Flatten()); + ASSERT_OK_AND_ASSIGN(flattened, nested_list_array->Flatten(/*with_recursion=*/true)); ASSERT_TRUE(flattened->type()->Equals(int32())); // List type with three nested level: list(list(list(int32))) @@ -800,7 +801,7 @@ class TestListArray : public ::testing::Test { null ] ])")); - ASSERT_OK_AND_ASSIGN(flattened, nested_list_array->Flatten()); + ASSERT_OK_AND_ASSIGN(flattened, nested_list_array->Flatten(/*with_recursion=*/true)); ASSERT_OK(flattened->ValidateFull()); ASSERT_EQ(7, flattened->length()); ASSERT_EQ(2, flattened->null_count()); @@ -1771,7 +1772,7 @@ TEST_F(TestFixedSizeListArray, FlattenNested) { [null, null] ])")); ASSERT_OK(values->ValidateFull()); - ASSERT_OK_AND_ASSIGN(auto flattened, values->Flatten()); + ASSERT_OK_AND_ASSIGN(auto flattened, values->Flatten(/*with_recursion=*/true)); ASSERT_OK(flattened->ValidateFull()); ASSERT_EQ(8, flattened->length()); ASSERT_EQ(2, flattened->null_count()); diff --git a/cpp/src/arrow/array/array_nested.cc b/cpp/src/arrow/array/array_nested.cc index ddb789b6fc3..063c820c09d 100644 --- a/cpp/src/arrow/array/array_nested.cc +++ b/cpp/src/arrow/array/array_nested.cc @@ -216,32 +216,24 @@ static std::shared_ptr SliceArrayWithOffsets(const Array& array, int64_t return array.Slice(begin, end - begin); } +namespace { +struct FlattenWithRecursion { + // Flatten all list-like types array recursively + static Result> Flatten(const Array& array, bool with_recursion, + MemoryPool* memory_pool); +}; +} // namespace + template Result> FlattenListArray(const ListArrayT& list_array, + bool with_recursion, MemoryPool* memory_pool) { const int64_t list_array_length = list_array.length(); std::shared_ptr value_array = list_array.values(); - // If the list array is nested list-array like 'list(list(int32))', then just - // flatten recursively. - if (is_list_like(value_array->type_id())) { - auto flatten_nested_list = - [&](const std::shared_ptr& varr) -> Result> { - switch (varr->type_id()) { - case Type::LIST: - return FlattenListArray(checked_cast(*varr), memory_pool); - case Type::LARGE_LIST: - return FlattenListArray(checked_cast(*varr), - memory_pool); - case Type::FIXED_SIZE_LIST: - return FlattenListArray(checked_cast(*varr), - memory_pool); - default: - return Status::Invalid("Unknown or unsupported arrow nested type: ", - varr->type()->ToString()); - } - }; - return flatten_nested_list(value_array); + // If it's a nested-list related array, flatten recursively. + if (is_list_like(value_array->type_id()) && with_recursion) { + return FlattenWithRecursion::Flatten(*value_array, with_recursion, memory_pool); } // Shortcut: if a ListArray does not contain nulls, then simply slice its @@ -287,6 +279,7 @@ Result> FlattenListArray(const ListArrayT& list_array, template Result> FlattenListViewArray(const ListViewArrayT& list_view_array, + bool with_recursion, MemoryPool* memory_pool) { using offset_type = typename ListViewArrayT::offset_type; const int64_t list_view_array_offset = list_view_array.offset(); @@ -294,36 +287,8 @@ Result> FlattenListViewArray(const ListViewArrayT& list_v std::shared_ptr value_array = list_view_array.values(); // If it's a nested list-view, flatten recursively. - if (is_list_view(value_array->type()->id())) { - auto flatten_nested_list_view = - [&](const std::shared_ptr& varr) -> Result> { - const bool has_nulls = varr->null_count() > 0; - - switch (varr->type_id()) { - case Type::LIST_VIEW: { - if (has_nulls) { - return FlattenListViewArray( - checked_cast(*varr), memory_pool); - } else { - return FlattenListViewArray( - checked_cast(*varr), memory_pool); - } - } - case Type::LARGE_LIST_VIEW: { - if (has_nulls) { - return FlattenListViewArray( - checked_cast(*varr), memory_pool); - } else { - return FlattenListViewArray( - checked_cast(*varr), memory_pool); - } - } - default: - return Status::Invalid("Unknown or unsupported arrow nested type: ", - varr->type()->ToString()); - } - }; - return flatten_nested_list_view(value_array); + if (is_list_view(value_array->type()->id()) && with_recursion) { + return FlattenWithRecursion::Flatten(*value_array, with_recursion, memory_pool); } if (list_view_array_length == 0) { @@ -406,6 +371,44 @@ Result> FlattenListViewArray(const ListViewArrayT& list_v return Concatenate(slices, memory_pool); } +Result> FlattenWithRecursion::Flatten(const Array& array, + bool with_recursion, + MemoryPool* memory_pool) { + const bool has_nulls = array.null_count() > 0; + switch (array.type_id()) { + case Type::LIST: + return FlattenListArray(checked_cast(array), with_recursion, + memory_pool); + case Type::LARGE_LIST: + return FlattenListArray(checked_cast(array), with_recursion, + memory_pool); + case Type::FIXED_SIZE_LIST: + return FlattenListArray(checked_cast(array), + with_recursion, memory_pool); + case Type::LIST_VIEW: { + if (has_nulls) { + return FlattenListViewArray( + checked_cast(array), with_recursion, memory_pool); + } else { + return FlattenListViewArray( + checked_cast(array), with_recursion, memory_pool); + } + } + case Type::LARGE_LIST_VIEW: { + if (has_nulls) { + return FlattenListViewArray( + checked_cast(array), with_recursion, memory_pool); + } else { + return FlattenListViewArray( + checked_cast(array), with_recursion, memory_pool); + } + } + default: + return Status::Invalid("Unknown or unsupported arrow nested type: ", + array.type()->ToString()); + } +} + std::shared_ptr BoxOffsets(const std::shared_ptr& boxed_type, const ArrayData& data) { const int64_t num_offsets = @@ -577,8 +580,9 @@ Result> ListArray::FromArrays( null_bitmap, null_count); } -Result> ListArray::Flatten(MemoryPool* memory_pool) const { - return FlattenListArray(*this, memory_pool); +Result> ListArray::Flatten(bool with_recursion, + MemoryPool* memory_pool) const { + return FlattenListArray(*this, with_recursion, memory_pool); } std::shared_ptr ListArray::offsets() const { return BoxOffsets(int32(), *data_); } @@ -636,8 +640,9 @@ Result> LargeListArray::FromArrays( null_bitmap, null_count); } -Result> LargeListArray::Flatten(MemoryPool* memory_pool) const { - return FlattenListArray(*this, memory_pool); +Result> LargeListArray::Flatten(bool with_recursion, + MemoryPool* memory_pool) const { + return FlattenListArray(*this, with_recursion, memory_pool); } std::shared_ptr LargeListArray::offsets() const { @@ -706,11 +711,12 @@ Result> LargeListViewArray::FromList( return std::make_shared(std::move(data)); } -Result> ListViewArray::Flatten(MemoryPool* memory_pool) const { +Result> ListViewArray::Flatten(bool with_recursion, + MemoryPool* memory_pool) const { if (null_count() > 0) { - return FlattenListViewArray(*this, memory_pool); + return FlattenListViewArray(*this, with_recursion, memory_pool); } - return FlattenListViewArray(*this, memory_pool); + return FlattenListViewArray(*this, with_recursion, memory_pool); } std::shared_ptr ListViewArray::offsets() const { @@ -767,11 +773,13 @@ Result> LargeListViewArray::FromArrays( } Result> LargeListViewArray::Flatten( - MemoryPool* memory_pool) const { + bool with_recursion, MemoryPool* memory_pool) const { if (null_count() > 0) { - return FlattenListViewArray(*this, memory_pool); + return FlattenListViewArray(*this, with_recursion, + memory_pool); } - return FlattenListViewArray(*this, memory_pool); + return FlattenListViewArray(*this, with_recursion, + memory_pool); } std::shared_ptr LargeListViewArray::offsets() const { @@ -988,8 +996,8 @@ Result> FixedSizeListArray::FromArrays( } Result> FixedSizeListArray::Flatten( - MemoryPool* memory_pool) const { - return FlattenListArray(*this, memory_pool); + bool with_recursion, MemoryPool* memory_pool) const { + return FlattenListArray(*this, with_recursion, memory_pool); } // ---------------------------------------------------------------------- diff --git a/cpp/src/arrow/array/array_nested.h b/cpp/src/arrow/array/array_nested.h index 768a630e0af..0ebf92d9bfa 100644 --- a/cpp/src/arrow/array/array_nested.h +++ b/cpp/src/arrow/array/array_nested.h @@ -183,11 +183,13 @@ class ARROW_EXPORT ListArray : public BaseListArray { /// \brief Return an Array that is a concatenation of the lists in this array. /// + /// \param[in] with_recursion Flatten recursively until reach non-list type + /// /// Note that it's different from `values()` in that it takes into /// consideration of this array's offsets as well as null elements backed /// by non-empty lists (they are skipped, thus copying may be needed). Result> Flatten( - MemoryPool* memory_pool = default_memory_pool()) const; + bool with_recursion = false, MemoryPool* memory_pool = default_memory_pool()) const; /// \brief Return list offsets as an Int32Array /// @@ -251,11 +253,13 @@ class ARROW_EXPORT LargeListArray : public BaseListArray { /// \brief Return an Array that is a concatenation of the lists in this array. /// + /// \param[in] with_recursion Flatten recursively until reach non-list type + /// /// Note that it's different from `values()` in that it takes into /// consideration of this array's offsets as well as null elements backed /// by non-empty lists (they are skipped, thus copying may be needed). Result> Flatten( - MemoryPool* memory_pool = default_memory_pool()) const; + bool with_recursion = false, MemoryPool* memory_pool = default_memory_pool()) const; /// \brief Return list offsets as an Int64Array std::shared_ptr offsets() const; @@ -353,6 +357,8 @@ class ARROW_EXPORT ListViewArray : public BaseListViewArray { /// \brief Return an Array that is a concatenation of the list-views in this array. /// + /// \param[in] with_recursion Flatten recursively until reach non-list type + /// /// Note that it's different from `values()` in that it takes into /// consideration this array's offsets (which can be in any order) /// and sizes. Nulls are skipped. @@ -362,7 +368,7 @@ class ARROW_EXPORT ListViewArray : public BaseListViewArray { /// maximizing the size of each slice (containing as many contiguous /// list-views as possible). Result> Flatten( - MemoryPool* memory_pool = default_memory_pool()) const; + bool with_recursion = false, MemoryPool* memory_pool = default_memory_pool()) const; /// \brief Return list-view offsets as an Int32Array /// @@ -442,11 +448,13 @@ class ARROW_EXPORT LargeListViewArray : public BaseListViewArray> Flatten( - MemoryPool* memory_pool = default_memory_pool()) const; + bool with_recursion = false, MemoryPool* memory_pool = default_memory_pool()) const; /// \brief Return list-view offsets as an Int64Array /// @@ -590,10 +598,12 @@ class ARROW_EXPORT FixedSizeListArray : public Array { /// \brief Return an Array that is a concatenation of the lists in this array. /// + /// \param[in] with_recursion Flatten recursively until reach non-list type + /// /// Note that it's different from `values()` in that it takes into /// consideration null elements (they are skipped, thus copying may be needed). Result> Flatten( - MemoryPool* memory_pool = default_memory_pool()) const; + bool with_recursion = false, MemoryPool* memory_pool = default_memory_pool()) const; /// \brief Construct FixedSizeListArray from child value array and value_length /// From 3d5a2ece4b7e5d7d3ed7a460acc2084bb7336860 Mon Sep 17 00:00:00 2001 From: ZhangHuiGui Date: Tue, 9 Apr 2024 22:09:43 +0800 Subject: [PATCH 3/6] add a new flatten api with a recursion argument --- cpp/src/arrow/array/array_list_test.cc | 15 +++--- cpp/src/arrow/array/array_nested.cc | 64 ++++++++++++++++++++------ cpp/src/arrow/array/array_nested.h | 45 ++++++++++++------ 3 files changed, 86 insertions(+), 38 deletions(-) diff --git a/cpp/src/arrow/array/array_list_test.cc b/cpp/src/arrow/array/array_list_test.cc index 65832ac6d88..5ef018c6997 100644 --- a/cpp/src/arrow/array/array_list_test.cc +++ b/cpp/src/arrow/array/array_list_test.cc @@ -763,7 +763,7 @@ class TestListArray : public ::testing::Test { << flattened->ToString(); } - void TestFlattenNested() { + void TestFlattenRecursion() { auto inner_type = std::make_shared(int32()); auto type = std::make_shared(inner_type); @@ -773,8 +773,7 @@ class TestListArray : public ::testing::Test { [null], [[2, 9], [4], [], [6, 5]] ])")); - ASSERT_OK_AND_ASSIGN(auto flattened, - nested_list_array->Flatten(/*with_recursion=*/true)); + ASSERT_OK_AND_ASSIGN(auto flattened, nested_list_array->FlattenRecursion()); ASSERT_OK(flattened->ValidateFull()); ASSERT_EQ(9, flattened->length()); ASSERT_TRUE(flattened->Equals(ArrayFromJSON(int32(), "[0, 1, 2, 3, 2, 9, 4, 6, 5]"))); @@ -782,7 +781,7 @@ class TestListArray : public ::testing::Test { // Empty nested list should flatten until reach it's non-list type nested_list_array = std::dynamic_pointer_cast(ArrayFromJSON(type, R"([null])")); - ASSERT_OK_AND_ASSIGN(flattened, nested_list_array->Flatten(/*with_recursion=*/true)); + ASSERT_OK_AND_ASSIGN(flattened, nested_list_array->FlattenRecursion()); ASSERT_TRUE(flattened->type()->Equals(int32())); // List type with three nested level: list(list(list(int32))) @@ -801,7 +800,7 @@ class TestListArray : public ::testing::Test { null ] ])")); - ASSERT_OK_AND_ASSIGN(flattened, nested_list_array->Flatten(/*with_recursion=*/true)); + ASSERT_OK_AND_ASSIGN(flattened, nested_list_array->FlattenRecursion()); ASSERT_OK(flattened->ValidateFull()); ASSERT_EQ(7, flattened->length()); ASSERT_EQ(2, flattened->null_count()); @@ -975,7 +974,7 @@ TYPED_TEST(TestListArray, FlattenZeroLength) { this->TestFlattenZeroLength(); } TYPED_TEST(TestListArray, TestFlattenNonEmptyBackingNulls) { this->TestFlattenNonEmptyBackingNulls(); } -TYPED_TEST(TestListArray, FlattenNested) { this->TestFlattenNested(); } +TYPED_TEST(TestListArray, FlattenRecursion) { this->TestFlattenRecursion(); } TYPED_TEST(TestListArray, ValidateDimensions) { this->TestValidateDimensions(); } @@ -1761,7 +1760,7 @@ TEST_F(TestFixedSizeListArray, Flatten) { } } -TEST_F(TestFixedSizeListArray, FlattenNested) { +TEST_F(TestFixedSizeListArray, FlattenRecursion) { // Nested fixed-size list-array: fixed_size_list(fixed_size_list(int32, 2), 2) auto inner_type = fixed_size_list(value_type_, 2); type_ = fixed_size_list(inner_type, 2); @@ -1772,7 +1771,7 @@ TEST_F(TestFixedSizeListArray, FlattenNested) { [null, null] ])")); ASSERT_OK(values->ValidateFull()); - ASSERT_OK_AND_ASSIGN(auto flattened, values->Flatten(/*with_recursion=*/true)); + ASSERT_OK_AND_ASSIGN(auto flattened, values->FlattenRecursion()); ASSERT_OK(flattened->ValidateFull()); ASSERT_EQ(8, flattened->length()); ASSERT_EQ(2, flattened->null_count()); diff --git a/cpp/src/arrow/array/array_nested.cc b/cpp/src/arrow/array/array_nested.cc index 063c820c09d..f64e83ad499 100644 --- a/cpp/src/arrow/array/array_nested.cc +++ b/cpp/src/arrow/array/array_nested.cc @@ -580,9 +580,13 @@ Result> ListArray::FromArrays( null_bitmap, null_count); } -Result> ListArray::Flatten(bool with_recursion, - MemoryPool* memory_pool) const { - return FlattenListArray(*this, with_recursion, memory_pool); +Result> ListArray::Flatten(MemoryPool* memory_pool) const { + return FlattenListArray(*this, /*with_recursion=*/false, memory_pool); +} + +Result> ListArray::FlattenRecursion( + MemoryPool* memory_pool) const { + return FlattenListArray(*this, /*with_recursion=*/true, memory_pool); } std::shared_ptr ListArray::offsets() const { return BoxOffsets(int32(), *data_); } @@ -640,9 +644,13 @@ Result> LargeListArray::FromArrays( null_bitmap, null_count); } -Result> LargeListArray::Flatten(bool with_recursion, - MemoryPool* memory_pool) const { - return FlattenListArray(*this, with_recursion, memory_pool); +Result> LargeListArray::Flatten(MemoryPool* memory_pool) const { + return FlattenListArray(*this, /*with_recursion=*/false, memory_pool); +} + +Result> LargeListArray::FlattenRecursion( + MemoryPool* memory_pool) const { + return FlattenListArray(*this, /*with_recursion=*/true, memory_pool); } std::shared_ptr LargeListArray::offsets() const { @@ -711,12 +719,23 @@ Result> LargeListViewArray::FromList( return std::make_shared(std::move(data)); } -Result> ListViewArray::Flatten(bool with_recursion, - MemoryPool* memory_pool) const { +Result> ListViewArray::Flatten(MemoryPool* memory_pool) const { + if (null_count() > 0) { + return FlattenListViewArray(*this, /*with_recursion=*/false, + memory_pool); + } + return FlattenListViewArray(*this, /*with_recursion=*/false, + memory_pool); +} + +Result> ListViewArray::FlattenRecursion( + MemoryPool* memory_pool) const { if (null_count() > 0) { - return FlattenListViewArray(*this, with_recursion, memory_pool); + return FlattenListViewArray(*this, /*with_recursion=*/true, + memory_pool); } - return FlattenListViewArray(*this, with_recursion, memory_pool); + return FlattenListViewArray(*this, /*with_recursion=*/true, + memory_pool); } std::shared_ptr ListViewArray::offsets() const { @@ -773,12 +792,22 @@ Result> LargeListViewArray::FromArrays( } Result> LargeListViewArray::Flatten( - bool with_recursion, MemoryPool* memory_pool) const { + MemoryPool* memory_pool) const { if (null_count() > 0) { - return FlattenListViewArray(*this, with_recursion, + return FlattenListViewArray(*this, /*with_recursion=*/false, memory_pool); } - return FlattenListViewArray(*this, with_recursion, + return FlattenListViewArray(*this, /*with_recursion=*/false, + memory_pool); +} + +Result> LargeListViewArray::FlattenRecursion( + MemoryPool* memory_pool) const { + if (null_count() > 0) { + return FlattenListViewArray(*this, /*with_recursion=*/true, + memory_pool); + } + return FlattenListViewArray(*this, /*with_recursion=*/true, memory_pool); } @@ -996,8 +1025,13 @@ Result> FixedSizeListArray::FromArrays( } Result> FixedSizeListArray::Flatten( - bool with_recursion, MemoryPool* memory_pool) const { - return FlattenListArray(*this, with_recursion, memory_pool); + MemoryPool* memory_pool) const { + return FlattenListArray(*this, /*with_recursion=*/false, memory_pool); +} + +Result> FixedSizeListArray::FlattenRecursion( + MemoryPool* memory_pool) const { + return FlattenListArray(*this, /*with_recursion=*/true, memory_pool); } // ---------------------------------------------------------------------- diff --git a/cpp/src/arrow/array/array_nested.h b/cpp/src/arrow/array/array_nested.h index 0ebf92d9bfa..94ac8ea68bc 100644 --- a/cpp/src/arrow/array/array_nested.h +++ b/cpp/src/arrow/array/array_nested.h @@ -183,13 +183,16 @@ class ARROW_EXPORT ListArray : public BaseListArray { /// \brief Return an Array that is a concatenation of the lists in this array. /// - /// \param[in] with_recursion Flatten recursively until reach non-list type - /// /// Note that it's different from `values()` in that it takes into /// consideration of this array's offsets as well as null elements backed /// by non-empty lists (they are skipped, thus copying may be needed). Result> Flatten( - bool with_recursion = false, MemoryPool* memory_pool = default_memory_pool()) const; + MemoryPool* memory_pool = default_memory_pool()) const; + + /// \brief Flatten all level recursively until reach a non-list type, and return a + /// non-list type Array. + Result> FlattenRecursion( + MemoryPool* memory_pool = default_memory_pool()) const; /// \brief Return list offsets as an Int32Array /// @@ -253,13 +256,16 @@ class ARROW_EXPORT LargeListArray : public BaseListArray { /// \brief Return an Array that is a concatenation of the lists in this array. /// - /// \param[in] with_recursion Flatten recursively until reach non-list type - /// /// Note that it's different from `values()` in that it takes into /// consideration of this array's offsets as well as null elements backed /// by non-empty lists (they are skipped, thus copying may be needed). Result> Flatten( - bool with_recursion = false, MemoryPool* memory_pool = default_memory_pool()) const; + MemoryPool* memory_pool = default_memory_pool()) const; + + /// \brief Flatten all level recursively until reach a non-list type, and return a + /// non-list type Array. + Result> FlattenRecursion( + MemoryPool* memory_pool = default_memory_pool()) const; /// \brief Return list offsets as an Int64Array std::shared_ptr offsets() const; @@ -357,8 +363,6 @@ class ARROW_EXPORT ListViewArray : public BaseListViewArray { /// \brief Return an Array that is a concatenation of the list-views in this array. /// - /// \param[in] with_recursion Flatten recursively until reach non-list type - /// /// Note that it's different from `values()` in that it takes into /// consideration this array's offsets (which can be in any order) /// and sizes. Nulls are skipped. @@ -368,7 +372,12 @@ class ARROW_EXPORT ListViewArray : public BaseListViewArray { /// maximizing the size of each slice (containing as many contiguous /// list-views as possible). Result> Flatten( - bool with_recursion = false, MemoryPool* memory_pool = default_memory_pool()) const; + MemoryPool* memory_pool = default_memory_pool()) const; + + /// \brief Flatten all level recursively until reach a non-list type, and return a + /// non-list type Array. + Result> FlattenRecursion( + MemoryPool* memory_pool = default_memory_pool()) const; /// \brief Return list-view offsets as an Int32Array /// @@ -448,13 +457,16 @@ class ARROW_EXPORT LargeListViewArray : public BaseListViewArray> Flatten( - bool with_recursion = false, MemoryPool* memory_pool = default_memory_pool()) const; + MemoryPool* memory_pool = default_memory_pool()) const; + + /// \brief Flatten all level recursively until reach a non-list type, and return a + /// non-list type Array. + Result> FlattenRecursion( + MemoryPool* memory_pool = default_memory_pool()) const; /// \brief Return list-view offsets as an Int64Array /// @@ -598,12 +610,15 @@ class ARROW_EXPORT FixedSizeListArray : public Array { /// \brief Return an Array that is a concatenation of the lists in this array. /// - /// \param[in] with_recursion Flatten recursively until reach non-list type - /// /// Note that it's different from `values()` in that it takes into /// consideration null elements (they are skipped, thus copying may be needed). Result> Flatten( - bool with_recursion = false, MemoryPool* memory_pool = default_memory_pool()) const; + MemoryPool* memory_pool = default_memory_pool()) const; + + /// \brief Flatten all level recursively until reach a non-list type, and return a + /// non-list type Array. + Result> FlattenRecursion( + MemoryPool* memory_pool = default_memory_pool()) const; /// \brief Construct FixedSizeListArray from child value array and value_length /// From 2450c60f378a3ba8430365d465a0ccb6f472ecc5 Mon Sep 17 00:00:00 2001 From: ZhangHuiGui Date: Wed, 10 Apr 2024 22:35:04 +0800 Subject: [PATCH 4/6] Refactor codes --- cpp/src/arrow/array/array_list_test.cc | 14 +- cpp/src/arrow/array/array_nested.cc | 173 ++++++++++--------------- cpp/src/arrow/array/array_nested.h | 33 ++--- 3 files changed, 91 insertions(+), 129 deletions(-) diff --git a/cpp/src/arrow/array/array_list_test.cc b/cpp/src/arrow/array/array_list_test.cc index 5ef018c6997..b4953416af4 100644 --- a/cpp/src/arrow/array/array_list_test.cc +++ b/cpp/src/arrow/array/array_list_test.cc @@ -763,7 +763,7 @@ class TestListArray : public ::testing::Test { << flattened->ToString(); } - void TestFlattenRecursion() { + void TestFlattenRecursively() { auto inner_type = std::make_shared(int32()); auto type = std::make_shared(inner_type); @@ -773,7 +773,7 @@ class TestListArray : public ::testing::Test { [null], [[2, 9], [4], [], [6, 5]] ])")); - ASSERT_OK_AND_ASSIGN(auto flattened, nested_list_array->FlattenRecursion()); + ASSERT_OK_AND_ASSIGN(auto flattened, nested_list_array->FlattenRecursively()); ASSERT_OK(flattened->ValidateFull()); ASSERT_EQ(9, flattened->length()); ASSERT_TRUE(flattened->Equals(ArrayFromJSON(int32(), "[0, 1, 2, 3, 2, 9, 4, 6, 5]"))); @@ -781,7 +781,7 @@ class TestListArray : public ::testing::Test { // Empty nested list should flatten until reach it's non-list type nested_list_array = std::dynamic_pointer_cast(ArrayFromJSON(type, R"([null])")); - ASSERT_OK_AND_ASSIGN(flattened, nested_list_array->FlattenRecursion()); + ASSERT_OK_AND_ASSIGN(flattened, nested_list_array->FlattenRecursively()); ASSERT_TRUE(flattened->type()->Equals(int32())); // List type with three nested level: list(list(list(int32))) @@ -800,7 +800,7 @@ class TestListArray : public ::testing::Test { null ] ])")); - ASSERT_OK_AND_ASSIGN(flattened, nested_list_array->FlattenRecursion()); + ASSERT_OK_AND_ASSIGN(flattened, nested_list_array->FlattenRecursively()); ASSERT_OK(flattened->ValidateFull()); ASSERT_EQ(7, flattened->length()); ASSERT_EQ(2, flattened->null_count()); @@ -974,7 +974,7 @@ TYPED_TEST(TestListArray, FlattenZeroLength) { this->TestFlattenZeroLength(); } TYPED_TEST(TestListArray, TestFlattenNonEmptyBackingNulls) { this->TestFlattenNonEmptyBackingNulls(); } -TYPED_TEST(TestListArray, FlattenRecursion) { this->TestFlattenRecursion(); } +TYPED_TEST(TestListArray, FlattenRecursively) { this->TestFlattenRecursively(); } TYPED_TEST(TestListArray, ValidateDimensions) { this->TestValidateDimensions(); } @@ -1760,7 +1760,7 @@ TEST_F(TestFixedSizeListArray, Flatten) { } } -TEST_F(TestFixedSizeListArray, FlattenRecursion) { +TEST_F(TestFixedSizeListArray, FlattenRecursively) { // Nested fixed-size list-array: fixed_size_list(fixed_size_list(int32, 2), 2) auto inner_type = fixed_size_list(value_type_, 2); type_ = fixed_size_list(inner_type, 2); @@ -1771,7 +1771,7 @@ TEST_F(TestFixedSizeListArray, FlattenRecursion) { [null, null] ])")); ASSERT_OK(values->ValidateFull()); - ASSERT_OK_AND_ASSIGN(auto flattened, values->FlattenRecursion()); + ASSERT_OK_AND_ASSIGN(auto flattened, values->FlattenRecursively()); ASSERT_OK(flattened->ValidateFull()); ASSERT_EQ(8, flattened->length()); ASSERT_EQ(2, flattened->null_count()); diff --git a/cpp/src/arrow/array/array_nested.cc b/cpp/src/arrow/array/array_nested.cc index f64e83ad499..71b196b7155 100644 --- a/cpp/src/arrow/array/array_nested.cc +++ b/cpp/src/arrow/array/array_nested.cc @@ -216,26 +216,12 @@ static std::shared_ptr SliceArrayWithOffsets(const Array& array, int64_t return array.Slice(begin, end - begin); } -namespace { -struct FlattenWithRecursion { - // Flatten all list-like types array recursively - static Result> Flatten(const Array& array, bool with_recursion, - MemoryPool* memory_pool); -}; -} // namespace - template Result> FlattenListArray(const ListArrayT& list_array, - bool with_recursion, MemoryPool* memory_pool) { const int64_t list_array_length = list_array.length(); std::shared_ptr value_array = list_array.values(); - // If it's a nested-list related array, flatten recursively. - if (is_list_like(value_array->type_id()) && with_recursion) { - return FlattenWithRecursion::Flatten(*value_array, with_recursion, memory_pool); - } - // Shortcut: if a ListArray does not contain nulls, then simply slice its // value array with the first and the last offsets. if (list_array.null_count() == 0) { @@ -279,18 +265,12 @@ Result> FlattenListArray(const ListArrayT& list_array, template Result> FlattenListViewArray(const ListViewArrayT& list_view_array, - bool with_recursion, MemoryPool* memory_pool) { using offset_type = typename ListViewArrayT::offset_type; const int64_t list_view_array_offset = list_view_array.offset(); const int64_t list_view_array_length = list_view_array.length(); std::shared_ptr value_array = list_view_array.values(); - // If it's a nested list-view, flatten recursively. - if (is_list_view(value_array->type()->id()) && with_recursion) { - return FlattenWithRecursion::Flatten(*value_array, with_recursion, memory_pool); - } - if (list_view_array_length == 0) { return SliceArrayWithOffsets(*value_array, 0, 0); } @@ -371,44 +351,6 @@ Result> FlattenListViewArray(const ListViewArrayT& list_v return Concatenate(slices, memory_pool); } -Result> FlattenWithRecursion::Flatten(const Array& array, - bool with_recursion, - MemoryPool* memory_pool) { - const bool has_nulls = array.null_count() > 0; - switch (array.type_id()) { - case Type::LIST: - return FlattenListArray(checked_cast(array), with_recursion, - memory_pool); - case Type::LARGE_LIST: - return FlattenListArray(checked_cast(array), with_recursion, - memory_pool); - case Type::FIXED_SIZE_LIST: - return FlattenListArray(checked_cast(array), - with_recursion, memory_pool); - case Type::LIST_VIEW: { - if (has_nulls) { - return FlattenListViewArray( - checked_cast(array), with_recursion, memory_pool); - } else { - return FlattenListViewArray( - checked_cast(array), with_recursion, memory_pool); - } - } - case Type::LARGE_LIST_VIEW: { - if (has_nulls) { - return FlattenListViewArray( - checked_cast(array), with_recursion, memory_pool); - } else { - return FlattenListViewArray( - checked_cast(array), with_recursion, memory_pool); - } - } - default: - return Status::Invalid("Unknown or unsupported arrow nested type: ", - array.type()->ToString()); - } -} - std::shared_ptr BoxOffsets(const std::shared_ptr& boxed_type, const ArrayData& data) { const int64_t num_offsets = @@ -527,6 +469,69 @@ inline void SetListData(VarLengthListLikeArray* self, self->values_ = MakeArray(self->data_->child_data[0]); } +Result> FlattenLogicalListRecursively(const Array& array, + MemoryPool* memory_pool) { + Type::type kind = array.type_id(); + std::shared_ptr in_array = array.Slice(0, array.length()); + while (is_list_like(kind) || is_list_view(kind)) { + const bool has_nulls = array.null_count() > 0; + std::shared_ptr out; + switch (kind) { + case Type::LIST: { + ARROW_ASSIGN_OR_RAISE( + out, + FlattenListArray(checked_cast(*in_array), memory_pool)); + break; + } + case Type::LARGE_LIST: { + ARROW_ASSIGN_OR_RAISE( + out, FlattenListArray(checked_cast(*in_array), + memory_pool)); + break; + } + case Type::FIXED_SIZE_LIST: { + ARROW_ASSIGN_OR_RAISE( + out, FlattenListArray(checked_cast(*in_array), + memory_pool)); + break; + } + case Type::LIST_VIEW: { + if (has_nulls) { + ARROW_ASSIGN_OR_RAISE( + out, (FlattenListViewArray( + checked_cast(*in_array), memory_pool))); + break; + } else { + ARROW_ASSIGN_OR_RAISE( + out, (FlattenListViewArray( + checked_cast(*in_array), memory_pool))); + break; + } + } + case Type::LARGE_LIST_VIEW: { + if (has_nulls) { + ARROW_ASSIGN_OR_RAISE( + out, (FlattenListViewArray( + checked_cast(*in_array), memory_pool))); + break; + } else { + ARROW_ASSIGN_OR_RAISE( + out, (FlattenListViewArray( + checked_cast(*in_array), memory_pool))); + break; + } + } + default: + return Status::Invalid("Unknown or unsupported arrow nested type: ", + in_array->type()->ToString()); + } + + in_array = out; + kind = in_array->type_id(); + } + return std::move(in_array); +} + } // namespace internal // ---------------------------------------------------------------------- @@ -581,12 +586,7 @@ Result> ListArray::FromArrays( } Result> ListArray::Flatten(MemoryPool* memory_pool) const { - return FlattenListArray(*this, /*with_recursion=*/false, memory_pool); -} - -Result> ListArray::FlattenRecursion( - MemoryPool* memory_pool) const { - return FlattenListArray(*this, /*with_recursion=*/true, memory_pool); + return FlattenListArray(*this, memory_pool); } std::shared_ptr ListArray::offsets() const { return BoxOffsets(int32(), *data_); } @@ -645,12 +645,7 @@ Result> LargeListArray::FromArrays( } Result> LargeListArray::Flatten(MemoryPool* memory_pool) const { - return FlattenListArray(*this, /*with_recursion=*/false, memory_pool); -} - -Result> LargeListArray::FlattenRecursion( - MemoryPool* memory_pool) const { - return FlattenListArray(*this, /*with_recursion=*/true, memory_pool); + return FlattenListArray(*this, memory_pool); } std::shared_ptr LargeListArray::offsets() const { @@ -721,21 +716,9 @@ Result> LargeListViewArray::FromList( Result> ListViewArray::Flatten(MemoryPool* memory_pool) const { if (null_count() > 0) { - return FlattenListViewArray(*this, /*with_recursion=*/false, - memory_pool); - } - return FlattenListViewArray(*this, /*with_recursion=*/false, - memory_pool); -} - -Result> ListViewArray::FlattenRecursion( - MemoryPool* memory_pool) const { - if (null_count() > 0) { - return FlattenListViewArray(*this, /*with_recursion=*/true, - memory_pool); + return FlattenListViewArray(*this, memory_pool); } - return FlattenListViewArray(*this, /*with_recursion=*/true, - memory_pool); + return FlattenListViewArray(*this, memory_pool); } std::shared_ptr ListViewArray::offsets() const { @@ -794,21 +777,9 @@ Result> LargeListViewArray::FromArrays( Result> LargeListViewArray::Flatten( MemoryPool* memory_pool) const { if (null_count() > 0) { - return FlattenListViewArray(*this, /*with_recursion=*/false, - memory_pool); - } - return FlattenListViewArray(*this, /*with_recursion=*/false, - memory_pool); -} - -Result> LargeListViewArray::FlattenRecursion( - MemoryPool* memory_pool) const { - if (null_count() > 0) { - return FlattenListViewArray(*this, /*with_recursion=*/true, - memory_pool); + return FlattenListViewArray(*this, memory_pool); } - return FlattenListViewArray(*this, /*with_recursion=*/true, - memory_pool); + return FlattenListViewArray(*this, memory_pool); } std::shared_ptr LargeListViewArray::offsets() const { @@ -1026,12 +997,12 @@ Result> FixedSizeListArray::FromArrays( Result> FixedSizeListArray::Flatten( MemoryPool* memory_pool) const { - return FlattenListArray(*this, /*with_recursion=*/false, memory_pool); + return FlattenListArray(*this, memory_pool); } -Result> FixedSizeListArray::FlattenRecursion( +Result> FixedSizeListArray::FlattenRecursively( MemoryPool* memory_pool) const { - return FlattenListArray(*this, /*with_recursion=*/true, memory_pool); + return internal::FlattenLogicalListRecursively(*this, memory_pool); } // ---------------------------------------------------------------------- diff --git a/cpp/src/arrow/array/array_nested.h b/cpp/src/arrow/array/array_nested.h index 94ac8ea68bc..e39ce77d0df 100644 --- a/cpp/src/arrow/array/array_nested.h +++ b/cpp/src/arrow/array/array_nested.h @@ -58,6 +58,10 @@ void SetListData(VarLengthListLikeArray* self, const std::shared_ptr& data, Type::type expected_type_id = TYPE::type_id); +// Private flatten helper for logical lists: [Large]List[View]Array, FixedSizeListArray +// and MapArray +ARROW_EXPORT Result> FlattenLogicalListRecursively( + const Array& array, MemoryPool* memory_pool); } // namespace internal /// Base class for variable-sized list and list-view arrays, regardless of offset size. @@ -103,6 +107,13 @@ class VarLengthListLikeArray : public Array { return values_->Slice(value_offset(i), value_length(i)); } + /// \brief Flatten all level recursively until reach a non-list type, and return a + /// non-list type Array. + Result> FlattenRecursively( + MemoryPool* memory_pool = default_memory_pool()) const { + return internal::FlattenLogicalListRecursively(*this, memory_pool); + } + protected: friend void internal::SetListData(VarLengthListLikeArray* self, const std::shared_ptr& data, @@ -189,11 +200,6 @@ class ARROW_EXPORT ListArray : public BaseListArray { Result> Flatten( MemoryPool* memory_pool = default_memory_pool()) const; - /// \brief Flatten all level recursively until reach a non-list type, and return a - /// non-list type Array. - Result> FlattenRecursion( - MemoryPool* memory_pool = default_memory_pool()) const; - /// \brief Return list offsets as an Int32Array /// /// The returned array will not have a validity bitmap, so you cannot expect @@ -262,11 +268,6 @@ class ARROW_EXPORT LargeListArray : public BaseListArray { Result> Flatten( MemoryPool* memory_pool = default_memory_pool()) const; - /// \brief Flatten all level recursively until reach a non-list type, and return a - /// non-list type Array. - Result> FlattenRecursion( - MemoryPool* memory_pool = default_memory_pool()) const; - /// \brief Return list offsets as an Int64Array std::shared_ptr offsets() const; @@ -374,11 +375,6 @@ class ARROW_EXPORT ListViewArray : public BaseListViewArray { Result> Flatten( MemoryPool* memory_pool = default_memory_pool()) const; - /// \brief Flatten all level recursively until reach a non-list type, and return a - /// non-list type Array. - Result> FlattenRecursion( - MemoryPool* memory_pool = default_memory_pool()) const; - /// \brief Return list-view offsets as an Int32Array /// /// The returned array will not have a validity bitmap, so you cannot expect @@ -463,11 +459,6 @@ class ARROW_EXPORT LargeListViewArray : public BaseListViewArray> Flatten( MemoryPool* memory_pool = default_memory_pool()) const; - /// \brief Flatten all level recursively until reach a non-list type, and return a - /// non-list type Array. - Result> FlattenRecursion( - MemoryPool* memory_pool = default_memory_pool()) const; - /// \brief Return list-view offsets as an Int64Array /// /// The returned array will not have a validity bitmap, so you cannot expect @@ -617,7 +608,7 @@ class ARROW_EXPORT FixedSizeListArray : public Array { /// \brief Flatten all level recursively until reach a non-list type, and return a /// non-list type Array. - Result> FlattenRecursion( + Result> FlattenRecursively( MemoryPool* memory_pool = default_memory_pool()) const; /// \brief Construct FixedSizeListArray from child value array and value_length From 11ebc54ee9dc0142334d9d7a9b686493b66a7684 Mon Sep 17 00:00:00 2001 From: ZhangHuiGui Date: Thu, 11 Apr 2024 11:44:19 +0800 Subject: [PATCH 5/6] mainly refactor FlattenLogicalListRecursively function --- cpp/src/arrow/array/array_nested.cc | 72 ++++++++++------------------- cpp/src/arrow/array/array_nested.h | 32 +++++++++---- 2 files changed, 48 insertions(+), 56 deletions(-) diff --git a/cpp/src/arrow/array/array_nested.cc b/cpp/src/arrow/array/array_nested.cc index 71b196b7155..24e0dfb7081 100644 --- a/cpp/src/arrow/array/array_nested.cc +++ b/cpp/src/arrow/array/array_nested.cc @@ -42,6 +42,7 @@ #include "arrow/util/checked_cast.h" #include "arrow/util/list_util.h" #include "arrow/util/logging.h" +#include "arrow/util/unreachable.h" namespace arrow { @@ -469,67 +470,47 @@ inline void SetListData(VarLengthListLikeArray* self, self->values_ = MakeArray(self->data_->child_data[0]); } -Result> FlattenLogicalListRecursively(const Array& array, +Result> FlattenLogicalListRecursively(const Array& in_array, MemoryPool* memory_pool) { - Type::type kind = array.type_id(); - std::shared_ptr in_array = array.Slice(0, array.length()); - while (is_list_like(kind) || is_list_view(kind)) { - const bool has_nulls = array.null_count() > 0; - std::shared_ptr out; + std::shared_ptr array = in_array.Slice(0, in_array.length()); + for (auto kind = array->type_id(); is_list(kind) || is_list_view(kind); + kind = array->type_id()) { switch (kind) { case Type::LIST: { ARROW_ASSIGN_OR_RAISE( - out, - FlattenListArray(checked_cast(*in_array), memory_pool)); + array, (checked_cast(array.get())->Flatten(memory_pool))); break; } case Type::LARGE_LIST: { ARROW_ASSIGN_OR_RAISE( - out, FlattenListArray(checked_cast(*in_array), - memory_pool)); + array, + (checked_cast(array.get())->Flatten(memory_pool))); break; } - case Type::FIXED_SIZE_LIST: { + case Type::LIST_VIEW: { ARROW_ASSIGN_OR_RAISE( - out, FlattenListArray(checked_cast(*in_array), - memory_pool)); + array, + (checked_cast(array.get())->Flatten(memory_pool))); break; } - case Type::LIST_VIEW: { - if (has_nulls) { - ARROW_ASSIGN_OR_RAISE( - out, (FlattenListViewArray( - checked_cast(*in_array), memory_pool))); - break; - } else { - ARROW_ASSIGN_OR_RAISE( - out, (FlattenListViewArray( - checked_cast(*in_array), memory_pool))); - break; - } - } case Type::LARGE_LIST_VIEW: { - if (has_nulls) { - ARROW_ASSIGN_OR_RAISE( - out, (FlattenListViewArray( - checked_cast(*in_array), memory_pool))); - break; - } else { - ARROW_ASSIGN_OR_RAISE( - out, (FlattenListViewArray( - checked_cast(*in_array), memory_pool))); - break; - } + ARROW_ASSIGN_OR_RAISE( + array, + (checked_cast(array.get())->Flatten(memory_pool))); + break; + } + case Type::FIXED_SIZE_LIST: { + ARROW_ASSIGN_OR_RAISE( + array, + (checked_cast(array.get())->Flatten(memory_pool))); + break; } default: - return Status::Invalid("Unknown or unsupported arrow nested type: ", - in_array->type()->ToString()); + Unreachable("unexpected non-list type"); + break; } - - in_array = out; - kind = in_array->type_id(); } - return std::move(in_array); + return array; } } // namespace internal @@ -1000,11 +981,6 @@ Result> FixedSizeListArray::Flatten( return FlattenListArray(*this, memory_pool); } -Result> FixedSizeListArray::FlattenRecursively( - MemoryPool* memory_pool) const { - return internal::FlattenLogicalListRecursively(*this, memory_pool); -} - // ---------------------------------------------------------------------- // Struct diff --git a/cpp/src/arrow/array/array_nested.h b/cpp/src/arrow/array/array_nested.h index e39ce77d0df..5744f5fcadf 100644 --- a/cpp/src/arrow/array/array_nested.h +++ b/cpp/src/arrow/array/array_nested.h @@ -58,10 +58,20 @@ void SetListData(VarLengthListLikeArray* self, const std::shared_ptr& data, Type::type expected_type_id = TYPE::type_id); -// Private flatten helper for logical lists: [Large]List[View]Array, FixedSizeListArray -// and MapArray +/// \brief A version of Flatten that keeps recursively flattening until an array of +/// non-list values is reached. +/// +/// Array types considered to be lists by this function: +/// - list +/// - large_list +/// - list_view +/// - large_list_view +/// - fixed_size_list +/// +/// \see ListArray::Flatten ARROW_EXPORT Result> FlattenLogicalListRecursively( - const Array& array, MemoryPool* memory_pool); + const Array& in_array, MemoryPool* memory_pool); + } // namespace internal /// Base class for variable-sized list and list-view arrays, regardless of offset size. @@ -107,8 +117,10 @@ class VarLengthListLikeArray : public Array { return values_->Slice(value_offset(i), value_length(i)); } - /// \brief Flatten all level recursively until reach a non-list type, and return a - /// non-list type Array. + /// \brief Flatten all level recursively until reach a non-list type, and return + /// a non-list type Array. + /// + /// \see internal::FlattenLogicalListRecursively Result> FlattenRecursively( MemoryPool* memory_pool = default_memory_pool()) const { return internal::FlattenLogicalListRecursively(*this, memory_pool); @@ -606,10 +618,14 @@ class ARROW_EXPORT FixedSizeListArray : public Array { Result> Flatten( MemoryPool* memory_pool = default_memory_pool()) const; - /// \brief Flatten all level recursively until reach a non-list type, and return a - /// non-list type Array. + /// \brief Flatten all level recursively until reach a non-list type, and return + /// a non-list type Array. + /// + /// \see internal::FlattenLogicalListRecursively Result> FlattenRecursively( - MemoryPool* memory_pool = default_memory_pool()) const; + MemoryPool* memory_pool = default_memory_pool()) const { + return internal::FlattenLogicalListRecursively(*this, memory_pool); + } /// \brief Construct FixedSizeListArray from child value array and value_length /// From 9f1067db93fc57603599e0d2ef11b9b683d4d1c6 Mon Sep 17 00:00:00 2001 From: ZhangHuiGui Date: Fri, 12 Apr 2024 00:26:08 +0800 Subject: [PATCH 6/6] fix ut --- cpp/src/arrow/array/array_list_test.cc | 36 ++++++++++++++------------ 1 file changed, 19 insertions(+), 17 deletions(-) diff --git a/cpp/src/arrow/array/array_list_test.cc b/cpp/src/arrow/array/array_list_test.cc index b4953416af4..18afcc90d71 100644 --- a/cpp/src/arrow/array/array_list_test.cc +++ b/cpp/src/arrow/array/array_list_test.cc @@ -767,33 +767,34 @@ class TestListArray : public ::testing::Test { auto inner_type = std::make_shared(int32()); auto type = std::make_shared(inner_type); - // List type with two nested level: list(list(int32)) + // List types with two nested level: list> auto nested_list_array = std::dynamic_pointer_cast(ArrayFromJSON(type, R"([ - [[0, 1, 2], null, [3]], - [null], - [[2, 9], [4], [], [6, 5]] - ])")); + [[0, 1, 2], null, [3, null]], + [null], + [[2, 9], [4], [], [6, 5]] + ])")); ASSERT_OK_AND_ASSIGN(auto flattened, nested_list_array->FlattenRecursively()); ASSERT_OK(flattened->ValidateFull()); - ASSERT_EQ(9, flattened->length()); - ASSERT_TRUE(flattened->Equals(ArrayFromJSON(int32(), "[0, 1, 2, 3, 2, 9, 4, 6, 5]"))); + ASSERT_EQ(10, flattened->length()); + ASSERT_TRUE( + flattened->Equals(ArrayFromJSON(int32(), "[0, 1, 2, 3, null, 2, 9, 4, 6, 5]"))); - // Empty nested list should flatten until reach it's non-list type + // Empty nested list should flatten until non-list type is reached nested_list_array = std::dynamic_pointer_cast(ArrayFromJSON(type, R"([null])")); ASSERT_OK_AND_ASSIGN(flattened, nested_list_array->FlattenRecursively()); ASSERT_TRUE(flattened->type()->Equals(int32())); - // List type with three nested level: list(list(list(int32))) - type = std::make_shared(std::make_shared(std::make_shared(int32()))); + // List types with three nested level: list>> + type = std::make_shared(std::make_shared(fixed_size_list(int32(), 2))); nested_list_array = std::dynamic_pointer_cast(ArrayFromJSON(type, R"([ [ - [[0],[null]], - [[2,3], null] + [[null, 0]], + [[3, 7], null] ], [ - [[null], [5]], - [[8]], + [[4, null], [5, 8]], + [[8, null]], null ], [ @@ -802,9 +803,10 @@ class TestListArray : public ::testing::Test { ])")); ASSERT_OK_AND_ASSIGN(flattened, nested_list_array->FlattenRecursively()); ASSERT_OK(flattened->ValidateFull()); - ASSERT_EQ(7, flattened->length()); - ASSERT_EQ(2, flattened->null_count()); - ASSERT_TRUE(flattened->Equals(ArrayFromJSON(int32(), "[0, null, 2, 3, null, 5, 8]"))); + ASSERT_EQ(10, flattened->length()); + ASSERT_EQ(3, flattened->null_count()); + ASSERT_TRUE(flattened->Equals( + ArrayFromJSON(int32(), "[null, 0, 3, 7, 4, null, 5, 8, 8, null]"))); } Status ValidateOffsetsAndSizes(int64_t length, std::vector offsets,