From 7cc67c26a98c3906fc0f81344a603fbc4472533f Mon Sep 17 00:00:00 2001 From: Lawrence Mitchell Date: Fri, 17 Mar 2023 12:24:07 +0000 Subject: [PATCH 01/10] p2p: Add in-memory buffering for output DiskShardsBuffer It is likely that at least some part (possibly all) of the output of the shuffle will fit in memory. In this circumstance, we don't need to necessarily write output shards to disk only to read them in later. To enable this, provide (optional) in-memory buffering on the output DiskShardsBuffer. While the total output size is less than some limit, don't bother hitting the disk, but rather just hold in memory references. Once too much data has arrived, block to flush these buffers to disk. When reading, we now might have some shards in memory, so we concatenate these with the on-disk buffers. --- distributed/shuffle/_disk.py | 64 +++++++++++++++++++++++++++--------- 1 file changed, 48 insertions(+), 16 deletions(-) diff --git a/distributed/shuffle/_disk.py b/distributed/shuffle/_disk.py index 2b3dc37beed..0b4a2505734 100644 --- a/distributed/shuffle/_disk.py +++ b/distributed/shuffle/_disk.py @@ -3,6 +3,9 @@ import contextlib import pathlib import shutil +from collections import defaultdict + +from dask.utils import parse_bytes from distributed.shuffle._buffer import ShardsBuffer from distributed.shuffle._limiter import ResourceLimiter @@ -38,6 +41,9 @@ class DiskShardsBuffer(ShardsBuffer): implementation of this scheme. """ + # TODO: Configure based on worker memory limits (possibly dynamically?) + max_buffer_size = parse_bytes("1 GiB") + def __init__( self, directory: str | pathlib.Path, @@ -50,6 +56,8 @@ def __init__( ) self.directory = pathlib.Path(directory) self.directory.mkdir(exist_ok=True) + self._in_memory = 0 + self._memory_buf: defaultdict[str, list[bytes]] = defaultdict(list) async def _process(self, id: str, shards: list[bytes]) -> None: """Write one buffer to file @@ -65,38 +73,62 @@ async def _process(self, id: str, shards: list[bytes]) -> None: dropping the write into communicate above. """ + # Normalisation for safety + id = str(id) with log_errors(): # Consider boosting total_size a bit here to account for duplication with self.time("write"): - with open( - self.directory / str(id), mode="ab", buffering=100_000_000 - ) as f: - for shard in shards: - f.write(shard) + if not self.max_buffer_size: + # Fast path if we're always hitting the disk + self._write(id, shards) + else: + while shards: + if self._in_memory < self.max_buffer_size: + self._memory_buf[id].append(newdata := shards.pop()) + self._in_memory += len(newdata) + else: + # Flush old data + # This could be offloaded to a background + # task at the cost of going further over + # the soft memory limit. + for id, bufs in self._memory_buf.items(): + self._write(id, bufs) + self._memory_buf.clear() + self._in_memory = 0 + + def _write(self, id: str, shards: list[bytes]) -> None: + with open(self.directory / str(id), mode="ab", buffering=100_000_000) as f: + for s in shards: + f.write(s) + + async def close(self) -> None: + await super().close() + with contextlib.suppress(FileNotFoundError): + shutil.rmtree(self.directory) def read(self, id: int | str) -> bytes: - """Read a complete file back into memory""" + """Read a complete file back into memory, concatting with any + in memory parts""" self.raise_on_exception() if not self._inputs_done: raise RuntimeError("Tried to read from file before done.") + id = str(id) + data = self._memory_buf[id] try: with self.time("read"): with open( self.directory / str(id), mode="rb", buffering=100_000_000 ) as f: - data = f.read() - size = f.tell() + data.append(f.read()) except FileNotFoundError: - raise KeyError(id) + if not data: + # Neither disk nor in memory + raise KeyError(id) if data: - self.bytes_read += size - return data + buf = b"".join(data) + self.bytes_read += len(buf) + return buf else: raise KeyError(id) - - async def close(self) -> None: - await super().close() - with contextlib.suppress(FileNotFoundError): - shutil.rmtree(self.directory) From edf8fe5e1695f43190afe810a97febf0e627e1a8 Mon Sep 17 00:00:00 2001 From: Lawrence Mitchell Date: Mon, 20 Mar 2023 15:13:51 +0000 Subject: [PATCH 02/10] Don't overwrite of local variables Fixes somes shards going "missing". --- distributed/shuffle/_disk.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/distributed/shuffle/_disk.py b/distributed/shuffle/_disk.py index 0b4a2505734..b40ca231647 100644 --- a/distributed/shuffle/_disk.py +++ b/distributed/shuffle/_disk.py @@ -91,8 +91,8 @@ async def _process(self, id: str, shards: list[bytes]) -> None: # This could be offloaded to a background # task at the cost of going further over # the soft memory limit. - for id, bufs in self._memory_buf.items(): - self._write(id, bufs) + for k, v in self._memory_buf.items(): + self._write(k, v) self._memory_buf.clear() self._in_memory = 0 @@ -114,7 +114,7 @@ def read(self, id: int | str) -> bytes: raise RuntimeError("Tried to read from file before done.") id = str(id) - data = self._memory_buf[id] + data = self._memory_buf.pop(id, []) try: with self.time("read"): with open( From 188a0829045b4646a0ac3a5500a53c8548de76a1 Mon Sep 17 00:00:00 2001 From: Lawrence Mitchell Date: Mon, 20 Mar 2023 15:18:22 +0000 Subject: [PATCH 03/10] Use an in-memory buffer proportional to the worker memory limit When buffering disk output, don't hardcode the memory limit, but rather allow up to a quarter of the worker's memory to be used. --- distributed/shuffle/_disk.py | 17 ++++++++++------- distributed/shuffle/_worker_extension.py | 11 +++++++++++ 2 files changed, 21 insertions(+), 7 deletions(-) diff --git a/distributed/shuffle/_disk.py b/distributed/shuffle/_disk.py index b40ca231647..598448824bb 100644 --- a/distributed/shuffle/_disk.py +++ b/distributed/shuffle/_disk.py @@ -5,8 +5,6 @@ import shutil from collections import defaultdict -from dask.utils import parse_bytes - from distributed.shuffle._buffer import ShardsBuffer from distributed.shuffle._limiter import ResourceLimiter from distributed.utils import log_errors @@ -39,15 +37,19 @@ class DiskShardsBuffer(ShardsBuffer): to be processed exceeds this limit, then the buffer will block until below the threshold. See :meth:`.write` for the implementation of this scheme. + max_in_memory_buffer_size : int, optional + Size of in-memory buffer to use before flushing to disk. If + configured, incoming shards will first be moved to memory + rather than immediately written to disk. This can provide for + speedups when an entire shuffle fits in memory. """ - # TODO: Configure based on worker memory limits (possibly dynamically?) - max_buffer_size = parse_bytes("1 GiB") - def __init__( self, directory: str | pathlib.Path, memory_limiter: ResourceLimiter | None = None, + *, + max_in_memory_buffer_size: int = 0, ): super().__init__( memory_limiter=memory_limiter, @@ -58,6 +60,7 @@ def __init__( self.directory.mkdir(exist_ok=True) self._in_memory = 0 self._memory_buf: defaultdict[str, list[bytes]] = defaultdict(list) + self.max_in_memory_buffer_size = max_in_memory_buffer_size async def _process(self, id: str, shards: list[bytes]) -> None: """Write one buffer to file @@ -78,12 +81,12 @@ async def _process(self, id: str, shards: list[bytes]) -> None: with log_errors(): # Consider boosting total_size a bit here to account for duplication with self.time("write"): - if not self.max_buffer_size: + if not self.max_in_memory_buffer_size: # Fast path if we're always hitting the disk self._write(id, shards) else: while shards: - if self._in_memory < self.max_buffer_size: + if self._in_memory < self.max_in_memory_buffer_size: self._memory_buf[id].append(newdata := shards.pop()) self._in_memory += len(newdata) else: diff --git a/distributed/shuffle/_worker_extension.py b/distributed/shuffle/_worker_extension.py index 751d69e1f57..84d0bcbaf4b 100644 --- a/distributed/shuffle/_worker_extension.py +++ b/distributed/shuffle/_worker_extension.py @@ -68,6 +68,7 @@ def __init__( scheduler: PooledRPCCall, memory_limiter_disk: ResourceLimiter, memory_limiter_comms: ResourceLimiter, + worker_memory_limit: int | None, ): self.id = id self.run_id = run_id @@ -81,6 +82,10 @@ def __init__( self._disk_buffer = DiskShardsBuffer( directory=directory, memory_limiter=memory_limiter_disk, + # If not given, then zero corresponds to no in-memory buffering (eager writes to disk) + max_in_memory_buffer_size=worker_memory_limit // 4 + if worker_memory_limit + else 0, ) self._comm_buffer = CommShardsBuffer( @@ -281,6 +286,7 @@ def __init__( scheduler: PooledRPCCall, memory_limiter_disk: ResourceLimiter, memory_limiter_comms: ResourceLimiter, + worker_memory_limit: int | None, ): from dask.array.rechunk import _old_to_new @@ -295,6 +301,7 @@ def __init__( scheduler=scheduler, memory_limiter_comms=memory_limiter_comms, memory_limiter_disk=memory_limiter_disk, + worker_memory_limit=worker_memory_limit, ) self.old = old self.new = new @@ -436,6 +443,7 @@ def __init__( scheduler: PooledRPCCall, memory_limiter_disk: ResourceLimiter, memory_limiter_comms: ResourceLimiter, + worker_memory_limit: int | None, ): import pandas as pd @@ -450,6 +458,7 @@ def __init__( scheduler=scheduler, memory_limiter_comms=memory_limiter_comms, memory_limiter_disk=memory_limiter_disk, + worker_memory_limit=worker_memory_limit, ) self.column = column self.schema = schema @@ -818,6 +827,7 @@ async def _( scheduler=self.worker.scheduler, memory_limiter_disk=self.memory_limiter_disk, memory_limiter_comms=self.memory_limiter_comms, + worker_memory_limit=self.worker.memory_manager.memory_limit, ) elif result["type"] == ShuffleType.ARRAY_RECHUNK: shuffle = ArrayRechunkRun( @@ -837,6 +847,7 @@ async def _( scheduler=self.worker.scheduler, memory_limiter_disk=self.memory_limiter_disk, memory_limiter_comms=self.memory_limiter_comms, + worker_memory_limit=self.worker.memory_manager.memory_limit, ) else: # pragma: no cover raise TypeError(result["type"]) From e9e2c146d62eb1c913890a58de6b4f39141e6433 Mon Sep 17 00:00:00 2001 From: Lawrence Mitchell Date: Mon, 20 Mar 2023 16:15:42 +0000 Subject: [PATCH 04/10] worker_memory_limit is optional when building shuffles --- distributed/shuffle/_worker_extension.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/distributed/shuffle/_worker_extension.py b/distributed/shuffle/_worker_extension.py index 84d0bcbaf4b..2b44734e806 100644 --- a/distributed/shuffle/_worker_extension.py +++ b/distributed/shuffle/_worker_extension.py @@ -286,7 +286,7 @@ def __init__( scheduler: PooledRPCCall, memory_limiter_disk: ResourceLimiter, memory_limiter_comms: ResourceLimiter, - worker_memory_limit: int | None, + worker_memory_limit: int | None = None, ): from dask.array.rechunk import _old_to_new @@ -443,7 +443,7 @@ def __init__( scheduler: PooledRPCCall, memory_limiter_disk: ResourceLimiter, memory_limiter_comms: ResourceLimiter, - worker_memory_limit: int | None, + worker_memory_limit: int | None = None, ): import pandas as pd From 81286d09467a2ca4de422d0d737650254181d196 Mon Sep 17 00:00:00 2001 From: Lawrence Mitchell Date: Tue, 21 Mar 2023 17:42:20 +0000 Subject: [PATCH 05/10] Config option for size of in-memory buffer for disk buffer --- distributed/distributed-schema.yaml | 12 ++++++++++++ distributed/distributed.yaml | 3 +++ distributed/shuffle/_worker_extension.py | 19 +++++++++++++++---- 3 files changed, 30 insertions(+), 4 deletions(-) diff --git a/distributed/distributed-schema.yaml b/distributed/distributed-schema.yaml index 3b52ae7044c..1f40567848c 100644 --- a/distributed/distributed-schema.yaml +++ b/distributed/distributed-schema.yaml @@ -644,6 +644,18 @@ properties: Should be used for variables that must be set before process startup, interpreter startup, or imports. + shuffle: + type: object + description: | + Low-level settings for control of p2p shuffle + properties: + output_max_buffer_size: + type: [string, integer, 'null'] + description: | + Maximum size of the in-memory output buffer for p2p + shuffles before a worker writes output to disk. If + ``None`` then a default of one quarter of the worker's + total memory is used. client: type: object description: | diff --git a/distributed/distributed.yaml b/distributed/distributed.yaml index a186b2d39e6..eda8face526 100644 --- a/distributed/distributed.yaml +++ b/distributed/distributed.yaml @@ -193,6 +193,9 @@ distributed: MKL_NUM_THREADS: 1 OPENBLAS_NUM_THREADS: 1 + shuffle: + output_max_buffer_size: null # Size of shuffle output memory buffer + client: heartbeat: 5s # Interval between client heartbeats scheduler-info-interval: 2s # Interval between scheduler-info updates diff --git a/distributed/shuffle/_worker_extension.py b/distributed/shuffle/_worker_extension.py index 2b44734e806..7e1c6ff4e81 100644 --- a/distributed/shuffle/_worker_extension.py +++ b/distributed/shuffle/_worker_extension.py @@ -15,6 +15,7 @@ import toolz +from dask import config from dask.utils import parse_bytes from distributed.core import PooledRPCCall @@ -79,13 +80,23 @@ def __init__( self.scheduler = scheduler self.closed = False + buffer_size: int | str | None = config.get( + "distributed.shuffle.output_max_buffer_size" + ) + if buffer_size is None: + if worker_memory_limit is None: + # No configuration and no known worker memory limit + # Safe default is "no in-memory buffering" + buffer_size = 0 + else: + buffer_size = worker_memory_limit // 4 + else: + buffer_size = parse_bytes(buffer_size) + self._disk_buffer = DiskShardsBuffer( directory=directory, memory_limiter=memory_limiter_disk, - # If not given, then zero corresponds to no in-memory buffering (eager writes to disk) - max_in_memory_buffer_size=worker_memory_limit // 4 - if worker_memory_limit - else 0, + max_in_memory_buffer_size=buffer_size, ) self._comm_buffer = CommShardsBuffer( From 893db039c665b783db276c9c422a308d1887bc68 Mon Sep 17 00:00:00 2001 From: Lawrence Mitchell Date: Tue, 21 Mar 2023 17:44:30 +0000 Subject: [PATCH 06/10] Move close back where it was --- distributed/shuffle/_disk.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/distributed/shuffle/_disk.py b/distributed/shuffle/_disk.py index 598448824bb..68045bd2263 100644 --- a/distributed/shuffle/_disk.py +++ b/distributed/shuffle/_disk.py @@ -104,11 +104,6 @@ def _write(self, id: str, shards: list[bytes]) -> None: for s in shards: f.write(s) - async def close(self) -> None: - await super().close() - with contextlib.suppress(FileNotFoundError): - shutil.rmtree(self.directory) - def read(self, id: int | str) -> bytes: """Read a complete file back into memory, concatting with any in memory parts""" @@ -135,3 +130,8 @@ def read(self, id: int | str) -> bytes: return buf else: raise KeyError(id) + + async def close(self) -> None: + await super().close() + with contextlib.suppress(FileNotFoundError): + shutil.rmtree(self.directory) From 9a408d92328d9cc73c729ca5079a0e97b442692a Mon Sep 17 00:00:00 2001 From: Lawrence Mitchell Date: Tue, 21 Mar 2023 17:55:39 +0000 Subject: [PATCH 07/10] Bad disk test runs with no in-memory buffering --- distributed/shuffle/tests/test_shuffle.py | 1 + 1 file changed, 1 insertion(+) diff --git a/distributed/shuffle/tests/test_shuffle.py b/distributed/shuffle/tests/test_shuffle.py index 02604f06382..1b82bb27812 100644 --- a/distributed/shuffle/tests/test_shuffle.py +++ b/distributed/shuffle/tests/test_shuffle.py @@ -143,6 +143,7 @@ async def test_concurrent(c, s, a, b): @gen_cluster(client=True) async def test_bad_disk(c, s, a, b): + await c.run(dask.config.set, {"distributed.shuffle.output_max_buffer_size": 0}) df = dask.datasets.timeseries( start="2000-01-01", end="2000-01-10", From 4cdd62216a9f12c316678cce2d2b4121e55af96d Mon Sep 17 00:00:00 2001 From: Lawrence Mitchell Date: Tue, 21 Mar 2023 18:06:14 +0000 Subject: [PATCH 08/10] Run some tests with in-memory buffering on and off --- distributed/shuffle/tests/test_shuffle.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/distributed/shuffle/tests/test_shuffle.py b/distributed/shuffle/tests/test_shuffle.py index 1b82bb27812..f8e2ade91d2 100644 --- a/distributed/shuffle/tests/test_shuffle.py +++ b/distributed/shuffle/tests/test_shuffle.py @@ -739,7 +739,19 @@ def __init__(self, value: int) -> None: @gen_cluster(client=True) -async def test_head(c, s, a, b): +@pytest.mark.parametrize( + "disk_buffer_size", + [ + 0, # No in-memory buffering + 128, # Small enough to hit disk + "1GiB", # Won't hit disk + ], +) +async def test_head(c, s, a, b, disk_buffer_size): + await c.run( + dask.config.set, + {"distributed.shuffle.output_max_buffer_size": disk_buffer_size}, + ) a_files = list(os.walk(a.local_directory)) b_files = list(os.walk(b.local_directory)) From 32a0844e4f0e4fe220e1123023668e0b39687821 Mon Sep 17 00:00:00 2001 From: Lawrence Mitchell Date: Tue, 21 Mar 2023 18:14:26 +0000 Subject: [PATCH 09/10] No more str | int --- distributed/shuffle/_disk.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/distributed/shuffle/_disk.py b/distributed/shuffle/_disk.py index 68045bd2263..e9065f4ef31 100644 --- a/distributed/shuffle/_disk.py +++ b/distributed/shuffle/_disk.py @@ -76,8 +76,6 @@ async def _process(self, id: str, shards: list[bytes]) -> None: dropping the write into communicate above. """ - # Normalisation for safety - id = str(id) with log_errors(): # Consider boosting total_size a bit here to account for duplication with self.time("write"): @@ -104,14 +102,13 @@ def _write(self, id: str, shards: list[bytes]) -> None: for s in shards: f.write(s) - def read(self, id: int | str) -> bytes: + def read(self, id: str) -> bytes: """Read a complete file back into memory, concatting with any in memory parts""" self.raise_on_exception() if not self._inputs_done: raise RuntimeError("Tried to read from file before done.") - id = str(id) data = self._memory_buf.pop(id, []) try: with self.time("read"): From 2e287c09df9694d71385c6ac2981f0c520dc7ed6 Mon Sep 17 00:00:00 2001 From: Lawrence Mitchell Date: Wed, 22 Mar 2023 18:33:22 +0000 Subject: [PATCH 10/10] Run shuffle tests with parameterized disk buffer size --- distributed/shuffle/tests/test_shuffle.py | 188 ++++++++++++++++++---- 1 file changed, 153 insertions(+), 35 deletions(-) diff --git a/distributed/shuffle/tests/test_shuffle.py b/distributed/shuffle/tests/test_shuffle.py index f8e2ade91d2..9a6d62b1771 100644 --- a/distributed/shuffle/tests/test_shuffle.py +++ b/distributed/shuffle/tests/test_shuffle.py @@ -74,6 +74,14 @@ async def clean_scheduler( assert not extension.heartbeats +@pytest.fixture( + params=["0B", "128B", "1GiB"], + ids=lambda x: f"{x} output buffer", +) +def disk_buffer_size(request): + return request.param + + @pytest.mark.skipif( pa is not None, reason="We don't have a CI job that is installing a very old pyarrow version", @@ -91,7 +99,11 @@ async def test_minimal_version(c, s, a, b): @gen_cluster(client=True) -async def test_basic_integration(c, s, a, b): +async def test_basic_integration(c, s, a, b, disk_buffer_size): + await c.run( + dask.config.set, + {"distributed.shuffle.output_max_buffer_size": disk_buffer_size}, + ) df = dask.datasets.timeseries( start="2000-01-01", end="2000-01-10", @@ -122,7 +134,11 @@ def test_raise_on_fuse_optimization(): @gen_cluster(client=True) -async def test_concurrent(c, s, a, b): +async def test_concurrent(c, s, a, b, disk_buffer_size): + await c.run( + dask.config.set, + {"distributed.shuffle.output_max_buffer_size": disk_buffer_size}, + ) df = dask.datasets.timeseries( start="2000-01-01", end="2000-01-10", @@ -143,6 +159,7 @@ async def test_concurrent(c, s, a, b): @gen_cluster(client=True) async def test_bad_disk(c, s, a, b): + # This only fails when there is no output memory buffering for the shuffle await c.run(dask.config.set, {"distributed.shuffle.output_max_buffer_size": 0}) df = dask.datasets.timeseries( start="2000-01-01", @@ -222,7 +239,11 @@ async def wait_until_new_shuffle_is_initialized( @gen_cluster(client=True, nthreads=[("", 1)] * 2) -async def test_closed_worker_during_transfer(c, s, a, b): +async def test_closed_worker_during_transfer(c, s, a, b, disk_buffer_size): + await c.run( + dask.config.set, + {"distributed.shuffle.output_max_buffer_size": disk_buffer_size}, + ) df = dask.datasets.timeseries( start="2000-01-01", end="2000-03-01", @@ -245,7 +266,11 @@ async def test_closed_worker_during_transfer(c, s, a, b): @pytest.mark.slow @gen_cluster(client=True, nthreads=[("", 1)]) -async def test_crashed_worker_during_transfer(c, s, a): +async def test_crashed_worker_during_transfer(c, s, a, disk_buffer_size): + await c.run( + dask.config.set, + {"distributed.shuffle.output_max_buffer_size": disk_buffer_size}, + ) async with Nanny(s.address, nthreads=1) as n: killed_worker_address = n.worker_address df = dask.datasets.timeseries( @@ -271,7 +296,12 @@ async def test_crashed_worker_during_transfer(c, s, a): # TODO: Deduplicate instead of failing: distributed#7324 @gen_cluster(client=True, nthreads=[("", 1)] * 2) -async def test_closed_input_only_worker_during_transfer(c, s, a, b): +async def test_closed_input_only_worker_during_transfer(c, s, a, b, disk_buffer_size): + await c.run( + dask.config.set, + {"distributed.shuffle.output_max_buffer_size": disk_buffer_size}, + ) + def mock_get_worker_for_range_sharding( output_partition: int, workers: list[str], npartitions: int ) -> str: @@ -339,7 +369,13 @@ def mock_mock_get_worker_for_range_sharding( @pytest.mark.slow @gen_cluster(client=True, nthreads=[("", 1)] * 3) -async def test_closed_bystanding_worker_during_shuffle(c, s, w1, w2, w3): +async def test_closed_bystanding_worker_during_shuffle( + c, s, w1, w2, w3, disk_buffer_size +): + await c.run( + dask.config.set, + {"distributed.shuffle.output_max_buffer_size": disk_buffer_size}, + ) with dask.annotate(workers=[w1.address, w2.address], allow_other_workers=False): df = dask.datasets.timeseries( start="2000-01-01", @@ -379,7 +415,11 @@ async def inputs_done(self) -> None: BlockedInputsDoneShuffle, ) @gen_cluster(client=True, nthreads=[("", 1)] * 2) -async def test_closed_worker_during_barrier(c, s, a, b): +async def test_closed_worker_during_barrier(c, s, a, b, disk_buffer_size): + await c.run( + dask.config.set, + {"distributed.shuffle.output_max_buffer_size": disk_buffer_size}, + ) df = dask.datasets.timeseries( start="2000-01-01", end="2000-01-10", @@ -423,7 +463,11 @@ async def test_closed_worker_during_barrier(c, s, a, b): BlockedInputsDoneShuffle, ) @gen_cluster(client=True, nthreads=[("", 1)] * 2) -async def test_closed_other_worker_during_barrier(c, s, a, b): +async def test_closed_other_worker_during_barrier(c, s, a, b, disk_buffer_size): + await c.run( + dask.config.set, + {"distributed.shuffle.output_max_buffer_size": disk_buffer_size}, + ) df = dask.datasets.timeseries( start="2000-01-01", end="2000-01-10", @@ -470,7 +514,11 @@ async def test_closed_other_worker_during_barrier(c, s, a, b): BlockedInputsDoneShuffle, ) @gen_cluster(client=True, nthreads=[("", 1)]) -async def test_crashed_other_worker_during_barrier(c, s, a): +async def test_crashed_other_worker_during_barrier(c, s, a, disk_buffer_size): + await c.run( + dask.config.set, + {"distributed.shuffle.output_max_buffer_size": disk_buffer_size}, + ) async with Nanny(s.address, nthreads=1) as n: df = dask.datasets.timeseries( start="2000-01-01", @@ -522,7 +570,11 @@ async def test_closed_worker_during_unpack(c, s, a, b): @pytest.mark.slow @gen_cluster(client=True, nthreads=[("", 1)]) -async def test_crashed_worker_during_unpack(c, s, a): +async def test_crashed_worker_during_unpack(c, s, a, disk_buffer_size): + await c.run( + dask.config.set, + {"distributed.shuffle.output_max_buffer_size": disk_buffer_size}, + ) async with Nanny(s.address, nthreads=2) as n: killed_worker_address = n.worker_address df = dask.datasets.timeseries( @@ -546,7 +598,11 @@ async def test_crashed_worker_during_unpack(c, s, a): @gen_cluster(client=True) -async def test_heartbeat(c, s, a, b): +async def test_heartbeat(c, s, a, b, disk_buffer_size): + await c.run( + dask.config.set, + {"distributed.shuffle.output_max_buffer_size": disk_buffer_size}, + ) await a.heartbeat() await clean_scheduler(s) df = dask.datasets.timeseries( @@ -739,14 +795,6 @@ def __init__(self, value: int) -> None: @gen_cluster(client=True) -@pytest.mark.parametrize( - "disk_buffer_size", - [ - 0, # No in-memory buffering - 128, # Small enough to hit disk - "1GiB", # Won't hit disk - ], -) async def test_head(c, s, a, b, disk_buffer_size): await c.run( dask.config.set, @@ -783,7 +831,11 @@ def test_split_by_worker(): @gen_cluster(client=True, nthreads=[("", 1)] * 2) -async def test_clean_after_forgotten_early(c, s, a, b): +async def test_clean_after_forgotten_early(c, s, a, b, disk_buffer_size): + await c.run( + dask.config.set, + {"distributed.shuffle.output_max_buffer_size": disk_buffer_size}, + ) df = dask.datasets.timeseries( start="2000-01-01", end="2000-03-01", @@ -801,7 +853,11 @@ async def test_clean_after_forgotten_early(c, s, a, b): @gen_cluster(client=True) -async def test_tail(c, s, a, b): +async def test_tail(c, s, a, b, disk_buffer_size): + await c.run( + dask.config.set, + {"distributed.shuffle.output_max_buffer_size": disk_buffer_size}, + ) df = dask.datasets.timeseries( start="2000-01-01", end="2000-01-10", @@ -826,13 +882,19 @@ async def test_tail(c, s, a, b): @pytest.mark.parametrize("wait_until_forgotten", [True, False]) @gen_cluster(client=True) -async def test_repeat_shuffle_instance(c, s, a, b, wait_until_forgotten): +async def test_repeat_shuffle_instance( + c, s, a, b, wait_until_forgotten, disk_buffer_size +): """Tests repeating the same instance of a shuffle-based task graph. See Also -------- test_repeat_shuffle_operation """ + await c.run( + dask.config.set, + {"distributed.shuffle.output_max_buffer_size": disk_buffer_size}, + ) df = dask.datasets.timeseries( start="2000-01-01", end="2000-01-10", @@ -855,7 +917,9 @@ async def test_repeat_shuffle_instance(c, s, a, b, wait_until_forgotten): @pytest.mark.parametrize("wait_until_forgotten", [True, False]) @gen_cluster(client=True) -async def test_repeat_shuffle_operation(c, s, a, b, wait_until_forgotten): +async def test_repeat_shuffle_operation( + c, s, a, b, wait_until_forgotten, disk_buffer_size +): """Tests repeating the same shuffle operation using two distinct instances of the task graph. @@ -863,6 +927,10 @@ async def test_repeat_shuffle_operation(c, s, a, b, wait_until_forgotten): -------- test_repeat_shuffle_instance """ + await c.run( + dask.config.set, + {"distributed.shuffle.output_max_buffer_size": disk_buffer_size}, + ) df = dask.datasets.timeseries( start="2000-01-01", end="2000-01-10", @@ -883,7 +951,11 @@ async def test_repeat_shuffle_operation(c, s, a, b, wait_until_forgotten): @gen_cluster(client=True, nthreads=[("", 1)]) -async def test_crashed_worker_after_shuffle(c, s, a): +async def test_crashed_worker_after_shuffle(c, s, a, disk_buffer_size): + await c.run( + dask.config.set, + {"distributed.shuffle.output_max_buffer_size": disk_buffer_size}, + ) in_event = Event() block_event = Event() @@ -944,7 +1016,11 @@ async def test_crashed_worker_after_shuffle_persisted(c, s, a): @gen_cluster(client=True, nthreads=[("", 1)] * 3) -async def test_closed_worker_between_repeats(c, s, w1, w2, w3): +async def test_closed_worker_between_repeats(c, s, w1, w2, w3, disk_buffer_size): + await c.run( + dask.config.set, + {"distributed.shuffle.output_max_buffer_size": disk_buffer_size}, + ) df = dask.datasets.timeseries( start="2000-01-01", end="2000-01-10", @@ -974,7 +1050,11 @@ async def test_closed_worker_between_repeats(c, s, w1, w2, w3): @gen_cluster(client=True) -async def test_new_worker(c, s, a, b): +async def test_new_worker(c, s, a, b, disk_buffer_size): + await c.run( + dask.config.set, + {"distributed.shuffle.output_max_buffer_size": disk_buffer_size}, + ) df = dask.datasets.timeseries( start="2000-01-01", end="2000-01-20", @@ -997,7 +1077,11 @@ async def test_new_worker(c, s, a, b): @gen_cluster(client=True) -async def test_multi(c, s, a, b): +async def test_multi(c, s, a, b, disk_buffer_size): + await c.run( + dask.config.set, + {"distributed.shuffle.output_max_buffer_size": disk_buffer_size}, + ) left = dask.datasets.timeseries( start="2000-01-01", end="2000-01-20", @@ -1023,7 +1107,11 @@ async def test_multi(c, s, a, b): @gen_cluster(client=True) -async def test_restrictions(c, s, a, b): +async def test_restrictions(c, s, a, b, disk_buffer_size): + await c.run( + dask.config.set, + {"distributed.shuffle.output_max_buffer_size": disk_buffer_size}, + ) df = dask.datasets.timeseries( start="2000-01-01", end="2000-01-10", @@ -1047,7 +1135,11 @@ async def test_restrictions(c, s, a, b): @gen_cluster(client=True) -async def test_delete_some_results(c, s, a, b): +async def test_delete_some_results(c, s, a, b, disk_buffer_size): + await c.run( + dask.config.set, + {"distributed.shuffle.output_max_buffer_size": disk_buffer_size}, + ) df = dask.datasets.timeseries( start="2000-01-01", end="2000-01-10", @@ -1068,7 +1160,11 @@ async def test_delete_some_results(c, s, a, b): @gen_cluster(client=True) -async def test_add_some_results(c, s, a, b): +async def test_add_some_results(c, s, a, b, disk_buffer_size): + await c.run( + dask.config.set, + {"distributed.shuffle.output_max_buffer_size": disk_buffer_size}, + ) df = dask.datasets.timeseries( start="2000-01-01", end="2000-01-10", @@ -1094,7 +1190,11 @@ async def test_add_some_results(c, s, a, b): @pytest.mark.slow @gen_cluster(client=True, nthreads=[("", 1)] * 2) -async def test_clean_after_close(c, s, a, b): +async def test_clean_after_close(c, s, a, b, disk_buffer_size): + await c.run( + dask.config.set, + {"distributed.shuffle.output_max_buffer_size": disk_buffer_size}, + ) df = dask.datasets.timeseries( start="2000-01-01", end="2001-01-01", @@ -1427,7 +1527,13 @@ async def shuffle_receive(self, *args: Any, **kwargs: Any) -> None: {"shuffle": BlockedShuffleReceiveShuffleWorkerExtension}, ) @gen_cluster(client=True, nthreads=[("", 1)] * 2) -async def test_deduplicate_stale_transfer(c, s, a, b, wait_until_forgotten): +async def test_deduplicate_stale_transfer( + c, s, a, b, wait_until_forgotten, disk_buffer_size +): + await c.run( + dask.config.set, + {"distributed.shuffle.output_max_buffer_size": disk_buffer_size}, + ) df = dask.datasets.timeseries( start="2000-01-01", end="2000-01-10", @@ -1481,7 +1587,11 @@ async def _barrier(self, *args: Any, **kwargs: Any) -> int: {"shuffle": BlockedBarrierShuffleWorkerExtension}, ) @gen_cluster(client=True, nthreads=[("", 1)] * 2) -async def test_handle_stale_barrier(c, s, a, b, wait_until_forgotten): +async def test_handle_stale_barrier(c, s, a, b, wait_until_forgotten, disk_buffer_size): + await c.run( + dask.config.set, + {"distributed.shuffle.output_max_buffer_size": disk_buffer_size}, + ) df = dask.datasets.timeseries( start="2000-01-01", end="2000-01-10", @@ -1527,7 +1637,7 @@ async def test_handle_stale_barrier(c, s, a, b, wait_until_forgotten): {"shuffle": BlockedBarrierShuffleWorkerExtension}, ) @gen_cluster(client=True, nthreads=[("", 1)]) -async def test_shuffle_run_consistency(c, s, a): +async def test_shuffle_run_consistency(c, s, a, disk_buffer_size): """This test checks the correct creation of shuffle run IDs through the scheduler as well as the correct handling through the workers. @@ -1538,6 +1648,10 @@ async def test_shuffle_run_consistency(c, s, a): The P2P implementation relies on the correctness of this behavior, but it is an implementation detail that users should not rely upon. """ + await c.run( + dask.config.set, + {"distributed.shuffle.output_max_buffer_size": disk_buffer_size}, + ) worker_ext = a.extensions["shuffle"] scheduler_ext = s.extensions["shuffle"] @@ -1629,7 +1743,11 @@ def shuffle_fail(self, *args: Any, **kwargs: Any) -> None: {"shuffle": BlockedShuffleAccessAndFailWorkerExtension}, ) @gen_cluster(client=True, nthreads=[("", 1)] * 2) -async def test_replace_stale_shuffle(c, s, a, b): +async def test_replace_stale_shuffle(c, s, a, b, disk_buffer_size): + await c.run( + dask.config.set, + {"distributed.shuffle.output_max_buffer_size": disk_buffer_size}, + ) ext_A = a.extensions["shuffle"] ext_B = b.extensions["shuffle"]