Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 47 additions & 30 deletions distributed/shuffle/_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,62 +6,61 @@
import pyarrow as pa


def dump_batch(batch: pa.Buffer, file: BinaryIO, schema: pa.Schema) -> None:
def dump_shards(shards: list[pa.Table], file: BinaryIO) -> None:
"""
Dump a batch to file, if we're the first, also write the schema
Write multiple shard tables to the file

This function is with respect to the open file object
Note: This function appends to the file and dumps each table as an individual stream.
This results in multiple end-of-stream signals in the file.

See Also
--------
load_arrow
load_partition
"""
if file.tell() == 0:
file.write(schema.serialize())
file.write(batch)
import pyarrow as pa

for table in shards:
with pa.ipc.new_stream(file, table.schema) as writer:
writer.write_table(table)


def load_arrow(file: BinaryIO) -> pa.Table:
"""Load batched data written to file back out into a table again
def load_partition(file: BinaryIO) -> pa.Table:
"""Load partition data written to file back out into a single table

Example
-------
>>> t = pa.Table.from_pandas(df) # doctest: +SKIP
>>> tables = [pa.Table.from_pandas(df), pa.Table.from_pandas(df2)] # doctest: +SKIP
>>> with open("myfile", mode="wb") as f: # doctest: +SKIP
... for batch in t.to_batches(): # doctest: +SKIP
... dump_batch(batch, f, schema=t.schema) # doctest: +SKIP
... for table in tables: # doctest: +SKIP
... dump_shards(tables, f) # doctest: +SKIP

>>> with open("myfile", mode="rb") as f: # doctest: +SKIP
... t = load_arrow(f) # doctest: +SKIP
... t = load_partition(f) # doctest: +SKIP

See Also
--------
dump_batch
dump_shards
"""
import os

import pyarrow as pa

try:
pos = file.tell()
file.seek(0, os.SEEK_END)
end = file.tell()
file.seek(pos)
shards = []
while file.tell() < end:
sr = pa.RecordBatchStreamReader(file)
return sr.read_all()
except Exception:
raise EOFError
shards.append(sr.read_all())
return pa.concat_tables(shards)


def list_of_buffers_to_table(data: list[bytes], schema: pa.Schema) -> pa.Table:
def list_of_buffers_to_table(data: list[bytes]) -> pa.Table:
"""Convert a list of arrow buffers and a schema to an Arrow Table"""
import io

import pyarrow as pa

bio = io.BytesIO()
bio.write(schema.serialize())
for batch in data:
bio.write(batch)
bio.seek(0)
sr = pa.RecordBatchStreamReader(bio)
data = sr.read_all()
bio.close()
return data
return pa.concat_tables(deserialize_table(buffer) for buffer in data)


def deserialize_schema(data: bytes) -> pa.Schema:
Expand All @@ -87,3 +86,21 @@ def deserialize_schema(data: bytes) -> pa.Schema:
table = sr.read_all()
bio.close()
return table.schema


def serialize_table(table: pa.Table) -> bytes:
import io

import pyarrow as pa

stream = io.BytesIO()
with pa.ipc.new_stream(stream, table.schema) as writer:
writer.write_table(table)
return stream.getvalue()


def deserialize_table(buffer: bytes) -> pa.Table:
import pyarrow as pa

with pa.ipc.open_stream(pa.py_buffer(buffer)) as reader:
return reader.read_all()
7 changes: 2 additions & 5 deletions distributed/shuffle/_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,11 @@
import logging
from collections import defaultdict
from collections.abc import Iterator
from typing import TYPE_CHECKING, Any, Generic, Sized, TypeVar
from typing import Any, Generic, Sized, TypeVar

from distributed.metrics import time
from distributed.shuffle._limiter import ResourceLimiter

if TYPE_CHECKING:
import pyarrow as pa

logger = logging.getLogger("distributed.shuffle")

ShardType = TypeVar("ShardType", bound=Sized)
Expand Down Expand Up @@ -96,7 +93,7 @@ def heartbeat(self) -> dict[str, Any]:
"memory_limit": self.memory_limiter._maxvalue if self.memory_limiter else 0,
}

async def process(self, id: str, shards: list[pa.Table], size: int) -> None:
async def process(self, id: str, shards: list[ShardType], size: int) -> None:
try:
start = time()
try:
Expand Down
7 changes: 1 addition & 6 deletions distributed/shuffle/_comms.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,4 @@ async def _process(self, address: str, shards: list[bytes]) -> None:

# Consider boosting total_size a bit here to account for duplication
with self.time("send"):
await self.send(address, [_join_shards(shards)])


def _join_shards(shards: list[bytes]) -> bytes:
# This is just there for easier profiling
return b"".join(shards)
await self.send(address, shards)
30 changes: 10 additions & 20 deletions distributed/shuffle/_disk.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,9 @@
import os
import pathlib
import shutil
from typing import TYPE_CHECKING, Any, BinaryIO, Callable
from typing import BinaryIO, Callable

if TYPE_CHECKING:
import pyarrow as pa

from distributed.shuffle._buffer import ShardsBuffer
from distributed.shuffle._buffer import ShardsBuffer, ShardType
from distributed.shuffle._limiter import ResourceLimiter
from distributed.utils import log_errors

Expand All @@ -23,7 +20,7 @@ class DiskShardsBuffer(ShardsBuffer):

**State**

- shards: dict[str, list[bytes]]
- shards: dict[str, list[ShardType]]

This is our in-memory buffer of data waiting to be written to files.

Expand Down Expand Up @@ -53,8 +50,8 @@ class DiskShardsBuffer(ShardsBuffer):
def __init__(
self,
directory: str,
dump: Callable[[Any, BinaryIO], None],
load: Callable[[BinaryIO], Any],
dump: Callable[[list[ShardType], BinaryIO], None],
load: Callable[[BinaryIO], list[ShardType]],
memory_limiter: ResourceLimiter | None = None,
):
super().__init__(
Expand All @@ -68,7 +65,7 @@ def __init__(
self.dump = dump
self.load = load

async def _process(self, id: str, shards: list[pa.Buffer]) -> None:
async def _process(self, id: str, shards: list[ShardType]) -> None:
"""Write one buffer to file

This function was built to offload the disk IO, but since then we've
Expand All @@ -88,11 +85,9 @@ async def _process(self, id: str, shards: list[pa.Buffer]) -> None:
with open(
self.directory / str(id), mode="ab", buffering=100_000_000
) as f:
for shard in shards:
self.dump(shard, f)
# os.fsync(f) # TODO: maybe?
self.dump(shards, f)

def read(self, id: int | str) -> pa.Table:
def read(self, id: int | str) -> list[ShardType]:
"""Read a complete file back into memory"""
self.raise_on_exception()
if not self._inputs_done:
Expand All @@ -104,20 +99,15 @@ def read(self, id: int | str) -> pa.Table:
with open(
self.directory / str(id), mode="rb", buffering=100_000_000
) as f:
while True:
try:
parts.append(self.load(f))
except EOFError:
break
parts = self.load(f)
size = f.tell()
except FileNotFoundError:
raise KeyError(id)

# TODO: We could consider deleting the file at this point
if parts:
self.bytes_read += size
assert len(parts) == 1
return parts[0]
return parts
else:
raise KeyError(id)

Expand Down
30 changes: 11 additions & 19 deletions distributed/shuffle/_worker_extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from collections import defaultdict
from collections.abc import Callable, Iterator
from concurrent.futures import ThreadPoolExecutor
from typing import TYPE_CHECKING, Any, BinaryIO, TypeVar, overload
from typing import TYPE_CHECKING, Any, TypeVar, overload

import toolz

Expand All @@ -19,9 +19,10 @@
from distributed.protocol import to_serialize
from distributed.shuffle._arrow import (
deserialize_schema,
dump_batch,
dump_shards,
list_of_buffers_to_table,
load_arrow,
load_partition,
serialize_table,
)
from distributed.shuffle._comms import CommShardsBuffer
from distributed.shuffle._disk import DiskShardsBuffer
Expand Down Expand Up @@ -121,12 +122,9 @@ def __init__(
self.worker_for = pd.Series(worker_for, name="_workers").astype("category")
self.closed = False

def _dump_batch(batch: pa.Buffer, file: BinaryIO) -> None:
return dump_batch(batch, file, self.schema)

self._disk_buffer = DiskShardsBuffer(
dump=_dump_batch,
load=load_arrow,
dump=dump_shards,
load=load_partition,
directory=directory,
memory_limiter=memory_limiter_disk,
)
Expand Down Expand Up @@ -201,17 +199,14 @@ async def _receive(self, data: list[bytes]) -> None:
self._exception = e
raise

def _repartition_buffers(self, data: list[bytes]) -> dict[str, list[bytes]]:
table = list_of_buffers_to_table(data, self.schema)
def _repartition_buffers(self, data: list[bytes]) -> dict[str, list[pa.Table]]:
table = list_of_buffers_to_table(data)
groups = split_by_partition(table, self.column)
assert len(table) == sum(map(len, groups.values()))
del data
return {
k: [batch.serialize() for batch in v.to_batches()]
for k, v in groups.items()
}
return {k: [v] for k, v in groups.items()}

async def _write_to_disk(self, data: dict[str, list[bytes]]) -> None:
async def _write_to_disk(self, data: dict[str, list[pa.Table]]) -> None:
self.raise_if_closed()
await self._disk_buffer.write(data)

Expand All @@ -234,10 +229,7 @@ def _() -> dict[str, list[bytes]]:
self.column,
self.worker_for,
)
out = {
k: [b.serialize().to_pybytes() for b in t.to_batches()]
for k, t in out.items()
}
out = {k: [serialize_table(t)] for k, t in out.items()}
return out

out = await self.offload(_)
Expand Down
1 change: 1 addition & 0 deletions distributed/shuffle/tests/test_disk_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@


def dump(data, f):
data = b"".join(data)
f.write(data)


Expand Down
Loading