11""" miscellaneous sorting / groupby utilities """
2- from typing import Callable , Dict , Union
2+ from typing import Callable , Optional
33
44import numpy as np
55
@@ -299,32 +299,20 @@ def nargsort(
299299 return indexer
300300
301301
302- def apply_key_name (values , key , name ):
303- if isinstance (key , dict ):
304- key = key .get (name , None )
305-
306- if key is None :
307- return values
308-
309- return key (values )
310-
311-
312- def ensure_key_mapped_dataframe (df , key , levels = None , axis = 0 ):
302+ def ensure_key_mapped_dataframe (df , key : Callable , levels = None , axis = 0 ):
313303 """
314304 Returns a new DataFrame in which key has been applied
315305 to all levels specified in level (or all levels if level
316- is None). Used for key sorting for DataFrames.
306+ is None) along an axis . Used for key sorting for DataFrames.
317307
318308 Parameters
319309 ----------
320310 df : DataFrame
321311 DataFrame to which to apply the key function on the
322312 specified levels.
323- key : Callable or Dict[Any, Callable]
324- If Callable, function that takes a Series and returns
325- a Series of the same shape. This key is applied to each
326- level separately. If dict, name or index of each column
327- or row is used to index the key object to get a Callable.
313+ key : Callable
314+ Function that takes a Series and returns a Series of
315+ the same shape. This key is applied to each level separately.
328316 levels : list-like, int or str, default None
329317 Level or list of levels to apply the key function to.
330318 If None, key function is applied to all levels. Other
@@ -345,9 +333,9 @@ def ensure_key_mapped_dataframe(df, key, levels=None, axis=0):
345333 )
346334
347335 if axis == 0 :
348- axis_levels = df .columns ._values
336+ axis_levels = list ( df .columns ._values ) # makes mypy happy
349337 else :
350- axis_levels = df .index ._values
338+ axis_levels = list ( df .index ._values )
351339
352340 if levels is not None :
353341 if isinstance (levels , (str , int )):
@@ -357,17 +345,16 @@ def ensure_key_mapped_dataframe(df, key, levels=None, axis=0):
357345 else :
358346 sort_levels = axis_levels
359347
360- new_levels = [
361- ensure_key_mapped (
362- Series (df ._get_label_or_level_values (name , axis = axis ), name = name ),
363- key ,
364- name = name ,
365- )
366- if name in sort_levels
367- else df ._get_label_or_level_values (name , axis = axis )
348+ values = [
349+ (name , Series (df ._get_label_or_level_values (name , axis = axis ), name = name ))
368350 for name in axis_levels
369351 ]
370352
353+ new_levels = [
354+ ensure_key_mapped (series , key ) if name in sort_levels else series
355+ for (name , series ) in values
356+ ]
357+
371358 if axis == 0 :
372359 new_df = DataFrame ._from_arrays (new_levels , df .columns , df .index )
373360 else :
@@ -419,11 +406,11 @@ def ensure_key_mapped_multiindex(index, key: Callable, levels=None):
419406 else :
420407 sort_levels = list (range (index .nlevels )) # satisfies mypy
421408
409+ values = [(level , index ._get_level_values (level )) for level in range (index .nlevels )]
410+
422411 mapped = [
423- ensure_key_mapped (index ._get_level_values (level ), key , name = level )
424- if level in sort_levels
425- else index ._get_level_values (level )
426- for level in range (index .nlevels )
412+ ensure_key_mapped (idx , key ) if level in sort_levels else idx
413+ for (level , idx ) in values
427414 ]
428415
429416 labels = MultiIndex .from_arrays (mapped )
@@ -432,7 +419,7 @@ def ensure_key_mapped_multiindex(index, key: Callable, levels=None):
432419
433420
434421def ensure_key_mapped (
435- values , key : Union [ Dict , Callable ], levels = None , name = None , axis = 0 ,
422+ values , key : Optional [ Callable ], levels = None , axis = 0 ,
436423):
437424
438425 """
@@ -443,19 +430,19 @@ def ensure_key_mapped(
443430 Parameters
444431 ----------
445432 values : Series, DataFrame, Index subclass, or ndarray
446- key : Union[Callable, Dict[Union[str, int], Callable]
447- key to be called on the values array. If dict, indexed by
448- name or number of columns in values. Dict supported only for
449- DataFrame or MultiIndex. Expected to return an object of the
450- same shape and compatible with the original type.
433+ key : Optional[Callable]
434+ key to be called on the values array. Expected to
435+ take a level or column from values and return an
436+ object of the same shape and compatible with the original type.
437+ For MultiIndex and DataFrame, applied to rows or columns of the
438+ values array. For Series and MultiIndex, applied directly.
451439 levels : list-like, int or str, default None
452440 For MultiIndex values, level or list of levels to apply the key
453441 function to. If None, key function is applied to all levels. Other
454442 levels are left unchanged.
455- name : str or int, default None
456- Name used to index the key function if a dictionary.
457443 axis : int, default 0
458- Axis to use for applying the key to DataFrame values level by level.
444+ Axis to use for applying the key to DataFrame values. 0 applies the
445+ key to columns, 1 to rows.
459446 """
460447 from pandas .core .indexes .api import Index
461448 from pandas import DataFrame
@@ -469,18 +456,18 @@ def ensure_key_mapped(
469456 if isinstance (values , DataFrame ): # apply the key to select levels
470457 return ensure_key_mapped_dataframe (values , key , levels = levels , axis = axis )
471458
472- result = apply_key_name (values .copy (), key , name )
459+ result = key (values .copy ())
473460
474461 if len (result ) != len (values ):
475462 raise ValueError (
476463 "User-provided `key` function must not change the shape of the array."
477464 )
478465
479466 try :
480- if isinstance (values , Index ):
467+ if isinstance (values , Index ): # allow a new Index class
481468 result = Index (result )
482469 else :
483- result = type (values )(result )
470+ result = type (values )(result ) # try to recover otherwise
484471 except TypeError :
485472 raise TypeError (
486473 "User-provided `key` function returned an invalid type {} \
0 commit comments