22Functions for arithmetic and comparison operations on NumPy arrays and
33ExtensionArrays.
44"""
5+ from datetime import timedelta
56from functools import partial
67import operator
7- from typing import Any , Optional
8+ from typing import Any , Optional , Tuple
89
910import numpy as np
1011
2425 is_object_dtype ,
2526 is_scalar ,
2627)
27- from pandas .core .dtypes .generic import (
28- ABCDatetimeArray ,
29- ABCExtensionArray ,
30- ABCIndex ,
31- ABCSeries ,
32- ABCTimedeltaArray ,
33- )
28+ from pandas .core .dtypes .generic import ABCExtensionArray , ABCIndex , ABCSeries
3429from pandas .core .dtypes .missing import isna , notna
3530
3631from pandas .core .ops import missing
37- from pandas .core .ops .dispatch import dispatch_to_extension_op , should_extension_dispatch
32+ from pandas .core .ops .dispatch import should_extension_dispatch
3833from pandas .core .ops .invalid import invalid_comparison
3934from pandas .core .ops .roperator import rpow
4035
@@ -199,23 +194,15 @@ def arithmetic_op(left: ArrayLike, right: Any, op, str_rep: str):
199194 ndarrray or ExtensionArray
200195 Or a 2-tuple of these in the case of divmod or rdivmod.
201196 """
202- from pandas .core .ops import maybe_upcast_for_op
203197
204198 # NB: We assume that extract_array has already been called
205199 # on `left` and `right`.
206- lvalues = left
207- rvalues = right
200+ lvalues = maybe_upcast_datetimelike_array ( left )
201+ rvalues = maybe_upcast_for_op ( right , lvalues . shape )
208202
209- rvalues = maybe_upcast_for_op (rvalues , lvalues .shape )
210-
211- if should_extension_dispatch (left , rvalues ) or isinstance (
212- rvalues , (ABCTimedeltaArray , ABCDatetimeArray , Timestamp , Timedelta )
213- ):
214- # TimedeltaArray, DatetimeArray, and Timestamp are included here
215- # because they have `freq` attribute which is handled correctly
216- # by dispatch_to_extension_op.
203+ if should_extension_dispatch (lvalues , rvalues ) or isinstance (rvalues , Timedelta ):
217204 # Timedelta is included because numexpr will fail on it, see GH#31457
218- res_values = dispatch_to_extension_op ( op , lvalues , rvalues )
205+ res_values = op ( lvalues , rvalues )
219206
220207 else :
221208 with np .errstate (all = "ignore" ):
@@ -287,7 +274,7 @@ def comparison_op(
287274 ndarray or ExtensionArray
288275 """
289276 # NB: We assume extract_array has already been called on left and right
290- lvalues = left
277+ lvalues = maybe_upcast_datetimelike_array ( left )
291278 rvalues = right
292279
293280 rvalues = lib .item_from_zerodim (rvalues )
@@ -307,7 +294,8 @@ def comparison_op(
307294 )
308295
309296 if should_extension_dispatch (lvalues , rvalues ):
310- res_values = dispatch_to_extension_op (op , lvalues , rvalues )
297+ # Call the method on lvalues
298+ res_values = op (lvalues , rvalues )
311299
312300 elif is_scalar (rvalues ) and isna (rvalues ):
313301 # numpy does not like comparisons vs None
@@ -406,11 +394,12 @@ def fill_bool(x, left=None):
406394 right = construct_1d_object_array_from_listlike (right )
407395
408396 # NB: We assume extract_array has already been called on left and right
409- lvalues = left
397+ lvalues = maybe_upcast_datetimelike_array ( left )
410398 rvalues = right
411399
412400 if should_extension_dispatch (lvalues , rvalues ):
413- res_values = dispatch_to_extension_op (op , lvalues , rvalues )
401+ # Call the method on lvalues
402+ res_values = op (lvalues , rvalues )
414403
415404 else :
416405 if isinstance (rvalues , np .ndarray ):
@@ -453,3 +442,87 @@ def get_array_op(op, str_rep: Optional[str] = None):
453442 return partial (logical_op , op = op )
454443 else :
455444 return partial (arithmetic_op , op = op , str_rep = str_rep )
445+
446+
447+ def maybe_upcast_datetimelike_array (obj : ArrayLike ) -> ArrayLike :
448+ """
449+ If we have an ndarray that is either datetime64 or timedelta64, wrap in EA.
450+
451+ Parameters
452+ ----------
453+ obj : ndarray or ExtensionArray
454+
455+ Returns
456+ -------
457+ ndarray or ExtensionArray
458+ """
459+ if isinstance (obj , np .ndarray ):
460+ if obj .dtype .kind == "m" :
461+ from pandas .core .arrays import TimedeltaArray
462+
463+ return TimedeltaArray ._from_sequence (obj )
464+ if obj .dtype .kind == "M" :
465+ from pandas .core .arrays import DatetimeArray
466+
467+ return DatetimeArray ._from_sequence (obj )
468+
469+ return obj
470+
471+
472+ def maybe_upcast_for_op (obj , shape : Tuple [int , ...]):
473+ """
474+ Cast non-pandas objects to pandas types to unify behavior of arithmetic
475+ and comparison operations.
476+
477+ Parameters
478+ ----------
479+ obj: object
480+ shape : tuple[int]
481+
482+ Returns
483+ -------
484+ out : object
485+
486+ Notes
487+ -----
488+ Be careful to call this *after* determining the `name` attribute to be
489+ attached to the result of the arithmetic operation.
490+ """
491+ from pandas .core .arrays import DatetimeArray , TimedeltaArray
492+
493+ if type (obj ) is timedelta :
494+ # GH#22390 cast up to Timedelta to rely on Timedelta
495+ # implementation; otherwise operation against numeric-dtype
496+ # raises TypeError
497+ return Timedelta (obj )
498+ elif isinstance (obj , np .datetime64 ):
499+ # GH#28080 numpy casts integer-dtype to datetime64 when doing
500+ # array[int] + datetime64, which we do not allow
501+ if isna (obj ):
502+ # Avoid possible ambiguities with pd.NaT
503+ obj = obj .astype ("datetime64[ns]" )
504+ right = np .broadcast_to (obj , shape )
505+ return DatetimeArray (right )
506+
507+ return Timestamp (obj )
508+
509+ elif isinstance (obj , np .timedelta64 ):
510+ if isna (obj ):
511+ # wrapping timedelta64("NaT") in Timedelta returns NaT,
512+ # which would incorrectly be treated as a datetime-NaT, so
513+ # we broadcast and wrap in a TimedeltaArray
514+ obj = obj .astype ("timedelta64[ns]" )
515+ right = np .broadcast_to (obj , shape )
516+ return TimedeltaArray (right )
517+
518+ # In particular non-nanosecond timedelta64 needs to be cast to
519+ # nanoseconds, or else we get undesired behavior like
520+ # np.timedelta64(3, 'D') / 2 == np.timedelta64(1, 'D')
521+ return Timedelta (obj )
522+
523+ elif isinstance (obj , np .ndarray ) and obj .dtype .kind == "m" :
524+ # GH#22390 Unfortunately we need to special-case right-hand
525+ # timedelta64 dtypes because numpy casts integer dtypes to
526+ # timedelta64 when operating with timedelta64
527+ return TimedeltaArray ._from_sequence (obj )
528+ return obj
0 commit comments