@@ -100,6 +100,11 @@ def data_for_grouping(request):
100100 return SparseArray ([1 , 1 , np .nan , np .nan , 2 , 2 , 1 , 3 ], fill_value = request .param )
101101
102102
103+ @pytest .fixture (params = [0 , np .nan ])
104+ def data_for_compare (request ):
105+ return SparseArray ([0 , 0 , np .nan , - 2 , - 1 , 4 , 2 , 3 , 0 , 0 ], fill_value = request .param )
106+
107+
103108class BaseSparseTests :
104109 def _check_unsupported (self , data ):
105110 if data .dtype == SparseDtype (int , 0 ):
@@ -432,32 +437,48 @@ def _check_divmod_op(self, ser, op, other, exc=NotImplementedError):
432437 super ()._check_divmod_op (ser , op , other , exc = None )
433438
434439
435- class TestComparisonOps (BaseSparseTests , base . BaseComparisonOpsTests ):
436- def _compare_other (self , s , data , comparison_op , other ):
440+ class TestComparisonOps (BaseSparseTests ):
441+ def _compare_other (self , data_for_compare : SparseArray , comparison_op , other ):
437442 op = comparison_op
438443
439- # array
440- result = pd .Series (op (data , other ))
441- # hard to test the fill value, since we don't know what expected
442- # is in general.
443- # Rely on tests in `tests/sparse` to validate that.
444- assert isinstance (result .dtype , SparseDtype )
445- assert result .dtype .subtype == np .dtype ("bool" )
446-
447- with np .errstate (all = "ignore" ):
448- expected = pd .Series (
449- SparseArray (
450- op (np .asarray (data ), np .asarray (other )),
451- fill_value = result .values .fill_value ,
452- )
444+ result = op (data_for_compare , other )
445+ assert isinstance (result , SparseArray )
446+ assert result .dtype .subtype == np .bool_
447+
448+ if isinstance (other , SparseArray ):
449+ expected = SparseArray (
450+ op (data_for_compare .to_dense (), np .asarray (other )),
451+ fill_value = op (data_for_compare .fill_value , other .fill_value ),
452+ dtype = np .bool_ ,
453+ )
454+ else :
455+ expected = SparseArray (
456+ op (data_for_compare .to_dense (), np .asarray (other )),
457+ fill_value = np .all (
458+ op (np .asarray (data_for_compare .fill_value ), np .asarray (other ))
459+ ),
460+ dtype = np .bool_ ,
453461 )
454462
455- tm .assert_series_equal (result , expected )
463+ tm .assert_sp_array_equal (result , expected )
456464
457- # series
458- ser = pd .Series (data )
459- result = op (ser , other )
460- tm .assert_series_equal (result , expected )
465+ def test_scalar (self , data_for_compare : SparseArray , comparison_op ):
466+ self ._compare_other (data_for_compare , comparison_op , 0 )
467+ self ._compare_other (data_for_compare , comparison_op , 1 )
468+ self ._compare_other (data_for_compare , comparison_op , - 1 )
469+ self ._compare_other (data_for_compare , comparison_op , np .nan )
470+
471+ @pytest .mark .xfail (reason = "Wrong indices" )
472+ def test_array (self , data_for_compare : SparseArray , comparison_op ):
473+ arr = np .linspace (- 4 , 5 , 10 )
474+ self ._compare_other (data_for_compare , comparison_op , arr )
475+
476+ @pytest .mark .xfail (reason = "Wrong indices" )
477+ def test_sparse_array (self , data_for_compare : SparseArray , comparison_op ):
478+ arr = data_for_compare + 1
479+ self ._compare_other (data_for_compare , comparison_op , arr )
480+ arr = data_for_compare * 2
481+ self ._compare_other (data_for_compare , comparison_op , arr )
461482
462483
463484class TestPrinting (BaseSparseTests , base .BasePrintingTests ):
0 commit comments