diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 55cb0bd88..2e3d8b108 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -44,7 +44,7 @@ jobs: # Wipe cache every 24 hours or whenever environment.yml changes. This means it # may take up to a day before changes to unpinned packages are picked up. # To force a cache refresh, change the hardcoded numerical suffix below. - cache-environment-key: environment-${{ steps.date.outputs.date }}-0 + cache-environment-key: environment-${{ steps.date.outputs.date }}-1 - name: Install dask-expr run: python -m pip install -e . --no-deps diff --git a/dask_expr/_groupby.py b/dask_expr/_groupby.py index c3851210d..e162eaf8f 100644 --- a/dask_expr/_groupby.py +++ b/dask_expr/_groupby.py @@ -27,6 +27,7 @@ _cumcount_aggregate, _determine_levels, _groupby_aggregate, + _groupby_aggregate_spec, _groupby_apply_funcs, _groupby_get_group, _groupby_slice_apply, @@ -34,6 +35,7 @@ _groupby_slice_transform, _head_aggregate, _head_chunk, + _non_agg_chunk, _normalize_spec, _nunique_df_chunk, _nunique_df_combine, @@ -249,10 +251,10 @@ def _simplify_up(self, parent, dependents): return groupby_projection(self, parent, dependents) -class GroupbyAggregation(GroupByApplyConcatApply, GroupByBase): - """General groupby aggregation +class GroupbyAggregationBase(GroupByApplyConcatApply, GroupByBase): + """Base class for groupby aggregation - This class can be used directly to perform a general + This class can be subclassed to perform a general groupby aggregation by passing in a `str`, `list` or `dict`-based specification using the `arg` operand. @@ -269,10 +271,6 @@ class GroupbyAggregation(GroupByApplyConcatApply, GroupByBase): Passed through to dataframe backend. dropna: Whether rows with NA values should be dropped. - chunk_kwargs: - Key-word arguments to pass to `groupby_chunk`. - aggregate_kwargs: - Key-word arguments to pass to `aggregate_chunk`. """ _parameters = [ @@ -295,7 +293,19 @@ class GroupbyAggregation(GroupByApplyConcatApply, GroupByBase): "shuffle_method": None, "_slice": None, } - chunk = staticmethod(_groupby_apply_funcs) + + @functools.cached_property + def _meta(self): + meta = meta_nonempty(self.frame._meta) + meta = meta.groupby( + self._by_meta, + **_as_dict("observed", self.observed), + **_as_dict("dropna", self.dropna), + ) + if self._slice is not None: + meta = meta[self._slice] + meta = meta.aggregate(self.arg) + return make_meta(meta) @functools.cached_property def spec(self): @@ -329,13 +339,121 @@ def spec(self): else: raise ValueError(f"aggregate on unknown object {self.frame._meta}") - # Median not supported yet - has_median = any(s[1] in ("median", np.median) for s in spec) - if has_median: - raise NotImplementedError("median not yet supported") + return spec + @functools.cached_property + def agg_args(self): keys = ["chunk_funcs", "aggregate_funcs", "finalizers"] - return dict(zip(keys, _build_agg_args(spec))) + return dict(zip(keys, _build_agg_args(self.spec))) + + def _simplify_down(self): + if not isinstance(self.arg, dict): + return + + # Use agg-spec information to add column projection + required_columns = ( + set(self._by_columns) + .union(self.arg.keys()) + .intersection(self.frame.columns) + ) + column_projection = [ + column for column in self.frame.columns if column in required_columns + ] + if column_projection != self.frame.columns: + return type(self)(self.frame[column_projection], *self.operands[1:]) + + +class GroupbyAggregation(GroupbyAggregationBase): + """Logical groupby aggregation class + + This class lowers itself to concrete implementations for decomposable + or holistic aggregations. + """ + + @functools.cached_property + def _is_decomposable(self): + return not any(s[1] in ("median", np.median) for s in self.spec) + + def _lower(self): + cls = ( + DecomposableGroupbyAggregation + if self._is_decomposable + else HolisticGroupbyAggregation + ) + return cls( + self.frame, + self.arg, + self.observed, + self.dropna, + self.split_every, + self.split_out, + self.sort, + self.shuffle_method, + self._slice, + *self.by, + ) + + +class HolisticGroupbyAggregation(GroupbyAggregationBase): + """Groupby aggregation for both decomposable and non-decomposable aggregates + + This class always calculates the aggregates by first collecting all the data for + the groups and then aggregating at once. + """ + + chunk = staticmethod(_non_agg_chunk) + + @property + def should_shuffle(self): + return True + + @classmethod + def chunk(cls, df, *by, **kwargs): + return _non_agg_chunk(df, *by, **kwargs) + + @classmethod + def combine(cls, inputs, **kwargs): + return _groupby_aggregate_spec(_concat(inputs), **kwargs) + + @classmethod + def aggregate(cls, inputs, **kwargs): + return _groupby_aggregate_spec(_concat(inputs), **kwargs) + + @property + def chunk_kwargs(self) -> dict: + return { + "by": self._by_columns, + "key": [col for col in self.frame.columns if col not in self._by_columns], + **_as_dict("observed", self.observed), + **_as_dict("dropna", self.dropna), + } + + @property + def combine_kwargs(self) -> dict: + return { + "spec": self.arg, + "levels": _determine_levels(self.by), + **_as_dict("observed", self.observed), + **_as_dict("dropna", self.dropna), + } + + @property + def aggregate_kwargs(self) -> dict: + return { + "spec": self.arg, + "levels": _determine_levels(self.by), + **_as_dict("observed", self.observed), + **_as_dict("dropna", self.dropna), + } + + +class DecomposableGroupbyAggregation(GroupbyAggregationBase): + """Groupby aggregation for decomposable aggregates + + The results may be calculated via tree or shuffle reduction. + """ + + chunk = staticmethod(_groupby_apply_funcs) @classmethod def combine(cls, inputs, **kwargs): @@ -348,7 +466,7 @@ def aggregate(cls, inputs, **kwargs): @property def chunk_kwargs(self) -> dict: return { - "funcs": self.spec["chunk_funcs"], + "funcs": self.agg_args["chunk_funcs"], "sort": self.sort, **_as_dict("observed", self.observed), **_as_dict("dropna", self.dropna), @@ -357,7 +475,7 @@ def chunk_kwargs(self) -> dict: @property def combine_kwargs(self) -> dict: return { - "funcs": self.spec["aggregate_funcs"], + "funcs": self.agg_args["aggregate_funcs"], "level": self.levels, "sort": self.sort, **_as_dict("observed", self.observed), @@ -367,26 +485,17 @@ def combine_kwargs(self) -> dict: @property def aggregate_kwargs(self) -> dict: return { - "aggregate_funcs": self.spec["aggregate_funcs"], - "finalize_funcs": self.spec["finalizers"], + "aggregate_funcs": self.agg_args["aggregate_funcs"], + "arg": self.arg, + "columns": self._slice, + "finalize_funcs": self.agg_args["finalizers"], + "is_series": self._meta.ndim == 1, "level": self.levels, "sort": self.sort, **_as_dict("observed", self.observed), **_as_dict("dropna", self.dropna), } - def _simplify_down(self): - # Use agg-spec information to add column projection - column_projection = None - if isinstance(self.arg, dict): - column_projection = ( - set(self._by_columns) - .union(self.arg.keys()) - .intersection(self.frame.columns) - ) - if column_projection and column_projection < set(self.frame.columns): - return type(self)(self.frame[list(column_projection)], *self.operands[1:]) - class Sum(SingleAggregation): groupby_chunk = M.sum @@ -1781,27 +1890,6 @@ def __init__( obj, by=by, slice=slice, observed=observed, dropna=dropna, sort=sort ) - def aggregate(self, arg=None, split_every=8, split_out=1, **kwargs): - result = super().aggregate( - arg=arg, split_every=split_every, split_out=split_out - ) - if self._slice: - try: - result = result[self._slice] - except KeyError: - pass - - if ( - arg is not None - and not isinstance(arg, (list, dict)) - and is_dataframe_like(result._meta) - ): - result = result[result.columns[0]] - - return result - - agg = aggregate - def idxmin( self, split_every=None, split_out=1, skipna=True, numeric_only=False, **kwargs ): diff --git a/dask_expr/_reductions.py b/dask_expr/_reductions.py index 7a87d5c3c..32e47af95 100644 --- a/dask_expr/_reductions.py +++ b/dask_expr/_reductions.py @@ -447,6 +447,13 @@ def _divisions(self): def _chunk_cls_args(self): return [] + @property + def should_shuffle(self): + sort = getattr(self, "sort", False) + return not ( + not isinstance(self.split_out, bool) and self.split_out == 1 or sort + ) + def _lower(self): # Normalize functions in case not all are defined chunk = self.chunk @@ -465,12 +472,11 @@ def _lower(self): combine = aggregate combine_kwargs = aggregate_kwargs - sort = getattr(self, "sort", False) split_every = getattr(self, "split_every", None) chunked = self._chunk_cls( self.frame, type(self), chunk, chunk_kwargs, *self._chunk_cls_args ) - if not isinstance(self.split_out, bool) and self.split_out == 1 or sort: + if not self.should_shuffle: # Lower into TreeReduce(Chunk) return TreeReduce( chunked, @@ -496,7 +502,7 @@ def _lower(self): split_by=self.split_by, split_out=self.split_out, split_every=split_every, - sort=sort, + sort=getattr(self, "sort", False), shuffle_by_index=getattr(self, "shuffle_by_index", None), shuffle_method=getattr(self, "shuffle_method", None), ignore_index=getattr(self, "ignore_index", True), diff --git a/dask_expr/tests/test_groupby.py b/dask_expr/tests/test_groupby.py index 810fcd265..d495fd904 100644 --- a/dask_expr/tests/test_groupby.py +++ b/dask_expr/tests/test_groupby.py @@ -238,6 +238,7 @@ def test_dataframe_aggregations_multilevel(df, pdf): {"x": ["sum", "mean"]}, ["min", "mean"], "sum", + "median", ], ) def test_groupby_agg(pdf, df, spec):