diff --git a/pandas-stubs/_typing.pyi b/pandas-stubs/_typing.pyi index 8415187a8..c532b5927 100644 --- a/pandas-stubs/_typing.pyi +++ b/pandas-stubs/_typing.pyi @@ -20,6 +20,7 @@ from typing import ( Optional, Protocol, Sequence, + Tuple, Type, TypeVar, Union, @@ -166,6 +167,12 @@ IndexingInt = Union[ int, np.int_, np.integer, np.unsignedinteger, np.signedinteger, np.int8 ] +# NDFrameT is stricter and ensures that the same subclass of NDFrame always is +# used. E.g. `def func(a: NDFrameT) -> NDFrameT: ...` means that if a +# Series is passed into a function, a Series is always returned and if a DataFrame is +# passed in, a DataFrame is always returned. +NDFrameT = TypeVar("NDFrameT", bound=NDFrame) + # Interval closed type IntervalClosedType = Literal["left", "right", "both", "neither"] @@ -197,6 +204,7 @@ XMLParsers = Literal["lxml", "etree"] # Any plain Python or numpy function Function = Union[np.ufunc, Callable[..., Any]] -GroupByObject = Union[ - Label, List[Label], Function, Series, np.ndarray, Mapping[Label, Any], Index +GroupByObjectNonScalar = Union[ + Tuple, List[Label], Function, Series, np.ndarray, Mapping[Label, Any], Index ] +GroupByObject = Union[Scalar, GroupByObjectNonScalar] diff --git a/pandas-stubs/core/base.pyi b/pandas-stubs/core/base.pyi index 780c7bc16..bd1ee2681 100644 --- a/pandas-stubs/core/base.pyi +++ b/pandas-stubs/core/base.pyi @@ -2,6 +2,7 @@ from __future__ import annotations from typing import ( Callable, + Generic, List, Literal, Optional, @@ -20,6 +21,7 @@ from pandas.core.arrays import ExtensionArray from pandas.core.arrays.categorical import Categorical from pandas._typing import ( + NDFrameT, Scalar, SeriesAxisType, ) @@ -34,7 +36,7 @@ class GroupByError(Exception): ... class DataError(GroupByError): ... class SpecificationError(GroupByError): ... -class SelectionMixin: +class SelectionMixin(Generic[NDFrameT]): def ndim(self) -> int: ... def __getitem__(self, key): ... def aggregate( diff --git a/pandas-stubs/core/frame.pyi b/pandas-stubs/core/frame.pyi index 6a946e267..b701c2540 100644 --- a/pandas-stubs/core/frame.pyi +++ b/pandas-stubs/core/frame.pyi @@ -30,7 +30,10 @@ from pandas import ( ) from pandas.core.arraylike import OpsMixin from pandas.core.generic import NDFrame -from pandas.core.groupby.generic import DataFrameGroupBy +from pandas.core.groupby.generic import ( + _DataFrameGroupByNonScalar, + _DataFrameGroupByScalar, +) from pandas.core.groupby.grouper import Grouper from pandas.core.indexes.base import Index from pandas.core.indexing import ( @@ -54,7 +57,7 @@ from pandas._typing import ( DtypeNp, FilePathOrBuffer, FilePathOrBytesBuffer, - GroupByObject, + GroupByObjectNonScalar, IgnoreRaise, IndexingInt, IndexLabel, @@ -862,9 +865,23 @@ class DataFrame(NDFrame, OpsMixin): filter_func: Optional[Callable] = ..., errors: Union[_str, Literal["raise", "ignore"]] = ..., ) -> None: ... + @overload + def groupby( + self, + by: Scalar, + axis: AxisType = ..., + level: Optional[Level] = ..., + as_index: _bool = ..., + sort: _bool = ..., + group_keys: _bool = ..., + squeeze: _bool = ..., + observed: _bool = ..., + dropna: _bool = ..., + ) -> _DataFrameGroupByScalar: ... + @overload def groupby( self, - by: Optional[GroupByObject] = ..., + by: Optional[GroupByObjectNonScalar] = ..., axis: AxisType = ..., level: Optional[Level] = ..., as_index: _bool = ..., @@ -873,7 +890,7 @@ class DataFrame(NDFrame, OpsMixin): squeeze: _bool = ..., observed: _bool = ..., dropna: _bool = ..., - ) -> DataFrameGroupBy: ... + ) -> _DataFrameGroupByNonScalar: ... def pivot( self, index=..., diff --git a/pandas-stubs/core/groupby/generic.pyi b/pandas-stubs/core/groupby/generic.pyi index d27536c6a..0eb7de9b4 100644 --- a/pandas-stubs/core/groupby/generic.pyi +++ b/pandas-stubs/core/groupby/generic.pyi @@ -3,6 +3,7 @@ from typing import ( Callable, Dict, FrozenSet, + Iterator, List, Literal, NamedTuple, @@ -32,6 +33,7 @@ from pandas._typing import ( FrameOrSeries, FuncType, Level, + Scalar, ) AggScalar = Union[str, Callable[..., Any]] @@ -46,6 +48,12 @@ def pin_whitelisted_properties( klass: Type[FrameOrSeries], whitelist: FrozenSet[str] ): ... +class _SeriesGroupByScalar(SeriesGroupBy): + def __iter__(self) -> Iterator[Tuple[Scalar, Series]]: ... + +class _SeriesGroupByNonScalar(SeriesGroupBy): + def __iter__(self) -> Iterator[Tuple[Tuple, Series]]: ... + class SeriesGroupBy(GroupBy): def any(self, skipna: bool = ...) -> Series[bool]: ... def all(self, skipna: bool = ...) -> Series[bool]: ... @@ -100,6 +108,12 @@ class SeriesGroupBy(GroupBy): self, n: Union[int, Sequence[int]], dropna: Optional[str] = ... ) -> Series[S1]: ... +class _DataFrameGroupByScalar(DataFrameGroupBy): + def __iter__(self) -> Iterator[Tuple[Scalar, DataFrame]]: ... + +class _DataFrameGroupByNonScalar(DataFrameGroupBy): + def __iter__(self) -> Iterator[Tuple[Tuple, DataFrame]]: ... + class DataFrameGroupBy(GroupBy): def any(self, skipna: bool = ...) -> DataFrame: ... def all(self, skipna: bool = ...) -> DataFrame: ... diff --git a/pandas-stubs/core/groupby/groupby.pyi b/pandas-stubs/core/groupby/groupby.pyi index b765a6e44..9ea65f9b7 100644 --- a/pandas-stubs/core/groupby/groupby.pyi +++ b/pandas-stubs/core/groupby/groupby.pyi @@ -1,11 +1,8 @@ from typing import ( - Any, Callable, Dict, - Generator, List, Optional, - Tuple, Union, ) @@ -13,6 +10,7 @@ from pandas.core.base import PandasObject from pandas.core.frame import DataFrame from pandas.core.generic import NDFrame from pandas.core.groupby import ops +from pandas.core.groupby.indexing import GroupByIndexingMixin from pandas.core.indexes.api import Index from pandas.core.series import Series @@ -27,7 +25,7 @@ class GroupByPlot(PandasObject): def __call__(self, *args, **kwargs): ... def __getattr__(self, name: str): ... -class _GroupBy(PandasObject): +class _GroupBy(PandasObject, GroupByIndexingMixin): level = ... as_index = ... keys = ... @@ -67,7 +65,6 @@ class _GroupBy(PandasObject): def pipe(self, func: Callable, *args, **kwargs): ... plot = ... def get_group(self, name, obj: Optional[DataFrame] = ...) -> DataFrame: ... - def __iter__(self) -> Generator[Tuple[str, Any], None, None]: ... def apply(self, func: Callable, *args, **kwargs) -> FrameOrSeriesUnion: ... class GroupBy(_GroupBy): diff --git a/pandas-stubs/core/series.pyi b/pandas-stubs/core/series.pyi index 81656dca1..a1a4ae9d4 100644 --- a/pandas-stubs/core/series.pyi +++ b/pandas-stubs/core/series.pyi @@ -34,7 +34,10 @@ from pandas import ( ) from pandas.core.arrays.base import ExtensionArray from pandas.core.arrays.categorical import CategoricalAccessor -from pandas.core.groupby.generic import SeriesGroupBy +from pandas.core.groupby.generic import ( + _SeriesGroupByNonScalar, + _SeriesGroupByScalar, +) from pandas.core.indexes.accessors import ( CombinedDatetimelikeProperties, DatetimeProperties, @@ -65,6 +68,7 @@ from pandas._typing import ( Dtype, DtypeNp, FilePathOrBuffer, + GroupByObjectNonScalar, IgnoreRaise, IndexingInt, Label, @@ -363,9 +367,23 @@ class Series(IndexOpsMixin, NDFrame, Generic[S1]): def keys(self) -> List: ... def to_dict(self, into: Hashable = ...) -> Dict[Any, S1]: ... def to_frame(self, name: Optional[object] = ...) -> DataFrame: ... + @overload + def groupby( + self, + by: Scalar, + axis: SeriesAxisType = ..., + level: Optional[Level] = ..., + as_index: _bool = ..., + sort: _bool = ..., + group_keys: _bool = ..., + squeeze: _bool = ..., + observed: _bool = ..., + dropna: _bool = ..., + ) -> _SeriesGroupByScalar: ... + @overload def groupby( self, - by=..., + by: GroupByObjectNonScalar = ..., axis: SeriesAxisType = ..., level: Optional[Level] = ..., as_index: _bool = ..., @@ -374,7 +392,7 @@ class Series(IndexOpsMixin, NDFrame, Generic[S1]): squeeze: _bool = ..., observed: _bool = ..., dropna: _bool = ..., - ) -> SeriesGroupBy: ... + ) -> _SeriesGroupByNonScalar: ... @overload def count(self, level: None = ...) -> int: ... @overload diff --git a/tests/test_frame.py b/tests/test_frame.py index 6c799c20e..ac437a707 100644 --- a/tests/test_frame.py +++ b/tests/test_frame.py @@ -9,6 +9,7 @@ Dict, Hashable, Iterable, + Iterator, List, Tuple, Union, @@ -20,6 +21,8 @@ import pytest from typing_extensions import assert_type +from pandas._typing import Scalar + from tests import check from pandas.io.parsers import TextFileReader @@ -1251,3 +1254,30 @@ def test_boolean_loc() -> None: df = pd.DataFrame([[0, 1], [1, 0]], columns=[True, False], index=[True, False]) check(assert_type(df.loc[True], pd.Series), pd.Series) check(assert_type(df.loc[:, False], pd.Series), pd.Series) + + +def test_groupby_result() -> None: + # GH 142 + df = pd.DataFrame({"a": [0, 1, 2], "b": [4, 5, 6], "c": [7, 8, 9]}) + iterator = df.groupby(["a", "b"]).__iter__() + assert_type(iterator, Iterator[Tuple[Tuple, pd.DataFrame]]) + index, value = next(iterator) + assert_type((index, value), Tuple[Tuple, pd.DataFrame]) + + check(assert_type(index, Tuple), tuple, np.int64) + check(assert_type(value, pd.DataFrame), pd.DataFrame) + + iterator2 = df.groupby("a").__iter__() + assert_type(iterator2, Iterator[Tuple[Scalar, pd.DataFrame]]) + index2, value2 = next(iterator2) + assert_type((index2, value2), Tuple[Scalar, pd.DataFrame]) + + check(assert_type(index2, Scalar), int) + check(assert_type(value2, pd.DataFrame), pd.DataFrame) + + # Want to make sure these cases are differentiated + for (k1, k2), g in df.groupby(["a", "b"]): + pass + + for kk, g in df.groupby("a"): + pass