diff --git a/sdks/python/apache_beam/dataframe/frames.py b/sdks/python/apache_beam/dataframe/frames.py index ccd01f35f87b..ce6db1a12057 100644 --- a/sdks/python/apache_beam/dataframe/frames.py +++ b/sdks/python/apache_beam/dataframe/frames.py @@ -4232,22 +4232,11 @@ def __getitem__(self, name): projection=name) @frame_base.with_docs_from(DataFrameGroupBy) - def agg(self, fn, *args, **kwargs): - if _is_associative(fn): - return _liftable_agg(fn)(self, *args, **kwargs) - elif _is_liftable_with_sum(fn): - return _liftable_agg(fn, postagg_meth='sum')(self, *args, **kwargs) - elif _is_unliftable(fn): - return _unliftable_agg(fn)(self, *args, **kwargs) - elif callable(fn): - return DeferredDataFrame( - expressions.ComputedExpression( - 'agg', - lambda gb: gb.agg(fn, *args, **kwargs), [self._expr], - requires_partition_by=partitionings.Index(), - preserves_partition_by=partitionings.Singleton())) + def agg(self, fn=None, *args, **kwargs): + if fn is None: + return _agg_with_no_function(self, *args, **kwargs) else: - raise NotImplementedError(f"GroupBy.agg(func={fn!r})") + return _handle_agg_function(self, fn, "agg", *args, **kwargs) @property def ndim(self): @@ -4696,6 +4685,82 @@ def _check_str_or_np_builtin(agg_func, func_list): getattr(agg_func, '__name__', None) in func_list and agg_func.__module__ in ('numpy', 'builtins')) +def _agg_with_no_function(gb, *args, **kwargs): + """ + Applies aggregation functions to the grouped data based on the provided + keyword arguments and combines the results into a single DataFrame. + + Args: + gb: The groupby instance (DeferredGroupBy). + *args: Additional positional arguments passed to the aggregation funcs. + **kwargs: A dictionary where each key is the column name to aggregate, + the value is a tuple containing the input column name and + the aggregation function to apply. + + Returns: + DeferredDataFrame: A DataFrame that contains the aggregated results of + all specified columns. + + Raises: + ValueError: If no aggregation functions are provided in the `kwargs`. + NotImplementedError: If the aggregation function type is unsupported. + """ + if not kwargs: + raise ValueError("No aggregation functions specified") + + # Handle dictionary-like input for aggregation. + result_columns, result_frames = [], [] + for col_name, (input_col, agg_fn) in kwargs.items(): + frame = _handle_agg_function( + gb[input_col], agg_fn, f"agg_{col_name}", *args + ) + result_frames.append(frame) + result_columns.append(col_name) + + # Combine all the resulting DeferredDataFrames into a single DataFrame. + return DeferredDataFrame( + expressions.ComputedExpression( + "agg", + lambda *results: pd.concat(results, axis=1, keys=result_columns), + [frame._expr for frame in result_frames], + requires_partition_by=partitionings.Index(), + preserves_partition_by=partitionings.Singleton(), + ) + ) + +def _handle_agg_function(gb, agg_func, agg_name, *args, **kwargs): + """ + Handles the aggregation logic based on the function type passed. + + Args: + gb: The groupby instance (DeferredGroupBy). + agg_name: The name/label of the aggregation function. + fn: The aggregation function to apply. + *args: Additional arguments to pass to the aggregation function. + **kwargs: Keyword arguments to pass to the aggregation function. + + Returns: + A DeferredDataFrame or the result of the aggregation function. + + Raises: + NotImplementedError: If the aggregation function type is unsupported. + """ + if _is_associative(agg_func): + return _liftable_agg(agg_func)(gb, *args, **kwargs) + elif _is_liftable_with_sum(agg_func): + return _liftable_agg(agg_func, postagg_meth='sum')(gb, *args, **kwargs) + elif _is_unliftable(agg_func): + return _unliftable_agg(agg_func)(gb, *args, **kwargs) + elif callable(agg_func): + return DeferredDataFrame( + expressions.ComputedExpression( + agg_name, + lambda gb_val: gb_val.agg(agg_func, *args, **kwargs), + [gb._expr], + requires_partition_by=partitionings.Index(), + preserves_partition_by=partitionings.Singleton())) + else: + raise NotImplementedError(f"GroupBy.agg(func={agg_func!r})") def _is_associative(agg_func): return _check_str_or_np_builtin(agg_func, LIFTABLE_AGGREGATIONS) diff --git a/sdks/python/apache_beam/dataframe/frames_test.py b/sdks/python/apache_beam/dataframe/frames_test.py index bfe87dab52eb..27c62d2772ad 100644 --- a/sdks/python/apache_beam/dataframe/frames_test.py +++ b/sdks/python/apache_beam/dataframe/frames_test.py @@ -2364,6 +2364,16 @@ def test_std_all_na(self): self._run_test(lambda s: s.agg('std'), s) self._run_test(lambda s: s.std(), s) + def test_df_agg_operations_on_columns(self): + self._run_test( + lambda df: df.groupby('group').agg( + mean_foo=('foo', lambda x: np.mean(x)), + median_bar=('bar', lambda x: np.median(x)), + sum_baz=('baz', 'sum'), + count_bool=('bool', 'count'), + ), + GROUPBY_DF) + def test_std_mostly_na_with_ddof(self): df = pd.DataFrame({ 'one': [i if i % 8 == 0 else np.nan for i in range(8)],