Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 94 additions & 0 deletions distributed/protocol/pickle.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@

if sys.version_info.major == 2:
import cPickle as pickle
from pickle import load as pyload
else:
import pickle
from pickle import _load as pyload

logger = logging.getLogger(__name__)

Expand All @@ -27,6 +29,67 @@ def _always_use_pickle_for(x):
return False


class _BytelistFile(object):

def __init__(self, chunks=None):
if chunks is None:
chunks = []
self._chunks = chunks
self._pos = sum(len(c) for c in chunks)

def __len__(self):
return sum(len(c) for c in self._chunks)

def write(self, chunk):
self._chunks.append(chunk)

def read(self, size=None):
return b''.join(self._collect_chunks(size=size))

def readline(self):
raise NotImplementedError

def _collect_chunks(self, size=None):
pos = self._pos
remainder = (len(self) - pos) if size is None else size
if remainder <= 0:
return []
collected = []
left_to_skip = pos
for chunk in self._chunks:
if remainder <= 0:
break
if left_to_skip > 0:
if left_to_skip > len(chunk):
left_to_skip -= len(chunk)
else:
chunk = chunk[left_to_skip:left_to_skip + remainder]
left_to_skip = 0
collected.append(chunk)
remainder -= len(chunk)
else:
if len(chunk) <= remainder:
collected.append(chunk)
remainder -= len(chunk)
else:
chunk = chunk[:remainder]
collected.append(chunk)
remainder = 0
self._pos += sum(len(c) for c in collected)
return collected

def tell(self):
return self._pos

def seek(self, pos):
if pos < 0:
raise ValueError("Negative position %d is invalid." % pos)
elif pos > len(self):
raise ValueError("Position %d is larger than size %d."
% (pos, len(self)))
self._pos = pos


def dumps(x):
""" Manage between cloudpickle and pickle

Expand Down Expand Up @@ -54,9 +117,40 @@ def dumps(x):
raise


def dump_bytelist(x):
"""Serialize the list of chunks using the pickle protocol

Note that cloudpickle leverages nocopy semantics using memoryviews on
large contiguous datastructures such as numpy arrays and derivatives.
"""
# TODO: if Python 3 dump supports nocopy dump we should try use it first
# and only fallback to cloudpickle
try:
writer = _BytelistFile()
cloudpickle.dump(x, writer, protocol=pickle.HIGHEST_PROTOCOL)
return writer._chunks
except Exception as e:
logger.info("Failed to serialize %s. Exception: %s", x, e)
raise


def loads(x):
try:
return pickle.loads(x)
except Exception:
logger.info("Failed to deserialize %s", x[:10000], exc_info=True)
raise


def load_bytelist(bytelist):
try:
reader = _BytelistFile(bytelist)
reader.seek(0)
# Use the Python-based Unpickler to avoid a memory-copy when loading
# large binary data buffers that back numpy arrays and pandas data
# frames.
return pyload(reader)
except Exception:
logger.info("Failed to deserialize %s", bytelist[0][:10000],
exc_info=True)
raise
4 changes: 2 additions & 2 deletions distributed/protocol/serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@


serializers = {}
deserializers = {None: lambda header, frames: pickle.loads(b''.join(frames))}
deserializers = {None: lambda header, frames: pickle.load_bytelist(frames)}

lazy_registrations = {}

Expand Down Expand Up @@ -132,7 +132,7 @@ def serialize(x):
else:
if _find_lazy_registration(name):
return serialize(x) # recurse
header, frames = {}, [pickle.dumps(x)]
header, frames = {}, pickle.dump_bytelist(x)

return header, frames

Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
click >= 6.6
cloudpickle >= 0.2.2
git+git://github.com/cloudpipe/cloudpickle@nocopy-memoryviews#egg=cloudpickle
dask >= 0.16.0
msgpack-python
psutil
Expand Down