2424from pandas .core .dtypes .common import (
2525 ensure_int64 ,
2626 ensure_platform_int ,
27- is_extension_array_dtype ,
2827)
2928from pandas .core .dtypes .generic import (
3029 ABCMultiIndex ,
3635
3736if TYPE_CHECKING :
3837 from pandas ._typing import (
38+ ArrayLike ,
3939 AxisInt ,
4040 IndexKeyFunc ,
4141 Level ,
4545 npt ,
4646 )
4747
48- from pandas import MultiIndex
48+ from pandas import (
49+ MultiIndex ,
50+ Series ,
51+ )
4952 from pandas .core .arrays import ExtensionArray
5053 from pandas .core .indexes .base import Index
5154
@@ -79,7 +82,10 @@ def get_indexer_indexer(
7982 The indexer for the new index.
8083 """
8184
82- target = ensure_key_mapped (target , key , levels = level )
85+ # error: Incompatible types in assignment (expression has type
86+ # "Union[ExtensionArray, ndarray[Any, Any], Index, Series]", variable has
87+ # type "Index")
88+ target = ensure_key_mapped (target , key , levels = level ) # type:ignore[assignment]
8389 target = target ._sort_levels_monotonic ()
8490
8591 if level is not None :
@@ -304,7 +310,7 @@ def indexer_from_factorized(
304310
305311
306312def lexsort_indexer (
307- keys ,
313+ keys : list [ ArrayLike ] | list [ Series ] ,
308314 orders = None ,
309315 na_position : str = "last" ,
310316 key : Callable | None = None ,
@@ -315,8 +321,9 @@ def lexsort_indexer(
315321
316322 Parameters
317323 ----------
318- keys : sequence of arrays
324+ keys : list[ArrayLike] | list[Series]
319325 Sequence of ndarrays to be sorted by the indexer
326+ list[Series] is only if key is not None.
320327 orders : bool or list of booleans, optional
321328 Determines the sorting order for each element in keys. If a list,
322329 it must be the same length as keys. This determines whether the
@@ -343,7 +350,10 @@ def lexsort_indexer(
343350 elif orders is None :
344351 orders = [True ] * len (keys )
345352
346- keys = [ensure_key_mapped (k , key ) for k in keys ]
353+ # error: Incompatible types in assignment (expression has type
354+ # "List[Union[ExtensionArray, ndarray[Any, Any], Index, Series]]", variable
355+ # has type "Union[List[Union[ExtensionArray, ndarray[Any, Any]]], List[Series]]")
356+ keys = [ensure_key_mapped (k , key ) for k in keys ] # type: ignore[assignment]
347357
348358 for k , order in zip (keys , orders ):
349359 if na_position not in ["last" , "first" ]:
@@ -354,7 +364,9 @@ def lexsort_indexer(
354364 codes = k .copy ()
355365 n = len (codes )
356366 mask_n = n
357- if mask .any ():
367+ # error: Item "ExtensionArray" of "Union[Any, ExtensionArray,
368+ # ndarray[Any, Any]]" has no attribute "any"
369+ if mask .any (): # type: ignore[union-attr]
358370 n -= 1
359371
360372 else :
@@ -369,14 +381,40 @@ def lexsort_indexer(
369381
370382 if order : # ascending
371383 if na_position == "last" :
372- codes = np .where (mask , n , codes )
384+ # error: Argument 1 to "where" has incompatible type "Union[Any,
385+ # ExtensionArray, ndarray[Any, Any]]"; expected
386+ # "Union[_SupportsArray[dtype[Any]],
387+ # _NestedSequence[_SupportsArray[dtype[Any]]], bool, int, float,
388+ # complex, str, bytes, _NestedSequence[Union[bool, int, float,
389+ # complex, str, bytes]]]"
390+ codes = np .where (mask , n , codes ) # type: ignore[arg-type]
373391 elif na_position == "first" :
374- codes += 1
392+ # error: Incompatible types in assignment (expression has type
393+ # "Union[Any, int, ndarray[Any, dtype[signedinteger[Any]]]]",
394+ # variable has type "Union[Series, ExtensionArray, ndarray[Any, Any]]")
395+ # error: Unsupported operand types for + ("ExtensionArray" and "int")
396+ codes += 1 # type: ignore[operator,assignment]
375397 else : # not order means descending
376398 if na_position == "last" :
377- codes = np .where (mask , n , n - codes - 1 )
399+ # error: Unsupported operand types for - ("int" and "ExtensionArray")
400+ # error: Argument 1 to "where" has incompatible type "Union[Any,
401+ # ExtensionArray, ndarray[Any, Any]]"; expected
402+ # "Union[_SupportsArray[dtype[Any]],
403+ # _NestedSequence[_SupportsArray[dtype[Any]]], bool, int, float,
404+ # complex, str, bytes, _NestedSequence[Union[bool, int, float,
405+ # complex, str, bytes]]]"
406+ codes = np .where (
407+ mask , n , n - codes - 1 # type: ignore[operator,arg-type]
408+ )
378409 elif na_position == "first" :
379- codes = np .where (mask , 0 , n - codes )
410+ # error: Unsupported operand types for - ("int" and "ExtensionArray")
411+ # error: Argument 1 to "where" has incompatible type "Union[Any,
412+ # ExtensionArray, ndarray[Any, Any]]"; expected
413+ # "Union[_SupportsArray[dtype[Any]],
414+ # _NestedSequence[_SupportsArray[dtype[Any]]], bool, int, float,
415+ # complex, str, bytes, _NestedSequence[Union[bool, int, float,
416+ # complex, str, bytes]]]"
417+ codes = np .where (mask , 0 , n - codes ) # type: ignore[operator,arg-type]
380418
381419 shape .append (mask_n )
382420 labels .append (codes )
@@ -385,7 +423,7 @@ def lexsort_indexer(
385423
386424
387425def nargsort (
388- items ,
426+ items : ArrayLike | Index | Series ,
389427 kind : str = "quicksort" ,
390428 ascending : bool = True ,
391429 na_position : str = "last" ,
@@ -401,6 +439,7 @@ def nargsort(
401439
402440 Parameters
403441 ----------
442+ items : np.ndarray, ExtensionArray, Index, or Series
404443 kind : str, default 'quicksort'
405444 ascending : bool, default True
406445 na_position : {'first', 'last'}, default 'last'
@@ -414,6 +453,7 @@ def nargsort(
414453 """
415454
416455 if key is not None :
456+ # see TestDataFrameSortKey, TestRangeIndex::test_sort_values_key
417457 items = ensure_key_mapped (items , key )
418458 return nargsort (
419459 items ,
@@ -425,16 +465,27 @@ def nargsort(
425465 )
426466
427467 if isinstance (items , ABCRangeIndex ):
428- return items .argsort (ascending = ascending ) # TODO: test coverage with key?
468+ return items .argsort (ascending = ascending )
429469 elif not isinstance (items , ABCMultiIndex ):
430470 items = extract_array (items )
471+ else :
472+ raise TypeError (
473+ "nargsort does not support MultiIndex. Use index.sort_values instead."
474+ )
475+
431476 if mask is None :
432- mask = np .asarray (isna (items )) # TODO: does this exclude MultiIndex too?
477+ mask = np .asarray (isna (items ))
433478
434- if is_extension_array_dtype (items ):
435- return items .argsort (ascending = ascending , kind = kind , na_position = na_position )
436- else :
437- items = np .asanyarray (items )
479+ if not isinstance (items , np .ndarray ):
480+ # i.e. ExtensionArray
481+ return items .argsort (
482+ ascending = ascending ,
483+ # error: Argument "kind" to "argsort" of "ExtensionArray" has
484+ # incompatible type "str"; expected "Literal['quicksort',
485+ # 'mergesort', 'heapsort', 'stable']"
486+ kind = kind , # type: ignore[arg-type]
487+ na_position = na_position ,
488+ )
438489
439490 idx = np .arange (len (items ))
440491 non_nans = items [~ mask ]
@@ -551,7 +602,9 @@ def _ensure_key_mapped_multiindex(
551602 return type (index ).from_arrays (mapped )
552603
553604
554- def ensure_key_mapped (values , key : Callable | None , levels = None ):
605+ def ensure_key_mapped (
606+ values : ArrayLike | Index | Series , key : Callable | None , levels = None
607+ ) -> ArrayLike | Index | Series :
555608 """
556609 Applies a callable key function to the values function and checks
557610 that the resulting value has the same shape. Can be called on Index
@@ -584,8 +637,10 @@ def ensure_key_mapped(values, key: Callable | None, levels=None):
584637 ): # convert to a new Index subclass, not necessarily the same
585638 result = Index (result )
586639 else :
640+ # try to revert to original type otherwise
587641 type_of_values = type (values )
588- result = type_of_values (result ) # try to revert to original type otherwise
642+ # error: Too many arguments for "ExtensionArray"
643+ result = type_of_values (result ) # type: ignore[call-arg]
589644 except TypeError :
590645 raise TypeError (
591646 f"User-provided `key` function returned an invalid type { type (result )} \
0 commit comments