Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
16 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/source/whatsnew/v1.1.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,7 @@ Groupby/resample/rolling
^^^^^^^^^^^^^^^^^^^^^^^^

- Bug in :meth:`GroupBy.apply` raises ``ValueError`` when the ``by`` axis is not sorted and has duplicates and the applied ``func`` does not mutate passed in objects (:issue:`30667`)
- Bug in :meth:`DataFrameGroupby.transform` produces incorrect result with transformation functions (:issue:`30918`)

Reshaping
^^^^^^^^^
Expand Down
30 changes: 14 additions & 16 deletions pandas/core/groupby/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1416,22 +1416,20 @@ def transform(self, func, *args, **kwargs):
# cythonized transformation or canned "reduction+broadcast"
return getattr(self, func)(*args, **kwargs)

# If func is a reduction, we need to broadcast the
# result to the whole group. Compute func result
# and deal with possible broadcasting below.
result = getattr(self, func)(*args, **kwargs)

# a reduction transform
if not isinstance(result, DataFrame):
return self._transform_general(func, *args, **kwargs)

obj = self._obj_with_exclusions

# nuisance columns
if not result.columns.equals(obj.columns):
return self._transform_general(func, *args, **kwargs)

return self._transform_fast(result, func)
# GH 30918
# Use _transform_fast only when we know func is an aggregation
if func in base.reduction_kernels:
# If func is a reduction, we need to broadcast the
# result to the whole group. Compute func result
# and deal with possible broadcasting below.
result = getattr(self, func)(*args, **kwargs)

if isinstance(result, DataFrame) and result.columns.equals(
self._obj_with_exclusions.columns
):
return self._transform_fast(result, func)

return self._transform_general(func, *args, **kwargs)

def _transform_fast(self, result: DataFrame, func_nm: str) -> DataFrame:
"""
Expand Down
26 changes: 26 additions & 0 deletions pandas/tests/groupby/test_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,32 @@ def test_dispatch_transform(tsframe):
tm.assert_frame_equal(filled, expected)


def test_transform_transformation_func(transformation_func):
# GH 30918
df = DataFrame(
{
"A": ["foo", "foo", "foo", "foo", "bar", "bar", "baz"],
"B": [1, 2, np.nan, 3, 3, np.nan, 4],
}
)

if transformation_func in ["pad", "backfill", "tshift", "corrwith", "cumcount"]:
# These transformation functions are not yet covered in this test
pytest.xfail("See GH 31269 and GH 31270")
elif transformation_func == "fillna":
test_op = lambda x: x.transform("fillna", value=0)
mock_op = lambda x: x.fillna(value=0)
else:
test_op = lambda x: x.transform(transformation_func)
mock_op = lambda x: getattr(x, transformation_func)()

result = test_op(df.groupby("A"))
groups = [df[["B"]].iloc[:4], df[["B"]].iloc[4:6], df[["B"]].iloc[6:]]
expected = concat([mock_op(g) for g in groups])

tm.assert_frame_equal(result, expected)


def test_transform_select_columns(df):
f = lambda x: x.mean()
result = df.groupby("A")[["C", "D"]].transform(f)
Expand Down