diff --git a/python/pyarrow/__init__.py b/python/pyarrow/__init__.py index bd31b21c196..a245fe67960 100644 --- a/python/pyarrow/__init__.py +++ b/python/pyarrow/__init__.py @@ -125,6 +125,7 @@ localfs = LocalFileSystem.get_instance() from pyarrow.serialization import (_default_serialization_context, + pandas_serialization_context, register_default_serialization_handlers) import pyarrow.types as types diff --git a/python/pyarrow/serialization.pxi b/python/pyarrow/serialization.pxi index bb266b2f928..faf164b3ebd 100644 --- a/python/pyarrow/serialization.pxi +++ b/python/pyarrow/serialization.pxi @@ -57,6 +57,22 @@ cdef class SerializationContext: self.custom_serializers = dict() self.custom_deserializers = dict() + def clone(self): + """ + Return copy of this SerializationContext + + Returns + ------- + clone : SerializationContext + """ + result = SerializationContext() + result.type_to_type_id = self.type_to_type_id.copy() + result.whitelisted_types = self.whitelisted_types.copy() + result.custom_serializers = self.custom_serializers.copy() + result.custom_deserializers = self.custom_deserializers.copy() + + return result + def register_type(self, type_, type_id, custom_serializer=None, custom_deserializer=None): """EXPERIMENTAL: Add type to the list of types we can serialize. diff --git a/python/pyarrow/serialization.py b/python/pyarrow/serialization.py index ab25b63d571..08e6cce751b 100644 --- a/python/pyarrow/serialization.py +++ b/python/pyarrow/serialization.py @@ -22,7 +22,7 @@ import numpy as np from pyarrow import serialize_pandas, deserialize_pandas -from pyarrow.lib import _default_serialization_context +from pyarrow.lib import _default_serialization_context, frombuffer try: import cloudpickle @@ -30,6 +30,28 @@ cloudpickle = pickle +# ---------------------------------------------------------------------- +# Set up serialization for numpy with dtype object (primitive types are +# handled efficiently with Arrow's Tensor facilities, see +# python_to_arrow.cc) + +def _serialize_numpy_array_list(obj): + return obj.tolist(), obj.dtype.str + + +def _deserialize_numpy_array_list(data): + return np.array(data[0], dtype=np.dtype(data[1])) + + +def _serialize_numpy_array_pickle(obj): + pickled = pickle.dumps(obj, protocol=pickle.HIGHEST_PROTOCOL) + return frombuffer(pickled) + + +def _deserialize_numpy_array_pickle(data): + return pickle.loads(memoryview(data)) + + def register_default_serialization_handlers(serialization_context): # ---------------------------------------------------------------------- @@ -80,21 +102,10 @@ def _deserialize_default_dict(data): custom_serializer=cloudpickle.dumps, custom_deserializer=cloudpickle.loads) - # ---------------------------------------------------------------------- - # Set up serialization for numpy with dtype object (primitive types are - # handled efficiently with Arrow's Tensor facilities, see - # python_to_arrow.cc) - - def _serialize_numpy_array(obj): - return obj.tolist(), obj.dtype.str - - def _deserialize_numpy_array(data): - return np.array(data[0], dtype=np.dtype(data[1])) - serialization_context.register_type( np.ndarray, 'np.array', - custom_serializer=_serialize_numpy_array, - custom_deserializer=_deserialize_numpy_array) + custom_serializer=_serialize_numpy_array_list, + custom_deserializer=_deserialize_numpy_array_list) # ---------------------------------------------------------------------- # Set up serialization for pandas Series and DataFrame @@ -153,3 +164,10 @@ def _deserialize_torch_tensor(data): register_default_serialization_handlers(_default_serialization_context) + +pandas_serialization_context = _default_serialization_context.clone() + +pandas_serialization_context.register_type( + np.ndarray, 'np.array', + custom_serializer=_serialize_numpy_array_pickle, + custom_deserializer=_deserialize_numpy_array_pickle) diff --git a/python/pyarrow/tests/test_serialization.py b/python/pyarrow/tests/test_serialization.py index d06beeac992..6d85621d411 100644 --- a/python/pyarrow/tests/test_serialization.py +++ b/python/pyarrow/tests/test_serialization.py @@ -212,11 +212,11 @@ def make_serialization_context(): serialization_context = make_serialization_context() -def serialization_roundtrip(value, f): +def serialization_roundtrip(value, f, ctx=serialization_context): f.seek(0) - pa.serialize_to(value, f, serialization_context) + pa.serialize_to(value, f, ctx) f.seek(0) - result = pa.deserialize_from(f, None, serialization_context) + result = pa.deserialize_from(f, None, ctx) assert_equal(value, result) _check_component_roundtrip(value) @@ -249,6 +249,7 @@ def test_primitive_serialization(large_memory_map): with pa.memory_map(large_memory_map, mode="r+") as mmap: for obj in PRIMITIVE_OBJECTS: serialization_roundtrip(obj, mmap) + serialization_roundtrip(obj, mmap, pa.pandas_serialization_context) def test_serialize_to_buffer():