-
-
Notifications
You must be signed in to change notification settings - Fork 748
Reduce memory footprint of P2P shuffling #8157
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
distributed/shuffle/_arrow.py
Outdated
| with pa.OSFile(str(path), mode="rb") as f: | ||
| size = f.seek(0, whence=2) | ||
| f.seek(0) | ||
| while f.tell() < size: | ||
| sr = pa.RecordBatchStreamReader(f) | ||
| shard = sr.read_all() | ||
| arrs = [pa.concat_arrays(column.chunks) for column in shard.columns] | ||
| shard = pa.table(data=arrs, schema=schema) | ||
| shards.append(shard) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
By interleaving disk reads and deserialization, we reduced the size of the individual buffers that get created.
distributed/shuffle/_arrow.py
Outdated
| while f.tell() < size: | ||
| sr = pa.RecordBatchStreamReader(f) | ||
| shard = sr.read_all() | ||
| arrs = [pa.concat_arrays(column.chunks) for column in shard.columns] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
From what I understand, the RecordBatchStreamReader creates one buffer per record batch. On main, this is a problem when we convert the pa.Table consisting of all those batches into a pd.DataFrame. This conversion frees buffers on a per-column basis. Effectively, this means that all buffers from all record batches will not be freed until we converted the last column. To avoid this, we force a copy for each column directly after reading it with pa.concat_arrays. This way, we (should) have one buffer per column per batch.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Similarly, pa.Table.combine_chunks proceeds on a per-column basis causing a spike in temporary memory usage (see #8128).
|
cc @phofl in case you have some thoughts on this |
Unit Test ResultsSee test report for an extended history of previous test failures. This is useful for diagnosing flaky tests. 21 files ± 0 21 suites ±0 10h 58m 4s ⏱️ + 22m 54s For more details on these failures, see this check. Results for commit f23c1aa. ± Comparison against base commit e350c99. ♻️ This comment has been updated with latest results. |
fjetter
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Did a very rough review. So far LGTM but I'll want to test drive this before merging. Will come back asap
|
|
||
|
|
||
| def convert_partition(data: bytes, meta: pd.DataFrame) -> pd.DataFrame: | ||
| def convert_shards(shards: list[pa.Table], meta: pd.DataFrame) -> pd.DataFrame: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(disclaimer: still in early review) I once tried to move tables around instead of bytes but that messed up the event loop. We should check this before merging
|
The increase of the minimal pyarrow version is something we have to do more carefully. This will otherwise silently cause the shuffle default method to fall back to tasks unless users are upgrading their pyarrow version. The very very least we should do is to raise a warning if pyarrow is installed but the version is too low. Safer would likely be to raise hard in this case. I doubt anybody would want to use pyarrow, shuffle a dataframe but not use p2p because the version is too old |
|
A/B test would obviously be nice |
Running those today |
| Raises a ModuleNotFoundError if pyarrow is not installed or an | ||
| ImportError if the installed version is not recent enough. | ||
| """ | ||
| # First version to introduce Table.sort_by | ||
| minversion = "7.0.0" | ||
| # First version that supports concatenating extension arrays (apache/arrow#14463) | ||
| minversion = "12.0.0" | ||
| try: | ||
| import pyarrow as pa | ||
| except ImportError: | ||
| raise RuntimeError(f"P2P shuffling requires pyarrow>={minversion}") | ||
|
|
||
| except ModuleNotFoundError: | ||
| raise ModuleNotFoundError(f"P2P shuffling requires pyarrow>={minversion}") | ||
| if parse(pa.__version__) < parse(minversion): | ||
| raise RuntimeError( | ||
| raise ImportError( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@fjetter: Together with dask/dask#10496, get_default_shuffle_method should raise if pyarrow is outdated and choose tasks if it's not installed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(testing it manually)
distributed/shuffle/_arrow.py
Outdated
| batch_size = parse_bytes("1 MiB") | ||
| batch = [] | ||
| shards = [] | ||
| schema = pa.Schema.from_pandas(meta, preserve_index=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not using pyarrow_schema_dispatch here because it doesn't support preserve_index yet.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| def read_from_disk(path: Path, meta: pd.DataFrame) -> tuple[Any, int]: | ||
| import pyarrow as pa | ||
|
|
||
| batch_size = parse_bytes("1 MiB") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is fragile and I don't really like it, but for now it seems to do the job. We will have to spend more time on performance optimization and understanding memory (de)allocation here to make this more robust.
|
A/B test results (https://github.com/coiled/benchmarks/actions/runs/6120630936): Runtime performance takes a minor hit on some tests
but average
and peak memory improve significantly.
I'm confident that we'll get runtime down again through further performance optimization and batching on the write side. |
|
I think the hit is ok when looking at the memory improvement |
|
|
|
At the very least |
|
I'm currently running a CI test on my fork against the dask/dask sibling branch to verify this works as expected |
|
There may actually be a related failure. https://github.com/fjetter/distributed/actions/runs/6160365960/job/16717162228 I have to check if #8110 is included in this (I guess it is) |
|
Ok, I could track this CancelledError somewhat down... The important message is that this is not an actual computation deadlock. The above "async instruction was cancelled" msg is expected if a worker closes while a task is being executed. The state machine task is cancelled but the thread still keeps running unnoticed. The test actually reaches a I think the CancelledError is actually a red herring since it is triggered by the test hitting a timeout while the shuffle plugin is closing. |
|
I strongly suspect that test failure is unrelated but I will spend some more time trying to hunt this down... 🤞 |
|
Found the cause why this was blocking, see #8184 This is an unrelated fix and we should be able to proceed here |



Closes #8015
Supersedes #8128
Blocked by #dask/dask#10493pre-commit run --all-files