diff --git a/distributed/shuffle/_arrow.py b/distributed/shuffle/_arrow.py index f7531020274..72b4c23fd7e 100644 --- a/distributed/shuffle/_arrow.py +++ b/distributed/shuffle/_arrow.py @@ -6,62 +6,61 @@ import pyarrow as pa -def dump_batch(batch: pa.Buffer, file: BinaryIO, schema: pa.Schema) -> None: +def dump_shards(shards: list[pa.Table], file: BinaryIO) -> None: """ - Dump a batch to file, if we're the first, also write the schema + Write multiple shard tables to the file - This function is with respect to the open file object + Note: This function appends to the file and dumps each table as an individual stream. + This results in multiple end-of-stream signals in the file. See Also -------- - load_arrow + load_partition """ - if file.tell() == 0: - file.write(schema.serialize()) - file.write(batch) + import pyarrow as pa + + for table in shards: + with pa.ipc.new_stream(file, table.schema) as writer: + writer.write_table(table) -def load_arrow(file: BinaryIO) -> pa.Table: - """Load batched data written to file back out into a table again +def load_partition(file: BinaryIO) -> pa.Table: + """Load partition data written to file back out into a single table Example ------- - >>> t = pa.Table.from_pandas(df) # doctest: +SKIP + >>> tables = [pa.Table.from_pandas(df), pa.Table.from_pandas(df2)] # doctest: +SKIP >>> with open("myfile", mode="wb") as f: # doctest: +SKIP - ... for batch in t.to_batches(): # doctest: +SKIP - ... dump_batch(batch, f, schema=t.schema) # doctest: +SKIP + ... for table in tables: # doctest: +SKIP + ... dump_shards(tables, f) # doctest: +SKIP >>> with open("myfile", mode="rb") as f: # doctest: +SKIP - ... t = load_arrow(f) # doctest: +SKIP + ... t = load_partition(f) # doctest: +SKIP See Also -------- - dump_batch + dump_shards """ + import os + import pyarrow as pa - try: + pos = file.tell() + file.seek(0, os.SEEK_END) + end = file.tell() + file.seek(pos) + shards = [] + while file.tell() < end: sr = pa.RecordBatchStreamReader(file) - return sr.read_all() - except Exception: - raise EOFError + shards.append(sr.read_all()) + return pa.concat_tables(shards) -def list_of_buffers_to_table(data: list[bytes], schema: pa.Schema) -> pa.Table: +def list_of_buffers_to_table(data: list[bytes]) -> pa.Table: """Convert a list of arrow buffers and a schema to an Arrow Table""" - import io - import pyarrow as pa - bio = io.BytesIO() - bio.write(schema.serialize()) - for batch in data: - bio.write(batch) - bio.seek(0) - sr = pa.RecordBatchStreamReader(bio) - data = sr.read_all() - bio.close() - return data + return pa.concat_tables(deserialize_table(buffer) for buffer in data) def deserialize_schema(data: bytes) -> pa.Schema: @@ -87,3 +86,21 @@ def deserialize_schema(data: bytes) -> pa.Schema: table = sr.read_all() bio.close() return table.schema + + +def serialize_table(table: pa.Table) -> bytes: + import io + + import pyarrow as pa + + stream = io.BytesIO() + with pa.ipc.new_stream(stream, table.schema) as writer: + writer.write_table(table) + return stream.getvalue() + + +def deserialize_table(buffer: bytes) -> pa.Table: + import pyarrow as pa + + with pa.ipc.open_stream(pa.py_buffer(buffer)) as reader: + return reader.read_all() diff --git a/distributed/shuffle/_buffer.py b/distributed/shuffle/_buffer.py index 97e639f6237..f56d1fcb3da 100644 --- a/distributed/shuffle/_buffer.py +++ b/distributed/shuffle/_buffer.py @@ -5,14 +5,11 @@ import logging from collections import defaultdict from collections.abc import Iterator -from typing import TYPE_CHECKING, Any, Generic, Sized, TypeVar +from typing import Any, Generic, Sized, TypeVar from distributed.metrics import time from distributed.shuffle._limiter import ResourceLimiter -if TYPE_CHECKING: - import pyarrow as pa - logger = logging.getLogger("distributed.shuffle") ShardType = TypeVar("ShardType", bound=Sized) @@ -96,7 +93,7 @@ def heartbeat(self) -> dict[str, Any]: "memory_limit": self.memory_limiter._maxvalue if self.memory_limiter else 0, } - async def process(self, id: str, shards: list[pa.Table], size: int) -> None: + async def process(self, id: str, shards: list[ShardType], size: int) -> None: try: start = time() try: diff --git a/distributed/shuffle/_comms.py b/distributed/shuffle/_comms.py index e679cf34250..39b667de413 100644 --- a/distributed/shuffle/_comms.py +++ b/distributed/shuffle/_comms.py @@ -67,9 +67,4 @@ async def _process(self, address: str, shards: list[bytes]) -> None: # Consider boosting total_size a bit here to account for duplication with self.time("send"): - await self.send(address, [_join_shards(shards)]) - - -def _join_shards(shards: list[bytes]) -> bytes: - # This is just there for easier profiling - return b"".join(shards) + await self.send(address, shards) diff --git a/distributed/shuffle/_disk.py b/distributed/shuffle/_disk.py index 71b9ae8b534..70487221e94 100644 --- a/distributed/shuffle/_disk.py +++ b/distributed/shuffle/_disk.py @@ -4,12 +4,9 @@ import os import pathlib import shutil -from typing import TYPE_CHECKING, Any, BinaryIO, Callable +from typing import BinaryIO, Callable -if TYPE_CHECKING: - import pyarrow as pa - -from distributed.shuffle._buffer import ShardsBuffer +from distributed.shuffle._buffer import ShardsBuffer, ShardType from distributed.shuffle._limiter import ResourceLimiter from distributed.utils import log_errors @@ -23,7 +20,7 @@ class DiskShardsBuffer(ShardsBuffer): **State** - - shards: dict[str, list[bytes]] + - shards: dict[str, list[ShardType]] This is our in-memory buffer of data waiting to be written to files. @@ -53,8 +50,8 @@ class DiskShardsBuffer(ShardsBuffer): def __init__( self, directory: str, - dump: Callable[[Any, BinaryIO], None], - load: Callable[[BinaryIO], Any], + dump: Callable[[list[ShardType], BinaryIO], None], + load: Callable[[BinaryIO], list[ShardType]], memory_limiter: ResourceLimiter | None = None, ): super().__init__( @@ -68,7 +65,7 @@ def __init__( self.dump = dump self.load = load - async def _process(self, id: str, shards: list[pa.Buffer]) -> None: + async def _process(self, id: str, shards: list[ShardType]) -> None: """Write one buffer to file This function was built to offload the disk IO, but since then we've @@ -88,11 +85,9 @@ async def _process(self, id: str, shards: list[pa.Buffer]) -> None: with open( self.directory / str(id), mode="ab", buffering=100_000_000 ) as f: - for shard in shards: - self.dump(shard, f) - # os.fsync(f) # TODO: maybe? + self.dump(shards, f) - def read(self, id: int | str) -> pa.Table: + def read(self, id: int | str) -> list[ShardType]: """Read a complete file back into memory""" self.raise_on_exception() if not self._inputs_done: @@ -104,11 +99,7 @@ def read(self, id: int | str) -> pa.Table: with open( self.directory / str(id), mode="rb", buffering=100_000_000 ) as f: - while True: - try: - parts.append(self.load(f)) - except EOFError: - break + parts = self.load(f) size = f.tell() except FileNotFoundError: raise KeyError(id) @@ -116,8 +107,7 @@ def read(self, id: int | str) -> pa.Table: # TODO: We could consider deleting the file at this point if parts: self.bytes_read += size - assert len(parts) == 1 - return parts[0] + return parts else: raise KeyError(id) diff --git a/distributed/shuffle/_worker_extension.py b/distributed/shuffle/_worker_extension.py index 370531466fe..5ff16844532 100644 --- a/distributed/shuffle/_worker_extension.py +++ b/distributed/shuffle/_worker_extension.py @@ -9,7 +9,7 @@ from collections import defaultdict from collections.abc import Callable, Iterator from concurrent.futures import ThreadPoolExecutor -from typing import TYPE_CHECKING, Any, BinaryIO, TypeVar, overload +from typing import TYPE_CHECKING, Any, TypeVar, overload import toolz @@ -19,9 +19,10 @@ from distributed.protocol import to_serialize from distributed.shuffle._arrow import ( deserialize_schema, - dump_batch, + dump_shards, list_of_buffers_to_table, - load_arrow, + load_partition, + serialize_table, ) from distributed.shuffle._comms import CommShardsBuffer from distributed.shuffle._disk import DiskShardsBuffer @@ -121,12 +122,9 @@ def __init__( self.worker_for = pd.Series(worker_for, name="_workers").astype("category") self.closed = False - def _dump_batch(batch: pa.Buffer, file: BinaryIO) -> None: - return dump_batch(batch, file, self.schema) - self._disk_buffer = DiskShardsBuffer( - dump=_dump_batch, - load=load_arrow, + dump=dump_shards, + load=load_partition, directory=directory, memory_limiter=memory_limiter_disk, ) @@ -201,17 +199,14 @@ async def _receive(self, data: list[bytes]) -> None: self._exception = e raise - def _repartition_buffers(self, data: list[bytes]) -> dict[str, list[bytes]]: - table = list_of_buffers_to_table(data, self.schema) + def _repartition_buffers(self, data: list[bytes]) -> dict[str, list[pa.Table]]: + table = list_of_buffers_to_table(data) groups = split_by_partition(table, self.column) assert len(table) == sum(map(len, groups.values())) del data - return { - k: [batch.serialize() for batch in v.to_batches()] - for k, v in groups.items() - } + return {k: [v] for k, v in groups.items()} - async def _write_to_disk(self, data: dict[str, list[bytes]]) -> None: + async def _write_to_disk(self, data: dict[str, list[pa.Table]]) -> None: self.raise_if_closed() await self._disk_buffer.write(data) @@ -234,10 +229,7 @@ def _() -> dict[str, list[bytes]]: self.column, self.worker_for, ) - out = { - k: [b.serialize().to_pybytes() for b in t.to_batches()] - for k, t in out.items() - } + out = {k: [serialize_table(t)] for k, t in out.items()} return out out = await self.offload(_) diff --git a/distributed/shuffle/tests/test_disk_buffer.py b/distributed/shuffle/tests/test_disk_buffer.py index 915e8a5babe..7880a7e2f1c 100644 --- a/distributed/shuffle/tests/test_disk_buffer.py +++ b/distributed/shuffle/tests/test_disk_buffer.py @@ -10,6 +10,7 @@ def dump(data, f): + data = b"".join(data) f.write(data) diff --git a/distributed/shuffle/tests/test_shuffle.py b/distributed/shuffle/tests/test_shuffle.py index 9e290f03a4a..e87413d44bb 100644 --- a/distributed/shuffle/tests/test_shuffle.py +++ b/distributed/shuffle/tests/test_shuffle.py @@ -6,9 +6,11 @@ import random import shutil from collections import defaultdict +from itertools import count from typing import Any, Mapping from unittest import mock +import numpy as np import pandas as pd import pytest @@ -20,15 +22,16 @@ from distributed.core import PooledRPCCall from distributed.scheduler import Scheduler from distributed.scheduler import TaskState as SchedulerTaskState +from distributed.shuffle._arrow import serialize_table from distributed.shuffle._limiter import ResourceLimiter from distributed.shuffle._scheduler_extension import get_worker_for from distributed.shuffle._shuffle import ShuffleId, barrier_key from distributed.shuffle._worker_extension import ( Shuffle, ShuffleWorkerExtension, - dump_batch, + dump_shards, list_of_buffers_to_table, - load_arrow, + load_partition, split_by_partition, split_by_worker, ) @@ -654,10 +657,104 @@ def test_processing_chain(): In practice this takes place on many different workers. Here we verify its accuracy in a single threaded situation. """ + + class Stub: + def __init__(self, value: int) -> None: + self.value = value + + counter = count() workers = ["a", "b", "c"] npartitions = 5 - df = pd.DataFrame({"x": range(100), "y": range(100)}) - df["_partitions"] = df.x % npartitions + + # Test the processing chain with a dataframe that contains all supported dtypes + df = pd.DataFrame( + { + # numpy dtypes + f"col{next(counter)}": pd.array([True, False] * 50, dtype="bool"), + f"col{next(counter)}": pd.array(range(100), dtype="int8"), + f"col{next(counter)}": pd.array(range(100), dtype="int16"), + f"col{next(counter)}": pd.array(range(100), dtype="int32"), + f"col{next(counter)}": pd.array(range(100), dtype="int64"), + f"col{next(counter)}": pd.array(range(100), dtype="uint8"), + f"col{next(counter)}": pd.array(range(100), dtype="uint16"), + f"col{next(counter)}": pd.array(range(100), dtype="uint32"), + f"col{next(counter)}": pd.array(range(100), dtype="uint64"), + f"col{next(counter)}": pd.array(range(100), dtype="float16"), + f"col{next(counter)}": pd.array(range(100), dtype="float32"), + f"col{next(counter)}": pd.array(range(100), dtype="float64"), + f"col{next(counter)}": pd.array( + [np.datetime64("2022-01-01") + i for i in range(100)], + dtype="datetime64", + ), + f"col{next(counter)}": pd.array( + [np.timedelta64(1, "D") + i for i in range(100)], dtype="timedelta64" + ), + # FIXME: PyArrow does not support complex numbers: https://issues.apache.org/jira/browse/ARROW-638 + # f"col{next(counter)}": pd.array(range(100), dtype="csingle"), + # f"col{next(counter)}": pd.array(range(100), dtype="cdouble"), + # f"col{next(counter)}": pd.array(range(100), dtype="clongdouble"), + # Nullable dtypes + f"col{next(counter)}": pd.array([True, False] * 50, dtype="boolean"), + f"col{next(counter)}": pd.array(range(100), dtype="Int8"), + f"col{next(counter)}": pd.array(range(100), dtype="Int16"), + f"col{next(counter)}": pd.array(range(100), dtype="Int32"), + f"col{next(counter)}": pd.array(range(100), dtype="Int64"), + f"col{next(counter)}": pd.array(range(100), dtype="UInt8"), + f"col{next(counter)}": pd.array(range(100), dtype="UInt16"), + f"col{next(counter)}": pd.array(range(100), dtype="UInt32"), + f"col{next(counter)}": pd.array(range(100), dtype="UInt64"), + # pandas dtypes + f"col{next(counter)}": pd.array( + [np.datetime64("2022-01-01") + i for i in range(100)], + dtype=pd.DatetimeTZDtype(tz="Europe/Berlin"), + ), + f"col{next(counter)}": pd.array( + [pd.Period("2022-01-01", freq="D") + i for i in range(100)], + dtype="period[D]", + ), + f"col{next(counter)}": pd.array( + [pd.Interval(left=i, right=i + 2) for i in range(100)], dtype="Interval" + ), + f"col{next(counter)}": pd.array(["x", "y"] * 50, dtype="category"), + f"col{next(counter)}": pd.array(["lorem ipsum"] * 100, dtype="string"), + # FIXME: PyArrow does not support sparse data: https://issues.apache.org/jira/browse/ARROW-8679 + # f"col{next(counter)}": pd.array( + # [np.nan, np.nan, 1.0, np.nan, np.nan] * 20, + # dtype="Sparse[float64]", + # ), + # PyArrow dtypes + f"col{next(counter)}": pd.array([True, False] * 50, dtype="bool[pyarrow]"), + f"col{next(counter)}": pd.array(range(100), dtype="int8[pyarrow]"), + f"col{next(counter)}": pd.array(range(100), dtype="int16[pyarrow]"), + f"col{next(counter)}": pd.array(range(100), dtype="int32[pyarrow]"), + f"col{next(counter)}": pd.array(range(100), dtype="int64[pyarrow]"), + f"col{next(counter)}": pd.array(range(100), dtype="uint8[pyarrow]"), + f"col{next(counter)}": pd.array(range(100), dtype="uint16[pyarrow]"), + f"col{next(counter)}": pd.array(range(100), dtype="uint32[pyarrow]"), + f"col{next(counter)}": pd.array(range(100), dtype="uint64[pyarrow]"), + f"col{next(counter)}": pd.array(range(100), dtype="float32[pyarrow]"), + f"col{next(counter)}": pd.array(range(100), dtype="float64[pyarrow]"), + f"col{next(counter)}": pd.array( + [pd.Timestamp.fromtimestamp(1641034800 + i) for i in range(100)], + dtype=pd.ArrowDtype(pa.timestamp("ms")), + ), + # FIXME: distributed#7420 + # f"col{next(counter)}": pd.array( + # ["lorem ipsum"] * 100, + # dtype="string[pyarrow]", + # ), + # f"col{next(counter)}": pd.array( + # ["lorem ipsum"] * 100, + # dtype=pd.StringDtype("pyarrow"), + # ), + # custom objects + # FIXME: Serializing custom objects is not supported in P2P shuffling + # f"col{next(counter)}": pd.array( + # [Stub(i) for i in range(100)], dtype="object" + # ), + } + ) + df["_partitions"] = df.col4 % npartitions schema = pa.Schema.from_pandas(df) worker_for = {i: random.choice(workers) for i in list(range(npartitions))} worker_for = pd.Series(worker_for, name="_worker").astype("category") @@ -666,16 +763,13 @@ def test_processing_chain(): assert set(data) == set(worker_for.cat.categories) assert sum(map(len, data.values())) == len(df) - batches = { - worker: [b.serialize().to_pybytes() for b in t.to_batches()] - for worker, t in data.items() - } + batches = {worker: [serialize_table(t)] for worker, t in data.items()} # Typically we communicate to different workers at this stage # We then receive them back and reconstute them by_worker = { - worker: list_of_buffers_to_table(list_of_batches, schema) + worker: list_of_buffers_to_table(list_of_batches) for worker, list_of_batches in batches.items() } assert sum(map(len, by_worker.values())) == len(df) @@ -687,10 +781,7 @@ def test_processing_chain(): } splits_by_worker = { - worker: { - partition: [batch.serialize() for batch in t.to_batches()] - for partition, t in d.items() - } + worker: {partition: [t] for partition, t in d.items()} for worker, d in splits_by_worker.items() } @@ -707,16 +798,21 @@ def test_processing_chain(): filesystem = defaultdict(io.BytesIO) for partitions in splits_by_worker.values(): - for partition, batches in partitions.items(): - for batch in batches: - dump_batch(batch, filesystem[partition], schema) + for partition, tables in partitions.items(): + dump_shards(tables, filesystem[partition]) out = {} for k, bio in filesystem.items(): bio.seek(0) - out[k] = load_arrow(bio) - - assert sum(map(len, out.values())) == len(df) + out[k] = load_partition(bio) + + shuffled_df = pd.concat(table.to_pandas() for table in out.values()) + pd.testing.assert_frame_equal( + df, + shuffled_df, + check_like=True, + check_exact=True, + ) @gen_cluster(client=True)