55"""
66import datetime
77import operator
8- from typing import Any , Callable , Tuple
8+ from typing import Any , Callable , Tuple , Union
99
1010import numpy as np
1111
3434 ABCIndexClass ,
3535 ABCSeries ,
3636 ABCSparseSeries ,
37+ ABCTimedeltaArray ,
38+ ABCTimedeltaIndex ,
3739)
3840from pandas .core .dtypes .missing import isna , notna
3941
40- import pandas as pd
4142from pandas ._typing import ArrayLike
4243from pandas .core .construction import array , extract_array
4344from pandas .core .ops .array_ops import comp_method_OBJECT_ARRAY , define_na_arithmetic_op
@@ -148,6 +149,8 @@ def maybe_upcast_for_op(obj, shape: Tuple[int, ...]):
148149 Be careful to call this *after* determining the `name` attribute to be
149150 attached to the result of the arithmetic operation.
150151 """
152+ from pandas .core .arrays import TimedeltaArray
153+
151154 if type (obj ) is datetime .timedelta :
152155 # GH#22390 cast up to Timedelta to rely on Timedelta
153156 # implementation; otherwise operation against numeric-dtype
@@ -157,12 +160,10 @@ def maybe_upcast_for_op(obj, shape: Tuple[int, ...]):
157160 if isna (obj ):
158161 # wrapping timedelta64("NaT") in Timedelta returns NaT,
159162 # which would incorrectly be treated as a datetime-NaT, so
160- # we broadcast and wrap in a Series
163+ # we broadcast and wrap in a TimedeltaArray
164+ obj = obj .astype ("timedelta64[ns]" )
161165 right = np .broadcast_to (obj , shape )
162-
163- # Note: we use Series instead of TimedeltaIndex to avoid having
164- # to worry about catching NullFrequencyError.
165- return pd .Series (right )
166+ return TimedeltaArray (right )
166167
167168 # In particular non-nanosecond timedelta64 needs to be cast to
168169 # nanoseconds, or else we get undesired behavior like
@@ -173,7 +174,7 @@ def maybe_upcast_for_op(obj, shape: Tuple[int, ...]):
173174 # GH#22390 Unfortunately we need to special-case right-hand
174175 # timedelta64 dtypes because numpy casts integer dtypes to
175176 # timedelta64 when operating with timedelta64
176- return pd . TimedeltaIndex (obj )
177+ return TimedeltaArray . _from_sequence (obj )
177178 return obj
178179
179180
@@ -520,13 +521,34 @@ def column_op(a, b):
520521 return result
521522
522523
523- def dispatch_to_extension_op (op , left , right ):
524+ def dispatch_to_extension_op (
525+ op ,
526+ left : Union [ABCExtensionArray , np .ndarray ],
527+ right : Any ,
528+ keep_null_freq : bool = False ,
529+ ):
524530 """
525531 Assume that left or right is a Series backed by an ExtensionArray,
526532 apply the operator defined by op.
533+
534+ Parameters
535+ ----------
536+ op : binary operator
537+ left : ExtensionArray or np.ndarray
538+ right : object
539+ keep_null_freq : bool, default False
540+ Whether to re-raise a NullFrequencyError unchanged, as opposed to
541+ catching and raising TypeError.
542+
543+ Returns
544+ -------
545+ ExtensionArray or np.ndarray
546+ 2-tuple of these if op is divmod or rdivmod
527547 """
548+ # NB: left and right should already be unboxed, so neither should be
549+ # a Series or Index.
528550
529- if left .dtype .kind in "mM" :
551+ if left .dtype .kind in "mM" and isinstance ( left , np . ndarray ) :
530552 # We need to cast datetime64 and timedelta64 ndarrays to
531553 # DatetimeArray/TimedeltaArray. But we avoid wrapping others in
532554 # PandasArray as that behaves poorly with e.g. IntegerArray.
@@ -535,15 +557,15 @@ def dispatch_to_extension_op(op, left, right):
535557 # The op calls will raise TypeError if the op is not defined
536558 # on the ExtensionArray
537559
538- # unbox Series and Index to arrays
539- new_left = extract_array (left , extract_numpy = True )
540- new_right = extract_array (right , extract_numpy = True )
541-
542560 try :
543- res_values = op (new_left , new_right )
561+ res_values = op (left , right )
544562 except NullFrequencyError :
545563 # DatetimeIndex and TimedeltaIndex with freq == None raise ValueError
546564 # on add/sub of integers (or int-like). We re-raise as a TypeError.
565+ if keep_null_freq :
566+ # TODO: remove keep_null_freq after Timestamp+int deprecation
567+ # GH#22535 is enforced
568+ raise
547569 raise TypeError (
548570 "incompatible type for a datetime/timedelta "
549571 "operation [{name}]" .format (name = op .__name__ )
@@ -615,25 +637,29 @@ def wrapper(left, right):
615637 if isinstance (right , ABCDataFrame ):
616638 return NotImplemented
617639
640+ keep_null_freq = isinstance (
641+ right ,
642+ (ABCDatetimeIndex , ABCDatetimeArray , ABCTimedeltaIndex , ABCTimedeltaArray ),
643+ )
644+
618645 left , right = _align_method_SERIES (left , right )
619646 res_name = get_op_result_name (left , right )
620- right = maybe_upcast_for_op (right , left .shape )
621647
622- if should_extension_dispatch (left , right ):
623- result = dispatch_to_extension_op ( op , left , right )
648+ lvalues = extract_array (left , extract_numpy = True )
649+ rvalues = extract_array ( right , extract_numpy = True )
624650
625- elif is_timedelta64_dtype (right ) or isinstance (
626- right , (ABCDatetimeArray , ABCDatetimeIndex )
627- ):
628- # We should only get here with td64 right with non-scalar values
629- # for right upcast by maybe_upcast_for_op
630- assert not isinstance (right , (np .timedelta64 , np .ndarray ))
631- result = op (left ._values , right )
651+ rvalues = maybe_upcast_for_op (rvalues , lvalues .shape )
632652
633- else :
634- lvalues = extract_array (left , extract_numpy = True )
635- rvalues = extract_array (right , extract_numpy = True )
653+ if should_extension_dispatch (lvalues , rvalues ):
654+ result = dispatch_to_extension_op (op , lvalues , rvalues , keep_null_freq )
655+
656+ elif is_timedelta64_dtype (rvalues ) or isinstance (rvalues , ABCDatetimeArray ):
657+ # We should only get here with td64 rvalues with non-scalar values
658+ # for rvalues upcast by maybe_upcast_for_op
659+ assert not isinstance (rvalues , (np .timedelta64 , np .ndarray ))
660+ result = dispatch_to_extension_op (op , lvalues , rvalues , keep_null_freq )
636661
662+ else :
637663 with np .errstate (all = "ignore" ):
638664 result = na_op (lvalues , rvalues )
639665
@@ -708,25 +734,25 @@ def wrapper(self, other, axis=None):
708734 if len (self ) != len (other ):
709735 raise ValueError ("Lengths must match to compare" )
710736
711- if should_extension_dispatch (self , other ):
712- res_values = dispatch_to_extension_op ( op , self , other )
737+ lvalues = extract_array (self , extract_numpy = True )
738+ rvalues = extract_array ( other , extract_numpy = True )
713739
714- elif is_scalar (other ) and isna (other ):
740+ if should_extension_dispatch (lvalues , rvalues ):
741+ res_values = dispatch_to_extension_op (op , lvalues , rvalues )
742+
743+ elif is_scalar (rvalues ) and isna (rvalues ):
715744 # numpy does not like comparisons vs None
716745 if op is operator .ne :
717- res_values = np .ones (len (self ), dtype = bool )
746+ res_values = np .ones (len (lvalues ), dtype = bool )
718747 else :
719- res_values = np .zeros (len (self ), dtype = bool )
748+ res_values = np .zeros (len (lvalues ), dtype = bool )
720749
721750 else :
722- lvalues = extract_array (self , extract_numpy = True )
723- rvalues = extract_array (other , extract_numpy = True )
724-
725751 with np .errstate (all = "ignore" ):
726752 res_values = na_op (lvalues , rvalues )
727753 if is_scalar (res_values ):
728754 raise TypeError (
729- "Could not compare {typ} type with Series" .format (typ = type (other ))
755+ "Could not compare {typ} type with Series" .format (typ = type (rvalues ))
730756 )
731757
732758 result = self ._constructor (res_values , index = self .index )
0 commit comments