diff --git a/python/pyarrow/array.pxi b/python/pyarrow/array.pxi index 46e94b4f4b3..c132269c563 100644 --- a/python/pyarrow/array.pxi +++ b/python/pyarrow/array.pxi @@ -244,12 +244,22 @@ cdef class Schema: def __len__(self): return self.schema.num_fields() - def __getitem__(self, i): - if i < 0 or i >= len(self): - raise IndexError("{0} is out of bounds".format(i)) + def __getitem__(self, int64_t i): + + cdef: + Field result = Field() + int64_t num_fields = self.schema.num_fields() + int64_t index + + if not -num_fields <= i < num_fields: + raise IndexError( + 'Schema field index {:d} is out of range'.format(i) + ) + + index = i if i >= 0 else num_fields + i + assert index >= 0 - cdef Field result = Field() - result.init(self.schema.field(i)) + result.init(self.schema.field(index)) result.type = pyarrow_wrap_data_type(result.field.type()) return result diff --git a/python/pyarrow/table.pxi b/python/pyarrow/table.pxi index 8dd18cf4136..bd8cce41400 100644 --- a/python/pyarrow/table.pxi +++ b/python/pyarrow/table.pxi @@ -762,24 +762,37 @@ cdef class Table: """ return pyarrow_wrap_schema(self.table.schema()) - def column(self, index): + def column(self, int64_t i): """ Select a column by its numeric index. Parameters ---------- - index: int + i : int Returns ------- pyarrow.Column """ self._check_nullptr() - cdef Column column = Column() + + cdef: + Column column = Column() + int64_t num_columns = self.num_columns + int64_t index + + if not -num_columns <= i < num_columns: + raise IndexError( + 'Table column index {:d} is out of range'.format(i) + ) + + index = i if i >= 0 else num_columns + i + assert index >= 0 + column.init(self.table.column(index)) return column - def __getitem__(self, i): + def __getitem__(self, int64_t i): return self.column(i) def itercolumns(self): diff --git a/python/pyarrow/tests/test_schema.py b/python/pyarrow/tests/test_schema.py index 2d98865b56e..a6fe1a5df0e 100644 --- a/python/pyarrow/tests/test_schema.py +++ b/python/pyarrow/tests/test_schema.py @@ -212,3 +212,23 @@ def test_schema_equals(): del fields[-1] sch3 = pa.schema(fields) assert not sch1.equals(sch3) + + +def test_schema_negative_indexing(): + fields = [ + pa.field('foo', pa.int32()), + pa.field('bar', pa.string()), + pa.field('baz', pa.list_(pa.int8())) + ] + + schema = pa.schema(fields) + + assert schema[-1].equals(schema[2]) + assert schema[-2].equals(schema[1]) + assert schema[-3].equals(schema[0]) + + with pytest.raises(IndexError): + schema[-4] + + with pytest.raises(IndexError): + schema[3] diff --git a/python/pyarrow/tests/test_table.py b/python/pyarrow/tests/test_table.py index 0567e8aba68..72ce6967edf 100644 --- a/python/pyarrow/tests/test_table.py +++ b/python/pyarrow/tests/test_table.py @@ -254,3 +254,24 @@ def test_table_pandas(): assert set(df.columns) == set(('a', 'b')) assert df.shape == (5, 2) assert df.loc[0, 'b'] == -10 + + +def test_table_negative_indexing(): + data = [ + pa.array(range(5)), + pa.array([-10, -5, 0, 5, 10]), + pa.array([1.0, 2.0, 3.0]), + pa.array(['ab', 'bc', 'cd']), + ] + table = pa.Table.from_arrays(data, names=tuple('abcd')) + + assert table[-1].equals(table[3]) + assert table[-2].equals(table[2]) + assert table[-3].equals(table[1]) + assert table[-4].equals(table[0]) + + with pytest.raises(IndexError): + table[-5] + + with pytest.raises(IndexError): + table[4]