Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions distributed/shuffle/_merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,6 @@ def merge_transfer(
input_partition: int,
npartitions: int,
parts_out: set[int],
meta: pd.DataFrame,
):
return shuffle_transfer(
input=input,
Expand All @@ -149,7 +148,6 @@ def merge_transfer(
npartitions=npartitions,
column=_HASH_COLUMN_NAME,
parts_out=parts_out,
meta=meta,
)


Expand All @@ -162,6 +160,8 @@ def merge_unpack(
how: MergeHow,
left_on: IndexLabel,
right_on: IndexLabel,
meta_left: pd.DataFrame,
meta_right: pd.DataFrame,
result_meta: pd.DataFrame,
suffixes: Suffixes,
):
Expand All @@ -170,10 +170,10 @@ def merge_unpack(
ext = _get_worker_extension()
# If the partition is empty, it doesn't contain the hash column name
left = ext.get_output_partition(
shuffle_id_left, barrier_left, output_partition
shuffle_id_left, barrier_left, output_partition, meta=meta_left
).drop(columns=_HASH_COLUMN_NAME, errors="ignore")
right = ext.get_output_partition(
shuffle_id_right, barrier_right, output_partition
shuffle_id_right, barrier_right, output_partition, meta=meta_right
).drop(columns=_HASH_COLUMN_NAME, errors="ignore")
return merge_chunk(
left,
Expand Down Expand Up @@ -355,7 +355,6 @@ def _construct_graph(self) -> dict[tuple | str, tuple]:
i,
self.npartitions,
self.parts_out,
self.meta_input_left,
)
for i in range(self.n_partitions_right):
transfer_keys_right.append((name_right, i))
Expand All @@ -366,7 +365,6 @@ def _construct_graph(self) -> dict[tuple | str, tuple]:
i,
self.npartitions,
self.parts_out,
self.meta_input_right,
)

_barrier_key_left = barrier_key(ShuffleId(token_left))
Expand All @@ -386,6 +384,8 @@ def _construct_graph(self) -> dict[tuple | str, tuple]:
self.how,
self.left_on,
self.right_on,
self.meta_input_left,
self.meta_input_right,
self.meta_output,
self.suffixes,
)
Expand Down
8 changes: 0 additions & 8 deletions distributed/shuffle/_scheduler_extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from typing import TYPE_CHECKING, Any, ClassVar

from distributed.diagnostics.plugin import SchedulerPlugin
from distributed.protocol import to_serialize
from distributed.shuffle._rechunk import ChunkedAxes, NIndex
from distributed.shuffle._shuffle import (
ShuffleId,
Expand All @@ -22,8 +21,6 @@
)

if TYPE_CHECKING:
import pandas as pd

from distributed.scheduler import (
Recs,
Scheduler,
Expand Down Expand Up @@ -53,7 +50,6 @@ def to_msg(self) -> dict[str, Any]:
class DataFrameShuffleState(ShuffleState):
type: ClassVar[ShuffleType] = ShuffleType.DATAFRAME
worker_for: dict[int, str]
meta: pd.DataFrame
column: str

def to_msg(self) -> dict[str, Any]:
Expand All @@ -63,7 +59,6 @@ def to_msg(self) -> dict[str, Any]:
"run_id": self.run_id,
"worker_for": self.worker_for,
"column": self.column,
"meta": to_serialize(self.meta),
"output_workers": self.output_workers,
}

Expand Down Expand Up @@ -189,11 +184,9 @@ def _raise_if_barrier_unknown(self, id: ShuffleId) -> None:
def _create_dataframe_shuffle_state(
self, id: ShuffleId, spec: dict[str, Any]
) -> DataFrameShuffleState:
meta = spec["meta"]
column = spec["column"]
npartitions = spec["npartitions"]
parts_out = spec["parts_out"]
assert meta is not None
assert column is not None
assert npartitions is not None
assert parts_out is not None
Expand All @@ -207,7 +200,6 @@ def _create_dataframe_shuffle_state(
id=id,
run_id=next(ShuffleState._run_id_iterator),
worker_for=mapping,
meta=meta,
column=column,
output_workers=output_workers,
participating_workers=output_workers.copy(),
Expand Down
15 changes: 9 additions & 6 deletions distributed/shuffle/_shuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@ def shuffle_transfer(
npartitions: int,
column: str,
parts_out: set[int],
meta: pd.DataFrame,
) -> int:
try:
return _get_worker_extension().add_partition(
Expand All @@ -69,18 +68,17 @@ def shuffle_transfer(
npartitions=npartitions,
column=column,
parts_out=parts_out,
meta=meta,
)
except Exception as e:
raise RuntimeError(f"shuffle_transfer failed during shuffle {id}") from e


def shuffle_unpack(
id: ShuffleId, output_partition: int, barrier_run_id: int
id: ShuffleId, output_partition: int, barrier_run_id: int, meta: pd.DataFrame
) -> pd.DataFrame:
try:
return _get_worker_extension().get_output_partition(
id, barrier_run_id, output_partition
id, barrier_run_id, output_partition, meta=meta
)
except Reschedule as e:
raise e
Expand Down Expand Up @@ -251,14 +249,19 @@ def _construct_graph(self) -> _T_LowLevelGraph:
self.npartitions,
self.column,
self.parts_out,
self.meta_input,
)

dsk[_barrier_key] = (shuffle_barrier, token, transfer_keys)

name = self.name
for part_out in self.parts_out:
dsk[(name, part_out)] = (shuffle_unpack, token, part_out, _barrier_key)
dsk[(name, part_out)] = (
shuffle_unpack,
token,
part_out,
_barrier_key,
self.meta_input,
)
return dsk


Expand Down
34 changes: 21 additions & 13 deletions distributed/shuffle/_worker_extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ async def add_partition(

@abc.abstractmethod
async def get_output_partition(
self, i: T_partition_id, key: str
self, i: T_partition_id, key: str, meta: pd.DataFrame | None = None
) -> T_partition_type:
"""Get an output partition to the shuffle run"""

Expand Down Expand Up @@ -380,8 +380,11 @@ def _() -> dict[str, list[tuple[ArrayRechunkShardID, bytes]]]:
await self._write_to_comm(out)
return self.run_id

async def get_output_partition(self, i: NIndex, key: str) -> np.ndarray:
async def get_output_partition(
self, i: NIndex, key: str, meta: pd.DataFrame | None = None
) -> np.ndarray:
self.raise_if_closed()
assert meta is None
assert self.transferred, "`get_output_partition` called before barrier task"

await self._ensure_output_worker(i, key)
Expand Down Expand Up @@ -420,8 +423,6 @@ class DataFrameShuffleRun(ShuffleRun[int, int, "pd.DataFrame"]):
A set of all participating worker (addresses).
column:
The data column we split the input partition by.
meta:
Empty metadata of the input.
id:
A unique `ShuffleID` this belongs to.
run_id:
Expand Down Expand Up @@ -450,7 +451,6 @@ def __init__(
worker_for: dict[int, str],
output_workers: set,
column: str,
meta: pd.DataFrame,
id: ShuffleId,
run_id: int,
local_address: str,
Expand All @@ -476,7 +476,6 @@ def __init__(
memory_limiter_disk=memory_limiter_disk,
)
self.column = column
self.meta = meta
partitions_of = defaultdict(list)
for part, addr in worker_for.items():
partitions_of[addr].append(part)
Expand Down Expand Up @@ -531,8 +530,11 @@ def _() -> dict[str, list[tuple[int, bytes]]]:
await self._write_to_comm(out)
return self.run_id

async def get_output_partition(self, i: int, key: str) -> pd.DataFrame:
async def get_output_partition(
self, i: int, key: str, meta: pd.DataFrame | None = None
) -> pd.DataFrame:
self.raise_if_closed()
assert meta is not None
assert self.transferred, "`get_output_partition` called before barrier task"

await self._ensure_output_worker(i, key)
Expand All @@ -542,11 +544,11 @@ async def get_output_partition(self, i: int, key: str) -> pd.DataFrame:
data = self._read_from_disk((i,))

def _() -> pd.DataFrame:
return convert_partition(data, self.meta)
return convert_partition(data, meta) # type: ignore

out = await self.offload(_)
except KeyError:
out = self.meta.copy()
out = meta.copy()
return out

def _get_assigned_worker(self, i: int) -> str:
Expand Down Expand Up @@ -771,7 +773,6 @@ async def _refresh_shuffle(
id=shuffle_id,
type=type,
spec={
"meta": to_serialize(kwargs["meta"]),
"npartitions": kwargs["npartitions"],
"column": kwargs["column"],
"parts_out": kwargs["parts_out"],
Expand Down Expand Up @@ -817,7 +818,6 @@ async def _(
column=result["column"],
worker_for=result["worker_for"],
output_workers=result["output_workers"],
meta=result["meta"],
id=shuffle_id,
run_id=result["run_id"],
directory=os.path.join(
Expand Down Expand Up @@ -904,7 +904,11 @@ def get_or_create_shuffle(
)

def get_output_partition(
self, shuffle_id: ShuffleId, run_id: int, output_partition: int | NIndex
self,
shuffle_id: ShuffleId,
run_id: int,
output_partition: int | NIndex,
meta: pd.DataFrame | None = None,
) -> Any:
"""
Task: Retrieve a shuffled output partition from the ShuffleExtension.
Expand All @@ -914,7 +918,11 @@ def get_output_partition(
shuffle = self.get_shuffle_run(shuffle_id, run_id)
key = thread_state.key
return sync(
self.worker.loop, shuffle.get_output_partition, output_partition, key
self.worker.loop,
shuffle.get_output_partition,
output_partition,
key,
meta=meta,
)


Expand Down
14 changes: 1 addition & 13 deletions distributed/shuffle/tests/test_shuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -1190,7 +1190,6 @@ def new_shuffle(
self,
name,
worker_for_mapping,
meta,
directory,
loop,
Shuffle=DataFrameShuffleRun,
Expand All @@ -1200,7 +1199,6 @@ def new_shuffle(
worker_for=worker_for_mapping,
# FIXME: Is output_workers redundant with worker_for?
output_workers=set(worker_for_mapping.values()),
meta=meta,
directory=directory / name,
id=ShuffleId(name),
run_id=next(AbstractShuffleTestPool._shuffle_run_id_iterator),
Expand Down Expand Up @@ -1257,7 +1255,6 @@ async def test_basic_lowlevel_shuffle(
local_shuffle_pool.new_shuffle(
name=workers[ix],
worker_for_mapping=worker_for_mapping,
meta=meta,
directory=tmp_path,
loop=loop_in_thread,
)
Expand Down Expand Up @@ -1290,7 +1287,7 @@ async def test_basic_lowlevel_shuffle(
all_parts = []
for part, worker in worker_for_mapping.items():
s = local_shuffle_pool.shuffles[worker]
all_parts.append(s.get_output_partition(part, f"key-{part}"))
all_parts.append(s.get_output_partition(part, f"key-{part}", meta=meta))

all_parts = await asyncio.gather(*all_parts)

Expand Down Expand Up @@ -1323,7 +1320,6 @@ async def test_error_offload(tmp_path, loop_in_thread):
npartitions, part, workers
)
partitions_for_worker[w].append(part)
meta = dfs[0].head(0)

class ErrorOffload(DataFrameShuffleRun):
async def offload(self, func, *args):
Expand All @@ -1333,15 +1329,13 @@ async def offload(self, func, *args):
sA = local_shuffle_pool.new_shuffle(
name="A",
worker_for_mapping=worker_for_mapping,
meta=meta,
directory=tmp_path,
loop=loop_in_thread,
Shuffle=ErrorOffload,
)
sB = local_shuffle_pool.new_shuffle(
name="B",
worker_for_mapping=worker_for_mapping,
meta=meta,
directory=tmp_path,
loop=loop_in_thread,
)
Expand Down Expand Up @@ -1377,7 +1371,6 @@ async def test_error_send(tmp_path, loop_in_thread):
npartitions, part, workers
)
partitions_for_worker[w].append(part)
meta = dfs[0].head(0)

class ErrorSend(DataFrameShuffleRun):
async def send(self, *args: Any, **kwargs: Any) -> None:
Expand All @@ -1387,15 +1380,13 @@ async def send(self, *args: Any, **kwargs: Any) -> None:
sA = local_shuffle_pool.new_shuffle(
name="A",
worker_for_mapping=worker_for_mapping,
meta=meta,
directory=tmp_path,
loop=loop_in_thread,
Shuffle=ErrorSend,
)
sB = local_shuffle_pool.new_shuffle(
name="B",
worker_for_mapping=worker_for_mapping,
meta=meta,
directory=tmp_path,
loop=loop_in_thread,
)
Expand Down Expand Up @@ -1430,7 +1421,6 @@ async def test_error_receive(tmp_path, loop_in_thread):
npartitions, part, workers
)
partitions_for_worker[w].append(part)
meta = dfs[0].head(0)

class ErrorReceive(DataFrameShuffleRun):
async def receive(self, data: list[tuple[int, bytes]]) -> None:
Expand All @@ -1440,15 +1430,13 @@ async def receive(self, data: list[tuple[int, bytes]]) -> None:
sA = local_shuffle_pool.new_shuffle(
name="A",
worker_for_mapping=worker_for_mapping,
meta=meta,
directory=tmp_path,
loop=loop_in_thread,
Shuffle=ErrorReceive,
)
sB = local_shuffle_pool.new_shuffle(
name="B",
worker_for_mapping=worker_for_mapping,
meta=meta,
directory=tmp_path,
loop=loop_in_thread,
)
Expand Down