@@ -393,32 +393,29 @@ def maybe_promote(dtype, fill_value=np.nan):
393393
394394 elif is_float (fill_value ):
395395 if issubclass (dtype .type , np .bool_ ):
396- dtype = np .object_
396+ dtype = np .dtype (np .object_ )
397+
397398 elif issubclass (dtype .type , np .integer ):
398399 dtype = np .dtype (np .float64 )
399- if not isna (fill_value ):
400- fill_value = dtype .type (fill_value )
401400
402401 elif dtype .kind == "f" :
403- if not np .can_cast (fill_value , dtype ):
404- # e.g. dtype is float32, need float64
405- dtype = np .min_scalar_type (fill_value )
402+ mst = np .min_scalar_type (fill_value )
403+ if mst > dtype :
404+ # e.g. mst is np.float64 and dtype is np.float32
405+ dtype = mst
406406
407407 elif dtype .kind == "c" :
408408 mst = np .min_scalar_type (fill_value )
409409 dtype = np .promote_types (dtype , mst )
410410
411- if dtype .kind == "c" and not np .isnan (fill_value ):
412- fill_value = dtype .type (fill_value )
413-
414411 elif is_bool (fill_value ):
415412 if not issubclass (dtype .type , np .bool_ ):
416- dtype = np .object_
417- else :
418- fill_value = np .bool_ (fill_value )
413+ dtype = np .dtype (np .object_ )
414+
419415 elif is_integer (fill_value ):
420416 if issubclass (dtype .type , np .bool_ ):
421417 dtype = np .dtype (np .object_ )
418+
422419 elif issubclass (dtype .type , np .integer ):
423420 if not np .can_cast (fill_value , dtype ):
424421 # upcast to prevent overflow
@@ -428,35 +425,20 @@ def maybe_promote(dtype, fill_value=np.nan):
428425 # Case where we disagree with numpy
429426 dtype = np .dtype (np .object_ )
430427
431- fill_value = dtype .type (fill_value )
432-
433- elif issubclass (dtype .type , np .floating ):
434- # check if we can cast
435- if _check_lossless_cast (fill_value , dtype ):
436- fill_value = dtype .type (fill_value )
437-
438- if dtype .kind in ["c" , "f" ]:
439- # e.g. if dtype is complex128 and fill_value is 1, we
440- # want np.complex128(1)
441- fill_value = dtype .type (fill_value )
442-
443428 elif is_complex (fill_value ):
444429 if issubclass (dtype .type , np .bool_ ):
445430 dtype = np .dtype (np .object_ )
431+
446432 elif issubclass (dtype .type , (np .integer , np .floating )):
447433 mst = np .min_scalar_type (fill_value )
448434 dtype = np .promote_types (dtype , mst )
449435
450436 elif dtype .kind == "c" :
451437 mst = np .min_scalar_type (fill_value )
452- if mst > dtype and mst . kind == "c" :
438+ if mst > dtype :
453439 # e.g. mst is np.complex128 and dtype is np.complex64
454440 dtype = mst
455441
456- if dtype .kind == "c" :
457- # make sure we have a np.complex and not python complex
458- fill_value = dtype .type (fill_value )
459-
460442 elif fill_value is None :
461443 if is_float_dtype (dtype ) or is_complex_dtype (dtype ):
462444 fill_value = np .nan
@@ -466,37 +448,48 @@ def maybe_promote(dtype, fill_value=np.nan):
466448 elif is_datetime_or_timedelta_dtype (dtype ):
467449 fill_value = dtype .type ("NaT" , "ns" )
468450 else :
469- dtype = np .object_
451+ dtype = np .dtype ( np . object_ )
470452 fill_value = np .nan
471453 else :
472- dtype = np .object_
454+ dtype = np .dtype ( np . object_ )
473455
474456 # in case we have a string that looked like a number
475457 if is_extension_array_dtype (dtype ):
476458 pass
477459 elif issubclass (np .dtype (dtype ).type , (bytes , str )):
478- dtype = np .object_
460+ dtype = np .dtype ( np . object_ )
479461
462+ fill_value = _ensure_dtype_type (fill_value , dtype )
480463 return dtype , fill_value
481464
482465
483- def _check_lossless_cast (value , dtype : np . dtype ) -> bool :
466+ def _ensure_dtype_type (value , dtype ) :
484467 """
485- Check if we can cast the given value to the given dtype _losslesly_.
468+ Ensure that the given value is an instance of the given dtype.
469+
470+ e.g. if out dtype is np.complex64, we should have an instance of that
471+ as opposed to a python complex object.
486472
487473 Parameters
488474 ----------
489475 value : object
490- dtype : np.dtype
476+ dtype : np.dtype or ExtensionDtype
491477
492478 Returns
493479 -------
494- bool
480+ object
495481 """
496- casted = dtype .type (value )
497- if casted == value :
498- return True
499- return False
482+
483+ # Start with exceptions in which we do _not_ cast to numpy types
484+ if is_extension_array_dtype (dtype ):
485+ return value
486+ elif dtype == np .object_ :
487+ return value
488+ elif isna (value ):
489+ # e.g. keep np.nan rather than try to cast to np.float32(np.nan)
490+ return value
491+
492+ return dtype .type (value )
500493
501494
502495def infer_dtype_from (val , pandas_dtype = False ):
0 commit comments