Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 25 additions & 4 deletions python/pyarrow/array.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,27 @@

import warnings

from cpython.object cimport Py_LT, Py_EQ, Py_GT, Py_LE, Py_NE, Py_GE


cdef str _op_to_function_name(int op):
cdef str function_name

if op == Py_EQ:
function_name = "equal"
elif op == Py_NE:
function_name = "not_equal"
elif op == Py_GT:
function_name = "greater"
elif op == Py_GE:
function_name = "greater_equal"
elif op == Py_LT:
function_name = "less"
elif op == Py_LE:
function_name = "less_equal"

return function_name


cdef _sequence_to_array(object sequence, object mask, object size,
DataType type, CMemoryPool* pool, c_bool from_pandas):
Expand Down Expand Up @@ -602,14 +623,14 @@ cdef class Array(_PandasConvertible):
self.ap = sp_array.get()
self.type = pyarrow_wrap_data_type(self.sp_array.get().type())

def __eq__(self, other):
raise NotImplementedError('Comparisons with pyarrow.Array are not '
'implemented')

def _debug_print(self):
with nogil:
check_status(DebugPrint(deref(self.ap), 0))

def __richcmp__(self, other, int op):
function_name = _op_to_function_name(op)
return _pc().call_function(function_name, [self, other])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe factor this out into a helper function to avoid the code duplication?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved the op to function name conversion into a helper function


def diff(self, Array other):
"""
Compare contents of this array against another one.
Expand Down
10 changes: 4 additions & 6 deletions python/pyarrow/table.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@ cdef class ChunkedArray(_PandasConvertible):
def __reduce__(self):
return chunked_array, (self.chunks, self.type)

def __richcmp__(self, other, int op):
function_name = _op_to_function_name(op)
return _pc().call_function(function_name, [self, other])

@property
def data(self):
import warnings
Expand Down Expand Up @@ -173,12 +177,6 @@ cdef class ChunkedArray(_PandasConvertible):
else:
index -= self.chunked_array.chunk(j).get().length()

def __eq__(self, other):
try:
return self.equals(other)
except TypeError:
return NotImplemented

def equals(self, ChunkedArray other):
"""
Return whether the contents of two chunked arrays are equal.
Expand Down
10 changes: 0 additions & 10 deletions python/pyarrow/tests/test_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,16 +395,6 @@ def test_array_ref_to_ndarray_base():
assert sys.getrefcount(arr) == (refcount + 1)


def test_array_eq_raises():
# ARROW-2150: we are raising when comparing arrays until we define the
# behavior to either be elementwise comparisons or data equality
arr1 = pa.array([1, 2, 3], type=pa.int32())
arr2 = pa.array([1, 2, 3], type=pa.int32())

with pytest.raises(NotImplementedError):
arr1 == arr2


def test_array_from_buffers():
values_buf = pa.py_buffer(np.int16([4, 5, 6, 7]))
nulls_buf = pa.py_buffer(np.uint8([0b00001101]))
Expand Down
75 changes: 75 additions & 0 deletions python/pyarrow/tests/test_compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,3 +238,78 @@ def test_filter_errors():
with pytest.raises(pa.ArrowInvalid,
match="must all be the same length"):
obj.filter(mask)


@pytest.mark.parametrize("typ", ["array", "chunked_array"])
def test_compare_array(typ):
if typ == "array":
def con(values): return pa.array(values)
else:
def con(values): return pa.chunked_array([values])

arr1 = con([1, 2, 3, 4, None])
arr2 = con([1, 1, 4, None, 4])

result = arr1 == arr2
assert result.equals(con([True, False, False, None, None]))

result = arr1 != arr2
assert result.equals(con([False, True, True, None, None]))

result = arr1 < arr2
assert result.equals(con([False, False, True, None, None]))

result = arr1 <= arr2
assert result.equals(con([True, False, True, None, None]))

result = arr1 > arr2
assert result.equals(con([False, True, False, None, None]))

result = arr1 >= arr2
assert result.equals(con([True, True, False, None, None]))


@pytest.mark.parametrize("typ", ["array", "chunked_array"])
def test_compare_scalar(typ):
if typ == "array":
def con(values): return pa.array(values)
else:
def con(values): return pa.chunked_array([values])

arr = con([1, 2, 3, None])
# TODO this is a hacky way to construct a scalar ..
scalar = pa.array([2]).sum()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a JIRA about adding a pyarrow.scalar function (in order to do pa.scalar(2) here)?


result = arr == scalar
assert result.equals(con([False, True, False, None]))

result = arr != scalar
assert result.equals(con([True, False, True, None]))

result = arr < scalar
assert result.equals(con([True, False, False, None]))

result = arr <= scalar
assert result.equals(con([True, True, False, None]))

result = arr > scalar
assert result.equals(con([False, False, True, None]))

result = arr >= scalar
assert result.equals(con([False, True, True, None]))


def test_compare_chunked_array_mixed():

arr = pa.array([1, 2, 3, 4, None])
arr_chunked = pa.chunked_array([[1, 2, 3], [4, None]])
arr_chunked2 = pa.chunked_array([[1, 2], [3, 4, None]])

expected = pa.chunked_array([[True, True, True, True, None]])

for result in [
arr == arr_chunked,
arr_chunked == arr,
arr_chunked == arr_chunked2,
]:
assert result.equals(expected)
3 changes: 0 additions & 3 deletions python/pyarrow/tests/test_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,8 +128,6 @@ def eq(xarrs, yarrs):
y = pa.chunked_array(yarrs)
assert x.equals(y)
assert y.equals(x)
assert x == y
assert x != str(y)

def ne(xarrs, yarrs):
if isinstance(xarrs, pa.ChunkedArray):
Expand All @@ -142,7 +140,6 @@ def ne(xarrs, yarrs):
y = pa.chunked_array(yarrs)
assert not x.equals(y)
assert not y.equals(x)
assert x != y

eq(pa.chunked_array([], type=pa.int32()),
pa.chunked_array([], type=pa.int32()))
Expand Down