Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 80 additions & 15 deletions sdks/python/apache_beam/dataframe/frames.py
Original file line number Diff line number Diff line change
Expand Up @@ -4232,22 +4232,11 @@ def __getitem__(self, name):
projection=name)

@frame_base.with_docs_from(DataFrameGroupBy)
def agg(self, fn, *args, **kwargs):
if _is_associative(fn):
return _liftable_agg(fn)(self, *args, **kwargs)
elif _is_liftable_with_sum(fn):
return _liftable_agg(fn, postagg_meth='sum')(self, *args, **kwargs)
elif _is_unliftable(fn):
return _unliftable_agg(fn)(self, *args, **kwargs)
elif callable(fn):
return DeferredDataFrame(
expressions.ComputedExpression(
'agg',
lambda gb: gb.agg(fn, *args, **kwargs), [self._expr],
requires_partition_by=partitionings.Index(),
preserves_partition_by=partitionings.Singleton()))
def agg(self, fn=None, *args, **kwargs):
if fn is None:
return _agg_with_no_function(self, *args, **kwargs)
else:
raise NotImplementedError(f"GroupBy.agg(func={fn!r})")
return _handle_agg_function(self, fn, "agg", *args, **kwargs)

@property
def ndim(self):
Expand Down Expand Up @@ -4696,6 +4685,82 @@ def _check_str_or_np_builtin(agg_func, func_list):
getattr(agg_func, '__name__', None) in func_list
and agg_func.__module__ in ('numpy', 'builtins'))

def _agg_with_no_function(gb, *args, **kwargs):
"""
Applies aggregation functions to the grouped data based on the provided
keyword arguments and combines the results into a single DataFrame.

Args:
gb: The groupby instance (DeferredGroupBy).
*args: Additional positional arguments passed to the aggregation funcs.
**kwargs: A dictionary where each key is the column name to aggregate,
the value is a tuple containing the input column name and
the aggregation function to apply.

Returns:
DeferredDataFrame: A DataFrame that contains the aggregated results of
all specified columns.

Raises:
ValueError: If no aggregation functions are provided in the `kwargs`.
NotImplementedError: If the aggregation function type is unsupported.
"""
if not kwargs:
raise ValueError("No aggregation functions specified")

# Handle dictionary-like input for aggregation.
result_columns, result_frames = [], []
for col_name, (input_col, agg_fn) in kwargs.items():
frame = _handle_agg_function(
gb[input_col], agg_fn, f"agg_{col_name}", *args
)
result_frames.append(frame)
result_columns.append(col_name)

# Combine all the resulting DeferredDataFrames into a single DataFrame.
return DeferredDataFrame(
expressions.ComputedExpression(
"agg",
lambda *results: pd.concat(results, axis=1, keys=result_columns),
[frame._expr for frame in result_frames],
requires_partition_by=partitionings.Index(),
preserves_partition_by=partitionings.Singleton(),
)
)

def _handle_agg_function(gb, agg_func, agg_name, *args, **kwargs):
"""
Handles the aggregation logic based on the function type passed.

Args:
gb: The groupby instance (DeferredGroupBy).
agg_name: The name/label of the aggregation function.
fn: The aggregation function to apply.
*args: Additional arguments to pass to the aggregation function.
**kwargs: Keyword arguments to pass to the aggregation function.

Returns:
A DeferredDataFrame or the result of the aggregation function.

Raises:
NotImplementedError: If the aggregation function type is unsupported.
"""
if _is_associative(agg_func):
return _liftable_agg(agg_func)(gb, *args, **kwargs)
elif _is_liftable_with_sum(agg_func):
return _liftable_agg(agg_func, postagg_meth='sum')(gb, *args, **kwargs)
elif _is_unliftable(agg_func):
return _unliftable_agg(agg_func)(gb, *args, **kwargs)
elif callable(agg_func):
return DeferredDataFrame(
expressions.ComputedExpression(
agg_name,
lambda gb_val: gb_val.agg(agg_func, *args, **kwargs),
[gb._expr],
requires_partition_by=partitionings.Index(),
preserves_partition_by=partitionings.Singleton()))
else:
raise NotImplementedError(f"GroupBy.agg(func={agg_func!r})")

def _is_associative(agg_func):
return _check_str_or_np_builtin(agg_func, LIFTABLE_AGGREGATIONS)
Expand Down
10 changes: 10 additions & 0 deletions sdks/python/apache_beam/dataframe/frames_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2364,6 +2364,16 @@ def test_std_all_na(self):
self._run_test(lambda s: s.agg('std'), s)
self._run_test(lambda s: s.std(), s)

def test_df_agg_operations_on_columns(self):
self._run_test(
lambda df: df.groupby('group').agg(
mean_foo=('foo', lambda x: np.mean(x)),
median_bar=('bar', lambda x: np.median(x)),
sum_baz=('baz', 'sum'),
count_bool=('bool', 'count'),
),
GROUPBY_DF)

def test_std_mostly_na_with_ddof(self):
df = pd.DataFrame({
'one': [i if i % 8 == 0 else np.nan for i in range(8)],
Expand Down
Loading