Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 7 additions & 10 deletions distributed/shuffle/_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from typing import Any, Generic, TypeVar

from distributed.metrics import time
from distributed.shuffle._limiter import ResourceLimiter
from distributed.shuffle._limiter import AbstractLimiter, NoopLimiter, ResourceLimiter
from distributed.sizeof import sizeof

logger = logging.getLogger("distributed.shuffle")
Expand Down Expand Up @@ -46,7 +46,7 @@ class ShardsBuffer(Generic[ShardType]):
shards: defaultdict[str, _List[ShardType]]
sizes: defaultdict[str, int]
concurrency_limit: int
memory_limiter: ResourceLimiter | None
memory_limiter: AbstractLimiter
diagnostics: dict[str, float]
max_message_size: int

Expand Down Expand Up @@ -74,7 +74,7 @@ def __init__(
self._exception = None
self.concurrency_limit = concurrency_limit
self._inputs_done = False
self.memory_limiter = memory_limiter
self.memory_limiter = memory_limiter or NoopLimiter()
self.diagnostics: dict[str, float] = defaultdict(float)
self._tasks = [
asyncio.create_task(self._background_task())
Expand All @@ -97,7 +97,7 @@ def heartbeat(self) -> dict[str, Any]:
"written": self.bytes_written,
"read": self.bytes_read,
"diagnostics": self.diagnostics,
"memory_limit": self.memory_limiter._maxvalue if self.memory_limiter else 0,
"memory_limit": self.memory_limiter._maxvalue,
}

async def process(self, id: str, shards: list[ShardType], size: int) -> None:
Expand All @@ -119,8 +119,7 @@ async def process(self, id: str, shards: list[ShardType], size: int) -> None:
"avg_duration"
] + 0.02 * (stop - start)
finally:
if self.memory_limiter:
await self.memory_limiter.decrease(size)
await self.memory_limiter.decrease(size)
self.bytes_memory -= size

async def _process(self, id: str, shards: list[ShardType]) -> None:
Expand Down Expand Up @@ -198,15 +197,13 @@ async def write(self, data: dict[str, ShardType]) -> None:
self.bytes_memory += total_batch_size
self.bytes_total += total_batch_size

if self.memory_limiter:
self.memory_limiter.increase(total_batch_size)
self.memory_limiter.increase(total_batch_size)
async with self._shards_available:
for worker, shard in data.items():
self.shards[worker].append(shard)
self.sizes[worker] += sizes[worker]
self._shards_available.notify()
if self.memory_limiter:
await self.memory_limiter.wait_for_available()
await self.memory_limiter.wait_for_available()
del data
assert total_batch_size

Expand Down
21 changes: 16 additions & 5 deletions distributed/shuffle/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,17 @@
from functools import partial
from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar

from dask import config

from distributed.core import PooledRPCCall
from distributed.exceptions import Reschedule
from distributed.protocol import to_serialize
from distributed.shuffle._comms import CommShardsBuffer
from distributed.shuffle._disk import DiskShardsBuffer
from distributed.shuffle._disk import (
DiskShardsBuffer,
FileShardsBuffer,
MemoryShardsBuffer,
)
from distributed.shuffle._exceptions import ShuffleClosedError
from distributed.shuffle._limiter import ResourceLimiter

Expand All @@ -40,6 +46,8 @@


class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
_disk_buffer: FileShardsBuffer

def __init__(
self,
id: ShuffleId,
Expand All @@ -60,10 +68,13 @@ def __init__(
self.scheduler = scheduler
self.closed = False

self._disk_buffer = DiskShardsBuffer(
directory=directory,
memory_limiter=memory_limiter_disk,
)
if config.get("distributed.shuffle.p2p.stage_in_memory", False):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of configuring this at shuffle execution time on the worker, we should instead configure this at graph creation time. This would allow us to specify this on a per-shuffle basis.

self._disk_buffer = MemoryShardsBuffer()
else:
self._disk_buffer = DiskShardsBuffer(
directory=directory,
memory_limiter=memory_limiter_disk,
)

self._comm_buffer = CommShardsBuffer(
send=self.send, memory_limiter=memory_limiter_comms
Expand Down
173 changes: 136 additions & 37 deletions distributed/shuffle/_disk.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,53 +3,68 @@
import contextlib
import pathlib
import shutil
from collections import defaultdict
from io import BytesIO
from types import TracebackType
from typing import BinaryIO

from distributed.shuffle._buffer import ShardsBuffer
from distributed.shuffle._limiter import ResourceLimiter
from distributed.utils import log_errors


class DiskShardsBuffer(ShardsBuffer):
"""Accept, buffer, and write many small objects to many files

This takes in lots of small objects, writes them to a local directory, and
then reads them back when all writes are complete. It buffers these
objects in memory so that it can optimize disk access for larger writes.

**State**

- shards: dict[str, list[bytes]]

This is our in-memory buffer of data waiting to be written to files.

- sizes: dict[str, int]

The size of each list of shards. We find the largest and write data from that buffer
class FileShardsBuffer(ShardsBuffer):
"""An abstract buffering object backed by a "file"

Parameters
----------
directory : str or pathlib.Path
Where to write and read data. Ideally points to fast disk.
memory_limiter : ResourceLimiter, optional
Limiter for in-memory buffering (at most this much data)
before writes to disk occur. If the incoming data that has yet
to be processed exceeds this limit, then the buffer will block
until below the threshold. See :meth:`.write` for the
implementation of this scheme.
Resource limiter.

Notes
-----
Currently, a concurrency limit of one is hard-coded.
"""

def __init__(
self,
directory: str | pathlib.Path,
memory_limiter: ResourceLimiter | None = None,
):
def __init__(self, memory_limiter: ResourceLimiter | None = None) -> None:
super().__init__(
memory_limiter=memory_limiter,
# Disk is not able to run concurrently atm
# FileShardsBuffer not able to run concurrently
concurrency_limit=1,
)
self.directory = pathlib.Path(directory)
self.directory.mkdir(exist_ok=True)

def writer(self, id: int | str) -> BinaryIO:
"""Return a file-like object for writing in append-mode.

Parameters
----------
id
The shard id (will normalised to a string)

Returns
-------
An object implementing the BinaryIO interface.
"""
raise NotImplementedError("Abstract class can't provide this")

def reader(self, id: int | str) -> BinaryIO:
"""Return a file-like object for reading from byte-0.

Parameters
----------
id
The shard id (will be normalised to a string)

Returns
-------
An object implementing the BinaryIO interface.

Raises
------
FileNotFoundError
If no shard with requested id exists.
"""
raise NotImplementedError("Abstract class can't provide this")

async def _process(self, id: str, shards: list[bytes]) -> None:
"""Write one buffer to file
Expand All @@ -68,9 +83,7 @@ async def _process(self, id: str, shards: list[bytes]) -> None:
with log_errors():
# Consider boosting total_size a bit here to account for duplication
with self.time("write"):
with open(
self.directory / str(id), mode="ab", buffering=100_000_000
) as f:
with self.writer(id) as f:
for shard in shards:
f.write(shard)

Expand All @@ -82,9 +95,7 @@ def read(self, id: int | str) -> bytes:

try:
with self.time("read"):
with open(
self.directory / str(id), mode="rb", buffering=100_000_000
) as f:
with self.reader(id) as f:
data = f.read()
size = f.tell()
except FileNotFoundError:
Expand All @@ -96,6 +107,94 @@ def read(self, id: int | str) -> bytes:
else:
raise KeyError(id)


class _PersistentBytesIO(BytesIO):
"""A BytesIO object that does not close itself when used in a with block."""

def __enter__(self) -> _PersistentBytesIO:
return self

def __exit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
pass


class MemoryShardsBuffer(FileShardsBuffer):
"""Accept and buffer many small objects into memory.

This implements in-memory "file" buffering with no resource limit
with the same interface as :class:`DiskShardsBuffer`.

"""

def __init__(self) -> None:
super().__init__(memory_limiter=None)
self._memory_buffers: defaultdict[str, _PersistentBytesIO] = defaultdict(
_PersistentBytesIO
)

def writer(self, id: int | str) -> BinaryIO:
buf = self._memory_buffers[str(id)]
buf.seek(buf.tell())
return buf

def reader(self, id: int | str) -> BinaryIO:
key = str(id)
if key not in self._memory_buffers:
raise FileNotFoundError(f"Shard with {id=} is unknown")
buf = self._memory_buffers[str(id)]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
buf = self._memory_buffers[str(id)]
buf = self._memory_buffers.pop(str(id))

This will remove the BinaryIO object from the MemoryShardsBuffer upon reading to prevent memory duplication after unpacking. From what I understand, reading should be an exactly-once operation.

buf.seek(0)
return buf


class DiskShardsBuffer(FileShardsBuffer):
"""Accept, buffer, and write many small objects to many files

This takes in lots of small objects, writes them to a local directory, and
then reads them back when all writes are complete. It buffers these
objects in memory so that it can optimize disk access for larger writes.

**State**

- shards: dict[str, list[bytes]]

This is our in-memory buffer of data waiting to be written to files.

- sizes: dict[str, int]

The size of each list of shards. We find the largest and write data from that buffer

Parameters
----------
directory : str or pathlib.Path
Where to write and read data. Ideally points to fast disk.
memory_limiter : ResourceLimiter, optional
Limiter for in-memory buffering (at most this much data)
before writes to disk occur. If the incoming data that has yet
to be processed exceeds this limit, then the buffer will block
until below the threshold. See :meth:`.write` for the
implementation of this scheme.
"""

def __init__(
self,
directory: str | pathlib.Path,
memory_limiter: ResourceLimiter | None = None,
):
super().__init__(memory_limiter=memory_limiter)
self.directory = pathlib.Path(directory)
self.directory.mkdir(exist_ok=True)

def writer(self, id: int | str) -> BinaryIO:
return open(self.directory / str(id), mode="ab", buffering=100_000_000)

def reader(self, id: int | str) -> BinaryIO:
return open(self.directory / str(id), mode="rb", buffering=100_000_000)

async def close(self) -> None:
await super().close()
with contextlib.suppress(FileNotFoundError):
Expand Down
54 changes: 54 additions & 0 deletions distributed/shuffle/_limiter.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,38 @@
from __future__ import annotations

import asyncio
import math
from typing import Protocol

from distributed.metrics import time


class AbstractLimiter(Protocol):
@property
def _maxvalue(self) -> int | float:
...

def available(self) -> int | float:
"""How far can the value be increased before blocking"""
...

def free(self) -> bool:
"""Return True if nothing has been acquired / the limiter is in a neutral state"""
...

async def wait_for_available(self) -> None:
"""Block until the counter drops below maxvalue"""
...

def increase(self, value: int) -> None:
"""Increase the internal counter by value"""
...

async def decrease(self, value: int) -> None:
"""Decrease the internal counter by value"""
...


class ResourceLimiter:
"""Limit an abstract resource

Expand Down Expand Up @@ -70,3 +98,29 @@ async def decrease(self, value: int) -> None:
self._acquired -= value
async with self._condition:
self._condition.notify_all()


# Used to simplify code in shardsbuffer
class NoopLimiter:
"""A no-op resource limiter."""

_maxvalue = math.inf

def __repr__(self) -> str:
return f"<NoopLimiter maxvalue: {math.inf} available: {math.inf}>"

def free(self) -> bool:
return True

def available(self) -> float:
return self._maxvalue

def increase(self, value: int) -> None:
pass

async def decrease(self, value: int) -> None:
pass

async def wait_for_available(self) -> None:
"""Don't block and return immediately"""
pass