11""" common utilities """
2-
32import itertools
3+ from typing import Dict , Hashable , Union
44from warnings import catch_warnings , filterwarnings
55
66import numpy as np
@@ -29,7 +29,9 @@ def _axify(obj, key, axis):
2929class Base :
3030 """ indexing comprehensive base class """
3131
32- _objs = {"series" , "frame" }
32+ frame = None # type: Dict[str, DataFrame]
33+ series = None # type: Dict[str, Series]
34+ _kinds = {"series" , "frame" }
3335 _typs = {
3436 "ints" ,
3537 "uints" ,
@@ -101,13 +103,12 @@ def setup_method(self, method):
101103 self .series_empty = Series ()
102104
103105 # form agglomerates
104- for o in self ._objs :
105-
106- d = dict ()
107- for t in self ._typs :
108- d [t ] = getattr (self , "{o}_{t}" .format (o = o , t = t ), None )
106+ for kind in self ._kinds :
107+ d = dict () # type: Dict[str, Union[DataFrame, Series]]
108+ for typ in self ._typs :
109+ d [typ ] = getattr (self , "{kind}_{typ}" .format (kind = kind , typ = typ ))
109110
110- setattr (self , o , d )
111+ setattr (self , kind , d )
111112
112113 def generate_indices (self , f , values = False ):
113114 """ generate the indices
@@ -117,7 +118,7 @@ def generate_indices(self, f, values=False):
117118
118119 axes = f .axes
119120 if values :
120- axes = (list (range (len (a ))) for a in axes )
121+ axes = (list (range (len (ax ))) for ax in axes )
121122
122123 return itertools .product (* axes )
123124
@@ -186,34 +187,41 @@ def check_result(
186187 method2 ,
187188 key2 ,
188189 typs = None ,
189- objs = None ,
190+ kinds = None ,
190191 axes = None ,
191192 fails = None ,
192193 ):
193- def _eq (t , o , a , obj , k1 , k2 ):
194+ def _eq (
195+ typ : str ,
196+ kind : str ,
197+ axis : int ,
198+ obj : Union [DataFrame , Series ],
199+ key1 : Hashable ,
200+ key2 : Hashable ,
201+ ) -> None :
194202 """ compare equal for these 2 keys """
195-
196- if a is not None and a > obj .ndim - 1 :
203+ if axis > obj .ndim - 1 :
197204 return
198205
199206 def _print (result , error = None ):
200- if error is not None :
201- error = str (error )
202- v = (
207+ err = str (error ) if error is not None else ""
208+ msg = (
203209 "%-16.16s [%-16.16s]: [typ->%-8.8s,obj->%-8.8s,"
204210 "key1->(%-4.4s),key2->(%-4.4s),axis->%s] %s"
205- % (name , result , t , o , method1 , method2 , a , error or "" )
211+ % (name , result , typ , kind , method1 , method2 , axis , err )
206212 )
207213 if _verbose :
208- pprint_thing (v )
214+ pprint_thing (msg )
209215
210216 try :
211- rs = getattr (obj , method1 ).__getitem__ (_axify (obj , k1 , a ))
217+ rs = getattr (obj , method1 ).__getitem__ (_axify (obj , key1 , axis ))
212218
213219 with catch_warnings (record = True ):
214220 filterwarnings ("ignore" , "\\ n.ix" , FutureWarning )
215221 try :
216- xp = self .get_result (obj , method2 , k2 , a )
222+ xp = self .get_result (
223+ obj = obj , method = method2 , key = key2 , axis = axis
224+ )
217225 except (KeyError , IndexError ):
218226 # TODO: why is this allowed?
219227 result = "no comp"
@@ -228,8 +236,8 @@ def _print(result, error=None):
228236 else :
229237 tm .assert_equal (rs , xp )
230238 result = "ok"
231- except AssertionError as e :
232- detail = str (e )
239+ except AssertionError as exc :
240+ detail = str (exc )
233241 result = "fail"
234242
235243 # reverse the checks
@@ -258,36 +266,25 @@ def _print(result, error=None):
258266 if typs is None :
259267 typs = self ._typs
260268
261- if objs is None :
262- objs = self ._objs
269+ if kinds is None :
270+ kinds = self ._kinds
263271
264- if axes is not None :
265- if not isinstance (axes , (tuple , list )):
266- axes = [axes ]
267- else :
268- axes = list (axes )
269- else :
272+ if axes is None :
270273 axes = [0 , 1 ]
274+ elif not isinstance (axes , (tuple , list )):
275+ assert isinstance (axes , int )
276+ axes = [axes ]
271277
272278 # check
273- for o in objs :
274- if o not in self ._objs :
279+ for kind in kinds : # type: str
280+ if kind not in self ._kinds :
275281 continue
276282
277- d = getattr (self , o )
278- for a in axes :
279- for t in typs :
280- if t not in self ._typs :
283+ d = getattr (self , kind ) # type: Dict[str, Union[DataFrame, Series]]
284+ for ax in axes :
285+ for typ in typs :
286+ if typ not in self ._typs :
281287 continue
282288
283- obj = d [t ]
284- if obj is None :
285- continue
286-
287- def _call (obj = obj ):
288- obj = obj .copy ()
289-
290- k2 = key2
291- _eq (t , o , a , obj , key1 , k2 )
292-
293- _call ()
289+ obj = d [typ ]
290+ _eq (typ = typ , kind = kind , axis = ax , obj = obj , key1 = key1 , key2 = key2 )
0 commit comments