diff --git a/cpp/src/arrow/type-test.cc b/cpp/src/arrow/type-test.cc index 48982cad424..f62d14d049b 100644 --- a/cpp/src/arrow/type-test.cc +++ b/cpp/src/arrow/type-test.cc @@ -399,6 +399,40 @@ TEST(TestStructType, Basics) { // TODO(wesm): out of bounds for field(...) } +TEST(TestStructType, GetChildByName) { + auto f0 = field("f0", int32()); + auto f1 = field("f1", uint8(), false); + auto f2 = field("f2", utf8()); + auto f3 = field("f3", list(int16())); + + StructType struct_type({f0, f1, f2, f3}); + std::shared_ptr result; + + result = struct_type.GetChildByName("f1"); + ASSERT_EQ(f1, result); + + result = struct_type.GetChildByName("f3"); + ASSERT_EQ(f3, result); + + result = struct_type.GetChildByName("not-found"); + ASSERT_EQ(result, nullptr); +} + +TEST(TestStructType, GetChildIndex) { + auto f0 = field("f0", int32()); + auto f1 = field("f1", uint8(), false); + auto f2 = field("f2", utf8()); + auto f3 = field("f3", list(int16())); + + StructType struct_type({f0, f1, f2, f3}); + + ASSERT_EQ(0, struct_type.GetChildIndex(f0->name())); + ASSERT_EQ(1, struct_type.GetChildIndex(f1->name())); + ASSERT_EQ(2, struct_type.GetChildIndex(f2->name())); + ASSERT_EQ(3, struct_type.GetChildIndex(f3->name())); + ASSERT_EQ(-1, struct_type.GetChildIndex("not-found")); +} + TEST(TypesTest, TestDecimal128Small) { Decimal128Type t1(8, 4); diff --git a/cpp/src/arrow/type.cc b/cpp/src/arrow/type.cc index 792d1bfd035..2eb9967976d 100644 --- a/cpp/src/arrow/type.cc +++ b/cpp/src/arrow/type.cc @@ -109,20 +109,6 @@ std::string FixedSizeBinaryType::ToString() const { return ss.str(); } -std::string StructType::ToString() const { - std::stringstream s; - s << "struct<"; - for (int i = 0; i < this->num_children(); ++i) { - if (i > 0) { - s << ", "; - } - std::shared_ptr field = this->child(i); - s << field->name() << ": " << field->type()->ToString(); - } - s << ">"; - return s.str(); -} - // ---------------------------------------------------------------------- // Date types @@ -206,6 +192,43 @@ std::string UnionType::ToString() const { return s.str(); } +// ---------------------------------------------------------------------- +// Struct type + +std::string StructType::ToString() const { + std::stringstream s; + s << "struct<"; + for (int i = 0; i < this->num_children(); ++i) { + if (i > 0) { + s << ", "; + } + std::shared_ptr field = this->child(i); + s << field->name() << ": " << field->type()->ToString(); + } + s << ">"; + return s.str(); +} + +std::shared_ptr StructType::GetChildByName(const std::string& name) const { + int i = GetChildIndex(name); + return i == -1 ? nullptr : children_[i]; +} + +int StructType::GetChildIndex(const std::string& name) const { + if (children_.size() > 0 && name_to_index_.size() == 0) { + for (size_t i = 0; i < children_.size(); ++i) { + name_to_index_[children_[i]->name()] = static_cast(i); + } + } + + auto it = name_to_index_.find(name); + if (it == name_to_index_.end()) { + return -1; + } else { + return it->second; + } +} + // ---------------------------------------------------------------------- // DictionaryType diff --git a/cpp/src/arrow/type.h b/cpp/src/arrow/type.h index e50760b55da..9cd1d8f86db 100644 --- a/cpp/src/arrow/type.h +++ b/cpp/src/arrow/type.h @@ -475,6 +475,16 @@ class ARROW_EXPORT StructType : public NestedType { Status Accept(TypeVisitor* visitor) const override; std::string ToString() const override; std::string name() const override { return "struct"; } + + /// Returns null if name not found + std::shared_ptr GetChildByName(const std::string& name) const; + + /// Returns -1 if name not found + int GetChildIndex(const std::string& name) const; + + private: + /// Lazily initialized mapping + mutable std::unordered_map name_to_index_; }; class ARROW_EXPORT DecimalType : public FixedSizeBinaryType { @@ -777,6 +787,8 @@ class ARROW_EXPORT Schema { private: std::vector> fields_; + + /// Lazily initialized mapping mutable std::unordered_map name_to_index_; std::shared_ptr metadata_; diff --git a/python/pyarrow/includes/libarrow.pxd b/python/pyarrow/includes/libarrow.pxd index 70eb9cbb205..0df7f34f33d 100644 --- a/python/pyarrow/includes/libarrow.pxd +++ b/python/pyarrow/includes/libarrow.pxd @@ -253,6 +253,9 @@ cdef extern from "arrow/api.h" namespace "arrow" nogil: cdef cppclass CStructType" arrow::StructType"(CDataType): CStructType(const vector[shared_ptr[CField]]& fields) + shared_ptr[CField] GetChildByName(const c_string& name) + int GetChildIndex(const c_string& name) + cdef cppclass CUnionType" arrow::UnionType"(CDataType): CUnionType(const vector[shared_ptr[CField]]& fields, const vector[uint8_t]& type_codes, UnionMode mode) diff --git a/python/pyarrow/lib.pxd b/python/pyarrow/lib.pxd index 998eeafc630..c6864d2ca27 100644 --- a/python/pyarrow/lib.pxd +++ b/python/pyarrow/lib.pxd @@ -143,6 +143,11 @@ cdef class ListValue(ArrayValue): cdef int64_t length(self) +cdef class StructValue(ArrayValue): + cdef: + CStructArray* ap + + cdef class UnionValue(ArrayValue): cdef: CUnionArray* ap @@ -150,6 +155,7 @@ cdef class UnionValue(ArrayValue): cdef getitem(self, int64_t i) + cdef class StringValue(ArrayValue): pass diff --git a/python/pyarrow/scalar.pxi b/python/pyarrow/scalar.pxi index e1527713787..04ecc9cf46e 100644 --- a/python/pyarrow/scalar.pxi +++ b/python/pyarrow/scalar.pxi @@ -371,14 +371,29 @@ cdef class FixedSizeBinaryValue(ArrayValue): cdef class StructValue(ArrayValue): + cdef void _set_array(self, const shared_ptr[CArray]& sp_array): + self.sp_array = sp_array + self.ap = sp_array.get() + + def __getitem__(self, key): + cdef: + CStructType* type + int index + + type = self.type.type + index = type.GetChildIndex(tobytes(key)) + + if index < 0: + raise KeyError(key) + + return pyarrow_wrap_array(self.ap.field(index))[self.index] + def as_py(self): cdef: - CStructArray* ap vector[shared_ptr[CField]] child_fields = self.type.type.children() - ap = self.sp_array.get() - wrapped_arrays = [pyarrow_wrap_array(ap.field(i)) - for i in range(ap.num_fields())] + wrapped_arrays = [pyarrow_wrap_array(self.ap.field(i)) + for i in range(self.ap.num_fields())] child_names = [child.get().name() for child in child_fields] # Return the struct as a dict return { @@ -415,6 +430,7 @@ cdef dict _scalar_classes = { _Type_STRUCT: StructValue, } + cdef object box_scalar(DataType type, const shared_ptr[CArray]& sp_array, int64_t index): cdef ArrayValue val diff --git a/python/pyarrow/tests/test_scalars.py b/python/pyarrow/tests/test_scalars.py index 41057a0a40b..9c86270c454 100644 --- a/python/pyarrow/tests/test_scalars.py +++ b/python/pyarrow/tests/test_scalars.py @@ -215,3 +215,21 @@ def test_array_to_set(self): set_from_array = set(arr) assert isinstance(set_from_array, set) assert set_from_array == {1, 2} + + def test_struct_value_subscripting(self): + ty = pa.struct([pa.field('x', pa.int16()), + pa.field('y', pa.float32())]) + arr = pa.array([(1, 2.5), (3, 4.5), (5, 6.5)], type=ty) + + assert arr[0]['x'] == 1 + assert arr[0]['y'] == 2.5 + assert arr[1]['x'] == 3 + assert arr[1]['y'] == 4.5 + assert arr[2]['x'] == 5 + assert arr[2]['y'] == 6.5 + + with pytest.raises(IndexError): + arr[4]['non-existent'] + + with pytest.raises(KeyError): + arr[0]['non-existent']