Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 15 additions & 5 deletions python/pyarrow/array.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -244,12 +244,22 @@ cdef class Schema:
def __len__(self):
return self.schema.num_fields()

def __getitem__(self, i):
if i < 0 or i >= len(self):
raise IndexError("{0} is out of bounds".format(i))
def __getitem__(self, int64_t i):

cdef:
Field result = Field()
int64_t num_fields = self.schema.num_fields()
int64_t index

if not -num_fields <= i < num_fields:
raise IndexError(
'Schema field index {:d} is out of range'.format(i)
)

index = i if i >= 0 else num_fields + i
assert index >= 0

cdef Field result = Field()
result.init(self.schema.field(i))
result.init(self.schema.field(index))
result.type = pyarrow_wrap_data_type(result.field.type())

return result
Expand Down
21 changes: 17 additions & 4 deletions python/pyarrow/table.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -762,24 +762,37 @@ cdef class Table:
"""
return pyarrow_wrap_schema(self.table.schema())

def column(self, index):
def column(self, int64_t i):
"""
Select a column by its numeric index.

Parameters
----------
index: int
i : int

Returns
-------
pyarrow.Column
"""
self._check_nullptr()
cdef Column column = Column()

cdef:
Column column = Column()
int64_t num_columns = self.num_columns
int64_t index

if not -num_columns <= i < num_columns:
raise IndexError(
'Table column index {:d} is out of range'.format(i)
)

index = i if i >= 0 else num_columns + i
assert index >= 0

column.init(self.table.column(index))
return column

def __getitem__(self, i):
def __getitem__(self, int64_t i):
return self.column(i)

def itercolumns(self):
Expand Down
20 changes: 20 additions & 0 deletions python/pyarrow/tests/test_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,3 +212,23 @@ def test_schema_equals():
del fields[-1]
sch3 = pa.schema(fields)
assert not sch1.equals(sch3)


def test_schema_negative_indexing():
fields = [
pa.field('foo', pa.int32()),
pa.field('bar', pa.string()),
pa.field('baz', pa.list_(pa.int8()))
]

schema = pa.schema(fields)

assert schema[-1].equals(schema[2])
assert schema[-2].equals(schema[1])
assert schema[-3].equals(schema[0])

with pytest.raises(IndexError):
schema[-4]

with pytest.raises(IndexError):
schema[3]
21 changes: 21 additions & 0 deletions python/pyarrow/tests/test_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,3 +254,24 @@ def test_table_pandas():
assert set(df.columns) == set(('a', 'b'))
assert df.shape == (5, 2)
assert df.loc[0, 'b'] == -10


def test_table_negative_indexing():
data = [
pa.array(range(5)),
pa.array([-10, -5, 0, 5, 10]),
pa.array([1.0, 2.0, 3.0]),
pa.array(['ab', 'bc', 'cd']),
]
table = pa.Table.from_arrays(data, names=tuple('abcd'))

assert table[-1].equals(table[3])
assert table[-2].equals(table[2])
assert table[-3].equals(table[1])
assert table[-4].equals(table[0])

with pytest.raises(IndexError):
table[-5]

with pytest.raises(IndexError):
table[4]