diff --git a/python/pyarrow/array.pxi b/python/pyarrow/array.pxi index efbe36f80b3..67418aa5eac 100644 --- a/python/pyarrow/array.pxi +++ b/python/pyarrow/array.pxi @@ -89,6 +89,23 @@ def array(object sequence, DataType type=None, MemoryPool memory_pool=None, return pyarrow_wrap_array(sp_array) +def _normalize_slice(object arrow_obj, slice key): + cdef Py_ssize_t n = len(arrow_obj) + + start = key.start or 0 + while start < 0: + start += n + + stop = key.stop if key.stop is not None else n + while stop < 0: + stop += n + + step = key.step or 1 + if step != 1: + raise IndexError('only slices with step 1 supported') + else: + return arrow_obj.slice(start, stop - start) + cdef class Array: @@ -230,23 +247,10 @@ cdef class Array: raise NotImplemented def __getitem__(self, key): - cdef: - Py_ssize_t n = len(self) + cdef Py_ssize_t n = len(self) if PySlice_Check(key): - start = key.start or 0 - while start < 0: - start += n - - stop = key.stop if key.stop is not None else n - while stop < 0: - stop += n - - step = key.step or 1 - if step != 1: - raise IndexError('only slices with step 1 supported') - else: - return self.slice(start, stop - start) + return _normalize_slice(self, key) while key < 0: key += len(self) diff --git a/python/pyarrow/table.pxi b/python/pyarrow/table.pxi index 6188e90616b..a9cb06480cd 100644 --- a/python/pyarrow/table.pxi +++ b/python/pyarrow/table.pxi @@ -475,8 +475,13 @@ cdef class RecordBatch: ) return pyarrow_wrap_array(self.batch.column(i)) - def __getitem__(self, i): - return self.column(i) + def __getitem__(self, key): + cdef: + Py_ssize_t start, stop + if isinstance(key, slice): + return _normalize_slice(self, key) + else: + return self.column(key) def slice(self, offset=0, length=None): """ diff --git a/python/pyarrow/tests/test_table.py b/python/pyarrow/tests/test_table.py index c2aeda9b2df..28b98f0952a 100644 --- a/python/pyarrow/tests/test_table.py +++ b/python/pyarrow/tests/test_table.py @@ -80,7 +80,7 @@ def test_recordbatch_basics(): batch[2] -def test_recordbatch_slice(): +def test_recordbatch_slice_getitem(): data = [ pa.array(range(5)), pa.array([-10, -5, 0, 5, 10]) @@ -90,7 +90,6 @@ def test_recordbatch_slice(): batch = pa.RecordBatch.from_arrays(data, names) sliced = batch.slice(2) - assert sliced.num_rows == 3 expected = pa.RecordBatch.from_arrays( @@ -111,6 +110,14 @@ def test_recordbatch_slice(): with pytest.raises(IndexError): batch.slice(-1) + # Check __getitem__-based slicing + assert batch.slice(0, 0).equals(batch[:0]) + assert batch.slice(0, 2).equals(batch[:2]) + assert batch.slice(2, 2).equals(batch[2:4]) + assert batch.slice(2, len(batch) - 2).equals(batch[2:]) + assert batch.slice(len(batch) - 2, 2).equals(batch[-2:]) + assert batch.slice(len(batch) - 4, 2).equals(batch[-4:-2]) + def test_recordbatch_from_to_pandas(): data = pd.DataFrame({