Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 9 additions & 7 deletions dask_expr/_shuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -907,19 +907,21 @@ class SortIndexBlockwise(Blockwise):
_is_length_preserving = True


def sort_function(self, *args, **kwargs):
sort_func = kwargs.pop("sort_function")
sort_kwargs = kwargs.pop("sort_kwargs")
return sort_func(*args, **kwargs, **sort_kwargs)


class SortValuesBlockwise(Blockwise):
_projection_passthrough = False
_parameters = ["frame", "sort_function", "sort_kwargs"]
operation = sort_function
_keyword_only = ["sort_function", "sort_kwargs"]
_is_length_preserving = True

def operation(self, *args, **kwargs):
sort_func = kwargs.pop("sort_function")
sort_kwargs = kwargs.pop("sort_kwargs")
return sort_func(*args, **kwargs, **sort_kwargs)

@functools.cached_property
def _meta(self):
return self.frame._meta


class SetIndexBlockwise(Blockwise):
_parameters = ["frame", "other", "drop", "new_divisions"]
Expand Down
14 changes: 14 additions & 0 deletions dask_expr/tests/test_distributed.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import pytest
from distributed import Client, LocalCluster

from dask_expr import from_pandas
from dask_expr.tests._util import _backend_library
Expand Down Expand Up @@ -86,3 +87,16 @@ async def test_merge_p2p_shuffle(c, s, a, b):
lib.testing.assert_frame_equal(
x.reset_index(drop=True), df_left.merge(df_right)[["b", "c"]]
)


def test_sort_values():
with LocalCluster() as cluster:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any way to leverage gen_cluster here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, it's async

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you explain more? Also, do we really want to allow a test to use all the available cores on our machine like this?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You maybe want these fixtures from distributed.utils_test

@pytest.fixture
def cluster_fixture(loop):
    with cluster() as (scheduler, workers):
        yield (scheduler, workers)


@pytest.fixture
def s(cluster_fixture):
    scheduler, workers = cluster_fixture
    return scheduler


@pytest.fixture
def a(cluster_fixture):
    scheduler, workers = cluster_fixture
    return workers[0]


@pytest.fixture
def b(cluster_fixture):
    scheduler, workers = cluster_fixture
    return workers[1]


@pytest.fixture
def client(loop, cluster_fixture):
    scheduler, workers = cluster_fixture
    with Client(scheduler["address"], loop=loop) as client:
        yield client

They handle cleanup and various sanity checks.

Alternatively, if you want things to be very fast, you might also want with LocalCluster(processes=False)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the gen_cluster decorator creates a async cluster and async tests will fail for sort_values/set_index because of intermediate computes. That's not a dask-expr problem, dask/dask has the same issue.

Sure, we can restrict the number of workers

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I prefer the fast version, this is only supposed to test that the function serialisation works as expected

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, we can restrict the number of workers

Sorry for mentioning that - I care less about that detail (for now).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See dask/distributed#8167 This one has the same issue in dask/dask

with Client(cluster) as client: # noqa: F841
pdf = lib.DataFrame({"a": [5] + list(range(100)), "b": 2})
df = from_pandas(pdf, npartitions=10)

out = df.sort_values(by="a").compute()
lib.testing.assert_frame_equal(
out.reset_index(drop=True),
pdf.sort_values(by="a", ignore_index=True),
)