Skip to content
12 changes: 12 additions & 0 deletions distributed/distributed-schema.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -644,6 +644,18 @@ properties:
Should be used for variables that must be set before
process startup, interpreter startup, or imports.

shuffle:
type: object
description: |
Low-level settings for control of p2p shuffle
properties:
output_max_buffer_size:
type: [string, integer, 'null']
description: |
Maximum size of the in-memory output buffer for p2p
shuffles before a worker writes output to disk. If
``None`` then a default of one quarter of the worker's
total memory is used.
client:
type: object
description: |
Expand Down
3 changes: 3 additions & 0 deletions distributed/distributed.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,9 @@ distributed:
MKL_NUM_THREADS: 1
OPENBLAS_NUM_THREADS: 1

shuffle:
output_max_buffer_size: null # Size of shuffle output memory buffer

client:
heartbeat: 5s # Interval between client heartbeats
scheduler-info-interval: 2s # Interval between scheduler-info updates
Expand Down
58 changes: 45 additions & 13 deletions distributed/shuffle/_disk.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import contextlib
import pathlib
import shutil
from collections import defaultdict

from distributed.shuffle._buffer import ShardsBuffer
from distributed.shuffle._limiter import ResourceLimiter
Expand Down Expand Up @@ -36,12 +37,19 @@ class DiskShardsBuffer(ShardsBuffer):
to be processed exceeds this limit, then the buffer will block
until below the threshold. See :meth:`.write` for the
implementation of this scheme.
max_in_memory_buffer_size : int, optional
Size of in-memory buffer to use before flushing to disk. If
configured, incoming shards will first be moved to memory
rather than immediately written to disk. This can provide for
speedups when an entire shuffle fits in memory.
"""

def __init__(
self,
directory: str | pathlib.Path,
memory_limiter: ResourceLimiter | None = None,
*,
max_in_memory_buffer_size: int = 0,
):
super().__init__(
memory_limiter=memory_limiter,
Expand All @@ -50,6 +58,9 @@ def __init__(
)
self.directory = pathlib.Path(directory)
self.directory.mkdir(exist_ok=True)
self._in_memory = 0
self._memory_buf: defaultdict[str, list[bytes]] = defaultdict(list)
self.max_in_memory_buffer_size = max_in_memory_buffer_size

async def _process(self, id: str, shards: list[bytes]) -> None:
"""Write one buffer to file
Expand All @@ -68,31 +79,52 @@ async def _process(self, id: str, shards: list[bytes]) -> None:
with log_errors():
# Consider boosting total_size a bit here to account for duplication
with self.time("write"):
with open(
self.directory / str(id), mode="ab", buffering=100_000_000
) as f:
for shard in shards:
f.write(shard)

def read(self, id: int | str) -> bytes:
"""Read a complete file back into memory"""
if not self.max_in_memory_buffer_size:
# Fast path if we're always hitting the disk
self._write(id, shards)
else:
while shards:
if self._in_memory < self.max_in_memory_buffer_size:
self._memory_buf[id].append(newdata := shards.pop())
self._in_memory += len(newdata)
else:
# Flush old data
# This could be offloaded to a background
# task at the cost of going further over
# the soft memory limit.
for k, v in self._memory_buf.items():
self._write(k, v)
self._memory_buf.clear()
self._in_memory = 0

def _write(self, id: str, shards: list[bytes]) -> None:
with open(self.directory / str(id), mode="ab", buffering=100_000_000) as f:
for s in shards:
f.write(s)

def read(self, id: str) -> bytes:
"""Read a complete file back into memory, concatting with any
in memory parts"""
self.raise_on_exception()
if not self._inputs_done:
raise RuntimeError("Tried to read from file before done.")

data = self._memory_buf.pop(id, [])
try:
with self.time("read"):
with open(
self.directory / str(id), mode="rb", buffering=100_000_000
) as f:
data = f.read()
size = f.tell()
data.append(f.read())
except FileNotFoundError:
raise KeyError(id)
if not data:
# Neither disk nor in memory
raise KeyError(id)

if data:
self.bytes_read += size
return data
buf = b"".join(data)
self.bytes_read += len(buf)
return buf
else:
raise KeyError(id)

Expand Down
22 changes: 22 additions & 0 deletions distributed/shuffle/_worker_extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import toolz

from dask import config
from dask.utils import parse_bytes

from distributed.core import PooledRPCCall
Expand Down Expand Up @@ -68,6 +69,7 @@ def __init__(
scheduler: PooledRPCCall,
memory_limiter_disk: ResourceLimiter,
memory_limiter_comms: ResourceLimiter,
worker_memory_limit: int | None,
):
self.id = id
self.run_id = run_id
Expand All @@ -78,9 +80,23 @@ def __init__(
self.scheduler = scheduler
self.closed = False

buffer_size: int | str | None = config.get(
"distributed.shuffle.output_max_buffer_size"
)
if buffer_size is None:
if worker_memory_limit is None:
# No configuration and no known worker memory limit
# Safe default is "no in-memory buffering"
buffer_size = 0
else:
buffer_size = worker_memory_limit // 4
else:
buffer_size = parse_bytes(buffer_size)

self._disk_buffer = DiskShardsBuffer(
directory=directory,
memory_limiter=memory_limiter_disk,
max_in_memory_buffer_size=buffer_size,
)

self._comm_buffer = CommShardsBuffer(
Expand Down Expand Up @@ -281,6 +297,7 @@ def __init__(
scheduler: PooledRPCCall,
memory_limiter_disk: ResourceLimiter,
memory_limiter_comms: ResourceLimiter,
worker_memory_limit: int | None = None,
):
from dask.array.rechunk import _old_to_new

Expand All @@ -295,6 +312,7 @@ def __init__(
scheduler=scheduler,
memory_limiter_comms=memory_limiter_comms,
memory_limiter_disk=memory_limiter_disk,
worker_memory_limit=worker_memory_limit,
)
self.old = old
self.new = new
Expand Down Expand Up @@ -436,6 +454,7 @@ def __init__(
scheduler: PooledRPCCall,
memory_limiter_disk: ResourceLimiter,
memory_limiter_comms: ResourceLimiter,
worker_memory_limit: int | None = None,
):
import pandas as pd

Expand All @@ -450,6 +469,7 @@ def __init__(
scheduler=scheduler,
memory_limiter_comms=memory_limiter_comms,
memory_limiter_disk=memory_limiter_disk,
worker_memory_limit=worker_memory_limit,
)
self.column = column
self.schema = schema
Expand Down Expand Up @@ -818,6 +838,7 @@ async def _(
scheduler=self.worker.scheduler,
memory_limiter_disk=self.memory_limiter_disk,
memory_limiter_comms=self.memory_limiter_comms,
worker_memory_limit=self.worker.memory_manager.memory_limit,
)
elif result["type"] == ShuffleType.ARRAY_RECHUNK:
shuffle = ArrayRechunkRun(
Expand All @@ -837,6 +858,7 @@ async def _(
scheduler=self.worker.scheduler,
memory_limiter_disk=self.memory_limiter_disk,
memory_limiter_comms=self.memory_limiter_comms,
worker_memory_limit=self.worker.memory_manager.memory_limit,
)
else: # pragma: no cover
raise TypeError(result["type"])
Expand Down
Loading