Skip to content

Conversation

@hendrikmakait
Copy link
Member

@hendrikmakait hendrikmakait commented Jun 9, 2023

Closes #7757

Preamble:

For the sake of this explanation, I will diverge from the common naming used in dask.array:

An n-dimensional Dask array like the one illustrated in the image below

image

is partitioned into many small n-dimensional NumPy arrays, each of which we will refer to as a partition. The partitions are created by chunking the array along each axis, where a single chunk determines the size of its corresponding partitions along this axis. In the image, we have 5 chunks along the horizontal axis, 4 chunks along the vertical axis and 20 partitions.

With $P :=$ number of partitions of the array, and $C_n :=$ number of chunks along axis n, and $N :=$ number of dimensions, it follows that

$$P = \prod_{n=1}^N C_n$$

Content

The big change in this PR is the intermediate result we store on the shuffle run to determine how to slice the input partitions in order to create the output partitions from the resulting shards. With this PR, we do not calculate for each of the input partitions, how we need to slice it into shards that belong to one output partition. Instead, we only store for each individual axis of the n-dimensional array how we need to split the chunks along that axis.

With $P_{old} :=$ number of old partitions, $P_{new} :=$ number of new partitions, this means with the previous algorithm we stored a result of size $\mathcal{O}(P_{old}* P_{new})$.

With $C_{old, n} :=$ number of old chunks along axis n, $C_{new, n} :=$ number of new chunks along axis n,

it follows that

$$ P_{old}* P_{new} = \prod_{n=1}^N C_{old, n} * C_{new, n} \in O((C_{old} * C_{new})^N)$$

where $C_{old} :=$ number of old chunks along all axes, $C_{new} :=$ number of new chunks along all axes

For the new approach, there are only $C_{old, n} + C_{new, n}$ splits per axis n. Hence, we only store a result of size

$$ \sum_{n=1}^N C_{old, n} + C_{new, n} = C_{old} + C_{new} \in \mathcal{O}(C_{old} + C_{new}) \in \mathcal{O}(P_{old} + P_{new})$$

  • Tests added / passed
  • Passes pre-commit run --all-files

cc @wence-

@github-actions
Copy link
Contributor

github-actions bot commented Jun 9, 2023

Unit Test Results

See test report for an extended history of previous test failures. This is useful for diagnosing flaky tests.

       20 files  ±    0         20 suites  ±0   13h 16m 46s ⏱️ + 1h 48m 58s
  3 677 tests  -     2    3 567 ✔️ +    1     103 💤  -     5    7 +  2 
35 743 runs  +163  34 066 ✔️ +255  1 655 💤  - 109  22 +17 

For more details on these failures, see this check.

Results for commit b93e602. ± Comparison against base commit 74a1bcd.

This pull request removes 9 and adds 7 tests. Note that renamed tests count towards both.
distributed.protocol.tests.test_numpy
distributed.shuffle.tests.test_rechunk
distributed.shuffle.tests.test_rechunk ‑ test_rechunk_slicing_1
distributed.shuffle.tests.test_rechunk ‑ test_rechunk_slicing_2
distributed.shuffle.tests.test_rechunk ‑ test_rechunk_slicing_chunks_with_nonzero
distributed.shuffle.tests.test_rechunk ‑ test_rechunk_slicing_chunks_with_zero
distributed.shuffle.tests.test_rechunk ‑ test_rechunk_slicing_nan
distributed.shuffle.tests.test_rechunk ‑ test_rechunk_slicing_nan_long
distributed.shuffle.tests.test_rechunk ‑ test_rechunk_slicing_nan_single
distributed.shuffle.tests.test_rechunk ‑ test_split_axes_1
distributed.shuffle.tests.test_rechunk ‑ test_split_axes_2
distributed.shuffle.tests.test_rechunk ‑ test_split_axes_nan
distributed.shuffle.tests.test_rechunk ‑ test_split_axes_nan_long
distributed.shuffle.tests.test_rechunk ‑ test_split_axes_nan_single
distributed.shuffle.tests.test_rechunk ‑ test_split_axes_with_nonzero
distributed.shuffle.tests.test_rechunk ‑ test_split_axes_with_zero

♻️ This comment has been updated with latest results.

Copy link
Contributor

@wence- wence- left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think a bit of exposition on what is going on would be helpful

from dask.array.rechunk import intersect_chunks

ndim = len(old)
intersections = intersect_chunks(old, new)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, so previously we would make the full cartesian product of intersections up front.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, previously we calculated the mapping information from each individual input chunk to each individual output chunk. This is results in the cartesian product of the chunks on the chunked axes.

from dask.array.rechunk import old_to_new

shard_indices = product(*(range(dim) for dim in sub_shape))
_old_to_new = old_to_new(old, new)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now we "just" do this mapping from old to new.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now we calculate the mapping per axis, which we can piece together for each inividual N-dimensional input at runtime. So this scales much slower.

Comment on lines 169 to 171
for axis_id, new_axis in enumerate(_old_to_new):
old_axis: SplitAxis = [[] for _ in old[axis_id]]
for new_chunk_id, new_chunk in enumerate(new_axis):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, so this is O(n^2) which makes sense because it's a transpose-like operation.

for axis_id, new_axis in enumerate(_old_to_new):
old_axis: SplitAxis = [[] for _ in old[axis_id]]
for new_chunk_id, new_chunk in enumerate(new_axis):
for new_subchunk_id, (old_chunk_id, slice) in enumerate(new_chunk):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And each new_chunk is O(1) long? Or can it be O(n) long?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Each new chunk is $O(C_{old, n})$, that is, we could have a new chunk along axis n that contains all old chunks along that axis.

for new_index, new_chunk in zip(new_indices, intersections):
sub_shape = [len({slice[dim][0] for slice in new_chunk}) for dim in range(ndim)]
def split_axes(old: ChunkedAxes, new: ChunkedAxes) -> SplitAxes:
from dask.array.rechunk import old_to_new
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a docstring please?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, does this help?

continue
old_chunk.sort(key=lambda subchunk: subchunk[2].start)
axes.append(old_axis)
return axes
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So this data is just O(n_old * n_new) in size, rather than O(n_old^2 * n_new^2) if I understand how the previous slicing datastructure was created?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See the description for an explanation of the memory complexity.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks.

from itertools import product

shards = product(
*(axis[i] for axis, i in zip(self.split_axes, input_partition))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it obvious why the input partition is the same length as the split axes?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think almost none about the rechunking code is obvious, so please raise wherever you're missing information. It took my way too long to understand what's going on in task-bask rechunking.

split_axes contains the splits across all N dimension/axes and input_partition is an N-dimensional index.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've renamed input_partition to partition_id in case that makes things clearer.

rec_cat_arg[tuple(index)] = shard
del data
del file
arrs = rec_cat_arg.tolist()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO: it feels naively like it is very wasteful to make objects of everything, convert to a nested list and then concatenate. Why is it not possible to just allocate the right size of thing up front and insert directly?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We only know the true size of the output chunk when we have loaded all its shards thanks to the chunks of size np.nan along some dimension. We could do some additional math to determine the output size once we've finished reading the file. This would allow us to avoid rec_cat_arg and we could insert the elements of shards directly. I'm not sure how much of a performance problem that really is though. I'd rather leave performance optimization for a follow-up.

@hendrikmakait
Copy link
Member Author

I think a bit of exposition on what is going on would be helpful

I'll write up an in-depth explanation of the changes made here.

(id, pickle.dumps((id.shard_index, data[nslice])))
from itertools import product

shards = product(
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we only compute the splits per axis, we create the full cartesian product here again to determine each of the n-dimensional shards we have to slice from the input and send to the output.

@hendrikmakait
Copy link
Member Author

@wence-: I've added a sketch of the changes to the memory complexity to the description.

Comment on lines 636 to 647
pytest.param(
da.ones(shape=(1000, 10), chunks=(5, 10)),
(None, 5),
marks=pytest.mark.skip(reason="distributed#7757"),
),
pytest.param(
da.ones(shape=(1000, 10), chunks=(5, 10)),
{1: 5},
marks=pytest.mark.skip(reason="distributed#7757"),
),
pytest.param(
da.ones(shape=(1000, 10), chunks=(5, 10)),
(None, (5, 5)),
marks=pytest.mark.skip(reason="distributed#7757"),
),
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like these test cases still fail on CI, I'm investigating.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Locally, these take ~7s, roughly double the time of task-based shuffling. Given the shape of the rechunk it's not surprising that task-based performs better here. I think we may want to leave these cases out or adjust their size.

@hendrikmakait
Copy link
Member Author

hendrikmakait commented Jun 14, 2023

chunks are partitions and shards are the sliced partitions we are submitting?

Yes, if you have some ideas on clarifying naming, I'd appreciate it. We also have chunks as partitions/n-dimensional arrays and the chunks per axis (which are the information on the size of the chunks along this axis) .

EDIT: Now partitions are partitions and shards are sub-partitions of partitions created by slicing them. That is, a partition can be re-created by concatenating its shards.

@hendrikmakait
Copy link
Member Author

I've run an A/B test and the results show significant improvements in average and peak memory:

average_memory peak_memory

as well as mild performance improvements on runtime for 2 tests:

runtime

slicing = defaultdict(list)
SplitChunk: TypeAlias = list[Split]
SplitAxis: TypeAlias = list[SplitChunk]
SplitAxes: TypeAlias = list[SplitAxis]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't have any better proposal right now but the difference between Axis and Axes doesn't help readability

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SplitAxes only pops up once so I guess this is not a big deal

(id, pickle.dumps((id.shard_index, data[nslice])))
from itertools import product

shards = product(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I believe you are using splits and shards interchangeably. If so, I'd prefer sticking to one

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch, I've only introduced splits very recently to the chunks/chunks issue, I may have missed some cases.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does ndsplits help convey the concept that they're not truly 1-d splits nor actual shards of data?

Comment on lines +968 to +972
[Split(0, 0, slice(0, 20, None))],
[Split(0, 1, slice(0, 20, None))],
[Split(0, 2, slice(0, 18, None)), Split(1, 0, slice(18, 20, None))],
[Split(1, 1, slice(0, 2, None)), Split(2, 0, slice(2, 20, None))],
[Split(2, 1, slice(0, 2, None)), Split(3, 0, slice(2, 20, None))],
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

off-topic: something like {avg|max|median}(len(...)) might be interesting metrics for heuristics that try to distinguish P2P vs tasks for single stage rechunks (if there is actually a use case for this)

@hendrikmakait hendrikmakait requested a review from wence- June 14, 2023 11:39
@hendrikmakait hendrikmakait marked this pull request as ready for review June 14, 2023 11:39
@fjetter
Copy link
Member

fjetter commented Jun 14, 2023

I only skimmed the algorithm but from what I can tell everything's in order. I would appreciate a more thorough look from @wence-

Thanks for the thorough explanation!

Regarding benchmarks, results look great! I'd be interested to see what happens with test_rechunk_out_of_memory now. This is out of scope for this PR but likely worth looking at briefly.

@hendrikmakait hendrikmakait added the needs review Needs review from a contributor. label Jun 14, 2023
@wence-
Copy link
Contributor

wence- commented Jun 14, 2023

I would appreciate a more thorough look from @wence-

I am out for the rest of the week and will look on Monday.

@hendrikmakait
Copy link
Member Author

hendrikmakait commented Jun 14, 2023

I'd be interested to see what happens with test_rechunk_out_of_memory now. This is out of scope for this PR but likely worth looking at briefly.

It finally runs with P2P. Its average (peak) memory is around 1/3 (1/2) of task-based shuffling, but it's still about 3x slower. I haven't done any profiling on this, so there's a myriad of reasons for the runtime difference.

https://github.com/coiled/benchmarks/actions/runs/5267999800

@fjetter
Copy link
Member

fjetter commented Jun 15, 2023

It finally runs with P2P. Its average (peak) memory is around 1/3 (1/2) of task-based shuffling, but it's still about 3x slower. I haven't done any profiling on this, so there's a myriad of reasons for the runtime difference.

Thanks. We should look into this eventually but that's out of scope for this

Copy link
Contributor

@wence- wence- left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks Hendrik, I have a few questions about whether or not there is sparsity to exploit in the chunk mappings.

Comment on lines +136 to +144
#: Index of the new output chunk to which this split belongs.
chunk_index: int

#: Index of the split within the list of splits that are concatenated
#: to create the new chunk.
split_index: int

#: Slice of the input chunk.
slice: slice
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, so the idea is that you're treating everything by just looking at the 1-D "cuts" of the axes.

So on an axis, each output chunk consists of some number of slices of input chunks.

An output chunk $C_{new,i}$ consists of $\bigoplus_{j=0}{}^N C_{old,j}[S^i_j]$ where $\bigoplus$ is concatenation, $N$ is the number of input chunks, and $S^i_j$ is the slice of the $j$ th input chunk that corresponds to the $i$ th output chunk.

How does this data structure record "sparseness" in the input chunks? AIUI, it is only a contiguous range of input chunks that can correspond to an output chunk, so most slices $S^i_j$ will be empty. We record the slice, but not which input chunk we're taking from.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is recorded implicitly in SplitChunk. The SplitChunk at position j contains all slices belonging to input chunk j. The SplitChunk at position j in SplitAxis contains a Split for output chunk i IFF input chunk j contributes a non-empty slice to i.

(There is some special casing going on if chunk i is of length 0, in that case there exists one Split of some input chunk. That split contains an empty slice.)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can walk you through some of the test cases, those should illustrate the data structure adequately.

_old_to_new = old_to_new(old, new)

axes = []
for axis_index, new_axis in enumerate(_old_to_new):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So I think old_to_new produces "dense" output, in that every output chunk is hypothetically made up of some piece of every input chunk (so if there are $N$ input chunks and $P$ output chunks this is $N*P$ big). However, most of the time, the input piece is empty (indicated by an empty slice), so this data structure only needs to be $O(P)$ large I think.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This only produces data for overlaps, so the data structure should be $\mathcal{O}(N + P) = O(M)$ with M being the number of overlaps between input and output chunks.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah ok, so this is sparse.

for axis_index, new_axis in enumerate(_old_to_new):
old_axis: SplitAxis = [[] for _ in old[axis_index]]
for new_chunk_index, new_chunk in enumerate(new_axis):
for split_index, (old_chunk_index, slice) in enumerate(new_chunk):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we don't want to change old_to_new we could prune the slices here with:

if slice.start == slice.end:
    continue

Comment on lines +177 to +179
old_axis[old_chunk_index].append(
Split(new_chunk_index, split_index, slice)
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, so this is how maintain the relation between the input and output chunks. This is a "scatter-like" data structure. Each old_chunk_index scatters some slice of itself to every new_chunk_index.

continue
old_chunk.sort(key=lambda subchunk: subchunk[2].start)
axes.append(old_axis)
return axes
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks.

Split(new_chunk_index, split_index, slice)
)
for old_chunk in old_axis:
old_chunk.sort(key=lambda split: split.slice.start)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does old_to_new not produce slices in sorted order?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, will have to check again.

Comment on lines +364 to +366
ndsplits = product(
*(axis[i] for axis, i in zip(self.split_axes, partition_id))
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So here, for each nD input piece, we construct the set of pD output chunks it scatters to.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again, I think IIUC that this is a dense representation, but it is actually mostly sparse (since most of the input boxes will not contribute to most of the output boxes).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FWIW, the dimensionality does not change, so both input and output are n-dimensional. See other comments wrt. sparseness.

)

for ndsplit in ndsplits:
chunk_index, shard_index, ndslice = zip(*ndsplit)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So ndsplit is a list of Split tuples, and this transposes, so we get a tuple that identifies the output chunk, the "shard" (still not quite sure here), and an nD slice of the input data (corresponding to the input chunk).

But this seems to presume that the indexing of the output chunks is nD like the indexing of the input chunks. But I don't think that is true?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do you think it's not true?

Consider the case of a 2-dim array with size (10, 10) and a single input partition of size (10, 10). This input partition has the 2-dim index (0, 0). Any way of rechunking this to smaller pieces would also result in a 2-dim index.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought that rechunking also allowed reshaping, but that was wrong, I believe?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed, rechunking does not allow reshaping. Admittedly, it looks like it's not documented anywhere but in some tests. We validate that chunks adhere to the shape here:
https://github.com/dask/dask/blob/499f4055da707fa76d06a2b79f408f124eee4723/dask/array/core.py#L3126-L3129


for ndsplit in ndsplits:
chunk_index, shard_index, ndslice = zip(*ndsplit)
id = ArrayRechunkShardID(chunk_index, shard_index)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again, I suspect pruning empty slices will be beneficial here:

if any(s.start == s.end for s in ndslice):
    continue

Returns
-------
SplitAxes
Splits along each axis that determine how to slice the input chunks to create
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think Splits here is a technical term (meaning the Split object)?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Terminology is hard:

By split, I mean a split object, which is essentially a slice with its index in the output. Do you have suggestions on how to make this clearer?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"list of :class:Splits ..." ? Which should linkify things?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should explicitly mention the sparsity of the output here.

@hendrikmakait
Copy link
Member Author

@wence-: Thanks for reviewing. Concerning your questions, a high-bandwidth conversation might help us sort these out. Let me know if you'd like to have one.

@wence-
Copy link
Contributor

wence- commented Jun 19, 2023

@wence-: Thanks for reviewing. Concerning your questions, a high-bandwidth conversation might help us sort these out. Let me know if you'd like to have one.

Thanks, I think I have...

Copy link
Contributor

@wence- wence- left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good. As discussed I'll open a followup PR adding some module-level documentation (which may result in internal renamings, but hopefully no substantive changes).

Copy link
Member

@fjetter fjetter left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you @hendrikmakait and @wence- !

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

needs review Needs review from a contributor. shuffle

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Calculate slicing of P2P rechunking on the fly

3 participants