From ae9e603635ceda06f64cdb38e14c791236d15cb0 Mon Sep 17 00:00:00 2001 From: Miles Granger Date: Mon, 7 Nov 2022 13:57:17 +0100 Subject: [PATCH 1/5] [skip ci] Support FieldRef to work with ListElement --- .../arrow/compute/kernels/scalar_nested.cc | 21 +++++++++++++---- cpp/src/arrow/type.cc | 23 +++++++++++++++++-- python/pyarrow/tests/test_compute.py | 16 +++++++++++++ 3 files changed, 54 insertions(+), 6 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/scalar_nested.cc b/cpp/src/arrow/compute/kernels/scalar_nested.cc index 77fd67b3c77..c92b009c3fb 100644 --- a/cpp/src/arrow/compute/kernels/scalar_nested.cc +++ b/cpp/src/arrow/compute/kernels/scalar_nested.cc @@ -447,6 +447,14 @@ struct StructFieldFunctor { union_array.GetFlattenedField(index, ctx->memory_pool())); break; } + case Type::LIST: { + const auto& list_array = checked_cast(*current); + Datum idx(index); + ARROW_ASSIGN_OR_RAISE(Datum result, + CallFunction("list_element", {list_array, idx})); + current = result.make_array(); + break; + } default: // Should have been checked in ResolveStructFieldType return Status::TypeError("struct_field: cannot reference child field of type ", @@ -460,7 +468,7 @@ struct StructFieldFunctor { static Status CheckIndex(int index, const DataType& type) { if (!ValidParentType(type)) { return Status::TypeError("struct_field: cannot subscript field of type ", type); - } else if (index < 0 || index >= type.num_fields()) { + } else if (type.id() != Type::LIST && (index < 0 || index >= type.num_fields())) { return Status::Invalid("struct_field: out-of-bounds field reference to field ", index, " in type ", type, " with ", type.num_fields(), " fields"); @@ -470,7 +478,7 @@ struct StructFieldFunctor { static bool ValidParentType(const DataType& type) { return type.id() == Type::STRUCT || type.id() == Type::DENSE_UNION || - type.id() == Type::SPARSE_UNION; + type.id() == Type::SPARSE_UNION || type.id() == Type::LIST; } }; @@ -487,8 +495,13 @@ Result ResolveStructFieldType(KernelContext* ctx, } for (const auto& index : field_path.indices()) { - RETURN_NOT_OK(StructFieldFunctor::CheckIndex(index, *type)); - type = type->field(index)->type().get(); + if (type->id() == Type::LIST) { + auto list_type = checked_cast(type); + type = list_type->value_type().get(); + } else { + RETURN_NOT_OK(StructFieldFunctor::CheckIndex(index, *type)); + type = type->field(index)->type().get(); + } } return type; } diff --git a/cpp/src/arrow/type.cc b/cpp/src/arrow/type.cc index ea9525404c8..d2943a785d1 100644 --- a/cpp/src/arrow/type.cc +++ b/cpp/src/arrow/type.cc @@ -1067,12 +1067,32 @@ struct FieldPathGetImpl { } int depth = 0; - const T* out; + const T* out = nullptr; for (int index : path->indices()) { if (children == nullptr) { return Status::NotImplemented("Get child data of non-struct array"); } + if constexpr (std::is_same_v>) { + // For lists, we don't care about the index, jump right into the list value type. + // The index here is used in the kernel to grab the specific list element. + // Children then are fields from list type, and out will be the children from + // that field, thus index 0. + if (out != nullptr && (*out)->type()->id() == Type::LIST) { + children = get_children(*out); + index = 0; + } else if (out == nullptr && children->size() == 1 && index >= 0) { + // WIP: Perhaps a dangerous assumption? + // For list types when previous was not a FieldPath, but + // a string name referencing a list type. Then the first iteration + // with index (referencing the list item), might look like it's + // out of bounds when actually we don't care about this index for + // type checking, only the kernel does; we need the list value type + // which is now the only child in the children vector. + index = 0; + } + } + if (index < 0 || static_cast(index) >= children->size()) { *out_of_range_depth = depth; return nullptr; @@ -1082,7 +1102,6 @@ struct FieldPathGetImpl { children = get_children(*out); ++depth; } - return *out; } diff --git a/python/pyarrow/tests/test_compute.py b/python/pyarrow/tests/test_compute.py index 68b3303fe78..f74458eaed6 100644 --- a/python/pyarrow/tests/test_compute.py +++ b/python/pyarrow/tests/test_compute.py @@ -2721,6 +2721,22 @@ def test_struct_fields_options(): # assert pc.struct_field(arr) == arr +@pytest.mark.parametrize("path,expected", ( + ([0, 0, 0], [1]), + ('.a[0].b', [1]), + ([0, 0, 1], [None]), + ('.a[0].c', [None]), + ([0, 1, 0], [None]), + ('.a[1].b', [None]), + ([0, 1, 1], ["hi"]), + ('.a[1].c', ["hi"]) +)) +def test_struct_field_list_path(path, expected): + arr = pa.array([{'a': [{'b': 1}, {'c': 'hi'}]}]) + out = pc.struct_field(arr, path) + assert out == pa.array(expected).cast(out.type) + + def test_case_when(): assert pc.case_when(pc.make_struct([True, False, None], [False, True, None]), From b06f715e8dba4a34c0063baaf5800757c8fff498 Mon Sep 17 00:00:00 2001 From: Miles Granger Date: Wed, 23 Nov 2022 12:34:24 +0100 Subject: [PATCH 2/5] [skip ci] Add C++ test for FieldRef/FieldPath --- cpp/src/arrow/type_test.cc | 37 +++++++++++++++++++++++++++++++++++++ 1 file changed, 37 insertions(+) diff --git a/cpp/src/arrow/type_test.cc b/cpp/src/arrow/type_test.cc index 954ad63c8aa..d4b9cd23610 100644 --- a/cpp/src/arrow/type_test.cc +++ b/cpp/src/arrow/type_test.cc @@ -434,6 +434,43 @@ TEST(TestFieldRef, DotPathRoundTrip) { check_roundtrip(FieldRef("foo", 1, FieldRef("bar", 2, 3), FieldRef())); } +TEST(TestFieldRef, NestedWithList) { + // Single list type + auto type = + struct_({field("c", list(struct_({field("d", int32()), field("e", int8())})))}); + // Note numeric values here go unused, outside of indicating a further step, during + // FindAll, only used by kernels for which list element to take. When FindAll encounters + // a list, it assumes the list value type. + // Not b/c they are numeric, but they are the position referencing a specific list + // element. + EXPECT_THAT(FieldRef("c", 0, "d").FindAll(*type), ElementsAre(FieldPath{0, 0, 0})); + EXPECT_THAT(FieldRef("c", 0, "e").FindAll(*type), ElementsAre(FieldPath{0, 0, 1})); + EXPECT_THAT(FieldRef("c", 1, "d").FindAll(*type), ElementsAre(FieldPath{0, 1, 0})); + EXPECT_THAT(FieldRef("c", 1, "e").FindAll(*type), ElementsAre(FieldPath{0, 1, 1})); + + // Double list and nested + type = struct_({field("a", list(type)), field("b", list(type))}); + EXPECT_THAT(FieldRef("a", 0, "c", 0, "d").FindAll(*type), + ElementsAre(FieldPath{0, 0, 0, 0, 0})); + EXPECT_THAT(FieldRef("b", 1, "c", 1, "e").FindAll(*type), + ElementsAre(FieldPath{1, 1, 0, 1, 1})); + + // Again, noting the 1 and 3 indexes refer to the specific list element + // and are not used in getting the type, only that its presence indicates + // further drilling into the list value type is needed. + // The values are used however in the kernel implementations for selecting + // from the specific list element. + ASSERT_OK_AND_ASSIGN(auto field, FieldPath({0, 0, 0, 0, 0}).Get(*type)) + ASSERT_EQ(field->type(), int32()); + ASSERT_OK_AND_ASSIGN(field, FieldPath({0, 1, 0, 1, 0}).Get(*type)); + ASSERT_EQ(field->type(), int32()); + + ASSERT_OK_AND_ASSIGN(field, FieldPath({1, 0, 0, 0, 1}).Get(*type)); + ASSERT_EQ(field->type(), int8()); + ASSERT_OK_AND_ASSIGN(field, FieldPath({1, 1, 0, 1, 1}).Get(*type)); + ASSERT_EQ(field->type(), int8()); +} + TEST(TestFieldPath, Nested) { auto f0 = field("alpha", int32()); auto f1_0 = field("alpha", int32()); From 5a082ae4e6453964df94007de6be15e28385bd1f Mon Sep 17 00:00:00 2001 From: Miles Granger Date: Wed, 30 Nov 2022 09:48:33 +0100 Subject: [PATCH 3/5] [skip ci] Pass parent DataType when available to FieldPathGetImpl --- cpp/src/arrow/type.cc | 115 ++++++++++++++++++++++-------------------- cpp/src/arrow/type.h | 6 ++- 2 files changed, 63 insertions(+), 58 deletions(-) diff --git a/cpp/src/arrow/type.cc b/cpp/src/arrow/type.cc index d2943a785d1..5583565d826 100644 --- a/cpp/src/arrow/type.cc +++ b/cpp/src/arrow/type.cc @@ -1060,8 +1060,9 @@ struct FieldPathGetImpl { } template - static Result Get(const FieldPath* path, const std::vector* children, - GetChildren&& get_children, int* out_of_range_depth) { + static Result Get(const FieldPath* path, const DataType* parent, + const std::vector* children, GetChildren&& get_children, + int* out_of_range_depth) { if (path->indices().empty()) { return Status::Invalid("empty indices cannot be traversed"); } @@ -1081,14 +1082,10 @@ struct FieldPathGetImpl { if (out != nullptr && (*out)->type()->id() == Type::LIST) { children = get_children(*out); index = 0; - } else if (out == nullptr && children->size() == 1 && index >= 0) { - // WIP: Perhaps a dangerous assumption? - // For list types when previous was not a FieldPath, but - // a string name referencing a list type. Then the first iteration - // with index (referencing the list item), might look like it's - // out of bounds when actually we don't care about this index for - // type checking, only the kernel does; we need the list value type - // which is now the only child in the children vector. + } else if (out == nullptr && parent != nullptr && parent->id() == Type::LIST) { + // For the first iteration (out == nullptr), if the parent is a list type + // then we certainly want to get the type of the list item. The specific + // kernel implemention will use the actual index to grab a specific list item. index = 0; } } @@ -1106,29 +1103,30 @@ struct FieldPathGetImpl { } template - static Result Get(const FieldPath* path, const std::vector* children, - GetChildren&& get_children) { + static Result Get(const FieldPath* path, const DataType* parent, + const std::vector* children, GetChildren&& get_children) { int out_of_range_depth = -1; - ARROW_ASSIGN_OR_RAISE(auto child, - Get(path, children, std::forward(get_children), - &out_of_range_depth)); + ARROW_ASSIGN_OR_RAISE( + auto child, Get(path, parent, children, std::forward(get_children), + &out_of_range_depth)); if (child != nullptr) { return std::move(child); } return IndexError(path, out_of_range_depth, *children); } - static Result> Get(const FieldPath* path, + static Result> Get(const FieldPath* path, const DataType* parent, const FieldVector& fields) { - return FieldPathGetImpl::Get(path, &fields, [](const std::shared_ptr& field) { - return &field->type()->fields(); - }); + return FieldPathGetImpl::Get( + path, parent, &fields, + [](const std::shared_ptr& field) { return &field->type()->fields(); }); } static Result> Get(const FieldPath* path, + const DataType* parent, const ArrayDataVector& child_data) { return FieldPathGetImpl::Get( - path, &child_data, + path, parent, &child_data, [](const std::shared_ptr& data) -> const ArrayDataVector* { if (data->type->id() != Type::STRUCT) { return nullptr; @@ -1139,19 +1137,20 @@ struct FieldPathGetImpl { }; Result> FieldPath::Get(const Schema& schema) const { - return FieldPathGetImpl::Get(this, schema.fields()); + return FieldPathGetImpl::Get(this, nullptr, schema.fields()); } Result> FieldPath::Get(const Field& field) const { - return FieldPathGetImpl::Get(this, field.type()->fields()); + return FieldPathGetImpl::Get(this, &*field.type(), field.type()->fields()); } Result> FieldPath::Get(const DataType& type) const { - return FieldPathGetImpl::Get(this, type.fields()); + return FieldPathGetImpl::Get(this, &type, type.fields()); } -Result> FieldPath::Get(const FieldVector& fields) const { - return FieldPathGetImpl::Get(this, fields); +Result> FieldPath::Get(const FieldVector& fields, + const DataType* parent) const { + return FieldPathGetImpl::Get(this, parent, fields); } Result> FieldPath::GetAll(const Schema& schm, @@ -1166,7 +1165,8 @@ Result> FieldPath::GetAll(const Schema& schm, } Result> FieldPath::Get(const RecordBatch& batch) const { - ARROW_ASSIGN_OR_RAISE(auto data, FieldPathGetImpl::Get(this, batch.column_data())); + ARROW_ASSIGN_OR_RAISE(auto data, + FieldPathGetImpl::Get(this, nullptr, batch.column_data())); return MakeArray(std::move(data)); } @@ -1179,7 +1179,7 @@ Result> FieldPath::Get(const ArrayData& data) const { if (data.type->id() != Type::STRUCT) { return Status::NotImplemented("Get child data of non-struct array"); } - return FieldPathGetImpl::Get(this, data.child_data); + return FieldPathGetImpl::Get(this, &*data.type, data.child_data); } FieldRef::FieldRef(FieldPath indices) : impl_(std::move(indices)) {} @@ -1395,29 +1395,29 @@ std::vector FieldRef::FindAll(const Schema& schema) const { return internal::MapVector([](int i) { return FieldPath{i}; }, schema.GetAllFieldIndices(*name)); } - return FindAll(schema.fields()); + return FindAll(schema.fields(), nullptr); } std::vector FieldRef::FindAll(const Field& field) const { - return FindAll(field.type()->fields()); + return FindAll(field.type()->fields(), &*field.type()); } std::vector FieldRef::FindAll(const DataType& type) const { - return FindAll(type.fields()); + return FindAll(type.fields(), &type); } -std::vector FieldRef::FindAll(const FieldVector& fields) const { +std::vector FieldRef::FindAll(const FieldVector& fields, + const DataType* parent) const { struct Visitor { std::vector operator()(const FieldPath& path) { // skip long IndexError construction if path is out of range int out_of_range_depth; auto maybe_field = FieldPathGetImpl::Get( - &path, &fields_, + &path, parent_, &fields_, [](const std::shared_ptr& field) { return &field->type()->fields(); }, &out_of_range_depth); DCHECK_OK(maybe_field.status()); - if (maybe_field.ValueOrDie() != nullptr) { return {path}; } @@ -1436,6 +1436,26 @@ std::vector FieldRef::FindAll(const FieldVector& fields) const { return out; } + std::vector operator()(const std::vector& refs) { + DCHECK_GE(refs.size(), 1); + Matches matches(refs.front().FindAll(fields_, parent_), fields_); + + for (auto ref_it = refs.begin() + 1; ref_it != refs.end(); ++ref_it) { + Matches next_matches; + for (size_t i = 0; i < matches.size(); ++i) { + const auto& referent = *matches.referents[i]; + + for (const FieldPath& match : ref_it->FindAll(referent)) { + next_matches.Add(matches.prefixes[i], match, referent.type()->fields(), + &*referent.type()); + } + } + matches = std::move(next_matches); + } + + return matches.prefixes; + } + struct Matches { // referents[i] is referenced by prefixes[i] std::vector prefixes; @@ -1452,10 +1472,11 @@ std::vector FieldRef::FindAll(const FieldVector& fields) const { size_t size() const { return referents.size(); } void Add(const FieldPath& prefix, const FieldPath& suffix, - const FieldVector& fields) { - auto maybe_field = suffix.Get(fields); + const FieldVector& fields, const DataType* parent = NULLPTR) { + auto maybe_field = suffix.Get(fields, parent); DCHECK_OK(maybe_field.status()); - referents.push_back(std::move(maybe_field).ValueOrDie()); + + referents.push_back(std::move(maybe_field.ValueOrDie())); std::vector concatenated_indices(prefix.indices().size() + suffix.indices().size()); @@ -1467,29 +1488,11 @@ std::vector FieldRef::FindAll(const FieldVector& fields) const { } }; - std::vector operator()(const std::vector& refs) { - DCHECK_GE(refs.size(), 1); - Matches matches(refs.front().FindAll(fields_), fields_); - - for (auto ref_it = refs.begin() + 1; ref_it != refs.end(); ++ref_it) { - Matches next_matches; - for (size_t i = 0; i < matches.size(); ++i) { - const auto& referent = *matches.referents[i]; - - for (const FieldPath& match : ref_it->FindAll(referent)) { - next_matches.Add(matches.prefixes[i], match, referent.type()->fields()); - } - } - matches = std::move(next_matches); - } - - return matches.prefixes; - } - + const DataType* parent_; const FieldVector& fields_; }; - return std::visit(Visitor{fields}, impl_); + return std::visit(Visitor{parent, fields}, impl_); } std::vector FieldRef::FindAll(const ArrayData& array) const { diff --git a/cpp/src/arrow/type.h b/cpp/src/arrow/type.h index 415aaacf1c9..0726a44b407 100644 --- a/cpp/src/arrow/type.h +++ b/cpp/src/arrow/type.h @@ -1624,7 +1624,8 @@ class ARROW_EXPORT FieldPath { Result> Get(const Schema& schema) const; Result> Get(const Field& field) const; Result> Get(const DataType& type) const; - Result> Get(const FieldVector& fields) const; + Result> Get(const FieldVector& fields, + const DataType* parent = NULLPTR) const; static Result> GetAll(const Schema& schema, const std::vector& paths); @@ -1762,7 +1763,8 @@ class ARROW_EXPORT FieldRef : public util::EqualityComparable { std::vector FindAll(const Schema& schema) const; std::vector FindAll(const Field& field) const; std::vector FindAll(const DataType& type) const; - std::vector FindAll(const FieldVector& fields) const; + std::vector FindAll(const FieldVector& fields, + const DataType* parent = NULLPTR) const; /// \brief Convenience function which applies FindAll to arg's type or schema. std::vector FindAll(const ArrayData& array) const; From a12ecb1718cb345c0e6b96dc6b75c2fca598a978 Mon Sep 17 00:00:00 2001 From: Miles Granger Date: Fri, 2 Dec 2022 09:28:02 +0100 Subject: [PATCH 4/5] Support other list types --- .../arrow/compute/kernels/scalar_nested.cc | 19 +++-- cpp/src/arrow/type.cc | 8 +- cpp/src/arrow/type_test.cc | 76 ++++++++++--------- python/pyarrow/tests/test_compute.py | 12 ++- 4 files changed, 70 insertions(+), 45 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/scalar_nested.cc b/cpp/src/arrow/compute/kernels/scalar_nested.cc index c92b009c3fb..117f30ac3cc 100644 --- a/cpp/src/arrow/compute/kernels/scalar_nested.cc +++ b/cpp/src/arrow/compute/kernels/scalar_nested.cc @@ -447,11 +447,12 @@ struct StructFieldFunctor { union_array.GetFlattenedField(index, ctx->memory_pool())); break; } - case Type::LIST: { - const auto& list_array = checked_cast(*current); + case Type::LIST: + case Type::LARGE_LIST: + case Type::FIXED_SIZE_LIST: { Datum idx(index); ARROW_ASSIGN_OR_RAISE(Datum result, - CallFunction("list_element", {list_array, idx})); + CallFunction("list_element", {*current, idx})); current = result.make_array(); break; } @@ -468,7 +469,7 @@ struct StructFieldFunctor { static Status CheckIndex(int index, const DataType& type) { if (!ValidParentType(type)) { return Status::TypeError("struct_field: cannot subscript field of type ", type); - } else if (type.id() != Type::LIST && (index < 0 || index >= type.num_fields())) { + } else if (!IsBaseListType(type) && (index < 0 || index >= type.num_fields())) { return Status::Invalid("struct_field: out-of-bounds field reference to field ", index, " in type ", type, " with ", type.num_fields(), " fields"); @@ -476,9 +477,13 @@ struct StructFieldFunctor { return Status::OK(); } + static bool IsBaseListType(const DataType& type) { + return dynamic_cast(&type) != nullptr; + } + static bool ValidParentType(const DataType& type) { return type.id() == Type::STRUCT || type.id() == Type::DENSE_UNION || - type.id() == Type::SPARSE_UNION || type.id() == Type::LIST; + type.id() == Type::SPARSE_UNION || IsBaseListType(type); } }; @@ -495,8 +500,8 @@ Result ResolveStructFieldType(KernelContext* ctx, } for (const auto& index : field_path.indices()) { - if (type->id() == Type::LIST) { - auto list_type = checked_cast(type); + if (StructFieldFunctor::IsBaseListType(*type->GetSharedPtr())) { + auto list_type = checked_cast(type); type = list_type->value_type().get(); } else { RETURN_NOT_OK(StructFieldFunctor::CheckIndex(index, *type)); diff --git a/cpp/src/arrow/type.cc b/cpp/src/arrow/type.cc index 5583565d826..7683e3bb22d 100644 --- a/cpp/src/arrow/type.cc +++ b/cpp/src/arrow/type.cc @@ -1067,6 +1067,10 @@ struct FieldPathGetImpl { return Status::Invalid("empty indices cannot be traversed"); } + auto IsBaseListType = [](const DataType& type) { + return dynamic_cast(&type) != nullptr; + }; + int depth = 0; const T* out = nullptr; for (int index : path->indices()) { @@ -1079,10 +1083,10 @@ struct FieldPathGetImpl { // The index here is used in the kernel to grab the specific list element. // Children then are fields from list type, and out will be the children from // that field, thus index 0. - if (out != nullptr && (*out)->type()->id() == Type::LIST) { + if (out != nullptr && IsBaseListType(*(*out)->type())) { children = get_children(*out); index = 0; - } else if (out == nullptr && parent != nullptr && parent->id() == Type::LIST) { + } else if (out == nullptr && parent != nullptr && IsBaseListType(*parent)) { // For the first iteration (out == nullptr), if the parent is a list type // then we certainly want to get the type of the list item. The specific // kernel implemention will use the actual index to grab a specific list item. diff --git a/cpp/src/arrow/type_test.cc b/cpp/src/arrow/type_test.cc index d4b9cd23610..3b08951503e 100644 --- a/cpp/src/arrow/type_test.cc +++ b/cpp/src/arrow/type_test.cc @@ -435,40 +435,48 @@ TEST(TestFieldRef, DotPathRoundTrip) { } TEST(TestFieldRef, NestedWithList) { - // Single list type - auto type = - struct_({field("c", list(struct_({field("d", int32()), field("e", int8())})))}); - // Note numeric values here go unused, outside of indicating a further step, during - // FindAll, only used by kernels for which list element to take. When FindAll encounters - // a list, it assumes the list value type. - // Not b/c they are numeric, but they are the position referencing a specific list - // element. - EXPECT_THAT(FieldRef("c", 0, "d").FindAll(*type), ElementsAre(FieldPath{0, 0, 0})); - EXPECT_THAT(FieldRef("c", 0, "e").FindAll(*type), ElementsAre(FieldPath{0, 0, 1})); - EXPECT_THAT(FieldRef("c", 1, "d").FindAll(*type), ElementsAre(FieldPath{0, 1, 0})); - EXPECT_THAT(FieldRef("c", 1, "e").FindAll(*type), ElementsAre(FieldPath{0, 1, 1})); - - // Double list and nested - type = struct_({field("a", list(type)), field("b", list(type))}); - EXPECT_THAT(FieldRef("a", 0, "c", 0, "d").FindAll(*type), - ElementsAre(FieldPath{0, 0, 0, 0, 0})); - EXPECT_THAT(FieldRef("b", 1, "c", 1, "e").FindAll(*type), - ElementsAre(FieldPath{1, 1, 0, 1, 1})); - - // Again, noting the 1 and 3 indexes refer to the specific list element - // and are not used in getting the type, only that its presence indicates - // further drilling into the list value type is needed. - // The values are used however in the kernel implementations for selecting - // from the specific list element. - ASSERT_OK_AND_ASSIGN(auto field, FieldPath({0, 0, 0, 0, 0}).Get(*type)) - ASSERT_EQ(field->type(), int32()); - ASSERT_OK_AND_ASSIGN(field, FieldPath({0, 1, 0, 1, 0}).Get(*type)); - ASSERT_EQ(field->type(), int32()); - - ASSERT_OK_AND_ASSIGN(field, FieldPath({1, 0, 0, 0, 1}).Get(*type)); - ASSERT_EQ(field->type(), int8()); - ASSERT_OK_AND_ASSIGN(field, FieldPath({1, 1, 0, 1, 1}).Get(*type)); - ASSERT_EQ(field->type(), int8()); + // bit of indirection to satisfy list / large_list variants + auto gen_type = [](std::shared_ptr (*f)(const std::shared_ptr&)) { + return f; + }; + + for (auto list_type : {gen_type(list), gen_type(large_list)}) { + // Single list type + auto type = struct_( + {field("c", list_type(struct_({field("d", int32()), field("e", int8())})))}); + + // Note numeric values here go unused, outside of indicating a further step, during + // FindAll, only used by kernels for which list element to take. When FindAll + // encounters a list, it assumes the list value type. Not b/c they are numeric, but + // they are the position referencing a specific list element. + EXPECT_THAT(FieldRef("c", 0, "d").FindAll(*type), ElementsAre(FieldPath{0, 0, 0})); + EXPECT_THAT(FieldRef("c", 0, "e").FindAll(*type), ElementsAre(FieldPath{0, 0, 1})); + EXPECT_THAT(FieldRef("c", 1, "d").FindAll(*type), ElementsAre(FieldPath{0, 1, 0})); + EXPECT_THAT(FieldRef("c", 1, "e").FindAll(*type), ElementsAre(FieldPath{0, 1, 1})); + + // Double list, variable and fixed + type = struct_({field("a", list_type(type)), field("b", fixed_size_list(type, 2))}); + + EXPECT_THAT(FieldRef("a", 0, "c", 0, "d").FindAll(*type), + ElementsAre(FieldPath{0, 0, 0, 0, 0})); + EXPECT_THAT(FieldRef("b", 1, "c", 1, "e").FindAll(*type), + ElementsAre(FieldPath{1, 1, 0, 1, 1})); + + // Again, noting the 1 and 3 indexes refer to the specific list element + // and are not used in getting the type, only that its presence indicates + // further drilling into the list value type is needed. + // The values are used however in the kernel implementations for selecting + // from the specific list element. + ASSERT_OK_AND_ASSIGN(auto field, FieldPath({0, 0, 0, 0, 0}).Get(*type)) + ASSERT_EQ(field->type(), int32()); + ASSERT_OK_AND_ASSIGN(field, FieldPath({0, 1, 0, 1, 0}).Get(*type)); + ASSERT_EQ(field->type(), int32()); + + ASSERT_OK_AND_ASSIGN(field, FieldPath({1, 0, 0, 0, 1}).Get(*type)); + ASSERT_EQ(field->type(), int8()); + ASSERT_OK_AND_ASSIGN(field, FieldPath({1, 1, 0, 1, 1}).Get(*type)); + ASSERT_EQ(field->type(), int8()); + } } TEST(TestFieldPath, Nested) { diff --git a/python/pyarrow/tests/test_compute.py b/python/pyarrow/tests/test_compute.py index f74458eaed6..a4ac1d2b90c 100644 --- a/python/pyarrow/tests/test_compute.py +++ b/python/pyarrow/tests/test_compute.py @@ -2731,8 +2731,16 @@ def test_struct_fields_options(): ([0, 1, 1], ["hi"]), ('.a[1].c', ["hi"]) )) -def test_struct_field_list_path(path, expected): - arr = pa.array([{'a': [{'b': 1}, {'c': 'hi'}]}]) +@pytest.mark.parametrize("list_type", ( + lambda v: pa.list_(v), + lambda v: pa.list_(v, 2), + lambda v: pa.large_list(v) +)) +def test_struct_field_list_path(path, expected, list_type): + type = pa.struct([pa.field('a', list_type( + pa.struct([pa.field('b', pa.int8()), + pa.field('c', pa.string())])))]) + arr = pa.array([{'a': [{'b': 1}, {'c': 'hi'}]}], type) out = pc.struct_field(arr, path) assert out == pa.array(expected).cast(out.type) From 99f829c7dcb22a525317d5bf52436264c0ea58ed Mon Sep 17 00:00:00 2001 From: Miles Granger Date: Fri, 2 Dec 2022 16:55:39 +0100 Subject: [PATCH 5/5] Expand C++ tests and fix for leading list type parent was not passed in FieldPath first iteration --- cpp/src/arrow/type.cc | 27 ++++++++++++++++----------- cpp/src/arrow/type_test.cc | 13 +++++++++++++ 2 files changed, 29 insertions(+), 11 deletions(-) diff --git a/cpp/src/arrow/type.cc b/cpp/src/arrow/type.cc index 7683e3bb22d..ec8b802d42f 100644 --- a/cpp/src/arrow/type.cc +++ b/cpp/src/arrow/type.cc @@ -1067,10 +1067,6 @@ struct FieldPathGetImpl { return Status::Invalid("empty indices cannot be traversed"); } - auto IsBaseListType = [](const DataType& type) { - return dynamic_cast(&type) != nullptr; - }; - int depth = 0; const T* out = nullptr; for (int index : path->indices()) { @@ -1079,6 +1075,10 @@ struct FieldPathGetImpl { } if constexpr (std::is_same_v>) { + auto IsBaseListType = [](const DataType& type) { + return dynamic_cast(&type) != nullptr; + }; + // For lists, we don't care about the index, jump right into the list value type. // The index here is used in the kernel to grab the specific list element. // Children then are fields from list type, and out will be the children from @@ -1442,7 +1442,7 @@ std::vector FieldRef::FindAll(const FieldVector& fields, std::vector operator()(const std::vector& refs) { DCHECK_GE(refs.size(), 1); - Matches matches(refs.front().FindAll(fields_, parent_), fields_); + Matches matches(refs.front().FindAll(fields_, parent_), fields_, parent_); for (auto ref_it = refs.begin() + 1; ref_it != refs.end(); ++ref_it) { Matches next_matches; @@ -1465,9 +1465,11 @@ std::vector FieldRef::FindAll(const FieldVector& fields, std::vector prefixes; FieldVector referents; - Matches(std::vector matches, const FieldVector& fields) { + Matches(std::vector matches, const FieldVector& fields, + const DataType* parent) { + auto current_parent = parent; for (auto& match : matches) { - Add({}, std::move(match), fields); + current_parent = Add({}, std::move(match), fields, current_parent); } } @@ -1475,12 +1477,14 @@ std::vector FieldRef::FindAll(const FieldVector& fields, size_t size() const { return referents.size(); } - void Add(const FieldPath& prefix, const FieldPath& suffix, - const FieldVector& fields, const DataType* parent = NULLPTR) { + const DataType* Add(const FieldPath& prefix, const FieldPath& suffix, + const FieldVector& fields, const DataType* parent = NULLPTR) { auto maybe_field = suffix.Get(fields, parent); - DCHECK_OK(maybe_field.status()); - referents.push_back(std::move(maybe_field.ValueOrDie())); + DCHECK_OK(maybe_field.status()); + auto field = maybe_field.ValueOrDie(); + auto field_type = field->type(); + referents.push_back(std::move(field)); std::vector concatenated_indices(prefix.indices().size() + suffix.indices().size()); @@ -1489,6 +1493,7 @@ std::vector FieldRef::FindAll(const FieldVector& fields, it = std::copy(path->indices().begin(), path->indices().end(), it); } prefixes.emplace_back(std::move(concatenated_indices)); + return &*field_type; } }; diff --git a/cpp/src/arrow/type_test.cc b/cpp/src/arrow/type_test.cc index 3b08951503e..07d22c6a107 100644 --- a/cpp/src/arrow/type_test.cc +++ b/cpp/src/arrow/type_test.cc @@ -453,6 +453,7 @@ TEST(TestFieldRef, NestedWithList) { EXPECT_THAT(FieldRef("c", 0, "e").FindAll(*type), ElementsAre(FieldPath{0, 0, 1})); EXPECT_THAT(FieldRef("c", 1, "d").FindAll(*type), ElementsAre(FieldPath{0, 1, 0})); EXPECT_THAT(FieldRef("c", 1, "e").FindAll(*type), ElementsAre(FieldPath{0, 1, 1})); + ASSERT_TRUE(FieldRef("c", "non-integer", "e").FindAll(*type).empty()); // Double list, variable and fixed type = struct_({field("a", list_type(type)), field("b", fixed_size_list(type, 2))}); @@ -476,6 +477,18 @@ TEST(TestFieldRef, NestedWithList) { ASSERT_EQ(field->type(), int8()); ASSERT_OK_AND_ASSIGN(field, FieldPath({1, 1, 0, 1, 1}).Get(*type)); ASSERT_EQ(field->type(), int8()); + + // leading list type + type = list_type(type); + EXPECT_THAT(FieldRef(1, "a", 0, "c", 0, "d").FindAll(*type), + ElementsAre(FieldPath{1, 0, 0, 0, 0, 0})); + EXPECT_THAT(FieldRef(0, "b", 1, "c", 1, "e").FindAll(*type), + ElementsAre(FieldPath{0, 1, 1, 0, 1, 1})); + + ASSERT_OK_AND_ASSIGN(field, FieldPath({0, 1, 0, 0, 0, 1}).Get(*type)); + ASSERT_EQ(field->type(), int8()); + ASSERT_OK_AND_ASSIGN(field, FieldPath({1, 1, 1, 0, 1, 1}).Get(*type)); + ASSERT_EQ(field->type(), int8()); } }