66import warnings
77import copy
88from textwrap import dedent
9+ from contextlib import contextmanager
910
1011from pandas .compat import (
1112 zip , range , lzip ,
@@ -549,6 +550,16 @@ def f(self):
549550 return attr
550551
551552
553+ @contextmanager
554+ def _group_selection_context (groupby ):
555+ """
556+ set / reset the _group_selection_context
557+ """
558+ groupby ._set_group_selection ()
559+ yield groupby
560+ groupby ._reset_group_selection ()
561+
562+
552563class _GroupBy (PandasObject , SelectionMixin ):
553564 _group_selection = None
554565 _apply_whitelist = frozenset ([])
@@ -704,6 +715,8 @@ def _set_group_selection(self):
704715 """
705716 Create group based selection. Used when selection is not passed
706717 directly but instead via a grouper.
718+
719+ NOTE: this should be paired with a call to _reset_group_selection
707720 """
708721 grp = self .grouper
709722 if not (self .as_index and
@@ -785,10 +798,10 @@ def _make_wrapper(self, name):
785798 type (self ).__name__ ))
786799 raise AttributeError (msg )
787800
788- # need to setup the selection
789- # as are not passed directly but in the grouper
790801 self ._set_group_selection ()
791802
803+ # need to setup the selection
804+ # as are not passed directly but in the grouper
792805 f = getattr (self ._selected_obj , name )
793806 if not isinstance (f , types .MethodType ):
794807 return self .apply (lambda self : getattr (self , name ))
@@ -913,9 +926,8 @@ def f(g):
913926 # fails on *some* columns, e.g. a numeric operation
914927 # on a string grouper column
915928
916- self ._set_group_selection ()
917- result = self ._python_apply_general (f )
918- self ._reset_group_selection ()
929+ with _group_selection_context (self ):
930+ return self ._python_apply_general (f )
919931
920932 return result
921933
@@ -1295,9 +1307,9 @@ def mean(self, *args, **kwargs):
12951307 except GroupByError :
12961308 raise
12971309 except Exception : # pragma: no cover
1298- self . _set_group_selection ()
1299- f = lambda x : x .mean (axis = self .axis , ** kwargs )
1300- return self ._python_agg_general (f )
1310+ with _group_selection_context ( self ):
1311+ f = lambda x : x .mean (axis = self .axis , ** kwargs )
1312+ return self ._python_agg_general (f )
13011313
13021314 @Substitution (name = 'groupby' )
13031315 @Appender (_doc_template )
@@ -1313,13 +1325,12 @@ def median(self, **kwargs):
13131325 raise
13141326 except Exception : # pragma: no cover
13151327
1316- self ._set_group_selection ()
1317-
13181328 def f (x ):
13191329 if isinstance (x , np .ndarray ):
13201330 x = Series (x )
13211331 return x .median (axis = self .axis , ** kwargs )
1322- return self ._python_agg_general (f )
1332+ with _group_selection_context (self ):
1333+ return self ._python_agg_general (f )
13231334
13241335 @Substitution (name = 'groupby' )
13251336 @Appender (_doc_template )
@@ -1356,9 +1367,9 @@ def var(self, ddof=1, *args, **kwargs):
13561367 if ddof == 1 :
13571368 return self ._cython_agg_general ('var' , ** kwargs )
13581369 else :
1359- self ._set_group_selection ()
13601370 f = lambda x : x .var (ddof = ddof , ** kwargs )
1361- return self ._python_agg_general (f )
1371+ with _group_selection_context (self ):
1372+ return self ._python_agg_general (f )
13621373
13631374 @Substitution (name = 'groupby' )
13641375 @Appender (_doc_template )
@@ -1404,6 +1415,7 @@ def f(self, **kwargs):
14041415 kwargs ['numeric_only' ] = numeric_only
14051416 if 'min_count' not in kwargs :
14061417 kwargs ['min_count' ] = min_count
1418+
14071419 self ._set_group_selection ()
14081420 try :
14091421 return self ._cython_agg_general (
@@ -1797,13 +1809,12 @@ def ngroup(self, ascending=True):
17971809 .cumcount : Number the rows in each group.
17981810 """
17991811
1800- self ._set_group_selection ()
1801-
1802- index = self ._selected_obj .index
1803- result = Series (self .grouper .group_info [0 ], index )
1804- if not ascending :
1805- result = self .ngroups - 1 - result
1806- return result
1812+ with _group_selection_context (self ):
1813+ index = self ._selected_obj .index
1814+ result = Series (self .grouper .group_info [0 ], index )
1815+ if not ascending :
1816+ result = self .ngroups - 1 - result
1817+ return result
18071818
18081819 @Substitution (name = 'groupby' )
18091820 def cumcount (self , ascending = True ):
@@ -1854,11 +1865,10 @@ def cumcount(self, ascending=True):
18541865 .ngroup : Number the groups themselves.
18551866 """
18561867
1857- self ._set_group_selection ()
1858-
1859- index = self ._selected_obj .index
1860- cumcounts = self ._cumcount_array (ascending = ascending )
1861- return Series (cumcounts , index )
1868+ with _group_selection_context (self ):
1869+ index = self ._selected_obj .index
1870+ cumcounts = self ._cumcount_array (ascending = ascending )
1871+ return Series (cumcounts , index )
18621872
18631873 @Substitution (name = 'groupby' )
18641874 @Appender (_doc_template )
0 commit comments