11import functools
22import itertools
33import operator
4- from typing import Any , Optional , Tuple , Union
4+ from typing import Any , Callable , Optional , Tuple , Union
55
66import numpy as np
77
88from pandas ._config import get_option
99
1010from pandas ._libs import NaT , Timedelta , Timestamp , iNaT , lib
11+ from pandas ._typing import Dtype , Scalar
1112from pandas .compat ._optional import import_optional_dependency
1213
1314from pandas .core .dtypes .cast import _int64_max , maybe_upcast_putmask
3738_USE_BOTTLENECK = False
3839
3940
40- def set_use_bottleneck (v = True ):
41+ def set_use_bottleneck (v : bool = True ) -> None :
4142 # set/unset to use bottleneck
4243 global _USE_BOTTLENECK
4344 if _BOTTLENECK_INSTALLED :
@@ -55,7 +56,7 @@ def __init__(self, *dtypes):
5556 def check (self , obj ) -> bool :
5657 return hasattr (obj , "dtype" ) and issubclass (obj .dtype .type , self .dtypes )
5758
58- def __call__ (self , f ):
59+ def __call__ (self , f ) -> Callable :
5960 @functools .wraps (f )
6061 def _f (* args , ** kwargs ):
6162 obj_iter = itertools .chain (args , kwargs .values ())
@@ -80,11 +81,11 @@ def _f(*args, **kwargs):
8081
8182
8283class bottleneck_switch :
83- def __init__ (self , name = None , ** kwargs ):
84+ def __init__ (self , name : Optional [ str ] = None , ** kwargs ):
8485 self .name = name
8586 self .kwargs = kwargs
8687
87- def __call__ (self , alt ) :
88+ def __call__ (self , alt : Callable ) -> Callable :
8889 bn_name = self .name or alt .__name__
8990
9091 try :
@@ -93,7 +94,9 @@ def __call__(self, alt):
9394 bn_func = None
9495
9596 @functools .wraps (alt )
96- def f (values , axis = None , skipna = True , ** kwds ):
97+ def f (
98+ values : np .ndarray , axis : Optional [int ] = None , skipna : bool = True , ** kwds
99+ ):
97100 if len (self .kwargs ) > 0 :
98101 for k , v in self .kwargs .items ():
99102 if k not in kwds :
@@ -129,7 +132,7 @@ def f(values, axis=None, skipna=True, **kwds):
129132 return f
130133
131134
132- def _bn_ok_dtype (dt , name : str ) -> bool :
135+ def _bn_ok_dtype (dt : Dtype , name : str ) -> bool :
133136 # Bottleneck chokes on datetime64
134137 if not is_object_dtype (dt ) and not (
135138 is_datetime_or_timedelta_dtype (dt ) or is_datetime64tz_dtype (dt )
@@ -163,7 +166,9 @@ def _has_infs(result) -> bool:
163166 return False
164167
165168
166- def _get_fill_value (dtype , fill_value = None , fill_value_typ = None ):
169+ def _get_fill_value (
170+ dtype : Dtype , fill_value : Any = None , fill_value_typ : Optional [str ] = None
171+ ):
167172 """ return the correct fill value for the dtype of the values """
168173 if fill_value is not None :
169174 return fill_value
@@ -326,12 +331,12 @@ def _get_values(
326331 return values , mask , dtype , dtype_max , fill_value
327332
328333
329- def _na_ok_dtype (dtype ):
334+ def _na_ok_dtype (dtype ) -> bool :
330335 # TODO: what about datetime64tz? PeriodDtype?
331336 return not issubclass (dtype .type , (np .integer , np .timedelta64 , np .datetime64 ))
332337
333338
334- def _wrap_results (result , dtype , fill_value = None ):
339+ def _wrap_results (result , dtype : Dtype , fill_value = None ):
335340 """ wrap our results if needed """
336341
337342 if is_datetime64_dtype (dtype ) or is_datetime64tz_dtype (dtype ):
@@ -362,7 +367,9 @@ def _wrap_results(result, dtype, fill_value=None):
362367 return result
363368
364369
365- def _na_for_min_count (values , axis : Optional [int ]):
370+ def _na_for_min_count (
371+ values : np .ndarray , axis : Optional [int ]
372+ ) -> Union [Scalar , np .ndarray ]:
366373 """
367374 Return the missing value for `values`.
368375
@@ -393,7 +400,12 @@ def _na_for_min_count(values, axis: Optional[int]):
393400 return result
394401
395402
396- def nanany (values , axis = None , skipna : bool = True , mask = None ):
403+ def nanany (
404+ values : np .ndarray ,
405+ axis : Optional [int ] = None ,
406+ skipna : bool = True ,
407+ mask : Optional [np .ndarray ] = None ,
408+ ) -> bool :
397409 """
398410 Check if any elements along an axis evaluate to True.
399411
@@ -425,7 +437,12 @@ def nanany(values, axis=None, skipna: bool = True, mask=None):
425437 return values .any (axis )
426438
427439
428- def nanall (values , axis = None , skipna : bool = True , mask = None ):
440+ def nanall (
441+ values : np .ndarray ,
442+ axis : Optional [int ] = None ,
443+ skipna : bool = True ,
444+ mask : Optional [np .ndarray ] = None ,
445+ ) -> bool :
429446 """
430447 Check if all elements along an axis evaluate to True.
431448
@@ -458,7 +475,13 @@ def nanall(values, axis=None, skipna: bool = True, mask=None):
458475
459476
460477@disallow ("M8" )
461- def nansum (values , axis = None , skipna = True , min_count = 0 , mask = None ):
478+ def nansum (
479+ values : np .ndarray ,
480+ axis : Optional [int ] = None ,
481+ skipna : bool = True ,
482+ min_count : int = 0 ,
483+ mask : Optional [np .ndarray ] = None ,
484+ ) -> Dtype :
462485 """
463486 Sum the elements along an axis ignoring NaNs
464487
@@ -629,7 +652,7 @@ def _get_counts_nanvar(
629652 mask : Optional [np .ndarray ],
630653 axis : Optional [int ],
631654 ddof : int ,
632- dtype = float ,
655+ dtype : Dtype = float ,
633656) -> Tuple [Union [int , np .ndarray ], Union [int , np .ndarray ]]:
634657 """ Get the count of non-null values along an axis, accounting
635658 for degrees of freedom.
@@ -776,7 +799,13 @@ def nanvar(values, axis=None, skipna=True, ddof=1, mask=None):
776799
777800
778801@disallow ("M8" , "m8" )
779- def nansem (values , axis = None , skipna = True , ddof = 1 , mask = None ):
802+ def nansem (
803+ values : np .ndarray ,
804+ axis : Optional [int ] = None ,
805+ skipna : bool = True ,
806+ ddof : int = 1 ,
807+ mask : Optional [np .ndarray ] = None ,
808+ ) -> float :
780809 """
781810 Compute the standard error in the mean along given axis while ignoring NaNs
782811
@@ -819,9 +848,14 @@ def nansem(values, axis=None, skipna=True, ddof=1, mask=None):
819848 return np .sqrt (var ) / np .sqrt (count )
820849
821850
822- def _nanminmax (meth , fill_value_typ ) :
851+ def _nanminmax (meth : str , fill_value_typ : str ) -> Callable :
823852 @bottleneck_switch (name = "nan" + meth )
824- def reduction (values , axis = None , skipna = True , mask = None ):
853+ def reduction (
854+ values : np .ndarray ,
855+ axis : Optional [int ] = None ,
856+ skipna : bool = True ,
857+ mask : Optional [np .ndarray ] = None ,
858+ ) -> np .ndarray :
825859
826860 values , mask , dtype , dtype_max , fill_value = _get_values (
827861 values , skipna , fill_value_typ = fill_value_typ , mask = mask
@@ -847,7 +881,12 @@ def reduction(values, axis=None, skipna=True, mask=None):
847881
848882
849883@disallow ("O" )
850- def nanargmax (values , axis = None , skipna = True , mask = None ):
884+ def nanargmax (
885+ values : np .ndarray ,
886+ axis : Optional [int ] = None ,
887+ skipna : bool = True ,
888+ mask : Optional [np .ndarray ] = None ,
889+ ) -> int :
851890 """
852891 Parameters
853892 ----------
@@ -878,7 +917,12 @@ def nanargmax(values, axis=None, skipna=True, mask=None):
878917
879918
880919@disallow ("O" )
881- def nanargmin (values , axis = None , skipna = True , mask = None ):
920+ def nanargmin (
921+ values : np .ndarray ,
922+ axis : Optional [int ] = None ,
923+ skipna : bool = True ,
924+ mask : Optional [np .ndarray ] = None ,
925+ ) -> int :
882926 """
883927 Parameters
884928 ----------
@@ -909,7 +953,12 @@ def nanargmin(values, axis=None, skipna=True, mask=None):
909953
910954
911955@disallow ("M8" , "m8" )
912- def nanskew (values , axis = None , skipna = True , mask = None ):
956+ def nanskew (
957+ values : np .ndarray ,
958+ axis : Optional [int ] = None ,
959+ skipna : bool = True ,
960+ mask : Optional [np .ndarray ] = None ,
961+ ) -> float :
913962 """ Compute the sample skewness.
914963
915964 The statistic computed here is the adjusted Fisher-Pearson standardized
@@ -987,7 +1036,12 @@ def nanskew(values, axis=None, skipna=True, mask=None):
9871036
9881037
9891038@disallow ("M8" , "m8" )
990- def nankurt (values , axis = None , skipna = True , mask = None ):
1039+ def nankurt (
1040+ values : np .ndarray ,
1041+ axis : Optional [int ] = None ,
1042+ skipna : bool = True ,
1043+ mask : Optional [np .ndarray ] = None ,
1044+ ) -> float :
9911045 """
9921046 Compute the sample excess kurtosis
9931047
@@ -1075,7 +1129,13 @@ def nankurt(values, axis=None, skipna=True, mask=None):
10751129
10761130
10771131@disallow ("M8" , "m8" )
1078- def nanprod (values , axis = None , skipna = True , min_count = 0 , mask = None ):
1132+ def nanprod (
1133+ values : np .ndarray ,
1134+ axis : Optional [int ] = None ,
1135+ skipna : bool = True ,
1136+ min_count : int = 0 ,
1137+ mask : Optional [np .ndarray ] = None ,
1138+ ) -> Dtype :
10791139 """
10801140 Parameters
10811141 ----------
@@ -1138,7 +1198,7 @@ def _get_counts(
11381198 values_shape : Tuple [int ],
11391199 mask : Optional [np .ndarray ],
11401200 axis : Optional [int ],
1141- dtype = float ,
1201+ dtype : Dtype = float ,
11421202) -> Union [int , np .ndarray ]:
11431203 """ Get the count of non-null values along an axis
11441204
@@ -1218,7 +1278,12 @@ def _zero_out_fperr(arg):
12181278
12191279
12201280@disallow ("M8" , "m8" )
1221- def nancorr (a , b , method = "pearson" , min_periods = None ):
1281+ def nancorr (
1282+ a : np .ndarray ,
1283+ b : np .ndarray ,
1284+ method : str = "pearson" ,
1285+ min_periods : Optional [int ] = None ,
1286+ ):
12221287 """
12231288 a, b: ndarrays
12241289 """
@@ -1240,7 +1305,7 @@ def nancorr(a, b, method="pearson", min_periods=None):
12401305 return f (a , b )
12411306
12421307
1243- def get_corr_func (method ):
1308+ def get_corr_func (method : str ):
12441309 if method in ["kendall" , "spearman" ]:
12451310 from scipy .stats import kendalltau , spearmanr
12461311 elif callable (method ):
@@ -1262,7 +1327,7 @@ def _spearman(a, b):
12621327
12631328
12641329@disallow ("M8" , "m8" )
1265- def nancov (a , b , min_periods = None ):
1330+ def nancov (a : np . ndarray , b : np . ndarray , min_periods : Optional [ int ] = None ):
12661331 if len (a ) != len (b ):
12671332 raise AssertionError ("Operands to nancov must have same size" )
12681333
@@ -1308,7 +1373,7 @@ def _ensure_numeric(x):
13081373# NA-friendly array comparisons
13091374
13101375
1311- def make_nancomp (op ):
1376+ def make_nancomp (op ) -> Callable :
13121377 def f (x , y ):
13131378 xmask = isna (x )
13141379 ymask = isna (y )
@@ -1335,7 +1400,9 @@ def f(x, y):
13351400nanne = make_nancomp (operator .ne )
13361401
13371402
1338- def _nanpercentile_1d (values , mask , q , na_value , interpolation ):
1403+ def _nanpercentile_1d (
1404+ values : np .ndarray , mask : np .ndarray , q , na_value : Scalar , interpolation : str
1405+ ) -> Union [Scalar , np .ndarray ]:
13391406 """
13401407 Wrapper for np.percentile that skips missing values, specialized to
13411408 1-dimensional case.
@@ -1366,7 +1433,15 @@ def _nanpercentile_1d(values, mask, q, na_value, interpolation):
13661433 return np .percentile (values , q , interpolation = interpolation )
13671434
13681435
1369- def nanpercentile (values , q , axis , na_value , mask , ndim , interpolation ):
1436+ def nanpercentile (
1437+ values : np .ndarray ,
1438+ q ,
1439+ axis : int ,
1440+ na_value ,
1441+ mask : np .ndarray ,
1442+ ndim : int ,
1443+ interpolation : str ,
1444+ ):
13701445 """
13711446 Wrapper for np.percentile that skips missing values.
13721447
0 commit comments