diff --git a/distributed/shuffle/_arrow.py b/distributed/shuffle/_arrow.py index 6a0bc81e8d3..f00bd05caed 100644 --- a/distributed/shuffle/_arrow.py +++ b/distributed/shuffle/_arrow.py @@ -46,6 +46,7 @@ def check_minimal_arrow_version() -> None: def convert_partition(data: bytes, meta: pd.DataFrame) -> pd.DataFrame: + import pandas as pd import pyarrow as pa file = BytesIO(data) @@ -56,7 +57,19 @@ def convert_partition(data: bytes, meta: pd.DataFrame) -> pd.DataFrame: shards.append(sr.read_all()) table = pa.concat_tables(shards) df = table.to_pandas(self_destruct=True) - return df.astype(meta.dtypes) + + def default_types_mapper(pyarrow_dtype: pa.DataType) -> object: + # Avoid converting strings from `string[pyarrow]` to `string[python]` + # if we have *some* `string[pyarrow]` + if ( + pyarrow_dtype in {pa.large_string(), pa.string()} + and pd.StringDtype("pyarrow") in meta.dtypes.values + ): + return pd.StringDtype("pyarrow") + return None + + df = table.to_pandas(self_destruct=True, types_mapper=default_types_mapper) + return df.astype(meta.dtypes, copy=False) def list_of_buffers_to_table(data: list[bytes]) -> pa.Table: