diff --git a/distributed/shuffle/_merge.py b/distributed/shuffle/_merge.py index 8a68b08afd4..d73fc03663a 100644 --- a/distributed/shuffle/_merge.py +++ b/distributed/shuffle/_merge.py @@ -140,7 +140,6 @@ def merge_transfer( input_partition: int, npartitions: int, parts_out: set[int], - meta: pd.DataFrame, ): return shuffle_transfer( input=input, @@ -149,7 +148,6 @@ def merge_transfer( npartitions=npartitions, column=_HASH_COLUMN_NAME, parts_out=parts_out, - meta=meta, ) @@ -162,6 +160,8 @@ def merge_unpack( how: MergeHow, left_on: IndexLabel, right_on: IndexLabel, + meta_left: pd.DataFrame, + meta_right: pd.DataFrame, result_meta: pd.DataFrame, suffixes: Suffixes, ): @@ -170,10 +170,10 @@ def merge_unpack( ext = _get_worker_extension() # If the partition is empty, it doesn't contain the hash column name left = ext.get_output_partition( - shuffle_id_left, barrier_left, output_partition + shuffle_id_left, barrier_left, output_partition, meta=meta_left ).drop(columns=_HASH_COLUMN_NAME, errors="ignore") right = ext.get_output_partition( - shuffle_id_right, barrier_right, output_partition + shuffle_id_right, barrier_right, output_partition, meta=meta_right ).drop(columns=_HASH_COLUMN_NAME, errors="ignore") return merge_chunk( left, @@ -355,7 +355,6 @@ def _construct_graph(self) -> dict[tuple | str, tuple]: i, self.npartitions, self.parts_out, - self.meta_input_left, ) for i in range(self.n_partitions_right): transfer_keys_right.append((name_right, i)) @@ -366,7 +365,6 @@ def _construct_graph(self) -> dict[tuple | str, tuple]: i, self.npartitions, self.parts_out, - self.meta_input_right, ) _barrier_key_left = barrier_key(ShuffleId(token_left)) @@ -386,6 +384,8 @@ def _construct_graph(self) -> dict[tuple | str, tuple]: self.how, self.left_on, self.right_on, + self.meta_input_left, + self.meta_input_right, self.meta_output, self.suffixes, ) diff --git a/distributed/shuffle/_scheduler_extension.py b/distributed/shuffle/_scheduler_extension.py index ba511b63de6..67f51cbeb6d 100644 --- a/distributed/shuffle/_scheduler_extension.py +++ b/distributed/shuffle/_scheduler_extension.py @@ -12,7 +12,6 @@ from typing import TYPE_CHECKING, Any, ClassVar from distributed.diagnostics.plugin import SchedulerPlugin -from distributed.protocol import to_serialize from distributed.shuffle._rechunk import ChunkedAxes, NIndex from distributed.shuffle._shuffle import ( ShuffleId, @@ -22,8 +21,6 @@ ) if TYPE_CHECKING: - import pandas as pd - from distributed.scheduler import ( Recs, Scheduler, @@ -53,7 +50,6 @@ def to_msg(self) -> dict[str, Any]: class DataFrameShuffleState(ShuffleState): type: ClassVar[ShuffleType] = ShuffleType.DATAFRAME worker_for: dict[int, str] - meta: pd.DataFrame column: str def to_msg(self) -> dict[str, Any]: @@ -63,7 +59,6 @@ def to_msg(self) -> dict[str, Any]: "run_id": self.run_id, "worker_for": self.worker_for, "column": self.column, - "meta": to_serialize(self.meta), "output_workers": self.output_workers, } @@ -189,11 +184,9 @@ def _raise_if_barrier_unknown(self, id: ShuffleId) -> None: def _create_dataframe_shuffle_state( self, id: ShuffleId, spec: dict[str, Any] ) -> DataFrameShuffleState: - meta = spec["meta"] column = spec["column"] npartitions = spec["npartitions"] parts_out = spec["parts_out"] - assert meta is not None assert column is not None assert npartitions is not None assert parts_out is not None @@ -207,7 +200,6 @@ def _create_dataframe_shuffle_state( id=id, run_id=next(ShuffleState._run_id_iterator), worker_for=mapping, - meta=meta, column=column, output_workers=output_workers, participating_workers=output_workers.copy(), diff --git a/distributed/shuffle/_shuffle.py b/distributed/shuffle/_shuffle.py index ad3dc08f485..d49706e5a53 100644 --- a/distributed/shuffle/_shuffle.py +++ b/distributed/shuffle/_shuffle.py @@ -58,7 +58,6 @@ def shuffle_transfer( npartitions: int, column: str, parts_out: set[int], - meta: pd.DataFrame, ) -> int: try: return _get_worker_extension().add_partition( @@ -69,18 +68,17 @@ def shuffle_transfer( npartitions=npartitions, column=column, parts_out=parts_out, - meta=meta, ) except Exception as e: raise RuntimeError(f"shuffle_transfer failed during shuffle {id}") from e def shuffle_unpack( - id: ShuffleId, output_partition: int, barrier_run_id: int + id: ShuffleId, output_partition: int, barrier_run_id: int, meta: pd.DataFrame ) -> pd.DataFrame: try: return _get_worker_extension().get_output_partition( - id, barrier_run_id, output_partition + id, barrier_run_id, output_partition, meta=meta ) except Reschedule as e: raise e @@ -251,14 +249,19 @@ def _construct_graph(self) -> _T_LowLevelGraph: self.npartitions, self.column, self.parts_out, - self.meta_input, ) dsk[_barrier_key] = (shuffle_barrier, token, transfer_keys) name = self.name for part_out in self.parts_out: - dsk[(name, part_out)] = (shuffle_unpack, token, part_out, _barrier_key) + dsk[(name, part_out)] = ( + shuffle_unpack, + token, + part_out, + _barrier_key, + self.meta_input, + ) return dsk diff --git a/distributed/shuffle/_worker_extension.py b/distributed/shuffle/_worker_extension.py index 105b40e2f1f..28470647d22 100644 --- a/distributed/shuffle/_worker_extension.py +++ b/distributed/shuffle/_worker_extension.py @@ -234,7 +234,7 @@ async def add_partition( @abc.abstractmethod async def get_output_partition( - self, i: T_partition_id, key: str + self, i: T_partition_id, key: str, meta: pd.DataFrame | None = None ) -> T_partition_type: """Get an output partition to the shuffle run""" @@ -380,8 +380,11 @@ def _() -> dict[str, list[tuple[ArrayRechunkShardID, bytes]]]: await self._write_to_comm(out) return self.run_id - async def get_output_partition(self, i: NIndex, key: str) -> np.ndarray: + async def get_output_partition( + self, i: NIndex, key: str, meta: pd.DataFrame | None = None + ) -> np.ndarray: self.raise_if_closed() + assert meta is None assert self.transferred, "`get_output_partition` called before barrier task" await self._ensure_output_worker(i, key) @@ -420,8 +423,6 @@ class DataFrameShuffleRun(ShuffleRun[int, int, "pd.DataFrame"]): A set of all participating worker (addresses). column: The data column we split the input partition by. - meta: - Empty metadata of the input. id: A unique `ShuffleID` this belongs to. run_id: @@ -450,7 +451,6 @@ def __init__( worker_for: dict[int, str], output_workers: set, column: str, - meta: pd.DataFrame, id: ShuffleId, run_id: int, local_address: str, @@ -476,7 +476,6 @@ def __init__( memory_limiter_disk=memory_limiter_disk, ) self.column = column - self.meta = meta partitions_of = defaultdict(list) for part, addr in worker_for.items(): partitions_of[addr].append(part) @@ -531,8 +530,11 @@ def _() -> dict[str, list[tuple[int, bytes]]]: await self._write_to_comm(out) return self.run_id - async def get_output_partition(self, i: int, key: str) -> pd.DataFrame: + async def get_output_partition( + self, i: int, key: str, meta: pd.DataFrame | None = None + ) -> pd.DataFrame: self.raise_if_closed() + assert meta is not None assert self.transferred, "`get_output_partition` called before barrier task" await self._ensure_output_worker(i, key) @@ -542,11 +544,11 @@ async def get_output_partition(self, i: int, key: str) -> pd.DataFrame: data = self._read_from_disk((i,)) def _() -> pd.DataFrame: - return convert_partition(data, self.meta) + return convert_partition(data, meta) # type: ignore out = await self.offload(_) except KeyError: - out = self.meta.copy() + out = meta.copy() return out def _get_assigned_worker(self, i: int) -> str: @@ -771,7 +773,6 @@ async def _refresh_shuffle( id=shuffle_id, type=type, spec={ - "meta": to_serialize(kwargs["meta"]), "npartitions": kwargs["npartitions"], "column": kwargs["column"], "parts_out": kwargs["parts_out"], @@ -817,7 +818,6 @@ async def _( column=result["column"], worker_for=result["worker_for"], output_workers=result["output_workers"], - meta=result["meta"], id=shuffle_id, run_id=result["run_id"], directory=os.path.join( @@ -904,7 +904,11 @@ def get_or_create_shuffle( ) def get_output_partition( - self, shuffle_id: ShuffleId, run_id: int, output_partition: int | NIndex + self, + shuffle_id: ShuffleId, + run_id: int, + output_partition: int | NIndex, + meta: pd.DataFrame | None = None, ) -> Any: """ Task: Retrieve a shuffled output partition from the ShuffleExtension. @@ -914,7 +918,11 @@ def get_output_partition( shuffle = self.get_shuffle_run(shuffle_id, run_id) key = thread_state.key return sync( - self.worker.loop, shuffle.get_output_partition, output_partition, key + self.worker.loop, + shuffle.get_output_partition, + output_partition, + key, + meta=meta, ) diff --git a/distributed/shuffle/tests/test_shuffle.py b/distributed/shuffle/tests/test_shuffle.py index 28f94aeabc2..3192ef149a5 100644 --- a/distributed/shuffle/tests/test_shuffle.py +++ b/distributed/shuffle/tests/test_shuffle.py @@ -1190,7 +1190,6 @@ def new_shuffle( self, name, worker_for_mapping, - meta, directory, loop, Shuffle=DataFrameShuffleRun, @@ -1200,7 +1199,6 @@ def new_shuffle( worker_for=worker_for_mapping, # FIXME: Is output_workers redundant with worker_for? output_workers=set(worker_for_mapping.values()), - meta=meta, directory=directory / name, id=ShuffleId(name), run_id=next(AbstractShuffleTestPool._shuffle_run_id_iterator), @@ -1257,7 +1255,6 @@ async def test_basic_lowlevel_shuffle( local_shuffle_pool.new_shuffle( name=workers[ix], worker_for_mapping=worker_for_mapping, - meta=meta, directory=tmp_path, loop=loop_in_thread, ) @@ -1290,7 +1287,7 @@ async def test_basic_lowlevel_shuffle( all_parts = [] for part, worker in worker_for_mapping.items(): s = local_shuffle_pool.shuffles[worker] - all_parts.append(s.get_output_partition(part, f"key-{part}")) + all_parts.append(s.get_output_partition(part, f"key-{part}", meta=meta)) all_parts = await asyncio.gather(*all_parts) @@ -1323,7 +1320,6 @@ async def test_error_offload(tmp_path, loop_in_thread): npartitions, part, workers ) partitions_for_worker[w].append(part) - meta = dfs[0].head(0) class ErrorOffload(DataFrameShuffleRun): async def offload(self, func, *args): @@ -1333,7 +1329,6 @@ async def offload(self, func, *args): sA = local_shuffle_pool.new_shuffle( name="A", worker_for_mapping=worker_for_mapping, - meta=meta, directory=tmp_path, loop=loop_in_thread, Shuffle=ErrorOffload, @@ -1341,7 +1336,6 @@ async def offload(self, func, *args): sB = local_shuffle_pool.new_shuffle( name="B", worker_for_mapping=worker_for_mapping, - meta=meta, directory=tmp_path, loop=loop_in_thread, ) @@ -1377,7 +1371,6 @@ async def test_error_send(tmp_path, loop_in_thread): npartitions, part, workers ) partitions_for_worker[w].append(part) - meta = dfs[0].head(0) class ErrorSend(DataFrameShuffleRun): async def send(self, *args: Any, **kwargs: Any) -> None: @@ -1387,7 +1380,6 @@ async def send(self, *args: Any, **kwargs: Any) -> None: sA = local_shuffle_pool.new_shuffle( name="A", worker_for_mapping=worker_for_mapping, - meta=meta, directory=tmp_path, loop=loop_in_thread, Shuffle=ErrorSend, @@ -1395,7 +1387,6 @@ async def send(self, *args: Any, **kwargs: Any) -> None: sB = local_shuffle_pool.new_shuffle( name="B", worker_for_mapping=worker_for_mapping, - meta=meta, directory=tmp_path, loop=loop_in_thread, ) @@ -1430,7 +1421,6 @@ async def test_error_receive(tmp_path, loop_in_thread): npartitions, part, workers ) partitions_for_worker[w].append(part) - meta = dfs[0].head(0) class ErrorReceive(DataFrameShuffleRun): async def receive(self, data: list[tuple[int, bytes]]) -> None: @@ -1440,7 +1430,6 @@ async def receive(self, data: list[tuple[int, bytes]]) -> None: sA = local_shuffle_pool.new_shuffle( name="A", worker_for_mapping=worker_for_mapping, - meta=meta, directory=tmp_path, loop=loop_in_thread, Shuffle=ErrorReceive, @@ -1448,7 +1437,6 @@ async def receive(self, data: list[tuple[int, bytes]]) -> None: sB = local_shuffle_pool.new_shuffle( name="B", worker_for_mapping=worker_for_mapping, - meta=meta, directory=tmp_path, loop=loop_in_thread, )