diff --git a/dask_expr/_shuffle.py b/dask_expr/_shuffle.py index 3f9ef8765..5bb4748b0 100644 --- a/dask_expr/_shuffle.py +++ b/dask_expr/_shuffle.py @@ -907,19 +907,21 @@ class SortIndexBlockwise(Blockwise): _is_length_preserving = True -def sort_function(self, *args, **kwargs): - sort_func = kwargs.pop("sort_function") - sort_kwargs = kwargs.pop("sort_kwargs") - return sort_func(*args, **kwargs, **sort_kwargs) - - class SortValuesBlockwise(Blockwise): _projection_passthrough = False _parameters = ["frame", "sort_function", "sort_kwargs"] - operation = sort_function _keyword_only = ["sort_function", "sort_kwargs"] _is_length_preserving = True + def operation(self, *args, **kwargs): + sort_func = kwargs.pop("sort_function") + sort_kwargs = kwargs.pop("sort_kwargs") + return sort_func(*args, **kwargs, **sort_kwargs) + + @functools.cached_property + def _meta(self): + return self.frame._meta + class SetIndexBlockwise(Blockwise): _parameters = ["frame", "other", "drop", "new_divisions"] diff --git a/dask_expr/tests/test_distributed.py b/dask_expr/tests/test_distributed.py index 6b07a4e7b..d0c0156bb 100644 --- a/dask_expr/tests/test_distributed.py +++ b/dask_expr/tests/test_distributed.py @@ -1,6 +1,7 @@ from __future__ import annotations import pytest +from distributed import Client, LocalCluster from dask_expr import from_pandas from dask_expr.tests._util import _backend_library @@ -86,3 +87,16 @@ async def test_merge_p2p_shuffle(c, s, a, b): lib.testing.assert_frame_equal( x.reset_index(drop=True), df_left.merge(df_right)[["b", "c"]] ) + + +def test_sort_values(): + with LocalCluster() as cluster: + with Client(cluster) as client: # noqa: F841 + pdf = lib.DataFrame({"a": [5] + list(range(100)), "b": 2}) + df = from_pandas(pdf, npartitions=10) + + out = df.sort_values(by="a").compute() + lib.testing.assert_frame_equal( + out.reset_index(drop=True), + pdf.sort_values(by="a", ignore_index=True), + )