Skip to content
31 changes: 30 additions & 1 deletion distributed/shuffle/_rechunk.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from __future__ import annotations

import math
from collections import defaultdict
from itertools import product
from itertools import compress, product
from typing import TYPE_CHECKING, NamedTuple

import dask
Expand Down Expand Up @@ -72,6 +73,34 @@ def rechunk_p2p(x: da.Array, chunks: ChunkedAxes) -> da.Array:
# Special case for empty array, as the algorithm below does not behave correctly
return da.empty(x.shape, chunks=chunks, dtype=x.dtype)

old_chunks = x.chunks
new_chunks = chunks

def is_unknown(dim: ChunkedAxis) -> bool:
return any(math.isnan(chunk) for chunk in dim)

old_is_unknown = [is_unknown(dim) for dim in old_chunks]
new_is_unknown = [is_unknown(dim) for dim in new_chunks]

if old_is_unknown != new_is_unknown or any(
new != old for new, old in compress(zip(old_chunks, new_chunks), old_is_unknown)
):
raise ValueError(
"Chunks must be unchanging along dimensions with missing values.\n\n"
"A possible solution:\n x.compute_chunk_sizes()"
)

old_known = [dim for dim, unknown in zip(old_chunks, old_is_unknown) if not unknown]
new_known = [dim for dim, unknown in zip(new_chunks, new_is_unknown) if not unknown]

old_sizes = [sum(o) for o in old_known]
new_sizes = [sum(n) for n in new_known]

if old_sizes != new_sizes:
raise ValueError(
f"Cannot change dimensions from {old_sizes!r} to {new_sizes!r}"
)

Comment on lines +76 to +103
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This part is currently copied from dask.array.rechunk.old_to_new. This should be cleaned up into a follow-up PR that moves validation logic into a helper.

dsk: dict = {}
token = tokenize(x, chunks)
_barrier_key = barrier_key(ShuffleId(token))
Expand Down
12 changes: 10 additions & 2 deletions distributed/shuffle/_worker_extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@ def __init__(
memory_limiter_disk: ResourceLimiter,
memory_limiter_comms: ResourceLimiter,
):
from dask.array.rechunk import _old_to_new
from dask.array.rechunk import old_to_new

super().__init__(
id=id,
Expand All @@ -315,6 +315,14 @@ def __init__(
memory_limiter_comms=memory_limiter_comms,
memory_limiter_disk=memory_limiter_disk,
)
from dask.array.core import normalize_chunks

# We rely on a canonical `np.nan` in `dask.array.rechunk.old_to_new`
# that passes an implicit identity check when testing for list equality.
# This does not work with (de)serialization, so we have to normalize the chunks
# here again to canonicalize `nan`s.
old = normalize_chunks(old)
new = normalize_chunks(new)
self.old = old
self.new = new
partitions_of = defaultdict(list)
Expand All @@ -323,7 +331,7 @@ def __init__(
self.partitions_of = dict(partitions_of)
self.worker_for = worker_for
self._slicing = rechunk_slicing(old, new)
self._old_to_new = _old_to_new(old, new)
self._old_to_new = old_to_new(old, new)

async def _receive(self, data: list[tuple[ArrayRechunkShardID, bytes]]) -> None:
self.raise_if_closed()
Expand Down
170 changes: 141 additions & 29 deletions distributed/shuffle/tests/test_rechunk.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import asyncio
import math
import random
import warnings

Expand Down Expand Up @@ -83,7 +84,9 @@ async def test_lowlevel_rechunk(

ind_chunks = [[(i, x) for i, x in enumerate(dim)] for dim in old]
ind_chunks = [list(zip(x, y)) for x, y in product(*ind_chunks)]
old_chunks = {idx: np.random.random(chunk) for idx, chunk in ind_chunks}
old_chunks = {
idx: np.random.default_rng().random(chunk) for idx, chunk in ind_chunks
}

workers = list("abcdefghijklmn")[:n_workers]

Expand Down Expand Up @@ -161,7 +164,7 @@ async def test_rechunk_configuration(c, s, *ws, config_value, keyword):
--------
dask.array.tests.test_rechunk.test_rechunk_1d
"""
a = np.random.uniform(0, 1, 30)
a = np.random.default_rng().uniform(0, 1, 30)
x = da.from_array(a, chunks=((10,) * 3,))
new = ((6,) * 5,)
config = {"array.rechunk.method": config_value} if config_value is not None else {}
Expand All @@ -185,7 +188,7 @@ async def test_rechunk_2d(c, s, *ws):
--------
dask.array.tests.test_rechunk.test_rechunk_2d
"""
a = np.random.uniform(0, 1, 300).reshape((10, 30))
a = np.random.default_rng().uniform(0, 1, 300).reshape((10, 30))
x = da.from_array(a, chunks=((1, 2, 3, 4), (5,) * 6))
new = ((5, 5), (15,) * 2)
x2 = rechunk(x, chunks=new, method="p2p")
Expand All @@ -202,7 +205,7 @@ async def test_rechunk_4d(c, s, *ws):
dask.array.tests.test_rechunk.test_rechunk_4d
"""
old = ((5, 5),) * 4
a = np.random.uniform(0, 1, 10000).reshape((10,) * 4)
a = np.random.default_rng().uniform(0, 1, 10000).reshape((10,) * 4)
x = da.from_array(a, chunks=old)
new = (
(10,),
Expand All @@ -225,7 +228,7 @@ async def test_rechunk_with_single_output_chunk_raises(c, s, *ws):
dask.array.tests.test_rechunk.test_rechunk_4d
"""
old = ((5, 5),) * 4
a = np.random.uniform(0, 1, 10000).reshape((10,) * 4)
a = np.random.default_rng().uniform(0, 1, 10000).reshape((10,) * 4)
x = da.from_array(a, chunks=old)
new = ((10,),) * 4
x2 = rechunk(x, chunks=new, method="p2p")
Expand All @@ -244,7 +247,7 @@ async def test_rechunk_expand(c, s, *ws):
--------
dask.array.tests.test_rechunk.test_rechunk_expand
"""
a = np.random.uniform(0, 1, 100).reshape((10, 10))
a = np.random.default_rng().uniform(0, 1, 100).reshape((10, 10))
x = da.from_array(a, chunks=(5, 5))
y = x.rechunk(chunks=((3, 3, 3, 1), (3, 3, 3, 1)), method="p2p")
assert np.all(await c.compute(y) == a)
Expand All @@ -258,7 +261,7 @@ async def test_rechunk_expand2(c, s, *ws):
dask.array.tests.test_rechunk.test_rechunk_expand2
"""
(a, b) = (3, 2)
orig = np.random.uniform(0, 1, a**b).reshape((a,) * b)
orig = np.random.default_rng().uniform(0, 1, a**b).reshape((a,) * b)
for off, off2 in product(range(1, a - 1), range(1, a - 1)):
old = ((a - off, off),) * b
x = da.from_array(orig, chunks=old)
Expand All @@ -280,7 +283,7 @@ async def test_rechunk_method(c, s, *ws):
"""
old = ((5, 2, 3),) * 4
new = ((3, 3, 3, 1),) * 4
a = np.random.uniform(0, 1, 10000).reshape((10,) * 4)
a = np.random.default_rng().uniform(0, 1, 10000).reshape((10,) * 4)
x = da.from_array(a, chunks=old)
x2 = x.rechunk(chunks=new, method="p2p")
assert x2.chunks == new
Expand All @@ -298,7 +301,7 @@ async def test_rechunk_blockshape(c, s, *ws):
new_shape, new_chunks = (10, 10), (4, 3)
new_blockdims = normalize_chunks(new_chunks, new_shape)
old_chunks = ((4, 4, 2), (3, 3, 3, 1))
a = np.random.uniform(0, 1, 100).reshape((10, 10))
a = np.random.default_rng().uniform(0, 1, 100).reshape((10, 10))
x = da.from_array(a, chunks=old_chunks)
check1 = rechunk(x, chunks=new_chunks, method="p2p")
assert check1.chunks == new_blockdims
Expand Down Expand Up @@ -463,6 +466,56 @@ async def test_rechunk_same(c, s, *ws):
assert x is y


@gen_cluster(client=True)
async def test_rechunk_same_fully_unknown(c, s, *ws):
"""
See Also
--------
dask.array.tests.test_rechunk.test_rechunk_same_fully_unknown
"""
dd = pytest.importorskip("dask.dataframe")
x = da.ones(shape=(10, 10), chunks=(5, 10))
y = dd.from_array(x).values
new_chunks = ((np.nan, np.nan), (10,))
assert y.chunks == new_chunks
result = y.rechunk(new_chunks, method="p2p")
assert y is result


@gen_cluster(client=True)
async def test_rechunk_same_fully_unknown_floats(c, s, *ws):
"""Similar to test_rechunk_same_fully_unknown but testing the behavior if
``float("nan")`` is used instead of the recommended ``np.nan``

See Also
--------
dask.array.tests.test_rechunk.test_rechunk_same_fully_unknown_floats
"""
dd = pytest.importorskip("dask.dataframe")
x = da.ones(shape=(10, 10), chunks=(5, 10))
y = dd.from_array(x).values
new_chunks = ((float("nan"), float("nan")), (10,))
result = y.rechunk(new_chunks, method="p2p")
assert y is result


@gen_cluster(client=True)
async def test_rechunk_same_partially_unknown(c, s, *ws):
"""
See Also
--------
dask.array.tests.test_rechunk.test_rechunk_same_partially_unknown
"""
dd = pytest.importorskip("dask.dataframe")
x = da.ones(shape=(10, 10), chunks=(5, 10))
y = dd.from_array(x).values
z = da.concatenate([x, y])
new_chunks = ((5, 5, np.nan, np.nan), (10,))
assert z.chunks == new_chunks
result = z.rechunk(new_chunks, method="p2p")
assert z is result


@gen_cluster(client=True)
async def test_rechunk_with_zero_placeholders(c, s, *ws):
"""
Expand Down Expand Up @@ -513,7 +566,7 @@ async def test_rechunk_unknown_from_pandas(c, s, *ws):
dd = pytest.importorskip("dask.dataframe")
pd = pytest.importorskip("pandas")

arr = np.random.randn(50, 10)
arr = np.random.default_rng().standard_normal((50, 10))
x = dd.from_pandas(pd.DataFrame(arr), 2).values
result = x.rechunk((None, (5, 5)), method="p2p")
assert np.isnan(x.chunks[0]).all()
Expand Down Expand Up @@ -559,11 +612,11 @@ async def test_rechunk_unknown_from_array(c, s, *ws):
],
)
@gen_cluster(client=True)
async def test_rechunk_unknown(c, s, *ws, x, chunks):
async def test_rechunk_with_fully_unknown_dimension(c, s, *ws, x, chunks):
"""
See Also
--------
dask.array.tests.test_rechunk.test_rechunk_unknown
dask.array.tests.test_rechunk.test_rechunk_with_fully_unknown_dimension
"""
dd = pytest.importorskip("dask.dataframe")
y = dd.from_array(x).values
Expand All @@ -574,28 +627,79 @@ async def test_rechunk_unknown(c, s, *ws, x, chunks):
assert_eq(await c.compute(result), await c.compute(expected))


@pytest.mark.parametrize(
"x, chunks",
[
(da.ones(shape=(50, 10), chunks=(25, 10)), (None, 5)),
(da.ones(shape=(50, 10), chunks=(25, 10)), {1: 5}),
(da.ones(shape=(50, 10), chunks=(25, 10)), (None, (5, 5))),
pytest.param(
da.ones(shape=(1000, 10), chunks=(5, 10)),
(None, 5),
marks=pytest.mark.skip(reason="distributed#7757"),
),
pytest.param(
da.ones(shape=(1000, 10), chunks=(5, 10)),
{1: 5},
marks=pytest.mark.skip(reason="distributed#7757"),
),
pytest.param(
da.ones(shape=(1000, 10), chunks=(5, 10)),
(None, (5, 5)),
marks=pytest.mark.skip(reason="distributed#7757"),
),
Comment on lines +636 to +650
Copy link
Member Author

@hendrikmakait hendrikmakait May 25, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These parametrizations are currently failing on CI due to horrible performance. My suspicion is that the performance problems are related to (but not necessarily exclusively caused by) #7757

(da.ones(shape=(10, 10), chunks=(10, 10)), (None, 5)),
(da.ones(shape=(10, 10), chunks=(10, 10)), {1: 5}),
(da.ones(shape=(10, 10), chunks=(10, 10)), (None, (5, 5))),
(da.ones(shape=(10, 10), chunks=(10, 2)), (None, 5)),
(da.ones(shape=(10, 10), chunks=(10, 2)), {1: 5}),
(da.ones(shape=(10, 10), chunks=(10, 2)), (None, (5, 5))),
],
)
@gen_cluster(client=True)
async def test_rechunk_with_partially_unknown_dimension(c, s, *ws, x, chunks):
"""
See Also
--------
dask.array.tests.test_rechunk.test_rechunk_with_partially_unknown_dimension
"""
dd = pytest.importorskip("dask.dataframe")
y = dd.from_array(x).values
z = da.concatenate([x, y])
xx = da.concatenate([x, x])
result = z.rechunk(chunks, method="p2p")
expected = xx.rechunk(chunks, method="p2p")
assert_chunks_match(result.chunks, expected.chunks)
assert_eq(await c.compute(result), await c.compute(expected))


@pytest.mark.parametrize(
"new_chunks",
[
((np.nan, np.nan), (5, 5)),
((math.nan, math.nan), (5, 5)),
((float("nan"), float("nan")), (5, 5)),
],
)
@gen_cluster(client=True)
async def test_rechunk_unknown_explicit(c, s, *ws):
async def test_rechunk_with_fully_unknown_dimension_explicit(c, s, *ws, new_chunks):
"""
See Also
--------
dask.array.tests.test_rechunk.test_rechunk_unknown_explicit
dask.array.tests.test_rechunk.test_rechunk_with_fully_unknown_dimension_explicit
"""
dd = pytest.importorskip("dask.dataframe")
x = da.ones(shape=(10, 10), chunks=(5, 2))
y = dd.from_array(x).values
result = y.rechunk(((float("nan"), float("nan")), (5, 5)), method="p2p")
result = y.rechunk(new_chunks, method="p2p")
expected = x.rechunk((None, (5, 5)), method="p2p")
assert_chunks_match(result.chunks, expected.chunks)
assert_eq(await c.compute(result), await c.compute(expected))


def assert_chunks_match(left, right):
for x, y in zip(left, right):
if np.isnan(x).any():
assert np.isnan(x).all()
else:
assert x == y
for ldim, rdim in zip(left, right):
assert all(np.isnan(l) or l == r for l, r in zip(ldim, rdim))


@gen_cluster(client=True)
Expand All @@ -607,9 +711,17 @@ async def test_rechunk_unknown_raises(c, s, *ws):
"""
dd = pytest.importorskip("dask.dataframe")

x = dd.from_array(da.ones(shape=(10, 10), chunks=(5, 5))).values
with pytest.raises(ValueError):
x.rechunk((None, (5, 5, 5)), method="p2p")
x = da.ones(shape=(10, 10), chunks=(5, 5))
y = dd.from_array(x).values
with pytest.raises(ValueError, match="Chunks do not add"):
y.rechunk((None, (5, 5, 5)), method="p2p")

with pytest.raises(ValueError, match="Chunks must be unchanging"):
y.rechunk(((5, 5), (5, 5)), method="p2p")

with pytest.raises(ValueError, match="Chunks must be unchanging"):
z = da.concatenate([x, y])
z.rechunk(((5, 3, 2, np.nan, np.nan), (5, 5)), method="p2p")


@gen_cluster(client=True)
Expand Down Expand Up @@ -880,8 +992,8 @@ def test_rechunk_slicing_nan():
--------
dask.array.tests.test_rechunk.test_intersect_nan
"""
old_chunks = ((float("nan"), float("nan")), (8,))
new_chunks = ((float("nan"), float("nan")), (4, 4))
old_chunks = ((np.nan, np.nan), (8,))
new_chunks = ((np.nan, np.nan), (4, 4))
result = rechunk_slicing(old_chunks, new_chunks)
expected = {
(0, 0): [
Expand All @@ -908,8 +1020,8 @@ def test_rechunk_slicing_nan_single():
--------
dask.array.tests.test_rechunk.test_intersect_nan_single
"""
old_chunks = ((float("nan"),), (10,))
new_chunks = ((float("nan"),), (5, 5))
old_chunks = ((np.nan,), (10,))
new_chunks = ((np.nan,), (5, 5))

result = rechunk_slicing(old_chunks, new_chunks)
expected = {
Expand All @@ -927,8 +1039,8 @@ def test_rechunk_slicing_nan_long():
--------
dask.array.tests.test_rechunk.test_intersect_nan_long
"""
old_chunks = (tuple([float("nan")] * 4), (10,))
new_chunks = (tuple([float("nan")] * 4), (5, 5))
old_chunks = (tuple([np.nan] * 4), (10,))
new_chunks = (tuple([np.nan] * 4), (5, 5))
result = rechunk_slicing(old_chunks, new_chunks)
expected = {
(0, 0): [
Expand Down