11""" define the IntervalIndex """
22from __future__ import annotations
33
4- from functools import wraps
54from operator import (
65 le ,
76 lt ,
6362)
6463from pandas .core .dtypes .dtypes import IntervalDtype
6564
66- from pandas .core .algorithms import (
67- take_nd ,
68- unique ,
69- )
65+ from pandas .core .algorithms import take_nd
7066from pandas .core .arrays .interval import (
7167 IntervalArray ,
7268 _interval_shared_docs ,
9389 TimedeltaIndex ,
9490 timedelta_range ,
9591)
96- from pandas .core .ops import get_op_result_name
9792
9893if TYPE_CHECKING :
9994 from pandas import CategoricalIndex
@@ -151,59 +146,6 @@ def _new_IntervalIndex(cls, d):
151146 return cls .from_arrays (** d )
152147
153148
154- def setop_check (method ):
155- """
156- This is called to decorate the set operations of IntervalIndex
157- to perform the type check in advance.
158- """
159- op_name = method .__name__
160-
161- @wraps (method )
162- def wrapped (self , other , sort = False ):
163- self ._validate_sort_keyword (sort )
164- self ._assert_can_do_setop (other )
165- other , result_name = self ._convert_can_do_setop (other )
166-
167- if op_name == "difference" :
168- if not isinstance (other , IntervalIndex ):
169- result = getattr (self .astype (object ), op_name )(other , sort = sort )
170- return result .astype (self .dtype )
171-
172- elif not self ._should_compare (other ):
173- # GH#19016: ensure set op will not return a prohibited dtype
174- result = getattr (self .astype (object ), op_name )(other , sort = sort )
175- return result .astype (self .dtype )
176-
177- return method (self , other , sort )
178-
179- return wrapped
180-
181-
182- def _setop (op_name : str ):
183- """
184- Implement set operation.
185- """
186-
187- def func (self , other , sort = None ):
188- # At this point we are assured
189- # isinstance(other, IntervalIndex)
190- # other.closed == self.closed
191-
192- result = getattr (self ._multiindex , op_name )(other ._multiindex , sort = sort )
193- result_name = get_op_result_name (self , other )
194-
195- # GH 19101: ensure empty results have correct dtype
196- if result .empty :
197- result = result ._values .astype (self .dtype .subtype )
198- else :
199- result = result ._values
200-
201- return type (self ).from_tuples (result , closed = self .closed , name = result_name )
202-
203- func .__name__ = op_name
204- return setop_check (func )
205-
206-
207149@Appender (
208150 _interval_shared_docs ["class" ]
209151 % {
@@ -859,11 +801,11 @@ def _intersection(self, other, sort):
859801 """
860802 # For IntervalIndex we also know other.closed == self.closed
861803 if self .left .is_unique and self .right .is_unique :
862- taken = self . _intersection_unique (other )
804+ return super (). _intersection (other , sort = sort )
863805 elif other .left .is_unique and other .right .is_unique and self .isna ().sum () <= 1 :
864806 # Swap other/self if other is unique and self does not have
865807 # multiple NaNs
866- taken = other . _intersection_unique ( self )
808+ return super (). _intersection ( other , sort = sort )
867809 else :
868810 # duplicates
869811 taken = self ._intersection_non_unique (other )
@@ -873,29 +815,6 @@ def _intersection(self, other, sort):
873815
874816 return taken
875817
876- def _intersection_unique (self , other : IntervalIndex ) -> IntervalIndex :
877- """
878- Used when the IntervalIndex does not have any common endpoint,
879- no matter left or right.
880- Return the intersection with another IntervalIndex.
881-
882- Parameters
883- ----------
884- other : IntervalIndex
885-
886- Returns
887- -------
888- IntervalIndex
889- """
890- lindexer = self .left .get_indexer (other .left )
891- rindexer = self .right .get_indexer (other .right )
892-
893- match = (lindexer == rindexer ) & (lindexer != - 1 )
894- indexer = lindexer .take (match .nonzero ()[0 ])
895- indexer = unique (indexer )
896-
897- return self .take (indexer )
898-
899818 def _intersection_non_unique (self , other : IntervalIndex ) -> IntervalIndex :
900819 """
901820 Used when the IntervalIndex does have some common endpoints,
@@ -923,9 +842,6 @@ def _intersection_non_unique(self, other: IntervalIndex) -> IntervalIndex:
923842
924843 return self [mask ]
925844
926- _union = _setop ("union" )
927- _difference = _setop ("difference" )
928-
929845 # --------------------------------------------------------------------
930846
931847 @property
0 commit comments