Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 2 additions & 12 deletions python/pyarrow/serialization.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -47,19 +47,17 @@ cdef class SerializationContext:
cdef:
object type_to_type_id
object whitelisted_types
object types_to_pickle
object custom_serializers
object custom_deserializers

def __init__(self):
# Types with special serialization handlers
self.type_to_type_id = dict()
self.whitelisted_types = dict()
self.types_to_pickle = set()
self.custom_serializers = dict()
self.custom_deserializers = dict()

def register_type(self, type_, type_id, pickle=False,
def register_type(self, type_, type_id,
custom_serializer=None, custom_deserializer=None):
"""EXPERIMENTAL: Add type to the list of types we can serialize.

Expand All @@ -69,9 +67,6 @@ cdef class SerializationContext:
The type that we can serialize.
type_id : bytes
A string of bytes used to identify the type.
pickle : bool
True if the serialization should be done with pickle.
False if it should be done efficiently with Arrow.
custom_serializer : callable
This argument is optional, but can be provided to
serialize objects of the class in a particular way.
Expand All @@ -81,8 +76,6 @@ cdef class SerializationContext:
"""
self.type_to_type_id[type_] = type_id
self.whitelisted_types[type_id] = type_
if pickle:
self.types_to_pickle.add(type_id)
if custom_serializer is not None:
self.custom_serializers[type_id] = custom_serializer
self.custom_deserializers[type_id] = custom_deserializer
Expand All @@ -102,9 +95,7 @@ cdef class SerializationContext:

# use the closest match to type(obj)
type_id = self.type_to_type_id[type_]
if type_id in self.types_to_pickle:
serialized_obj = {"data": pickle.dumps(obj), "pickle": True}
elif type_id in self.custom_serializers:
if type_id in self.custom_serializers:
serialized_obj = {"data": self.custom_serializers[type_id](obj)}
else:
if is_named_tuple(type_):
Expand All @@ -125,7 +116,6 @@ cdef class SerializationContext:
# The object was pickled, so unpickle it.
obj = pickle.loads(serialized_obj["data"])
else:
assert type_id not in self.types_to_pickle
if type_id not in self.whitelisted_types:
msg = "Type ID " + str(type_id) + " not registered in " \
"deserialization callback"
Expand Down
13 changes: 11 additions & 2 deletions python/pyarrow/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,18 @@

from collections import OrderedDict, defaultdict
import sys
import pickle

import numpy as np

from pyarrow import serialize_pandas, deserialize_pandas
from pyarrow.lib import _default_serialization_context

try:
import cloudpickle
except ImportError:
cloudpickle = pickle


def register_default_serialization_handlers(serialization_context):

Expand Down Expand Up @@ -67,9 +73,12 @@ def _deserialize_default_dict(data):

serialization_context.register_type(
type(lambda: 0), "function",
pickle=True)
custom_serializer=cloudpickle.dumps,
custom_deserializer=cloudpickle.loads)

serialization_context.register_type(type, "type", pickle=True)
serialization_context.register_type(type, "type",
custom_serializer=cloudpickle.dumps,
custom_deserializer=cloudpickle.loads)

# ----------------------------------------------------------------------
# Set up serialization for numpy with dtype object (primitive types are
Expand Down
9 changes: 6 additions & 3 deletions python/pyarrow/tests/test_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import datetime
import string
import sys
import pickle

import pyarrow as pa
import numpy as np
Expand Down Expand Up @@ -197,7 +198,9 @@ def make_serialization_context():
context.register_type(Baz, "Baz")
context.register_type(Qux, "Quz")
context.register_type(SubQux, "SubQux")
context.register_type(SubQuxPickle, "SubQuxPickle", pickle=True)
context.register_type(SubQuxPickle, "SubQuxPickle",
custom_serializer=pickle.dumps,
custom_deserializer=pickle.loads)
context.register_type(Exception, "Exception")
context.register_type(CustomError, "CustomError")
context.register_type(Point, "Point")
Expand Down Expand Up @@ -338,7 +341,7 @@ def deserialize_dummy_class(serialized_obj):
return serialized_obj

pa._default_serialization_context.register_type(
DummyClass, "DummyClass", pickle=False,
DummyClass, "DummyClass",
custom_serializer=serialize_dummy_class,
custom_deserializer=deserialize_dummy_class)

Expand All @@ -357,7 +360,7 @@ def deserialize_buffer_class(serialized_obj):
return serialized_obj

pa._default_serialization_context.register_type(
BufferClass, "BufferClass", pickle=False,
BufferClass, "BufferClass",
custom_serializer=serialize_buffer_class,
custom_deserializer=deserialize_buffer_class)

Expand Down