-
-
Notifications
You must be signed in to change notification settings - Fork 748
Improve memory footprint of P2P rechunking #7897
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Unit Test ResultsSee test report for an extended history of previous test failures. This is useful for diagnosing flaky tests. 20 files ± 0 20 suites ±0 13h 16m 46s ⏱️ + 1h 48m 58s For more details on these failures, see this check. Results for commit b93e602. ± Comparison against base commit 74a1bcd. This pull request removes 9 and adds 7 tests. Note that renamed tests count towards both.♻️ This comment has been updated with latest results. |
wence-
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think a bit of exposition on what is going on would be helpful
| from dask.array.rechunk import intersect_chunks | ||
|
|
||
| ndim = len(old) | ||
| intersections = intersect_chunks(old, new) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK, so previously we would make the full cartesian product of intersections up front.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, previously we calculated the mapping information from each individual input chunk to each individual output chunk. This is results in the cartesian product of the chunks on the chunked axes.
| from dask.array.rechunk import old_to_new | ||
|
|
||
| shard_indices = product(*(range(dim) for dim in sub_shape)) | ||
| _old_to_new = old_to_new(old, new) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Now we "just" do this mapping from old to new.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Now we calculate the mapping per axis, which we can piece together for each inividual N-dimensional input at runtime. So this scales much slower.
distributed/shuffle/_rechunk.py
Outdated
| for axis_id, new_axis in enumerate(_old_to_new): | ||
| old_axis: SplitAxis = [[] for _ in old[axis_id]] | ||
| for new_chunk_id, new_chunk in enumerate(new_axis): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK, so this is O(n^2) which makes sense because it's a transpose-like operation.
distributed/shuffle/_rechunk.py
Outdated
| for axis_id, new_axis in enumerate(_old_to_new): | ||
| old_axis: SplitAxis = [[] for _ in old[axis_id]] | ||
| for new_chunk_id, new_chunk in enumerate(new_axis): | ||
| for new_subchunk_id, (old_chunk_id, slice) in enumerate(new_chunk): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
And each new_chunk is O(1) long? Or can it be O(n) long?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Each new chunk is
| for new_index, new_chunk in zip(new_indices, intersections): | ||
| sub_shape = [len({slice[dim][0] for slice in new_chunk}) for dim in range(ndim)] | ||
| def split_axes(old: ChunkedAxes, new: ChunkedAxes) -> SplitAxes: | ||
| from dask.array.rechunk import old_to_new |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add a docstring please?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done, does this help?
| continue | ||
| old_chunk.sort(key=lambda subchunk: subchunk[2].start) | ||
| axes.append(old_axis) | ||
| return axes |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So this data is just O(n_old * n_new) in size, rather than O(n_old^2 * n_new^2) if I understand how the previous slicing datastructure was created?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See the description for an explanation of the memory complexity.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks.
| from itertools import product | ||
|
|
||
| shards = product( | ||
| *(axis[i] for axis, i in zip(self.split_axes, input_partition)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it obvious why the input partition is the same length as the split axes?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think almost none about the rechunking code is obvious, so please raise wherever you're missing information. It took my way too long to understand what's going on in task-bask rechunking.
split_axes contains the splits across all N dimension/axes and input_partition is an N-dimensional index.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've renamed input_partition to partition_id in case that makes things clearer.
| rec_cat_arg[tuple(index)] = shard | ||
| del data | ||
| del file | ||
| arrs = rec_cat_arg.tolist() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TODO: it feels naively like it is very wasteful to make objects of everything, convert to a nested list and then concatenate. Why is it not possible to just allocate the right size of thing up front and insert directly?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We only know the true size of the output chunk when we have loaded all its shards thanks to the chunks of size np.nan along some dimension. We could do some additional math to determine the output size once we've finished reading the file. This would allow us to avoid rec_cat_arg and we could insert the elements of shards directly. I'm not sure how much of a performance problem that really is though. I'd rather leave performance optimization for a follow-up.
I'll write up an in-depth explanation of the changes made here. |
| (id, pickle.dumps((id.shard_index, data[nslice]))) | ||
| from itertools import product | ||
|
|
||
| shards = product( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since we only compute the splits per axis, we create the full cartesian product here again to determine each of the n-dimensional shards we have to slice from the input and send to the output.
|
@wence-: I've added a sketch of the changes to the memory complexity to the description. |
| pytest.param( | ||
| da.ones(shape=(1000, 10), chunks=(5, 10)), | ||
| (None, 5), | ||
| marks=pytest.mark.skip(reason="distributed#7757"), | ||
| ), | ||
| pytest.param( | ||
| da.ones(shape=(1000, 10), chunks=(5, 10)), | ||
| {1: 5}, | ||
| marks=pytest.mark.skip(reason="distributed#7757"), | ||
| ), | ||
| pytest.param( | ||
| da.ones(shape=(1000, 10), chunks=(5, 10)), | ||
| (None, (5, 5)), | ||
| marks=pytest.mark.skip(reason="distributed#7757"), | ||
| ), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looks like these test cases still fail on CI, I'm investigating.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Locally, these take ~7s, roughly double the time of task-based shuffling. Given the shape of the rechunk it's not surprising that task-based performs better here. I think we may want to leave these cases out or adjust their size.
Yes, if you have some ideas on clarifying naming, I'd appreciate it. We also have chunks as partitions/n-dimensional arrays and the chunks per axis (which are the information on the size of the chunks along this axis) . EDIT: Now partitions are partitions and shards are sub-partitions of partitions created by slicing them. That is, a partition can be re-created by concatenating its shards. |
| slicing = defaultdict(list) | ||
| SplitChunk: TypeAlias = list[Split] | ||
| SplitAxis: TypeAlias = list[SplitChunk] | ||
| SplitAxes: TypeAlias = list[SplitAxis] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't have any better proposal right now but the difference between Axis and Axes doesn't help readability
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
SplitAxes only pops up once so I guess this is not a big deal
| (id, pickle.dumps((id.shard_index, data[nslice]))) | ||
| from itertools import product | ||
|
|
||
| shards = product( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: I believe you are using splits and shards interchangeably. If so, I'd prefer sticking to one
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch, I've only introduced splits very recently to the chunks/chunks issue, I may have missed some cases.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does ndsplits help convey the concept that they're not truly 1-d splits nor actual shards of data?
| [Split(0, 0, slice(0, 20, None))], | ||
| [Split(0, 1, slice(0, 20, None))], | ||
| [Split(0, 2, slice(0, 18, None)), Split(1, 0, slice(18, 20, None))], | ||
| [Split(1, 1, slice(0, 2, None)), Split(2, 0, slice(2, 20, None))], | ||
| [Split(2, 1, slice(0, 2, None)), Split(3, 0, slice(2, 20, None))], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
off-topic: something like {avg|max|median}(len(...)) might be interesting metrics for heuristics that try to distinguish P2P vs tasks for single stage rechunks (if there is actually a use case for this)
|
I only skimmed the algorithm but from what I can tell everything's in order. I would appreciate a more thorough look from @wence- Thanks for the thorough explanation! Regarding benchmarks, results look great! I'd be interested to see what happens with |
I am out for the rest of the week and will look on Monday. |
It finally runs with P2P. Its average (peak) memory is around 1/3 (1/2) of task-based shuffling, but it's still about 3x slower. I haven't done any profiling on this, so there's a myriad of reasons for the runtime difference. https://github.com/coiled/benchmarks/actions/runs/5267999800 |
Thanks. We should look into this eventually but that's out of scope for this |
wence-
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks Hendrik, I have a few questions about whether or not there is sparsity to exploit in the chunk mappings.
| #: Index of the new output chunk to which this split belongs. | ||
| chunk_index: int | ||
|
|
||
| #: Index of the split within the list of splits that are concatenated | ||
| #: to create the new chunk. | ||
| split_index: int | ||
|
|
||
| #: Slice of the input chunk. | ||
| slice: slice |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK, so the idea is that you're treating everything by just looking at the 1-D "cuts" of the axes.
So on an axis, each output chunk consists of some number of slices of input chunks.
An output chunk
How does this data structure record "sparseness" in the input chunks? AIUI, it is only a contiguous range of input chunks that can correspond to an output chunk, so most slices
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is recorded implicitly in SplitChunk. The SplitChunk at position j contains all slices belonging to input chunk j. The SplitChunk at position j in SplitAxis contains a Split for output chunk i IFF input chunk j contributes a non-empty slice to i.
(There is some special casing going on if chunk i is of length 0, in that case there exists one Split of some input chunk. That split contains an empty slice.)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I can walk you through some of the test cases, those should illustrate the data structure adequately.
| _old_to_new = old_to_new(old, new) | ||
|
|
||
| axes = [] | ||
| for axis_index, new_axis in enumerate(_old_to_new): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So I think old_to_new produces "dense" output, in that every output chunk is hypothetically made up of some piece of every input chunk (so if there are
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This only produces data for overlaps, so the data structure should be M being the number of overlaps between input and output chunks.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah ok, so this is sparse.
| for axis_index, new_axis in enumerate(_old_to_new): | ||
| old_axis: SplitAxis = [[] for _ in old[axis_index]] | ||
| for new_chunk_index, new_chunk in enumerate(new_axis): | ||
| for split_index, (old_chunk_index, slice) in enumerate(new_chunk): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we don't want to change old_to_new we could prune the slices here with:
if slice.start == slice.end:
continue
| old_axis[old_chunk_index].append( | ||
| Split(new_chunk_index, split_index, slice) | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK, so this is how maintain the relation between the input and output chunks. This is a "scatter-like" data structure. Each old_chunk_index scatters some slice of itself to every new_chunk_index.
| continue | ||
| old_chunk.sort(key=lambda subchunk: subchunk[2].start) | ||
| axes.append(old_axis) | ||
| return axes |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks.
| Split(new_chunk_index, split_index, slice) | ||
| ) | ||
| for old_chunk in old_axis: | ||
| old_chunk.sort(key=lambda split: split.slice.start) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does old_to_new not produce slices in sorted order?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point, will have to check again.
| ndsplits = product( | ||
| *(axis[i] for axis, i in zip(self.split_axes, partition_id)) | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So here, for each nD input piece, we construct the set of pD output chunks it scatters to.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Again, I think IIUC that this is a dense representation, but it is actually mostly sparse (since most of the input boxes will not contribute to most of the output boxes).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FWIW, the dimensionality does not change, so both input and output are n-dimensional. See other comments wrt. sparseness.
| ) | ||
|
|
||
| for ndsplit in ndsplits: | ||
| chunk_index, shard_index, ndslice = zip(*ndsplit) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So ndsplit is a list of Split tuples, and this transposes, so we get a tuple that identifies the output chunk, the "shard" (still not quite sure here), and an nD slice of the input data (corresponding to the input chunk).
But this seems to presume that the indexing of the output chunks is nD like the indexing of the input chunks. But I don't think that is true?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do you think it's not true?
Consider the case of a 2-dim array with size (10, 10) and a single input partition of size (10, 10). This input partition has the 2-dim index (0, 0). Any way of rechunking this to smaller pieces would also result in a 2-dim index.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I thought that rechunking also allowed reshaping, but that was wrong, I believe?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Indeed, rechunking does not allow reshaping. Admittedly, it looks like it's not documented anywhere but in some tests. We validate that chunks adhere to the shape here:
https://github.com/dask/dask/blob/499f4055da707fa76d06a2b79f408f124eee4723/dask/array/core.py#L3126-L3129
|
|
||
| for ndsplit in ndsplits: | ||
| chunk_index, shard_index, ndslice = zip(*ndsplit) | ||
| id = ArrayRechunkShardID(chunk_index, shard_index) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Again, I suspect pruning empty slices will be beneficial here:
if any(s.start == s.end for s in ndslice):
continue
| Returns | ||
| ------- | ||
| SplitAxes | ||
| Splits along each axis that determine how to slice the input chunks to create |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think Splits here is a technical term (meaning the Split object)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Terminology is hard:
By split, I mean a split object, which is essentially a slice with its index in the output. Do you have suggestions on how to make this clearer?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"list of :class:Splits ..." ? Which should linkify things?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should explicitly mention the sparsity of the output here.
|
@wence-: Thanks for reviewing. Concerning your questions, a high-bandwidth conversation might help us sort these out. Let me know if you'd like to have one. |
Thanks, I think I have... |
wence-
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good. As discussed I'll open a followup PR adding some module-level documentation (which may result in internal renamings, but hopefully no substantive changes).
fjetter
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you @hendrikmakait and @wence- !



Closes #7757
Preamble:
For the sake of this explanation, I will diverge from the common naming used in
dask.array:An n-dimensional Dask array like the one illustrated in the image below
is partitioned into many small n-dimensional NumPy arrays, each of which we will refer to as a partition. The partitions are created by chunking the array along each axis, where a single chunk determines the size of its corresponding partitions along this axis. In the image, we have 5 chunks along the horizontal axis, 4 chunks along the vertical axis and 20 partitions.
With$P :=$ number of partitions of the array, and $C_n :=$ number of chunks along axis n, and $N :=$ number of dimensions, it follows that
Content
The big change in this PR is the intermediate result we store on the shuffle run to determine how to slice the input partitions in order to create the output partitions from the resulting shards. With this PR, we do not calculate for each of the input partitions, how we need to slice it into shards that belong to one output partition. Instead, we only store for each individual axis of the n-dimensional array how we need to split the chunks along that axis.
With$P_{old} :=$ number of old partitions, $P_{new} :=$ number of new partitions, this means with the previous algorithm we stored a result of size $\mathcal{O}(P_{old}* P_{new})$ .
With$C_{old, n} :=$ number of old chunks along axis n, $C_{new, n} :=$ number of new chunks along axis n,
it follows that
where$C_{old} :=$ number of old chunks along all axes, $C_{new} :=$ number of new chunks along all axes
For the new approach, there are only$C_{old, n} + C_{new, n}$ splits per axis n. Hence, we only store a result of size
pre-commit run --all-filescc @wence-