diff --git a/distributed/protocol/core.py b/distributed/protocol/core.py index 24e980351d0..9947d5c8869 100644 --- a/distributed/protocol/core.py +++ b/distributed/protocol/core.py @@ -2,8 +2,6 @@ import logging -import msgpack - try: from cytoolz import get_in except ImportError: @@ -15,6 +13,8 @@ from .utils import frame_split_size, merge_frames from ..utils import nbytes +from . import msgpack + _deserialize = deserialize diff --git a/distributed/protocol/msgpack.py b/distributed/protocol/msgpack.py new file mode 100644 index 00000000000..b5ec47b3645 --- /dev/null +++ b/distributed/protocol/msgpack.py @@ -0,0 +1,46 @@ +from __future__ import print_function, division, absolute_import + +from functools import partial + +import msgpack + +_MSGPACK_EXT_TUPLE = 0 +_MSGPACK_EXT_SET = 1 +_MSGPACK_EXT_FROZENSET = 2 + + +def _msgpack_default(o): + """ Default handler to allow encoding some other collection types correctly + + """ + if isinstance(o, (tuple, set, frozenset)): + payload = msgpack.packb( + list(o), strict_types=True, use_bin_type=True, default=_msgpack_default) + if isinstance(o, tuple): + ext_type = _MSGPACK_EXT_TUPLE + elif isinstance(o, frozenset): + ext_type = _MSGPACK_EXT_FROZENSET + elif isinstance(o, set): + ext_type = _MSGPACK_EXT_SET + else: + raise TypeError("Unknown type %s" % type(o)) + return msgpack.ExtType(ext_type, payload) + else: + raise TypeError("Unknown type %s for %s" % (repr(o), type(o))) + + +def _msgpack_ext_hook(code, payload): + if code in {_MSGPACK_EXT_TUPLE, _MSGPACK_EXT_SET, _MSGPACK_EXT_FROZENSET}: + l = msgpack.unpackb(payload, encoding='utf-8', ext_hook=_msgpack_ext_hook) + if code == _MSGPACK_EXT_TUPLE: + return tuple(l) + elif code == _MSGPACK_EXT_SET: + return set(l) + elif code == _MSGPACK_EXT_FROZENSET: + return frozenset(l) + raise ValueError("Unknown Ext code %s, payload: %s" % (code, payload)) + + +# msgpack's non-strict mode automatically treats tuple as a list which bypasses our default +dumps = partial(msgpack.dumps, default=_msgpack_default, strict_types=True) +loads = partial(msgpack.loads, ext_hook=_msgpack_ext_hook) diff --git a/distributed/protocol/serialize.py b/distributed/protocol/serialize.py index 4e397abee1a..9ce3207ddc1 100644 --- a/distributed/protocol/serialize.py +++ b/distributed/protocol/serialize.py @@ -3,13 +3,13 @@ import traceback from dask.base import normalize_token + try: from cytoolz import valmap, get_in except ImportError: from toolz import valmap, get_in -import msgpack - +from . import msgpack from . import pickle from ..compatibility import PY2 from .compression import maybe_compress, decompress diff --git a/distributed/protocol/tests/test_serialize.py b/distributed/protocol/tests/test_serialize.py index bc75082c350..f1f186f2137 100644 --- a/distributed/protocol/tests/test_serialize.py +++ b/distributed/protocol/tests/test_serialize.py @@ -178,6 +178,14 @@ def test_serialize_bytes(): assert str(x) == str(y) +@pytest.mark.parametrize('val', [tuple([1, 2]), set([1,2]), frozenset([1,2])]) +def test_serialize_msgpack(val): + from distributed.protocol.msgpack import loads, dumps + res = loads(dumps(val)) + assert type(res) == type(val) + assert res == val + + def test_serialize_list_compress(): pytest.importorskip('lz4') x = np.ones(1000000)