Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions distributed/protocol/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@

import logging

import msgpack

try:
from cytoolz import get_in
except ImportError:
Expand All @@ -15,6 +13,8 @@
from .utils import frame_split_size, merge_frames
from ..utils import nbytes

from . import msgpack

_deserialize = deserialize


Expand Down
46 changes: 46 additions & 0 deletions distributed/protocol/msgpack.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
from __future__ import print_function, division, absolute_import

from functools import partial

import msgpack

_MSGPACK_EXT_TUPLE = 0
_MSGPACK_EXT_SET = 1
_MSGPACK_EXT_FROZENSET = 2


def _msgpack_default(o):
""" Default handler to allow encoding some other collection types correctly

"""
if isinstance(o, (tuple, set, frozenset)):
payload = msgpack.packb(
list(o), strict_types=True, use_bin_type=True, default=_msgpack_default)
if isinstance(o, tuple):
ext_type = _MSGPACK_EXT_TUPLE
elif isinstance(o, frozenset):
ext_type = _MSGPACK_EXT_FROZENSET
elif isinstance(o, set):
ext_type = _MSGPACK_EXT_SET
else:
raise TypeError("Unknown type %s" % type(o))
return msgpack.ExtType(ext_type, payload)
else:
raise TypeError("Unknown type %s for %s" % (repr(o), type(o)))


def _msgpack_ext_hook(code, payload):
if code in {_MSGPACK_EXT_TUPLE, _MSGPACK_EXT_SET, _MSGPACK_EXT_FROZENSET}:
l = msgpack.unpackb(payload, encoding='utf-8', ext_hook=_msgpack_ext_hook)
if code == _MSGPACK_EXT_TUPLE:
return tuple(l)
elif code == _MSGPACK_EXT_SET:
return set(l)
elif code == _MSGPACK_EXT_FROZENSET:
return frozenset(l)
raise ValueError("Unknown Ext code %s, payload: %s" % (code, payload))


# msgpack's non-strict mode automatically treats tuple as a list which bypasses our default
dumps = partial(msgpack.dumps, default=_msgpack_default, strict_types=True)
loads = partial(msgpack.loads, ext_hook=_msgpack_ext_hook)
4 changes: 2 additions & 2 deletions distributed/protocol/serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@
import traceback

from dask.base import normalize_token

try:
from cytoolz import valmap, get_in
except ImportError:
from toolz import valmap, get_in

import msgpack

from . import msgpack
from . import pickle
from ..compatibility import PY2
from .compression import maybe_compress, decompress
Expand Down
8 changes: 8 additions & 0 deletions distributed/protocol/tests/test_serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,14 @@ def test_serialize_bytes():
assert str(x) == str(y)


@pytest.mark.parametrize('val', [tuple([1, 2]), set([1,2]), frozenset([1,2])])
def test_serialize_msgpack(val):
from distributed.protocol.msgpack import loads, dumps
res = loads(dumps(val))
assert type(res) == type(val)
assert res == val


def test_serialize_list_compress():
pytest.importorskip('lz4')
x = np.ones(1000000)
Expand Down