From 7ecb9f70f050387ddd889c3401535c2da3ffe834 Mon Sep 17 00:00:00 2001 From: crusaderky Date: Thu, 2 Nov 2023 21:08:20 +0100 Subject: [PATCH] P2P shuffle: pickle tiny buffers into monolithic bytes objects --- distributed/protocol/tests/test_utils_test.py | 8 ++-- distributed/protocol/utils_test.py | 12 +++--- distributed/shuffle/_buffer.py | 21 +++++----- distributed/shuffle/_core.py | 41 ++++++++++++++++--- distributed/shuffle/_rechunk.py | 1 - distributed/shuffle/_shuffle.py | 3 -- distributed/shuffle/_worker_plugin.py | 2 +- distributed/shuffle/tests/test_core.py | 29 +++++++++++++ distributed/shuffle/tests/test_rechunk.py | 7 +++- distributed/shuffle/tests/test_shuffle.py | 2 +- 10 files changed, 94 insertions(+), 32 deletions(-) create mode 100644 distributed/shuffle/tests/test_core.py diff --git a/distributed/protocol/tests/test_utils_test.py b/distributed/protocol/tests/test_utils_test.py index c3861eea74c..d275f3af4c8 100644 --- a/distributed/protocol/tests/test_utils_test.py +++ b/distributed/protocol/tests/test_utils_test.py @@ -21,6 +21,8 @@ def test_get_host_array(): a = np.frombuffer(buf[1:], dtype="u1") assert get_host_array(a) is buf.obj - a = np.frombuffer(bytearray(3), dtype="u1") - with pytest.raises(TypeError): - get_host_array(a) + for buf in (b"123", bytearray(b"123")): + a = np.frombuffer(buf, dtype="u1") + assert get_host_array(a) is buf + a = np.frombuffer(memoryview(buf), dtype="u1") + assert get_host_array(a) is buf diff --git a/distributed/protocol/utils_test.py b/distributed/protocol/utils_test.py index d72c98df372..2ea3b68414a 100644 --- a/distributed/protocol/utils_test.py +++ b/distributed/protocol/utils_test.py @@ -6,7 +6,7 @@ import numpy -def get_host_array(a: numpy.ndarray) -> numpy.ndarray: +def get_host_array(a: numpy.ndarray) -> numpy.ndarray | bytes | bytearray: """Given a numpy array, find the underlying memory allocated by either distributed.protocol.utils.host_array or internally by numpy """ @@ -22,9 +22,7 @@ def get_host_array(a: numpy.ndarray) -> numpy.ndarray: o = o.base else: return o - else: - # distributed.comm.utils.host_array() uses numpy.empty() - raise TypeError( - "Array uses a buffer allocated neither internally nor by host_array: " - f"{type(o)}" - ) + elif isinstance(o, (bytes, bytearray)): + return o + else: # pragma: nocover + raise TypeError(f"Unexpected numpy buffer: {o!r}") diff --git a/distributed/shuffle/_buffer.py b/distributed/shuffle/_buffer.py index bfacac2790e..e4a44bc843e 100644 --- a/distributed/shuffle/_buffer.py +++ b/distributed/shuffle/_buffer.py @@ -5,21 +5,22 @@ import logging from collections import defaultdict from collections.abc import Iterator, Sized -from typing import Any, Generic, TypeVar +from typing import TYPE_CHECKING, Any, Generic, TypeVar from distributed.metrics import time from distributed.shuffle._limiter import ResourceLimiter from distributed.sizeof import sizeof logger = logging.getLogger("distributed.shuffle") +if TYPE_CHECKING: + # TODO import from collections.abc (requires Python >=3.12) + from typing_extensions import Buffer +else: + Buffer = Sized -ShardType = TypeVar("ShardType", bound=Sized) -T = TypeVar("T") - +ShardType = TypeVar("ShardType", bound=Buffer) -class _List(list[T]): - # This ensures that the distributed.protocol will not iterate over this collection - pass +T = TypeVar("T") class ShardsBuffer(Generic[ShardType]): @@ -43,7 +44,7 @@ class ShardsBuffer(Generic[ShardType]): Flushing will not raise an exception. To ensure that the buffer finished successfully, please call `ShardsBuffer.raise_on_exception` """ - shards: defaultdict[str, _List[ShardType]] + shards: defaultdict[str, list[ShardType]] sizes: defaultdict[str, int] sizes_detail: defaultdict[str, list[int]] concurrency_limit: int @@ -70,7 +71,7 @@ def __init__( max_message_size: int = -1, ) -> None: self._accepts_input = True - self.shards = defaultdict(_List) + self.shards = defaultdict(list) self.sizes = defaultdict(int) self.sizes_detail = defaultdict(list) self._exception = None @@ -146,7 +147,7 @@ def _continue() -> bool: part_id = max(self.sizes, key=self.sizes.__getitem__) if self.max_message_size > 0: size = 0 - shards: _List[ShardType] = _List() + shards = [] while size < self.max_message_size: try: shard = self.shards[part_id].pop() diff --git a/distributed/shuffle/_core.py b/distributed/shuffle/_core.py index 7e15a1a5709..ac62d18c393 100644 --- a/distributed/shuffle/_core.py +++ b/distributed/shuffle/_core.py @@ -4,19 +4,21 @@ import asyncio import contextlib import itertools +import pickle import time from collections import defaultdict -from collections.abc import Callable, Iterator, Sequence +from collections.abc import Callable, Iterable, Iterator, Sequence from concurrent.futures import ThreadPoolExecutor from dataclasses import dataclass, field from enum import Enum from functools import partial from pathlib import Path -from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar +from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast from tornado.ioloop import IOLoop import dask.config +from dask.core import flatten from dask.typing import Key from dask.utils import parse_timedelta @@ -140,7 +142,7 @@ async def barrier(self, run_ids: Sequence[int]) -> int: return self.run_id async def _send( - self, address: str, shards: list[tuple[_T_partition_id, Any]] + self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes ) -> None: self.raise_if_closed() return await self.rpc(address).shuffle_receive( @@ -159,8 +161,19 @@ async def send( retry_delay_max = parse_timedelta( dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s" ) + + if _mean_shard_size(shards) < 65536: + # Don't send buffers individually over the tcp comms. + # Instead, merge everything into an opaque bytes blob, send it all at once, + # and unpickle it on the other side. + # Performance tests informing the size threshold: + # https://github.com/dask/distributed/pull/8318 + shards_or_bytes: list | bytes = pickle.dumps(shards) + else: + shards_or_bytes = shards + return await retry( - partial(self._send, address, shards), + partial(self._send, address, shards_or_bytes), count=retry_count, delay_min=retry_delay_min, delay_max=retry_delay_max, @@ -239,7 +252,10 @@ def _read_from_disk(self, id: NDIndex) -> list[Any]: # TODO: Typing self.raise_if_closed() return self._disk_buffer.read("_".join(str(i) for i in id)) - async def receive(self, data: list[tuple[_T_partition_id, Any]]) -> None: + async def receive(self, data: list[tuple[_T_partition_id, Any]] | bytes) -> None: + if isinstance(data, bytes): + # Unpack opaque blob. See send() + data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data)) await self._receive(data) async def _ensure_output_worker(self, i: _T_partition_id, key: str) -> None: @@ -422,3 +438,18 @@ def handle_unpack_errors(id: ShuffleId) -> Iterator[None]: raise Reschedule() except Exception as e: raise RuntimeError(f"P2P shuffling {id} failed during unpack phase") from e + + +def _mean_shard_size(shards: Iterable) -> int: + """Return estimated mean size in bytes of each shard""" + size = 0 + count = 0 + for shard in flatten(shards, container=(tuple, list)): + if not isinstance(shard, int): + # This also asserts that shard is a Buffer and that we didn't forget + # a container or metadata type above + size += memoryview(shard).nbytes + count += 1 + if count == 10: + break + return size // count if count else 0 diff --git a/distributed/shuffle/_rechunk.py b/distributed/shuffle/_rechunk.py index 1f47eb3c5ed..bf576f84745 100644 --- a/distributed/shuffle/_rechunk.py +++ b/distributed/shuffle/_rechunk.py @@ -273,7 +273,6 @@ def convert_chunk(shards: list[list[tuple[NDIndex, np.ndarray]]]) -> np.ndarray: for sublist in shards: for index, shard in sublist: indexed[index] = shard - del shards subshape = [max(dim) + 1 for dim in zip(*indexed.keys())] assert len(indexed) == np.prod(subshape) diff --git a/distributed/shuffle/_shuffle.py b/distributed/shuffle/_shuffle.py index 03d990a7f30..063079df052 100644 --- a/distributed/shuffle/_shuffle.py +++ b/distributed/shuffle/_shuffle.py @@ -455,9 +455,6 @@ def __init__( self.partitions_of = dict(partitions_of) self.worker_for = pd.Series(worker_for, name="_workers").astype("category") - async def receive(self, data: list[tuple[int, bytes]]) -> None: - await self._receive(data) - async def _receive(self, data: list[tuple[int, bytes]]) -> None: self.raise_if_closed() diff --git a/distributed/shuffle/_worker_plugin.py b/distributed/shuffle/_worker_plugin.py index 8828058e698..838137b5f1b 100644 --- a/distributed/shuffle/_worker_plugin.py +++ b/distributed/shuffle/_worker_plugin.py @@ -299,7 +299,7 @@ async def shuffle_receive( self, shuffle_id: ShuffleId, run_id: int, - data: list[tuple[int, Any]], + data: list[tuple[int, Any]] | bytes, ) -> None: """ Handler: Receive an incoming shard of data from a peer worker. diff --git a/distributed/shuffle/tests/test_core.py b/distributed/shuffle/tests/test_core.py new file mode 100644 index 00000000000..deb9d2a0bbb --- /dev/null +++ b/distributed/shuffle/tests/test_core.py @@ -0,0 +1,29 @@ +from __future__ import annotations + +import pytest + +from distributed.shuffle._core import _mean_shard_size + + +def test_mean_shard_size(): + assert _mean_shard_size([]) == 0 + assert _mean_shard_size([b""]) == 0 + assert _mean_shard_size([b"123", b"45678"]) == 4 + # Don't fully iterate over large collections + assert _mean_shard_size([b"12" * n for n in range(1000)]) == 9 + # Support any Buffer object + assert _mean_shard_size([b"12", bytearray(b"1234"), memoryview(b"123456")]) == 4 + # Recursion into lists or tuples; ignore int + assert _mean_shard_size([(1, 2, [3, b"123456"])]) == 6 + # Don't blindly call sizeof() on unexpected objects + with pytest.raises(TypeError): + _mean_shard_size([1.2]) + with pytest.raises(TypeError): + _mean_shard_size([{1: 2}]) + + +def test_mean_shard_size_numpy(): + """Test that _mean_shard_size doesn't call len() on multi-byte data types""" + np = pytest.importorskip("numpy") + assert _mean_shard_size([np.zeros(10, dtype="u1")]) == 10 + assert _mean_shard_size([np.zeros(10, dtype="u8")]) == 80 diff --git a/distributed/shuffle/tests/test_rechunk.py b/distributed/shuffle/tests/test_rechunk.py index c31f122f9fe..d12a5474727 100644 --- a/distributed/shuffle/tests/test_rechunk.py +++ b/distributed/shuffle/tests/test_rechunk.py @@ -147,7 +147,12 @@ async def test_lowlevel_rechunk(tmp_path, n_workers, barrier_first_worker, disk) total_bytes_recvd += metrics["disk"]["total"] total_bytes_recvd_shuffle += s.total_recvd - assert total_bytes_recvd_shuffle == total_bytes_sent + # Allow for some uncertainty due to slight differences in measuring + assert ( + total_bytes_sent * 0.95 + < total_bytes_recvd_shuffle + < total_bytes_sent * 1.05 + ) all_chunks = np.empty(tuple(len(dim) for dim in new), dtype="O") for ix, worker in worker_for_mapping.items(): diff --git a/distributed/shuffle/tests/test_shuffle.py b/distributed/shuffle/tests/test_shuffle.py index b3ababdea21..85e22f75ee5 100644 --- a/distributed/shuffle/tests/test_shuffle.py +++ b/distributed/shuffle/tests/test_shuffle.py @@ -1816,7 +1816,7 @@ async def test_error_receive(tmp_path, loop_in_thread): partitions_for_worker[w].append(part) class ErrorReceive(DataFrameShuffleRun): - async def receive(self, data: list[tuple[int, bytes]]) -> None: + async def _receive(self, data: list[tuple[int, bytes]]) -> None: raise RuntimeError("Error during receive") with DataFrameShuffleTestPool() as local_shuffle_pool: