@@ -803,39 +803,29 @@ def test_transform_with_non_scalar_group():
803803
804804
805805@pytest .mark .parametrize (
806- "cols,exp,comp_func " ,
806+ "cols,expected " ,
807807 [
808- ("a" , Series ([1 , 1 , 1 ], name = "a" ), tm . assert_series_equal ),
808+ ("a" , Series ([1 , 1 , 1 ], name = "a" )),
809809 (
810810 ["a" , "c" ],
811811 DataFrame ({"a" : [1 , 1 , 1 ], "c" : [1 , 1 , 1 ]}),
812- tm .assert_frame_equal ,
813812 ),
814813 ],
815814)
816815@pytest .mark .parametrize ("agg_func" , ["count" , "rank" , "size" ])
817- def test_transform_numeric_ret (cols , exp , comp_func , agg_func , request ):
818- if agg_func == "size" and isinstance (cols , list ):
819- # https://github.com/pytest-dev/pytest/issues/6300
820- # workaround to xfail fixture/param permutations
821- reason = "'size' transformation not supported with NDFrameGroupy"
822- request .node .add_marker (pytest .mark .xfail (reason = reason ))
823-
824- # GH 19200
816+ def test_transform_numeric_ret (cols , expected , agg_func ):
817+ # GH#19200 and GH#27469
825818 df = DataFrame (
826819 {"a" : date_range ("2018-01-01" , periods = 3 ), "b" : range (3 ), "c" : range (7 , 10 )}
827820 )
828-
829- warn = FutureWarning
830- if isinstance (exp , Series ) or agg_func != "size" :
831- warn = None
832- with tm .assert_produces_warning (warn , match = "Dropping invalid columns" ):
833- result = df .groupby ("b" )[cols ].transform (agg_func )
821+ result = df .groupby ("b" )[cols ].transform (agg_func )
834822
835823 if agg_func == "rank" :
836- exp = exp .astype ("float" )
837-
838- comp_func (result , exp )
824+ expected = expected .astype ("float" )
825+ elif agg_func == "size" and cols == ["a" , "c" ]:
826+ # transform("size") returns a Series
827+ expected = expected ["a" ].rename (None )
828+ tm .assert_equal (result , expected )
839829
840830
841831def test_transform_ffill ():
@@ -1131,27 +1121,19 @@ def test_transform_agg_by_name(request, reduction_func, obj):
11311121 request .node .add_marker (
11321122 pytest .mark .xfail (reason = "TODO: g.transform('ngroup') doesn't work" )
11331123 )
1134- if func == "size" and obj .ndim == 2 : # GH#27469
1135- request .node .add_marker (
1136- pytest .mark .xfail (reason = "TODO: g.transform('size') doesn't work" )
1137- )
11381124 if func == "corrwith" and isinstance (obj , Series ): # GH#32293
11391125 request .node .add_marker (
11401126 pytest .mark .xfail (reason = "TODO: implement SeriesGroupBy.corrwith" )
11411127 )
11421128
11431129 args = {"nth" : [0 ], "quantile" : [0.5 ], "corrwith" : [obj ]}.get (func , [])
1144-
1145- warn = None
1146- if isinstance (obj , DataFrame ) and func == "size" :
1147- warn = FutureWarning
1148-
1149- with tm .assert_produces_warning (warn , match = "Dropping invalid columns" ):
1150- result = g .transform (func , * args )
1130+ result = g .transform (func , * args )
11511131
11521132 # this is the *definition* of a transformation
11531133 tm .assert_index_equal (result .index , obj .index )
1154- if hasattr (obj , "columns" ):
1134+
1135+ if func != "size" and obj .ndim == 2 :
1136+ # size returns a Series, unlike other transforms
11551137 tm .assert_index_equal (result .columns , obj .columns )
11561138
11571139 # verify that values were broadcasted across each group
0 commit comments