diff --git a/python/pyarrow/array.pxi b/python/pyarrow/array.pxi index 230ef7c0651..cb4b5f82144 100644 --- a/python/pyarrow/array.pxi +++ b/python/pyarrow/array.pxi @@ -17,6 +17,27 @@ import warnings +from cpython.object cimport Py_LT, Py_EQ, Py_GT, Py_LE, Py_NE, Py_GE + + +cdef str _op_to_function_name(int op): + cdef str function_name + + if op == Py_EQ: + function_name = "equal" + elif op == Py_NE: + function_name = "not_equal" + elif op == Py_GT: + function_name = "greater" + elif op == Py_GE: + function_name = "greater_equal" + elif op == Py_LT: + function_name = "less" + elif op == Py_LE: + function_name = "less_equal" + + return function_name + cdef _sequence_to_array(object sequence, object mask, object size, DataType type, CMemoryPool* pool, c_bool from_pandas): @@ -602,14 +623,14 @@ cdef class Array(_PandasConvertible): self.ap = sp_array.get() self.type = pyarrow_wrap_data_type(self.sp_array.get().type()) - def __eq__(self, other): - raise NotImplementedError('Comparisons with pyarrow.Array are not ' - 'implemented') - def _debug_print(self): with nogil: check_status(DebugPrint(deref(self.ap), 0)) + def __richcmp__(self, other, int op): + function_name = _op_to_function_name(op) + return _pc().call_function(function_name, [self, other]) + def diff(self, Array other): """ Compare contents of this array against another one. diff --git a/python/pyarrow/table.pxi b/python/pyarrow/table.pxi index 322023dde8a..612d05f4554 100644 --- a/python/pyarrow/table.pxi +++ b/python/pyarrow/table.pxi @@ -41,6 +41,10 @@ cdef class ChunkedArray(_PandasConvertible): def __reduce__(self): return chunked_array, (self.chunks, self.type) + def __richcmp__(self, other, int op): + function_name = _op_to_function_name(op) + return _pc().call_function(function_name, [self, other]) + @property def data(self): import warnings @@ -173,12 +177,6 @@ cdef class ChunkedArray(_PandasConvertible): else: index -= self.chunked_array.chunk(j).get().length() - def __eq__(self, other): - try: - return self.equals(other) - except TypeError: - return NotImplemented - def equals(self, ChunkedArray other): """ Return whether the contents of two chunked arrays are equal. diff --git a/python/pyarrow/tests/test_array.py b/python/pyarrow/tests/test_array.py index 9cc80ed99c1..858a57e09b8 100644 --- a/python/pyarrow/tests/test_array.py +++ b/python/pyarrow/tests/test_array.py @@ -395,16 +395,6 @@ def test_array_ref_to_ndarray_base(): assert sys.getrefcount(arr) == (refcount + 1) -def test_array_eq_raises(): - # ARROW-2150: we are raising when comparing arrays until we define the - # behavior to either be elementwise comparisons or data equality - arr1 = pa.array([1, 2, 3], type=pa.int32()) - arr2 = pa.array([1, 2, 3], type=pa.int32()) - - with pytest.raises(NotImplementedError): - arr1 == arr2 - - def test_array_from_buffers(): values_buf = pa.py_buffer(np.int16([4, 5, 6, 7])) nulls_buf = pa.py_buffer(np.uint8([0b00001101])) diff --git a/python/pyarrow/tests/test_compute.py b/python/pyarrow/tests/test_compute.py index 09c4d024070..fd262e12e79 100644 --- a/python/pyarrow/tests/test_compute.py +++ b/python/pyarrow/tests/test_compute.py @@ -238,3 +238,78 @@ def test_filter_errors(): with pytest.raises(pa.ArrowInvalid, match="must all be the same length"): obj.filter(mask) + + +@pytest.mark.parametrize("typ", ["array", "chunked_array"]) +def test_compare_array(typ): + if typ == "array": + def con(values): return pa.array(values) + else: + def con(values): return pa.chunked_array([values]) + + arr1 = con([1, 2, 3, 4, None]) + arr2 = con([1, 1, 4, None, 4]) + + result = arr1 == arr2 + assert result.equals(con([True, False, False, None, None])) + + result = arr1 != arr2 + assert result.equals(con([False, True, True, None, None])) + + result = arr1 < arr2 + assert result.equals(con([False, False, True, None, None])) + + result = arr1 <= arr2 + assert result.equals(con([True, False, True, None, None])) + + result = arr1 > arr2 + assert result.equals(con([False, True, False, None, None])) + + result = arr1 >= arr2 + assert result.equals(con([True, True, False, None, None])) + + +@pytest.mark.parametrize("typ", ["array", "chunked_array"]) +def test_compare_scalar(typ): + if typ == "array": + def con(values): return pa.array(values) + else: + def con(values): return pa.chunked_array([values]) + + arr = con([1, 2, 3, None]) + # TODO this is a hacky way to construct a scalar .. + scalar = pa.array([2]).sum() + + result = arr == scalar + assert result.equals(con([False, True, False, None])) + + result = arr != scalar + assert result.equals(con([True, False, True, None])) + + result = arr < scalar + assert result.equals(con([True, False, False, None])) + + result = arr <= scalar + assert result.equals(con([True, True, False, None])) + + result = arr > scalar + assert result.equals(con([False, False, True, None])) + + result = arr >= scalar + assert result.equals(con([False, True, True, None])) + + +def test_compare_chunked_array_mixed(): + + arr = pa.array([1, 2, 3, 4, None]) + arr_chunked = pa.chunked_array([[1, 2, 3], [4, None]]) + arr_chunked2 = pa.chunked_array([[1, 2], [3, 4, None]]) + + expected = pa.chunked_array([[True, True, True, True, None]]) + + for result in [ + arr == arr_chunked, + arr_chunked == arr, + arr_chunked == arr_chunked2, + ]: + assert result.equals(expected) diff --git a/python/pyarrow/tests/test_table.py b/python/pyarrow/tests/test_table.py index 491cca01943..69ee4058fa2 100644 --- a/python/pyarrow/tests/test_table.py +++ b/python/pyarrow/tests/test_table.py @@ -128,8 +128,6 @@ def eq(xarrs, yarrs): y = pa.chunked_array(yarrs) assert x.equals(y) assert y.equals(x) - assert x == y - assert x != str(y) def ne(xarrs, yarrs): if isinstance(xarrs, pa.ChunkedArray): @@ -142,7 +140,6 @@ def ne(xarrs, yarrs): y = pa.chunked_array(yarrs) assert not x.equals(y) assert not y.equals(x) - assert x != y eq(pa.chunked_array([], type=pa.int32()), pa.chunked_array([], type=pa.int32()))