diff --git a/distributed/distributed-schema.yaml b/distributed/distributed-schema.yaml index 3b52ae7044c..1f40567848c 100644 --- a/distributed/distributed-schema.yaml +++ b/distributed/distributed-schema.yaml @@ -644,6 +644,18 @@ properties: Should be used for variables that must be set before process startup, interpreter startup, or imports. + shuffle: + type: object + description: | + Low-level settings for control of p2p shuffle + properties: + output_max_buffer_size: + type: [string, integer, 'null'] + description: | + Maximum size of the in-memory output buffer for p2p + shuffles before a worker writes output to disk. If + ``None`` then a default of one quarter of the worker's + total memory is used. client: type: object description: | diff --git a/distributed/distributed.yaml b/distributed/distributed.yaml index a186b2d39e6..eda8face526 100644 --- a/distributed/distributed.yaml +++ b/distributed/distributed.yaml @@ -193,6 +193,9 @@ distributed: MKL_NUM_THREADS: 1 OPENBLAS_NUM_THREADS: 1 + shuffle: + output_max_buffer_size: null # Size of shuffle output memory buffer + client: heartbeat: 5s # Interval between client heartbeats scheduler-info-interval: 2s # Interval between scheduler-info updates diff --git a/distributed/shuffle/_disk.py b/distributed/shuffle/_disk.py index 2b3dc37beed..e9065f4ef31 100644 --- a/distributed/shuffle/_disk.py +++ b/distributed/shuffle/_disk.py @@ -3,6 +3,7 @@ import contextlib import pathlib import shutil +from collections import defaultdict from distributed.shuffle._buffer import ShardsBuffer from distributed.shuffle._limiter import ResourceLimiter @@ -36,12 +37,19 @@ class DiskShardsBuffer(ShardsBuffer): to be processed exceeds this limit, then the buffer will block until below the threshold. See :meth:`.write` for the implementation of this scheme. + max_in_memory_buffer_size : int, optional + Size of in-memory buffer to use before flushing to disk. If + configured, incoming shards will first be moved to memory + rather than immediately written to disk. This can provide for + speedups when an entire shuffle fits in memory. """ def __init__( self, directory: str | pathlib.Path, memory_limiter: ResourceLimiter | None = None, + *, + max_in_memory_buffer_size: int = 0, ): super().__init__( memory_limiter=memory_limiter, @@ -50,6 +58,9 @@ def __init__( ) self.directory = pathlib.Path(directory) self.directory.mkdir(exist_ok=True) + self._in_memory = 0 + self._memory_buf: defaultdict[str, list[bytes]] = defaultdict(list) + self.max_in_memory_buffer_size = max_in_memory_buffer_size async def _process(self, id: str, shards: list[bytes]) -> None: """Write one buffer to file @@ -68,31 +79,52 @@ async def _process(self, id: str, shards: list[bytes]) -> None: with log_errors(): # Consider boosting total_size a bit here to account for duplication with self.time("write"): - with open( - self.directory / str(id), mode="ab", buffering=100_000_000 - ) as f: - for shard in shards: - f.write(shard) - - def read(self, id: int | str) -> bytes: - """Read a complete file back into memory""" + if not self.max_in_memory_buffer_size: + # Fast path if we're always hitting the disk + self._write(id, shards) + else: + while shards: + if self._in_memory < self.max_in_memory_buffer_size: + self._memory_buf[id].append(newdata := shards.pop()) + self._in_memory += len(newdata) + else: + # Flush old data + # This could be offloaded to a background + # task at the cost of going further over + # the soft memory limit. + for k, v in self._memory_buf.items(): + self._write(k, v) + self._memory_buf.clear() + self._in_memory = 0 + + def _write(self, id: str, shards: list[bytes]) -> None: + with open(self.directory / str(id), mode="ab", buffering=100_000_000) as f: + for s in shards: + f.write(s) + + def read(self, id: str) -> bytes: + """Read a complete file back into memory, concatting with any + in memory parts""" self.raise_on_exception() if not self._inputs_done: raise RuntimeError("Tried to read from file before done.") + data = self._memory_buf.pop(id, []) try: with self.time("read"): with open( self.directory / str(id), mode="rb", buffering=100_000_000 ) as f: - data = f.read() - size = f.tell() + data.append(f.read()) except FileNotFoundError: - raise KeyError(id) + if not data: + # Neither disk nor in memory + raise KeyError(id) if data: - self.bytes_read += size - return data + buf = b"".join(data) + self.bytes_read += len(buf) + return buf else: raise KeyError(id) diff --git a/distributed/shuffle/_worker_extension.py b/distributed/shuffle/_worker_extension.py index 751d69e1f57..7e1c6ff4e81 100644 --- a/distributed/shuffle/_worker_extension.py +++ b/distributed/shuffle/_worker_extension.py @@ -15,6 +15,7 @@ import toolz +from dask import config from dask.utils import parse_bytes from distributed.core import PooledRPCCall @@ -68,6 +69,7 @@ def __init__( scheduler: PooledRPCCall, memory_limiter_disk: ResourceLimiter, memory_limiter_comms: ResourceLimiter, + worker_memory_limit: int | None, ): self.id = id self.run_id = run_id @@ -78,9 +80,23 @@ def __init__( self.scheduler = scheduler self.closed = False + buffer_size: int | str | None = config.get( + "distributed.shuffle.output_max_buffer_size" + ) + if buffer_size is None: + if worker_memory_limit is None: + # No configuration and no known worker memory limit + # Safe default is "no in-memory buffering" + buffer_size = 0 + else: + buffer_size = worker_memory_limit // 4 + else: + buffer_size = parse_bytes(buffer_size) + self._disk_buffer = DiskShardsBuffer( directory=directory, memory_limiter=memory_limiter_disk, + max_in_memory_buffer_size=buffer_size, ) self._comm_buffer = CommShardsBuffer( @@ -281,6 +297,7 @@ def __init__( scheduler: PooledRPCCall, memory_limiter_disk: ResourceLimiter, memory_limiter_comms: ResourceLimiter, + worker_memory_limit: int | None = None, ): from dask.array.rechunk import _old_to_new @@ -295,6 +312,7 @@ def __init__( scheduler=scheduler, memory_limiter_comms=memory_limiter_comms, memory_limiter_disk=memory_limiter_disk, + worker_memory_limit=worker_memory_limit, ) self.old = old self.new = new @@ -436,6 +454,7 @@ def __init__( scheduler: PooledRPCCall, memory_limiter_disk: ResourceLimiter, memory_limiter_comms: ResourceLimiter, + worker_memory_limit: int | None = None, ): import pandas as pd @@ -450,6 +469,7 @@ def __init__( scheduler=scheduler, memory_limiter_comms=memory_limiter_comms, memory_limiter_disk=memory_limiter_disk, + worker_memory_limit=worker_memory_limit, ) self.column = column self.schema = schema @@ -818,6 +838,7 @@ async def _( scheduler=self.worker.scheduler, memory_limiter_disk=self.memory_limiter_disk, memory_limiter_comms=self.memory_limiter_comms, + worker_memory_limit=self.worker.memory_manager.memory_limit, ) elif result["type"] == ShuffleType.ARRAY_RECHUNK: shuffle = ArrayRechunkRun( @@ -837,6 +858,7 @@ async def _( scheduler=self.worker.scheduler, memory_limiter_disk=self.memory_limiter_disk, memory_limiter_comms=self.memory_limiter_comms, + worker_memory_limit=self.worker.memory_manager.memory_limit, ) else: # pragma: no cover raise TypeError(result["type"]) diff --git a/distributed/shuffle/tests/test_shuffle.py b/distributed/shuffle/tests/test_shuffle.py index 02604f06382..9a6d62b1771 100644 --- a/distributed/shuffle/tests/test_shuffle.py +++ b/distributed/shuffle/tests/test_shuffle.py @@ -74,6 +74,14 @@ async def clean_scheduler( assert not extension.heartbeats +@pytest.fixture( + params=["0B", "128B", "1GiB"], + ids=lambda x: f"{x} output buffer", +) +def disk_buffer_size(request): + return request.param + + @pytest.mark.skipif( pa is not None, reason="We don't have a CI job that is installing a very old pyarrow version", @@ -91,7 +99,11 @@ async def test_minimal_version(c, s, a, b): @gen_cluster(client=True) -async def test_basic_integration(c, s, a, b): +async def test_basic_integration(c, s, a, b, disk_buffer_size): + await c.run( + dask.config.set, + {"distributed.shuffle.output_max_buffer_size": disk_buffer_size}, + ) df = dask.datasets.timeseries( start="2000-01-01", end="2000-01-10", @@ -122,7 +134,11 @@ def test_raise_on_fuse_optimization(): @gen_cluster(client=True) -async def test_concurrent(c, s, a, b): +async def test_concurrent(c, s, a, b, disk_buffer_size): + await c.run( + dask.config.set, + {"distributed.shuffle.output_max_buffer_size": disk_buffer_size}, + ) df = dask.datasets.timeseries( start="2000-01-01", end="2000-01-10", @@ -143,6 +159,8 @@ async def test_concurrent(c, s, a, b): @gen_cluster(client=True) async def test_bad_disk(c, s, a, b): + # This only fails when there is no output memory buffering for the shuffle + await c.run(dask.config.set, {"distributed.shuffle.output_max_buffer_size": 0}) df = dask.datasets.timeseries( start="2000-01-01", end="2000-01-10", @@ -221,7 +239,11 @@ async def wait_until_new_shuffle_is_initialized( @gen_cluster(client=True, nthreads=[("", 1)] * 2) -async def test_closed_worker_during_transfer(c, s, a, b): +async def test_closed_worker_during_transfer(c, s, a, b, disk_buffer_size): + await c.run( + dask.config.set, + {"distributed.shuffle.output_max_buffer_size": disk_buffer_size}, + ) df = dask.datasets.timeseries( start="2000-01-01", end="2000-03-01", @@ -244,7 +266,11 @@ async def test_closed_worker_during_transfer(c, s, a, b): @pytest.mark.slow @gen_cluster(client=True, nthreads=[("", 1)]) -async def test_crashed_worker_during_transfer(c, s, a): +async def test_crashed_worker_during_transfer(c, s, a, disk_buffer_size): + await c.run( + dask.config.set, + {"distributed.shuffle.output_max_buffer_size": disk_buffer_size}, + ) async with Nanny(s.address, nthreads=1) as n: killed_worker_address = n.worker_address df = dask.datasets.timeseries( @@ -270,7 +296,12 @@ async def test_crashed_worker_during_transfer(c, s, a): # TODO: Deduplicate instead of failing: distributed#7324 @gen_cluster(client=True, nthreads=[("", 1)] * 2) -async def test_closed_input_only_worker_during_transfer(c, s, a, b): +async def test_closed_input_only_worker_during_transfer(c, s, a, b, disk_buffer_size): + await c.run( + dask.config.set, + {"distributed.shuffle.output_max_buffer_size": disk_buffer_size}, + ) + def mock_get_worker_for_range_sharding( output_partition: int, workers: list[str], npartitions: int ) -> str: @@ -338,7 +369,13 @@ def mock_mock_get_worker_for_range_sharding( @pytest.mark.slow @gen_cluster(client=True, nthreads=[("", 1)] * 3) -async def test_closed_bystanding_worker_during_shuffle(c, s, w1, w2, w3): +async def test_closed_bystanding_worker_during_shuffle( + c, s, w1, w2, w3, disk_buffer_size +): + await c.run( + dask.config.set, + {"distributed.shuffle.output_max_buffer_size": disk_buffer_size}, + ) with dask.annotate(workers=[w1.address, w2.address], allow_other_workers=False): df = dask.datasets.timeseries( start="2000-01-01", @@ -378,7 +415,11 @@ async def inputs_done(self) -> None: BlockedInputsDoneShuffle, ) @gen_cluster(client=True, nthreads=[("", 1)] * 2) -async def test_closed_worker_during_barrier(c, s, a, b): +async def test_closed_worker_during_barrier(c, s, a, b, disk_buffer_size): + await c.run( + dask.config.set, + {"distributed.shuffle.output_max_buffer_size": disk_buffer_size}, + ) df = dask.datasets.timeseries( start="2000-01-01", end="2000-01-10", @@ -422,7 +463,11 @@ async def test_closed_worker_during_barrier(c, s, a, b): BlockedInputsDoneShuffle, ) @gen_cluster(client=True, nthreads=[("", 1)] * 2) -async def test_closed_other_worker_during_barrier(c, s, a, b): +async def test_closed_other_worker_during_barrier(c, s, a, b, disk_buffer_size): + await c.run( + dask.config.set, + {"distributed.shuffle.output_max_buffer_size": disk_buffer_size}, + ) df = dask.datasets.timeseries( start="2000-01-01", end="2000-01-10", @@ -469,7 +514,11 @@ async def test_closed_other_worker_during_barrier(c, s, a, b): BlockedInputsDoneShuffle, ) @gen_cluster(client=True, nthreads=[("", 1)]) -async def test_crashed_other_worker_during_barrier(c, s, a): +async def test_crashed_other_worker_during_barrier(c, s, a, disk_buffer_size): + await c.run( + dask.config.set, + {"distributed.shuffle.output_max_buffer_size": disk_buffer_size}, + ) async with Nanny(s.address, nthreads=1) as n: df = dask.datasets.timeseries( start="2000-01-01", @@ -521,7 +570,11 @@ async def test_closed_worker_during_unpack(c, s, a, b): @pytest.mark.slow @gen_cluster(client=True, nthreads=[("", 1)]) -async def test_crashed_worker_during_unpack(c, s, a): +async def test_crashed_worker_during_unpack(c, s, a, disk_buffer_size): + await c.run( + dask.config.set, + {"distributed.shuffle.output_max_buffer_size": disk_buffer_size}, + ) async with Nanny(s.address, nthreads=2) as n: killed_worker_address = n.worker_address df = dask.datasets.timeseries( @@ -545,7 +598,11 @@ async def test_crashed_worker_during_unpack(c, s, a): @gen_cluster(client=True) -async def test_heartbeat(c, s, a, b): +async def test_heartbeat(c, s, a, b, disk_buffer_size): + await c.run( + dask.config.set, + {"distributed.shuffle.output_max_buffer_size": disk_buffer_size}, + ) await a.heartbeat() await clean_scheduler(s) df = dask.datasets.timeseries( @@ -738,7 +795,11 @@ def __init__(self, value: int) -> None: @gen_cluster(client=True) -async def test_head(c, s, a, b): +async def test_head(c, s, a, b, disk_buffer_size): + await c.run( + dask.config.set, + {"distributed.shuffle.output_max_buffer_size": disk_buffer_size}, + ) a_files = list(os.walk(a.local_directory)) b_files = list(os.walk(b.local_directory)) @@ -770,7 +831,11 @@ def test_split_by_worker(): @gen_cluster(client=True, nthreads=[("", 1)] * 2) -async def test_clean_after_forgotten_early(c, s, a, b): +async def test_clean_after_forgotten_early(c, s, a, b, disk_buffer_size): + await c.run( + dask.config.set, + {"distributed.shuffle.output_max_buffer_size": disk_buffer_size}, + ) df = dask.datasets.timeseries( start="2000-01-01", end="2000-03-01", @@ -788,7 +853,11 @@ async def test_clean_after_forgotten_early(c, s, a, b): @gen_cluster(client=True) -async def test_tail(c, s, a, b): +async def test_tail(c, s, a, b, disk_buffer_size): + await c.run( + dask.config.set, + {"distributed.shuffle.output_max_buffer_size": disk_buffer_size}, + ) df = dask.datasets.timeseries( start="2000-01-01", end="2000-01-10", @@ -813,13 +882,19 @@ async def test_tail(c, s, a, b): @pytest.mark.parametrize("wait_until_forgotten", [True, False]) @gen_cluster(client=True) -async def test_repeat_shuffle_instance(c, s, a, b, wait_until_forgotten): +async def test_repeat_shuffle_instance( + c, s, a, b, wait_until_forgotten, disk_buffer_size +): """Tests repeating the same instance of a shuffle-based task graph. See Also -------- test_repeat_shuffle_operation """ + await c.run( + dask.config.set, + {"distributed.shuffle.output_max_buffer_size": disk_buffer_size}, + ) df = dask.datasets.timeseries( start="2000-01-01", end="2000-01-10", @@ -842,7 +917,9 @@ async def test_repeat_shuffle_instance(c, s, a, b, wait_until_forgotten): @pytest.mark.parametrize("wait_until_forgotten", [True, False]) @gen_cluster(client=True) -async def test_repeat_shuffle_operation(c, s, a, b, wait_until_forgotten): +async def test_repeat_shuffle_operation( + c, s, a, b, wait_until_forgotten, disk_buffer_size +): """Tests repeating the same shuffle operation using two distinct instances of the task graph. @@ -850,6 +927,10 @@ async def test_repeat_shuffle_operation(c, s, a, b, wait_until_forgotten): -------- test_repeat_shuffle_instance """ + await c.run( + dask.config.set, + {"distributed.shuffle.output_max_buffer_size": disk_buffer_size}, + ) df = dask.datasets.timeseries( start="2000-01-01", end="2000-01-10", @@ -870,7 +951,11 @@ async def test_repeat_shuffle_operation(c, s, a, b, wait_until_forgotten): @gen_cluster(client=True, nthreads=[("", 1)]) -async def test_crashed_worker_after_shuffle(c, s, a): +async def test_crashed_worker_after_shuffle(c, s, a, disk_buffer_size): + await c.run( + dask.config.set, + {"distributed.shuffle.output_max_buffer_size": disk_buffer_size}, + ) in_event = Event() block_event = Event() @@ -931,7 +1016,11 @@ async def test_crashed_worker_after_shuffle_persisted(c, s, a): @gen_cluster(client=True, nthreads=[("", 1)] * 3) -async def test_closed_worker_between_repeats(c, s, w1, w2, w3): +async def test_closed_worker_between_repeats(c, s, w1, w2, w3, disk_buffer_size): + await c.run( + dask.config.set, + {"distributed.shuffle.output_max_buffer_size": disk_buffer_size}, + ) df = dask.datasets.timeseries( start="2000-01-01", end="2000-01-10", @@ -961,7 +1050,11 @@ async def test_closed_worker_between_repeats(c, s, w1, w2, w3): @gen_cluster(client=True) -async def test_new_worker(c, s, a, b): +async def test_new_worker(c, s, a, b, disk_buffer_size): + await c.run( + dask.config.set, + {"distributed.shuffle.output_max_buffer_size": disk_buffer_size}, + ) df = dask.datasets.timeseries( start="2000-01-01", end="2000-01-20", @@ -984,7 +1077,11 @@ async def test_new_worker(c, s, a, b): @gen_cluster(client=True) -async def test_multi(c, s, a, b): +async def test_multi(c, s, a, b, disk_buffer_size): + await c.run( + dask.config.set, + {"distributed.shuffle.output_max_buffer_size": disk_buffer_size}, + ) left = dask.datasets.timeseries( start="2000-01-01", end="2000-01-20", @@ -1010,7 +1107,11 @@ async def test_multi(c, s, a, b): @gen_cluster(client=True) -async def test_restrictions(c, s, a, b): +async def test_restrictions(c, s, a, b, disk_buffer_size): + await c.run( + dask.config.set, + {"distributed.shuffle.output_max_buffer_size": disk_buffer_size}, + ) df = dask.datasets.timeseries( start="2000-01-01", end="2000-01-10", @@ -1034,7 +1135,11 @@ async def test_restrictions(c, s, a, b): @gen_cluster(client=True) -async def test_delete_some_results(c, s, a, b): +async def test_delete_some_results(c, s, a, b, disk_buffer_size): + await c.run( + dask.config.set, + {"distributed.shuffle.output_max_buffer_size": disk_buffer_size}, + ) df = dask.datasets.timeseries( start="2000-01-01", end="2000-01-10", @@ -1055,7 +1160,11 @@ async def test_delete_some_results(c, s, a, b): @gen_cluster(client=True) -async def test_add_some_results(c, s, a, b): +async def test_add_some_results(c, s, a, b, disk_buffer_size): + await c.run( + dask.config.set, + {"distributed.shuffle.output_max_buffer_size": disk_buffer_size}, + ) df = dask.datasets.timeseries( start="2000-01-01", end="2000-01-10", @@ -1081,7 +1190,11 @@ async def test_add_some_results(c, s, a, b): @pytest.mark.slow @gen_cluster(client=True, nthreads=[("", 1)] * 2) -async def test_clean_after_close(c, s, a, b): +async def test_clean_after_close(c, s, a, b, disk_buffer_size): + await c.run( + dask.config.set, + {"distributed.shuffle.output_max_buffer_size": disk_buffer_size}, + ) df = dask.datasets.timeseries( start="2000-01-01", end="2001-01-01", @@ -1414,7 +1527,13 @@ async def shuffle_receive(self, *args: Any, **kwargs: Any) -> None: {"shuffle": BlockedShuffleReceiveShuffleWorkerExtension}, ) @gen_cluster(client=True, nthreads=[("", 1)] * 2) -async def test_deduplicate_stale_transfer(c, s, a, b, wait_until_forgotten): +async def test_deduplicate_stale_transfer( + c, s, a, b, wait_until_forgotten, disk_buffer_size +): + await c.run( + dask.config.set, + {"distributed.shuffle.output_max_buffer_size": disk_buffer_size}, + ) df = dask.datasets.timeseries( start="2000-01-01", end="2000-01-10", @@ -1468,7 +1587,11 @@ async def _barrier(self, *args: Any, **kwargs: Any) -> int: {"shuffle": BlockedBarrierShuffleWorkerExtension}, ) @gen_cluster(client=True, nthreads=[("", 1)] * 2) -async def test_handle_stale_barrier(c, s, a, b, wait_until_forgotten): +async def test_handle_stale_barrier(c, s, a, b, wait_until_forgotten, disk_buffer_size): + await c.run( + dask.config.set, + {"distributed.shuffle.output_max_buffer_size": disk_buffer_size}, + ) df = dask.datasets.timeseries( start="2000-01-01", end="2000-01-10", @@ -1514,7 +1637,7 @@ async def test_handle_stale_barrier(c, s, a, b, wait_until_forgotten): {"shuffle": BlockedBarrierShuffleWorkerExtension}, ) @gen_cluster(client=True, nthreads=[("", 1)]) -async def test_shuffle_run_consistency(c, s, a): +async def test_shuffle_run_consistency(c, s, a, disk_buffer_size): """This test checks the correct creation of shuffle run IDs through the scheduler as well as the correct handling through the workers. @@ -1525,6 +1648,10 @@ async def test_shuffle_run_consistency(c, s, a): The P2P implementation relies on the correctness of this behavior, but it is an implementation detail that users should not rely upon. """ + await c.run( + dask.config.set, + {"distributed.shuffle.output_max_buffer_size": disk_buffer_size}, + ) worker_ext = a.extensions["shuffle"] scheduler_ext = s.extensions["shuffle"] @@ -1616,7 +1743,11 @@ def shuffle_fail(self, *args: Any, **kwargs: Any) -> None: {"shuffle": BlockedShuffleAccessAndFailWorkerExtension}, ) @gen_cluster(client=True, nthreads=[("", 1)] * 2) -async def test_replace_stale_shuffle(c, s, a, b): +async def test_replace_stale_shuffle(c, s, a, b, disk_buffer_size): + await c.run( + dask.config.set, + {"distributed.shuffle.output_max_buffer_size": disk_buffer_size}, + ) ext_A = a.extensions["shuffle"] ext_B = b.extensions["shuffle"]