From d0d4b15c25af54ca82cfe898baaf297a4ebf38ff Mon Sep 17 00:00:00 2001 From: AlenkaF Date: Tue, 20 Jun 2023 14:38:00 +0200 Subject: [PATCH 1/2] Add __reduce__ method to ExtensionType class --- python/pyarrow/tests/test_extension_type.py | 7 +++++++ python/pyarrow/types.pxi | 3 +++ 2 files changed, 10 insertions(+) diff --git a/python/pyarrow/tests/test_extension_type.py b/python/pyarrow/tests/test_extension_type.py index 78618c19fff..4ddc366817f 100644 --- a/python/pyarrow/tests/test_extension_type.py +++ b/python/pyarrow/tests/test_extension_type.py @@ -867,6 +867,13 @@ def test_generic_ext_type_equality(): assert not period_type == period_type3 +def test_generic_ext_type_pickling(): + # GH-36038 + period_type = PeriodType('D') + period_type_pickled = pickle.loads(pickle.dumps(period_type)) + assert period_type == period_type_pickled + + def test_generic_ext_type_register(registered_period_type): # test that trying to register other type does not segfault with pytest.raises(TypeError): diff --git a/python/pyarrow/types.pxi b/python/pyarrow/types.pxi index 48605c293ec..a3311cbbcf4 100644 --- a/python/pyarrow/types.pxi +++ b/python/pyarrow/types.pxi @@ -1487,6 +1487,9 @@ cdef class ExtensionType(BaseExtensionType): """ return NotImplementedError + def __reduce__(self): + return self.__arrow_ext_deserialize__, (self.storage_type, self.__arrow_ext_serialize__()) + def __arrow_ext_class__(self): """Return an extension array class to be used for building or deserializing arrays with this extension type. From 7d7281730410415f5e6f87ecfcb90bb528470dcc Mon Sep 17 00:00:00 2001 From: AlenkaF Date: Wed, 21 Jun 2023 08:53:25 +0200 Subject: [PATCH 2/2] Add test for Extension array with generic ExtensionType --- python/pyarrow/tests/test_extension_type.py | 26 +++++++++++++++++---- 1 file changed, 22 insertions(+), 4 deletions(-) diff --git a/python/pyarrow/tests/test_extension_type.py b/python/pyarrow/tests/test_extension_type.py index 4ddc366817f..51bf57c9ba3 100644 --- a/python/pyarrow/tests/test_extension_type.py +++ b/python/pyarrow/tests/test_extension_type.py @@ -867,11 +867,29 @@ def test_generic_ext_type_equality(): assert not period_type == period_type3 -def test_generic_ext_type_pickling(): +def test_generic_ext_type_pickling(registered_period_type): # GH-36038 - period_type = PeriodType('D') - period_type_pickled = pickle.loads(pickle.dumps(period_type)) - assert period_type == period_type_pickled + for proto in range(0, pickle.HIGHEST_PROTOCOL + 1): + period_type, _ = registered_period_type + ser = pickle.dumps(period_type, protocol=proto) + period_type_pickled = pickle.loads(ser) + assert period_type == period_type_pickled + + +def test_generic_ext_array_pickling(registered_period_type): + for proto in range(0, pickle.HIGHEST_PROTOCOL + 1): + period_type, _ = registered_period_type + storage = pa.array([1, 2, 3, 4], pa.int64()) + arr = pa.ExtensionArray.from_storage(period_type, storage) + ser = pickle.dumps(arr, protocol=proto) + del storage, arr + arr = pickle.loads(ser) + arr.validate() + assert isinstance(arr, pa.ExtensionArray) + assert arr.type == period_type + assert arr.type.storage_type == pa.int64() + assert arr.storage.type == pa.int64() + assert arr.storage.to_pylist() == [1, 2, 3, 4] def test_generic_ext_type_register(registered_period_type):