Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 5 additions & 24 deletions distributed/shuffle/_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@
from io import BytesIO
from typing import TYPE_CHECKING

import pyarrow as pa
from packaging.version import parse

if TYPE_CHECKING:
import pandas as pd
import pyarrow as pa


def check_dtype_support(meta_input: pd.DataFrame) -> None:
Expand Down Expand Up @@ -70,27 +70,8 @@ def default_types_mapper(pyarrow_dtype: pa.DataType) -> object:
df = table.to_pandas(self_destruct=True, types_mapper=default_types_mapper)
return df.astype(meta.dtypes, copy=False)

from dask.sizeof import sizeof

def list_of_buffers_to_table(data: list[bytes]) -> pa.Table:
"""Convert a list of arrow buffers and a schema to an Arrow Table"""
import pyarrow as pa

return pa.concat_tables(deserialize_table(buffer) for buffer in data)


def serialize_table(table: pa.Table) -> bytes:
import io

import pyarrow as pa

stream = io.BytesIO()
with pa.ipc.new_stream(stream, table.schema) as writer:
writer.write_table(table)
return stream.getvalue()


def deserialize_table(buffer: bytes) -> pa.Table:
import pyarrow as pa

with pa.ipc.open_stream(pa.py_buffer(buffer)) as reader:
return reader.read_all()
@sizeof.register(pa.Buffer)
def sizeof_pa_buffer(obj: pa.Buffer) -> int:
return obj.size
5 changes: 5 additions & 0 deletions distributed/shuffle/_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,11 @@
ShardType = TypeVar("ShardType", bound=Sized)
T = TypeVar("T")

import pyarrow as pa
@sizeof.register(pa.Table)
def pa_tab_sizeof(obj):
# FIXME: this is pretty expensive
return obj.nbytes

class _List(list[T]):
# This ensures that the distributed.protocol will not iterate over this collection
Expand Down
6 changes: 5 additions & 1 deletion distributed/shuffle/_comms.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,14 @@ def __init__(
memory_limiter: ResourceLimiter | None = None,
concurrency_limit: int = 10,
):
import dask

super().__init__(
memory_limiter=memory_limiter,
concurrency_limit=concurrency_limit,
max_message_size=CommShardsBuffer.max_message_size,
max_message_size=parse_bytes(
dask.config.get("shuffle.comm.max_message_size", default="2 MiB")
),
)
self.send = send

Expand Down
17 changes: 6 additions & 11 deletions distributed/shuffle/_disk.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from distributed.shuffle._limiter import ResourceLimiter
from distributed.utils import log_errors

from dask.utils import parse_bytes

class DiskShardsBuffer(ShardsBuffer):
"""Accept, buffer, and write many small objects to many files
Expand Down Expand Up @@ -40,14 +41,16 @@ class DiskShardsBuffer(ShardsBuffer):

def __init__(
self,
write,
directory: str | pathlib.Path,
memory_limiter: ResourceLimiter | None = None,
):
import dask
super().__init__(
memory_limiter=memory_limiter,
# Disk is not able to run concurrently atm
concurrency_limit=1,
concurrency_limit=10,
)
self.__write = write
self.directory = pathlib.Path(directory)
self.directory.mkdir(exist_ok=True)

Expand All @@ -64,15 +67,7 @@ async def _process(self, id: str, shards: list[bytes]) -> None:
future then we should consider simplifying this considerably and
dropping the write into communicate above.
"""

with log_errors():
# Consider boosting total_size a bit here to account for duplication
with self.time("write"):
with open(
self.directory / str(id), mode="ab", buffering=100_000_000
) as f:
for shard in shards:
f.write(shard)
await self.__write(path=self.directory / str(id), shards=shards)

def read(self, id: int | str) -> bytes:
"""Read a complete file back into memory"""
Expand Down
41 changes: 30 additions & 11 deletions distributed/shuffle/_worker_extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,17 @@
from io import BytesIO
from typing import TYPE_CHECKING, Any, Generic, TypeVar, overload

import pyarrow as pa
import toolz

import dask
from dask.context import thread_state
from dask.utils import parse_bytes

from distributed.core import PooledRPCCall
from distributed.exceptions import Reschedule
from distributed.protocol import to_serialize
from distributed.shuffle._arrow import (
convert_partition,
list_of_buffers_to_table,
serialize_table,
)
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._arrow import convert_partition
from distributed.shuffle._comms import CommShardsBuffer
from distributed.shuffle._disk import DiskShardsBuffer
from distributed.shuffle._limiter import ResourceLimiter
Expand Down Expand Up @@ -80,6 +78,7 @@ def __init__(
self.closed = False

self._disk_buffer = DiskShardsBuffer(
write=self.write_to_disk,
directory=directory,
memory_limiter=memory_limiter_disk,
)
Expand Down Expand Up @@ -120,12 +119,29 @@ async def barrier(self) -> None:
# up the comm pool on scheduler side
await self.scheduler.shuffle_barrier(id=self.id, run_id=self.run_id)

# FIXME: The level of indirection is way too high
async def write_to_disk(self, path, shards):
def _():
with log_errors():
tab = pa.concat_tables(shards)
with self.time("write"):
# TODO: compression is trivial with this assuming the reader
# also uses pyarrow to open the file
with pa.output_stream(path) as fd:
with pa.ipc.new_stream(fd, schema=tab.schema) as stream:
stream.write_table(tab)

# TODO: CPU instrumentation is off with this
return await self.offload(_)

async def send(
self, address: str, shards: list[tuple[T_transfer_shard_id, bytes]]
) -> None:
class mylist(list):
...
self.raise_if_closed()
return await self.rpc(address).shuffle_receive(
data=to_serialize(shards),
data=ToPickle(mylist(shards)),
shuffle_id=self.id,
run_id=self.run_id,
)
Expand Down Expand Up @@ -503,11 +519,14 @@ async def _receive(self, data: list[tuple[int, bytes]]) -> None:
raise

def _repartition_buffers(self, data: list[bytes]) -> dict[NDIndex, list[bytes]]:
table = list_of_buffers_to_table(data)
table = pa.concat_tables(data)
groups = split_by_partition(table, self.column)
assert len(table) == sum(map(len, groups.values()))
del data
return {(k,): [serialize_table(v)] for k, v in groups.items()}
res = {(k,): [v] for k, v in groups.items()}
class mydict(dict):
pass
return mydict(res)

async def add_partition(self, data: pd.DataFrame, partition_id: int) -> int:
self.raise_if_closed()
Expand All @@ -520,7 +539,7 @@ def _() -> dict[str, list[tuple[int, bytes]]]:
self.column,
self.worker_for,
)
out = {k: [(partition_id, serialize_table(t))] for k, t in out.items()}
out = {k: [(partition_id, t)] for k, t in out.items()}
return out

out = await self.offload(_)
Expand Down Expand Up @@ -585,7 +604,7 @@ def __init__(self, worker: Worker) -> None:
self.memory_limiter_comms = ResourceLimiter(parse_bytes("100 MiB"))
self.memory_limiter_disk = ResourceLimiter(parse_bytes("1 GiB"))
self.closed = False
self._executor = ThreadPoolExecutor(self.worker.state.nthreads)
self._executor = ThreadPoolExecutor(self.worker.state.nthreads * 2)

def __str__(self) -> str:
return f"ShuffleWorkerExtension on {self.worker.address}"
Expand Down
9 changes: 4 additions & 5 deletions distributed/shuffle/tests/test_shuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@
ShuffleRun,
ShuffleWorkerExtension,
convert_partition,
list_of_buffers_to_table,
split_by_partition,
split_by_worker,
)
Expand Down Expand Up @@ -1215,10 +1214,10 @@ def new_shuffle(

# 36 parametrizations
# Runtime each ~0.1s
@pytest.mark.parametrize("n_workers", [1, 10])
@pytest.mark.parametrize("n_input_partitions", [1, 2, 10])
@pytest.mark.parametrize("npartitions", [1, 20])
@pytest.mark.parametrize("barrier_first_worker", [True, False])
@pytest.mark.parametrize("n_workers", [10])
@pytest.mark.parametrize("n_input_partitions", [5000])
@pytest.mark.parametrize("npartitions", [5000])
@pytest.mark.parametrize("barrier_first_worker", [True])
@gen_test()
async def test_basic_lowlevel_shuffle(
tmp_path,
Expand Down
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,6 @@ addopts = '''
-p no:asyncio
-p no:legacypath'''
filterwarnings = [
"error",
'''ignore:Please use `dok_matrix` from the `scipy\.sparse` namespace, the `scipy\.sparse\.dok` namespace is deprecated.:DeprecationWarning''',
'''ignore:elementwise comparison failed. this will raise an error in the future:DeprecationWarning''',
'''ignore:unclosed <socket\.socket.*:ResourceWarning''',
Expand Down