Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion distributed/shuffle/_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def check_minimal_arrow_version() -> None:


def convert_partition(data: bytes, meta: pd.DataFrame) -> pd.DataFrame:
import pandas as pd
import pyarrow as pa

file = BytesIO(data)
Expand All @@ -56,7 +57,19 @@ def convert_partition(data: bytes, meta: pd.DataFrame) -> pd.DataFrame:
shards.append(sr.read_all())
table = pa.concat_tables(shards)
df = table.to_pandas(self_destruct=True)
return df.astype(meta.dtypes)

def default_types_mapper(pyarrow_dtype: pa.DataType) -> object:
# Avoid converting strings from `string[pyarrow]` to `string[python]`
# if we have *some* `string[pyarrow]`
if (
pyarrow_dtype in {pa.large_string(), pa.string()}
and pd.StringDtype("pyarrow") in meta.dtypes.values
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are actually two implementations of pyarrow-strings in pandas. The one you have here and also pd.ArrowDtype(pa.string()). What you have here is fine for now, especially since it's just a performance optimization. Over in dask/dask we're also using pd.StringDtype("pyarrow") as it, historically, has been more feature complete than pd.ArrowDtype(pa.string()). That said, I think the situation has changed in pandas=2, so we may switch to pd.ArrowDtype(pa.string()) at some point in the future. This is mostly just an FYI in case we need to circle back to here in the future.

):
return pd.StringDtype("pyarrow")
return None

df = table.to_pandas(self_destruct=True, types_mapper=default_types_mapper)
return df.astype(meta.dtypes, copy=False)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

copy=False to make this a no-op for columns with matching dtypes.



def list_of_buffers_to_table(data: list[bytes]) -> pa.Table:
Expand Down