From a762becb8ba43d78d68ea45368315f1c4728a073 Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Thu, 15 Dec 2022 11:05:03 +0100 Subject: [PATCH 01/23] Extend test case --- distributed/shuffle/tests/test_shuffle.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/distributed/shuffle/tests/test_shuffle.py b/distributed/shuffle/tests/test_shuffle.py index 7f397c8c1f8..961f8d1faad 100644 --- a/distributed/shuffle/tests/test_shuffle.py +++ b/distributed/shuffle/tests/test_shuffle.py @@ -6,6 +6,7 @@ import random import shutil from collections import defaultdict +from itertools import chain from typing import Any, Mapping from unittest import mock @@ -661,7 +662,10 @@ def test_processing_chain(): """ workers = ["a", "b", "c"] npartitions = 5 - df = pd.DataFrame({"x": range(100), "y": range(100)}) + df = pd.DataFrame( + {"x": range(100), "y": range(100), "z": chain(range(50), range(50))} + ) + df["z"] = df["z"].astype("category") df["_partitions"] = df.x % npartitions schema = pa.Schema.from_pandas(df) worker_for = {i: random.choice(workers) for i in list(range(npartitions))} From 106486471a36d7a46cdf9478557f9d8aab2366c2 Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Thu, 15 Dec 2022 11:09:52 +0100 Subject: [PATCH 02/23] Improve typing --- distributed/shuffle/_buffer.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/distributed/shuffle/_buffer.py b/distributed/shuffle/_buffer.py index 97e639f6237..f56d1fcb3da 100644 --- a/distributed/shuffle/_buffer.py +++ b/distributed/shuffle/_buffer.py @@ -5,14 +5,11 @@ import logging from collections import defaultdict from collections.abc import Iterator -from typing import TYPE_CHECKING, Any, Generic, Sized, TypeVar +from typing import Any, Generic, Sized, TypeVar from distributed.metrics import time from distributed.shuffle._limiter import ResourceLimiter -if TYPE_CHECKING: - import pyarrow as pa - logger = logging.getLogger("distributed.shuffle") ShardType = TypeVar("ShardType", bound=Sized) @@ -96,7 +93,7 @@ def heartbeat(self) -> dict[str, Any]: "memory_limit": self.memory_limiter._maxvalue if self.memory_limiter else 0, } - async def process(self, id: str, shards: list[pa.Table], size: int) -> None: + async def process(self, id: str, shards: list[ShardType], size: int) -> None: try: start = time() try: From 4c23e3a2f6d4547de981ae9469012dc5468f945d Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Thu, 15 Dec 2022 11:31:19 +0100 Subject: [PATCH 03/23] Fix network sending --- distributed/shuffle/_arrow.py | 25 +++++++++++++---------- distributed/shuffle/_shuffle_extension.py | 8 +++----- distributed/shuffle/tests/test_shuffle.py | 6 ++---- 3 files changed, 19 insertions(+), 20 deletions(-) diff --git a/distributed/shuffle/_arrow.py b/distributed/shuffle/_arrow.py index f7531020274..01740242db7 100644 --- a/distributed/shuffle/_arrow.py +++ b/distributed/shuffle/_arrow.py @@ -49,19 +49,11 @@ def load_arrow(file: BinaryIO) -> pa.Table: def list_of_buffers_to_table(data: list[bytes], schema: pa.Schema) -> pa.Table: """Convert a list of arrow buffers and a schema to an Arrow Table""" - import io - import pyarrow as pa - bio = io.BytesIO() - bio.write(schema.serialize()) - for batch in data: - bio.write(batch) - bio.seek(0) - sr = pa.RecordBatchStreamReader(bio) - data = sr.read_all() - bio.close() - return data + assert len(data) == 1 + with pa.ipc.open_stream(pa.py_buffer(data[0])) as reader: + return reader.read_all() def deserialize_schema(data: bytes) -> pa.Schema: @@ -87,3 +79,14 @@ def deserialize_schema(data: bytes) -> pa.Schema: table = sr.read_all() bio.close() return table.schema + + +def serialize_table(table: pa.Table) -> bytes: + import io + + import pyarrow as pa + + stream = io.BytesIO() + with pa.ipc.new_stream(stream, table.schema) as writer: + writer.write_table(table) + return stream.getvalue() diff --git a/distributed/shuffle/_shuffle_extension.py b/distributed/shuffle/_shuffle_extension.py index 37ac35893c6..9cf8a877a80 100644 --- a/distributed/shuffle/_shuffle_extension.py +++ b/distributed/shuffle/_shuffle_extension.py @@ -23,6 +23,7 @@ dump_batch, list_of_buffers_to_table, load_arrow, + serialize_table, ) from distributed.shuffle._comms import CommShardsBuffer from distributed.shuffle._disk import DiskShardsBuffer @@ -203,15 +204,12 @@ async def _receive(self, data: list[bytes]) -> None: self._exception = e raise - def _repartition_buffers(self, data: list[bytes]) -> dict[str, list[bytes]]: + def _repartition_buffers(self, data: list[bytes]) -> dict[str, list[pa.Table]]: table = list_of_buffers_to_table(data, self.schema) groups = split_by_partition(table, self.column) assert len(table) == sum(map(len, groups.values())) del data - return { - k: [batch.serialize() for batch in v.to_batches()] - for k, v in groups.items() - } + return {k: [serialize_table(v)] for k, v in groups.items()} async def _write_to_disk(self, data: dict[str, list[bytes]]) -> None: self.raise_if_closed() diff --git a/distributed/shuffle/tests/test_shuffle.py b/distributed/shuffle/tests/test_shuffle.py index 961f8d1faad..b2fd46569be 100644 --- a/distributed/shuffle/tests/test_shuffle.py +++ b/distributed/shuffle/tests/test_shuffle.py @@ -30,6 +30,7 @@ get_worker_for, list_of_buffers_to_table, load_arrow, + serialize_table, split_by_partition, split_by_worker, ) @@ -675,10 +676,7 @@ def test_processing_chain(): assert set(data) == set(worker_for.cat.categories) assert sum(map(len, data.values())) == len(df) - batches = { - worker: [b.serialize().to_pybytes() for b in t.to_batches()] - for worker, t in data.items() - } + batches = {worker: [serialize_table(t)] for worker, t in data.items()} # Typically we communicate to different workers at this stage # We then receive them back and reconstute them From ab7a1f96e7108cec9d55c497219f7e95d63ad836 Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Thu, 15 Dec 2022 11:54:55 +0100 Subject: [PATCH 04/23] Fix disk serialization --- distributed/shuffle/_arrow.py | 14 ++++++++------ distributed/shuffle/_disk.py | 3 +-- distributed/shuffle/_shuffle_extension.py | 13 ++++++------- distributed/shuffle/tests/test_shuffle.py | 15 ++++++--------- 4 files changed, 21 insertions(+), 24 deletions(-) diff --git a/distributed/shuffle/_arrow.py b/distributed/shuffle/_arrow.py index 01740242db7..1c64b7d5850 100644 --- a/distributed/shuffle/_arrow.py +++ b/distributed/shuffle/_arrow.py @@ -6,19 +6,21 @@ import pyarrow as pa -def dump_batch(batch: pa.Buffer, file: BinaryIO, schema: pa.Schema) -> None: +def dump_table(table: pa.Table, file: BinaryIO) -> None: """ - Dump a batch to file, if we're the first, also write the schema + Dump a table to file - This function is with respect to the open file object + Note: This function appends to the file and signals end-of-stream when done. + This results in multiple end-of-stream signals in a stream. See Also -------- load_arrow """ - if file.tell() == 0: - file.write(schema.serialize()) - file.write(batch) + import pyarrow as pa + + with pa.ipc.new_stream(file, table.schema) as writer: + writer.write_table(table) def load_arrow(file: BinaryIO) -> pa.Table: diff --git a/distributed/shuffle/_disk.py b/distributed/shuffle/_disk.py index 71b9ae8b534..73ebcb2f943 100644 --- a/distributed/shuffle/_disk.py +++ b/distributed/shuffle/_disk.py @@ -116,8 +116,7 @@ def read(self, id: int | str) -> pa.Table: # TODO: We could consider deleting the file at this point if parts: self.bytes_read += size - assert len(parts) == 1 - return parts[0] + return pa.concat_tables(parts) else: raise KeyError(id) diff --git a/distributed/shuffle/_shuffle_extension.py b/distributed/shuffle/_shuffle_extension.py index 9cf8a877a80..608e249248e 100644 --- a/distributed/shuffle/_shuffle_extension.py +++ b/distributed/shuffle/_shuffle_extension.py @@ -20,10 +20,9 @@ from distributed.protocol import to_serialize from distributed.shuffle._arrow import ( deserialize_schema, - dump_batch, + dump_table, list_of_buffers_to_table, load_arrow, - serialize_table, ) from distributed.shuffle._comms import CommShardsBuffer from distributed.shuffle._disk import DiskShardsBuffer @@ -124,11 +123,11 @@ def __init__( self.worker_for = pd.Series(worker_for, name="_workers").astype("category") self.closed = False - def _dump_batch(batch: pa.Buffer, file: BinaryIO) -> None: - return dump_batch(batch, file, self.schema) + def _dump_table(table: pa.Table, file: BinaryIO) -> None: + return dump_table(table, file) self._disk_buffer = DiskShardsBuffer( - dump=_dump_batch, + dump=_dump_table, load=load_arrow, directory=directory, memory_limiter=memory_limiter_disk, @@ -209,9 +208,9 @@ def _repartition_buffers(self, data: list[bytes]) -> dict[str, list[pa.Table]]: groups = split_by_partition(table, self.column) assert len(table) == sum(map(len, groups.values())) del data - return {k: [serialize_table(v)] for k, v in groups.items()} + return {k: [v] for k, v in groups.items()} - async def _write_to_disk(self, data: dict[str, list[bytes]]) -> None: + async def _write_to_disk(self, data: dict[str, list[pa.Table]]) -> None: self.raise_if_closed() await self._disk_buffer.write(data) diff --git a/distributed/shuffle/tests/test_shuffle.py b/distributed/shuffle/tests/test_shuffle.py index b2fd46569be..cb3349f6578 100644 --- a/distributed/shuffle/tests/test_shuffle.py +++ b/distributed/shuffle/tests/test_shuffle.py @@ -21,16 +21,16 @@ from distributed.core import PooledRPCCall from distributed.scheduler import Scheduler from distributed.scheduler import TaskState as SchedulerTaskState +from distributed.shuffle._arrow import serialize_table from distributed.shuffle._limiter import ResourceLimiter from distributed.shuffle._shuffle_extension import ( Shuffle, ShuffleId, ShuffleWorkerExtension, - dump_batch, + dump_table, get_worker_for, list_of_buffers_to_table, load_arrow, - serialize_table, split_by_partition, split_by_worker, ) @@ -694,10 +694,7 @@ def test_processing_chain(): } splits_by_worker = { - worker: { - partition: [batch.serialize() for batch in t.to_batches()] - for partition, t in d.items() - } + worker: {partition: [t] for partition, t in d.items()} for worker, d in splits_by_worker.items() } @@ -714,9 +711,9 @@ def test_processing_chain(): filesystem = defaultdict(io.BytesIO) for partitions in splits_by_worker.values(): - for partition, batches in partitions.items(): - for batch in batches: - dump_batch(batch, filesystem[partition], schema) + for partition, tables in partitions.items(): + for table in tables: + dump_table(table, filesystem[partition]) out = {} for k, bio in filesystem.items(): From 3d6cbe8be70ec1d7f50ffc04d8889977f10b60ce Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Thu, 15 Dec 2022 12:03:56 +0100 Subject: [PATCH 05/23] Fix --- distributed/shuffle/_disk.py | 2 ++ distributed/shuffle/_shuffle_extension.py | 6 ++---- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/distributed/shuffle/_disk.py b/distributed/shuffle/_disk.py index 73ebcb2f943..dea144ec257 100644 --- a/distributed/shuffle/_disk.py +++ b/distributed/shuffle/_disk.py @@ -93,6 +93,8 @@ async def _process(self, id: str, shards: list[pa.Buffer]) -> None: # os.fsync(f) # TODO: maybe? def read(self, id: int | str) -> pa.Table: + import pyarrow as pa + """Read a complete file back into memory""" self.raise_on_exception() if not self._inputs_done: diff --git a/distributed/shuffle/_shuffle_extension.py b/distributed/shuffle/_shuffle_extension.py index 608e249248e..6be0b66479a 100644 --- a/distributed/shuffle/_shuffle_extension.py +++ b/distributed/shuffle/_shuffle_extension.py @@ -23,6 +23,7 @@ dump_table, list_of_buffers_to_table, load_arrow, + serialize_table, ) from distributed.shuffle._comms import CommShardsBuffer from distributed.shuffle._disk import DiskShardsBuffer @@ -233,10 +234,7 @@ def _() -> dict[str, list[bytes]]: self.column, self.worker_for, ) - out = { - k: [b.serialize().to_pybytes() for b in t.to_batches()] - for k, t in out.items() - } + out = {k: [serialize_table(t)] for k, t in out.items()} return out out = await self.offload(_) From 5dff0d778040f7a83e5e43e86aaa14788f44353c Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Thu, 15 Dec 2022 16:33:58 +0100 Subject: [PATCH 06/23] Fix changes to buffer interfaces --- distributed/shuffle/_arrow.py | 37 ++++++++++--------- distributed/shuffle/_disk.py | 31 +++++----------- distributed/shuffle/_shuffle_extension.py | 12 +++--- distributed/shuffle/tests/test_disk_buffer.py | 1 + distributed/shuffle/tests/test_shuffle.py | 9 ++--- 5 files changed, 41 insertions(+), 49 deletions(-) diff --git a/distributed/shuffle/_arrow.py b/distributed/shuffle/_arrow.py index 1c64b7d5850..82566e80880 100644 --- a/distributed/shuffle/_arrow.py +++ b/distributed/shuffle/_arrow.py @@ -6,12 +6,12 @@ import pyarrow as pa -def dump_table(table: pa.Table, file: BinaryIO) -> None: +def dump_table_batch(tables: list[pa.Table], file: BinaryIO) -> None: """ - Dump a table to file + Dump multiple tables to the file - Note: This function appends to the file and signals end-of-stream when done. - This results in multiple end-of-stream signals in a stream. + Note: This function appends to the file and dumps each table as an individual stream. + This results in multiple end-of-stream signals in the file. See Also -------- @@ -19,34 +19,37 @@ def dump_table(table: pa.Table, file: BinaryIO) -> None: """ import pyarrow as pa - with pa.ipc.new_stream(file, table.schema) as writer: - writer.write_table(table) + for table in tables: + with pa.ipc.new_stream(file, table.schema) as writer: + writer.write_table(table) -def load_arrow(file: BinaryIO) -> pa.Table: - """Load batched data written to file back out into a table again +def load_into_table(file: BinaryIO) -> pa.Table: + """Load batched data written to file back out into a single table Example ------- - >>> t = pa.Table.from_pandas(df) # doctest: +SKIP + >>> tables = [pa.Table.from_pandas(df), pa.Table.from_pandas(df2)] # doctest: +SKIP >>> with open("myfile", mode="wb") as f: # doctest: +SKIP - ... for batch in t.to_batches(): # doctest: +SKIP - ... dump_batch(batch, f, schema=t.schema) # doctest: +SKIP + ... for table in tables: # doctest: +SKIP + ... dump_table_batch(tables, f, schema=t.schema) # doctest: +SKIP >>> with open("myfile", mode="rb") as f: # doctest: +SKIP - ... t = load_arrow(f) # doctest: +SKIP + ... t = load_into_table(f) # doctest: +SKIP See Also -------- - dump_batch + dump_table_batch """ import pyarrow as pa + tables = [] try: - sr = pa.RecordBatchStreamReader(file) - return sr.read_all() - except Exception: - raise EOFError + while True: + sr = pa.RecordBatchStreamReader(file) + tables.append(sr.read_all()) + except pa.ArrowInvalid: + return pa.concat_tables(tables) def list_of_buffers_to_table(data: list[bytes], schema: pa.Schema) -> pa.Table: diff --git a/distributed/shuffle/_disk.py b/distributed/shuffle/_disk.py index dea144ec257..70487221e94 100644 --- a/distributed/shuffle/_disk.py +++ b/distributed/shuffle/_disk.py @@ -4,12 +4,9 @@ import os import pathlib import shutil -from typing import TYPE_CHECKING, Any, BinaryIO, Callable +from typing import BinaryIO, Callable -if TYPE_CHECKING: - import pyarrow as pa - -from distributed.shuffle._buffer import ShardsBuffer +from distributed.shuffle._buffer import ShardsBuffer, ShardType from distributed.shuffle._limiter import ResourceLimiter from distributed.utils import log_errors @@ -23,7 +20,7 @@ class DiskShardsBuffer(ShardsBuffer): **State** - - shards: dict[str, list[bytes]] + - shards: dict[str, list[ShardType]] This is our in-memory buffer of data waiting to be written to files. @@ -53,8 +50,8 @@ class DiskShardsBuffer(ShardsBuffer): def __init__( self, directory: str, - dump: Callable[[Any, BinaryIO], None], - load: Callable[[BinaryIO], Any], + dump: Callable[[list[ShardType], BinaryIO], None], + load: Callable[[BinaryIO], list[ShardType]], memory_limiter: ResourceLimiter | None = None, ): super().__init__( @@ -68,7 +65,7 @@ def __init__( self.dump = dump self.load = load - async def _process(self, id: str, shards: list[pa.Buffer]) -> None: + async def _process(self, id: str, shards: list[ShardType]) -> None: """Write one buffer to file This function was built to offload the disk IO, but since then we've @@ -88,13 +85,9 @@ async def _process(self, id: str, shards: list[pa.Buffer]) -> None: with open( self.directory / str(id), mode="ab", buffering=100_000_000 ) as f: - for shard in shards: - self.dump(shard, f) - # os.fsync(f) # TODO: maybe? - - def read(self, id: int | str) -> pa.Table: - import pyarrow as pa + self.dump(shards, f) + def read(self, id: int | str) -> list[ShardType]: """Read a complete file back into memory""" self.raise_on_exception() if not self._inputs_done: @@ -106,11 +99,7 @@ def read(self, id: int | str) -> pa.Table: with open( self.directory / str(id), mode="rb", buffering=100_000_000 ) as f: - while True: - try: - parts.append(self.load(f)) - except EOFError: - break + parts = self.load(f) size = f.tell() except FileNotFoundError: raise KeyError(id) @@ -118,7 +107,7 @@ def read(self, id: int | str) -> pa.Table: # TODO: We could consider deleting the file at this point if parts: self.bytes_read += size - return pa.concat_tables(parts) + return parts else: raise KeyError(id) diff --git a/distributed/shuffle/_shuffle_extension.py b/distributed/shuffle/_shuffle_extension.py index 6be0b66479a..fe0940e1845 100644 --- a/distributed/shuffle/_shuffle_extension.py +++ b/distributed/shuffle/_shuffle_extension.py @@ -20,9 +20,9 @@ from distributed.protocol import to_serialize from distributed.shuffle._arrow import ( deserialize_schema, - dump_table, + dump_table_batch, list_of_buffers_to_table, - load_arrow, + load_into_table, serialize_table, ) from distributed.shuffle._comms import CommShardsBuffer @@ -124,12 +124,12 @@ def __init__( self.worker_for = pd.Series(worker_for, name="_workers").astype("category") self.closed = False - def _dump_table(table: pa.Table, file: BinaryIO) -> None: - return dump_table(table, file) + def _dump_table_batch(tables: list[pa.Table], file: BinaryIO) -> None: + return dump_table_batch(tables, file) self._disk_buffer = DiskShardsBuffer( - dump=_dump_table, - load=load_arrow, + dump=_dump_table_batch, + load=load_into_table, directory=directory, memory_limiter=memory_limiter_disk, ) diff --git a/distributed/shuffle/tests/test_disk_buffer.py b/distributed/shuffle/tests/test_disk_buffer.py index 915e8a5babe..7880a7e2f1c 100644 --- a/distributed/shuffle/tests/test_disk_buffer.py +++ b/distributed/shuffle/tests/test_disk_buffer.py @@ -10,6 +10,7 @@ def dump(data, f): + data = b"".join(data) f.write(data) diff --git a/distributed/shuffle/tests/test_shuffle.py b/distributed/shuffle/tests/test_shuffle.py index cb3349f6578..cdfc3571319 100644 --- a/distributed/shuffle/tests/test_shuffle.py +++ b/distributed/shuffle/tests/test_shuffle.py @@ -27,10 +27,10 @@ Shuffle, ShuffleId, ShuffleWorkerExtension, - dump_table, + dump_table_batch, get_worker_for, list_of_buffers_to_table, - load_arrow, + load_into_table, split_by_partition, split_by_worker, ) @@ -712,13 +712,12 @@ def test_processing_chain(): for partitions in splits_by_worker.values(): for partition, tables in partitions.items(): - for table in tables: - dump_table(table, filesystem[partition]) + dump_table_batch(tables, filesystem[partition]) out = {} for k, bio in filesystem.items(): bio.seek(0) - out[k] = load_arrow(bio) + out[k] = load_into_table(bio) assert sum(map(len, out.values())) == len(df) From d4d2e4c1b4018a4c817c5bb30c382264b6cef457 Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Fri, 16 Dec 2022 15:22:48 +0100 Subject: [PATCH 07/23] Add all dtypes --- distributed/shuffle/tests/test_shuffle.py | 60 +++++++++++++++++++++-- 1 file changed, 56 insertions(+), 4 deletions(-) diff --git a/distributed/shuffle/tests/test_shuffle.py b/distributed/shuffle/tests/test_shuffle.py index 7216b7f1c55..5dbf03fb366 100644 --- a/distributed/shuffle/tests/test_shuffle.py +++ b/distributed/shuffle/tests/test_shuffle.py @@ -6,10 +6,11 @@ import random import shutil from collections import defaultdict -from itertools import chain +from itertools import count from typing import Any, Mapping from unittest import mock +import numpy as np import pandas as pd import pytest @@ -656,13 +657,63 @@ def test_processing_chain(): In practice this takes place on many different workers. Here we verify its accuracy in a single threaded situation. """ + counter = count() workers = ["a", "b", "c"] npartitions = 5 df = pd.DataFrame( - {"x": range(100), "y": range(100), "z": chain(range(50), range(50))} + { + f"col{next(counter)}": pd.array([True, False] * 50, dtype="bool"), + f"col{next(counter)}": pd.array([True, False] * 50, dtype="boolean"), + f"col{next(counter)}": pd.array(range(100), dtype="int8"), + f"col{next(counter)}": pd.array(range(100), dtype="int16"), + f"col{next(counter)}": pd.array(range(100), dtype="int32"), + f"col{next(counter)}": pd.array(range(100), dtype="int64"), + f"col{next(counter)}": pd.array(range(100), dtype="uint8"), + f"col{next(counter)}": pd.array(range(100), dtype="uint16"), + f"col{next(counter)}": pd.array(range(100), dtype="uint32"), + f"col{next(counter)}": pd.array(range(100), dtype="uint64"), + f"col{next(counter)}": pd.array(range(100), dtype="float16"), + f"col{next(counter)}": pd.array(range(100), dtype="float32"), + f"col{next(counter)}": pd.array(range(100), dtype="float64"), + # f"col{next(counter)}": pd.array(range(100), dtype="csingle"), + # f"col{next(counter)}": pd.array(range(100), dtype="cdouble"), + # f"col{next(counter)}": pd.array(range(100), dtype="clongdouble"), + f"col{next(counter)}": pd.array( + [np.datetime64("2022-01-01")] * 100, dtype="datetime64" + ), + f"col{next(counter)}": pd.array( + [np.datetime64("2022-01-01")] * 100, + dtype=pd.DatetimeTZDtype(tz="UTC"), + ), + f"col{next(counter)}": pd.array( + [np.timedelta64(1, "D")] * 100, dtype="timedelta64" + ), + f"col{next(counter)}": pd.array( + [pd.Period("2022-01-01", freq="D")] * 100, dtype="period[D]" + ), + f"col{next(counter)}": pd.array( + [pd.Interval(left=0, right=5)] * 100, dtype="Interval" + ), + f"col{next(counter)}": pd.array(range(100), dtype="Int8"), + f"col{next(counter)}": pd.array(range(100), dtype="Int16"), + f"col{next(counter)}": pd.array(range(100), dtype="Int32"), + f"col{next(counter)}": pd.array(range(100), dtype="Int64"), + f"col{next(counter)}": pd.array(range(100), dtype="UInt8"), + f"col{next(counter)}": pd.array(range(100), dtype="UInt16"), + f"col{next(counter)}": pd.array(range(100), dtype="UInt32"), + f"col{next(counter)}": pd.array(range(100), dtype="UInt64"), + f"col{next(counter)}": pd.array(["x", "y"] * 50, dtype="category"), + # f"col{next(counter)}": pd.array( + # [np.nan, np.nan, 1.0, np.nan, np.nan] * 20, + # dtype="Sparse[float64]", + # ), + f"col{next(counter)}": pd.array(["lorem ipsum"] * 100, dtype="string"), + # f"col{next(counter)}": pd.array( + # [object() for _ in range(100)], dtype="object" + # ), + } ) - df["z"] = df["z"].astype("category") - df["_partitions"] = df.x % npartitions + df["_partitions"] = df.col3 % npartitions schema = pa.Schema.from_pandas(df) worker_for = {i: random.choice(workers) for i in list(range(npartitions))} worker_for = pd.Series(worker_for, name="_worker").astype("category") @@ -715,6 +766,7 @@ def test_processing_chain(): out[k] = load_into_table(bio) assert sum(map(len, out.values())) == len(df) + assert all(v.to_pandas().dtypes.equals(df.dtypes) for v in out.values()) @gen_cluster(client=True) From 41ec668110669ed4a4cc4a7714650e7e128fe90b Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Fri, 16 Dec 2022 15:27:26 +0100 Subject: [PATCH 08/23] Stub class --- distributed/shuffle/tests/test_shuffle.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/distributed/shuffle/tests/test_shuffle.py b/distributed/shuffle/tests/test_shuffle.py index 5dbf03fb366..0b4abb04d2e 100644 --- a/distributed/shuffle/tests/test_shuffle.py +++ b/distributed/shuffle/tests/test_shuffle.py @@ -657,6 +657,11 @@ def test_processing_chain(): In practice this takes place on many different workers. Here we verify its accuracy in a single threaded situation. """ + + class Stub: + def __init__(self, value: int) -> None: + self.value = value + counter = count() workers = ["a", "b", "c"] npartitions = 5 @@ -709,7 +714,7 @@ def test_processing_chain(): # ), f"col{next(counter)}": pd.array(["lorem ipsum"] * 100, dtype="string"), # f"col{next(counter)}": pd.array( - # [object() for _ in range(100)], dtype="object" + # [Stub(i) for i in range(100)], dtype="object" # ), } ) From c10de32c2c2dcd89d4d0d4ce3ea65df1f8310744 Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Fri, 16 Dec 2022 16:13:06 +0100 Subject: [PATCH 09/23] Renaming --- distributed/shuffle/_arrow.py | 24 +++++++++++------------ distributed/shuffle/_worker_extension.py | 8 ++++---- distributed/shuffle/tests/test_shuffle.py | 8 ++++---- 3 files changed, 20 insertions(+), 20 deletions(-) diff --git a/distributed/shuffle/_arrow.py b/distributed/shuffle/_arrow.py index 82566e80880..2d323045767 100644 --- a/distributed/shuffle/_arrow.py +++ b/distributed/shuffle/_arrow.py @@ -6,50 +6,50 @@ import pyarrow as pa -def dump_table_batch(tables: list[pa.Table], file: BinaryIO) -> None: +def dump_shards(shards: list[pa.Table], file: BinaryIO) -> None: """ - Dump multiple tables to the file + Write multiple shard tables to the file Note: This function appends to the file and dumps each table as an individual stream. This results in multiple end-of-stream signals in the file. See Also -------- - load_arrow + load_partition """ import pyarrow as pa - for table in tables: + for table in shards: with pa.ipc.new_stream(file, table.schema) as writer: writer.write_table(table) -def load_into_table(file: BinaryIO) -> pa.Table: - """Load batched data written to file back out into a single table +def load_partition(file: BinaryIO) -> pa.Table: + """Load partition data written to file back out into a single table Example ------- >>> tables = [pa.Table.from_pandas(df), pa.Table.from_pandas(df2)] # doctest: +SKIP >>> with open("myfile", mode="wb") as f: # doctest: +SKIP ... for table in tables: # doctest: +SKIP - ... dump_table_batch(tables, f, schema=t.schema) # doctest: +SKIP + ... dump_shards(tables, f, schema=t.schema) # doctest: +SKIP >>> with open("myfile", mode="rb") as f: # doctest: +SKIP - ... t = load_into_table(f) # doctest: +SKIP + ... t = load_partition(f) # doctest: +SKIP See Also -------- - dump_table_batch + dump_shards """ import pyarrow as pa - tables = [] + shards = [] try: while True: sr = pa.RecordBatchStreamReader(file) - tables.append(sr.read_all()) + shards.append(sr.read_all()) except pa.ArrowInvalid: - return pa.concat_tables(tables) + return pa.concat_tables(shards) def list_of_buffers_to_table(data: list[bytes], schema: pa.Schema) -> pa.Table: diff --git a/distributed/shuffle/_worker_extension.py b/distributed/shuffle/_worker_extension.py index 9632277c7ed..1235f60161f 100644 --- a/distributed/shuffle/_worker_extension.py +++ b/distributed/shuffle/_worker_extension.py @@ -19,9 +19,9 @@ from distributed.protocol import to_serialize from distributed.shuffle._arrow import ( deserialize_schema, - dump_table_batch, + dump_shards, list_of_buffers_to_table, - load_into_table, + load_partition, serialize_table, ) from distributed.shuffle._comms import CommShardsBuffer @@ -123,11 +123,11 @@ def __init__( self.closed = False def _dump_table_batch(tables: list[pa.Table], file: BinaryIO) -> None: - return dump_table_batch(tables, file) + return dump_shards(tables, file) self._disk_buffer = DiskShardsBuffer( dump=_dump_table_batch, - load=load_into_table, + load=load_partition, directory=directory, memory_limiter=memory_limiter_disk, ) diff --git a/distributed/shuffle/tests/test_shuffle.py b/distributed/shuffle/tests/test_shuffle.py index 0b4abb04d2e..59c9b50745e 100644 --- a/distributed/shuffle/tests/test_shuffle.py +++ b/distributed/shuffle/tests/test_shuffle.py @@ -29,9 +29,9 @@ from distributed.shuffle._worker_extension import ( Shuffle, ShuffleWorkerExtension, - dump_table_batch, + dump_shards, list_of_buffers_to_table, - load_into_table, + load_partition, split_by_partition, split_by_worker, ) @@ -763,12 +763,12 @@ def __init__(self, value: int) -> None: for partitions in splits_by_worker.values(): for partition, tables in partitions.items(): - dump_table_batch(tables, filesystem[partition]) + dump_shards(tables, filesystem[partition]) out = {} for k, bio in filesystem.items(): bio.seek(0) - out[k] = load_into_table(bio) + out[k] = load_partition(bio) assert sum(map(len, out.values())) == len(df) assert all(v.to_pandas().dtypes.equals(df.dtypes) for v in out.values()) From fe1fcbdc66f630db9020b407e12aa2e2c947e048 Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Fri, 16 Dec 2022 16:21:41 +0100 Subject: [PATCH 10/23] improve check --- distributed/shuffle/_arrow.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/distributed/shuffle/_arrow.py b/distributed/shuffle/_arrow.py index 2d323045767..c1a6f677b74 100644 --- a/distributed/shuffle/_arrow.py +++ b/distributed/shuffle/_arrow.py @@ -48,8 +48,12 @@ def load_partition(file: BinaryIO) -> pa.Table: while True: sr = pa.RecordBatchStreamReader(file) shards.append(sr.read_all()) - except pa.ArrowInvalid: - return pa.concat_tables(shards) + # Since we write multiple streams to the same file, we have to read until + # there is nothing to read anymore. At that point, pa.ArrowInvalid is raised + except pa.ArrowInvalid as e: + if str(e) == "Tried reading schema message, was null or length 0": + return pa.concat_tables(shards) + raise def list_of_buffers_to_table(data: list[bytes], schema: pa.Schema) -> pa.Table: From d734ec3af741642c218d83d4046aa9922d16c048 Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Fri, 16 Dec 2022 17:13:27 +0100 Subject: [PATCH 11/23] Naming --- distributed/shuffle/_worker_extension.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/distributed/shuffle/_worker_extension.py b/distributed/shuffle/_worker_extension.py index 1235f60161f..e5e206e3eeb 100644 --- a/distributed/shuffle/_worker_extension.py +++ b/distributed/shuffle/_worker_extension.py @@ -122,11 +122,11 @@ def __init__( self.worker_for = pd.Series(worker_for, name="_workers").astype("category") self.closed = False - def _dump_table_batch(tables: list[pa.Table], file: BinaryIO) -> None: - return dump_shards(tables, file) + def _dump_shards(shards: list[pa.Table], file: BinaryIO) -> None: + return dump_shards(shards, file) self._disk_buffer = DiskShardsBuffer( - dump=_dump_table_batch, + dump=_dump_shards, load=load_partition, directory=directory, memory_limiter=memory_limiter_disk, From a224f3b2def65945c55893de5d13101d79facbf2 Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Fri, 16 Dec 2022 17:28:43 +0100 Subject: [PATCH 12/23] improve test --- distributed/shuffle/tests/test_shuffle.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/distributed/shuffle/tests/test_shuffle.py b/distributed/shuffle/tests/test_shuffle.py index 59c9b50745e..e531bad3f84 100644 --- a/distributed/shuffle/tests/test_shuffle.py +++ b/distributed/shuffle/tests/test_shuffle.py @@ -665,6 +665,8 @@ def __init__(self, value: int) -> None: counter = count() workers = ["a", "b", "c"] npartitions = 5 + + # Test the processing chain with a dataframe that contains all possible dtypes df = pd.DataFrame( { f"col{next(counter)}": pd.array([True, False] * 50, dtype="bool"), @@ -684,20 +686,22 @@ def __init__(self, value: int) -> None: # f"col{next(counter)}": pd.array(range(100), dtype="cdouble"), # f"col{next(counter)}": pd.array(range(100), dtype="clongdouble"), f"col{next(counter)}": pd.array( - [np.datetime64("2022-01-01")] * 100, dtype="datetime64" + [np.datetime64("2022-01-01") + i for i in range(100)], + dtype="datetime64", ), f"col{next(counter)}": pd.array( - [np.datetime64("2022-01-01")] * 100, + [np.datetime64("2022-01-01") + i for i in range(100)], dtype=pd.DatetimeTZDtype(tz="UTC"), ), f"col{next(counter)}": pd.array( - [np.timedelta64(1, "D")] * 100, dtype="timedelta64" + [np.timedelta64(1, "D") + i for i in range(100)], dtype="timedelta64" ), f"col{next(counter)}": pd.array( - [pd.Period("2022-01-01", freq="D")] * 100, dtype="period[D]" + [pd.Period("2022-01-01", freq="D") + i for i in range(100)], + dtype="period[D]", ), f"col{next(counter)}": pd.array( - [pd.Interval(left=0, right=5)] * 100, dtype="Interval" + [pd.Interval(left=i, right=i + 2) for i in range(100)], dtype="Interval" ), f"col{next(counter)}": pd.array(range(100), dtype="Int8"), f"col{next(counter)}": pd.array(range(100), dtype="Int16"), From 841144d6574d65611c83f3ae14f03c8b0f95cd1a Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Fri, 16 Dec 2022 17:29:07 +0100 Subject: [PATCH 13/23] Minor --- distributed/shuffle/tests/test_shuffle.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/distributed/shuffle/tests/test_shuffle.py b/distributed/shuffle/tests/test_shuffle.py index e531bad3f84..91208ce3f58 100644 --- a/distributed/shuffle/tests/test_shuffle.py +++ b/distributed/shuffle/tests/test_shuffle.py @@ -666,7 +666,7 @@ def __init__(self, value: int) -> None: workers = ["a", "b", "c"] npartitions = 5 - # Test the processing chain with a dataframe that contains all possible dtypes + # Test the processing chain with a dataframe that contains all supported dtypes df = pd.DataFrame( { f"col{next(counter)}": pd.array([True, False] * 50, dtype="bool"), From 4ddfbc6bf5fefa3c2031add1b949c895019b388b Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Fri, 16 Dec 2022 20:10:05 +0100 Subject: [PATCH 14/23] Fix receiving --- distributed/shuffle/_arrow.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/distributed/shuffle/_arrow.py b/distributed/shuffle/_arrow.py index c1a6f677b74..43c2241a076 100644 --- a/distributed/shuffle/_arrow.py +++ b/distributed/shuffle/_arrow.py @@ -58,11 +58,21 @@ def load_partition(file: BinaryIO) -> pa.Table: def list_of_buffers_to_table(data: list[bytes], schema: pa.Schema) -> pa.Table: """Convert a list of arrow buffers and a schema to an Arrow Table""" + import io + import pyarrow as pa + tables = [] assert len(data) == 1 - with pa.ipc.open_stream(pa.py_buffer(data[0])) as reader: - return reader.read_all() + buffer = data[0] + with io.BytesIO(buffer) as stream: + while True: + try: + with pa.ipc.open_stream(stream) as reader: + tables.append(reader.read_all()) + except Exception: + break + return pa.concat_tables(tables) def deserialize_schema(data: bytes) -> pa.Schema: From 5a21c5188c98d3f84cf1882c57e771ef50bcfd19 Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Fri, 16 Dec 2022 20:14:21 +0100 Subject: [PATCH 15/23] Simplify --- distributed/shuffle/_arrow.py | 14 +++----------- distributed/shuffle/_comms.py | 7 +------ 2 files changed, 4 insertions(+), 17 deletions(-) diff --git a/distributed/shuffle/_arrow.py b/distributed/shuffle/_arrow.py index 43c2241a076..2a145adf472 100644 --- a/distributed/shuffle/_arrow.py +++ b/distributed/shuffle/_arrow.py @@ -58,20 +58,12 @@ def load_partition(file: BinaryIO) -> pa.Table: def list_of_buffers_to_table(data: list[bytes], schema: pa.Schema) -> pa.Table: """Convert a list of arrow buffers and a schema to an Arrow Table""" - import io - import pyarrow as pa tables = [] - assert len(data) == 1 - buffer = data[0] - with io.BytesIO(buffer) as stream: - while True: - try: - with pa.ipc.open_stream(stream) as reader: - tables.append(reader.read_all()) - except Exception: - break + for buffer in data: + with pa.ipc.open_stream(pa.py_buffer(buffer)) as reader: + tables.append(reader.read_all()) return pa.concat_tables(tables) diff --git a/distributed/shuffle/_comms.py b/distributed/shuffle/_comms.py index e679cf34250..39b667de413 100644 --- a/distributed/shuffle/_comms.py +++ b/distributed/shuffle/_comms.py @@ -67,9 +67,4 @@ async def _process(self, address: str, shards: list[bytes]) -> None: # Consider boosting total_size a bit here to account for duplication with self.time("send"): - await self.send(address, [_join_shards(shards)]) - - -def _join_shards(shards: list[bytes]) -> bytes: - # This is just there for easier profiling - return b"".join(shards) + await self.send(address, shards) From b5c8e143da56865799a00e30ce57d84bb65d4dc0 Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Mon, 19 Dec 2022 18:29:11 +0100 Subject: [PATCH 16/23] Add pyarrow dtypes --- distributed/shuffle/tests/test_shuffle.py | 56 +++++++++++++---------- 1 file changed, 33 insertions(+), 23 deletions(-) diff --git a/distributed/shuffle/tests/test_shuffle.py b/distributed/shuffle/tests/test_shuffle.py index 91208ce3f58..be00c6d1f64 100644 --- a/distributed/shuffle/tests/test_shuffle.py +++ b/distributed/shuffle/tests/test_shuffle.py @@ -666,11 +666,12 @@ def __init__(self, value: int) -> None: workers = ["a", "b", "c"] npartitions = 5 + # FIXME: csingle, cdouble, clongdouble, sparse and object not supported # Test the processing chain with a dataframe that contains all supported dtypes df = pd.DataFrame( { + # numpy dtypes f"col{next(counter)}": pd.array([True, False] * 50, dtype="bool"), - f"col{next(counter)}": pd.array([True, False] * 50, dtype="boolean"), f"col{next(counter)}": pd.array(range(100), dtype="int8"), f"col{next(counter)}": pd.array(range(100), dtype="int16"), f"col{next(counter)}": pd.array(range(100), dtype="int32"), @@ -682,27 +683,15 @@ def __init__(self, value: int) -> None: f"col{next(counter)}": pd.array(range(100), dtype="float16"), f"col{next(counter)}": pd.array(range(100), dtype="float32"), f"col{next(counter)}": pd.array(range(100), dtype="float64"), - # f"col{next(counter)}": pd.array(range(100), dtype="csingle"), - # f"col{next(counter)}": pd.array(range(100), dtype="cdouble"), - # f"col{next(counter)}": pd.array(range(100), dtype="clongdouble"), f"col{next(counter)}": pd.array( [np.datetime64("2022-01-01") + i for i in range(100)], dtype="datetime64", ), - f"col{next(counter)}": pd.array( - [np.datetime64("2022-01-01") + i for i in range(100)], - dtype=pd.DatetimeTZDtype(tz="UTC"), - ), f"col{next(counter)}": pd.array( [np.timedelta64(1, "D") + i for i in range(100)], dtype="timedelta64" ), - f"col{next(counter)}": pd.array( - [pd.Period("2022-01-01", freq="D") + i for i in range(100)], - dtype="period[D]", - ), - f"col{next(counter)}": pd.array( - [pd.Interval(left=i, right=i + 2) for i in range(100)], dtype="Interval" - ), + # Nullable dtypes + f"col{next(counter)}": pd.array([True, False] * 50, dtype="boolean"), f"col{next(counter)}": pd.array(range(100), dtype="Int8"), f"col{next(counter)}": pd.array(range(100), dtype="Int16"), f"col{next(counter)}": pd.array(range(100), dtype="Int32"), @@ -711,18 +700,39 @@ def __init__(self, value: int) -> None: f"col{next(counter)}": pd.array(range(100), dtype="UInt16"), f"col{next(counter)}": pd.array(range(100), dtype="UInt32"), f"col{next(counter)}": pd.array(range(100), dtype="UInt64"), + # pandas dtypes + f"col{next(counter)}": pd.array( + [np.datetime64("2022-01-01") + i for i in range(100)], + dtype=pd.DatetimeTZDtype(tz="Europe/Berlin"), + ), + f"col{next(counter)}": pd.array( + [pd.Period("2022-01-01", freq="D") + i for i in range(100)], + dtype="period[D]", + ), + f"col{next(counter)}": pd.array( + [pd.Interval(left=i, right=i + 2) for i in range(100)], dtype="Interval" + ), f"col{next(counter)}": pd.array(["x", "y"] * 50, dtype="category"), - # f"col{next(counter)}": pd.array( - # [np.nan, np.nan, 1.0, np.nan, np.nan] * 20, - # dtype="Sparse[float64]", - # ), f"col{next(counter)}": pd.array(["lorem ipsum"] * 100, dtype="string"), - # f"col{next(counter)}": pd.array( - # [Stub(i) for i in range(100)], dtype="object" - # ), + # PyArrow dtypes + f"col{next(counter)}": pd.array([True, False] * 50, dtype="bool[pyarrow]"), + f"col{next(counter)}": pd.array(range(100), dtype="int8[pyarrow]"), + f"col{next(counter)}": pd.array(range(100), dtype="int16[pyarrow]"), + f"col{next(counter)}": pd.array(range(100), dtype="int32[pyarrow]"), + f"col{next(counter)}": pd.array(range(100), dtype="int64[pyarrow]"), + f"col{next(counter)}": pd.array(range(100), dtype="uint8[pyarrow]"), + f"col{next(counter)}": pd.array(range(100), dtype="uint16[pyarrow]"), + f"col{next(counter)}": pd.array(range(100), dtype="uint32[pyarrow]"), + f"col{next(counter)}": pd.array(range(100), dtype="uint64[pyarrow]"), + f"col{next(counter)}": pd.array(range(100), dtype="float32[pyarrow]"), + f"col{next(counter)}": pd.array(range(100), dtype="float64[pyarrow]"), + f"col{next(counter)}": pd.array( + [pd.Timestamp.fromtimestamp(1641034800 + i) for i in range(100)], + dtype=pd.ArrowDtype(pa.timestamp("ms")), + ), } ) - df["_partitions"] = df.col3 % npartitions + df["_partitions"] = df.col4 % npartitions schema = pa.Schema.from_pandas(df) worker_for = {i: random.choice(workers) for i in list(range(npartitions))} worker_for = pd.Series(worker_for, name="_worker").astype("category") From 9101b975deff58145fc19f8dcd3c152261a6dc49 Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Mon, 19 Dec 2022 18:47:09 +0100 Subject: [PATCH 17/23] string[pyarrow] --- distributed/shuffle/tests/test_shuffle.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/distributed/shuffle/tests/test_shuffle.py b/distributed/shuffle/tests/test_shuffle.py index be00c6d1f64..bf59a1a56ab 100644 --- a/distributed/shuffle/tests/test_shuffle.py +++ b/distributed/shuffle/tests/test_shuffle.py @@ -730,6 +730,15 @@ def __init__(self, value: int) -> None: [pd.Timestamp.fromtimestamp(1641034800 + i) for i in range(100)], dtype=pd.ArrowDtype(pa.timestamp("ms")), ), + # FIXME: distributed#7420 + # f"col{next(counter)}": pd.array( + # ["lorem ipsum"] * 100, + # dtype="string[pyarrow]", + # ), + # f"col{next(counter)}": pd.array( + # ["lorem ipsum"] * 100, + # dtype=pd.StringDtype("pyarrow"), + # ), } ) df["_partitions"] = df.col4 % npartitions From 60257dec858e63f76e45925d5f11c7bdf1b22060 Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Mon, 19 Dec 2022 18:53:22 +0100 Subject: [PATCH 18/23] Add commented-out blocks for unsupported dtypes --- distributed/shuffle/tests/test_shuffle.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/distributed/shuffle/tests/test_shuffle.py b/distributed/shuffle/tests/test_shuffle.py index bf59a1a56ab..3f30ac101e4 100644 --- a/distributed/shuffle/tests/test_shuffle.py +++ b/distributed/shuffle/tests/test_shuffle.py @@ -666,7 +666,6 @@ def __init__(self, value: int) -> None: workers = ["a", "b", "c"] npartitions = 5 - # FIXME: csingle, cdouble, clongdouble, sparse and object not supported # Test the processing chain with a dataframe that contains all supported dtypes df = pd.DataFrame( { @@ -690,6 +689,10 @@ def __init__(self, value: int) -> None: f"col{next(counter)}": pd.array( [np.timedelta64(1, "D") + i for i in range(100)], dtype="timedelta64" ), + # FIXME: PyArrow does not support complex numbers: https://issues.apache.org/jira/browse/ARROW-638 + # f"col{next(counter)}": pd.array(range(100), dtype="csingle"), + # f"col{next(counter)}": pd.array(range(100), dtype="cdouble"), + # f"col{next(counter)}": pd.array(range(100), dtype="clongdouble"), # Nullable dtypes f"col{next(counter)}": pd.array([True, False] * 50, dtype="boolean"), f"col{next(counter)}": pd.array(range(100), dtype="Int8"), @@ -714,6 +717,11 @@ def __init__(self, value: int) -> None: ), f"col{next(counter)}": pd.array(["x", "y"] * 50, dtype="category"), f"col{next(counter)}": pd.array(["lorem ipsum"] * 100, dtype="string"), + # FIXME: PyArrow does not support sparse data: https://issues.apache.org/jira/browse/ARROW-8679 + # f"col{next(counter)}": pd.array( + # [np.nan, np.nan, 1.0, np.nan, np.nan] * 20, + # dtype="Sparse[float64]", + # ), # PyArrow dtypes f"col{next(counter)}": pd.array([True, False] * 50, dtype="bool[pyarrow]"), f"col{next(counter)}": pd.array(range(100), dtype="int8[pyarrow]"), @@ -739,6 +747,11 @@ def __init__(self, value: int) -> None: # ["lorem ipsum"] * 100, # dtype=pd.StringDtype("pyarrow"), # ), + # custom objects + # FIXME: Serializing custom objects is not supported in P2P shuffling + # f"col{next(counter)}": pd.array( + # [Stub(i) for i in range(100)], dtype="object" + # ), } ) df["_partitions"] = df.col4 % npartitions From c5d996da159c0cf8db9fa55216a068a6dba26973 Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Tue, 20 Dec 2022 15:44:05 +0100 Subject: [PATCH 19/23] Remove unnecessary param --- distributed/shuffle/_arrow.py | 4 ++-- distributed/shuffle/_worker_extension.py | 2 +- distributed/shuffle/tests/test_shuffle.py | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/distributed/shuffle/_arrow.py b/distributed/shuffle/_arrow.py index 2a145adf472..8c86d523a0f 100644 --- a/distributed/shuffle/_arrow.py +++ b/distributed/shuffle/_arrow.py @@ -32,7 +32,7 @@ def load_partition(file: BinaryIO) -> pa.Table: >>> tables = [pa.Table.from_pandas(df), pa.Table.from_pandas(df2)] # doctest: +SKIP >>> with open("myfile", mode="wb") as f: # doctest: +SKIP ... for table in tables: # doctest: +SKIP - ... dump_shards(tables, f, schema=t.schema) # doctest: +SKIP + ... dump_shards(tables, f) # doctest: +SKIP >>> with open("myfile", mode="rb") as f: # doctest: +SKIP ... t = load_partition(f) # doctest: +SKIP @@ -56,7 +56,7 @@ def load_partition(file: BinaryIO) -> pa.Table: raise -def list_of_buffers_to_table(data: list[bytes], schema: pa.Schema) -> pa.Table: +def list_of_buffers_to_table(data: list[bytes]) -> pa.Table: """Convert a list of arrow buffers and a schema to an Arrow Table""" import pyarrow as pa diff --git a/distributed/shuffle/_worker_extension.py b/distributed/shuffle/_worker_extension.py index e5e206e3eeb..750b4fec8ec 100644 --- a/distributed/shuffle/_worker_extension.py +++ b/distributed/shuffle/_worker_extension.py @@ -203,7 +203,7 @@ async def _receive(self, data: list[bytes]) -> None: raise def _repartition_buffers(self, data: list[bytes]) -> dict[str, list[pa.Table]]: - table = list_of_buffers_to_table(data, self.schema) + table = list_of_buffers_to_table(data) groups = split_by_partition(table, self.column) assert len(table) == sum(map(len, groups.values())) del data diff --git a/distributed/shuffle/tests/test_shuffle.py b/distributed/shuffle/tests/test_shuffle.py index 3f30ac101e4..58c97689460 100644 --- a/distributed/shuffle/tests/test_shuffle.py +++ b/distributed/shuffle/tests/test_shuffle.py @@ -769,7 +769,7 @@ def __init__(self, value: int) -> None: # We then receive them back and reconstute them by_worker = { - worker: list_of_buffers_to_table(list_of_batches, schema) + worker: list_of_buffers_to_table(list_of_batches) for worker, list_of_batches in batches.items() } assert sum(map(len, by_worker.values())) == len(df) From 7a777f47425b2d966a4ed368fb11623ab16442d6 Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Tue, 20 Dec 2022 17:41:36 +0100 Subject: [PATCH 20/23] Deserialization method --- distributed/shuffle/_arrow.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/distributed/shuffle/_arrow.py b/distributed/shuffle/_arrow.py index 8c86d523a0f..db66879526f 100644 --- a/distributed/shuffle/_arrow.py +++ b/distributed/shuffle/_arrow.py @@ -62,8 +62,7 @@ def list_of_buffers_to_table(data: list[bytes]) -> pa.Table: tables = [] for buffer in data: - with pa.ipc.open_stream(pa.py_buffer(buffer)) as reader: - tables.append(reader.read_all()) + tables.append(deserialize_table(buffer)) return pa.concat_tables(tables) @@ -101,3 +100,10 @@ def serialize_table(table: pa.Table) -> bytes: with pa.ipc.new_stream(stream, table.schema) as writer: writer.write_table(table) return stream.getvalue() + + +def deserialize_table(buffer: bytes) -> pa.Table: + import pyarrow as pa + + with pa.ipc.open_stream(pa.py_buffer(buffer)) as reader: + return reader.read_all() From 9a2576f823bc6040ab8835fdf6b9238150cdf6ad Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Wed, 21 Dec 2022 09:25:10 +0100 Subject: [PATCH 21/23] Apply suggestions from code review Co-authored-by: James Bourbeau --- distributed/shuffle/_arrow.py | 5 +---- distributed/shuffle/_worker_extension.py | 2 -- 2 files changed, 1 insertion(+), 6 deletions(-) diff --git a/distributed/shuffle/_arrow.py b/distributed/shuffle/_arrow.py index db66879526f..69aba85ad7e 100644 --- a/distributed/shuffle/_arrow.py +++ b/distributed/shuffle/_arrow.py @@ -60,10 +60,7 @@ def list_of_buffers_to_table(data: list[bytes]) -> pa.Table: """Convert a list of arrow buffers and a schema to an Arrow Table""" import pyarrow as pa - tables = [] - for buffer in data: - tables.append(deserialize_table(buffer)) - return pa.concat_tables(tables) + return pa.concat_tables(deserialize_table(buffer) for buffer in data) def deserialize_schema(data: bytes) -> pa.Schema: diff --git a/distributed/shuffle/_worker_extension.py b/distributed/shuffle/_worker_extension.py index 750b4fec8ec..1bd0f4e3ef0 100644 --- a/distributed/shuffle/_worker_extension.py +++ b/distributed/shuffle/_worker_extension.py @@ -122,8 +122,6 @@ def __init__( self.worker_for = pd.Series(worker_for, name="_workers").astype("category") self.closed = False - def _dump_shards(shards: list[pa.Table], file: BinaryIO) -> None: - return dump_shards(shards, file) self._disk_buffer = DiskShardsBuffer( dump=_dump_shards, From b825d41bb1755d43e4d79aef1fe1c6474a594991 Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Wed, 21 Dec 2022 10:20:25 +0100 Subject: [PATCH 22/23] Improve loading --- distributed/shuffle/_arrow.py | 20 ++++++++++---------- distributed/shuffle/_worker_extension.py | 5 ++--- 2 files changed, 12 insertions(+), 13 deletions(-) diff --git a/distributed/shuffle/_arrow.py b/distributed/shuffle/_arrow.py index 69aba85ad7e..72b4c23fd7e 100644 --- a/distributed/shuffle/_arrow.py +++ b/distributed/shuffle/_arrow.py @@ -41,19 +41,19 @@ def load_partition(file: BinaryIO) -> pa.Table: -------- dump_shards """ + import os + import pyarrow as pa + pos = file.tell() + file.seek(0, os.SEEK_END) + end = file.tell() + file.seek(pos) shards = [] - try: - while True: - sr = pa.RecordBatchStreamReader(file) - shards.append(sr.read_all()) - # Since we write multiple streams to the same file, we have to read until - # there is nothing to read anymore. At that point, pa.ArrowInvalid is raised - except pa.ArrowInvalid as e: - if str(e) == "Tried reading schema message, was null or length 0": - return pa.concat_tables(shards) - raise + while file.tell() < end: + sr = pa.RecordBatchStreamReader(file) + shards.append(sr.read_all()) + return pa.concat_tables(shards) def list_of_buffers_to_table(data: list[bytes]) -> pa.Table: diff --git a/distributed/shuffle/_worker_extension.py b/distributed/shuffle/_worker_extension.py index 1bd0f4e3ef0..5ff16844532 100644 --- a/distributed/shuffle/_worker_extension.py +++ b/distributed/shuffle/_worker_extension.py @@ -9,7 +9,7 @@ from collections import defaultdict from collections.abc import Callable, Iterator from concurrent.futures import ThreadPoolExecutor -from typing import TYPE_CHECKING, Any, BinaryIO, TypeVar, overload +from typing import TYPE_CHECKING, Any, TypeVar, overload import toolz @@ -122,9 +122,8 @@ def __init__( self.worker_for = pd.Series(worker_for, name="_workers").astype("category") self.closed = False - self._disk_buffer = DiskShardsBuffer( - dump=_dump_shards, + dump=dump_shards, load=load_partition, directory=directory, memory_limiter=memory_limiter_disk, From 5fb858d17ce68c59fa18e0a760a67b1b36a77b45 Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Wed, 21 Dec 2022 10:51:06 +0100 Subject: [PATCH 23/23] Improve test --- distributed/shuffle/tests/test_shuffle.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/distributed/shuffle/tests/test_shuffle.py b/distributed/shuffle/tests/test_shuffle.py index 58c97689460..e87413d44bb 100644 --- a/distributed/shuffle/tests/test_shuffle.py +++ b/distributed/shuffle/tests/test_shuffle.py @@ -806,8 +806,13 @@ def __init__(self, value: int) -> None: bio.seek(0) out[k] = load_partition(bio) - assert sum(map(len, out.values())) == len(df) - assert all(v.to_pandas().dtypes.equals(df.dtypes) for v in out.values()) + shuffled_df = pd.concat(table.to_pandas() for table in out.values()) + pd.testing.assert_frame_equal( + df, + shuffled_df, + check_like=True, + check_exact=True, + ) @gen_cluster(client=True)