3939# dtype access #
4040# --------------- #
4141
42- def _ensure_data (values , dtype = None ):
42+ def _ensure_data (values , dtype = None , infer = True ):
4343 """
4444 routine to ensure that our data is of the correct
4545 input dtype for lower-level routines
@@ -57,10 +57,15 @@ def _ensure_data(values, dtype=None):
5757 values : array-like
5858 dtype : pandas_dtype, optional
5959 coerce to this dtype
60+ infer : boolean, default True
61+ infer object dtypes
6062
6163 Returns
6264 -------
63- (ndarray, pandas_dtype, algo dtype as a string)
65+ (ndarray,
66+ pandas_dtype,
67+ algo dtype as a string,
68+ inferred type as a string or None)
6469
6570 """
6671
@@ -69,28 +74,40 @@ def _ensure_data(values, dtype=None):
6974 if is_bool_dtype (values ) or is_bool_dtype (dtype ):
7075 # we are actually coercing to uint64
7176 # until our algos suppport uint8 directly (see TODO)
72- return np .asarray (values ).astype ('uint64' ), 'bool' , 'uint64'
77+ return np .asarray (values ).astype ('uint64' ), 'bool' , 'uint64' , None
7378 elif is_signed_integer_dtype (values ) or is_signed_integer_dtype (dtype ):
74- return _ensure_int64 (values ), 'int64' , 'int64'
79+ return _ensure_int64 (values ), 'int64' , 'int64' , None
7580 elif (is_unsigned_integer_dtype (values ) or
7681 is_unsigned_integer_dtype (dtype )):
77- return _ensure_uint64 (values ), 'uint64' , 'uint64'
82+ return _ensure_uint64 (values ), 'uint64' , 'uint64' , None
7883 elif is_float_dtype (values ) or is_float_dtype (dtype ):
79- return _ensure_float64 (values ), 'float64' , 'float64'
84+ return _ensure_float64 (values ), 'float64' , 'float64' , None
8085 elif is_object_dtype (values ) and dtype is None :
81- return _ensure_object (np .asarray (values )), 'object' , 'object'
86+
87+ # if we can infer a numeric then do this
88+ inferred = None
89+ if infer :
90+ inferred = lib .infer_dtype (values )
91+ if inferred in ['integer' ]:
92+ return _ensure_int64 (values ), 'int64' , 'int64' , inferred
93+ elif inferred in ['floating' ]:
94+ return (_ensure_float64 (values ),
95+ 'float64' , 'float64' , inferred )
96+
97+ return (_ensure_object (np .asarray (values )),
98+ 'object' , 'object' , inferred )
8299 elif is_complex_dtype (values ) or is_complex_dtype (dtype ):
83100
84101 # ignore the fact that we are casting to float
85102 # which discards complex parts
86103 with catch_warnings (record = True ):
87104 values = _ensure_float64 (values )
88- return values , 'float64' , 'float64'
105+ return values , 'float64' , 'float64' , None
89106
90- except (TypeError , ValueError ):
107+ except (TypeError , ValueError , OverflowError ):
91108 # if we are trying to coerce to a dtype
92109 # and it is incompat this will fall thru to here
93- return _ensure_object (values ), 'object' , 'object'
110+ return _ensure_object (values ), 'object' , 'object' , None
94111
95112 # datetimelike
96113 if (needs_i8_conversion (values ) or
@@ -111,7 +128,7 @@ def _ensure_data(values, dtype=None):
111128 values = DatetimeIndex (values )
112129 dtype = values .dtype
113130
114- return values .asi8 , dtype , 'int64'
131+ return values .asi8 , dtype , 'int64' , None
115132
116133 elif is_categorical_dtype (values ) or is_categorical_dtype (dtype ):
117134 values = getattr (values , 'values' , values )
@@ -122,11 +139,11 @@ def _ensure_data(values, dtype=None):
122139 # until our algos suppport int* directly (not all do)
123140 values = _ensure_int64 (values )
124141
125- return values , dtype , 'int64'
142+ return values , dtype , 'int64' , None
126143
127144 # we have failed, return object
128145 values = np .asarray (values )
129- return _ensure_object (values ), 'object' , 'object'
146+ return _ensure_object (values ), 'object' , 'object' , None
130147
131148
132149def _reconstruct_data (values , dtype , original ):
@@ -150,7 +167,13 @@ def _reconstruct_data(values, dtype, original):
150167 elif is_datetime64tz_dtype (dtype ) or is_period_dtype (dtype ):
151168 values = Index (original )._shallow_copy (values , name = None )
152169 elif dtype is not None :
153- values = values .astype (dtype )
170+
171+ # don't cast to object if we are numeric
172+ if is_object_dtype (dtype ):
173+ if not is_numeric_dtype (values ):
174+ values = values .astype (dtype )
175+ else :
176+ values = values .astype (dtype )
154177
155178 return values
156179
@@ -161,7 +184,7 @@ def _ensure_arraylike(values):
161184 """
162185 if not isinstance (values , (np .ndarray , ABCCategorical ,
163186 ABCIndexClass , ABCSeries )):
164- values = np .array (values )
187+ values = np .array (values , dtype = object )
165188 return values
166189
167190
@@ -174,11 +197,13 @@ def _ensure_arraylike(values):
174197}
175198
176199
177- def _get_hashtable_algo (values ):
200+ def _get_hashtable_algo (values , infer = False ):
178201 """
179202 Parameters
180203 ----------
181204 values : arraylike
205+ infer : boolean, default False
206+ infer object dtypes
182207
183208 Returns
184209 -------
@@ -188,12 +213,14 @@ def _get_hashtable_algo(values):
188213 dtype,
189214 ndtype)
190215 """
191- values , dtype , ndtype = _ensure_data (values )
216+ values , dtype , ndtype , inferred = _ensure_data (values , infer = infer )
192217
193218 if ndtype == 'object' :
194219
195220 # its cheaper to use a String Hash Table than Object
196- if lib .infer_dtype (values ) in ['string' ]:
221+ if inferred is None :
222+ inferred = lib .infer_dtype (values )
223+ if inferred in ['string' ]:
197224 ndtype = 'string'
198225 else :
199226 ndtype = 'object'
@@ -202,24 +229,43 @@ def _get_hashtable_algo(values):
202229 return (htable , table , values , dtype , ndtype )
203230
204231
205- def _get_data_algo (values , func_map ):
232+ def _get_data_algo (values , func_map , dtype = None , infer = False ):
233+ """
234+ Parameters
235+ ----------
236+ values : array-like
237+ func_map : an inferred -> function dict
238+ dtype : dtype, optional
239+ the requested dtype
240+ infer : boolean, default False
241+ infer object dtypes
242+
243+ Returns
244+ -------
245+ (function,
246+ values,
247+ ndtype)
248+ """
206249
207250 if is_categorical_dtype (values ):
208251 values = values ._values_for_rank ()
209252
210- values , dtype , ndtype = _ensure_data (values )
253+ values , dtype , ndtype , inferred = _ensure_data (
254+ values , dtype = dtype , infer = infer )
211255 if ndtype == 'object' :
212256
213257 # its cheaper to use a String Hash Table than Object
214- if lib .infer_dtype (values ) in ['string' ]:
258+ if inferred is None :
259+ inferred = lib .infer_dtype (values )
260+ if inferred in ['string' ]:
215261 try :
216262 f = func_map ['string' ]
217263 except KeyError :
218264 pass
219265
220266 f = func_map .get (ndtype , func_map ['object' ])
221267
222- return f , values
268+ return f , values , ndtype
223269
224270
225271# --------------- #
@@ -248,7 +294,7 @@ def match(to_match, values, na_sentinel=-1):
248294 """
249295 values = com ._asarray_tuplesafe (values )
250296 htable , _ , values , dtype , ndtype = _get_hashtable_algo (values )
251- to_match , _ , _ = _ensure_data (to_match , dtype )
297+ to_match , _ , _ , _ = _ensure_data (to_match , dtype )
252298 table = htable (min (len (to_match ), 1000000 ))
253299 table .map_locations (values )
254300 result = table .lookup (to_match )
@@ -344,7 +390,7 @@ def unique(values):
344390 return values .unique ()
345391
346392 original = values
347- htable , _ , values , dtype , ndtype = _get_hashtable_algo (values )
393+ htable , _ , values , dtype , ndtype = _get_hashtable_algo (values , infer = False )
348394
349395 table = htable (len (values ))
350396 uniques = table .unique (values )
@@ -389,8 +435,8 @@ def isin(comps, values):
389435 if not isinstance (values , (ABCIndex , ABCSeries , np .ndarray )):
390436 values = np .array (list (values ), dtype = 'object' )
391437
392- comps , dtype , _ = _ensure_data (comps )
393- values , _ , _ = _ensure_data (values , dtype = dtype )
438+ comps , dtype , _ , _ = _ensure_data (comps )
439+ values , _ , _ , _ = _ensure_data (values , dtype = dtype )
394440
395441 # GH11232
396442 # work-around for numpy < 1.8 and comparisions on py3
@@ -499,7 +545,7 @@ def sort_mixed(values):
499545
500546 if sorter is None :
501547 # mixed types
502- (hash_klass , _ ), values = _get_data_algo (values , _hashtables )
548+ (hash_klass , _ ), values , _ = _get_data_algo (values , _hashtables )
503549 t = hash_klass (len (values ))
504550 t .map_locations (values )
505551 sorter = _ensure_platform_int (t .lookup (ordered ))
@@ -545,8 +591,8 @@ def factorize(values, sort=False, order=None, na_sentinel=-1, size_hint=None):
545591
546592 values = _ensure_arraylike (values )
547593 original = values
548- values , dtype , _ = _ensure_data (values )
549- (hash_klass , vec_klass ), values = _get_data_algo (values , _hashtables )
594+ values , dtype , _ , _ = _ensure_data (values )
595+ (hash_klass , vec_klass ), values , _ = _get_data_algo (values , _hashtables )
550596
551597 table = hash_klass (size_hint or len (values ))
552598 uniques = vec_klass ()
@@ -660,7 +706,7 @@ def _value_counts_arraylike(values, dropna):
660706 """
661707 values = _ensure_arraylike (values )
662708 original = values
663- values , dtype , ndtype = _ensure_data (values )
709+ values , dtype , ndtype , inferred = _ensure_data (values )
664710
665711 if needs_i8_conversion (dtype ):
666712 # i8
@@ -711,7 +757,7 @@ def duplicated(values, keep='first'):
711757 duplicated : ndarray
712758 """
713759
714- values , dtype , ndtype = _ensure_data (values )
760+ values , dtype , ndtype , inferred = _ensure_data (values )
715761 f = getattr (htable , "duplicated_{dtype}" .format (dtype = ndtype ))
716762 return f (values , keep = keep )
717763
@@ -741,7 +787,7 @@ def mode(values):
741787 return Series (values .values .mode (), name = values .name )
742788 return values .mode ()
743789
744- values , dtype , ndtype = _ensure_data (values )
790+ values , dtype , ndtype , inferred = _ensure_data (values )
745791
746792 # TODO: this should support float64
747793 if ndtype not in ['int64' , 'uint64' , 'object' ]:
@@ -785,11 +831,11 @@ def rank(values, axis=0, method='average', na_option='keep',
785831 (e.g. 1, 2, 3) or in percentile form (e.g. 0.333..., 0.666..., 1).
786832 """
787833 if values .ndim == 1 :
788- f , values = _get_data_algo (values , _rank1d_functions )
834+ f , values , _ = _get_data_algo (values , _rank1d_functions )
789835 ranks = f (values , ties_method = method , ascending = ascending ,
790836 na_option = na_option , pct = pct )
791837 elif values .ndim == 2 :
792- f , values = _get_data_algo (values , _rank2d_functions )
838+ f , values , _ = _get_data_algo (values , _rank2d_functions )
793839 ranks = f (values , axis = axis , ties_method = method ,
794840 ascending = ascending , na_option = na_option , pct = pct )
795841 else :
@@ -1049,7 +1095,7 @@ def compute(self, method):
10491095 return dropped [slc ].sort_values (ascending = ascending ).head (n )
10501096
10511097 # fast method
1052- arr , _ , _ = _ensure_data (dropped .values )
1098+ arr , _ , _ , _ = _ensure_data (dropped .values )
10531099 if method == 'nlargest' :
10541100 arr = - arr
10551101
0 commit comments