11""" common utilities """
2-
32import itertools
3+ from typing import Dict , Hashable , Union
44from warnings import catch_warnings , filterwarnings
55
66import numpy as np
@@ -29,7 +29,9 @@ def _axify(obj, key, axis):
2929class Base :
3030 """ indexing comprehensive base class """
3131
32- _objs = {"series" , "frame" }
32+ frame = None # type: Dict[str, DataFrame]
33+ series = None # type: Dict[str, Series]
34+ _kinds = {"series" , "frame" }
3335 _typs = {
3436 "ints" ,
3537 "uints" ,
@@ -101,13 +103,12 @@ def setup_method(self, method):
101103 self .series_empty = Series ()
102104
103105 # form agglomerates
104- for o in self ._objs :
105-
106- d = dict ()
107- for t in self ._typs :
108- d [t ] = getattr (self , "{o}_{t}" .format (o = o , t = t ), None )
106+ for kind in self ._kinds :
107+ d = dict () # type: Dict[str, Union[DataFrame, Series]]
108+ for typ in self ._typs :
109+ d [typ ] = getattr (self , "{kind}_{typ}" .format (kind = kind , typ = typ ))
109110
110- setattr (self , o , d )
111+ setattr (self , kind , d )
111112
112113 def generate_indices (self , f , values = False ):
113114 """ generate the indices
@@ -117,7 +118,7 @@ def generate_indices(self, f, values=False):
117118
118119 axes = f .axes
119120 if values :
120- axes = (list (range (len (a ))) for a in axes )
121+ axes = (list (range (len (ax ))) for ax in axes )
121122
122123 return itertools .product (* axes )
123124
@@ -186,34 +187,42 @@ def check_result(
186187 method2 ,
187188 key2 ,
188189 typs = None ,
189- objs = None ,
190+ kinds = None ,
190191 axes = None ,
191192 fails = None ,
192193 ):
193- def _eq (t , o , a , obj , k1 , k2 ):
194- """ compare equal for these 2 keys """
195194
196- if a is not None and a > obj .ndim - 1 :
195+ def _eq (
196+ typ : str ,
197+ kind : str ,
198+ axis : int ,
199+ obj : Union [DataFrame , Series ],
200+ key1 : Hashable ,
201+ key2 : Hashable ,
202+ ) -> None :
203+ """ compare equal for these 2 keys """
204+ if axis > obj .ndim - 1 :
197205 return
198206
199207 def _print (result , error = None ):
200- if error is not None :
201- error = str (error )
202- v = (
208+ err = str (error ) if error is not None else ""
209+ msg = (
203210 "%-16.16s [%-16.16s]: [typ->%-8.8s,obj->%-8.8s,"
204211 "key1->(%-4.4s),key2->(%-4.4s),axis->%s] %s"
205- % (name , result , t , o , method1 , method2 , a , error or "" )
212+ % (name , result , typ , kind , method1 , method2 , axis , err )
206213 )
207214 if _verbose :
208- pprint_thing (v )
215+ pprint_thing (msg )
209216
210217 try :
211- rs = getattr (obj , method1 ).__getitem__ (_axify (obj , k1 , a ))
218+ rs = getattr (obj , method1 ).__getitem__ (_axify (obj , key1 , axis ))
212219
213220 with catch_warnings (record = True ):
214221 filterwarnings ("ignore" , "\\ n.ix" , FutureWarning )
215222 try :
216- xp = self .get_result (obj , method2 , k2 , a )
223+ xp = self .get_result (
224+ obj = obj , method = method2 , key = key2 , axis = axis
225+ )
217226 except (KeyError , IndexError ):
218227 # TODO: why is this allowed?
219228 result = "no comp"
@@ -228,8 +237,8 @@ def _print(result, error=None):
228237 else :
229238 tm .assert_equal (rs , xp )
230239 result = "ok"
231- except AssertionError as e :
232- detail = str (e )
240+ except AssertionError as exc :
241+ detail = str (exc )
233242 result = "fail"
234243
235244 # reverse the checks
@@ -258,36 +267,25 @@ def _print(result, error=None):
258267 if typs is None :
259268 typs = self ._typs
260269
261- if objs is None :
262- objs = self ._objs
270+ if kinds is None :
271+ kinds = self ._kinds
263272
264- if axes is not None :
265- if not isinstance (axes , (tuple , list )):
266- axes = [axes ]
267- else :
268- axes = list (axes )
269- else :
273+ if axes is None :
270274 axes = [0 , 1 ]
275+ elif not isinstance (axes , (tuple , list )):
276+ assert isinstance (axes , int )
277+ axes = [axes ]
271278
272279 # check
273- for o in objs :
274- if o not in self ._objs :
280+ for kind in kinds : # type: str
281+ if kind not in self ._kinds :
275282 continue
276283
277- d = getattr (self , o )
278- for a in axes :
279- for t in typs :
280- if t not in self ._typs :
284+ d = getattr (self , kind ) # type: Dict[str, Union[DataFrame, Series]]
285+ for ax in axes :
286+ for typ in typs :
287+ if typ not in self ._typs :
281288 continue
282289
283- obj = d [t ]
284- if obj is None :
285- continue
286-
287- def _call (obj = obj ):
288- obj = obj .copy ()
289-
290- k2 = key2
291- _eq (t , o , a , obj , key1 , k2 )
292-
293- _call ()
290+ obj = d [typ ]
291+ _eq (typ = typ , kind = kind , axis = ax , obj = obj , key1 = key1 , key2 = key2 )
0 commit comments