Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 19 additions & 15 deletions python/pyarrow/array.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,23 @@ def array(object sequence, DataType type=None, MemoryPool memory_pool=None,
return pyarrow_wrap_array(sp_array)


def _normalize_slice(object arrow_obj, slice key):
cdef Py_ssize_t n = len(arrow_obj)

start = key.start or 0
while start < 0:
start += n

stop = key.stop if key.stop is not None else n
while stop < 0:
stop += n

step = key.step or 1
if step != 1:
raise IndexError('only slices with step 1 supported')
else:
return arrow_obj.slice(start, stop - start)


cdef class Array:

Expand Down Expand Up @@ -230,23 +247,10 @@ cdef class Array:
raise NotImplemented

def __getitem__(self, key):
cdef:
Py_ssize_t n = len(self)
cdef Py_ssize_t n = len(self)

if PySlice_Check(key):
start = key.start or 0
while start < 0:
start += n

stop = key.stop if key.stop is not None else n
while stop < 0:
stop += n

step = key.step or 1
if step != 1:
raise IndexError('only slices with step 1 supported')
else:
return self.slice(start, stop - start)
return _normalize_slice(self, key)

while key < 0:
key += len(self)
Expand Down
9 changes: 7 additions & 2 deletions python/pyarrow/table.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -475,8 +475,13 @@ cdef class RecordBatch:
)
return pyarrow_wrap_array(self.batch.column(i))

def __getitem__(self, i):
return self.column(i)
def __getitem__(self, key):
cdef:
Py_ssize_t start, stop
if isinstance(key, slice):
return _normalize_slice(self, key)
else:
return self.column(key)

def slice(self, offset=0, length=None):
"""
Expand Down
11 changes: 9 additions & 2 deletions python/pyarrow/tests/test_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def test_recordbatch_basics():
batch[2]


def test_recordbatch_slice():
def test_recordbatch_slice_getitem():
data = [
pa.array(range(5)),
pa.array([-10, -5, 0, 5, 10])
Expand All @@ -90,7 +90,6 @@ def test_recordbatch_slice():
batch = pa.RecordBatch.from_arrays(data, names)

sliced = batch.slice(2)

assert sliced.num_rows == 3

expected = pa.RecordBatch.from_arrays(
Expand All @@ -111,6 +110,14 @@ def test_recordbatch_slice():
with pytest.raises(IndexError):
batch.slice(-1)

# Check __getitem__-based slicing
assert batch.slice(0, 0).equals(batch[:0])
assert batch.slice(0, 2).equals(batch[:2])
assert batch.slice(2, 2).equals(batch[2:4])
assert batch.slice(2, len(batch) - 2).equals(batch[2:])
assert batch.slice(len(batch) - 2, 2).equals(batch[-2:])
assert batch.slice(len(batch) - 4, 2).equals(batch[-4:-2])


def test_recordbatch_from_to_pandas():
data = pd.DataFrame({
Expand Down