diff --git a/cpp/src/arrow/array.h b/cpp/src/arrow/array.h index 9c59c8c650a..eaabcf2f6b1 100644 --- a/cpp/src/arrow/array.h +++ b/cpp/src/arrow/array.h @@ -426,8 +426,6 @@ class ARROW_EXPORT Array { ARROW_DISALLOW_COPY_AND_ASSIGN(Array); }; -using ArrayVector = std::vector>; - namespace internal { /// Given a number of ArrayVectors, treat each ArrayVector as the diff --git a/cpp/src/arrow/buffer.h b/cpp/src/arrow/buffer.h index 465ca4ba160..fbc7b6b2148 100644 --- a/cpp/src/arrow/buffer.h +++ b/cpp/src/arrow/buffer.h @@ -332,8 +332,6 @@ class ARROW_EXPORT Buffer { ARROW_DISALLOW_COPY_AND_ASSIGN(Buffer); }; -using BufferVector = std::vector>; - /// \defgroup buffer-slicing-functions Functions for slicing buffers /// /// @{ diff --git a/cpp/src/arrow/dataset/dataset_internal.h b/cpp/src/arrow/dataset/dataset_internal.h index 269f5ae6f30..910695fc5e1 100644 --- a/cpp/src/arrow/dataset/dataset_internal.h +++ b/cpp/src/arrow/dataset/dataset_internal.h @@ -54,9 +54,10 @@ static inline FragmentIterator GetFragmentsFromDatasets( inline std::shared_ptr SchemaFromColumnNames( const std::shared_ptr& input, const std::vector& column_names) { std::vector> columns; - for (const auto& name : column_names) { - if (auto field = input->GetFieldByName(name)) { - columns.push_back(std::move(field)); + for (FieldRef ref : column_names) { + auto maybe_field = ref.GetOne(*input); + if (maybe_field.ok()) { + columns.push_back(std::move(maybe_field).ValueOrDie()); } } diff --git a/cpp/src/arrow/dataset/filter.cc b/cpp/src/arrow/dataset/filter.cc index ea7ab869b1b..83a607308e4 100644 --- a/cpp/src/arrow/dataset/filter.cc +++ b/cpp/src/arrow/dataset/filter.cc @@ -898,7 +898,8 @@ Result> ScalarExpression::Validate(const Schema& schem } Result> FieldExpression::Validate(const Schema& schema) const { - if (auto field = schema.GetFieldByName(name_)) { + ARROW_ASSIGN_OR_RAISE(auto field, FieldRef(name_).GetOneOrNone(schema)); + if (field != nullptr) { return field->type(); } return null(); diff --git a/cpp/src/arrow/dataset/partition.cc b/cpp/src/arrow/dataset/partition.cc index eea841f2c78..8879263de35 100644 --- a/cpp/src/arrow/dataset/partition.cc +++ b/cpp/src/arrow/dataset/partition.cc @@ -71,7 +71,7 @@ Result> SegmentDictionaryPartitioning::Parse( Result> KeyValuePartitioning::ConvertKey( const Key& key, const Schema& schema) { - auto field = schema.GetFieldByName(key.name); + ARROW_ASSIGN_OR_RAISE(auto field, FieldRef(key.name).GetOneOrNone(schema)); if (field == nullptr) { return scalar(true); } @@ -141,10 +141,8 @@ class DirectoryPartitioningFactory : public PartitioningFactory { Result> Finish( const std::shared_ptr& schema) const override { - for (const auto& field_name : field_names_) { - if (schema->GetFieldIndex(field_name) == -1) { - return Status::TypeError("no field named '", field_name, "' in schema", *schema); - } + for (FieldRef ref : field_names_) { + RETURN_NOT_OK(ref.FindOne(*schema).status()); } // drop fields which aren't in field_names_ diff --git a/cpp/src/arrow/dataset/projector.cc b/cpp/src/arrow/dataset/projector.cc index 7148c333d84..e7f1b38381a 100644 --- a/cpp/src/arrow/dataset/projector.cc +++ b/cpp/src/arrow/dataset/projector.cc @@ -40,8 +40,15 @@ RecordBatchProjector::RecordBatchProjector(std::shared_ptr to) column_indices_(to_->num_fields(), kNoMatch), scalars_(to_->num_fields(), nullptr) {} -Status RecordBatchProjector::SetDefaultValue(int index, std::shared_ptr scalar) { +Status RecordBatchProjector::SetDefaultValue(FieldRef ref, + std::shared_ptr scalar) { DCHECK_NE(scalar, nullptr); + if (ref.IsNested()) { + return Status::NotImplemented("setting default values for nested columns"); + } + + ARROW_ASSIGN_OR_RAISE(auto match, ref.FindOne(*to_)); + auto index = match[0]; auto field_type = to_->field(index)->type(); if (!field_type->Equals(scalar->type)) { @@ -83,9 +90,19 @@ Status RecordBatchProjector::SetInputSchema(std::shared_ptr from, for (int i = 0; i < to_->num_fields(); ++i) { const auto& field = to_->field(i); - int matching_index = from_->GetFieldIndex(field->name()); + FieldRef ref(field->name()); + auto matches = ref.FindAll(*from_); + + if (matches.empty()) { + // Mark column i as missing by setting missing_columns_[i] + // to a non-null placeholder. + RETURN_NOT_OK( + MakeArrayOfNull(pool, to_->field(i)->type(), 0, &missing_columns_[i])); + column_indices_[i] = kNoMatch; + } else { + RETURN_NOT_OK(ref.CheckNonMultiple(matches, *from_)); + int matching_index = matches[0][0]; - if (matching_index != kNoMatch) { if (!from_->field(matching_index)->Equals(field)) { return Status::TypeError("fields had matching names but were not equivalent ", from_->field(matching_index)->ToString(), " vs ", @@ -94,14 +111,8 @@ Status RecordBatchProjector::SetInputSchema(std::shared_ptr from, // Mark column i as not missing by setting missing_columns_[i] to nullptr missing_columns_[i] = nullptr; - } else { - // Mark column i as missing by setting missing_columns_[i] - // to a non-null placeholder. - RETURN_NOT_OK( - MakeArrayOfNull(pool, to_->field(i)->type(), 0, &missing_columns_[i])); + column_indices_[i] = matching_index; } - - column_indices_[i] = matching_index; } return Status::OK(); } diff --git a/cpp/src/arrow/dataset/projector.h b/cpp/src/arrow/dataset/projector.h index b64f2923110..13a0ffb1938 100644 --- a/cpp/src/arrow/dataset/projector.h +++ b/cpp/src/arrow/dataset/projector.h @@ -18,8 +18,6 @@ #pragma once #include -#include -#include #include #include "arrow/dataset/type_fwd.h" @@ -48,7 +46,7 @@ class ARROW_DS_EXPORT RecordBatchProjector { /// If the indexed field is absent from a record batch it will be added to the projected /// record batch with all its slots equal to the provided scalar (instead of null). - Status SetDefaultValue(int index, std::shared_ptr scalar); + Status SetDefaultValue(FieldRef ref, std::shared_ptr scalar); Result> Project(const RecordBatch& batch, MemoryPool* pool = default_memory_pool()); @@ -63,6 +61,7 @@ class ARROW_DS_EXPORT RecordBatchProjector { std::shared_ptr from_, to_; int64_t missing_columns_length_ = 0; + // these vectors are indexed parallel to to_->fields() std::vector> missing_columns_; std::vector column_indices_; std::vector> scalars_; diff --git a/cpp/src/arrow/flight/perf_server.cc b/cpp/src/arrow/flight/perf_server.cc index 2d4492c2a09..c899076029d 100644 --- a/cpp/src/arrow/flight/perf_server.cc +++ b/cpp/src/arrow/flight/perf_server.cc @@ -54,8 +54,6 @@ namespace flight { } \ } while (0) -using ArrayVector = std::vector>; - // Create record batches with a unique "a" column so we can verify on the // client side that the results are correct class PerfDataStream : public FlightDataStream { diff --git a/cpp/src/arrow/record_batch.cc b/cpp/src/arrow/record_batch.cc index ffcbdc459d1..b1a1708ba6c 100644 --- a/cpp/src/arrow/record_batch.cc +++ b/cpp/src/arrow/record_batch.cc @@ -80,6 +80,8 @@ class SimpleRecordBatch : public RecordBatch { std::shared_ptr column_data(int i) const override { return columns_[i]; } + ArrayDataVector column_data() const override { return columns_; } + Status AddColumn(int i, const std::shared_ptr& field, const std::shared_ptr& column, std::shared_ptr* out) const override { diff --git a/cpp/src/arrow/record_batch.h b/cpp/src/arrow/record_batch.h index 4a59c239b63..ada1ad7eef4 100644 --- a/cpp/src/arrow/record_batch.h +++ b/cpp/src/arrow/record_batch.h @@ -99,11 +99,14 @@ class ARROW_EXPORT RecordBatch { /// \return an Array or null if no field was found std::shared_ptr GetColumnByName(const std::string& name) const; - /// \brief Retrieve an array's internaldata from the record batch + /// \brief Retrieve an array's internal data from the record batch /// \param[in] i field index, does not boundscheck /// \return an internal ArrayData object virtual std::shared_ptr column_data(int i) const = 0; + /// \brief Retrieve all arrays' internal data from the record batch. + virtual ArrayDataVector column_data() const = 0; + /// \brief Add column to the record batch, producing a new RecordBatch /// /// \param[in] i field index, which will be boundschecked diff --git a/cpp/src/arrow/table.cc b/cpp/src/arrow/table.cc index e68f6811c59..2948a9299a7 100644 --- a/cpp/src/arrow/table.cc +++ b/cpp/src/arrow/table.cc @@ -41,25 +41,24 @@ using internal::checked_cast; // ---------------------------------------------------------------------- // ChunkedArray methods -ChunkedArray::ChunkedArray(const ArrayVector& chunks) : chunks_(chunks) { +ChunkedArray::ChunkedArray(ArrayVector chunks) : chunks_(std::move(chunks)) { length_ = 0; null_count_ = 0; - ARROW_CHECK_GT(chunks.size(), 0) + ARROW_CHECK_GT(chunks_.size(), 0) << "cannot construct ChunkedArray from empty vector and omitted type"; - type_ = chunks[0]->type(); - for (const std::shared_ptr& chunk : chunks) { + type_ = chunks_[0]->type(); + for (const std::shared_ptr& chunk : chunks_) { length_ += chunk->length(); null_count_ += chunk->null_count(); } } -ChunkedArray::ChunkedArray(const ArrayVector& chunks, - const std::shared_ptr& type) - : chunks_(chunks), type_(type) { +ChunkedArray::ChunkedArray(ArrayVector chunks, std::shared_ptr type) + : chunks_(std::move(chunks)), type_(std::move(type)) { length_ = 0; null_count_ = 0; - for (const std::shared_ptr& chunk : chunks) { + for (const std::shared_ptr& chunk : chunks_) { length_ += chunk->length(); null_count_ += chunk->null_count(); } diff --git a/cpp/src/arrow/table.h b/cpp/src/arrow/table.h index 4b106a16b9d..880573adc6e 100644 --- a/cpp/src/arrow/table.h +++ b/cpp/src/arrow/table.h @@ -30,8 +30,6 @@ namespace arrow { -using ArrayVector = std::vector>; - /// \class ChunkedArray /// \brief A data structure managing a list of primitive Arrow arrays logically /// as one large array @@ -41,7 +39,7 @@ class ARROW_EXPORT ChunkedArray { /// /// The vector must be non-empty and all its elements must have the same /// data type. - explicit ChunkedArray(const ArrayVector& chunks); + explicit ChunkedArray(ArrayVector chunks); /// \brief Construct a chunked array from a single Array explicit ChunkedArray(const std::shared_ptr& chunk) @@ -50,7 +48,7 @@ class ARROW_EXPORT ChunkedArray { /// \brief Construct a chunked array from a vector of arrays and a data type /// /// As the data type is passed explicitly, the vector may be empty. - ChunkedArray(const ArrayVector& chunks, const std::shared_ptr& type); + ChunkedArray(ArrayVector chunks, std::shared_ptr type); /// \return the total length of the chunked array; computed on construction int64_t length() const { return length_; } diff --git a/cpp/src/arrow/testing/gtest_util.h b/cpp/src/arrow/testing/gtest_util.h index 364ecc58d2d..93ea12ddcf8 100644 --- a/cpp/src/arrow/testing/gtest_util.h +++ b/cpp/src/arrow/testing/gtest_util.h @@ -164,8 +164,6 @@ struct Datum; using Datum = compute::Datum; -using ArrayVector = std::vector>; - #define ASSERT_ARRAYS_EQUAL(lhs, rhs) AssertArraysEqual((lhs), (rhs)) #define ASSERT_BATCHES_EQUAL(lhs, rhs) AssertBatchesEqual((lhs), (rhs)) #define ASSERT_TABLES_EQUAL(lhs, rhs) AssertTablesEqual((lhs), (rhs)) diff --git a/cpp/src/arrow/testing/util.h b/cpp/src/arrow/testing/util.h index e67231158bd..3ddf097a3a8 100644 --- a/cpp/src/arrow/testing/util.h +++ b/cpp/src/arrow/testing/util.h @@ -39,14 +39,6 @@ namespace arrow { -class Array; -class ChunkedArray; -class MemoryPool; -class RecordBatch; -class Table; - -using ArrayVector = std::vector>; - template Status CopyBufferFromVector(const std::vector& values, MemoryPool* pool, std::shared_ptr* result) { diff --git a/cpp/src/arrow/type.cc b/cpp/src/arrow/type.cc index b716033707c..108758927fe 100644 --- a/cpp/src/arrow/type.cc +++ b/cpp/src/arrow/type.cc @@ -30,9 +30,12 @@ #include "arrow/array.h" #include "arrow/compare.h" +#include "arrow/record_batch.h" #include "arrow/result.h" #include "arrow/status.h" +#include "arrow/table.h" #include "arrow/util/checked_cast.h" +#include "arrow/util/hashing.h" #include "arrow/util/key_value_metadata.h" #include "arrow/util/logging.h" #include "arrow/util/make_unique.h" @@ -608,6 +611,417 @@ std::string DictionaryType::ToString() const { std::string NullType::ToString() const { return name(); } +// ---------------------------------------------------------------------- +// FieldRef + +size_t FieldPath::hash() const { + return internal::ComputeStringHash<0>(data(), size() * sizeof(int)); +} + +std::string FieldPath::ToString() const { + std::string repr = "FieldPath("; + for (auto index : *this) { + repr += std::to_string(index) + " "; + } + repr.resize(repr.size() - 1); + repr += ")"; + return repr; +} + +struct FieldPathGetImpl { + static const DataType& GetType(const ArrayData& data) { return *data.type; } + + static const DataType& GetType(const ChunkedArray& array) { return *array.type(); } + + static void Summarize(const FieldVector& fields, std::stringstream* ss) { + *ss << "{ "; + for (const auto& field : fields) { + *ss << field->ToString() << ", "; + } + *ss << "}"; + } + + template + static void Summarize(const std::vector& columns, std::stringstream* ss) { + *ss << "{ "; + for (const auto& column : columns) { + *ss << GetType(*column) << ", "; + } + *ss << "}"; + } + + template + static Status IndexError(const FieldPath* path, int out_of_range_depth, + const std::vector& children) { + std::stringstream ss; + ss << "index out of range. "; + + ss << "indices=[ "; + int depth = 0; + for (int i : *path) { + if (depth != out_of_range_depth) { + ss << i << " "; + continue; + } + ss << ">" << i << "< "; + ++depth; + } + ss << "] "; + + if (std::is_same>::value) { + ss << "fields were: "; + } else { + ss << "columns had types: "; + } + Summarize(children, &ss); + + return Status::IndexError(ss.str()); + } + + template + static Result Get(const FieldPath* path, const std::vector* children, + GetChildren&& get_children, int* out_of_range_depth) { + if (path->empty()) { + return Status::Invalid("empty indices cannot be traversed"); + } + + int depth = 0; + const T* out; + for (int index : *path) { + if (index < 0 || static_cast(index) >= children->size()) { + *out_of_range_depth = depth; + return nullptr; + } + + out = &children->at(index); + children = get_children(*out); + ++depth; + } + + return *out; + } + + template + static Result Get(const FieldPath* path, const std::vector* children, + GetChildren&& get_children) { + int out_of_range_depth; + ARROW_ASSIGN_OR_RAISE(auto child, + Get(path, children, std::forward(get_children), + &out_of_range_depth)); + if (child != nullptr) { + return std::move(child); + } + return IndexError(path, out_of_range_depth, *children); + } + + static Result> Get(const FieldPath* path, + const FieldVector& fields) { + return FieldPathGetImpl::Get(path, &fields, [](const std::shared_ptr& field) { + return &field->type()->children(); + }); + } + + static Result> Get(const FieldPath* path, + const ArrayDataVector& child_data) { + return FieldPathGetImpl::Get( + path, &child_data, + [](const std::shared_ptr& data) { return &data->child_data; }); + } + + static Result> Get( + const FieldPath* path, const ChunkedArrayVector& columns_arg) { + ChunkedArrayVector columns = columns_arg; + + return FieldPathGetImpl::Get( + path, &columns, [&](const std::shared_ptr& a) { + columns.clear(); + + for (int i = 0; i < a->type()->num_children(); ++i) { + ArrayVector child_chunks; + + for (const auto& chunk : a->chunks()) { + auto child_chunk = MakeArray(chunk->data()->child_data[i]); + child_chunks.push_back(std::move(child_chunk)); + } + + auto child_column = std::make_shared( + std::move(child_chunks), a->type()->child(i)->type()); + + columns.emplace_back(std::move(child_column)); + } + + return &columns; + }); + } +}; + +Result> FieldPath::Get(const Schema& schema) const { + return FieldPathGetImpl::Get(this, schema.fields()); +} + +Result> FieldPath::Get(const Field& field) const { + return FieldPathGetImpl::Get(this, field.type()->children()); +} + +Result> FieldPath::Get(const DataType& type) const { + return FieldPathGetImpl::Get(this, type.children()); +} + +Result> FieldPath::Get(const FieldVector& fields) const { + return FieldPathGetImpl::Get(this, fields); +} + +Result> FieldPath::Get(const RecordBatch& batch) const { + ARROW_ASSIGN_OR_RAISE(auto data, FieldPathGetImpl::Get(this, batch.column_data())); + return MakeArray(data); +} + +Result> FieldPath::Get(const Table& table) const { + return FieldPathGetImpl::Get(this, table.columns()); +} + +FieldRef::FieldRef(FieldPath indices) : impl_(std::move(indices)) { + DCHECK_GT(util::get(impl_).size(), 0); +} + +void FieldRef::Flatten(std::vector children) { + // flatten children + struct Visitor { + void operator()(std::string&& name) { *out++ = FieldRef(std::move(name)); } + + void operator()(FieldPath&& indices) { *out++ = FieldRef(std::move(indices)); } + + void operator()(std::vector&& children) { + for (auto& child : children) { + util::visit(*this, std::move(child.impl_)); + } + } + + std::back_insert_iterator> out; + }; + + std::vector out; + Visitor visitor{std::back_inserter(out)}; + visitor(std::move(children)); + + DCHECK(!out.empty()); + DCHECK(std::none_of(out.begin(), out.end(), + [](const FieldRef& ref) { return ref.IsNested(); })); + + if (out.size() == 1) { + impl_ = std::move(out[0].impl_); + } else { + impl_ = std::move(out); + } +} + +Result FieldRef::FromDotPath(const std::string& dot_path_arg) { + if (dot_path_arg.empty()) { + return Status::Invalid("Dot path was empty"); + } + + std::vector children; + + util::string_view dot_path = dot_path_arg; + + auto parse_name = [&] { + std::string name; + for (;;) { + auto segment_end = dot_path.find_first_of("\\[."); + if (segment_end == util::string_view::npos) { + // dot_path doesn't contain any other special characters; consume all + name.append(dot_path.begin(), dot_path.end()); + dot_path = ""; + break; + } + + if (dot_path[segment_end] != '\\') { + // segment_end points to a subscript for a new FieldRef + name.append(dot_path.begin(), segment_end); + dot_path = dot_path.substr(segment_end); + break; + } + + if (dot_path.size() == segment_end + 1) { + // dot_path ends with backslash; consume it all + name.append(dot_path.begin(), dot_path.end()); + dot_path = ""; + break; + } + + // append all characters before backslash, then the character which follows it + name.append(dot_path.begin(), segment_end); + name.push_back(dot_path[segment_end + 1]); + dot_path = dot_path.substr(segment_end + 2); + } + return name; + }; + + while (!dot_path.empty()) { + auto subscript = dot_path[0]; + dot_path = dot_path.substr(1); + switch (subscript) { + case '.': { + // next element is a name + children.emplace_back(parse_name()); + continue; + } + case '[': { + auto subscript_end = dot_path.find_first_not_of("0123456789"); + if (subscript_end == util::string_view::npos || dot_path[subscript_end] != ']') { + return Status::Invalid("Dot path '", dot_path_arg, + "' contained an unterminated index"); + } + children.emplace_back(std::atoi(dot_path.data())); + dot_path = dot_path.substr(subscript_end + 1); + continue; + } + default: + return Status::Invalid("Dot path must begin with '[' or '.', got '", dot_path_arg, + "'"); + } + } + + FieldRef out; + out.Flatten(std::move(children)); + return out; +} + +size_t FieldRef::hash() const { + struct Visitor : std::hash { + using std::hash::operator(); + + size_t operator()(const FieldPath& path) { return path.hash(); } + + size_t operator()(const std::vector& children) { + size_t hash = 0; + + for (const FieldRef& child : children) { + hash ^= child.hash(); + } + + return hash; + } + }; + + return util::visit(Visitor{}, impl_); +} + +std::string FieldRef::ToString() const { + struct Visitor { + std::string operator()(const FieldPath& path) { return path.ToString(); } + + std::string operator()(const std::string& name) { return "Name(" + name + ")"; } + + std::string operator()(const std::vector& children) { + std::string repr = "Nested("; + for (const auto& child : children) { + repr += child.ToString() + " "; + } + repr.resize(repr.size() - 1); + repr += ")"; + return repr; + } + }; + + return "FieldRef." + util::visit(Visitor{}, impl_); +} + +std::vector FieldRef::FindAll(const Schema& schema) const { + return FindAll(schema.fields()); +} + +std::vector FieldRef::FindAll(const Field& field) const { + return FindAll(field.type()->children()); +} + +std::vector FieldRef::FindAll(const DataType& type) const { + return FindAll(type.children()); +} + +std::vector FieldRef::FindAll(const FieldVector& fields) const { + struct Visitor { + std::vector operator()(const FieldPath& path) { + // skip long IndexError construction if path is out of range + int out_of_range_depth; + auto maybe_field = FieldPathGetImpl::Get( + &path, &fields_, + [](const std::shared_ptr& field) { return &field->type()->children(); }, + &out_of_range_depth); + + DCHECK_OK(maybe_field.status()); + + if (maybe_field.ValueOrDie() != nullptr) { + return {path}; + } + return {}; + } + + std::vector operator()(const std::string& name) { + std::vector out; + + for (int i = 0; i < static_cast(fields_.size()); ++i) { + if (fields_[i]->name() == name) { + out.push_back({i}); + } + } + + return out; + } + + struct Matches { + // referents[i] is referenced by prefixes[i] + std::vector prefixes; + FieldVector referents; + + Matches(std::vector matches, const FieldVector& fields) { + for (auto& match : matches) { + Add({}, std::move(match), fields); + } + } + + Matches() = default; + + size_t size() const { return referents.size(); } + + void Add(FieldPath prefix, const FieldPath& match, const FieldVector& fields) { + auto maybe_field = match.Get(fields); + DCHECK_OK(maybe_field.status()); + + prefix.resize(prefix.size() + match.size()); + std::copy(match.begin(), match.end(), prefix.end() - match.size()); + prefixes.push_back(std::move(prefix)); + referents.push_back(std::move(maybe_field).ValueOrDie()); + } + }; + + std::vector operator()(const std::vector& refs) { + DCHECK_GE(refs.size(), 1); + Matches matches(refs.front().FindAll(fields_), fields_); + + for (auto ref_it = refs.begin() + 1; ref_it != refs.end(); ++ref_it) { + Matches next_matches; + for (size_t i = 0; i < matches.size(); ++i) { + const auto& referent = *matches.referents[i]; + + for (const FieldPath& match : ref_it->FindAll(referent)) { + next_matches.Add(matches.prefixes[i], match, referent.type()->children()); + } + } + matches = std::move(next_matches); + } + + return matches.prefixes; + } + + const FieldVector& fields_; + }; + + return util::visit(Visitor{fields}, impl_); +} + +void PrintTo(const FieldRef& ref, std::ostream* os) { *os << ref.ToString(); } + // ---------------------------------------------------------------------- // Schema implementation @@ -636,7 +1050,7 @@ Schema::~Schema() {} int Schema::num_fields() const { return static_cast(impl_->fields_.size()); } -std::shared_ptr Schema::field(int i) const { +const std::shared_ptr& Schema::field(int i) const { DCHECK_GE(i, 0); DCHECK_LT(i, num_fields()); return impl_->fields_[i]; diff --git a/cpp/src/arrow/type.h b/cpp/src/arrow/type.h index 7cd1b651994..7f58369b38b 100644 --- a/cpp/src/arrow/type.h +++ b/cpp/src/arrow/type.h @@ -26,9 +26,11 @@ #include #include +#include "arrow/result.h" #include "arrow/type_fwd.h" // IWYU pragma: export #include "arrow/util/checked_cast.h" #include "arrow/util/macros.h" +#include "arrow/util/variant.h" #include "arrow/util/visibility.h" #include "arrow/visitor.h" // IWYU pragma: keep @@ -243,7 +245,7 @@ class ARROW_EXPORT DataType : public detail::Fingerprintable { /// \brief Return whether the types are equal bool Equals(const std::shared_ptr& other) const; - std::shared_ptr child(int i) const { return children_[i]; } + const std::shared_ptr& child(int i) const { return children_[i]; } const std::vector>& children() const { return children_; } @@ -1398,6 +1400,238 @@ class ARROW_EXPORT DictionaryUnifier { std::shared_ptr* out_dict) = 0; }; +// ---------------------------------------------------------------------- +// FieldRef + +/// \class FieldPath +/// +/// Represents a path to a nested field using indices of child fields. +/// For example, given indices {5, 9, 3} the field would be retrieved with +/// schema->field(5)->type()->child(9)->type()->child(3) +/// +/// Attempting to retrieve a child field using a FieldPath which is not valid for +/// a given schema will raise an error. Invalid FieldPaths include: +/// - an index is out of range +/// - the path is empty (note: a default constructed FieldPath will be empty) +/// +/// FieldPaths provide a number of accessors for drilling down to potentially nested +/// children. They are overloaded for convenience to support Schema (returns a field), +/// DataType (returns a child field), Field (returns a child field of this field's type) +/// Array (returns a child array), RecordBatch (returns a column), ChunkedArray (returns a +/// ChunkedArray where each chunk is a child array of the corresponding original chunk) +/// and Table (returns a column). +class ARROW_EXPORT FieldPath : public std::vector { + public: + using std::vector::vector; + + FieldPath() = default; + + FieldPath(std::vector indices) // NOLINT runtime/explicit + : std::vector(std::move(indices)) {} + + std::string ToString() const; + + size_t hash() const; + + explicit operator bool() const { return !empty(); } + + /// \brief Retrieve the referenced child Field from a Schema, Field, or DataType + Result> Get(const Schema& schema) const; + Result> Get(const Field& field) const; + Result> Get(const DataType& type) const; + Result> Get(const FieldVector& fields) const; + + /// \brief Retrieve the referenced column from a RecordBatch or Table + Result> Get(const RecordBatch& batch) const; + Result> Get(const Table& table) const; + + /// \brief Retrieve the referenced child Array from an Array or ChunkedArray + Result> Get(const Array& array) const; + Result> Get(const ChunkedArray& array) const; +}; + +/// \class FieldRef +/// \brief Descriptor of a (potentially nested) field within a schema. +/// +/// Unlike FieldPath (which exclusively uses indices of child fields), FieldRef may +/// reference a field by name. It is intended to replace parameters like `int field_index` +/// and `const std::string& field_name`; it can be implicitly constructed from either a +/// field index or a name. +/// +/// Nested fields can be referenced as well. Given +/// schema({field("a", struct_({field("n", null())})), field("b", int32())}) +/// +/// the following all indicate the nested field named "n": +/// FieldRef ref1(0, 0); +/// FieldRef ref2("a", 0); +/// FieldRef ref3("a", "n"); +/// FieldRef ref4(0, "n"); +/// ARROW_ASSIGN_OR_RAISE(FieldRef ref5, +/// FieldRef::FromDotPath(".a[0]")); +/// +/// FieldPaths matching a FieldRef are retrieved using the member function FindAll. +/// Multiple matches are possible because field names may be duplicated within a schema. +/// For example: +/// Schema a_is_ambiguous({field("a", int32()), field("a", float32())}); +/// auto matches = FieldRef("a").FindAll(a_is_ambiguous); +/// assert(matches.size() == 2); +/// assert(matches[0].Get(a_is_ambiguous)->Equals(a_is_ambiguous.field(0))); +/// assert(matches[1].Get(a_is_ambiguous)->Equals(a_is_ambiguous.field(1))); +/// +/// Convenience accessors are available which raise a helpful error if the field is not +/// found or ambiguous, and for immediately calling FieldPath::Get to retrieve any +/// matching children: +/// auto maybe_match = FieldRef("struct", "field_i32").FindOneOrNone(schema); +/// auto maybe_column = FieldRef("struct", "field_i32").GetOne(some_table); +class ARROW_EXPORT FieldRef { + public: + FieldRef() = default; + + /// Construct a FieldRef using a string of indices. The reference will be retrieved as: + /// schema.fields[self.indices[0]].type.fields[self.indices[1]] ... + /// + /// Empty indices are not valid. + FieldRef(FieldPath indices); // NOLINT runtime/explicit + + /// Construct a by-name FieldRef. Multiple fields may match a by-name FieldRef: + /// [f for f in schema.fields where f.name == self.name] + FieldRef(std::string name) : impl_(std::move(name)) {} // NOLINT runtime/explicit + + /// Equivalent to a single index string of indices. + FieldRef(int index) : impl_(FieldPath({index})) {} // NOLINT runtime/explicit + + /// Convenience constructor for nested FieldRefs: each argument will be used to + /// construct a FieldRef + template + FieldRef(A0&& a0, A1&& a1, A&&... a) { + Flatten({// cpplint thinks the following are constructor decls + FieldRef(std::forward(a0)), // NOLINT runtime/explicit + FieldRef(std::forward(a1)), // NOLINT runtime/explicit + FieldRef(std::forward(a))...}); // NOLINT runtime/explicit + } + + /// Parse a dot path into a FieldRef. + /// + /// dot_path = '.' name + /// | '[' digit+ ']' + /// | dot_path+ + /// + /// Examples: + /// ".alpha" => FieldRef("alpha") + /// "[2]" => FieldRef(2) + /// ".beta[3]" => FieldRef("beta", 3) + /// "[5].gamma.delta[7]" => FieldRef(5, "gamma", "delta", 7) + /// ".hello world" => FieldRef("hello world") + /// R"(.\[y\]\\tho\.\)" => FieldRef(R"([y]\tho.\)") + /// + /// Note: When parsing a name, a '\' preceding any other character will be dropped from + /// the resulting name. Therefore if a name must contain the characters '.', '\', or '[' + /// those must be escaped with a preceding '\'. + static Result FromDotPath(const std::string& dot_path); + + bool Equals(const FieldRef& other) const { return impl_ == other.impl_; } + bool operator==(const FieldRef& other) const { return Equals(other); } + + std::string ToString() const; + + size_t hash() const; + + bool IsFieldPath() const { return util::holds_alternative(impl_); } + bool IsName() const { return util::holds_alternative(impl_); } + bool IsNested() const { + if (IsName()) return false; + if (IsFieldPath()) return util::get(impl_).size() > 1; + return true; + } + + /// \brief Retrieve FieldPath of every child field which matches this FieldRef. + std::vector FindAll(const Schema& schema) const; + std::vector FindAll(const Field& field) const; + std::vector FindAll(const DataType& type) const; + std::vector FindAll(const FieldVector& fields) const; + + /// \brief Convenience function: raise an error if matches is empty. + template + Status CheckNonEmpty(const std::vector& matches, const T& root) const { + if (matches.empty()) { + return Status::Invalid("No match for ", ToString(), " in ", root.ToString()); + } + return Status::OK(); + } + + /// \brief Convenience function: raise an error if matches contains multiple FieldPaths. + template + Status CheckNonMultiple(const std::vector& matches, const T& root) const { + if (matches.size() > 1) { + return Status::Invalid("Multiple matches for ", ToString(), " in ", + root.ToString()); + } + return Status::OK(); + } + + /// \brief Retrieve FieldPath of a single child field which matches this + /// FieldRef. Emit an error if none or multiple match. + template + Result FindOne(const T& root) const { + auto matches = FindAll(root); + ARROW_RETURN_NOT_OK(CheckNonEmpty(matches, root)); + ARROW_RETURN_NOT_OK(CheckNonMultiple(matches, root)); + return std::move(matches[0]); + } + + /// \brief Retrieve FieldPath of a single child field which matches this + /// FieldRef. Emit an error if multiple match. An empty (invalid) FieldPath + /// will be returned if none match. + template + Result FindOneOrNone(const T& root) const { + auto matches = FindAll(root); + ARROW_RETURN_NOT_OK(CheckNonMultiple(matches, root)); + if (matches.empty()) { + return FieldPath(); + } + return std::move(matches[0]); + } + + template + using GetType = decltype(std::declval().Get(std::declval()).ValueOrDie()); + + /// \brief Get all children matching this FieldRef. + template + std::vector> GetAll(const T& root) const { + std::vector> out; + for (const auto& match : FindAll(root)) { + out.push_back(match.Get(root).ValueOrDie()); + } + return out; + } + + /// \brief Get the single child matching this FieldRef. + /// Emit an error if none or multiple match. + template + Result> GetOne(const T& root) const { + ARROW_ASSIGN_OR_RAISE(auto match, FindOne(root)); + return match.Get(root).ValueOrDie(); + } + + /// \brief Get the single child matching this FieldRef. + /// Return nullptr if none match, emit an error if multiple match. + template + Result> GetOneOrNone(const T& root) const { + ARROW_ASSIGN_OR_RAISE(auto match, FindOneOrNone(root)); + if (match) { + return match.Get(root).ValueOrDie(); + } + return NULLPTR; + } + + private: + void Flatten(std::vector children); + + util::variant> impl_; + + ARROW_EXPORT friend void PrintTo(const FieldRef& ref, std::ostream* os); +}; + // ---------------------------------------------------------------------- // Schema @@ -1423,7 +1657,7 @@ class ARROW_EXPORT Schema : public detail::Fingerprintable, int num_fields() const; /// Return the ith schema element. Does not boundscheck - std::shared_ptr field(int i) const; + const std::shared_ptr& field(int i) const; const std::vector>& fields() const; diff --git a/cpp/src/arrow/type_fwd.h b/cpp/src/arrow/type_fwd.h index 4ea81ee1937..918b28960d5 100644 --- a/cpp/src/arrow/type_fwd.h +++ b/cpp/src/arrow/type_fwd.h @@ -15,10 +15,10 @@ // specific language governing permissions and limitations // under the License. -#ifndef ARROW_TYPE_FWD_H -#define ARROW_TYPE_FWD_H +#pragma once #include +#include #include "arrow/util/visibility.h" @@ -39,21 +39,31 @@ class MemoryPool; class MutableBuffer; class ResizableBuffer; +using BufferVector = std::vector>; + class DataType; class Field; +class FieldRef; class KeyValueMetadata; class Schema; +using FieldVector = std::vector>; + class Array; struct ArrayData; class ArrayBuilder; class Tensor; struct Scalar; +using ArrayDataVector = std::vector>; +using ArrayVector = std::vector>; + class ChunkedArray; class RecordBatch; class Table; +using ChunkedArrayVector = std::vector>; +using RecordBatchVector = std::vector>; using RecordBatchIterator = Iterator>; class DictionaryType; @@ -261,5 +271,3 @@ ARROW_EXPORT MemoryPool* default_memory_pool(); #endif } // namespace arrow - -#endif // ARROW_TYPE_FWD_H diff --git a/cpp/src/arrow/type_test.cc b/cpp/src/arrow/type_test.cc index f0a12dc7514..fd715f2b4fb 100644 --- a/cpp/src/arrow/type_test.cc +++ b/cpp/src/arrow/type_test.cc @@ -22,6 +22,8 @@ #include #include +#include + #include "arrow/memory_pool.h" #include "arrow/testing/gtest_util.h" #include "arrow/testing/util.h" @@ -32,6 +34,8 @@ namespace arrow { +using testing::ElementsAre; + using internal::checked_cast; using internal::checked_pointer_cast; @@ -311,6 +315,101 @@ TEST(TestField, TestMerge) { } } +TEST(TestFieldPath, Basics) { + auto f0 = field("alpha", int32()); + auto f1 = field("beta", int32()); + auto f2 = field("alpha", int32()); + auto f3 = field("beta", int32()); + Schema s({f0, f1, f2, f3}); + + // retrieving a field with single-element FieldPath is equivalent to Schema::field + for (int index = 0; index < s.num_fields(); ++index) { + ASSERT_OK_AND_EQ(s.field(index), FieldPath({index}).Get(s)); + } + EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid, + testing::HasSubstr("empty indices cannot be traversed"), + FieldPath().Get(s)); + EXPECT_RAISES_WITH_MESSAGE_THAT(IndexError, testing::HasSubstr("index out of range"), + FieldPath({s.num_fields() * 2}).Get(s)); +} + +TEST(TestFieldRef, Basics) { + auto f0 = field("alpha", int32()); + auto f1 = field("beta", int32()); + auto f2 = field("alpha", int32()); + auto f3 = field("beta", int32()); + Schema s({f0, f1, f2, f3}); + + // lookup by index returns Indices{index} + for (int index = 0; index < s.num_fields(); ++index) { + EXPECT_THAT(FieldRef(index).FindAll(s), ElementsAre(FieldPath{index})); + } + // out of range index results in a failure to match + EXPECT_THAT(FieldRef(s.num_fields() * 2).FindAll(s), ElementsAre()); + + // lookup by name returns the Indices of both matching fields + EXPECT_THAT(FieldRef("alpha").FindAll(s), ElementsAre(FieldPath{0}, FieldPath{2})); + EXPECT_THAT(FieldRef("beta").FindAll(s), ElementsAre(FieldPath{1}, FieldPath{3})); +} + +TEST(TestFieldRef, FromDotPath) { + ASSERT_OK_AND_EQ(FieldRef("alpha"), FieldRef::FromDotPath(R"(.alpha)")); + + ASSERT_OK_AND_EQ(FieldRef("", ""), FieldRef::FromDotPath(R"(..)")); + + ASSERT_OK_AND_EQ(FieldRef(2), FieldRef::FromDotPath(R"([2])")); + + ASSERT_OK_AND_EQ(FieldRef("beta", 3), FieldRef::FromDotPath(R"(.beta[3])")); + + ASSERT_OK_AND_EQ(FieldRef(5, "gamma", "delta", 7), + FieldRef::FromDotPath(R"([5].gamma.delta[7])")); + + ASSERT_OK_AND_EQ(FieldRef("hello world"), FieldRef::FromDotPath(R"(.hello world)")); + + ASSERT_OK_AND_EQ(FieldRef(R"([y]\tho.\)"), FieldRef::FromDotPath(R"(.\[y\]\\tho\.\)")); + + ASSERT_RAISES(Invalid, FieldRef::FromDotPath(R"()")); + ASSERT_RAISES(Invalid, FieldRef::FromDotPath(R"(alpha)")); + ASSERT_RAISES(Invalid, FieldRef::FromDotPath(R"([134234)")); + ASSERT_RAISES(Invalid, FieldRef::FromDotPath(R"([1stuf])")); +} + +TEST(TestFieldPath, Nested) { + auto f0 = field("alpha", int32()); + auto f1_0 = field("alpha", int32()); + auto f1 = field("beta", struct_({f1_0})); + auto f2_0 = field("alpha", int32()); + auto f2_1_0 = field("alpha", int32()); + auto f2_1_1 = field("alpha", int32()); + auto f2_1 = field("gamma", struct_({f2_1_0, f2_1_1})); + auto f2 = field("beta", struct_({f2_0, f2_1})); + Schema s({f0, f1, f2}); + + // retrieving fields with nested indices + EXPECT_EQ(FieldPath({0}).Get(s), f0); + EXPECT_EQ(FieldPath({1, 0}).Get(s), f1_0); + EXPECT_EQ(FieldPath({2, 0}).Get(s), f2_0); + EXPECT_EQ(FieldPath({2, 1, 0}).Get(s), f2_1_0); + EXPECT_EQ(FieldPath({2, 1, 1}).Get(s), f2_1_1); +} + +TEST(TestFieldRef, Nested) { + auto f0 = field("alpha", int32()); + auto f1_0 = field("alpha", int32()); + auto f1 = field("beta", struct_({f1_0})); + auto f2_0 = field("alpha", int32()); + auto f2_1_0 = field("alpha", int32()); + auto f2_1_1 = field("alpha", int32()); + auto f2_1 = field("gamma", struct_({f2_1_0, f2_1_1})); + auto f2 = field("beta", struct_({f2_0, f2_1})); + Schema s({f0, f1, f2}); + + EXPECT_THAT(FieldRef("beta", "alpha").FindAll(s), + ElementsAre(FieldPath{1, 0}, FieldPath{2, 0})); + EXPECT_THAT(FieldRef("beta", "gamma", "alpha").FindAll(s), + ElementsAre(FieldPath{2, 1, 0}, FieldPath{2, 1, 1})); +} + using TestSchema = ::testing::Test; TEST_F(TestSchema, Basics) {