|
10 | 10 | import collections |
11 | 11 | import functools |
12 | 12 | from typing import ( |
| 13 | + Callable, |
13 | 14 | Generic, |
14 | 15 | Hashable, |
15 | 16 | Iterator, |
|
29 | 30 | from pandas._typing import ( |
30 | 31 | ArrayLike, |
31 | 32 | DtypeObj, |
32 | | - F, |
33 | 33 | FrameOrSeries, |
34 | 34 | Shape, |
35 | 35 | npt, |
@@ -700,7 +700,7 @@ def get_iterator( |
700 | 700 | yield key, group.__finalize__(data, method="groupby") |
701 | 701 |
|
702 | 702 | @final |
703 | | - def _get_splitter(self, data: FrameOrSeries, axis: int = 0) -> DataSplitter: |
| 703 | + def _get_splitter(self, data: NDFrame, axis: int = 0) -> DataSplitter: |
704 | 704 | """ |
705 | 705 | Returns |
706 | 706 | ------- |
@@ -732,7 +732,9 @@ def group_keys_seq(self): |
732 | 732 | return get_flattened_list(ids, ngroups, self.levels, self.codes) |
733 | 733 |
|
734 | 734 | @final |
735 | | - def apply(self, f: F, data: FrameOrSeries, axis: int = 0) -> tuple[list, bool]: |
| 735 | + def apply( |
| 736 | + self, f: Callable, data: DataFrame | Series, axis: int = 0 |
| 737 | + ) -> tuple[list, bool]: |
736 | 738 | mutated = self.mutated |
737 | 739 | splitter = self._get_splitter(data, axis=axis) |
738 | 740 | group_keys = self.group_keys_seq |
@@ -918,7 +920,7 @@ def _cython_operation( |
918 | 920 |
|
919 | 921 | @final |
920 | 922 | def agg_series( |
921 | | - self, obj: Series, func: F, preserve_dtype: bool = False |
| 923 | + self, obj: Series, func: Callable, preserve_dtype: bool = False |
922 | 924 | ) -> ArrayLike: |
923 | 925 | """ |
924 | 926 | Parameters |
@@ -960,7 +962,7 @@ def agg_series( |
960 | 962 |
|
961 | 963 | @final |
962 | 964 | def _aggregate_series_pure_python( |
963 | | - self, obj: Series, func: F |
| 965 | + self, obj: Series, func: Callable |
964 | 966 | ) -> npt.NDArray[np.object_]: |
965 | 967 | ids, _, ngroups = self.group_info |
966 | 968 |
|
@@ -1061,7 +1063,7 @@ def _get_grouper(self): |
1061 | 1063 | """ |
1062 | 1064 | return self |
1063 | 1065 |
|
1064 | | - def get_iterator(self, data: FrameOrSeries, axis: int = 0): |
| 1066 | + def get_iterator(self, data: NDFrame, axis: int = 0): |
1065 | 1067 | """ |
1066 | 1068 | Groupby iterator |
1067 | 1069 |
|
@@ -1142,7 +1144,7 @@ def groupings(self) -> list[grouper.Grouping]: |
1142 | 1144 | ping = grouper.Grouping(lev, lev, in_axis=False, level=None) |
1143 | 1145 | return [ping] |
1144 | 1146 |
|
1145 | | - def _aggregate_series_fast(self, obj: Series, func: F) -> np.ndarray: |
| 1147 | + def _aggregate_series_fast(self, obj: Series, func: Callable) -> np.ndarray: |
1146 | 1148 | # -> np.ndarray[object] |
1147 | 1149 | raise NotImplementedError( |
1148 | 1150 | "This should not be reached; use _aggregate_series_pure_python" |
@@ -1241,7 +1243,7 @@ def _chop(self, sdata: DataFrame, slice_obj: slice) -> DataFrame: |
1241 | 1243 |
|
1242 | 1244 |
|
1243 | 1245 | def get_splitter( |
1244 | | - data: FrameOrSeries, labels: np.ndarray, ngroups: int, axis: int = 0 |
| 1246 | + data: NDFrame, labels: np.ndarray, ngroups: int, axis: int = 0 |
1245 | 1247 | ) -> DataSplitter: |
1246 | 1248 | if isinstance(data, Series): |
1247 | 1249 | klass: type[DataSplitter] = SeriesSplitter |
|
0 commit comments