Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 22 additions & 4 deletions cpp/src/arrow/compute/kernels/scalar_nested.cc
Original file line number Diff line number Diff line change
Expand Up @@ -447,6 +447,15 @@ struct StructFieldFunctor {
union_array.GetFlattenedField(index, ctx->memory_pool()));
break;
}
case Type::LIST:
case Type::LARGE_LIST:
case Type::FIXED_SIZE_LIST: {
Datum idx(index);
ARROW_ASSIGN_OR_RAISE(Datum result,
CallFunction("list_element", {*current, idx}));
current = result.make_array();
break;
}
default:
// Should have been checked in ResolveStructFieldType
return Status::TypeError("struct_field: cannot reference child field of type ",
Expand All @@ -460,17 +469,21 @@ struct StructFieldFunctor {
static Status CheckIndex(int index, const DataType& type) {
if (!ValidParentType(type)) {
return Status::TypeError("struct_field: cannot subscript field of type ", type);
} else if (index < 0 || index >= type.num_fields()) {
} else if (!IsBaseListType(type) && (index < 0 || index >= type.num_fields())) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In theory, we could check the index for fixed sized list arrays as well (not for variable size list arrays, though, so not sure that is worth it)

return Status::Invalid("struct_field: out-of-bounds field reference to field ",
index, " in type ", type, " with ", type.num_fields(),
" fields");
}
return Status::OK();
}

static bool IsBaseListType(const DataType& type) {
return dynamic_cast<const BaseListType*>(&type) != nullptr;
}
Comment on lines +480 to +482
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have is_list_like helper defined (although that also includes Map type, not sure if the list_element kernel supports that)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, if we keep this, using a combination of a few type.id() == .. checks (like ValidParentType just below) might be more readable to know which types are included)


static bool ValidParentType(const DataType& type) {
return type.id() == Type::STRUCT || type.id() == Type::DENSE_UNION ||
type.id() == Type::SPARSE_UNION;
type.id() == Type::SPARSE_UNION || IsBaseListType(type);
}
};

Expand All @@ -487,8 +500,13 @@ Result<TypeHolder> ResolveStructFieldType(KernelContext* ctx,
}

for (const auto& index : field_path.indices()) {
RETURN_NOT_OK(StructFieldFunctor::CheckIndex(index, *type));
type = type->field(index)->type().get();
if (StructFieldFunctor::IsBaseListType(*type->GetSharedPtr())) {
auto list_type = checked_cast<const BaseListType*>(type);
type = list_type->value_type().get();
} else {
RETURN_NOT_OK(StructFieldFunctor::CheckIndex(index, *type));
type = type->field(index)->type().get();
}
}
return type;
}
Expand Down
137 changes: 84 additions & 53 deletions cpp/src/arrow/type.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1060,19 +1060,40 @@ struct FieldPathGetImpl {
}

template <typename T, typename GetChildren>
static Result<T> Get(const FieldPath* path, const std::vector<T>* children,
GetChildren&& get_children, int* out_of_range_depth) {
static Result<T> Get(const FieldPath* path, const DataType* parent,
const std::vector<T>* children, GetChildren&& get_children,
int* out_of_range_depth) {
if (path->indices().empty()) {
return Status::Invalid("empty indices cannot be traversed");
}

int depth = 0;
const T* out;
const T* out = nullptr;
for (int index : path->indices()) {
if (children == nullptr) {
return Status::NotImplemented("Get child data of non-struct array");
}

if constexpr (std::is_same_v<T, std::shared_ptr<arrow::Field>>) {
auto IsBaseListType = [](const DataType& type) {
return dynamic_cast<const BaseListType*>(&type) != nullptr;
};

// For lists, we don't care about the index, jump right into the list value type.
// The index here is used in the kernel to grab the specific list element.
// Children then are fields from list type, and out will be the children from
// that field, thus index 0.
if (out != nullptr && IsBaseListType(*(*out)->type())) {
children = get_children(*out);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't fully understand why we need the get_children here (because then we call get_children twice in a single iteration? since there is another call a bit below)
Can you add a clarifying comment about that?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(or to put it differently: why isn't this initial get_children needed when parent is a struct type? What's the difference between both?)

index = 0;
} else if (out == nullptr && parent != nullptr && IsBaseListType(*parent)) {
// For the first iteration (out == nullptr), if the parent is a list type
// then we certainly want to get the type of the list item. The specific
// kernel implemention will use the actual index to grab a specific list item.
index = 0;
}
}

if (index < 0 || static_cast<size_t>(index) >= children->size()) {
*out_of_range_depth = depth;
return nullptr;
Expand All @@ -1082,34 +1103,34 @@ struct FieldPathGetImpl {
children = get_children(*out);
++depth;
}

return *out;
}

template <typename T, typename GetChildren>
static Result<T> Get(const FieldPath* path, const std::vector<T>* children,
GetChildren&& get_children) {
static Result<T> Get(const FieldPath* path, const DataType* parent,
const std::vector<T>* children, GetChildren&& get_children) {
int out_of_range_depth = -1;
ARROW_ASSIGN_OR_RAISE(auto child,
Get(path, children, std::forward<GetChildren>(get_children),
&out_of_range_depth));
ARROW_ASSIGN_OR_RAISE(
auto child, Get(path, parent, children, std::forward<GetChildren>(get_children),
&out_of_range_depth));
if (child != nullptr) {
return std::move(child);
}
return IndexError(path, out_of_range_depth, *children);
}

static Result<std::shared_ptr<Field>> Get(const FieldPath* path,
static Result<std::shared_ptr<Field>> Get(const FieldPath* path, const DataType* parent,
const FieldVector& fields) {
return FieldPathGetImpl::Get(path, &fields, [](const std::shared_ptr<Field>& field) {
return &field->type()->fields();
});
return FieldPathGetImpl::Get(
path, parent, &fields,
[](const std::shared_ptr<Field>& field) { return &field->type()->fields(); });
}

static Result<std::shared_ptr<ArrayData>> Get(const FieldPath* path,
const DataType* parent,
const ArrayDataVector& child_data) {
return FieldPathGetImpl::Get(
path, &child_data,
path, parent, &child_data,
[](const std::shared_ptr<ArrayData>& data) -> const ArrayDataVector* {
if (data->type->id() != Type::STRUCT) {
return nullptr;
Expand All @@ -1120,19 +1141,20 @@ struct FieldPathGetImpl {
};

Result<std::shared_ptr<Field>> FieldPath::Get(const Schema& schema) const {
return FieldPathGetImpl::Get(this, schema.fields());
return FieldPathGetImpl::Get(this, nullptr, schema.fields());
}

Result<std::shared_ptr<Field>> FieldPath::Get(const Field& field) const {
return FieldPathGetImpl::Get(this, field.type()->fields());
return FieldPathGetImpl::Get(this, &*field.type(), field.type()->fields());
}

Result<std::shared_ptr<Field>> FieldPath::Get(const DataType& type) const {
return FieldPathGetImpl::Get(this, type.fields());
return FieldPathGetImpl::Get(this, &type, type.fields());
}

Result<std::shared_ptr<Field>> FieldPath::Get(const FieldVector& fields) const {
return FieldPathGetImpl::Get(this, fields);
Result<std::shared_ptr<Field>> FieldPath::Get(const FieldVector& fields,
const DataType* parent) const {
return FieldPathGetImpl::Get(this, parent, fields);
}

Result<std::shared_ptr<Schema>> FieldPath::GetAll(const Schema& schm,
Expand All @@ -1147,7 +1169,8 @@ Result<std::shared_ptr<Schema>> FieldPath::GetAll(const Schema& schm,
}

Result<std::shared_ptr<Array>> FieldPath::Get(const RecordBatch& batch) const {
ARROW_ASSIGN_OR_RAISE(auto data, FieldPathGetImpl::Get(this, batch.column_data()));
ARROW_ASSIGN_OR_RAISE(auto data,
FieldPathGetImpl::Get(this, nullptr, batch.column_data()));
return MakeArray(std::move(data));
}

Expand All @@ -1160,7 +1183,7 @@ Result<std::shared_ptr<ArrayData>> FieldPath::Get(const ArrayData& data) const {
if (data.type->id() != Type::STRUCT) {
return Status::NotImplemented("Get child data of non-struct array");
}
return FieldPathGetImpl::Get(this, data.child_data);
return FieldPathGetImpl::Get(this, &*data.type, data.child_data);
}

FieldRef::FieldRef(FieldPath indices) : impl_(std::move(indices)) {}
Expand Down Expand Up @@ -1376,29 +1399,29 @@ std::vector<FieldPath> FieldRef::FindAll(const Schema& schema) const {
return internal::MapVector([](int i) { return FieldPath{i}; },
schema.GetAllFieldIndices(*name));
}
return FindAll(schema.fields());
return FindAll(schema.fields(), nullptr);
}

std::vector<FieldPath> FieldRef::FindAll(const Field& field) const {
return FindAll(field.type()->fields());
return FindAll(field.type()->fields(), &*field.type());
}

std::vector<FieldPath> FieldRef::FindAll(const DataType& type) const {
return FindAll(type.fields());
return FindAll(type.fields(), &type);
}

std::vector<FieldPath> FieldRef::FindAll(const FieldVector& fields) const {
std::vector<FieldPath> FieldRef::FindAll(const FieldVector& fields,
const DataType* parent) const {
struct Visitor {
std::vector<FieldPath> operator()(const FieldPath& path) {
// skip long IndexError construction if path is out of range
int out_of_range_depth;
auto maybe_field = FieldPathGetImpl::Get(
&path, &fields_,
&path, parent_, &fields_,
[](const std::shared_ptr<Field>& field) { return &field->type()->fields(); },
&out_of_range_depth);

DCHECK_OK(maybe_field.status());

if (maybe_field.ValueOrDie() != nullptr) {
return {path};
}
Expand All @@ -1417,26 +1440,51 @@ std::vector<FieldPath> FieldRef::FindAll(const FieldVector& fields) const {
return out;
}

std::vector<FieldPath> operator()(const std::vector<FieldRef>& refs) {
DCHECK_GE(refs.size(), 1);
Matches matches(refs.front().FindAll(fields_, parent_), fields_, parent_);

for (auto ref_it = refs.begin() + 1; ref_it != refs.end(); ++ref_it) {
Matches next_matches;
for (size_t i = 0; i < matches.size(); ++i) {
const auto& referent = *matches.referents[i];

for (const FieldPath& match : ref_it->FindAll(referent)) {
next_matches.Add(matches.prefixes[i], match, referent.type()->fields(),
&*referent.type());
}
}
matches = std::move(next_matches);
}

return matches.prefixes;
}

struct Matches {
// referents[i] is referenced by prefixes[i]
std::vector<FieldPath> prefixes;
FieldVector referents;

Matches(std::vector<FieldPath> matches, const FieldVector& fields) {
Matches(std::vector<FieldPath> matches, const FieldVector& fields,
const DataType* parent) {
auto current_parent = parent;
for (auto& match : matches) {
Add({}, std::move(match), fields);
current_parent = Add({}, std::move(match), fields, current_parent);
}
}

Matches() = default;

size_t size() const { return referents.size(); }

void Add(const FieldPath& prefix, const FieldPath& suffix,
const FieldVector& fields) {
auto maybe_field = suffix.Get(fields);
const DataType* Add(const FieldPath& prefix, const FieldPath& suffix,
const FieldVector& fields, const DataType* parent = NULLPTR) {
auto maybe_field = suffix.Get(fields, parent);

DCHECK_OK(maybe_field.status());
referents.push_back(std::move(maybe_field).ValueOrDie());
auto field = maybe_field.ValueOrDie();
auto field_type = field->type();
referents.push_back(std::move(field));

std::vector<int> concatenated_indices(prefix.indices().size() +
suffix.indices().size());
Expand All @@ -1445,32 +1493,15 @@ std::vector<FieldPath> FieldRef::FindAll(const FieldVector& fields) const {
it = std::copy(path->indices().begin(), path->indices().end(), it);
}
prefixes.emplace_back(std::move(concatenated_indices));
return &*field_type;
}
};

std::vector<FieldPath> operator()(const std::vector<FieldRef>& refs) {
DCHECK_GE(refs.size(), 1);
Matches matches(refs.front().FindAll(fields_), fields_);

for (auto ref_it = refs.begin() + 1; ref_it != refs.end(); ++ref_it) {
Matches next_matches;
for (size_t i = 0; i < matches.size(); ++i) {
const auto& referent = *matches.referents[i];

for (const FieldPath& match : ref_it->FindAll(referent)) {
next_matches.Add(matches.prefixes[i], match, referent.type()->fields());
}
}
matches = std::move(next_matches);
}

return matches.prefixes;
}

const DataType* parent_;
const FieldVector& fields_;
};

return std::visit(Visitor{fields}, impl_);
return std::visit(Visitor{parent, fields}, impl_);
}

std::vector<FieldPath> FieldRef::FindAll(const ArrayData& array) const {
Expand Down
6 changes: 4 additions & 2 deletions cpp/src/arrow/type.h
Original file line number Diff line number Diff line change
Expand Up @@ -1624,7 +1624,8 @@ class ARROW_EXPORT FieldPath {
Result<std::shared_ptr<Field>> Get(const Schema& schema) const;
Result<std::shared_ptr<Field>> Get(const Field& field) const;
Result<std::shared_ptr<Field>> Get(const DataType& type) const;
Result<std::shared_ptr<Field>> Get(const FieldVector& fields) const;
Result<std::shared_ptr<Field>> Get(const FieldVector& fields,
const DataType* parent = NULLPTR) const;

static Result<std::shared_ptr<Schema>> GetAll(const Schema& schema,
const std::vector<FieldPath>& paths);
Expand Down Expand Up @@ -1762,7 +1763,8 @@ class ARROW_EXPORT FieldRef : public util::EqualityComparable<FieldRef> {
std::vector<FieldPath> FindAll(const Schema& schema) const;
std::vector<FieldPath> FindAll(const Field& field) const;
std::vector<FieldPath> FindAll(const DataType& type) const;
std::vector<FieldPath> FindAll(const FieldVector& fields) const;
std::vector<FieldPath> FindAll(const FieldVector& fields,
const DataType* parent = NULLPTR) const;

/// \brief Convenience function which applies FindAll to arg's type or schema.
std::vector<FieldPath> FindAll(const ArrayData& array) const;
Expand Down
Loading