diff --git a/cpp/src/arrow/compute/kernels/scalar_nested.cc b/cpp/src/arrow/compute/kernels/scalar_nested.cc index 77fd67b3c77..117f30ac3cc 100644 --- a/cpp/src/arrow/compute/kernels/scalar_nested.cc +++ b/cpp/src/arrow/compute/kernels/scalar_nested.cc @@ -447,6 +447,15 @@ struct StructFieldFunctor { union_array.GetFlattenedField(index, ctx->memory_pool())); break; } + case Type::LIST: + case Type::LARGE_LIST: + case Type::FIXED_SIZE_LIST: { + Datum idx(index); + ARROW_ASSIGN_OR_RAISE(Datum result, + CallFunction("list_element", {*current, idx})); + current = result.make_array(); + break; + } default: // Should have been checked in ResolveStructFieldType return Status::TypeError("struct_field: cannot reference child field of type ", @@ -460,7 +469,7 @@ struct StructFieldFunctor { static Status CheckIndex(int index, const DataType& type) { if (!ValidParentType(type)) { return Status::TypeError("struct_field: cannot subscript field of type ", type); - } else if (index < 0 || index >= type.num_fields()) { + } else if (!IsBaseListType(type) && (index < 0 || index >= type.num_fields())) { return Status::Invalid("struct_field: out-of-bounds field reference to field ", index, " in type ", type, " with ", type.num_fields(), " fields"); @@ -468,9 +477,13 @@ struct StructFieldFunctor { return Status::OK(); } + static bool IsBaseListType(const DataType& type) { + return dynamic_cast(&type) != nullptr; + } + static bool ValidParentType(const DataType& type) { return type.id() == Type::STRUCT || type.id() == Type::DENSE_UNION || - type.id() == Type::SPARSE_UNION; + type.id() == Type::SPARSE_UNION || IsBaseListType(type); } }; @@ -487,8 +500,13 @@ Result ResolveStructFieldType(KernelContext* ctx, } for (const auto& index : field_path.indices()) { - RETURN_NOT_OK(StructFieldFunctor::CheckIndex(index, *type)); - type = type->field(index)->type().get(); + if (StructFieldFunctor::IsBaseListType(*type->GetSharedPtr())) { + auto list_type = checked_cast(type); + type = list_type->value_type().get(); + } else { + RETURN_NOT_OK(StructFieldFunctor::CheckIndex(index, *type)); + type = type->field(index)->type().get(); + } } return type; } diff --git a/cpp/src/arrow/type.cc b/cpp/src/arrow/type.cc index ea9525404c8..ec8b802d42f 100644 --- a/cpp/src/arrow/type.cc +++ b/cpp/src/arrow/type.cc @@ -1060,19 +1060,40 @@ struct FieldPathGetImpl { } template - static Result Get(const FieldPath* path, const std::vector* children, - GetChildren&& get_children, int* out_of_range_depth) { + static Result Get(const FieldPath* path, const DataType* parent, + const std::vector* children, GetChildren&& get_children, + int* out_of_range_depth) { if (path->indices().empty()) { return Status::Invalid("empty indices cannot be traversed"); } int depth = 0; - const T* out; + const T* out = nullptr; for (int index : path->indices()) { if (children == nullptr) { return Status::NotImplemented("Get child data of non-struct array"); } + if constexpr (std::is_same_v>) { + auto IsBaseListType = [](const DataType& type) { + return dynamic_cast(&type) != nullptr; + }; + + // For lists, we don't care about the index, jump right into the list value type. + // The index here is used in the kernel to grab the specific list element. + // Children then are fields from list type, and out will be the children from + // that field, thus index 0. + if (out != nullptr && IsBaseListType(*(*out)->type())) { + children = get_children(*out); + index = 0; + } else if (out == nullptr && parent != nullptr && IsBaseListType(*parent)) { + // For the first iteration (out == nullptr), if the parent is a list type + // then we certainly want to get the type of the list item. The specific + // kernel implemention will use the actual index to grab a specific list item. + index = 0; + } + } + if (index < 0 || static_cast(index) >= children->size()) { *out_of_range_depth = depth; return nullptr; @@ -1082,34 +1103,34 @@ struct FieldPathGetImpl { children = get_children(*out); ++depth; } - return *out; } template - static Result Get(const FieldPath* path, const std::vector* children, - GetChildren&& get_children) { + static Result Get(const FieldPath* path, const DataType* parent, + const std::vector* children, GetChildren&& get_children) { int out_of_range_depth = -1; - ARROW_ASSIGN_OR_RAISE(auto child, - Get(path, children, std::forward(get_children), - &out_of_range_depth)); + ARROW_ASSIGN_OR_RAISE( + auto child, Get(path, parent, children, std::forward(get_children), + &out_of_range_depth)); if (child != nullptr) { return std::move(child); } return IndexError(path, out_of_range_depth, *children); } - static Result> Get(const FieldPath* path, + static Result> Get(const FieldPath* path, const DataType* parent, const FieldVector& fields) { - return FieldPathGetImpl::Get(path, &fields, [](const std::shared_ptr& field) { - return &field->type()->fields(); - }); + return FieldPathGetImpl::Get( + path, parent, &fields, + [](const std::shared_ptr& field) { return &field->type()->fields(); }); } static Result> Get(const FieldPath* path, + const DataType* parent, const ArrayDataVector& child_data) { return FieldPathGetImpl::Get( - path, &child_data, + path, parent, &child_data, [](const std::shared_ptr& data) -> const ArrayDataVector* { if (data->type->id() != Type::STRUCT) { return nullptr; @@ -1120,19 +1141,20 @@ struct FieldPathGetImpl { }; Result> FieldPath::Get(const Schema& schema) const { - return FieldPathGetImpl::Get(this, schema.fields()); + return FieldPathGetImpl::Get(this, nullptr, schema.fields()); } Result> FieldPath::Get(const Field& field) const { - return FieldPathGetImpl::Get(this, field.type()->fields()); + return FieldPathGetImpl::Get(this, &*field.type(), field.type()->fields()); } Result> FieldPath::Get(const DataType& type) const { - return FieldPathGetImpl::Get(this, type.fields()); + return FieldPathGetImpl::Get(this, &type, type.fields()); } -Result> FieldPath::Get(const FieldVector& fields) const { - return FieldPathGetImpl::Get(this, fields); +Result> FieldPath::Get(const FieldVector& fields, + const DataType* parent) const { + return FieldPathGetImpl::Get(this, parent, fields); } Result> FieldPath::GetAll(const Schema& schm, @@ -1147,7 +1169,8 @@ Result> FieldPath::GetAll(const Schema& schm, } Result> FieldPath::Get(const RecordBatch& batch) const { - ARROW_ASSIGN_OR_RAISE(auto data, FieldPathGetImpl::Get(this, batch.column_data())); + ARROW_ASSIGN_OR_RAISE(auto data, + FieldPathGetImpl::Get(this, nullptr, batch.column_data())); return MakeArray(std::move(data)); } @@ -1160,7 +1183,7 @@ Result> FieldPath::Get(const ArrayData& data) const { if (data.type->id() != Type::STRUCT) { return Status::NotImplemented("Get child data of non-struct array"); } - return FieldPathGetImpl::Get(this, data.child_data); + return FieldPathGetImpl::Get(this, &*data.type, data.child_data); } FieldRef::FieldRef(FieldPath indices) : impl_(std::move(indices)) {} @@ -1376,29 +1399,29 @@ std::vector FieldRef::FindAll(const Schema& schema) const { return internal::MapVector([](int i) { return FieldPath{i}; }, schema.GetAllFieldIndices(*name)); } - return FindAll(schema.fields()); + return FindAll(schema.fields(), nullptr); } std::vector FieldRef::FindAll(const Field& field) const { - return FindAll(field.type()->fields()); + return FindAll(field.type()->fields(), &*field.type()); } std::vector FieldRef::FindAll(const DataType& type) const { - return FindAll(type.fields()); + return FindAll(type.fields(), &type); } -std::vector FieldRef::FindAll(const FieldVector& fields) const { +std::vector FieldRef::FindAll(const FieldVector& fields, + const DataType* parent) const { struct Visitor { std::vector operator()(const FieldPath& path) { // skip long IndexError construction if path is out of range int out_of_range_depth; auto maybe_field = FieldPathGetImpl::Get( - &path, &fields_, + &path, parent_, &fields_, [](const std::shared_ptr& field) { return &field->type()->fields(); }, &out_of_range_depth); DCHECK_OK(maybe_field.status()); - if (maybe_field.ValueOrDie() != nullptr) { return {path}; } @@ -1417,14 +1440,36 @@ std::vector FieldRef::FindAll(const FieldVector& fields) const { return out; } + std::vector operator()(const std::vector& refs) { + DCHECK_GE(refs.size(), 1); + Matches matches(refs.front().FindAll(fields_, parent_), fields_, parent_); + + for (auto ref_it = refs.begin() + 1; ref_it != refs.end(); ++ref_it) { + Matches next_matches; + for (size_t i = 0; i < matches.size(); ++i) { + const auto& referent = *matches.referents[i]; + + for (const FieldPath& match : ref_it->FindAll(referent)) { + next_matches.Add(matches.prefixes[i], match, referent.type()->fields(), + &*referent.type()); + } + } + matches = std::move(next_matches); + } + + return matches.prefixes; + } + struct Matches { // referents[i] is referenced by prefixes[i] std::vector prefixes; FieldVector referents; - Matches(std::vector matches, const FieldVector& fields) { + Matches(std::vector matches, const FieldVector& fields, + const DataType* parent) { + auto current_parent = parent; for (auto& match : matches) { - Add({}, std::move(match), fields); + current_parent = Add({}, std::move(match), fields, current_parent); } } @@ -1432,11 +1477,14 @@ std::vector FieldRef::FindAll(const FieldVector& fields) const { size_t size() const { return referents.size(); } - void Add(const FieldPath& prefix, const FieldPath& suffix, - const FieldVector& fields) { - auto maybe_field = suffix.Get(fields); + const DataType* Add(const FieldPath& prefix, const FieldPath& suffix, + const FieldVector& fields, const DataType* parent = NULLPTR) { + auto maybe_field = suffix.Get(fields, parent); + DCHECK_OK(maybe_field.status()); - referents.push_back(std::move(maybe_field).ValueOrDie()); + auto field = maybe_field.ValueOrDie(); + auto field_type = field->type(); + referents.push_back(std::move(field)); std::vector concatenated_indices(prefix.indices().size() + suffix.indices().size()); @@ -1445,32 +1493,15 @@ std::vector FieldRef::FindAll(const FieldVector& fields) const { it = std::copy(path->indices().begin(), path->indices().end(), it); } prefixes.emplace_back(std::move(concatenated_indices)); + return &*field_type; } }; - std::vector operator()(const std::vector& refs) { - DCHECK_GE(refs.size(), 1); - Matches matches(refs.front().FindAll(fields_), fields_); - - for (auto ref_it = refs.begin() + 1; ref_it != refs.end(); ++ref_it) { - Matches next_matches; - for (size_t i = 0; i < matches.size(); ++i) { - const auto& referent = *matches.referents[i]; - - for (const FieldPath& match : ref_it->FindAll(referent)) { - next_matches.Add(matches.prefixes[i], match, referent.type()->fields()); - } - } - matches = std::move(next_matches); - } - - return matches.prefixes; - } - + const DataType* parent_; const FieldVector& fields_; }; - return std::visit(Visitor{fields}, impl_); + return std::visit(Visitor{parent, fields}, impl_); } std::vector FieldRef::FindAll(const ArrayData& array) const { diff --git a/cpp/src/arrow/type.h b/cpp/src/arrow/type.h index 415aaacf1c9..0726a44b407 100644 --- a/cpp/src/arrow/type.h +++ b/cpp/src/arrow/type.h @@ -1624,7 +1624,8 @@ class ARROW_EXPORT FieldPath { Result> Get(const Schema& schema) const; Result> Get(const Field& field) const; Result> Get(const DataType& type) const; - Result> Get(const FieldVector& fields) const; + Result> Get(const FieldVector& fields, + const DataType* parent = NULLPTR) const; static Result> GetAll(const Schema& schema, const std::vector& paths); @@ -1762,7 +1763,8 @@ class ARROW_EXPORT FieldRef : public util::EqualityComparable { std::vector FindAll(const Schema& schema) const; std::vector FindAll(const Field& field) const; std::vector FindAll(const DataType& type) const; - std::vector FindAll(const FieldVector& fields) const; + std::vector FindAll(const FieldVector& fields, + const DataType* parent = NULLPTR) const; /// \brief Convenience function which applies FindAll to arg's type or schema. std::vector FindAll(const ArrayData& array) const; diff --git a/cpp/src/arrow/type_test.cc b/cpp/src/arrow/type_test.cc index 954ad63c8aa..07d22c6a107 100644 --- a/cpp/src/arrow/type_test.cc +++ b/cpp/src/arrow/type_test.cc @@ -434,6 +434,64 @@ TEST(TestFieldRef, DotPathRoundTrip) { check_roundtrip(FieldRef("foo", 1, FieldRef("bar", 2, 3), FieldRef())); } +TEST(TestFieldRef, NestedWithList) { + // bit of indirection to satisfy list / large_list variants + auto gen_type = [](std::shared_ptr (*f)(const std::shared_ptr&)) { + return f; + }; + + for (auto list_type : {gen_type(list), gen_type(large_list)}) { + // Single list type + auto type = struct_( + {field("c", list_type(struct_({field("d", int32()), field("e", int8())})))}); + + // Note numeric values here go unused, outside of indicating a further step, during + // FindAll, only used by kernels for which list element to take. When FindAll + // encounters a list, it assumes the list value type. Not b/c they are numeric, but + // they are the position referencing a specific list element. + EXPECT_THAT(FieldRef("c", 0, "d").FindAll(*type), ElementsAre(FieldPath{0, 0, 0})); + EXPECT_THAT(FieldRef("c", 0, "e").FindAll(*type), ElementsAre(FieldPath{0, 0, 1})); + EXPECT_THAT(FieldRef("c", 1, "d").FindAll(*type), ElementsAre(FieldPath{0, 1, 0})); + EXPECT_THAT(FieldRef("c", 1, "e").FindAll(*type), ElementsAre(FieldPath{0, 1, 1})); + ASSERT_TRUE(FieldRef("c", "non-integer", "e").FindAll(*type).empty()); + + // Double list, variable and fixed + type = struct_({field("a", list_type(type)), field("b", fixed_size_list(type, 2))}); + + EXPECT_THAT(FieldRef("a", 0, "c", 0, "d").FindAll(*type), + ElementsAre(FieldPath{0, 0, 0, 0, 0})); + EXPECT_THAT(FieldRef("b", 1, "c", 1, "e").FindAll(*type), + ElementsAre(FieldPath{1, 1, 0, 1, 1})); + + // Again, noting the 1 and 3 indexes refer to the specific list element + // and are not used in getting the type, only that its presence indicates + // further drilling into the list value type is needed. + // The values are used however in the kernel implementations for selecting + // from the specific list element. + ASSERT_OK_AND_ASSIGN(auto field, FieldPath({0, 0, 0, 0, 0}).Get(*type)) + ASSERT_EQ(field->type(), int32()); + ASSERT_OK_AND_ASSIGN(field, FieldPath({0, 1, 0, 1, 0}).Get(*type)); + ASSERT_EQ(field->type(), int32()); + + ASSERT_OK_AND_ASSIGN(field, FieldPath({1, 0, 0, 0, 1}).Get(*type)); + ASSERT_EQ(field->type(), int8()); + ASSERT_OK_AND_ASSIGN(field, FieldPath({1, 1, 0, 1, 1}).Get(*type)); + ASSERT_EQ(field->type(), int8()); + + // leading list type + type = list_type(type); + EXPECT_THAT(FieldRef(1, "a", 0, "c", 0, "d").FindAll(*type), + ElementsAre(FieldPath{1, 0, 0, 0, 0, 0})); + EXPECT_THAT(FieldRef(0, "b", 1, "c", 1, "e").FindAll(*type), + ElementsAre(FieldPath{0, 1, 1, 0, 1, 1})); + + ASSERT_OK_AND_ASSIGN(field, FieldPath({0, 1, 0, 0, 0, 1}).Get(*type)); + ASSERT_EQ(field->type(), int8()); + ASSERT_OK_AND_ASSIGN(field, FieldPath({1, 1, 1, 0, 1, 1}).Get(*type)); + ASSERT_EQ(field->type(), int8()); + } +} + TEST(TestFieldPath, Nested) { auto f0 = field("alpha", int32()); auto f1_0 = field("alpha", int32()); diff --git a/python/pyarrow/tests/test_compute.py b/python/pyarrow/tests/test_compute.py index 68b3303fe78..a4ac1d2b90c 100644 --- a/python/pyarrow/tests/test_compute.py +++ b/python/pyarrow/tests/test_compute.py @@ -2721,6 +2721,30 @@ def test_struct_fields_options(): # assert pc.struct_field(arr) == arr +@pytest.mark.parametrize("path,expected", ( + ([0, 0, 0], [1]), + ('.a[0].b', [1]), + ([0, 0, 1], [None]), + ('.a[0].c', [None]), + ([0, 1, 0], [None]), + ('.a[1].b', [None]), + ([0, 1, 1], ["hi"]), + ('.a[1].c', ["hi"]) +)) +@pytest.mark.parametrize("list_type", ( + lambda v: pa.list_(v), + lambda v: pa.list_(v, 2), + lambda v: pa.large_list(v) +)) +def test_struct_field_list_path(path, expected, list_type): + type = pa.struct([pa.field('a', list_type( + pa.struct([pa.field('b', pa.int8()), + pa.field('c', pa.string())])))]) + arr = pa.array([{'a': [{'b': 1}, {'c': 'hi'}]}], type) + out = pc.struct_field(arr, path) + assert out == pa.array(expected).cast(out.type) + + def test_case_when(): assert pc.case_when(pc.make_struct([True, False, None], [False, True, None]),