diff --git a/distributed/protocol/pickle.py b/distributed/protocol/pickle.py index 8419541687f..1a0ce6c6371 100644 --- a/distributed/protocol/pickle.py +++ b/distributed/protocol/pickle.py @@ -7,8 +7,10 @@ if sys.version_info.major == 2: import cPickle as pickle + from pickle import load as pyload else: import pickle + from pickle import _load as pyload logger = logging.getLogger(__name__) @@ -27,6 +29,67 @@ def _always_use_pickle_for(x): return False +class _BytelistFile(object): + + def __init__(self, chunks=None): + if chunks is None: + chunks = [] + self._chunks = chunks + self._pos = sum(len(c) for c in chunks) + + def __len__(self): + return sum(len(c) for c in self._chunks) + + def write(self, chunk): + self._chunks.append(chunk) + + def read(self, size=None): + return b''.join(self._collect_chunks(size=size)) + + def readline(self): + raise NotImplementedError + + def _collect_chunks(self, size=None): + pos = self._pos + remainder = (len(self) - pos) if size is None else size + if remainder <= 0: + return [] + collected = [] + left_to_skip = pos + for chunk in self._chunks: + if remainder <= 0: + break + if left_to_skip > 0: + if left_to_skip > len(chunk): + left_to_skip -= len(chunk) + else: + chunk = chunk[left_to_skip:left_to_skip + remainder] + left_to_skip = 0 + collected.append(chunk) + remainder -= len(chunk) + else: + if len(chunk) <= remainder: + collected.append(chunk) + remainder -= len(chunk) + else: + chunk = chunk[:remainder] + collected.append(chunk) + remainder = 0 + self._pos += sum(len(c) for c in collected) + return collected + + def tell(self): + return self._pos + + def seek(self, pos): + if pos < 0: + raise ValueError("Negative position %d is invalid." % pos) + elif pos > len(self): + raise ValueError("Position %d is larger than size %d." + % (pos, len(self))) + self._pos = pos + + def dumps(x): """ Manage between cloudpickle and pickle @@ -54,9 +117,40 @@ def dumps(x): raise +def dump_bytelist(x): + """Serialize the list of chunks using the pickle protocol + + Note that cloudpickle leverages nocopy semantics using memoryviews on + large contiguous datastructures such as numpy arrays and derivatives. + """ + # TODO: if Python 3 dump supports nocopy dump we should try use it first + # and only fallback to cloudpickle + try: + writer = _BytelistFile() + cloudpickle.dump(x, writer, protocol=pickle.HIGHEST_PROTOCOL) + return writer._chunks + except Exception as e: + logger.info("Failed to serialize %s. Exception: %s", x, e) + raise + + def loads(x): try: return pickle.loads(x) except Exception: logger.info("Failed to deserialize %s", x[:10000], exc_info=True) raise + + +def load_bytelist(bytelist): + try: + reader = _BytelistFile(bytelist) + reader.seek(0) + # Use the Python-based Unpickler to avoid a memory-copy when loading + # large binary data buffers that back numpy arrays and pandas data + # frames. + return pyload(reader) + except Exception: + logger.info("Failed to deserialize %s", bytelist[0][:10000], + exc_info=True) + raise diff --git a/distributed/protocol/serialize.py b/distributed/protocol/serialize.py index 1b7b99b9edb..089b7420198 100644 --- a/distributed/protocol/serialize.py +++ b/distributed/protocol/serialize.py @@ -17,7 +17,7 @@ serializers = {} -deserializers = {None: lambda header, frames: pickle.loads(b''.join(frames))} +deserializers = {None: lambda header, frames: pickle.load_bytelist(frames)} lazy_registrations = {} @@ -132,7 +132,7 @@ def serialize(x): else: if _find_lazy_registration(name): return serialize(x) # recurse - header, frames = {}, [pickle.dumps(x)] + header, frames = {}, pickle.dump_bytelist(x) return header, frames diff --git a/requirements.txt b/requirements.txt index 79b2ed02d12..0a0796cb1e1 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ click >= 6.6 -cloudpickle >= 0.2.2 +git+git://github.com/cloudpipe/cloudpickle@nocopy-memoryviews#egg=cloudpickle dask >= 0.16.0 msgpack-python psutil