From 40ec84f02ee469a726f67381fa19b9a9c5087f98 Mon Sep 17 00:00:00 2001 From: AlenkaF Date: Tue, 6 Jun 2023 11:39:35 +0200 Subject: [PATCH] Start with easy fix for FixedShapeTensorType --- python/pyarrow/tests/test_extension_type.py | 16 ++++++++++++++++ python/pyarrow/types.pxi | 4 ++++ 2 files changed, 20 insertions(+) diff --git a/python/pyarrow/tests/test_extension_type.py b/python/pyarrow/tests/test_extension_type.py index 023968b20df..78618c19fff 100644 --- a/python/pyarrow/tests/test_extension_type.py +++ b/python/pyarrow/tests/test_extension_type.py @@ -1306,3 +1306,19 @@ def test_extension_to_pandas_storage_type(registered_period_type): # Check the usage of types_mapper result = table.to_pandas(types_mapper=pd.ArrowDtype) assert isinstance(result["ext"].dtype, pd.ArrowDtype) + + +def test_tensor_type_is_picklable(): + # GH-35599 + + expected_type = pa.fixed_shape_tensor(pa.int32(), (2, 2)) + result = pickle.loads(pickle.dumps(expected_type)) + + assert result == expected_type + + arr = [[1, 2, 3, 4], [10, 20, 30, 40], [100, 200, 300, 400]] + storage = pa.array(arr, pa.list_(pa.int32(), 4)) + expected_arr = pa.ExtensionArray.from_storage(expected_type, storage) + result = pickle.loads(pickle.dumps(expected_arr)) + + assert result == expected_arr diff --git a/python/pyarrow/types.pxi b/python/pyarrow/types.pxi index bcd358c9a5b..48605c293ec 100644 --- a/python/pyarrow/types.pxi +++ b/python/pyarrow/types.pxi @@ -1586,6 +1586,10 @@ cdef class FixedShapeTensorType(BaseExtensionType): def __arrow_ext_class__(self): return FixedShapeTensorArray + def __reduce__(self): + return fixed_shape_tensor, (self.value_type, self.shape, + self.dim_names, self.permutation) + cdef class PyExtensionType(ExtensionType): """