diff --git a/python/pyarrow/interchange/dataframe.py b/python/pyarrow/interchange/dataframe.py index d0717e02e88..59ba765c175 100644 --- a/python/pyarrow/interchange/dataframe.py +++ b/python/pyarrow/interchange/dataframe.py @@ -44,11 +44,13 @@ class _PyArrowDataFrame: """ def __init__( - self, df: pa.Table, nan_as_null: bool = False, allow_copy: bool = True + self, df: pa.Table | pa.RecordBatch, + nan_as_null: bool = False, + allow_copy: bool = True ) -> None: """ Constructor - an instance of this (private) class is returned from - `pa.Table.__dataframe__`. + `pa.Table.__dataframe__` or `pa.RecordBatch.__dataframe__`. """ self._df = df # ``nan_as_null`` is a keyword intended for the consumer to tell the @@ -114,18 +116,21 @@ def num_chunks(self) -> int: """ Return the number of chunks the DataFrame consists of. """ - # pyarrow.Table can have columns with different number - # of chunks so we take the number of chunks that - # .to_batches() returns as it takes the min chunk size - # of all the columns (to_batches is a zero copy method) - batches = self._df.to_batches() - return len(batches) + if isinstance(self._df, pa.RecordBatch): + return 1 + else: + # pyarrow.Table can have columns with different number + # of chunks so we take the number of chunks that + # .to_batches() returns as it takes the min chunk size + # of all the columns (to_batches is a zero copy method) + batches = self._df.to_batches() + return len(batches) def column_names(self) -> Iterable[str]: """ Return an iterator yielding the column names. """ - return self._df.column_names + return self._df.schema.names def get_column(self, i: int) -> _PyArrowColumn: """ @@ -182,21 +187,31 @@ def get_chunks( Note that the producer must ensure that all columns are chunked the same way. """ + # Subdivide chunks if n_chunks and n_chunks > 1: chunk_size = self.num_rows() // n_chunks if self.num_rows() % n_chunks != 0: chunk_size += 1 - batches = self._df.to_batches(max_chunksize=chunk_size) + if isinstance(self._df, pa.Table): + batches = self._df.to_batches(max_chunksize=chunk_size) + else: + batches = [] + for start in range(0, chunk_size * n_chunks, chunk_size): + batches.append(self._df.slice(start, chunk_size)) # In case when the size of the chunk is such that the resulting # list is one less chunk then n_chunks -> append an empty chunk if len(batches) == n_chunks - 1: batches.append(pa.record_batch([[]], schema=self._df.schema)) + # yields the chunks that the data is stored as else: - batches = self._df.to_batches() - - iterator_tables = [_PyArrowDataFrame( - pa.Table.from_batches([batch]), self._nan_as_null, self._allow_copy - ) - for batch in batches - ] - return iterator_tables + if isinstance(self._df, pa.Table): + batches = self._df.to_batches() + else: + batches = [self._df] + + # Create an iterator of RecordBatches + iterator = [_PyArrowDataFrame(batch, + self._nan_as_null, + self._allow_copy) + for batch in batches] + return iterator diff --git a/python/pyarrow/interchange/from_dataframe.py b/python/pyarrow/interchange/from_dataframe.py index 204530a3354..801d0dd452a 100644 --- a/python/pyarrow/interchange/from_dataframe.py +++ b/python/pyarrow/interchange/from_dataframe.py @@ -60,8 +60,7 @@ def from_dataframe(df: DataFrameObject, allow_copy=True) -> pa.Table: """ - Build a ``pa.Table`` from any DataFrame supporting the interchange - protocol. + Build a ``pa.Table`` from any DataFrame supporting the interchange protocol. Parameters ---------- @@ -78,6 +77,8 @@ def from_dataframe(df: DataFrameObject, allow_copy=True) -> pa.Table: """ if isinstance(df, pa.Table): return df + elif isinstance(df, pa.RecordBatch): + return pa.Table.from_batches([df]) if not hasattr(df, "__dataframe__"): raise ValueError("`df` does not support __dataframe__") @@ -108,8 +109,7 @@ def _from_dataframe(df: DataFrameObject, allow_copy=True): batch = protocol_df_chunk_to_pyarrow(chunk, allow_copy) batches.append(batch) - table = pa.Table.from_batches(batches) - return table + return pa.Table.from_batches(batches) def protocol_df_chunk_to_pyarrow( diff --git a/python/pyarrow/table.pxi b/python/pyarrow/table.pxi index 5b7989e5825..e400605e566 100644 --- a/python/pyarrow/table.pxi +++ b/python/pyarrow/table.pxi @@ -1517,6 +1517,39 @@ cdef class RecordBatch(_PandasConvertible): self.sp_batch = batch self.batch = batch.get() + # ---------------------------------------------------------------------- + def __dataframe__(self, nan_as_null: bool = False, allow_copy: bool = True): + """ + Return the dataframe interchange object implementing the interchange protocol. + + Parameters + ---------- + nan_as_null : bool, default False + Whether to tell the DataFrame to overwrite null values in the data + with ``NaN`` (or ``NaT``). + allow_copy : bool, default True + Whether to allow memory copying when exporting. If set to False + it would cause non-zero-copy exports to fail. + + Returns + ------- + DataFrame interchange object + The object which consuming library can use to ingress the dataframe. + + Notes + ----- + Details on the interchange protocol: + https://data-apis.org/dataframe-protocol/latest/index.html + `nan_as_null` currently has no effect; once support for nullable extension + dtypes is added, this value should be propagated to columns. + """ + + from pyarrow.interchange.dataframe import _PyArrowDataFrame + + return _PyArrowDataFrame(self, nan_as_null, allow_copy) + + # ---------------------------------------------------------------------- + @staticmethod def from_pydict(mapping, schema=None, metadata=None): """ diff --git a/python/pyarrow/tests/interchange/test_conversion.py b/python/pyarrow/tests/interchange/test_conversion.py index 0680d9c4ec1..089f316e508 100644 --- a/python/pyarrow/tests/interchange/test_conversion.py +++ b/python/pyarrow/tests/interchange/test_conversion.py @@ -108,6 +108,7 @@ def test_categorical_roundtrip(): if Version(pd.__version__) < Version("1.5.0"): pytest.skip("__dataframe__ added to pandas in 1.5.0") + arr = ["Mon", "Tue", "Mon", "Wed", "Mon", "Thu", "Fri", "Sat", "Sun"] table = pa.table( {"weekday": pa.array(arr).dictionary_encode()} @@ -447,7 +448,7 @@ def test_pyarrow_roundtrip_categorical(offset, length): assert col_result.size() == col_table.size() assert col_result.offset == col_table.offset - desc_cat_table = col_result.describe_categorical + desc_cat_table = col_table.describe_categorical desc_cat_result = col_result.describe_categorical assert desc_cat_table["is_ordered"] == desc_cat_result["is_ordered"] diff --git a/python/pyarrow/tests/interchange/test_interchange_spec.py b/python/pyarrow/tests/interchange/test_interchange_spec.py index 42ec8053599..7b2b8eb7208 100644 --- a/python/pyarrow/tests/interchange/test_interchange_spec.py +++ b/python/pyarrow/tests/interchange/test_interchange_spec.py @@ -76,8 +76,10 @@ def test_dtypes(arr): ) @pytest.mark.parametrize("unit", ['s', 'ms', 'us', 'ns']) @pytest.mark.parametrize("tz", ['', 'America/New_York', '+07:30', '-04:30']) +@pytest.mark.parametrize("use_batch", [False, True]) def test_mixed_dtypes(uint, uint_bw, int, int_bw, - float, float_bw, np_float, unit, tz): + float, float_bw, np_float, unit, tz, + use_batch): from datetime import datetime as dt arr = [1, 2, 3] dt_arr = [dt(2007, 7, 13), dt(2007, 7, 14), dt(2007, 7, 15)] @@ -91,6 +93,8 @@ def test_mixed_dtypes(uint, uint_bw, int, int_bw, "f": pa.array(dt_arr, type=pa.timestamp(unit, tz=tz)) } ) + if use_batch: + table = table.to_batches()[0] df = table.__dataframe__() # 0 = DtypeKind.INT, 1 = DtypeKind.UINT, 2 = DtypeKind.FLOAT, # 20 = DtypeKind.BOOL, 21 = DtypeKind.STRING, 22 = DtypeKind.DATETIME @@ -126,12 +130,15 @@ def test_noncategorical(): col.describe_categorical -def test_categorical(): +@pytest.mark.parametrize("use_batch", [False, True]) +def test_categorical(use_batch): import pyarrow as pa arr = ["Mon", "Tue", "Mon", "Wed", "Mon", "Thu", "Fri", "Sat", None] table = pa.table( {"weekday": pa.array(arr).dictionary_encode()} ) + if use_batch: + table = table.to_batches()[0] col = table.__dataframe__().get_column_by_name("weekday") categorical = col.describe_categorical @@ -139,34 +146,46 @@ def test_categorical(): assert isinstance(categorical["is_dictionary"], bool) -def test_dataframe(): +@pytest.mark.parametrize("use_batch", [False, True]) +def test_dataframe(use_batch): n = pa.chunked_array([[2, 2, 4], [4, 5, 100]]) a = pa.chunked_array([["Flamingo", "Parrot", "Cow"], ["Horse", "Brittle stars", "Centipede"]]) table = pa.table([n, a], names=['n_legs', 'animals']) + if use_batch: + table = table.combine_chunks().to_batches()[0] df = table.__dataframe__() assert df.num_columns() == 2 assert df.num_rows() == 6 - assert df.num_chunks() == 2 + if use_batch: + assert df.num_chunks() == 1 + else: + assert df.num_chunks() == 2 assert list(df.column_names()) == ['n_legs', 'animals'] assert list(df.select_columns((1,)).column_names()) == list( df.select_columns_by_name(("animals",)).column_names() ) +@pytest.mark.parametrize("use_batch", [False, True]) @pytest.mark.parametrize(["size", "n_chunks"], [(10, 3), (12, 3), (12, 5)]) -def test_df_get_chunks(size, n_chunks): +def test_df_get_chunks(use_batch, size, n_chunks): table = pa.table({"x": list(range(size))}) + if use_batch: + table = table.to_batches()[0] df = table.__dataframe__() chunks = list(df.get_chunks(n_chunks)) assert len(chunks) == n_chunks assert sum(chunk.num_rows() for chunk in chunks) == size +@pytest.mark.parametrize("use_batch", [False, True]) @pytest.mark.parametrize(["size", "n_chunks"], [(10, 3), (12, 3), (12, 5)]) -def test_column_get_chunks(size, n_chunks): +def test_column_get_chunks(use_batch, size, n_chunks): table = pa.table({"x": list(range(size))}) + if use_batch: + table = table.to_batches()[0] df = table.__dataframe__() chunks = list(df.get_column(0).get_chunks(n_chunks)) assert len(chunks) == n_chunks @@ -187,7 +206,8 @@ def test_column_get_chunks(size, n_chunks): (pa.float64(), np.float64) ] ) -def test_get_columns(uint, int, float, np_float): +@pytest.mark.parametrize("use_batch", [False, True]) +def test_get_columns(uint, int, float, np_float, use_batch): arr = [[1, 2, 3], [4, 5]] arr_float = np.array([1, 2, 3, 4, 5], dtype=np_float) table = pa.table( @@ -197,6 +217,8 @@ def test_get_columns(uint, int, float, np_float): "c": pa.array(arr_float, type=float) } ) + if use_batch: + table = table.combine_chunks().to_batches()[0] df = table.__dataframe__() for col in df.get_columns(): assert col.size() == 5 @@ -212,9 +234,12 @@ def test_get_columns(uint, int, float, np_float): @pytest.mark.parametrize( "int", [pa.int8(), pa.int16(), pa.int32(), pa.int64()] ) -def test_buffer(int): +@pytest.mark.parametrize("use_batch", [False, True]) +def test_buffer(int, use_batch): arr = [0, 1, -1] table = pa.table({"a": pa.array(arr, type=int)}) + if use_batch: + table = table.to_batches()[0] df = table.__dataframe__() col = df.get_column(0) buf = col.get_buffers()