diff --git a/distributed/shuffle/_buffer.py b/distributed/shuffle/_buffer.py index e329d54dcce..127500720e2 100644 --- a/distributed/shuffle/_buffer.py +++ b/distributed/shuffle/_buffer.py @@ -8,7 +8,7 @@ from typing import Any, Generic, TypeVar from distributed.metrics import time -from distributed.shuffle._limiter import ResourceLimiter +from distributed.shuffle._limiter import AbstractLimiter, NoopLimiter, ResourceLimiter from distributed.sizeof import sizeof logger = logging.getLogger("distributed.shuffle") @@ -46,7 +46,7 @@ class ShardsBuffer(Generic[ShardType]): shards: defaultdict[str, _List[ShardType]] sizes: defaultdict[str, int] concurrency_limit: int - memory_limiter: ResourceLimiter | None + memory_limiter: AbstractLimiter diagnostics: dict[str, float] max_message_size: int @@ -74,7 +74,7 @@ def __init__( self._exception = None self.concurrency_limit = concurrency_limit self._inputs_done = False - self.memory_limiter = memory_limiter + self.memory_limiter = memory_limiter or NoopLimiter() self.diagnostics: dict[str, float] = defaultdict(float) self._tasks = [ asyncio.create_task(self._background_task()) @@ -97,7 +97,7 @@ def heartbeat(self) -> dict[str, Any]: "written": self.bytes_written, "read": self.bytes_read, "diagnostics": self.diagnostics, - "memory_limit": self.memory_limiter._maxvalue if self.memory_limiter else 0, + "memory_limit": self.memory_limiter._maxvalue, } async def process(self, id: str, shards: list[ShardType], size: int) -> None: @@ -119,8 +119,7 @@ async def process(self, id: str, shards: list[ShardType], size: int) -> None: "avg_duration" ] + 0.02 * (stop - start) finally: - if self.memory_limiter: - await self.memory_limiter.decrease(size) + await self.memory_limiter.decrease(size) self.bytes_memory -= size async def _process(self, id: str, shards: list[ShardType]) -> None: @@ -198,15 +197,13 @@ async def write(self, data: dict[str, ShardType]) -> None: self.bytes_memory += total_batch_size self.bytes_total += total_batch_size - if self.memory_limiter: - self.memory_limiter.increase(total_batch_size) + self.memory_limiter.increase(total_batch_size) async with self._shards_available: for worker, shard in data.items(): self.shards[worker].append(shard) self.sizes[worker] += sizes[worker] self._shards_available.notify() - if self.memory_limiter: - await self.memory_limiter.wait_for_available() + await self.memory_limiter.wait_for_available() del data assert total_batch_size diff --git a/distributed/shuffle/_core.py b/distributed/shuffle/_core.py index e7dd72c73cf..d0820249dc6 100644 --- a/distributed/shuffle/_core.py +++ b/distributed/shuffle/_core.py @@ -13,11 +13,17 @@ from functools import partial from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar +from dask import config + from distributed.core import PooledRPCCall from distributed.exceptions import Reschedule from distributed.protocol import to_serialize from distributed.shuffle._comms import CommShardsBuffer -from distributed.shuffle._disk import DiskShardsBuffer +from distributed.shuffle._disk import ( + DiskShardsBuffer, + FileShardsBuffer, + MemoryShardsBuffer, +) from distributed.shuffle._exceptions import ShuffleClosedError from distributed.shuffle._limiter import ResourceLimiter @@ -40,6 +46,8 @@ class ShuffleRun(Generic[_T_partition_id, _T_partition_type]): + _disk_buffer: FileShardsBuffer + def __init__( self, id: ShuffleId, @@ -60,10 +68,13 @@ def __init__( self.scheduler = scheduler self.closed = False - self._disk_buffer = DiskShardsBuffer( - directory=directory, - memory_limiter=memory_limiter_disk, - ) + if config.get("distributed.shuffle.p2p.stage_in_memory", False): + self._disk_buffer = MemoryShardsBuffer() + else: + self._disk_buffer = DiskShardsBuffer( + directory=directory, + memory_limiter=memory_limiter_disk, + ) self._comm_buffer = CommShardsBuffer( send=self.send, memory_limiter=memory_limiter_comms diff --git a/distributed/shuffle/_disk.py b/distributed/shuffle/_disk.py index 2b3dc37beed..dfb3a5afe98 100644 --- a/distributed/shuffle/_disk.py +++ b/distributed/shuffle/_disk.py @@ -3,53 +3,68 @@ import contextlib import pathlib import shutil +from collections import defaultdict +from io import BytesIO +from types import TracebackType +from typing import BinaryIO from distributed.shuffle._buffer import ShardsBuffer from distributed.shuffle._limiter import ResourceLimiter from distributed.utils import log_errors -class DiskShardsBuffer(ShardsBuffer): - """Accept, buffer, and write many small objects to many files - - This takes in lots of small objects, writes them to a local directory, and - then reads them back when all writes are complete. It buffers these - objects in memory so that it can optimize disk access for larger writes. - - **State** - - - shards: dict[str, list[bytes]] - - This is our in-memory buffer of data waiting to be written to files. - - - sizes: dict[str, int] - - The size of each list of shards. We find the largest and write data from that buffer +class FileShardsBuffer(ShardsBuffer): + """An abstract buffering object backed by a "file" Parameters ---------- - directory : str or pathlib.Path - Where to write and read data. Ideally points to fast disk. memory_limiter : ResourceLimiter, optional - Limiter for in-memory buffering (at most this much data) - before writes to disk occur. If the incoming data that has yet - to be processed exceeds this limit, then the buffer will block - until below the threshold. See :meth:`.write` for the - implementation of this scheme. + Resource limiter. + + Notes + ----- + Currently, a concurrency limit of one is hard-coded. """ - def __init__( - self, - directory: str | pathlib.Path, - memory_limiter: ResourceLimiter | None = None, - ): + def __init__(self, memory_limiter: ResourceLimiter | None = None) -> None: super().__init__( memory_limiter=memory_limiter, - # Disk is not able to run concurrently atm + # FileShardsBuffer not able to run concurrently concurrency_limit=1, ) - self.directory = pathlib.Path(directory) - self.directory.mkdir(exist_ok=True) + + def writer(self, id: int | str) -> BinaryIO: + """Return a file-like object for writing in append-mode. + + Parameters + ---------- + id + The shard id (will normalised to a string) + + Returns + ------- + An object implementing the BinaryIO interface. + """ + raise NotImplementedError("Abstract class can't provide this") + + def reader(self, id: int | str) -> BinaryIO: + """Return a file-like object for reading from byte-0. + + Parameters + ---------- + id + The shard id (will be normalised to a string) + + Returns + ------- + An object implementing the BinaryIO interface. + + Raises + ------ + FileNotFoundError + If no shard with requested id exists. + """ + raise NotImplementedError("Abstract class can't provide this") async def _process(self, id: str, shards: list[bytes]) -> None: """Write one buffer to file @@ -68,9 +83,7 @@ async def _process(self, id: str, shards: list[bytes]) -> None: with log_errors(): # Consider boosting total_size a bit here to account for duplication with self.time("write"): - with open( - self.directory / str(id), mode="ab", buffering=100_000_000 - ) as f: + with self.writer(id) as f: for shard in shards: f.write(shard) @@ -82,9 +95,7 @@ def read(self, id: int | str) -> bytes: try: with self.time("read"): - with open( - self.directory / str(id), mode="rb", buffering=100_000_000 - ) as f: + with self.reader(id) as f: data = f.read() size = f.tell() except FileNotFoundError: @@ -96,6 +107,94 @@ def read(self, id: int | str) -> bytes: else: raise KeyError(id) + +class _PersistentBytesIO(BytesIO): + """A BytesIO object that does not close itself when used in a with block.""" + + def __enter__(self) -> _PersistentBytesIO: + return self + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: + pass + + +class MemoryShardsBuffer(FileShardsBuffer): + """Accept and buffer many small objects into memory. + + This implements in-memory "file" buffering with no resource limit + with the same interface as :class:`DiskShardsBuffer`. + + """ + + def __init__(self) -> None: + super().__init__(memory_limiter=None) + self._memory_buffers: defaultdict[str, _PersistentBytesIO] = defaultdict( + _PersistentBytesIO + ) + + def writer(self, id: int | str) -> BinaryIO: + buf = self._memory_buffers[str(id)] + buf.seek(buf.tell()) + return buf + + def reader(self, id: int | str) -> BinaryIO: + key = str(id) + if key not in self._memory_buffers: + raise FileNotFoundError(f"Shard with {id=} is unknown") + buf = self._memory_buffers[str(id)] + buf.seek(0) + return buf + + +class DiskShardsBuffer(FileShardsBuffer): + """Accept, buffer, and write many small objects to many files + + This takes in lots of small objects, writes them to a local directory, and + then reads them back when all writes are complete. It buffers these + objects in memory so that it can optimize disk access for larger writes. + + **State** + + - shards: dict[str, list[bytes]] + + This is our in-memory buffer of data waiting to be written to files. + + - sizes: dict[str, int] + + The size of each list of shards. We find the largest and write data from that buffer + + Parameters + ---------- + directory : str or pathlib.Path + Where to write and read data. Ideally points to fast disk. + memory_limiter : ResourceLimiter, optional + Limiter for in-memory buffering (at most this much data) + before writes to disk occur. If the incoming data that has yet + to be processed exceeds this limit, then the buffer will block + until below the threshold. See :meth:`.write` for the + implementation of this scheme. + """ + + def __init__( + self, + directory: str | pathlib.Path, + memory_limiter: ResourceLimiter | None = None, + ): + super().__init__(memory_limiter=memory_limiter) + self.directory = pathlib.Path(directory) + self.directory.mkdir(exist_ok=True) + + def writer(self, id: int | str) -> BinaryIO: + return open(self.directory / str(id), mode="ab", buffering=100_000_000) + + def reader(self, id: int | str) -> BinaryIO: + return open(self.directory / str(id), mode="rb", buffering=100_000_000) + async def close(self) -> None: await super().close() with contextlib.suppress(FileNotFoundError): diff --git a/distributed/shuffle/_limiter.py b/distributed/shuffle/_limiter.py index f3591b53f7f..b6550a1551c 100644 --- a/distributed/shuffle/_limiter.py +++ b/distributed/shuffle/_limiter.py @@ -1,10 +1,38 @@ from __future__ import annotations import asyncio +import math +from typing import Protocol from distributed.metrics import time +class AbstractLimiter(Protocol): + @property + def _maxvalue(self) -> int | float: + ... + + def available(self) -> int | float: + """How far can the value be increased before blocking""" + ... + + def free(self) -> bool: + """Return True if nothing has been acquired / the limiter is in a neutral state""" + ... + + async def wait_for_available(self) -> None: + """Block until the counter drops below maxvalue""" + ... + + def increase(self, value: int) -> None: + """Increase the internal counter by value""" + ... + + async def decrease(self, value: int) -> None: + """Decrease the internal counter by value""" + ... + + class ResourceLimiter: """Limit an abstract resource @@ -70,3 +98,29 @@ async def decrease(self, value: int) -> None: self._acquired -= value async with self._condition: self._condition.notify_all() + + +# Used to simplify code in shardsbuffer +class NoopLimiter: + """A no-op resource limiter.""" + + _maxvalue = math.inf + + def __repr__(self) -> str: + return f"" + + def free(self) -> bool: + return True + + def available(self) -> float: + return self._maxvalue + + def increase(self, value: int) -> None: + pass + + async def decrease(self, value: int) -> None: + pass + + async def wait_for_available(self) -> None: + """Don't block and return immediately""" + pass