@@ -234,11 +234,10 @@ def test_aggregate_item_by_item(df):
234234 K = len (result .columns )
235235
236236 # GH5782
237- # odd comparisons can result here, so cast to make easy
238- exp = Series (np .array ([foo ] * K ), index = list ("BCD" ), dtype = np .float64 , name = "foo" )
237+ exp = Series (np .array ([foo ] * K ), index = list ("BCD" ), name = "foo" )
239238 tm .assert_series_equal (result .xs ("foo" ), exp )
240239
241- exp = Series (np .array ([bar ] * K ), index = list ("BCD" ), dtype = np . float64 , name = "bar" )
240+ exp = Series (np .array ([bar ] * K ), index = list ("BCD" ), name = "bar" )
242241 tm .assert_almost_equal (result .xs ("bar" ), exp )
243242
244243 def aggfun (ser ):
@@ -442,6 +441,57 @@ def test_bool_agg_dtype(op):
442441 assert is_integer_dtype (result )
443442
444443
444+ @pytest .mark .parametrize (
445+ "keys, agg_index" ,
446+ [
447+ (["a" ], Index ([1 ], name = "a" )),
448+ (["a" , "b" ], MultiIndex ([[1 ], [2 ]], [[0 ], [0 ]], names = ["a" , "b" ])),
449+ ],
450+ )
451+ @pytest .mark .parametrize (
452+ "input_dtype" , ["bool" , "int32" , "int64" , "float32" , "float64" ]
453+ )
454+ @pytest .mark .parametrize (
455+ "result_dtype" , ["bool" , "int32" , "int64" , "float32" , "float64" ]
456+ )
457+ @pytest .mark .parametrize ("method" , ["apply" , "aggregate" , "transform" ])
458+ def test_callable_result_dtype_frame (
459+ keys , agg_index , input_dtype , result_dtype , method
460+ ):
461+ # GH 21240
462+ df = DataFrame ({"a" : [1 ], "b" : [2 ], "c" : [True ]})
463+ df ["c" ] = df ["c" ].astype (input_dtype )
464+ op = getattr (df .groupby (keys )[["c" ]], method )
465+ result = op (lambda x : x .astype (result_dtype ).iloc [0 ])
466+ expected_index = pd .RangeIndex (0 , 1 ) if method == "transform" else agg_index
467+ expected = DataFrame ({"c" : [df ["c" ].iloc [0 ]]}, index = expected_index ).astype (
468+ result_dtype
469+ )
470+ if method == "apply" :
471+ expected .columns .names = [0 ]
472+ tm .assert_frame_equal (result , expected )
473+
474+
475+ @pytest .mark .parametrize (
476+ "keys, agg_index" ,
477+ [
478+ (["a" ], Index ([1 ], name = "a" )),
479+ (["a" , "b" ], MultiIndex ([[1 ], [2 ]], [[0 ], [0 ]], names = ["a" , "b" ])),
480+ ],
481+ )
482+ @pytest .mark .parametrize ("input" , [True , 1 , 1.0 ])
483+ @pytest .mark .parametrize ("dtype" , [bool , int , float ])
484+ @pytest .mark .parametrize ("method" , ["apply" , "aggregate" , "transform" ])
485+ def test_callable_result_dtype_series (keys , agg_index , input , dtype , method ):
486+ # GH 21240
487+ df = DataFrame ({"a" : [1 ], "b" : [2 ], "c" : [input ]})
488+ op = getattr (df .groupby (keys )["c" ], method )
489+ result = op (lambda x : x .astype (dtype ).iloc [0 ])
490+ expected_index = pd .RangeIndex (0 , 1 ) if method == "transform" else agg_index
491+ expected = Series ([df ["c" ].iloc [0 ]], index = expected_index , name = "c" ).astype (dtype )
492+ tm .assert_series_equal (result , expected )
493+
494+
445495def test_order_aggregate_multiple_funcs ():
446496 # GH 25692
447497 df = DataFrame ({"A" : [1 , 1 , 2 , 2 ], "B" : [1 , 2 , 3 , 4 ]})
@@ -462,7 +512,9 @@ def test_uint64_type_handling(dtype, how):
462512 expected = df .groupby ("y" ).agg ({"x" : how })
463513 df .x = df .x .astype (dtype )
464514 result = df .groupby ("y" ).agg ({"x" : how })
465- result .x = result .x .astype (np .int64 )
515+ if how not in ("mean" , "median" ):
516+ # mean and median always result in floats
517+ result .x = result .x .astype (np .int64 )
466518 tm .assert_frame_equal (result , expected , check_exact = True )
467519
468520
@@ -849,7 +901,11 @@ def test_multiindex_custom_func(func):
849901 data = [[1 , 4 , 2 ], [5 , 7 , 1 ]]
850902 df = DataFrame (data , columns = MultiIndex .from_arrays ([[1 , 1 , 2 ], [3 , 4 , 3 ]]))
851903 result = df .groupby (np .array ([0 , 1 ])).agg (func )
852- expected_dict = {(1 , 3 ): {0 : 1 , 1 : 5 }, (1 , 4 ): {0 : 4 , 1 : 7 }, (2 , 3 ): {0 : 2 , 1 : 1 }}
904+ expected_dict = {
905+ (1 , 3 ): {0 : 1.0 , 1 : 5.0 },
906+ (1 , 4 ): {0 : 4.0 , 1 : 7.0 },
907+ (2 , 3 ): {0 : 2.0 , 1 : 1.0 },
908+ }
853909 expected = DataFrame (expected_dict )
854910 tm .assert_frame_equal (result , expected )
855911
0 commit comments