From f4a9babe17a48b086c567f7388932981e7ee2d7b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kriszti=C3=A1n=20Sz=C5=B1cs?= Date: Tue, 24 Apr 2018 16:58:19 +0200 Subject: [PATCH 1/4] GetChildByName and GetChildIndex for StructType --- cpp/src/arrow/type.cc | 51 +++++++++++++++++++++++++++++++------------ cpp/src/arrow/type.h | 9 ++++++++ 2 files changed, 46 insertions(+), 14 deletions(-) diff --git a/cpp/src/arrow/type.cc b/cpp/src/arrow/type.cc index 792d1bfd035..e84b074ac68 100644 --- a/cpp/src/arrow/type.cc +++ b/cpp/src/arrow/type.cc @@ -109,20 +109,6 @@ std::string FixedSizeBinaryType::ToString() const { return ss.str(); } -std::string StructType::ToString() const { - std::stringstream s; - s << "struct<"; - for (int i = 0; i < this->num_children(); ++i) { - if (i > 0) { - s << ", "; - } - std::shared_ptr field = this->child(i); - s << field->name() << ": " << field->type()->ToString(); - } - s << ">"; - return s.str(); -} - // ---------------------------------------------------------------------- // Date types @@ -206,6 +192,43 @@ std::string UnionType::ToString() const { return s.str(); } +// ---------------------------------------------------------------------- +// Struct type + +std::string StructType::ToString() const { + std::stringstream s; + s << "struct<"; + for (int i = 0; i < this->num_children(); ++i) { + if (i > 0) { + s << ", "; + } + std::shared_ptr field = this->child(i); + s << field->name() << ": " << field->type()->ToString(); + } + s << ">"; + return s.str(); +} + +std::shared_ptr StructType::GetChildByName(const std::string& name) const { + int64_t i = GetChildIndex(name); + return i == -1 ? nullptr : children_[i]; +} + +int64_t StructType::GetChildIndex(const std::string& name) const { + if (children_.size() > 0 && name_to_index_.size() == 0) { + for (size_t i = 0; i < children_.size(); ++i) { + name_to_index_[children_[i]->name()] = static_cast(i); + } + } + + auto it = name_to_index_.find(name); + if (it == name_to_index_.end()) { + return -1; + } else { + return it->second; + } +} + // ---------------------------------------------------------------------- // DictionaryType diff --git a/cpp/src/arrow/type.h b/cpp/src/arrow/type.h index e50760b55da..0ed1f9914ec 100644 --- a/cpp/src/arrow/type.h +++ b/cpp/src/arrow/type.h @@ -475,6 +475,15 @@ class ARROW_EXPORT StructType : public NestedType { Status Accept(TypeVisitor* visitor) const override; std::string ToString() const override; std::string name() const override { return "struct"; } + + /// Returns null if name not found + std::shared_ptr GetChildByName(const std::string& name) const; + + /// Returns -1 if name not found + int64_t GetChildIndex(const std::string& name) const; + + private: + mutable std::unordered_map name_to_index_; }; class ARROW_EXPORT DecimalType : public FixedSizeBinaryType { From 708e78f816f6782d043b77b79e8f419245fdbb36 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kriszti=C3=A1n=20Sz=C5=B1cs?= Date: Wed, 25 Apr 2018 08:28:43 +0200 Subject: [PATCH 2/4] cpp unittests --- cpp/src/arrow/type-test.cc | 34 ++++++++++++++++++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/cpp/src/arrow/type-test.cc b/cpp/src/arrow/type-test.cc index 48982cad424..0da0747427c 100644 --- a/cpp/src/arrow/type-test.cc +++ b/cpp/src/arrow/type-test.cc @@ -399,6 +399,40 @@ TEST(TestStructType, Basics) { // TODO(wesm): out of bounds for field(...) } +TEST(TestStructType, GetChildByName) { + auto f0 = field("f0", int32()); + auto f1 = field("f1", uint8(), false); + auto f2 = field("f2", utf8()); + auto f3 = field("f3", list(int16())); + + StructType struct_type({f0, f1, f2, f3}); + std::shared_ptr result; + + result = struct_type.GetChildByName("f1"); + ASSERT_TRUE(f1->Equals(result)); + + result = struct_type.GetChildByName("f3"); + ASSERT_TRUE(f3->Equals(result)); + + result = struct_type.GetChildByName("not-found"); + ASSERT_TRUE(result == nullptr); +} + +TEST(TestStructType, GetChildIndex) { + auto f0 = field("f0", int32()); + auto f1 = field("f1", uint8(), false); + auto f2 = field("f2", utf8()); + auto f3 = field("f3", list(int16())); + + StructType struct_type({f0, f1, f2, f3}); + + ASSERT_EQ(0, struct_type.GetChildIndex(f0->name())); + ASSERT_EQ(1, struct_type.GetChildIndex(f1->name())); + ASSERT_EQ(2, struct_type.GetChildIndex(f2->name())); + ASSERT_EQ(3, struct_type.GetChildIndex(f3->name())); + ASSERT_EQ(-1, struct_type.GetChildIndex("not-found")); +} + TEST(TypesTest, TestDecimal128Small) { Decimal128Type t1(8, 4); From f84885836706a7883b50cad0be78184b62bf5cc5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kriszti=C3=A1n=20Sz=C5=B1cs?= Date: Wed, 25 Apr 2018 09:10:27 +0200 Subject: [PATCH 3/4] implement StructValue.__getitem__ --- python/pyarrow/includes/libarrow.pxd | 3 +++ python/pyarrow/lib.pxd | 6 ++++++ python/pyarrow/scalar.pxi | 24 ++++++++++++++++++++---- python/pyarrow/tests/test_scalars.py | 18 ++++++++++++++++++ 4 files changed, 47 insertions(+), 4 deletions(-) diff --git a/python/pyarrow/includes/libarrow.pxd b/python/pyarrow/includes/libarrow.pxd index 70eb9cbb205..d97e86ac280 100644 --- a/python/pyarrow/includes/libarrow.pxd +++ b/python/pyarrow/includes/libarrow.pxd @@ -253,6 +253,9 @@ cdef extern from "arrow/api.h" namespace "arrow" nogil: cdef cppclass CStructType" arrow::StructType"(CDataType): CStructType(const vector[shared_ptr[CField]]& fields) + shared_ptr[CField] GetChildByName(const c_string& name) + int64_t GetChildIndex(const c_string& name) + cdef cppclass CUnionType" arrow::UnionType"(CDataType): CUnionType(const vector[shared_ptr[CField]]& fields, const vector[uint8_t]& type_codes, UnionMode mode) diff --git a/python/pyarrow/lib.pxd b/python/pyarrow/lib.pxd index 998eeafc630..c6864d2ca27 100644 --- a/python/pyarrow/lib.pxd +++ b/python/pyarrow/lib.pxd @@ -143,6 +143,11 @@ cdef class ListValue(ArrayValue): cdef int64_t length(self) +cdef class StructValue(ArrayValue): + cdef: + CStructArray* ap + + cdef class UnionValue(ArrayValue): cdef: CUnionArray* ap @@ -150,6 +155,7 @@ cdef class UnionValue(ArrayValue): cdef getitem(self, int64_t i) + cdef class StringValue(ArrayValue): pass diff --git a/python/pyarrow/scalar.pxi b/python/pyarrow/scalar.pxi index e1527713787..2a6d0c66c0d 100644 --- a/python/pyarrow/scalar.pxi +++ b/python/pyarrow/scalar.pxi @@ -371,14 +371,29 @@ cdef class FixedSizeBinaryValue(ArrayValue): cdef class StructValue(ArrayValue): + cdef void _set_array(self, const shared_ptr[CArray]& sp_array): + self.sp_array = sp_array + self.ap = sp_array.get() + + def __getitem__(self, key): + cdef: + CStructType* type + int64_t index + + type = self.type.type + index = type.GetChildIndex(tobytes(key)) + + if index < 0: + raise KeyError(key) + + return pyarrow_wrap_array(self.ap.field(index))[self.index] + def as_py(self): cdef: - CStructArray* ap vector[shared_ptr[CField]] child_fields = self.type.type.children() - ap = self.sp_array.get() - wrapped_arrays = [pyarrow_wrap_array(ap.field(i)) - for i in range(ap.num_fields())] + wrapped_arrays = [pyarrow_wrap_array(self.ap.field(i)) + for i in range(self.ap.num_fields())] child_names = [child.get().name() for child in child_fields] # Return the struct as a dict return { @@ -415,6 +430,7 @@ cdef dict _scalar_classes = { _Type_STRUCT: StructValue, } + cdef object box_scalar(DataType type, const shared_ptr[CArray]& sp_array, int64_t index): cdef ArrayValue val diff --git a/python/pyarrow/tests/test_scalars.py b/python/pyarrow/tests/test_scalars.py index 41057a0a40b..33d4a0cbc42 100644 --- a/python/pyarrow/tests/test_scalars.py +++ b/python/pyarrow/tests/test_scalars.py @@ -215,3 +215,21 @@ def test_array_to_set(self): set_from_array = set(arr) assert isinstance(set_from_array, set) assert set_from_array == {1, 2} + + def test_struct_array_subscripting(self): + ty = pa.struct([pa.field('x', pa.int16()), + pa.field('y', pa.float32())]) + arr = pa.array([(1, 2.5), (3, 4.5), (5, 6.5)], type=ty) + + assert arr[0]['x'] == 1 + assert arr[0]['y'] == 2.5 + assert arr[1]['x'] == 3 + assert arr[1]['y'] == 4.5 + assert arr[2]['x'] == 5 + assert arr[2]['y'] == 6.5 + + with pytest.raises(IndexError): + arr[4]['non-existent'] + + with pytest.raises(KeyError): + arr[0]['non-existent'] From bbd496c09688e1d087b19d0c70630ae42c84a98d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kriszti=C3=A1n=20Sz=C5=B1cs?= Date: Thu, 26 Apr 2018 11:30:01 +0200 Subject: [PATCH 4/4] fix review issues --- cpp/src/arrow/type-test.cc | 6 +++--- cpp/src/arrow/type.cc | 4 ++-- cpp/src/arrow/type.h | 5 ++++- python/pyarrow/includes/libarrow.pxd | 2 +- python/pyarrow/scalar.pxi | 2 +- python/pyarrow/tests/test_scalars.py | 2 +- 6 files changed, 12 insertions(+), 9 deletions(-) diff --git a/cpp/src/arrow/type-test.cc b/cpp/src/arrow/type-test.cc index 0da0747427c..f62d14d049b 100644 --- a/cpp/src/arrow/type-test.cc +++ b/cpp/src/arrow/type-test.cc @@ -409,13 +409,13 @@ TEST(TestStructType, GetChildByName) { std::shared_ptr result; result = struct_type.GetChildByName("f1"); - ASSERT_TRUE(f1->Equals(result)); + ASSERT_EQ(f1, result); result = struct_type.GetChildByName("f3"); - ASSERT_TRUE(f3->Equals(result)); + ASSERT_EQ(f3, result); result = struct_type.GetChildByName("not-found"); - ASSERT_TRUE(result == nullptr); + ASSERT_EQ(result, nullptr); } TEST(TestStructType, GetChildIndex) { diff --git a/cpp/src/arrow/type.cc b/cpp/src/arrow/type.cc index e84b074ac68..2eb9967976d 100644 --- a/cpp/src/arrow/type.cc +++ b/cpp/src/arrow/type.cc @@ -210,11 +210,11 @@ std::string StructType::ToString() const { } std::shared_ptr StructType::GetChildByName(const std::string& name) const { - int64_t i = GetChildIndex(name); + int i = GetChildIndex(name); return i == -1 ? nullptr : children_[i]; } -int64_t StructType::GetChildIndex(const std::string& name) const { +int StructType::GetChildIndex(const std::string& name) const { if (children_.size() > 0 && name_to_index_.size() == 0) { for (size_t i = 0; i < children_.size(); ++i) { name_to_index_[children_[i]->name()] = static_cast(i); diff --git a/cpp/src/arrow/type.h b/cpp/src/arrow/type.h index 0ed1f9914ec..9cd1d8f86db 100644 --- a/cpp/src/arrow/type.h +++ b/cpp/src/arrow/type.h @@ -480,9 +480,10 @@ class ARROW_EXPORT StructType : public NestedType { std::shared_ptr GetChildByName(const std::string& name) const; /// Returns -1 if name not found - int64_t GetChildIndex(const std::string& name) const; + int GetChildIndex(const std::string& name) const; private: + /// Lazily initialized mapping mutable std::unordered_map name_to_index_; }; @@ -786,6 +787,8 @@ class ARROW_EXPORT Schema { private: std::vector> fields_; + + /// Lazily initialized mapping mutable std::unordered_map name_to_index_; std::shared_ptr metadata_; diff --git a/python/pyarrow/includes/libarrow.pxd b/python/pyarrow/includes/libarrow.pxd index d97e86ac280..0df7f34f33d 100644 --- a/python/pyarrow/includes/libarrow.pxd +++ b/python/pyarrow/includes/libarrow.pxd @@ -254,7 +254,7 @@ cdef extern from "arrow/api.h" namespace "arrow" nogil: CStructType(const vector[shared_ptr[CField]]& fields) shared_ptr[CField] GetChildByName(const c_string& name) - int64_t GetChildIndex(const c_string& name) + int GetChildIndex(const c_string& name) cdef cppclass CUnionType" arrow::UnionType"(CDataType): CUnionType(const vector[shared_ptr[CField]]& fields, diff --git a/python/pyarrow/scalar.pxi b/python/pyarrow/scalar.pxi index 2a6d0c66c0d..04ecc9cf46e 100644 --- a/python/pyarrow/scalar.pxi +++ b/python/pyarrow/scalar.pxi @@ -378,7 +378,7 @@ cdef class StructValue(ArrayValue): def __getitem__(self, key): cdef: CStructType* type - int64_t index + int index type = self.type.type index = type.GetChildIndex(tobytes(key)) diff --git a/python/pyarrow/tests/test_scalars.py b/python/pyarrow/tests/test_scalars.py index 33d4a0cbc42..9c86270c454 100644 --- a/python/pyarrow/tests/test_scalars.py +++ b/python/pyarrow/tests/test_scalars.py @@ -216,7 +216,7 @@ def test_array_to_set(self): assert isinstance(set_from_array, set) assert set_from_array == {1, 2} - def test_struct_array_subscripting(self): + def test_struct_value_subscripting(self): ty = pa.struct([pa.field('x', pa.int16()), pa.field('y', pa.float32())]) arr = pa.array([(1, 2.5), (3, 4.5), (5, 6.5)], type=ty)