-
Notifications
You must be signed in to change notification settings - Fork 4k
ARROW-18265: [C++][Python] Support FieldRef to work with ListElement #14697
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
ae9e603
b06f715
5a082ae
a12ecb1
99f829c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -447,6 +447,15 @@ struct StructFieldFunctor { | |
| union_array.GetFlattenedField(index, ctx->memory_pool())); | ||
| break; | ||
| } | ||
| case Type::LIST: | ||
| case Type::LARGE_LIST: | ||
| case Type::FIXED_SIZE_LIST: { | ||
| Datum idx(index); | ||
| ARROW_ASSIGN_OR_RAISE(Datum result, | ||
| CallFunction("list_element", {*current, idx})); | ||
| current = result.make_array(); | ||
| break; | ||
| } | ||
| default: | ||
| // Should have been checked in ResolveStructFieldType | ||
| return Status::TypeError("struct_field: cannot reference child field of type ", | ||
|
|
@@ -460,17 +469,21 @@ struct StructFieldFunctor { | |
| static Status CheckIndex(int index, const DataType& type) { | ||
| if (!ValidParentType(type)) { | ||
| return Status::TypeError("struct_field: cannot subscript field of type ", type); | ||
| } else if (index < 0 || index >= type.num_fields()) { | ||
| } else if (!IsBaseListType(type) && (index < 0 || index >= type.num_fields())) { | ||
| return Status::Invalid("struct_field: out-of-bounds field reference to field ", | ||
| index, " in type ", type, " with ", type.num_fields(), | ||
| " fields"); | ||
| } | ||
| return Status::OK(); | ||
| } | ||
|
|
||
| static bool IsBaseListType(const DataType& type) { | ||
| return dynamic_cast<const BaseListType*>(&type) != nullptr; | ||
| } | ||
|
Comment on lines
+480
to
+482
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We have
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also, if we keep this, using a combination of a few |
||
|
|
||
| static bool ValidParentType(const DataType& type) { | ||
| return type.id() == Type::STRUCT || type.id() == Type::DENSE_UNION || | ||
| type.id() == Type::SPARSE_UNION; | ||
| type.id() == Type::SPARSE_UNION || IsBaseListType(type); | ||
| } | ||
| }; | ||
|
|
||
|
|
@@ -487,8 +500,13 @@ Result<TypeHolder> ResolveStructFieldType(KernelContext* ctx, | |
| } | ||
|
|
||
| for (const auto& index : field_path.indices()) { | ||
| RETURN_NOT_OK(StructFieldFunctor::CheckIndex(index, *type)); | ||
| type = type->field(index)->type().get(); | ||
| if (StructFieldFunctor::IsBaseListType(*type->GetSharedPtr())) { | ||
| auto list_type = checked_cast<const BaseListType*>(type); | ||
| type = list_type->value_type().get(); | ||
| } else { | ||
| RETURN_NOT_OK(StructFieldFunctor::CheckIndex(index, *type)); | ||
| type = type->field(index)->type().get(); | ||
| } | ||
| } | ||
| return type; | ||
| } | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1060,19 +1060,40 @@ struct FieldPathGetImpl { | |
| } | ||
|
|
||
| template <typename T, typename GetChildren> | ||
| static Result<T> Get(const FieldPath* path, const std::vector<T>* children, | ||
| GetChildren&& get_children, int* out_of_range_depth) { | ||
| static Result<T> Get(const FieldPath* path, const DataType* parent, | ||
| const std::vector<T>* children, GetChildren&& get_children, | ||
| int* out_of_range_depth) { | ||
| if (path->indices().empty()) { | ||
| return Status::Invalid("empty indices cannot be traversed"); | ||
| } | ||
|
|
||
| int depth = 0; | ||
| const T* out; | ||
| const T* out = nullptr; | ||
| for (int index : path->indices()) { | ||
| if (children == nullptr) { | ||
| return Status::NotImplemented("Get child data of non-struct array"); | ||
| } | ||
|
|
||
| if constexpr (std::is_same_v<T, std::shared_ptr<arrow::Field>>) { | ||
| auto IsBaseListType = [](const DataType& type) { | ||
| return dynamic_cast<const BaseListType*>(&type) != nullptr; | ||
| }; | ||
|
|
||
| // For lists, we don't care about the index, jump right into the list value type. | ||
| // The index here is used in the kernel to grab the specific list element. | ||
| // Children then are fields from list type, and out will be the children from | ||
| // that field, thus index 0. | ||
| if (out != nullptr && IsBaseListType(*(*out)->type())) { | ||
| children = get_children(*out); | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't fully understand why we need the
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. (or to put it differently: why isn't this initial get_children needed when parent is a struct type? What's the difference between both?) |
||
| index = 0; | ||
| } else if (out == nullptr && parent != nullptr && IsBaseListType(*parent)) { | ||
| // For the first iteration (out == nullptr), if the parent is a list type | ||
| // then we certainly want to get the type of the list item. The specific | ||
| // kernel implemention will use the actual index to grab a specific list item. | ||
| index = 0; | ||
| } | ||
| } | ||
|
|
||
| if (index < 0 || static_cast<size_t>(index) >= children->size()) { | ||
| *out_of_range_depth = depth; | ||
| return nullptr; | ||
|
|
@@ -1082,34 +1103,34 @@ struct FieldPathGetImpl { | |
| children = get_children(*out); | ||
| ++depth; | ||
| } | ||
|
|
||
| return *out; | ||
| } | ||
|
|
||
| template <typename T, typename GetChildren> | ||
| static Result<T> Get(const FieldPath* path, const std::vector<T>* children, | ||
| GetChildren&& get_children) { | ||
| static Result<T> Get(const FieldPath* path, const DataType* parent, | ||
| const std::vector<T>* children, GetChildren&& get_children) { | ||
| int out_of_range_depth = -1; | ||
| ARROW_ASSIGN_OR_RAISE(auto child, | ||
| Get(path, children, std::forward<GetChildren>(get_children), | ||
| &out_of_range_depth)); | ||
| ARROW_ASSIGN_OR_RAISE( | ||
| auto child, Get(path, parent, children, std::forward<GetChildren>(get_children), | ||
| &out_of_range_depth)); | ||
| if (child != nullptr) { | ||
| return std::move(child); | ||
| } | ||
| return IndexError(path, out_of_range_depth, *children); | ||
| } | ||
|
|
||
| static Result<std::shared_ptr<Field>> Get(const FieldPath* path, | ||
| static Result<std::shared_ptr<Field>> Get(const FieldPath* path, const DataType* parent, | ||
| const FieldVector& fields) { | ||
| return FieldPathGetImpl::Get(path, &fields, [](const std::shared_ptr<Field>& field) { | ||
| return &field->type()->fields(); | ||
| }); | ||
| return FieldPathGetImpl::Get( | ||
| path, parent, &fields, | ||
| [](const std::shared_ptr<Field>& field) { return &field->type()->fields(); }); | ||
| } | ||
|
|
||
| static Result<std::shared_ptr<ArrayData>> Get(const FieldPath* path, | ||
| const DataType* parent, | ||
| const ArrayDataVector& child_data) { | ||
| return FieldPathGetImpl::Get( | ||
| path, &child_data, | ||
| path, parent, &child_data, | ||
| [](const std::shared_ptr<ArrayData>& data) -> const ArrayDataVector* { | ||
| if (data->type->id() != Type::STRUCT) { | ||
| return nullptr; | ||
|
|
@@ -1120,19 +1141,20 @@ struct FieldPathGetImpl { | |
| }; | ||
|
|
||
| Result<std::shared_ptr<Field>> FieldPath::Get(const Schema& schema) const { | ||
| return FieldPathGetImpl::Get(this, schema.fields()); | ||
| return FieldPathGetImpl::Get(this, nullptr, schema.fields()); | ||
| } | ||
|
|
||
| Result<std::shared_ptr<Field>> FieldPath::Get(const Field& field) const { | ||
| return FieldPathGetImpl::Get(this, field.type()->fields()); | ||
| return FieldPathGetImpl::Get(this, &*field.type(), field.type()->fields()); | ||
| } | ||
|
|
||
| Result<std::shared_ptr<Field>> FieldPath::Get(const DataType& type) const { | ||
| return FieldPathGetImpl::Get(this, type.fields()); | ||
| return FieldPathGetImpl::Get(this, &type, type.fields()); | ||
| } | ||
|
|
||
| Result<std::shared_ptr<Field>> FieldPath::Get(const FieldVector& fields) const { | ||
| return FieldPathGetImpl::Get(this, fields); | ||
| Result<std::shared_ptr<Field>> FieldPath::Get(const FieldVector& fields, | ||
| const DataType* parent) const { | ||
| return FieldPathGetImpl::Get(this, parent, fields); | ||
| } | ||
|
|
||
| Result<std::shared_ptr<Schema>> FieldPath::GetAll(const Schema& schm, | ||
|
|
@@ -1147,7 +1169,8 @@ Result<std::shared_ptr<Schema>> FieldPath::GetAll(const Schema& schm, | |
| } | ||
|
|
||
| Result<std::shared_ptr<Array>> FieldPath::Get(const RecordBatch& batch) const { | ||
| ARROW_ASSIGN_OR_RAISE(auto data, FieldPathGetImpl::Get(this, batch.column_data())); | ||
| ARROW_ASSIGN_OR_RAISE(auto data, | ||
| FieldPathGetImpl::Get(this, nullptr, batch.column_data())); | ||
| return MakeArray(std::move(data)); | ||
| } | ||
|
|
||
|
|
@@ -1160,7 +1183,7 @@ Result<std::shared_ptr<ArrayData>> FieldPath::Get(const ArrayData& data) const { | |
| if (data.type->id() != Type::STRUCT) { | ||
| return Status::NotImplemented("Get child data of non-struct array"); | ||
| } | ||
| return FieldPathGetImpl::Get(this, data.child_data); | ||
| return FieldPathGetImpl::Get(this, &*data.type, data.child_data); | ||
| } | ||
|
|
||
| FieldRef::FieldRef(FieldPath indices) : impl_(std::move(indices)) {} | ||
|
|
@@ -1376,29 +1399,29 @@ std::vector<FieldPath> FieldRef::FindAll(const Schema& schema) const { | |
| return internal::MapVector([](int i) { return FieldPath{i}; }, | ||
| schema.GetAllFieldIndices(*name)); | ||
| } | ||
| return FindAll(schema.fields()); | ||
| return FindAll(schema.fields(), nullptr); | ||
| } | ||
|
|
||
| std::vector<FieldPath> FieldRef::FindAll(const Field& field) const { | ||
| return FindAll(field.type()->fields()); | ||
| return FindAll(field.type()->fields(), &*field.type()); | ||
| } | ||
|
|
||
| std::vector<FieldPath> FieldRef::FindAll(const DataType& type) const { | ||
| return FindAll(type.fields()); | ||
| return FindAll(type.fields(), &type); | ||
| } | ||
|
|
||
| std::vector<FieldPath> FieldRef::FindAll(const FieldVector& fields) const { | ||
| std::vector<FieldPath> FieldRef::FindAll(const FieldVector& fields, | ||
| const DataType* parent) const { | ||
| struct Visitor { | ||
| std::vector<FieldPath> operator()(const FieldPath& path) { | ||
| // skip long IndexError construction if path is out of range | ||
| int out_of_range_depth; | ||
| auto maybe_field = FieldPathGetImpl::Get( | ||
| &path, &fields_, | ||
| &path, parent_, &fields_, | ||
| [](const std::shared_ptr<Field>& field) { return &field->type()->fields(); }, | ||
| &out_of_range_depth); | ||
|
|
||
| DCHECK_OK(maybe_field.status()); | ||
|
|
||
| if (maybe_field.ValueOrDie() != nullptr) { | ||
| return {path}; | ||
| } | ||
|
|
@@ -1417,26 +1440,51 @@ std::vector<FieldPath> FieldRef::FindAll(const FieldVector& fields) const { | |
| return out; | ||
| } | ||
|
|
||
| std::vector<FieldPath> operator()(const std::vector<FieldRef>& refs) { | ||
| DCHECK_GE(refs.size(), 1); | ||
| Matches matches(refs.front().FindAll(fields_, parent_), fields_, parent_); | ||
|
|
||
| for (auto ref_it = refs.begin() + 1; ref_it != refs.end(); ++ref_it) { | ||
| Matches next_matches; | ||
| for (size_t i = 0; i < matches.size(); ++i) { | ||
| const auto& referent = *matches.referents[i]; | ||
|
|
||
| for (const FieldPath& match : ref_it->FindAll(referent)) { | ||
| next_matches.Add(matches.prefixes[i], match, referent.type()->fields(), | ||
| &*referent.type()); | ||
| } | ||
| } | ||
| matches = std::move(next_matches); | ||
| } | ||
|
|
||
| return matches.prefixes; | ||
| } | ||
|
|
||
| struct Matches { | ||
| // referents[i] is referenced by prefixes[i] | ||
| std::vector<FieldPath> prefixes; | ||
| FieldVector referents; | ||
|
|
||
| Matches(std::vector<FieldPath> matches, const FieldVector& fields) { | ||
| Matches(std::vector<FieldPath> matches, const FieldVector& fields, | ||
| const DataType* parent) { | ||
| auto current_parent = parent; | ||
| for (auto& match : matches) { | ||
| Add({}, std::move(match), fields); | ||
| current_parent = Add({}, std::move(match), fields, current_parent); | ||
| } | ||
| } | ||
|
|
||
| Matches() = default; | ||
|
|
||
| size_t size() const { return referents.size(); } | ||
|
|
||
| void Add(const FieldPath& prefix, const FieldPath& suffix, | ||
| const FieldVector& fields) { | ||
| auto maybe_field = suffix.Get(fields); | ||
| const DataType* Add(const FieldPath& prefix, const FieldPath& suffix, | ||
| const FieldVector& fields, const DataType* parent = NULLPTR) { | ||
| auto maybe_field = suffix.Get(fields, parent); | ||
|
|
||
| DCHECK_OK(maybe_field.status()); | ||
| referents.push_back(std::move(maybe_field).ValueOrDie()); | ||
| auto field = maybe_field.ValueOrDie(); | ||
| auto field_type = field->type(); | ||
| referents.push_back(std::move(field)); | ||
|
|
||
| std::vector<int> concatenated_indices(prefix.indices().size() + | ||
| suffix.indices().size()); | ||
|
|
@@ -1445,32 +1493,15 @@ std::vector<FieldPath> FieldRef::FindAll(const FieldVector& fields) const { | |
| it = std::copy(path->indices().begin(), path->indices().end(), it); | ||
| } | ||
| prefixes.emplace_back(std::move(concatenated_indices)); | ||
| return &*field_type; | ||
| } | ||
| }; | ||
|
|
||
| std::vector<FieldPath> operator()(const std::vector<FieldRef>& refs) { | ||
| DCHECK_GE(refs.size(), 1); | ||
| Matches matches(refs.front().FindAll(fields_), fields_); | ||
|
|
||
| for (auto ref_it = refs.begin() + 1; ref_it != refs.end(); ++ref_it) { | ||
| Matches next_matches; | ||
| for (size_t i = 0; i < matches.size(); ++i) { | ||
| const auto& referent = *matches.referents[i]; | ||
|
|
||
| for (const FieldPath& match : ref_it->FindAll(referent)) { | ||
| next_matches.Add(matches.prefixes[i], match, referent.type()->fields()); | ||
| } | ||
| } | ||
| matches = std::move(next_matches); | ||
| } | ||
|
|
||
| return matches.prefixes; | ||
| } | ||
|
|
||
| const DataType* parent_; | ||
| const FieldVector& fields_; | ||
| }; | ||
|
|
||
| return std::visit(Visitor{fields}, impl_); | ||
| return std::visit(Visitor{parent, fields}, impl_); | ||
| } | ||
|
|
||
| std::vector<FieldPath> FieldRef::FindAll(const ArrayData& array) const { | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In theory, we could check the index for fixed sized list arrays as well (not for variable size list arrays, though, so not sure that is worth it)