Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions cpp/src/arrow/type-test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -399,6 +399,40 @@ TEST(TestStructType, Basics) {
// TODO(wesm): out of bounds for field(...)
}

TEST(TestStructType, GetChildByName) {
auto f0 = field("f0", int32());
auto f1 = field("f1", uint8(), false);
auto f2 = field("f2", utf8());
auto f3 = field("f3", list(int16()));

StructType struct_type({f0, f1, f2, f3});
std::shared_ptr<Field> result;

result = struct_type.GetChildByName("f1");
ASSERT_EQ(f1, result);

result = struct_type.GetChildByName("f3");
ASSERT_EQ(f3, result);

result = struct_type.GetChildByName("not-found");
ASSERT_EQ(result, nullptr);
}

TEST(TestStructType, GetChildIndex) {
auto f0 = field("f0", int32());
auto f1 = field("f1", uint8(), false);
auto f2 = field("f2", utf8());
auto f3 = field("f3", list(int16()));

StructType struct_type({f0, f1, f2, f3});

ASSERT_EQ(0, struct_type.GetChildIndex(f0->name()));
ASSERT_EQ(1, struct_type.GetChildIndex(f1->name()));
ASSERT_EQ(2, struct_type.GetChildIndex(f2->name()));
ASSERT_EQ(3, struct_type.GetChildIndex(f3->name()));
ASSERT_EQ(-1, struct_type.GetChildIndex("not-found"));
}

TEST(TypesTest, TestDecimal128Small) {
Decimal128Type t1(8, 4);

Expand Down
51 changes: 37 additions & 14 deletions cpp/src/arrow/type.cc
Original file line number Diff line number Diff line change
Expand Up @@ -109,20 +109,6 @@ std::string FixedSizeBinaryType::ToString() const {
return ss.str();
}

std::string StructType::ToString() const {
std::stringstream s;
s << "struct<";
for (int i = 0; i < this->num_children(); ++i) {
if (i > 0) {
s << ", ";
}
std::shared_ptr<Field> field = this->child(i);
s << field->name() << ": " << field->type()->ToString();
}
s << ">";
return s.str();
}

// ----------------------------------------------------------------------
// Date types

Expand Down Expand Up @@ -206,6 +192,43 @@ std::string UnionType::ToString() const {
return s.str();
}

// ----------------------------------------------------------------------
// Struct type

std::string StructType::ToString() const {
std::stringstream s;
s << "struct<";
for (int i = 0; i < this->num_children(); ++i) {
if (i > 0) {
s << ", ";
}
std::shared_ptr<Field> field = this->child(i);
s << field->name() << ": " << field->type()->ToString();
}
s << ">";
return s.str();
}

std::shared_ptr<Field> StructType::GetChildByName(const std::string& name) const {
int i = GetChildIndex(name);
return i == -1 ? nullptr : children_[i];
}

int StructType::GetChildIndex(const std::string& name) const {
if (children_.size() > 0 && name_to_index_.size() == 0) {
for (size_t i = 0; i < children_.size(); ++i) {
name_to_index_[children_[i]->name()] = static_cast<int>(i);
}
}

auto it = name_to_index_.find(name);
if (it == name_to_index_.end()) {
return -1;
} else {
return it->second;
}
}

// ----------------------------------------------------------------------
// DictionaryType

Expand Down
12 changes: 12 additions & 0 deletions cpp/src/arrow/type.h
Original file line number Diff line number Diff line change
Expand Up @@ -475,6 +475,16 @@ class ARROW_EXPORT StructType : public NestedType {
Status Accept(TypeVisitor* visitor) const override;
std::string ToString() const override;
std::string name() const override { return "struct"; }

/// Returns null if name not found
std::shared_ptr<Field> GetChildByName(const std::string& name) const;

/// Returns -1 if name not found
int GetChildIndex(const std::string& name) const;

private:
/// Lazily initialized mapping
mutable std::unordered_map<std::string, int> name_to_index_;
};

class ARROW_EXPORT DecimalType : public FixedSizeBinaryType {
Expand Down Expand Up @@ -777,6 +787,8 @@ class ARROW_EXPORT Schema {

private:
std::vector<std::shared_ptr<Field>> fields_;

/// Lazily initialized mapping
mutable std::unordered_map<std::string, int> name_to_index_;

std::shared_ptr<const KeyValueMetadata> metadata_;
Expand Down
3 changes: 3 additions & 0 deletions python/pyarrow/includes/libarrow.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,9 @@ cdef extern from "arrow/api.h" namespace "arrow" nogil:
cdef cppclass CStructType" arrow::StructType"(CDataType):
CStructType(const vector[shared_ptr[CField]]& fields)

shared_ptr[CField] GetChildByName(const c_string& name)
int GetChildIndex(const c_string& name)

cdef cppclass CUnionType" arrow::UnionType"(CDataType):
CUnionType(const vector[shared_ptr[CField]]& fields,
const vector[uint8_t]& type_codes, UnionMode mode)
Expand Down
6 changes: 6 additions & 0 deletions python/pyarrow/lib.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -143,13 +143,19 @@ cdef class ListValue(ArrayValue):
cdef int64_t length(self)


cdef class StructValue(ArrayValue):
cdef:
CStructArray* ap


cdef class UnionValue(ArrayValue):
cdef:
CUnionArray* ap
list value_types

cdef getitem(self, int64_t i)


cdef class StringValue(ArrayValue):
pass

Expand Down
24 changes: 20 additions & 4 deletions python/pyarrow/scalar.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -371,14 +371,29 @@ cdef class FixedSizeBinaryValue(ArrayValue):

cdef class StructValue(ArrayValue):

cdef void _set_array(self, const shared_ptr[CArray]& sp_array):
self.sp_array = sp_array
self.ap = <CStructArray*> sp_array.get()

def __getitem__(self, key):
cdef:
CStructType* type
int index

type = <CStructType*> self.type.type
index = type.GetChildIndex(tobytes(key))

if index < 0:
raise KeyError(key)

return pyarrow_wrap_array(self.ap.field(index))[self.index]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are passing an int64_t to a function that takes a (32-bit) int, hence the warning about truncation which is turned into an error.


def as_py(self):
cdef:
CStructArray* ap
vector[shared_ptr[CField]] child_fields = self.type.type.children()

ap = <CStructArray*> self.sp_array.get()
wrapped_arrays = [pyarrow_wrap_array(ap.field(i))
for i in range(ap.num_fields())]
wrapped_arrays = [pyarrow_wrap_array(self.ap.field(i))
for i in range(self.ap.num_fields())]
child_names = [child.get().name() for child in child_fields]
# Return the struct as a dict
return {
Expand Down Expand Up @@ -415,6 +430,7 @@ cdef dict _scalar_classes = {
_Type_STRUCT: StructValue,
}


cdef object box_scalar(DataType type, const shared_ptr[CArray]& sp_array,
int64_t index):
cdef ArrayValue val
Expand Down
18 changes: 18 additions & 0 deletions python/pyarrow/tests/test_scalars.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,3 +215,21 @@ def test_array_to_set(self):
set_from_array = set(arr)
assert isinstance(set_from_array, set)
assert set_from_array == {1, 2}

def test_struct_value_subscripting(self):
ty = pa.struct([pa.field('x', pa.int16()),
pa.field('y', pa.float32())])
arr = pa.array([(1, 2.5), (3, 4.5), (5, 6.5)], type=ty)

assert arr[0]['x'] == 1
assert arr[0]['y'] == 2.5
assert arr[1]['x'] == 3
assert arr[1]['y'] == 4.5
assert arr[2]['x'] == 5
assert arr[2]['y'] == 6.5

with pytest.raises(IndexError):
arr[4]['non-existent']

with pytest.raises(KeyError):
arr[0]['non-existent']