@@ -1198,25 +1198,59 @@ def test_transform_lambda_indexing():
11981198 tm .assert_frame_equal (result , expected )
11991199
12001200
1201- @pytest .mark .parametrize (
1202- "input_df" ,
1203- [
1204- DataFrame (
1205- {
1206- "A" : [121 , 121 , 121 , 121 , 231 , 231 , 676 ],
1207- "B" : [1 , 2 , np .nan , 3 , 3 , np .nan , 4 ],
1208- }
1209- ),
1210- DataFrame (
1211- {
1212- "A" : [121 , 121 , 121 , 121 , 231 , 231 , 676 ],
1213- "B" : [1.0 , 2.0 , 2.0 , 3.0 , 3.0 , 3.0 , 4.0 ],
1214- }
1215- ),
1216- ],
1217- )
1218- def test_groupby_transform_fillna (input_df ):
1219- # GH 27905
1220- result = input_df .groupby ("A" ).transform (lambda x : x .fillna (x .mean ()))
1221- expected = pd .DataFrame ({"B" : [1.0 , 2.0 , 2.0 , 3.0 , 3.0 , 3.0 , 4.0 ]})
1222- tm .assert_frame_equal (result , expected )
1201+ def test_transform_nan_tshift_corrwith (transformation_func ):
1202+
1203+ df1 = DataFrame (
1204+ {
1205+ "A" : [121 , 121 , 121 , 121 , 231 , 231 , 676 ],
1206+ "B" : [1.0 , 2.0 , 2.0 , 3.0 , 3.0 , 3.0 , 4.0 ],
1207+ }
1208+ )
1209+ g1 = df1 .groupby ("A" )
1210+
1211+ if transformation_func == "corrwith" :
1212+ result = g1 .corrwith (df1 )
1213+ expected = pd .DataFrame (dict (B = [1 , np .nan , np .nan ], A = [np .nan ] * 3 ))
1214+ expected .index = pd .Index ([121 , 231 , 676 ], name = "A" )
1215+ tm .assert_frame_equal (result , expected )
1216+
1217+ if transformation_func == "fillna" :
1218+ df3 = df1 .copy ()
1219+ df3 ["B" ] = [1 , np .nan , np .nan , 3 , np .nan , 3 , 4 ]
1220+ result = df3 .groupby ("A" ).transform (lambda x : x .fillna (x .mean ()))
1221+ expected = pd .DataFrame ({"B" : [1.0 , 2.0 , 2.0 , 3.0 , 3.0 , 3.0 , 4.0 ]})
1222+ tm .assert_frame_equal (result , expected )
1223+
1224+ result = df3 .groupby ("A" ).transform (transformation_func , value = 1 )
1225+ expected = pd .DataFrame ({"B" : [1.0 , 1.0 , 1.0 , 3.0 , 1.0 , 3.0 , 4.0 ]})
1226+ tm .assert_frame_equal (result , expected )
1227+
1228+ if transformation_func == "tshift" :
1229+ df2 = df1 .copy ()
1230+ dt_periods = pd .date_range ("2013-11-03" , periods = 7 , freq = "D" )
1231+ df2 ["C" ] = dt_periods
1232+ result = df2 .set_index ("C" ).groupby ("A" ).tshift (2 , "D" )
1233+ df2 ["C" ] = dt_periods + dt_periods .freq * 2
1234+ expected = df2
1235+ tm .assert_frame_equal (
1236+ result .reset_index ().reindex (columns = ["A" , "B" , "C" ]), expected
1237+ )
1238+
1239+
1240+ def test_check_original_and_transformed_index (transformation_func ):
1241+ df = DataFrame ({"A" : [0 , 0 , 0 , 1 , 1 , 1 ], "B" : [0 , 1 , 2 , 3 , 4 , 5 ]})
1242+ g = df .groupby ("A" )
1243+
1244+ if transformation_func in [
1245+ "cummax" ,
1246+ "cummin" ,
1247+ "cumprod" ,
1248+ "cumsum" ,
1249+ "diff" ,
1250+ "ffill" ,
1251+ "pct_change" ,
1252+ "rank" ,
1253+ "shift" ,
1254+ ]:
1255+ result = g .transform (transformation_func )
1256+ tm .assert_index_equal (result .index , df .index )
0 commit comments