3636)
3737from pandas .core .dtypes .missing import _maybe_fill , isna
3838
39+ from pandas ._typing import FrameOrSeries
3940import pandas .core .algorithms as algorithms
4041from pandas .core .base import SelectionMixin
4142import pandas .core .common as com
@@ -89,12 +90,16 @@ def __init__(
8990
9091 self ._filter_empty_groups = self .compressed = len (groupings ) != 1
9192 self .axis = axis
92- self .groupings = groupings # type: Sequence [grouper.Grouping]
93+ self ._groupings = list ( groupings ) # type: List [grouper.Grouping]
9394 self .sort = sort
9495 self .group_keys = group_keys
9596 self .mutated = mutated
9697 self .indexer = indexer
9798
99+ @property
100+ def groupings (self ) -> List ["grouper.Grouping" ]:
101+ return self ._groupings
102+
98103 @property
99104 def shape (self ):
100105 return tuple (ping .ngroups for ping in self .groupings )
@@ -106,7 +111,7 @@ def __iter__(self):
106111 def nkeys (self ) -> int :
107112 return len (self .groupings )
108113
109- def get_iterator (self , data , axis = 0 ):
114+ def get_iterator (self , data : FrameOrSeries , axis : int = 0 ):
110115 """
111116 Groupby iterator
112117
@@ -120,7 +125,7 @@ def get_iterator(self, data, axis=0):
120125 for key , (i , group ) in zip (keys , splitter ):
121126 yield key , group
122127
123- def _get_splitter (self , data , axis = 0 ) :
128+ def _get_splitter (self , data : FrameOrSeries , axis : int = 0 ) -> "DataSplitter" :
124129 comp_ids , _ , ngroups = self .group_info
125130 return get_splitter (data , comp_ids , ngroups , axis = axis )
126131
@@ -142,13 +147,13 @@ def _get_group_keys(self):
142147 # provide "flattened" iterator for multi-group setting
143148 return get_flattened_iterator (comp_ids , ngroups , self .levels , self .codes )
144149
145- def apply (self , f , data , axis : int = 0 ):
150+ def apply (self , f , data : FrameOrSeries , axis : int = 0 ):
146151 mutated = self .mutated
147152 splitter = self ._get_splitter (data , axis = axis )
148153 group_keys = self ._get_group_keys ()
149154 result_values = None
150155
151- sdata = splitter ._get_sorted_data ()
156+ sdata = splitter ._get_sorted_data () # type: FrameOrSeries
152157 if sdata .ndim == 2 and np .any (sdata .dtypes .apply (is_extension_array_dtype )):
153158 # calling splitter.fast_apply will raise TypeError via apply_frame_axis0
154159 # if we pass EA instead of ndarray
@@ -157,7 +162,7 @@ def apply(self, f, data, axis: int = 0):
157162
158163 elif (
159164 com .get_callable_name (f ) not in base .plotting_methods
160- and hasattr (splitter , "fast_apply" )
165+ and isinstance (splitter , FrameSplitter )
161166 and axis == 0
162167 # with MultiIndex, apply_frame_axis0 would raise InvalidApply
163168 # TODO: can we make this check prettier?
@@ -229,8 +234,7 @@ def names(self):
229234
230235 def size (self ) -> Series :
231236 """
232- Compute group sizes
233-
237+ Compute group sizes.
234238 """
235239 ids , _ , ngroup = self .group_info
236240 ids = ensure_platform_int (ids )
@@ -292,7 +296,7 @@ def reconstructed_codes(self) -> List[np.ndarray]:
292296 return decons_obs_group_ids (comp_ids , obs_ids , self .shape , codes , xnull = True )
293297
294298 @cache_readonly
295- def result_index (self ):
299+ def result_index (self ) -> Index :
296300 if not self .compressed and len (self .groupings ) == 1 :
297301 return self .groupings [0 ].result_index .rename (self .names [0 ])
298302
@@ -629,7 +633,7 @@ def agg_series(self, obj: Series, func):
629633 raise
630634 return self ._aggregate_series_pure_python (obj , func )
631635
632- def _aggregate_series_fast (self , obj , func ):
636+ def _aggregate_series_fast (self , obj : Series , func ):
633637 # At this point we have already checked that
634638 # - obj.index is not a MultiIndex
635639 # - obj is backed by an ndarray, not ExtensionArray
@@ -648,7 +652,7 @@ def _aggregate_series_fast(self, obj, func):
648652 result , counts = grouper .get_result ()
649653 return result , counts
650654
651- def _aggregate_series_pure_python (self , obj , func ):
655+ def _aggregate_series_pure_python (self , obj : Series , func ):
652656
653657 group_index , _ , ngroups = self .group_info
654658
@@ -705,7 +709,12 @@ class BinGrouper(BaseGrouper):
705709 """
706710
707711 def __init__ (
708- self , bins , binlabels , filter_empty = False , mutated = False , indexer = None
712+ self ,
713+ bins ,
714+ binlabels ,
715+ filter_empty : bool = False ,
716+ mutated : bool = False ,
717+ indexer = None ,
709718 ):
710719 self .bins = ensure_int64 (bins )
711720 self .binlabels = ensure_index (binlabels )
@@ -739,7 +748,7 @@ def _get_grouper(self):
739748 """
740749 return self
741750
742- def get_iterator (self , data : NDFrame , axis : int = 0 ):
751+ def get_iterator (self , data : FrameOrSeries , axis : int = 0 ):
743752 """
744753 Groupby iterator
745754
@@ -811,11 +820,9 @@ def names(self):
811820 return [self .binlabels .name ]
812821
813822 @property
814- def groupings (self ):
815- from pandas .core .groupby .grouper import Grouping
816-
823+ def groupings (self ) -> "List[grouper.Grouping]" :
817824 return [
818- Grouping (lvl , lvl , in_axis = False , level = None , name = name )
825+ grouper . Grouping (lvl , lvl , in_axis = False , level = None , name = name )
819826 for lvl , name in zip (self .levels , self .names )
820827 ]
821828
@@ -856,7 +863,7 @@ def _is_indexed_like(obj, axes) -> bool:
856863
857864
858865class DataSplitter :
859- def __init__ (self , data , labels , ngroups , axis : int = 0 ):
866+ def __init__ (self , data : FrameOrSeries , labels , ngroups : int , axis : int = 0 ):
860867 self .data = data
861868 self .labels = ensure_int64 (labels )
862869 self .ngroups = ngroups
@@ -887,15 +894,15 @@ def __iter__(self):
887894 for i , (start , end ) in enumerate (zip (starts , ends )):
888895 yield i , self ._chop (sdata , slice (start , end ))
889896
890- def _get_sorted_data (self ):
897+ def _get_sorted_data (self ) -> FrameOrSeries :
891898 return self .data .take (self .sort_idx , axis = self .axis )
892899
893- def _chop (self , sdata , slice_obj : slice ):
900+ def _chop (self , sdata , slice_obj : slice ) -> NDFrame :
894901 raise AbstractMethodError (self )
895902
896903
897904class SeriesSplitter (DataSplitter ):
898- def _chop (self , sdata , slice_obj : slice ):
905+ def _chop (self , sdata : Series , slice_obj : slice ) -> Series :
899906 return sdata ._get_values (slice_obj )
900907
901908
@@ -907,14 +914,14 @@ def fast_apply(self, f, names):
907914 sdata = self ._get_sorted_data ()
908915 return libreduction .apply_frame_axis0 (sdata , f , names , starts , ends )
909916
910- def _chop (self , sdata , slice_obj : slice ):
917+ def _chop (self , sdata : DataFrame , slice_obj : slice ) -> DataFrame :
911918 if self .axis == 0 :
912919 return sdata .iloc [slice_obj ]
913920 else :
914921 return sdata ._slice (slice_obj , axis = 1 )
915922
916923
917- def get_splitter (data : NDFrame , * args , ** kwargs ):
924+ def get_splitter (data : FrameOrSeries , * args , ** kwargs ) -> DataSplitter :
918925 if isinstance (data , Series ):
919926 klass = SeriesSplitter # type: Type[DataSplitter]
920927 else :
0 commit comments