-
-
Notifications
You must be signed in to change notification settings - Fork 748
Restructure P2P code #8098
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Restructure P2P code #8098
Changes from all commits
d2892d2
11408be
a5f0817
121bdc8
e922db4
f361e71
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,277 @@ | ||
| from __future__ import annotations | ||
|
|
||
| import abc | ||
| import asyncio | ||
| import contextlib | ||
| import itertools | ||
| import time | ||
| from collections import defaultdict | ||
| from collections.abc import Callable, Iterator | ||
| from concurrent.futures import ThreadPoolExecutor | ||
| from dataclasses import dataclass, field | ||
| from enum import Enum | ||
| from typing import TYPE_CHECKING, Any, ClassVar, Generic, NewType, TypeVar | ||
|
|
||
| from distributed.core import PooledRPCCall | ||
| from distributed.exceptions import Reschedule | ||
| from distributed.protocol import to_serialize | ||
| from distributed.shuffle._comms import CommShardsBuffer | ||
| from distributed.shuffle._disk import DiskShardsBuffer | ||
| from distributed.shuffle._exceptions import ShuffleClosedError | ||
| from distributed.shuffle._limiter import ResourceLimiter | ||
|
|
||
| if TYPE_CHECKING: | ||
| import pandas as pd | ||
| from typing_extensions import TypeAlias | ||
|
|
||
| # avoid circular dependencies | ||
| from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin | ||
|
|
||
| _T_partition_id = TypeVar("_T_partition_id") | ||
| _T_partition_type = TypeVar("_T_partition_type") | ||
| _T = TypeVar("_T") | ||
|
|
||
| NDIndex: TypeAlias = tuple[int, ...] | ||
|
|
||
| ShuffleId = NewType("ShuffleId", str) | ||
|
|
||
|
|
||
| class ShuffleRun(Generic[_T_partition_id, _T_partition_type]): | ||
| def __init__( | ||
| self, | ||
| id: ShuffleId, | ||
| run_id: int, | ||
| output_workers: set[str], | ||
| local_address: str, | ||
| directory: str, | ||
| executor: ThreadPoolExecutor, | ||
| rpc: Callable[[str], PooledRPCCall], | ||
| scheduler: PooledRPCCall, | ||
| memory_limiter_disk: ResourceLimiter, | ||
| memory_limiter_comms: ResourceLimiter, | ||
| ): | ||
| self.id = id | ||
| self.run_id = run_id | ||
| self.output_workers = output_workers | ||
| self.local_address = local_address | ||
| self.executor = executor | ||
| self.rpc = rpc | ||
| self.scheduler = scheduler | ||
| self.closed = False | ||
|
|
||
| self._disk_buffer = DiskShardsBuffer( | ||
| directory=directory, | ||
| memory_limiter=memory_limiter_disk, | ||
| ) | ||
|
|
||
| self._comm_buffer = CommShardsBuffer( | ||
| send=self.send, memory_limiter=memory_limiter_comms | ||
| ) | ||
| # TODO: reduce number of connections to number of workers | ||
| # MultiComm.max_connections = min(10, n_workers) | ||
|
|
||
| self.diagnostics: dict[str, float] = defaultdict(float) | ||
| self.transferred = False | ||
| self.received: set[_T_partition_id] = set() | ||
| self.total_recvd = 0 | ||
| self.start_time = time.time() | ||
| self._exception: Exception | None = None | ||
| self._closed_event = asyncio.Event() | ||
|
|
||
| def __repr__(self) -> str: | ||
| return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>" | ||
|
|
||
| def __str__(self) -> str: | ||
| return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}" | ||
|
|
||
| def __hash__(self) -> int: | ||
| return self.run_id | ||
|
|
||
| @contextlib.contextmanager | ||
| def time(self, name: str) -> Iterator[None]: | ||
| start = time.time() | ||
| yield | ||
| stop = time.time() | ||
| self.diagnostics[name] += stop - start | ||
|
|
||
| async def barrier(self) -> None: | ||
| self.raise_if_closed() | ||
| # TODO: Consider broadcast pinging once when the shuffle starts to warm | ||
| # up the comm pool on scheduler side | ||
| await self.scheduler.shuffle_barrier(id=self.id, run_id=self.run_id) | ||
|
|
||
| async def send( | ||
| self, address: str, shards: list[tuple[_T_partition_id, bytes]] | ||
| ) -> None: | ||
| self.raise_if_closed() | ||
| return await self.rpc(address).shuffle_receive( | ||
| data=to_serialize(shards), | ||
| shuffle_id=self.id, | ||
| run_id=self.run_id, | ||
| ) | ||
|
|
||
| async def offload(self, func: Callable[..., _T], *args: Any) -> _T: | ||
| self.raise_if_closed() | ||
| with self.time("cpu"): | ||
| return await asyncio.get_running_loop().run_in_executor( | ||
| self.executor, | ||
| func, | ||
| *args, | ||
| ) | ||
|
|
||
| def heartbeat(self) -> dict[str, Any]: | ||
| comm_heartbeat = self._comm_buffer.heartbeat() | ||
| comm_heartbeat["read"] = self.total_recvd | ||
| return { | ||
| "disk": self._disk_buffer.heartbeat(), | ||
| "comm": comm_heartbeat, | ||
| "diagnostics": self.diagnostics, | ||
| "start": self.start_time, | ||
| } | ||
|
|
||
| async def _write_to_comm( | ||
| self, data: dict[str, tuple[_T_partition_id, bytes]] | ||
| ) -> None: | ||
| self.raise_if_closed() | ||
| await self._comm_buffer.write(data) | ||
|
|
||
| async def _write_to_disk(self, data: dict[NDIndex, bytes]) -> None: | ||
| self.raise_if_closed() | ||
| await self._disk_buffer.write( | ||
| {"_".join(str(i) for i in k): v for k, v in data.items()} | ||
| ) | ||
|
|
||
| def raise_if_closed(self) -> None: | ||
| if self.closed: | ||
| if self._exception: | ||
| raise self._exception | ||
| raise ShuffleClosedError(f"{self} has already been closed") | ||
|
|
||
| async def inputs_done(self) -> None: | ||
| self.raise_if_closed() | ||
| self.transferred = True | ||
| await self._flush_comm() | ||
| try: | ||
| self._comm_buffer.raise_on_exception() | ||
| except Exception as e: | ||
| self._exception = e | ||
| raise | ||
|
|
||
| async def _flush_comm(self) -> None: | ||
| self.raise_if_closed() | ||
| await self._comm_buffer.flush() | ||
|
|
||
| async def flush_receive(self) -> None: | ||
| self.raise_if_closed() | ||
| await self._disk_buffer.flush() | ||
|
|
||
| async def close(self) -> None: | ||
| if self.closed: # pragma: no cover | ||
| await self._closed_event.wait() | ||
| return | ||
|
|
||
| self.closed = True | ||
| await self._comm_buffer.close() | ||
| await self._disk_buffer.close() | ||
| self._closed_event.set() | ||
|
|
||
| def fail(self, exception: Exception) -> None: | ||
| if not self.closed: | ||
| self._exception = exception | ||
|
|
||
| def _read_from_disk(self, id: NDIndex) -> bytes: | ||
| self.raise_if_closed() | ||
| data: bytes = self._disk_buffer.read("_".join(str(i) for i in id)) | ||
| return data | ||
|
|
||
| async def receive(self, data: list[tuple[_T_partition_id, bytes]]) -> None: | ||
| await self._receive(data) | ||
|
|
||
| async def _ensure_output_worker(self, i: _T_partition_id, key: str) -> None: | ||
| assigned_worker = self._get_assigned_worker(i) | ||
|
|
||
| if assigned_worker != self.local_address: | ||
| result = await self.scheduler.shuffle_restrict_task( | ||
| id=self.id, run_id=self.run_id, key=key, worker=assigned_worker | ||
| ) | ||
| if result["status"] == "error": | ||
| raise RuntimeError(result["message"]) | ||
| assert result["status"] == "OK" | ||
| raise Reschedule() | ||
|
|
||
| @abc.abstractmethod | ||
| def _get_assigned_worker(self, i: _T_partition_id) -> str: | ||
| """Get the address of the worker assigned to the output partition""" | ||
|
|
||
| @abc.abstractmethod | ||
| async def _receive(self, data: list[tuple[_T_partition_id, bytes]]) -> None: | ||
| """Receive shards belonging to output partitions of this shuffle run""" | ||
|
|
||
| @abc.abstractmethod | ||
| async def add_partition( | ||
| self, data: _T_partition_type, partition_id: _T_partition_id | ||
| ) -> int: | ||
| """Add an input partition to the shuffle run""" | ||
|
|
||
| @abc.abstractmethod | ||
| async def get_output_partition( | ||
| self, partition_id: _T_partition_id, key: str, meta: pd.DataFrame | None = None | ||
| ) -> _T_partition_type: | ||
| """Get an output partition to the shuffle run""" | ||
|
|
||
|
|
||
| def get_worker_plugin() -> ShuffleWorkerPlugin: | ||
| from distributed import get_worker | ||
|
|
||
| try: | ||
| worker = get_worker() | ||
| except ValueError as e: | ||
| raise RuntimeError( | ||
| "`shuffle='p2p'` requires Dask's distributed scheduler. This task is not running on a Worker; " | ||
| "please confirm that you've created a distributed Client and are submitting this computation through it." | ||
| ) from e | ||
| plugin: ShuffleWorkerPlugin | None = worker.plugins.get("shuffle") # type: ignore | ||
| if plugin is None: | ||
| raise RuntimeError( | ||
| f"The worker {worker.address} does not have a ShuffleExtension. " | ||
| "Is pandas installed on the worker?" | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is that now necessary even if only using shuffle extensions for array rechunking?
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This exception is outdated, but the worker plugin is still involved in array rechunking. Note that we have conflicting names between the generic "P2P shuffle" approach as in sending data all-to-all across the cluster and the specific |
||
| ) | ||
| return plugin | ||
|
|
||
|
|
||
| _BARRIER_PREFIX = "shuffle-barrier-" | ||
|
|
||
|
|
||
| def barrier_key(shuffle_id: ShuffleId) -> str: | ||
| return _BARRIER_PREFIX + shuffle_id | ||
|
|
||
|
|
||
| def id_from_key(key: str) -> ShuffleId: | ||
| assert key.startswith(_BARRIER_PREFIX) | ||
| return ShuffleId(key.replace(_BARRIER_PREFIX, "")) | ||
|
|
||
|
|
||
| class ShuffleType(Enum): | ||
| DATAFRAME = "DataFrameShuffle" | ||
| ARRAY_RECHUNK = "ArrayRechunk" | ||
|
|
||
|
|
||
| @dataclass(eq=False) | ||
| class ShuffleState(abc.ABC): | ||
| _run_id_iterator: ClassVar[itertools.count] = itertools.count(1) | ||
|
|
||
| id: ShuffleId | ||
| run_id: int | ||
| output_workers: set[str] | ||
| participating_workers: set[str] | ||
| _archived_by: str | None = field(default=None, init=False) | ||
|
|
||
| @abc.abstractmethod | ||
| def to_msg(self) -> dict[str, Any]: | ||
| """Transform the shuffle state into a JSON-serializable message""" | ||
|
|
||
| def __str__(self) -> str: | ||
| return f"{self.__class__.__name__}<{self.id}[{self.run_id}]>" | ||
|
|
||
| def __hash__(self) -> int: | ||
| return hash(self.run_id) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
pd.DataFrameseems oddly specific for a generic implementation.Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very true, but I would like to make more involved changes such as adding generic typing in a separate PR (e.g., #8096). It will be impossible to spot any meaningful changes in this PR.