diff --git a/python/pyarrow/tests/test_extension_type.py b/python/pyarrow/tests/test_extension_type.py index 78618c19fff..51bf57c9ba3 100644 --- a/python/pyarrow/tests/test_extension_type.py +++ b/python/pyarrow/tests/test_extension_type.py @@ -867,6 +867,31 @@ def test_generic_ext_type_equality(): assert not period_type == period_type3 +def test_generic_ext_type_pickling(registered_period_type): + # GH-36038 + for proto in range(0, pickle.HIGHEST_PROTOCOL + 1): + period_type, _ = registered_period_type + ser = pickle.dumps(period_type, protocol=proto) + period_type_pickled = pickle.loads(ser) + assert period_type == period_type_pickled + + +def test_generic_ext_array_pickling(registered_period_type): + for proto in range(0, pickle.HIGHEST_PROTOCOL + 1): + period_type, _ = registered_period_type + storage = pa.array([1, 2, 3, 4], pa.int64()) + arr = pa.ExtensionArray.from_storage(period_type, storage) + ser = pickle.dumps(arr, protocol=proto) + del storage, arr + arr = pickle.loads(ser) + arr.validate() + assert isinstance(arr, pa.ExtensionArray) + assert arr.type == period_type + assert arr.type.storage_type == pa.int64() + assert arr.storage.type == pa.int64() + assert arr.storage.to_pylist() == [1, 2, 3, 4] + + def test_generic_ext_type_register(registered_period_type): # test that trying to register other type does not segfault with pytest.raises(TypeError): diff --git a/python/pyarrow/types.pxi b/python/pyarrow/types.pxi index 48605c293ec..a3311cbbcf4 100644 --- a/python/pyarrow/types.pxi +++ b/python/pyarrow/types.pxi @@ -1487,6 +1487,9 @@ cdef class ExtensionType(BaseExtensionType): """ return NotImplementedError + def __reduce__(self): + return self.__arrow_ext_deserialize__, (self.storage_type, self.__arrow_ext_serialize__()) + def __arrow_ext_class__(self): """Return an extension array class to be used for building or deserializing arrays with this extension type.