@@ -46,6 +46,7 @@ class providing the base-class of operations.
4646 ArrayLike ,
4747 IndexLabel ,
4848 NDFrameT ,
49+ PositionalIndexer ,
4950 RandomState ,
5051 Scalar ,
5152 T ,
@@ -65,6 +66,7 @@ class providing the base-class of operations.
6566 is_bool_dtype ,
6667 is_datetime64_dtype ,
6768 is_float_dtype ,
69+ is_integer ,
6870 is_integer_dtype ,
6971 is_numeric_dtype ,
7072 is_object_dtype ,
@@ -97,6 +99,7 @@ class providing the base-class of operations.
9799 numba_ ,
98100 ops ,
99101)
102+ from pandas .core .groupby .indexing import GroupByIndexingMixin
100103from pandas .core .indexes .api import (
101104 CategoricalIndex ,
102105 Index ,
@@ -555,7 +558,7 @@ def f(self):
555558]
556559
557560
558- class BaseGroupBy (PandasObject , SelectionMixin [NDFrameT ]):
561+ class BaseGroupBy (PandasObject , SelectionMixin [NDFrameT ], GroupByIndexingMixin ):
559562 _group_selection : IndexLabel | None = None
560563 _apply_allowlist : frozenset [str ] = frozenset ()
561564 _hidden_attrs = PandasObject ._hidden_attrs | {
@@ -2445,23 +2448,28 @@ def backfill(self, limit=None):
24452448 @Substitution (name = "groupby" )
24462449 @Substitution (see_also = _common_see_also )
24472450 def nth (
2448- self , n : int | list [int ], dropna : Literal ["any" , "all" , None ] = None
2451+ self ,
2452+ n : PositionalIndexer | tuple ,
2453+ dropna : Literal ["any" , "all" , None ] = None ,
24492454 ) -> NDFrameT :
24502455 """
2451- Take the nth row from each group if n is an int, or a subset of rows
2452- if n is a list of ints.
2456+ Take the nth row from each group if n is an int, otherwise a subset of rows.
24532457
24542458 If dropna, will take the nth non-null row, dropna is either
24552459 'all' or 'any'; this is equivalent to calling dropna(how=dropna)
24562460 before the groupby.
24572461
24582462 Parameters
24592463 ----------
2460- n : int or list of ints
2461- A single nth value for the row or a list of nth values.
2464+ n : int, slice or list of ints and slices
2465+ A single nth value for the row or a list of nth values or slices.
2466+
2467+ .. versionchanged:: 1.4.0
2468+ Added slice and lists containiing slices.
2469+
24622470 dropna : {'any', 'all', None}, default None
24632471 Apply the specified dropna operation before counting which row is
2464- the nth row.
2472+ the nth row. Only supported if n is an int.
24652473
24662474 Returns
24672475 -------
@@ -2496,6 +2504,12 @@ def nth(
24962504 1 2.0
24972505 2 3.0
24982506 2 5.0
2507+ >>> g.nth(slice(None, -1))
2508+ B
2509+ A
2510+ 1 NaN
2511+ 1 2.0
2512+ 2 3.0
24992513
25002514 Specifying `dropna` allows count ignoring ``NaN``
25012515
@@ -2520,33 +2534,16 @@ def nth(
25202534 1 1 2.0
25212535 4 2 5.0
25222536 """
2523- valid_containers = (set , list , tuple )
2524- if not isinstance (n , (valid_containers , int )):
2525- raise TypeError ("n needs to be an int or a list/set/tuple of ints" )
2526-
25272537 if not dropna :
2528-
2529- if isinstance (n , int ):
2530- nth_values = [n ]
2531- elif isinstance (n , valid_containers ):
2532- nth_values = list (set (n ))
2533-
2534- nth_array = np .array (nth_values , dtype = np .intp )
25352538 with self ._group_selection_context ():
2536-
2537- mask_left = np .in1d (self ._cumcount_array (), nth_array )
2538- mask_right = np .in1d (
2539- self ._cumcount_array (ascending = False ) + 1 , - nth_array
2540- )
2541- mask = mask_left | mask_right
2539+ mask = self ._make_mask_from_positional_indexer (n )
25422540
25432541 ids , _ , _ = self .grouper .group_info
25442542
25452543 # Drop NA values in grouping
25462544 mask = mask & (ids != - 1 )
25472545
25482546 out = self ._mask_selected_obj (mask )
2549-
25502547 if not self .as_index :
25512548 return out
25522549
@@ -2563,19 +2560,20 @@ def nth(
25632560 return out .sort_index (axis = self .axis ) if self .sort else out
25642561
25652562 # dropna is truthy
2566- if isinstance ( n , valid_containers ):
2567- raise ValueError ("dropna option with a list of nth values is not supported " )
2563+ if not is_integer ( n ):
2564+ raise ValueError ("dropna option only supported for an integer argument " )
25682565
25692566 if dropna not in ["any" , "all" ]:
25702567 # Note: when agg-ing picker doesn't raise this, just returns NaN
25712568 raise ValueError (
2572- "For a DataFrame groupby, dropna must be "
2569+ "For a DataFrame or Series groupby.nth , dropna must be "
25732570 "either None, 'any' or 'all', "
25742571 f"(was passed { dropna } )."
25752572 )
25762573
25772574 # old behaviour, but with all and any support for DataFrames.
25782575 # modified in GH 7559 to have better perf
2576+ n = cast (int , n )
25792577 max_len = n if n >= 0 else - 1 - n
25802578 dropped = self .obj .dropna (how = dropna , axis = self .axis )
25812579
@@ -3301,11 +3299,16 @@ def head(self, n=5):
33013299 from the original DataFrame with original index and order preserved
33023300 (``as_index`` flag is ignored).
33033301
3304- Does not work for negative values of `n`.
3302+ Parameters
3303+ ----------
3304+ n : int
3305+ If positive: number of entries to include from start of each group.
3306+ If negative: number of entries to exclude from end of each group.
33053307
33063308 Returns
33073309 -------
33083310 Series or DataFrame
3311+ Subset of original Series or DataFrame as determined by n.
33093312 %(see_also)s
33103313 Examples
33113314 --------
@@ -3317,12 +3320,11 @@ def head(self, n=5):
33173320 0 1 2
33183321 2 5 6
33193322 >>> df.groupby('A').head(-1)
3320- Empty DataFrame
3321- Columns: [A, B]
3322- Index: []
3323+ A B
3324+ 0 1 2
33233325 """
33243326 self ._reset_group_selection ()
3325- mask = self ._cumcount_array () < n
3327+ mask = self ._make_mask_from_positional_indexer ( slice ( None , n ))
33263328 return self ._mask_selected_obj (mask )
33273329
33283330 @final
@@ -3336,11 +3338,16 @@ def tail(self, n=5):
33363338 from the original DataFrame with original index and order preserved
33373339 (``as_index`` flag is ignored).
33383340
3339- Does not work for negative values of `n`.
3341+ Parameters
3342+ ----------
3343+ n : int
3344+ If positive: number of entries to include from end of each group.
3345+ If negative: number of entries to exclude from start of each group.
33403346
33413347 Returns
33423348 -------
33433349 Series or DataFrame
3350+ Subset of original Series or DataFrame as determined by n.
33443351 %(see_also)s
33453352 Examples
33463353 --------
@@ -3352,12 +3359,16 @@ def tail(self, n=5):
33523359 1 a 2
33533360 3 b 2
33543361 >>> df.groupby('A').tail(-1)
3355- Empty DataFrame
3356- Columns: [A, B]
3357- Index: []
3362+ A B
3363+ 1 a 2
3364+ 3 b 2
33583365 """
33593366 self ._reset_group_selection ()
3360- mask = self ._cumcount_array (ascending = False ) < n
3367+ if n :
3368+ mask = self ._make_mask_from_positional_indexer (slice (- n , None ))
3369+ else :
3370+ mask = self ._make_mask_from_positional_indexer ([])
3371+
33613372 return self ._mask_selected_obj (mask )
33623373
33633374 @final
0 commit comments