Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions python/pyarrow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@
localfs = LocalFileSystem.get_instance()

from pyarrow.serialization import (_default_serialization_context,
pandas_serialization_context,
register_default_serialization_handlers)

import pyarrow.types as types
Expand Down
16 changes: 16 additions & 0 deletions python/pyarrow/serialization.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,22 @@ cdef class SerializationContext:
self.custom_serializers = dict()
self.custom_deserializers = dict()

def clone(self):
"""
Return copy of this SerializationContext

Returns
-------
clone : SerializationContext
"""
result = SerializationContext()
result.type_to_type_id = self.type_to_type_id.copy()
result.whitelisted_types = self.whitelisted_types.copy()
result.custom_serializers = self.custom_serializers.copy()
result.custom_deserializers = self.custom_deserializers.copy()

return result

def register_type(self, type_, type_id,
custom_serializer=None, custom_deserializer=None):
"""EXPERIMENTAL: Add type to the list of types we can serialize.
Expand Down
46 changes: 32 additions & 14 deletions python/pyarrow/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,36 @@
import numpy as np

from pyarrow import serialize_pandas, deserialize_pandas
from pyarrow.lib import _default_serialization_context
from pyarrow.lib import _default_serialization_context, frombuffer

try:
import cloudpickle
except ImportError:
cloudpickle = pickle


# ----------------------------------------------------------------------
# Set up serialization for numpy with dtype object (primitive types are
# handled efficiently with Arrow's Tensor facilities, see
# python_to_arrow.cc)

def _serialize_numpy_array_list(obj):
return obj.tolist(), obj.dtype.str


def _deserialize_numpy_array_list(data):
return np.array(data[0], dtype=np.dtype(data[1]))


def _serialize_numpy_array_pickle(obj):
pickled = pickle.dumps(obj, protocol=pickle.HIGHEST_PROTOCOL)
return frombuffer(pickled)


def _deserialize_numpy_array_pickle(data):
return pickle.loads(memoryview(data))


def register_default_serialization_handlers(serialization_context):

# ----------------------------------------------------------------------
Expand Down Expand Up @@ -80,21 +102,10 @@ def _deserialize_default_dict(data):
custom_serializer=cloudpickle.dumps,
custom_deserializer=cloudpickle.loads)

# ----------------------------------------------------------------------
# Set up serialization for numpy with dtype object (primitive types are
# handled efficiently with Arrow's Tensor facilities, see
# python_to_arrow.cc)

def _serialize_numpy_array(obj):
return obj.tolist(), obj.dtype.str

def _deserialize_numpy_array(data):
return np.array(data[0], dtype=np.dtype(data[1]))

serialization_context.register_type(
np.ndarray, 'np.array',
custom_serializer=_serialize_numpy_array,
custom_deserializer=_deserialize_numpy_array)
custom_serializer=_serialize_numpy_array_list,
custom_deserializer=_deserialize_numpy_array_list)

# ----------------------------------------------------------------------
# Set up serialization for pandas Series and DataFrame
Expand Down Expand Up @@ -153,3 +164,10 @@ def _deserialize_torch_tensor(data):


register_default_serialization_handlers(_default_serialization_context)

pandas_serialization_context = _default_serialization_context.clone()

pandas_serialization_context.register_type(
np.ndarray, 'np.array',
custom_serializer=_serialize_numpy_array_pickle,
custom_deserializer=_deserialize_numpy_array_pickle)
7 changes: 4 additions & 3 deletions python/pyarrow/tests/test_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,11 +212,11 @@ def make_serialization_context():
serialization_context = make_serialization_context()


def serialization_roundtrip(value, f):
def serialization_roundtrip(value, f, ctx=serialization_context):
f.seek(0)
pa.serialize_to(value, f, serialization_context)
pa.serialize_to(value, f, ctx)
f.seek(0)
result = pa.deserialize_from(f, None, serialization_context)
result = pa.deserialize_from(f, None, ctx)
assert_equal(value, result)

_check_component_roundtrip(value)
Expand Down Expand Up @@ -249,6 +249,7 @@ def test_primitive_serialization(large_memory_map):
with pa.memory_map(large_memory_map, mode="r+") as mmap:
for obj in PRIMITIVE_OBJECTS:
serialization_roundtrip(obj, mmap)
serialization_roundtrip(obj, mmap, pa.pandas_serialization_context)


def test_serialize_to_buffer():
Expand Down