From fee7fa23df280f2d9309be6d425712570c0d5e90 Mon Sep 17 00:00:00 2001 From: Alenka Frim Date: Mon, 30 Jan 2023 10:43:59 +0100 Subject: [PATCH 01/11] Update RecordBatch class and from_datafame method --- python/pyarrow/interchange/from_dataframe.py | 16 ++++++----- python/pyarrow/table.pxi | 30 ++++++++++++++++++++ 2 files changed, 39 insertions(+), 7 deletions(-) diff --git a/python/pyarrow/interchange/from_dataframe.py b/python/pyarrow/interchange/from_dataframe.py index 204530a3354..1a58f9b5e46 100644 --- a/python/pyarrow/interchange/from_dataframe.py +++ b/python/pyarrow/interchange/from_dataframe.py @@ -58,10 +58,10 @@ } -def from_dataframe(df: DataFrameObject, allow_copy=True) -> pa.Table: +def from_dataframe(df: DataFrameObject, allow_copy=True) -> pa.Table | pa.RecordBatch: """ - Build a ``pa.Table`` from any DataFrame supporting the interchange - protocol. + Build a ``pa.Table`` or a ``pa.RecordBatch`` from any DataFrame supporting + the interchange protocol. Parameters ---------- @@ -74,9 +74,9 @@ def from_dataframe(df: DataFrameObject, allow_copy=True) -> pa.Table: Returns ------- - pa.Table + pa.Table or pa.RecordBatch """ - if isinstance(df, pa.Table): + if isinstance(df, (pa.Table, pa.RecordBatch)): return df if not hasattr(df, "__dataframe__"): @@ -108,8 +108,10 @@ def _from_dataframe(df: DataFrameObject, allow_copy=True): batch = protocol_df_chunk_to_pyarrow(chunk, allow_copy) batches.append(batch) - table = pa.Table.from_batches(batches) - return table + if len(batches) == 1: + return batches[0] + else: + return pa.Table.from_batches(batches) def protocol_df_chunk_to_pyarrow( diff --git a/python/pyarrow/table.pxi b/python/pyarrow/table.pxi index 5b7989e5825..506fdac3d74 100644 --- a/python/pyarrow/table.pxi +++ b/python/pyarrow/table.pxi @@ -1517,6 +1517,36 @@ cdef class RecordBatch(_PandasConvertible): self.sp_batch = batch self.batch = batch.get() + # ---------------------------------------------------------------------- + def __dataframe__(self, nan_as_null: bool = False, allow_copy: bool = True): + """ + Return the dataframe interchange object implementing the interchange protocol. + Parameters + ---------- + nan_as_null : bool, default False + Whether to tell the DataFrame to overwrite null values in the data + with ``NaN`` (or ``NaT``). + allow_copy : bool, default True + Whether to allow memory copying when exporting. If set to False + it would cause non-zero-copy exports to fail. + Returns + ------- + DataFrame interchange object + The object which consuming library can use to ingress the dataframe. + Notes + ----- + Details on the interchange protocol: + https://data-apis.org/dataframe-protocol/latest/index.html + `nan_as_null` currently has no effect; once support for nullable extension + dtypes is added, this value should be propagated to columns. + """ + + from pyarrow.interchange.dataframe import _PyArrowDataFrame + + return _PyArrowDataFrame(self, nan_as_null, allow_copy) + + # ---------------------------------------------------------------------- + @staticmethod def from_pydict(mapping, schema=None, metadata=None): """ From f2d82af571507377f1c8c8d8ab4ce646ad3eca73 Mon Sep 17 00:00:00 2001 From: Alenka Frim Date: Tue, 31 Jan 2023 10:11:46 +0100 Subject: [PATCH 02/11] Update dataframe interchange classes to include RecordBatches --- python/pyarrow/interchange/dataframe.py | 54 ++++++++++++------- .../tests/interchange/test_conversion.py | 12 ++++- 2 files changed, 46 insertions(+), 20 deletions(-) diff --git a/python/pyarrow/interchange/dataframe.py b/python/pyarrow/interchange/dataframe.py index d0717e02e88..ee389bb1c19 100644 --- a/python/pyarrow/interchange/dataframe.py +++ b/python/pyarrow/interchange/dataframe.py @@ -44,11 +44,13 @@ class _PyArrowDataFrame: """ def __init__( - self, df: pa.Table, nan_as_null: bool = False, allow_copy: bool = True + self, df: pa.Table | pa.RecordBatch, + nan_as_null: bool = False, + allow_copy: bool = True ) -> None: """ Constructor - an instance of this (private) class is returned from - `pa.Table.__dataframe__`. + `pa.Table.__dataframe__` or `pa.RecordBatch.__dataframe__`. """ self._df = df # ``nan_as_null`` is a keyword intended for the consumer to tell the @@ -114,18 +116,24 @@ def num_chunks(self) -> int: """ Return the number of chunks the DataFrame consists of. """ - # pyarrow.Table can have columns with different number - # of chunks so we take the number of chunks that - # .to_batches() returns as it takes the min chunk size - # of all the columns (to_batches is a zero copy method) - batches = self._df.to_batches() - return len(batches) + if isinstance(self._df, pa.RecordBatch): + return 1 + else: + # pyarrow.Table can have columns with different number + # of chunks so we take the number of chunks that + # .to_batches() returns as it takes the min chunk size + # of all the columns (to_batches is a zero copy method) + batches = self._df.to_batches() + return len(batches) def column_names(self) -> Iterable[str]: """ Return an iterator yielding the column names. """ - return self._df.column_names + if isinstance(self._df, pa.RecordBatch): + return self._df.schema.names + else: + return self._df.column_names def get_column(self, i: int) -> _PyArrowColumn: """ @@ -182,21 +190,31 @@ def get_chunks( Note that the producer must ensure that all columns are chunked the same way. """ + # Subdivide chunks if n_chunks and n_chunks > 1: chunk_size = self.num_rows() // n_chunks if self.num_rows() % n_chunks != 0: chunk_size += 1 - batches = self._df.to_batches(max_chunksize=chunk_size) + if isinstance(self._df, pa.Table): + batches = self._df.to_batches(max_chunksize=chunk_size) + else: + batches = [] + for start in range(0, chunk_size * n_chunks, chunk_size): + batches.append(self.slice(start, chunk_size)) # In case when the size of the chunk is such that the resulting # list is one less chunk then n_chunks -> append an empty chunk if len(batches) == n_chunks - 1: batches.append(pa.record_batch([[]], schema=self._df.schema)) + # yields the chunks that the data is stored as else: - batches = self._df.to_batches() - - iterator_tables = [_PyArrowDataFrame( - pa.Table.from_batches([batch]), self._nan_as_null, self._allow_copy - ) - for batch in batches - ] - return iterator_tables + if isinstance(self._df, pa.Table): + batches = self._df.to_batches() + else: + batches = self._df + + # Create an iterator of RecordBatches + iterator = [_PyArrowDataFrame(batch, + self._nan_as_null, + self._allow_copy) + for batch in batches] + return iterator \ No newline at end of file diff --git a/python/pyarrow/tests/interchange/test_conversion.py b/python/pyarrow/tests/interchange/test_conversion.py index 0680d9c4ec1..9cd338a8f49 100644 --- a/python/pyarrow/tests/interchange/test_conversion.py +++ b/python/pyarrow/tests/interchange/test_conversion.py @@ -85,6 +85,7 @@ def test_offset_of_sliced_array(): assert col.offset == 2 result = _from_dataframe(table_sliced.__dataframe__()) + result = pa.Table.from_batches([result]) assert table_sliced.equals(result) assert not table.equals(result) @@ -176,6 +177,7 @@ def test_pandas_roundtrip(uint, int, float, np_float): ) pandas_df = pandas_from_dataframe(table) result = pi.from_dataframe(pandas_df) + result = pa.Table.from_batches([result]) assert table.equals(result) table_protocol = table.__dataframe__() @@ -201,6 +203,7 @@ def test_roundtrip_pandas_string(): ) pandas_df = pandas_from_dataframe(table) result = pi.from_dataframe(pandas_df) + result = pa.Table.from_batches([result]) assert result[0].to_pylist() == table[0].to_pylist() assert pa.types.is_string(table[0].type) @@ -227,6 +230,7 @@ def test_roundtrip_pandas_boolean(): ) pandas_df = pandas_from_dataframe(table) result = pi.from_dataframe(pandas_df) + result = pa.Table.from_batches([result]) assert table.equals(result) @@ -263,6 +267,7 @@ def test_roundtrip_pandas_datetime(unit): ) pandas_df = pandas_from_dataframe(table) result = pi.from_dataframe(pandas_df) + result = pa.Table.from_batches([result]) assert expected.equals(result) @@ -313,6 +318,7 @@ def test_pandas_to_pyarrow_with_missing(np_float): "dt": pa.array(datetime_array, type=pa.timestamp("ns")) }) result = pi.from_dataframe(df) + result = pa.Table.from_batches([result]) assert result.equals(expected) @@ -367,8 +373,8 @@ def test_pandas_to_pyarrow_categorical_with_missing(): expected_indices = pa.array([1, 4, 1, 5, 1, 3, 0, 2, None], type=pa.int8()) assert result[0].to_pylist() == arr - assert result[0].chunk(0).dictionary.to_pylist() == expected_dictionary - assert result[0].chunk(0).indices.equals(expected_indices) + assert result[0].dictionary.to_pylist() == expected_dictionary + assert result[0].indices.equals(expected_indices) @pytest.mark.parametrize( @@ -408,6 +414,7 @@ def test_pyarrow_roundtrip(uint, int, float, np_float, ) table = table.slice(offset, length) result = _from_dataframe(table.__dataframe__()) + result = pa.Table.from_batches([result]) assert table.equals(result) @@ -428,6 +435,7 @@ def test_pyarrow_roundtrip_categorical(offset, length): ) table = table.slice(offset, length) result = _from_dataframe(table.__dataframe__()) + result = pa.Table.from_batches([result]) assert table.equals(result) From 464e41b36bbfbfeaadfbebb4ebe0f6efca86ada2 Mon Sep 17 00:00:00 2001 From: Alenka Frim Date: Wed, 22 Feb 2023 14:25:53 +0100 Subject: [PATCH 03/11] Update dataframe.py select methods --- python/pyarrow/interchange/dataframe.py | 25 +++++++++++++++++-------- 1 file changed, 17 insertions(+), 8 deletions(-) diff --git a/python/pyarrow/interchange/dataframe.py b/python/pyarrow/interchange/dataframe.py index ee389bb1c19..3776b5c3aad 100644 --- a/python/pyarrow/interchange/dataframe.py +++ b/python/pyarrow/interchange/dataframe.py @@ -162,9 +162,14 @@ def select_columns(self, indices: Sequence[int]) -> _PyArrowDataFrame: """ Create a new DataFrame by selecting a subset of columns by index. """ - return _PyArrowDataFrame( - self._df.select(list(indices)), self._nan_as_null, self._allow_copy - ) + if isinstance(self._df, pa.RecordBatch): + columns = [self._df.column(i) for i in indices] + names = [self._df.schema.names[i] for i in indices] + return _PyArrowDataFrame(pa.record_batch(columns, names=names)) + else: + return _PyArrowDataFrame( + self._df.select(list(indices)), self._nan_as_null, self._allow_copy + ) def select_columns_by_name( self, names: Sequence[str] @@ -172,9 +177,13 @@ def select_columns_by_name( """ Create a new DataFrame by selecting a subset of columns by name. """ - return _PyArrowDataFrame( - self._df.select(list(names)), self._nan_as_null, self._allow_copy - ) + if isinstance(self._df, pa.RecordBatch): + columns = [self._df[i] for i in names] + return _PyArrowDataFrame(pa.record_batch(columns, names=names)) + else: + return _PyArrowDataFrame( + self._df.select(list(names)), self._nan_as_null, self._allow_copy + ) def get_chunks( self, n_chunks: Optional[int] = None @@ -200,7 +209,7 @@ def get_chunks( else: batches = [] for start in range(0, chunk_size * n_chunks, chunk_size): - batches.append(self.slice(start, chunk_size)) + batches.append(self._df.slice(start, chunk_size)) # In case when the size of the chunk is such that the resulting # list is one less chunk then n_chunks -> append an empty chunk if len(batches) == n_chunks - 1: @@ -217,4 +226,4 @@ def get_chunks( self._nan_as_null, self._allow_copy) for batch in batches] - return iterator \ No newline at end of file + return iterator From a630e7a5f7248f0967c97bce861a805b47d92f9f Mon Sep 17 00:00:00 2001 From: Alenka Frim Date: Wed, 22 Feb 2023 14:33:05 +0100 Subject: [PATCH 04/11] Reset from_dataframe to return only Table, not RecordBatch --- python/pyarrow/interchange/from_dataframe.py | 5 +---- python/pyarrow/tests/interchange/test_conversion.py | 12 ++---------- 2 files changed, 3 insertions(+), 14 deletions(-) diff --git a/python/pyarrow/interchange/from_dataframe.py b/python/pyarrow/interchange/from_dataframe.py index 1a58f9b5e46..28a1abe4463 100644 --- a/python/pyarrow/interchange/from_dataframe.py +++ b/python/pyarrow/interchange/from_dataframe.py @@ -108,10 +108,7 @@ def _from_dataframe(df: DataFrameObject, allow_copy=True): batch = protocol_df_chunk_to_pyarrow(chunk, allow_copy) batches.append(batch) - if len(batches) == 1: - return batches[0] - else: - return pa.Table.from_batches(batches) + return pa.Table.from_batches(batches) def protocol_df_chunk_to_pyarrow( diff --git a/python/pyarrow/tests/interchange/test_conversion.py b/python/pyarrow/tests/interchange/test_conversion.py index 9cd338a8f49..0680d9c4ec1 100644 --- a/python/pyarrow/tests/interchange/test_conversion.py +++ b/python/pyarrow/tests/interchange/test_conversion.py @@ -85,7 +85,6 @@ def test_offset_of_sliced_array(): assert col.offset == 2 result = _from_dataframe(table_sliced.__dataframe__()) - result = pa.Table.from_batches([result]) assert table_sliced.equals(result) assert not table.equals(result) @@ -177,7 +176,6 @@ def test_pandas_roundtrip(uint, int, float, np_float): ) pandas_df = pandas_from_dataframe(table) result = pi.from_dataframe(pandas_df) - result = pa.Table.from_batches([result]) assert table.equals(result) table_protocol = table.__dataframe__() @@ -203,7 +201,6 @@ def test_roundtrip_pandas_string(): ) pandas_df = pandas_from_dataframe(table) result = pi.from_dataframe(pandas_df) - result = pa.Table.from_batches([result]) assert result[0].to_pylist() == table[0].to_pylist() assert pa.types.is_string(table[0].type) @@ -230,7 +227,6 @@ def test_roundtrip_pandas_boolean(): ) pandas_df = pandas_from_dataframe(table) result = pi.from_dataframe(pandas_df) - result = pa.Table.from_batches([result]) assert table.equals(result) @@ -267,7 +263,6 @@ def test_roundtrip_pandas_datetime(unit): ) pandas_df = pandas_from_dataframe(table) result = pi.from_dataframe(pandas_df) - result = pa.Table.from_batches([result]) assert expected.equals(result) @@ -318,7 +313,6 @@ def test_pandas_to_pyarrow_with_missing(np_float): "dt": pa.array(datetime_array, type=pa.timestamp("ns")) }) result = pi.from_dataframe(df) - result = pa.Table.from_batches([result]) assert result.equals(expected) @@ -373,8 +367,8 @@ def test_pandas_to_pyarrow_categorical_with_missing(): expected_indices = pa.array([1, 4, 1, 5, 1, 3, 0, 2, None], type=pa.int8()) assert result[0].to_pylist() == arr - assert result[0].dictionary.to_pylist() == expected_dictionary - assert result[0].indices.equals(expected_indices) + assert result[0].chunk(0).dictionary.to_pylist() == expected_dictionary + assert result[0].chunk(0).indices.equals(expected_indices) @pytest.mark.parametrize( @@ -414,7 +408,6 @@ def test_pyarrow_roundtrip(uint, int, float, np_float, ) table = table.slice(offset, length) result = _from_dataframe(table.__dataframe__()) - result = pa.Table.from_batches([result]) assert table.equals(result) @@ -435,7 +428,6 @@ def test_pyarrow_roundtrip_categorical(offset, length): ) table = table.slice(offset, length) result = _from_dataframe(table.__dataframe__()) - result = pa.Table.from_batches([result]) assert table.equals(result) From add9631b0341e39f7efe8f6d586847b3f8e36806 Mon Sep 17 00:00:00 2001 From: Alenka Frim Date: Thu, 23 Feb 2023 10:55:52 +0100 Subject: [PATCH 05/11] Add tests --- python/pyarrow/interchange/dataframe.py | 2 +- .../tests/interchange/test_conversion.py | 169 +++++++++++++++++- .../interchange/test_interchange_spec.py | 103 +++++++++++ 3 files changed, 272 insertions(+), 2 deletions(-) diff --git a/python/pyarrow/interchange/dataframe.py b/python/pyarrow/interchange/dataframe.py index 3776b5c3aad..c92c1a22555 100644 --- a/python/pyarrow/interchange/dataframe.py +++ b/python/pyarrow/interchange/dataframe.py @@ -219,7 +219,7 @@ def get_chunks( if isinstance(self._df, pa.Table): batches = self._df.to_batches() else: - batches = self._df + batches = [self._df] # Create an iterator of RecordBatches iterator = [_PyArrowDataFrame(batch, diff --git a/python/pyarrow/tests/interchange/test_conversion.py b/python/pyarrow/tests/interchange/test_conversion.py index 0680d9c4ec1..d55fc5842d7 100644 --- a/python/pyarrow/tests/interchange/test_conversion.py +++ b/python/pyarrow/tests/interchange/test_conversion.py @@ -49,6 +49,15 @@ def test_datetime(unit, tz): assert col.dtype[0] == DtypeKind.DATETIME assert col.describe_null == (ColumnNullType.USE_BITMASK, 0) + batch = table.to_batches()[0] + col = batch.__dataframe__().get_column_by_name("A") + + assert col.size() == 3 + assert col.offset == 0 + assert col.null_count == 1 + assert col.dtype[0] == DtypeKind.DATETIME + assert col.describe_null == (ColumnNullType.USE_BITMASK, 0) + @pytest.mark.parametrize( ["test_data", "kind"], @@ -98,6 +107,15 @@ def test_offset_of_sliced_array(): # tm.assert_series_equal(df["arr"][2:4], df_sliced["arr_sliced"], # check_index=False, check_names=False) + batch_sliced = pa.record_batch([arr_sliced], names=["arr_sliced"]) + + col = batch_sliced.__dataframe__().get_column(0) + assert col.offset == 2 + + result = _from_dataframe(batch_sliced.__dataframe__()) + assert table_sliced.equals(result) + assert not table.equals(result) + # Currently errors due to string conversion # as col.size is called as a property not method in pandas @@ -108,6 +126,7 @@ def test_categorical_roundtrip(): if Version(pd.__version__) < Version("1.5.0"): pytest.skip("__dataframe__ added to pandas in 1.5.0") + arr = ["Mon", "Tue", "Mon", "Wed", "Mon", "Thu", "Fri", "Sat", "Sun"] table = pa.table( {"weekday": pa.array(arr).dictionary_encode()} @@ -144,6 +163,38 @@ def test_categorical_roundtrip(): assert desc_cat_table["is_dictionary"] == desc_cat_result["is_dictionary"] assert isinstance(desc_cat_result["categories"]._col, pa.Array) + batch = table.to_batches()[0] + pandas_df = batch.to_pandas() + result = pi.from_dataframe(pandas_df) + + # Checking equality for the values + # As the dtype of the indices is changed from int32 in pa.Table + # to int64 in pandas interchange protocol implementation + assert result[0].chunk(0).dictionary == table[0].chunk(0).dictionary + + batch_protocol = batch.__dataframe__() + result_protocol = result.__dataframe__() + + assert batch_protocol.num_columns() == result_protocol.num_columns() + assert batch_protocol.num_rows() == result_protocol.num_rows() + assert batch_protocol.num_chunks() == result_protocol.num_chunks() + assert batch_protocol.column_names() == result_protocol.column_names() + + col_table = batch_protocol.get_column(0) + col_result = result_protocol.get_column(0) + + assert col_result.dtype[0] == DtypeKind.CATEGORICAL + assert col_result.dtype[0] == col_table.dtype[0] + assert col_result.size == col_table.size + assert col_result.offset == col_table.offset + + desc_cat_table = col_result.describe_categorical + desc_cat_result = col_result.describe_categorical + + assert desc_cat_table["is_ordered"] == desc_cat_result["is_ordered"] + assert desc_cat_table["is_dictionary"] == desc_cat_result["is_dictionary"] + assert isinstance(desc_cat_result["categories"]._col, pa.Array) + @pytest.mark.pandas @pytest.mark.parametrize( @@ -186,6 +237,19 @@ def test_pandas_roundtrip(uint, int, float, np_float): assert table_protocol.num_chunks() == result_protocol.num_chunks() assert table_protocol.column_names() == result_protocol.column_names() + batch = table.to_batches()[0] + pandas_df = pandas_from_dataframe(batch) + result = pi.from_dataframe(pandas_df) + assert table.equals(result) + + batch_protocol = batch.__dataframe__() + result_protocol = result.__dataframe__() + + assert batch_protocol.num_columns() == result_protocol.num_columns() + assert batch_protocol.num_rows() == result_protocol.num_rows() + assert batch_protocol.num_chunks() == result_protocol.num_chunks() + assert batch_protocol.column_names() == result_protocol.column_names() + @pytest.mark.pandas def test_roundtrip_pandas_string(): @@ -214,6 +278,22 @@ def test_roundtrip_pandas_string(): assert table_protocol.num_chunks() == result_protocol.num_chunks() assert table_protocol.column_names() == result_protocol.column_names() + batch = table.to_batches()[0] + pandas_df = pandas_from_dataframe(batch) + result = pi.from_dataframe(pandas_df) + + assert result[0].to_pylist() == batch[0].to_pylist() + assert pa.types.is_string(table[0].type) + assert pa.types.is_large_string(result[0].type) + + batch_protocol = batch.__dataframe__() + result_protocol = result.__dataframe__() + + assert batch_protocol.num_columns() == result_protocol.num_columns() + assert batch_protocol.num_rows() == result_protocol.num_rows() + assert batch_protocol.num_chunks() == result_protocol.num_chunks() + assert batch_protocol.column_names() == result_protocol.column_names() + @pytest.mark.pandas def test_roundtrip_pandas_boolean(): @@ -238,6 +318,20 @@ def test_roundtrip_pandas_boolean(): assert table_protocol.num_chunks() == result_protocol.num_chunks() assert table_protocol.column_names() == result_protocol.column_names() + batch = table.to_batches()[0] + pandas_df = pandas_from_dataframe(batch) + result = pi.from_dataframe(pandas_df) + + assert table.equals(result) + + batch_protocol = batch.__dataframe__() + result_protocol = result.__dataframe__() + + assert batch_protocol.num_columns() == result_protocol.num_columns() + assert batch_protocol.num_rows() == result_protocol.num_rows() + assert batch_protocol.num_chunks() == result_protocol.num_chunks() + assert batch_protocol.column_names() == result_protocol.column_names() + @pytest.mark.pandas @pytest.mark.parametrize("unit", ['s', 'ms', 'us', 'ns']) @@ -274,6 +368,21 @@ def test_roundtrip_pandas_datetime(unit): assert expected_protocol.num_chunks() == result_protocol.num_chunks() assert expected_protocol.column_names() == result_protocol.column_names() + batch = table.to_batches()[0] + pandas_df = pandas_from_dataframe(batch) + result = pi.from_dataframe(pandas_df) + + assert expected.equals(result) + + expected_batch = expected.to_batches()[0] + expected_protocol = expected_batch.__dataframe__() + result_protocol = result.__dataframe__() + + assert expected_protocol.num_columns() == result_protocol.num_columns() + assert expected_protocol.num_rows() == result_protocol.num_rows() + assert expected_protocol.num_chunks() == result_protocol.num_chunks() + assert expected_protocol.column_names() == result_protocol.column_names() + @pytest.mark.large_memory @pytest.mark.pandas @@ -293,6 +402,10 @@ def test_pandas_assertion_error_large_string(): with pytest.raises(AssertionError): pandas_from_dataframe(table) + batch = table.to_batches()[0] + with pytest.raises(AssertionError): + pandas_from_dataframe(batch) + @pytest.mark.pandas @pytest.mark.parametrize( @@ -419,6 +532,19 @@ def test_pyarrow_roundtrip(uint, int, float, np_float, assert table_protocol.num_chunks() == result_protocol.num_chunks() assert table_protocol.column_names() == result_protocol.column_names() + batch = table.to_batches()[0] + result = _from_dataframe(batch.__dataframe__()) + + assert table.equals(result) + + batch_protocol = batch.__dataframe__() + result_protocol = result.__dataframe__() + + assert batch_protocol.num_columns() == result_protocol.num_columns() + assert batch_protocol.num_rows() == result_protocol.num_rows() + assert batch_protocol.num_chunks() == result_protocol.num_chunks() + assert batch_protocol.column_names() == result_protocol.column_names() + @pytest.mark.parametrize("offset, length", [(0, 10), (0, 2), (7, 3), (2, 1)]) def test_pyarrow_roundtrip_categorical(offset, length): @@ -447,13 +573,41 @@ def test_pyarrow_roundtrip_categorical(offset, length): assert col_result.size() == col_table.size() assert col_result.offset == col_table.offset - desc_cat_table = col_result.describe_categorical + desc_cat_table = col_table.describe_categorical desc_cat_result = col_result.describe_categorical assert desc_cat_table["is_ordered"] == desc_cat_result["is_ordered"] assert desc_cat_table["is_dictionary"] == desc_cat_result["is_dictionary"] assert isinstance(desc_cat_result["categories"]._col, pa.Array) + batch = table.to_batches()[0] + result = _from_dataframe(batch.__dataframe__()) + + assert table.equals(result) + + batch_protocol = batch.__dataframe__() + result_protocol = result.__dataframe__() + + assert batch_protocol.num_columns() == result_protocol.num_columns() + assert batch_protocol.num_rows() == result_protocol.num_rows() + assert batch_protocol.num_chunks() == result_protocol.num_chunks() + assert batch_protocol.column_names() == result_protocol.column_names() + + col_batch = batch_protocol.get_column(0) + col_result = result_protocol.get_column(0) + + assert col_result.dtype[0] == DtypeKind.CATEGORICAL + assert col_result.dtype[0] == col_batch.dtype[0] + assert col_result.size() == col_batch.size() + assert col_result.offset == col_batch.offset + + desc_col_batch = col_batch.describe_categorical + desc_cat_result = col_result.describe_categorical + + assert desc_col_batch["is_ordered"] == desc_cat_result["is_ordered"] + assert desc_col_batch["is_dictionary"] == desc_cat_result["is_dictionary"] + assert isinstance(desc_cat_result["categories"]._col, pa.Array) + @pytest.mark.large_memory def test_pyarrow_roundtrip_large_string(): @@ -471,11 +625,24 @@ def test_pyarrow_roundtrip_large_string(): assert table.equals(result) + batch = table.to_batches()[0] + result = _from_dataframe(batch.__dataframe__()) + col = result.__dataframe__().get_column(0) + + assert col.size() == 3*1024**2 + assert pa.types.is_large_string(table[0].type) + assert pa.types.is_large_string(result[0].type) + + assert table.equals(result) + def test_nan_as_null(): table = pa.table({"a": [1, 2, 3, 4]}) with pytest.raises(RuntimeError): table.__dataframe__(nan_as_null=True) + batch = table.to_batches()[0] + with pytest.raises(RuntimeError): + batch.__dataframe__(nan_as_null=True) @pytest.mark.pandas diff --git a/python/pyarrow/tests/interchange/test_interchange_spec.py b/python/pyarrow/tests/interchange/test_interchange_spec.py index 42ec8053599..cd87d889754 100644 --- a/python/pyarrow/tests/interchange/test_interchange_spec.py +++ b/python/pyarrow/tests/interchange/test_interchange_spec.py @@ -50,6 +50,15 @@ def test_dtypes(arr): assert df.get_column(0).size() == 3 assert df.get_column(0).offset == 0 + batch = pa.record_batch([arr], names=["a"]) + df = batch.__dataframe__() + + null_count = df.get_column(0).null_count + assert null_count == arr.null_count + assert isinstance(null_count, int) + assert df.get_column(0).size() == 3 + assert df.get_column(0).offset == 0 + @pytest.mark.parametrize( "uint, uint_bw", @@ -109,6 +118,21 @@ def test_mixed_dtypes(uint, uint_bw, int, int_bw, assert df.get_column_by_name("b").dtype[1] == int_bw assert df.get_column_by_name("c").dtype[1] == float_bw + batch = table.to_batches()[0] + df = batch.__dataframe__() + + for column, kind in columns.items(): + col = df.get_column_by_name(column) + + assert col.null_count == 0 + assert col.size() == 3 + assert col.offset == 0 + assert col.dtype[0] == kind + + assert df.get_column_by_name("a").dtype[1] == uint_bw + assert df.get_column_by_name("b").dtype[1] == int_bw + assert df.get_column_by_name("c").dtype[1] == float_bw + def test_na_float(): table = pa.table({"a": [1.0, None, 2.0]}) @@ -117,6 +141,12 @@ def test_na_float(): assert col.null_count == 1 assert isinstance(col.null_count, int) + batch = table.to_batches()[0] + df = batch.__dataframe__() + col = df.get_column_by_name("a") + assert col.null_count == 1 + assert isinstance(col.null_count, int) + def test_noncategorical(): table = pa.table({"a": [1, 2, 3]}) @@ -125,6 +155,12 @@ def test_noncategorical(): with pytest.raises(TypeError, match=".*categorical.*"): col.describe_categorical + batch = table.to_batches()[0] + df = batch.__dataframe__() + col = df.get_column_by_name("a") + with pytest.raises(TypeError, match=".*categorical.*"): + col.describe_categorical + def test_categorical(): import pyarrow as pa @@ -138,6 +174,12 @@ def test_categorical(): assert isinstance(categorical["is_ordered"], bool) assert isinstance(categorical["is_dictionary"], bool) + batch = table.to_batches()[0] + col = batch.__dataframe__().get_column_by_name("weekday") + categorical = col.describe_categorical + assert isinstance(categorical["is_ordered"], bool) + assert isinstance(categorical["is_dictionary"], bool) + def test_dataframe(): n = pa.chunked_array([[2, 2, 4], [4, 5, 100]]) @@ -154,6 +196,17 @@ def test_dataframe(): df.select_columns_by_name(("animals",)).column_names() ) + batch = table.combine_chunks().to_batches()[0] + df = batch.__dataframe__() + + assert df.num_columns() == 2 + assert df.num_rows() == 6 + assert df.num_chunks() == 1 + assert list(df.column_names()) == ['n_legs', 'animals'] + assert list(df.select_columns((1,)).column_names()) == list( + df.select_columns_by_name(("animals",)).column_names() + ) + @pytest.mark.parametrize(["size", "n_chunks"], [(10, 3), (12, 3), (12, 5)]) def test_df_get_chunks(size, n_chunks): @@ -163,6 +216,12 @@ def test_df_get_chunks(size, n_chunks): assert len(chunks) == n_chunks assert sum(chunk.num_rows() for chunk in chunks) == size + batch = table.to_batches()[0] + df = batch.__dataframe__() + chunks = list(df.get_chunks(n_chunks)) + assert len(chunks) == n_chunks + assert sum(chunk.num_rows() for chunk in chunks) == size + @pytest.mark.parametrize(["size", "n_chunks"], [(10, 3), (12, 3), (12, 5)]) def test_column_get_chunks(size, n_chunks): @@ -172,6 +231,12 @@ def test_column_get_chunks(size, n_chunks): assert len(chunks) == n_chunks assert sum(chunk.size() for chunk in chunks) == size + batch = table.to_batches()[0] + df = batch.__dataframe__() + chunks = list(df.get_column(0).get_chunks(n_chunks)) + assert len(chunks) == n_chunks + assert sum(chunk.size() for chunk in chunks) == size + @pytest.mark.pandas @pytest.mark.parametrize( @@ -208,6 +273,16 @@ def test_get_columns(uint, int, float, np_float): assert df.get_column(1).dtype[0] == 0 # INT assert df.get_column(2).dtype[0] == 2 # FLOAT + batch = table.combine_chunks().to_batches()[0] + df = batch.__dataframe__() + for col in df.get_columns(): + assert col.size() == 5 + assert col.num_chunks() == 1 + + assert df.get_column(0).dtype[0] == 1 + assert df.get_column(1).dtype[0] == 0 + assert df.get_column(2).dtype[0] == 2 + @pytest.mark.parametrize( "int", [pa.int8(), pa.int16(), pa.int32(), pa.int64()] @@ -241,3 +316,31 @@ def test_buffer(int): for idx, truth in enumerate(arr): val = ctype.from_address(dataBuf.ptr + idx * (bitwidth // 8)).value assert val == truth, f"Buffer at index {idx} mismatch" + + batch = table.to_batches()[0] + df = batch.__dataframe__() + col = df.get_column(0) + buf = col.get_buffers() + + dataBuf, dataDtype = buf["data"] + + assert dataBuf.bufsize > 0 + assert dataBuf.ptr != 0 + device, _ = dataBuf.__dlpack_device__() + + # 0 = DtypeKind.INT + # see DtypeKind class in column.py + assert dataDtype[0] == 0 + + if device == 1: # CPU-only as we're going to directly read memory here + bitwidth = dataDtype[1] + ctype = { + 8: ctypes.c_int8, + 16: ctypes.c_int16, + 32: ctypes.c_int32, + 64: ctypes.c_int64, + }[bitwidth] + + for idx, truth in enumerate(arr): + val = ctype.from_address(dataBuf.ptr + idx * (bitwidth // 8)).value + assert val == truth, f"Buffer at index {idx} mismatch" From 625cb859f68cf4ada0be604f935bc4e4741264e8 Mon Sep 17 00:00:00 2001 From: Alenka Frim Date: Mon, 27 Feb 2023 09:41:16 +0100 Subject: [PATCH 06/11] Update the column_names() method --- python/pyarrow/interchange/dataframe.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/python/pyarrow/interchange/dataframe.py b/python/pyarrow/interchange/dataframe.py index c92c1a22555..f7e9bf4eb4b 100644 --- a/python/pyarrow/interchange/dataframe.py +++ b/python/pyarrow/interchange/dataframe.py @@ -130,10 +130,7 @@ def column_names(self) -> Iterable[str]: """ Return an iterator yielding the column names. """ - if isinstance(self._df, pa.RecordBatch): - return self._df.schema.names - else: - return self._df.column_names + return self._df.schema.names def get_column(self, i: int) -> _PyArrowColumn: """ From 3e9cf123ee73a56ba6cc82316bf99a5bf3b71c7f Mon Sep 17 00:00:00 2001 From: Alenka Frim Date: Mon, 27 Feb 2023 09:42:19 +0100 Subject: [PATCH 07/11] Add missing blank lines in __dataframe__ method docstrings --- python/pyarrow/table.pxi | 3 +++ 1 file changed, 3 insertions(+) diff --git a/python/pyarrow/table.pxi b/python/pyarrow/table.pxi index 506fdac3d74..e400605e566 100644 --- a/python/pyarrow/table.pxi +++ b/python/pyarrow/table.pxi @@ -1521,6 +1521,7 @@ cdef class RecordBatch(_PandasConvertible): def __dataframe__(self, nan_as_null: bool = False, allow_copy: bool = True): """ Return the dataframe interchange object implementing the interchange protocol. + Parameters ---------- nan_as_null : bool, default False @@ -1529,10 +1530,12 @@ cdef class RecordBatch(_PandasConvertible): allow_copy : bool, default True Whether to allow memory copying when exporting. If set to False it would cause non-zero-copy exports to fail. + Returns ------- DataFrame interchange object The object which consuming library can use to ingress the dataframe. + Notes ----- Details on the interchange protocol: From a740b681e7e6f21ea6b079b3eb04b803fa45263b Mon Sep 17 00:00:00 2001 From: Alenka Frim Date: Mon, 27 Feb 2023 10:39:45 +0100 Subject: [PATCH 08/11] Update from_dataframe to have a consistent return type (pa.Table) --- python/pyarrow/interchange/from_dataframe.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/python/pyarrow/interchange/from_dataframe.py b/python/pyarrow/interchange/from_dataframe.py index 28a1abe4463..801d0dd452a 100644 --- a/python/pyarrow/interchange/from_dataframe.py +++ b/python/pyarrow/interchange/from_dataframe.py @@ -58,10 +58,9 @@ } -def from_dataframe(df: DataFrameObject, allow_copy=True) -> pa.Table | pa.RecordBatch: +def from_dataframe(df: DataFrameObject, allow_copy=True) -> pa.Table: """ - Build a ``pa.Table`` or a ``pa.RecordBatch`` from any DataFrame supporting - the interchange protocol. + Build a ``pa.Table`` from any DataFrame supporting the interchange protocol. Parameters ---------- @@ -74,10 +73,12 @@ def from_dataframe(df: DataFrameObject, allow_copy=True) -> pa.Table | pa.Record Returns ------- - pa.Table or pa.RecordBatch + pa.Table """ - if isinstance(df, (pa.Table, pa.RecordBatch)): + if isinstance(df, pa.Table): return df + elif isinstance(df, pa.RecordBatch): + return pa.Table.from_batches([df]) if not hasattr(df, "__dataframe__"): raise ValueError("`df` does not support __dataframe__") From 47ff167ee102270ba65a333e701ef195029eb013 Mon Sep 17 00:00:00 2001 From: Alenka Frim Date: Mon, 27 Feb 2023 11:37:42 +0100 Subject: [PATCH 09/11] Update the tests --- .../tests/interchange/test_conversion.py | 166 ------------------ .../interchange/test_interchange_spec.py | 143 ++++----------- 2 files changed, 33 insertions(+), 276 deletions(-) diff --git a/python/pyarrow/tests/interchange/test_conversion.py b/python/pyarrow/tests/interchange/test_conversion.py index d55fc5842d7..089f316e508 100644 --- a/python/pyarrow/tests/interchange/test_conversion.py +++ b/python/pyarrow/tests/interchange/test_conversion.py @@ -49,15 +49,6 @@ def test_datetime(unit, tz): assert col.dtype[0] == DtypeKind.DATETIME assert col.describe_null == (ColumnNullType.USE_BITMASK, 0) - batch = table.to_batches()[0] - col = batch.__dataframe__().get_column_by_name("A") - - assert col.size() == 3 - assert col.offset == 0 - assert col.null_count == 1 - assert col.dtype[0] == DtypeKind.DATETIME - assert col.describe_null == (ColumnNullType.USE_BITMASK, 0) - @pytest.mark.parametrize( ["test_data", "kind"], @@ -107,15 +98,6 @@ def test_offset_of_sliced_array(): # tm.assert_series_equal(df["arr"][2:4], df_sliced["arr_sliced"], # check_index=False, check_names=False) - batch_sliced = pa.record_batch([arr_sliced], names=["arr_sliced"]) - - col = batch_sliced.__dataframe__().get_column(0) - assert col.offset == 2 - - result = _from_dataframe(batch_sliced.__dataframe__()) - assert table_sliced.equals(result) - assert not table.equals(result) - # Currently errors due to string conversion # as col.size is called as a property not method in pandas @@ -163,38 +145,6 @@ def test_categorical_roundtrip(): assert desc_cat_table["is_dictionary"] == desc_cat_result["is_dictionary"] assert isinstance(desc_cat_result["categories"]._col, pa.Array) - batch = table.to_batches()[0] - pandas_df = batch.to_pandas() - result = pi.from_dataframe(pandas_df) - - # Checking equality for the values - # As the dtype of the indices is changed from int32 in pa.Table - # to int64 in pandas interchange protocol implementation - assert result[0].chunk(0).dictionary == table[0].chunk(0).dictionary - - batch_protocol = batch.__dataframe__() - result_protocol = result.__dataframe__() - - assert batch_protocol.num_columns() == result_protocol.num_columns() - assert batch_protocol.num_rows() == result_protocol.num_rows() - assert batch_protocol.num_chunks() == result_protocol.num_chunks() - assert batch_protocol.column_names() == result_protocol.column_names() - - col_table = batch_protocol.get_column(0) - col_result = result_protocol.get_column(0) - - assert col_result.dtype[0] == DtypeKind.CATEGORICAL - assert col_result.dtype[0] == col_table.dtype[0] - assert col_result.size == col_table.size - assert col_result.offset == col_table.offset - - desc_cat_table = col_result.describe_categorical - desc_cat_result = col_result.describe_categorical - - assert desc_cat_table["is_ordered"] == desc_cat_result["is_ordered"] - assert desc_cat_table["is_dictionary"] == desc_cat_result["is_dictionary"] - assert isinstance(desc_cat_result["categories"]._col, pa.Array) - @pytest.mark.pandas @pytest.mark.parametrize( @@ -237,19 +187,6 @@ def test_pandas_roundtrip(uint, int, float, np_float): assert table_protocol.num_chunks() == result_protocol.num_chunks() assert table_protocol.column_names() == result_protocol.column_names() - batch = table.to_batches()[0] - pandas_df = pandas_from_dataframe(batch) - result = pi.from_dataframe(pandas_df) - assert table.equals(result) - - batch_protocol = batch.__dataframe__() - result_protocol = result.__dataframe__() - - assert batch_protocol.num_columns() == result_protocol.num_columns() - assert batch_protocol.num_rows() == result_protocol.num_rows() - assert batch_protocol.num_chunks() == result_protocol.num_chunks() - assert batch_protocol.column_names() == result_protocol.column_names() - @pytest.mark.pandas def test_roundtrip_pandas_string(): @@ -278,22 +215,6 @@ def test_roundtrip_pandas_string(): assert table_protocol.num_chunks() == result_protocol.num_chunks() assert table_protocol.column_names() == result_protocol.column_names() - batch = table.to_batches()[0] - pandas_df = pandas_from_dataframe(batch) - result = pi.from_dataframe(pandas_df) - - assert result[0].to_pylist() == batch[0].to_pylist() - assert pa.types.is_string(table[0].type) - assert pa.types.is_large_string(result[0].type) - - batch_protocol = batch.__dataframe__() - result_protocol = result.__dataframe__() - - assert batch_protocol.num_columns() == result_protocol.num_columns() - assert batch_protocol.num_rows() == result_protocol.num_rows() - assert batch_protocol.num_chunks() == result_protocol.num_chunks() - assert batch_protocol.column_names() == result_protocol.column_names() - @pytest.mark.pandas def test_roundtrip_pandas_boolean(): @@ -318,20 +239,6 @@ def test_roundtrip_pandas_boolean(): assert table_protocol.num_chunks() == result_protocol.num_chunks() assert table_protocol.column_names() == result_protocol.column_names() - batch = table.to_batches()[0] - pandas_df = pandas_from_dataframe(batch) - result = pi.from_dataframe(pandas_df) - - assert table.equals(result) - - batch_protocol = batch.__dataframe__() - result_protocol = result.__dataframe__() - - assert batch_protocol.num_columns() == result_protocol.num_columns() - assert batch_protocol.num_rows() == result_protocol.num_rows() - assert batch_protocol.num_chunks() == result_protocol.num_chunks() - assert batch_protocol.column_names() == result_protocol.column_names() - @pytest.mark.pandas @pytest.mark.parametrize("unit", ['s', 'ms', 'us', 'ns']) @@ -368,21 +275,6 @@ def test_roundtrip_pandas_datetime(unit): assert expected_protocol.num_chunks() == result_protocol.num_chunks() assert expected_protocol.column_names() == result_protocol.column_names() - batch = table.to_batches()[0] - pandas_df = pandas_from_dataframe(batch) - result = pi.from_dataframe(pandas_df) - - assert expected.equals(result) - - expected_batch = expected.to_batches()[0] - expected_protocol = expected_batch.__dataframe__() - result_protocol = result.__dataframe__() - - assert expected_protocol.num_columns() == result_protocol.num_columns() - assert expected_protocol.num_rows() == result_protocol.num_rows() - assert expected_protocol.num_chunks() == result_protocol.num_chunks() - assert expected_protocol.column_names() == result_protocol.column_names() - @pytest.mark.large_memory @pytest.mark.pandas @@ -402,10 +294,6 @@ def test_pandas_assertion_error_large_string(): with pytest.raises(AssertionError): pandas_from_dataframe(table) - batch = table.to_batches()[0] - with pytest.raises(AssertionError): - pandas_from_dataframe(batch) - @pytest.mark.pandas @pytest.mark.parametrize( @@ -532,19 +420,6 @@ def test_pyarrow_roundtrip(uint, int, float, np_float, assert table_protocol.num_chunks() == result_protocol.num_chunks() assert table_protocol.column_names() == result_protocol.column_names() - batch = table.to_batches()[0] - result = _from_dataframe(batch.__dataframe__()) - - assert table.equals(result) - - batch_protocol = batch.__dataframe__() - result_protocol = result.__dataframe__() - - assert batch_protocol.num_columns() == result_protocol.num_columns() - assert batch_protocol.num_rows() == result_protocol.num_rows() - assert batch_protocol.num_chunks() == result_protocol.num_chunks() - assert batch_protocol.column_names() == result_protocol.column_names() - @pytest.mark.parametrize("offset, length", [(0, 10), (0, 2), (7, 3), (2, 1)]) def test_pyarrow_roundtrip_categorical(offset, length): @@ -580,34 +455,6 @@ def test_pyarrow_roundtrip_categorical(offset, length): assert desc_cat_table["is_dictionary"] == desc_cat_result["is_dictionary"] assert isinstance(desc_cat_result["categories"]._col, pa.Array) - batch = table.to_batches()[0] - result = _from_dataframe(batch.__dataframe__()) - - assert table.equals(result) - - batch_protocol = batch.__dataframe__() - result_protocol = result.__dataframe__() - - assert batch_protocol.num_columns() == result_protocol.num_columns() - assert batch_protocol.num_rows() == result_protocol.num_rows() - assert batch_protocol.num_chunks() == result_protocol.num_chunks() - assert batch_protocol.column_names() == result_protocol.column_names() - - col_batch = batch_protocol.get_column(0) - col_result = result_protocol.get_column(0) - - assert col_result.dtype[0] == DtypeKind.CATEGORICAL - assert col_result.dtype[0] == col_batch.dtype[0] - assert col_result.size() == col_batch.size() - assert col_result.offset == col_batch.offset - - desc_col_batch = col_batch.describe_categorical - desc_cat_result = col_result.describe_categorical - - assert desc_col_batch["is_ordered"] == desc_cat_result["is_ordered"] - assert desc_col_batch["is_dictionary"] == desc_cat_result["is_dictionary"] - assert isinstance(desc_cat_result["categories"]._col, pa.Array) - @pytest.mark.large_memory def test_pyarrow_roundtrip_large_string(): @@ -625,24 +472,11 @@ def test_pyarrow_roundtrip_large_string(): assert table.equals(result) - batch = table.to_batches()[0] - result = _from_dataframe(batch.__dataframe__()) - col = result.__dataframe__().get_column(0) - - assert col.size() == 3*1024**2 - assert pa.types.is_large_string(table[0].type) - assert pa.types.is_large_string(result[0].type) - - assert table.equals(result) - def test_nan_as_null(): table = pa.table({"a": [1, 2, 3, 4]}) with pytest.raises(RuntimeError): table.__dataframe__(nan_as_null=True) - batch = table.to_batches()[0] - with pytest.raises(RuntimeError): - batch.__dataframe__(nan_as_null=True) @pytest.mark.pandas diff --git a/python/pyarrow/tests/interchange/test_interchange_spec.py b/python/pyarrow/tests/interchange/test_interchange_spec.py index cd87d889754..e8908aaa615 100644 --- a/python/pyarrow/tests/interchange/test_interchange_spec.py +++ b/python/pyarrow/tests/interchange/test_interchange_spec.py @@ -50,15 +50,6 @@ def test_dtypes(arr): assert df.get_column(0).size() == 3 assert df.get_column(0).offset == 0 - batch = pa.record_batch([arr], names=["a"]) - df = batch.__dataframe__() - - null_count = df.get_column(0).null_count - assert null_count == arr.null_count - assert isinstance(null_count, int) - assert df.get_column(0).size() == 3 - assert df.get_column(0).offset == 0 - @pytest.mark.parametrize( "uint, uint_bw", @@ -85,8 +76,10 @@ def test_dtypes(arr): ) @pytest.mark.parametrize("unit", ['s', 'ms', 'us', 'ns']) @pytest.mark.parametrize("tz", ['', 'America/New_York', '+07:30', '-04:30']) +@pytest.mark.parametrize("use_batch", [False, True]) def test_mixed_dtypes(uint, uint_bw, int, int_bw, - float, float_bw, np_float, unit, tz): + float, float_bw, np_float, unit, tz, + use_batch): from datetime import datetime as dt arr = [1, 2, 3] dt_arr = [dt(2007, 7, 13), dt(2007, 7, 14), dt(2007, 7, 15)] @@ -100,6 +93,8 @@ def test_mixed_dtypes(uint, uint_bw, int, int_bw, "f": pa.array(dt_arr, type=pa.timestamp(unit, tz=tz)) } ) + if use_batch: + table = table.to_batches()[0] df = table.__dataframe__() # 0 = DtypeKind.INT, 1 = DtypeKind.UINT, 2 = DtypeKind.FLOAT, # 20 = DtypeKind.BOOL, 21 = DtypeKind.STRING, 22 = DtypeKind.DATETIME @@ -118,21 +113,6 @@ def test_mixed_dtypes(uint, uint_bw, int, int_bw, assert df.get_column_by_name("b").dtype[1] == int_bw assert df.get_column_by_name("c").dtype[1] == float_bw - batch = table.to_batches()[0] - df = batch.__dataframe__() - - for column, kind in columns.items(): - col = df.get_column_by_name(column) - - assert col.null_count == 0 - assert col.size() == 3 - assert col.offset == 0 - assert col.dtype[0] == kind - - assert df.get_column_by_name("a").dtype[1] == uint_bw - assert df.get_column_by_name("b").dtype[1] == int_bw - assert df.get_column_by_name("c").dtype[1] == float_bw - def test_na_float(): table = pa.table({"a": [1.0, None, 2.0]}) @@ -141,12 +121,6 @@ def test_na_float(): assert col.null_count == 1 assert isinstance(col.null_count, int) - batch = table.to_batches()[0] - df = batch.__dataframe__() - col = df.get_column_by_name("a") - assert col.null_count == 1 - assert isinstance(col.null_count, int) - def test_noncategorical(): table = pa.table({"a": [1, 2, 3]}) @@ -155,88 +129,68 @@ def test_noncategorical(): with pytest.raises(TypeError, match=".*categorical.*"): col.describe_categorical - batch = table.to_batches()[0] - df = batch.__dataframe__() - col = df.get_column_by_name("a") - with pytest.raises(TypeError, match=".*categorical.*"): - col.describe_categorical - -def test_categorical(): +@pytest.mark.parametrize("use_batch", [False, True]) +def test_categorical(use_batch): import pyarrow as pa arr = ["Mon", "Tue", "Mon", "Wed", "Mon", "Thu", "Fri", "Sat", None] table = pa.table( {"weekday": pa.array(arr).dictionary_encode()} ) + if use_batch: + table = table.to_batches()[0] col = table.__dataframe__().get_column_by_name("weekday") categorical = col.describe_categorical assert isinstance(categorical["is_ordered"], bool) assert isinstance(categorical["is_dictionary"], bool) - batch = table.to_batches()[0] - col = batch.__dataframe__().get_column_by_name("weekday") - categorical = col.describe_categorical - assert isinstance(categorical["is_ordered"], bool) - assert isinstance(categorical["is_dictionary"], bool) - -def test_dataframe(): +@pytest.mark.parametrize("use_batch", [False, True]) +def test_dataframe(use_batch): n = pa.chunked_array([[2, 2, 4], [4, 5, 100]]) a = pa.chunked_array([["Flamingo", "Parrot", "Cow"], ["Horse", "Brittle stars", "Centipede"]]) table = pa.table([n, a], names=['n_legs', 'animals']) + if use_batch: + table = table.combine_chunks().to_batches()[0] df = table.__dataframe__() assert df.num_columns() == 2 assert df.num_rows() == 6 - assert df.num_chunks() == 2 - assert list(df.column_names()) == ['n_legs', 'animals'] - assert list(df.select_columns((1,)).column_names()) == list( - df.select_columns_by_name(("animals",)).column_names() - ) - - batch = table.combine_chunks().to_batches()[0] - df = batch.__dataframe__() - - assert df.num_columns() == 2 - assert df.num_rows() == 6 - assert df.num_chunks() == 1 + if use_batch: + assert df.num_chunks() == 1 + else: + assert df.num_chunks() == 2 assert list(df.column_names()) == ['n_legs', 'animals'] assert list(df.select_columns((1,)).column_names()) == list( df.select_columns_by_name(("animals",)).column_names() ) +@pytest.mark.parametrize("use_batch", [False, True]) @pytest.mark.parametrize(["size", "n_chunks"], [(10, 3), (12, 3), (12, 5)]) -def test_df_get_chunks(size, n_chunks): +def test_df_get_chunks(use_batch, size, n_chunks): table = pa.table({"x": list(range(size))}) + if use_batch: + table = table.to_batches()[0] df = table.__dataframe__() chunks = list(df.get_chunks(n_chunks)) assert len(chunks) == n_chunks assert sum(chunk.num_rows() for chunk in chunks) == size - batch = table.to_batches()[0] - df = batch.__dataframe__() - chunks = list(df.get_chunks(n_chunks)) - assert len(chunks) == n_chunks - assert sum(chunk.num_rows() for chunk in chunks) == size - +@pytest.mark.parametrize("use_batch", [False, True]) @pytest.mark.parametrize(["size", "n_chunks"], [(10, 3), (12, 3), (12, 5)]) -def test_column_get_chunks(size, n_chunks): +def test_column_get_chunks(use_batch, size, n_chunks): table = pa.table({"x": list(range(size))}) + if use_batch: + table = table.to_batches()[0] df = table.__dataframe__() chunks = list(df.get_column(0).get_chunks(n_chunks)) assert len(chunks) == n_chunks assert sum(chunk.size() for chunk in chunks) == size - batch = table.to_batches()[0] - df = batch.__dataframe__() - chunks = list(df.get_column(0).get_chunks(n_chunks)) - assert len(chunks) == n_chunks - assert sum(chunk.size() for chunk in chunks) == size - @pytest.mark.pandas @pytest.mark.parametrize( @@ -252,7 +206,8 @@ def test_column_get_chunks(size, n_chunks): (pa.float64(), np.float64) ] ) -def test_get_columns(uint, int, float, np_float): +@pytest.mark.parametrize("use_batch", [False, True]) +def test_get_columns(uint, int, float, np_float, use_batch): arr = [[1, 2, 3], [4, 5]] arr_float = np.array([1, 2, 3, 4, 5], dtype=np_float) table = pa.table( @@ -262,6 +217,8 @@ def test_get_columns(uint, int, float, np_float): "c": pa.array(arr_float, type=float) } ) + if use_batch: + table = table.combine_chunks().to_batches()[0] df = table.__dataframe__() for col in df.get_columns(): assert col.size() == 5 @@ -273,23 +230,16 @@ def test_get_columns(uint, int, float, np_float): assert df.get_column(1).dtype[0] == 0 # INT assert df.get_column(2).dtype[0] == 2 # FLOAT - batch = table.combine_chunks().to_batches()[0] - df = batch.__dataframe__() - for col in df.get_columns(): - assert col.size() == 5 - assert col.num_chunks() == 1 - - assert df.get_column(0).dtype[0] == 1 - assert df.get_column(1).dtype[0] == 0 - assert df.get_column(2).dtype[0] == 2 - @pytest.mark.parametrize( "int", [pa.int8(), pa.int16(), pa.int32(), pa.int64()] ) -def test_buffer(int): +@pytest.mark.parametrize("use_batch", [False, True]) +def test_buffer(int, use_batch): arr = [0, 1, -1] table = pa.table({"a": pa.array(arr, type=int)}) + if use_batch: + table = table.to_batches()[0] df = table.__dataframe__() col = df.get_column(0) buf = col.get_buffers() @@ -317,30 +267,3 @@ def test_buffer(int): val = ctype.from_address(dataBuf.ptr + idx * (bitwidth // 8)).value assert val == truth, f"Buffer at index {idx} mismatch" - batch = table.to_batches()[0] - df = batch.__dataframe__() - col = df.get_column(0) - buf = col.get_buffers() - - dataBuf, dataDtype = buf["data"] - - assert dataBuf.bufsize > 0 - assert dataBuf.ptr != 0 - device, _ = dataBuf.__dlpack_device__() - - # 0 = DtypeKind.INT - # see DtypeKind class in column.py - assert dataDtype[0] == 0 - - if device == 1: # CPU-only as we're going to directly read memory here - bitwidth = dataDtype[1] - ctype = { - 8: ctypes.c_int8, - 16: ctypes.c_int16, - 32: ctypes.c_int32, - 64: ctypes.c_int64, - }[bitwidth] - - for idx, truth in enumerate(arr): - val = ctype.from_address(dataBuf.ptr + idx * (bitwidth // 8)).value - assert val == truth, f"Buffer at index {idx} mismatch" From ef6983d770fce550748b85e269b2e5d0a194f31a Mon Sep 17 00:00:00 2001 From: Alenka Frim Date: Mon, 27 Feb 2023 11:46:44 +0100 Subject: [PATCH 10/11] Update select method --- python/pyarrow/interchange/dataframe.py | 21 ++++++--------------- 1 file changed, 6 insertions(+), 15 deletions(-) diff --git a/python/pyarrow/interchange/dataframe.py b/python/pyarrow/interchange/dataframe.py index f7e9bf4eb4b..59ba765c175 100644 --- a/python/pyarrow/interchange/dataframe.py +++ b/python/pyarrow/interchange/dataframe.py @@ -159,14 +159,9 @@ def select_columns(self, indices: Sequence[int]) -> _PyArrowDataFrame: """ Create a new DataFrame by selecting a subset of columns by index. """ - if isinstance(self._df, pa.RecordBatch): - columns = [self._df.column(i) for i in indices] - names = [self._df.schema.names[i] for i in indices] - return _PyArrowDataFrame(pa.record_batch(columns, names=names)) - else: - return _PyArrowDataFrame( - self._df.select(list(indices)), self._nan_as_null, self._allow_copy - ) + return _PyArrowDataFrame( + self._df.select(list(indices)), self._nan_as_null, self._allow_copy + ) def select_columns_by_name( self, names: Sequence[str] @@ -174,13 +169,9 @@ def select_columns_by_name( """ Create a new DataFrame by selecting a subset of columns by name. """ - if isinstance(self._df, pa.RecordBatch): - columns = [self._df[i] for i in names] - return _PyArrowDataFrame(pa.record_batch(columns, names=names)) - else: - return _PyArrowDataFrame( - self._df.select(list(names)), self._nan_as_null, self._allow_copy - ) + return _PyArrowDataFrame( + self._df.select(list(names)), self._nan_as_null, self._allow_copy + ) def get_chunks( self, n_chunks: Optional[int] = None From 821f9f3f3547ea2b792c7e8f535f04c6091fbde4 Mon Sep 17 00:00:00 2001 From: Alenka Frim Date: Mon, 27 Feb 2023 13:06:53 +0100 Subject: [PATCH 11/11] Correct linter error --- python/pyarrow/tests/interchange/test_interchange_spec.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/pyarrow/tests/interchange/test_interchange_spec.py b/python/pyarrow/tests/interchange/test_interchange_spec.py index e8908aaa615..7b2b8eb7208 100644 --- a/python/pyarrow/tests/interchange/test_interchange_spec.py +++ b/python/pyarrow/tests/interchange/test_interchange_spec.py @@ -266,4 +266,3 @@ def test_buffer(int, use_batch): for idx, truth in enumerate(arr): val = ctype.from_address(dataBuf.ptr + idx * (bitwidth // 8)).value assert val == truth, f"Buffer at index {idx} mismatch" -