1919has_c16 = hasattr (np , "complex128" )
2020
2121
22+ @pytest .fixture (params = [True , False ])
23+ def skipna (request ):
24+ """
25+ Fixture to pass skipna to nanops functions.
26+ """
27+ return request .param
28+
29+
2230class TestnanopsDataFrame :
2331 def setup_method (self , method ):
2432 np .random .seed (11235 )
@@ -89,38 +97,22 @@ def teardown_method(self, method):
8997
9098 def check_results (self , targ , res , axis , check_dtype = True ):
9199 res = getattr (res , "asm8" , res )
92- res = getattr (res , "values" , res )
93-
94- # timedeltas are a beast here
95- def _coerce_tds (targ , res ):
96- if hasattr (targ , "dtype" ) and targ .dtype == "m8[ns]" :
97- if len (targ ) == 1 :
98- targ = targ [0 ].item ()
99- res = res .item ()
100- else :
101- targ = targ .view ("i8" )
102- return targ , res
103100
104- try :
105- if (
106- axis != 0
107- and hasattr (targ , "shape" )
108- and targ .ndim
109- and targ .shape != res .shape
110- ):
111- res = np .split (res , [targ .shape [0 ]], axis = 0 )[0 ]
112- except (ValueError , IndexError ):
113- targ , res = _coerce_tds (targ , res )
101+ if (
102+ axis != 0
103+ and hasattr (targ , "shape" )
104+ and targ .ndim
105+ and targ .shape != res .shape
106+ ):
107+ res = np .split (res , [targ .shape [0 ]], axis = 0 )[0 ]
114108
115109 try :
116110 tm .assert_almost_equal (targ , res , check_dtype = check_dtype )
117111 except AssertionError :
118112
119113 # handle timedelta dtypes
120114 if hasattr (targ , "dtype" ) and targ .dtype == "m8[ns]" :
121- targ , res = _coerce_tds (targ , res )
122- tm .assert_almost_equal (targ , res , check_dtype = check_dtype )
123- return
115+ raise
124116
125117 # There are sometimes rounding errors with
126118 # complex and object dtypes.
@@ -149,29 +141,29 @@ def check_fun_data(
149141 targfunc ,
150142 testarval ,
151143 targarval ,
144+ skipna ,
152145 check_dtype = True ,
153146 empty_targfunc = None ,
154147 ** kwargs ,
155148 ):
156149 for axis in list (range (targarval .ndim )) + [None ]:
157- for skipna in [False , True ]:
158- targartempval = targarval if skipna else testarval
159- if skipna and empty_targfunc and isna (targartempval ).all ():
160- targ = empty_targfunc (targartempval , axis = axis , ** kwargs )
161- else :
162- targ = targfunc (targartempval , axis = axis , ** kwargs )
150+ targartempval = targarval if skipna else testarval
151+ if skipna and empty_targfunc and isna (targartempval ).all ():
152+ targ = empty_targfunc (targartempval , axis = axis , ** kwargs )
153+ else :
154+ targ = targfunc (targartempval , axis = axis , ** kwargs )
163155
164- res = testfunc (testarval , axis = axis , skipna = skipna , ** kwargs )
156+ res = testfunc (testarval , axis = axis , skipna = skipna , ** kwargs )
157+ self .check_results (targ , res , axis , check_dtype = check_dtype )
158+ if skipna :
159+ res = testfunc (testarval , axis = axis , ** kwargs )
160+ self .check_results (targ , res , axis , check_dtype = check_dtype )
161+ if axis is None :
162+ res = testfunc (testarval , skipna = skipna , ** kwargs )
163+ self .check_results (targ , res , axis , check_dtype = check_dtype )
164+ if skipna and axis is None :
165+ res = testfunc (testarval , ** kwargs )
165166 self .check_results (targ , res , axis , check_dtype = check_dtype )
166- if skipna :
167- res = testfunc (testarval , axis = axis , ** kwargs )
168- self .check_results (targ , res , axis , check_dtype = check_dtype )
169- if axis is None :
170- res = testfunc (testarval , skipna = skipna , ** kwargs )
171- self .check_results (targ , res , axis , check_dtype = check_dtype )
172- if skipna and axis is None :
173- res = testfunc (testarval , ** kwargs )
174- self .check_results (targ , res , axis , check_dtype = check_dtype )
175167
176168 if testarval .ndim <= 1 :
177169 return
@@ -184,12 +176,15 @@ def check_fun_data(
184176 targfunc ,
185177 testarval2 ,
186178 targarval2 ,
179+ skipna = skipna ,
187180 check_dtype = check_dtype ,
188181 empty_targfunc = empty_targfunc ,
189182 ** kwargs ,
190183 )
191184
192- def check_fun (self , testfunc , targfunc , testar , empty_targfunc = None , ** kwargs ):
185+ def check_fun (
186+ self , testfunc , targfunc , testar , skipna , empty_targfunc = None , ** kwargs
187+ ):
193188
194189 targar = testar
195190 if testar .endswith ("_nan" ) and hasattr (self , testar [:- 4 ]):
@@ -202,6 +197,7 @@ def check_fun(self, testfunc, targfunc, testar, empty_targfunc=None, **kwargs):
202197 targfunc ,
203198 testarval ,
204199 targarval ,
200+ skipna = skipna ,
205201 empty_targfunc = empty_targfunc ,
206202 ** kwargs ,
207203 )
@@ -210,36 +206,37 @@ def check_funs(
210206 self ,
211207 testfunc ,
212208 targfunc ,
209+ skipna ,
213210 allow_complex = True ,
214211 allow_all_nan = True ,
215212 allow_date = True ,
216213 allow_tdelta = True ,
217214 allow_obj = True ,
218215 ** kwargs ,
219216 ):
220- self .check_fun (testfunc , targfunc , "arr_float" , ** kwargs )
221- self .check_fun (testfunc , targfunc , "arr_float_nan" , ** kwargs )
222- self .check_fun (testfunc , targfunc , "arr_int" , ** kwargs )
223- self .check_fun (testfunc , targfunc , "arr_bool" , ** kwargs )
217+ self .check_fun (testfunc , targfunc , "arr_float" , skipna , ** kwargs )
218+ self .check_fun (testfunc , targfunc , "arr_float_nan" , skipna , ** kwargs )
219+ self .check_fun (testfunc , targfunc , "arr_int" , skipna , ** kwargs )
220+ self .check_fun (testfunc , targfunc , "arr_bool" , skipna , ** kwargs )
224221 objs = [
225222 self .arr_float .astype ("O" ),
226223 self .arr_int .astype ("O" ),
227224 self .arr_bool .astype ("O" ),
228225 ]
229226
230227 if allow_all_nan :
231- self .check_fun (testfunc , targfunc , "arr_nan" , ** kwargs )
228+ self .check_fun (testfunc , targfunc , "arr_nan" , skipna , ** kwargs )
232229
233230 if allow_complex :
234- self .check_fun (testfunc , targfunc , "arr_complex" , ** kwargs )
235- self .check_fun (testfunc , targfunc , "arr_complex_nan" , ** kwargs )
231+ self .check_fun (testfunc , targfunc , "arr_complex" , skipna , ** kwargs )
232+ self .check_fun (testfunc , targfunc , "arr_complex_nan" , skipna , ** kwargs )
236233 if allow_all_nan :
237- self .check_fun (testfunc , targfunc , "arr_nan_nanj" , ** kwargs )
234+ self .check_fun (testfunc , targfunc , "arr_nan_nanj" , skipna , ** kwargs )
238235 objs += [self .arr_complex .astype ("O" )]
239236
240237 if allow_date :
241238 targfunc (self .arr_date )
242- self .check_fun (testfunc , targfunc , "arr_date" , ** kwargs )
239+ self .check_fun (testfunc , targfunc , "arr_date" , skipna , ** kwargs )
243240 objs += [self .arr_date .astype ("O" )]
244241
245242 if allow_tdelta :
@@ -248,7 +245,7 @@ def check_funs(
248245 except TypeError :
249246 pass
250247 else :
251- self .check_fun (testfunc , targfunc , "arr_tdelta" , ** kwargs )
248+ self .check_fun (testfunc , targfunc , "arr_tdelta" , skipna , ** kwargs )
252249 objs += [self .arr_tdelta .astype ("O" )]
253250
254251 if allow_obj :
@@ -260,7 +257,7 @@ def check_funs(
260257 targfunc = partial (
261258 self ._badobj_wrap , func = targfunc , allow_complex = allow_complex
262259 )
263- self .check_fun (testfunc , targfunc , "arr_obj" , ** kwargs )
260+ self .check_fun (testfunc , targfunc , "arr_obj" , skipna , ** kwargs )
264261
265262 def _badobj_wrap (self , value , func , allow_complex = True , ** kwargs ):
266263 if value .dtype .kind == "O" :
@@ -273,28 +270,22 @@ def _badobj_wrap(self, value, func, allow_complex=True, **kwargs):
273270 @pytest .mark .parametrize (
274271 "nan_op,np_op" , [(nanops .nanany , np .any ), (nanops .nanall , np .all )]
275272 )
276- def test_nan_funcs (self , nan_op , np_op ):
277- # TODO: allow tdelta, doesn't break tests
278- self .check_funs (
279- nan_op , np_op , allow_all_nan = False , allow_date = False , allow_tdelta = False
280- )
273+ def test_nan_funcs (self , nan_op , np_op , skipna ):
274+ self .check_funs (nan_op , np_op , skipna , allow_all_nan = False , allow_date = False )
281275
282- def test_nansum (self ):
276+ def test_nansum (self , skipna ):
283277 self .check_funs (
284278 nanops .nansum ,
285279 np .sum ,
280+ skipna ,
286281 allow_date = False ,
287282 check_dtype = False ,
288283 empty_targfunc = np .nansum ,
289284 )
290285
291- def test_nanmean (self ):
286+ def test_nanmean (self , skipna ):
292287 self .check_funs (
293- nanops .nanmean ,
294- np .mean ,
295- allow_complex = False , # TODO: allow this, doesn't break test
296- allow_obj = False ,
297- allow_date = False ,
288+ nanops .nanmean , np .mean , skipna , allow_obj = False , allow_date = False ,
298289 )
299290
300291 def test_nanmean_overflow (self ):
@@ -336,33 +327,36 @@ def test_returned_dtype(self, dtype):
336327 else :
337328 assert result .dtype == dtype
338329
339- def test_nanmedian (self ):
330+ def test_nanmedian (self , skipna ):
340331 with warnings .catch_warnings (record = True ):
341332 warnings .simplefilter ("ignore" , RuntimeWarning )
342333 self .check_funs (
343334 nanops .nanmedian ,
344335 np .median ,
336+ skipna ,
345337 allow_complex = False ,
346338 allow_date = False ,
347339 allow_obj = "convert" ,
348340 )
349341
350342 @pytest .mark .parametrize ("ddof" , range (3 ))
351- def test_nanvar (self , ddof ):
343+ def test_nanvar (self , ddof , skipna ):
352344 self .check_funs (
353345 nanops .nanvar ,
354346 np .var ,
347+ skipna ,
355348 allow_complex = False ,
356349 allow_date = False ,
357350 allow_obj = "convert" ,
358351 ddof = ddof ,
359352 )
360353
361354 @pytest .mark .parametrize ("ddof" , range (3 ))
362- def test_nanstd (self , ddof ):
355+ def test_nanstd (self , ddof , skipna ):
363356 self .check_funs (
364357 nanops .nanstd ,
365358 np .std ,
359+ skipna ,
366360 allow_complex = False ,
367361 allow_date = False ,
368362 allow_obj = "convert" ,
@@ -371,13 +365,14 @@ def test_nanstd(self, ddof):
371365
372366 @td .skip_if_no_scipy
373367 @pytest .mark .parametrize ("ddof" , range (3 ))
374- def test_nansem (self , ddof ):
368+ def test_nansem (self , ddof , skipna ):
375369 from scipy .stats import sem
376370
377371 with np .errstate (invalid = "ignore" ):
378372 self .check_funs (
379373 nanops .nansem ,
380374 sem ,
375+ skipna ,
381376 allow_complex = False ,
382377 allow_date = False ,
383378 allow_tdelta = False ,
@@ -388,10 +383,10 @@ def test_nansem(self, ddof):
388383 @pytest .mark .parametrize (
389384 "nan_op,np_op" , [(nanops .nanmin , np .min ), (nanops .nanmax , np .max )]
390385 )
391- def test_nanops_with_warnings (self , nan_op , np_op ):
386+ def test_nanops_with_warnings (self , nan_op , np_op , skipna ):
392387 with warnings .catch_warnings (record = True ):
393388 warnings .simplefilter ("ignore" , RuntimeWarning )
394- self .check_funs (nan_op , np_op , allow_obj = False )
389+ self .check_funs (nan_op , np_op , skipna , allow_obj = False )
395390
396391 def _argminmax_wrap (self , value , axis = None , func = None ):
397392 res = func (value , axis )
@@ -408,17 +403,17 @@ def _argminmax_wrap(self, value, axis=None, func=None):
408403 res = - 1
409404 return res
410405
411- def test_nanargmax (self ):
406+ def test_nanargmax (self , skipna ):
412407 with warnings .catch_warnings (record = True ):
413408 warnings .simplefilter ("ignore" , RuntimeWarning )
414409 func = partial (self ._argminmax_wrap , func = np .argmax )
415- self .check_funs (nanops .nanargmax , func , allow_obj = False )
410+ self .check_funs (nanops .nanargmax , func , skipna , allow_obj = False )
416411
417- def test_nanargmin (self ):
412+ def test_nanargmin (self , skipna ):
418413 with warnings .catch_warnings (record = True ):
419414 warnings .simplefilter ("ignore" , RuntimeWarning )
420415 func = partial (self ._argminmax_wrap , func = np .argmin )
421- self .check_funs (nanops .nanargmin , func , allow_obj = False )
416+ self .check_funs (nanops .nanargmin , func , skipna , allow_obj = False )
422417
423418 def _skew_kurt_wrap (self , values , axis = None , func = None ):
424419 if not isinstance (values .dtype .type , np .floating ):
@@ -433,21 +428,22 @@ def _skew_kurt_wrap(self, values, axis=None, func=None):
433428 return result
434429
435430 @td .skip_if_no_scipy
436- def test_nanskew (self ):
431+ def test_nanskew (self , skipna ):
437432 from scipy .stats import skew
438433
439434 func = partial (self ._skew_kurt_wrap , func = skew )
440435 with np .errstate (invalid = "ignore" ):
441436 self .check_funs (
442437 nanops .nanskew ,
443438 func ,
439+ skipna ,
444440 allow_complex = False ,
445441 allow_date = False ,
446442 allow_tdelta = False ,
447443 )
448444
449445 @td .skip_if_no_scipy
450- def test_nankurt (self ):
446+ def test_nankurt (self , skipna ):
451447 from scipy .stats import kurtosis
452448
453449 func1 = partial (kurtosis , fisher = True )
@@ -456,15 +452,17 @@ def test_nankurt(self):
456452 self .check_funs (
457453 nanops .nankurt ,
458454 func ,
455+ skipna ,
459456 allow_complex = False ,
460457 allow_date = False ,
461458 allow_tdelta = False ,
462459 )
463460
464- def test_nanprod (self ):
461+ def test_nanprod (self , skipna ):
465462 self .check_funs (
466463 nanops .nanprod ,
467464 np .prod ,
465+ skipna ,
468466 allow_date = False ,
469467 allow_tdelta = False ,
470468 empty_targfunc = np .nanprod ,
0 commit comments