Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
277 changes: 277 additions & 0 deletions distributed/shuffle/_core.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,277 @@
from __future__ import annotations

import abc
import asyncio
import contextlib
import itertools
import time
from collections import defaultdict
from collections.abc import Callable, Iterator
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass, field
from enum import Enum
from typing import TYPE_CHECKING, Any, ClassVar, Generic, NewType, TypeVar

from distributed.core import PooledRPCCall
from distributed.exceptions import Reschedule
from distributed.protocol import to_serialize
from distributed.shuffle._comms import CommShardsBuffer
from distributed.shuffle._disk import DiskShardsBuffer
from distributed.shuffle._exceptions import ShuffleClosedError
from distributed.shuffle._limiter import ResourceLimiter

if TYPE_CHECKING:
import pandas as pd
from typing_extensions import TypeAlias

# avoid circular dependencies
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin

_T_partition_id = TypeVar("_T_partition_id")
_T_partition_type = TypeVar("_T_partition_type")
_T = TypeVar("_T")

NDIndex: TypeAlias = tuple[int, ...]

ShuffleId = NewType("ShuffleId", str)


class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
def __init__(
self,
id: ShuffleId,
run_id: int,
output_workers: set[str],
local_address: str,
directory: str,
executor: ThreadPoolExecutor,
rpc: Callable[[str], PooledRPCCall],
scheduler: PooledRPCCall,
memory_limiter_disk: ResourceLimiter,
memory_limiter_comms: ResourceLimiter,
):
self.id = id
self.run_id = run_id
self.output_workers = output_workers
self.local_address = local_address
self.executor = executor
self.rpc = rpc
self.scheduler = scheduler
self.closed = False

self._disk_buffer = DiskShardsBuffer(
directory=directory,
memory_limiter=memory_limiter_disk,
)

self._comm_buffer = CommShardsBuffer(
send=self.send, memory_limiter=memory_limiter_comms
)
# TODO: reduce number of connections to number of workers
# MultiComm.max_connections = min(10, n_workers)

self.diagnostics: dict[str, float] = defaultdict(float)
self.transferred = False
self.received: set[_T_partition_id] = set()
self.total_recvd = 0
self.start_time = time.time()
self._exception: Exception | None = None
self._closed_event = asyncio.Event()

def __repr__(self) -> str:
return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"

def __str__(self) -> str:
return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"

def __hash__(self) -> int:
return self.run_id

@contextlib.contextmanager
def time(self, name: str) -> Iterator[None]:
start = time.time()
yield
stop = time.time()
self.diagnostics[name] += stop - start

async def barrier(self) -> None:
self.raise_if_closed()
# TODO: Consider broadcast pinging once when the shuffle starts to warm
# up the comm pool on scheduler side
await self.scheduler.shuffle_barrier(id=self.id, run_id=self.run_id)

async def send(
self, address: str, shards: list[tuple[_T_partition_id, bytes]]
) -> None:
self.raise_if_closed()
return await self.rpc(address).shuffle_receive(
data=to_serialize(shards),
shuffle_id=self.id,
run_id=self.run_id,
)

async def offload(self, func: Callable[..., _T], *args: Any) -> _T:
self.raise_if_closed()
with self.time("cpu"):
return await asyncio.get_running_loop().run_in_executor(
self.executor,
func,
*args,
)

def heartbeat(self) -> dict[str, Any]:
comm_heartbeat = self._comm_buffer.heartbeat()
comm_heartbeat["read"] = self.total_recvd
return {
"disk": self._disk_buffer.heartbeat(),
"comm": comm_heartbeat,
"diagnostics": self.diagnostics,
"start": self.start_time,
}

async def _write_to_comm(
self, data: dict[str, tuple[_T_partition_id, bytes]]
) -> None:
self.raise_if_closed()
await self._comm_buffer.write(data)

async def _write_to_disk(self, data: dict[NDIndex, bytes]) -> None:
self.raise_if_closed()
await self._disk_buffer.write(
{"_".join(str(i) for i in k): v for k, v in data.items()}
)

def raise_if_closed(self) -> None:
if self.closed:
if self._exception:
raise self._exception
raise ShuffleClosedError(f"{self} has already been closed")

async def inputs_done(self) -> None:
self.raise_if_closed()
self.transferred = True
await self._flush_comm()
try:
self._comm_buffer.raise_on_exception()
except Exception as e:
self._exception = e
raise

async def _flush_comm(self) -> None:
self.raise_if_closed()
await self._comm_buffer.flush()

async def flush_receive(self) -> None:
self.raise_if_closed()
await self._disk_buffer.flush()

async def close(self) -> None:
if self.closed: # pragma: no cover
await self._closed_event.wait()
return

self.closed = True
await self._comm_buffer.close()
await self._disk_buffer.close()
self._closed_event.set()

def fail(self, exception: Exception) -> None:
if not self.closed:
self._exception = exception

def _read_from_disk(self, id: NDIndex) -> bytes:
self.raise_if_closed()
data: bytes = self._disk_buffer.read("_".join(str(i) for i in id))
return data

async def receive(self, data: list[tuple[_T_partition_id, bytes]]) -> None:
await self._receive(data)

async def _ensure_output_worker(self, i: _T_partition_id, key: str) -> None:
assigned_worker = self._get_assigned_worker(i)

if assigned_worker != self.local_address:
result = await self.scheduler.shuffle_restrict_task(
id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
)
if result["status"] == "error":
raise RuntimeError(result["message"])
assert result["status"] == "OK"
raise Reschedule()

@abc.abstractmethod
def _get_assigned_worker(self, i: _T_partition_id) -> str:
"""Get the address of the worker assigned to the output partition"""

@abc.abstractmethod
async def _receive(self, data: list[tuple[_T_partition_id, bytes]]) -> None:
"""Receive shards belonging to output partitions of this shuffle run"""

@abc.abstractmethod
async def add_partition(
self, data: _T_partition_type, partition_id: _T_partition_id
) -> int:
"""Add an input partition to the shuffle run"""

@abc.abstractmethod
async def get_output_partition(
self, partition_id: _T_partition_id, key: str, meta: pd.DataFrame | None = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pd.DataFrame seems oddly specific for a generic implementation.

Copy link
Member Author

@hendrikmakait hendrikmakait Aug 14, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very true, but I would like to make more involved changes such as adding generic typing in a separate PR (e.g., #8096). It will be impossible to spot any meaningful changes in this PR.

) -> _T_partition_type:
"""Get an output partition to the shuffle run"""


def get_worker_plugin() -> ShuffleWorkerPlugin:
from distributed import get_worker

try:
worker = get_worker()
except ValueError as e:
raise RuntimeError(
"`shuffle='p2p'` requires Dask's distributed scheduler. This task is not running on a Worker; "
"please confirm that you've created a distributed Client and are submitting this computation through it."
) from e
plugin: ShuffleWorkerPlugin | None = worker.plugins.get("shuffle") # type: ignore
if plugin is None:
raise RuntimeError(
f"The worker {worker.address} does not have a ShuffleExtension. "
"Is pandas installed on the worker?"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is that now necessary even if only using shuffle extensions for array rechunking?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This exception is outdated, but the worker plugin is still involved in array rechunking. Note that we have conflicting names between the generic "P2P shuffle" approach as in sending data all-to-all across the cluster and the specific dataframe.shuffle operation. If you have naming suggestions for resolving this, please let me know. As with other comments, I would rather not address this here.

)
return plugin


_BARRIER_PREFIX = "shuffle-barrier-"


def barrier_key(shuffle_id: ShuffleId) -> str:
return _BARRIER_PREFIX + shuffle_id


def id_from_key(key: str) -> ShuffleId:
assert key.startswith(_BARRIER_PREFIX)
return ShuffleId(key.replace(_BARRIER_PREFIX, ""))


class ShuffleType(Enum):
DATAFRAME = "DataFrameShuffle"
ARRAY_RECHUNK = "ArrayRechunk"


@dataclass(eq=False)
class ShuffleState(abc.ABC):
_run_id_iterator: ClassVar[itertools.count] = itertools.count(1)

id: ShuffleId
run_id: int
output_workers: set[str]
participating_workers: set[str]
_archived_by: str | None = field(default=None, init=False)

@abc.abstractmethod
def to_msg(self) -> dict[str, Any]:
"""Transform the shuffle state into a JSON-serializable message"""

def __str__(self) -> str:
return f"{self.__class__.__name__}<{self.id}[{self.run_id}]>"

def __hash__(self) -> int:
return hash(self.run_id)
11 changes: 3 additions & 8 deletions distributed/shuffle/_merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,8 @@
from dask.highlevelgraph import HighLevelGraph
from dask.layers import Layer

from distributed.shuffle._shuffle import (
ShuffleId,
_get_worker_plugin,
barrier_key,
shuffle_barrier,
shuffle_transfer,
)
from distributed.shuffle._core import ShuffleId, barrier_key, get_worker_plugin
from distributed.shuffle._shuffle import shuffle_barrier, shuffle_transfer

if TYPE_CHECKING:
import pandas as pd
Expand Down Expand Up @@ -167,7 +162,7 @@ def merge_unpack(
):
from dask.dataframe.multi import merge_chunk

ext = _get_worker_plugin()
ext = get_worker_plugin()
# If the partition is empty, it doesn't contain the hash column name
left = ext.get_output_partition(
shuffle_id_left, barrier_left, output_partition, meta=meta_left
Expand Down
Loading