From 5cece3efd440c4412354f12ee2e83163cccc9d25 Mon Sep 17 00:00:00 2001 From: Marius van Niekerk Date: Wed, 23 May 2018 22:26:43 -0400 Subject: [PATCH 1/4] Use ext types with msgpack to move tuples, sets and frozensets nicely --- distributed/protocol/core.py | 45 ++++++++++++++++++-- distributed/protocol/serialize.py | 5 ++- distributed/protocol/tests/test_serialize.py | 9 ++++ 3 files changed, 53 insertions(+), 6 deletions(-) diff --git a/distributed/protocol/core.py b/distributed/protocol/core.py index 24e980351d0..57519d5af3d 100644 --- a/distributed/protocol/core.py +++ b/distributed/protocol/core.py @@ -130,6 +130,43 @@ def loads(frames, deserialize=True, deserializers=None): raise +_MSGPACK_EXT_TUPLE = 0 +_MSGPACK_EXT_SET = 1 +_MSGPACK_EXT_FROZENSET = 2 + + +def msgpack_default(o): + """ Default handler to allow encoding some other collection types correctly + + """ + if isinstance(o, (tuple, set, frozenset)): + payload = msgpack.packb( + list(o), strict_types=True, use_bin_type=True, default=msgpack_default) + if isinstance(o, tuple): + ext_type = _MSGPACK_EXT_TUPLE + elif isinstance(o, frozenset): + ext_type = _MSGPACK_EXT_FROZENSET + elif isinstance(o, set): + ext_type = _MSGPACK_EXT_SET + else: + raise TypeError("Unknown type %s" % type(o)) + return msgpack.ExtType(ext_type, payload) + else: + raise TypeError("Unknown type %s for %s" % (repr(o), type(o))) + + +def msgpack_ext_hook(code, payload): + if code in {_MSGPACK_EXT_TUPLE, _MSGPACK_EXT_SET, _MSGPACK_EXT_FROZENSET}: + l = msgpack.unpackb(payload, encoding='utf-8', ext_hook=msgpack_ext_hook) + if code == _MSGPACK_EXT_TUPLE: + return tuple(l) + elif code == _MSGPACK_EXT_SET: + return set(l) + elif code == _MSGPACK_EXT_FROZENSET: + return frozenset(l) + raise ValueError("Unknown Ext code %s, payload: %s" % (code, payload)) + + def dumps_msgpack(msg): """ Dump msg into header and payload, both bytestrings @@ -139,14 +176,14 @@ def dumps_msgpack(msg): loads_msgpack """ header = {} - payload = msgpack.dumps(msg, use_bin_type=True) + payload = msgpack.dumps(msg, use_bin_type=True, default=msgpack_default) fmt, payload = maybe_compress(payload) if fmt: header['compression'] = fmt if header: - header_bytes = msgpack.dumps(header, use_bin_type=True) + header_bytes = msgpack.dumps(header, use_bin_type=True, default=msgpack_default) else: header_bytes = b'' @@ -160,7 +197,7 @@ def loads_msgpack(header, payload): dumps_msgpack """ if header: - header = msgpack.loads(header, encoding='utf8') + header = msgpack.loads(header, encoding='utf8', ext_hook=msgpack_ext_hook) else: header = {} @@ -172,4 +209,4 @@ def loads_msgpack(header, payload): raise ValueError("Data is compressed as %s but we don't have this" " installed" % str(header['compression'])) - return msgpack.loads(payload, encoding='utf8') + return msgpack.loads(payload, encoding='utf8', ext_hook=msgpack_ext_hook) diff --git a/distributed/protocol/serialize.py b/distributed/protocol/serialize.py index 4e397abee1a..f94e31a552c 100644 --- a/distributed/protocol/serialize.py +++ b/distributed/protocol/serialize.py @@ -10,6 +10,7 @@ import msgpack +from . import core from . import pickle from ..compatibility import PY2 from .compression import maybe_compress, decompress @@ -59,11 +60,11 @@ def pickle_loads(header, frames): def msgpack_dumps(x): - return {'serializer': 'msgpack'}, [msgpack.dumps(x, use_bin_type=True)] + return {'serializer': 'msgpack'}, [msgpack.dumps(x, use_bin_type=True, default=core.msgpack_default)] def msgpack_loads(header, frames): - return msgpack.loads(b''.join(frames), encoding='utf8') + return msgpack.loads(b''.join(frames), encoding='utf8', ext_hook=core.msgpack_ext_hook) def serialization_error_loads(header, frames): diff --git a/distributed/protocol/tests/test_serialize.py b/distributed/protocol/tests/test_serialize.py index bc75082c350..a6293494bd1 100644 --- a/distributed/protocol/tests/test_serialize.py +++ b/distributed/protocol/tests/test_serialize.py @@ -178,6 +178,15 @@ def test_serialize_bytes(): assert str(x) == str(y) +@pytest.mark.parametrize('val', [tuple([1, 2]), set([1,2]), frozenset([1,2])]) +def test_serialize_tuple(val): + from distributed.protocol import loads, dumps + reslist = loads(dumps([to_serialize(val)])) + res = reslist[0] + assert type(res) == type(val) + assert res == val + + def test_serialize_list_compress(): pytest.importorskip('lz4') x = np.ones(1000000) From 2b16a1684127d503080b5fdc7a2daaa1d77941a5 Mon Sep 17 00:00:00 2001 From: Marius van Niekerk Date: Wed, 23 May 2018 22:34:48 -0400 Subject: [PATCH 2/4] Added a few missing spots --- distributed/protocol/core.py | 4 ++-- distributed/protocol/serialize.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/distributed/protocol/core.py b/distributed/protocol/core.py index 57519d5af3d..35c13cda1a6 100644 --- a/distributed/protocol/core.py +++ b/distributed/protocol/core.py @@ -81,7 +81,7 @@ def dumps(msg, serializers=None, on_error='message'): out_frames[i] = frame return [small_header, small_payload, - msgpack.dumps(header, use_bin_type=True)] + out_frames + msgpack.dumps(header, use_bin_type=True, default=msgpack_default)] + out_frames except Exception: logger.critical("Failed to Serialize", exc_info=True) raise @@ -100,7 +100,7 @@ def loads(frames, deserialize=True, deserializers=None): return msg header = frames.pop() - header = msgpack.loads(header, encoding='utf8', use_list=False) + header = msgpack.loads(header, encoding='utf8', use_list=False, ext_hook=msgpack_ext_hook) keys = header['keys'] headers = header['headers'] bytestrings = set(header['bytestrings']) diff --git a/distributed/protocol/serialize.py b/distributed/protocol/serialize.py index f94e31a552c..30bbf8f0d03 100644 --- a/distributed/protocol/serialize.py +++ b/distributed/protocol/serialize.py @@ -426,7 +426,7 @@ def serialize_bytelist(x, **kwargs): header['compression'] = compression header['count'] = len(frames) - header = msgpack.dumps(header, use_bin_type=True) + header = msgpack.dumps(header, use_bin_type=True, default=core.msgpack_default) frames2 = [header] + list(frames) return [pack_frames_prelude(frames2)] + frames2 @@ -442,7 +442,7 @@ def deserialize_bytes(b): frames = unpack_frames(b) header, frames = frames[0], frames[1:] if header: - header = msgpack.loads(header, encoding='utf8') + header = msgpack.loads(header, encoding='utf8', ext_hook=core.msgpack_ext_hook) else: header = {} frames = decompress(header, frames) From 6576f0e9cd7592f016939fa7685dc981060a6899 Mon Sep 17 00:00:00 2001 From: Marius van Niekerk Date: Wed, 23 May 2018 23:13:50 -0400 Subject: [PATCH 3/4] Refactor a bit to move msgpack to a wrapper similar to pickle --- distributed/protocol/core.py | 53 +++++-------------------------- distributed/protocol/msgpack.py | 45 ++++++++++++++++++++++++++ distributed/protocol/serialize.py | 13 ++++---- 3 files changed, 59 insertions(+), 52 deletions(-) create mode 100644 distributed/protocol/msgpack.py diff --git a/distributed/protocol/core.py b/distributed/protocol/core.py index 35c13cda1a6..9947d5c8869 100644 --- a/distributed/protocol/core.py +++ b/distributed/protocol/core.py @@ -2,8 +2,6 @@ import logging -import msgpack - try: from cytoolz import get_in except ImportError: @@ -15,6 +13,8 @@ from .utils import frame_split_size, merge_frames from ..utils import nbytes +from . import msgpack + _deserialize = deserialize @@ -81,7 +81,7 @@ def dumps(msg, serializers=None, on_error='message'): out_frames[i] = frame return [small_header, small_payload, - msgpack.dumps(header, use_bin_type=True, default=msgpack_default)] + out_frames + msgpack.dumps(header, use_bin_type=True)] + out_frames except Exception: logger.critical("Failed to Serialize", exc_info=True) raise @@ -100,7 +100,7 @@ def loads(frames, deserialize=True, deserializers=None): return msg header = frames.pop() - header = msgpack.loads(header, encoding='utf8', use_list=False, ext_hook=msgpack_ext_hook) + header = msgpack.loads(header, encoding='utf8', use_list=False) keys = header['keys'] headers = header['headers'] bytestrings = set(header['bytestrings']) @@ -130,43 +130,6 @@ def loads(frames, deserialize=True, deserializers=None): raise -_MSGPACK_EXT_TUPLE = 0 -_MSGPACK_EXT_SET = 1 -_MSGPACK_EXT_FROZENSET = 2 - - -def msgpack_default(o): - """ Default handler to allow encoding some other collection types correctly - - """ - if isinstance(o, (tuple, set, frozenset)): - payload = msgpack.packb( - list(o), strict_types=True, use_bin_type=True, default=msgpack_default) - if isinstance(o, tuple): - ext_type = _MSGPACK_EXT_TUPLE - elif isinstance(o, frozenset): - ext_type = _MSGPACK_EXT_FROZENSET - elif isinstance(o, set): - ext_type = _MSGPACK_EXT_SET - else: - raise TypeError("Unknown type %s" % type(o)) - return msgpack.ExtType(ext_type, payload) - else: - raise TypeError("Unknown type %s for %s" % (repr(o), type(o))) - - -def msgpack_ext_hook(code, payload): - if code in {_MSGPACK_EXT_TUPLE, _MSGPACK_EXT_SET, _MSGPACK_EXT_FROZENSET}: - l = msgpack.unpackb(payload, encoding='utf-8', ext_hook=msgpack_ext_hook) - if code == _MSGPACK_EXT_TUPLE: - return tuple(l) - elif code == _MSGPACK_EXT_SET: - return set(l) - elif code == _MSGPACK_EXT_FROZENSET: - return frozenset(l) - raise ValueError("Unknown Ext code %s, payload: %s" % (code, payload)) - - def dumps_msgpack(msg): """ Dump msg into header and payload, both bytestrings @@ -176,14 +139,14 @@ def dumps_msgpack(msg): loads_msgpack """ header = {} - payload = msgpack.dumps(msg, use_bin_type=True, default=msgpack_default) + payload = msgpack.dumps(msg, use_bin_type=True) fmt, payload = maybe_compress(payload) if fmt: header['compression'] = fmt if header: - header_bytes = msgpack.dumps(header, use_bin_type=True, default=msgpack_default) + header_bytes = msgpack.dumps(header, use_bin_type=True) else: header_bytes = b'' @@ -197,7 +160,7 @@ def loads_msgpack(header, payload): dumps_msgpack """ if header: - header = msgpack.loads(header, encoding='utf8', ext_hook=msgpack_ext_hook) + header = msgpack.loads(header, encoding='utf8') else: header = {} @@ -209,4 +172,4 @@ def loads_msgpack(header, payload): raise ValueError("Data is compressed as %s but we don't have this" " installed" % str(header['compression'])) - return msgpack.loads(payload, encoding='utf8', ext_hook=msgpack_ext_hook) + return msgpack.loads(payload, encoding='utf8') diff --git a/distributed/protocol/msgpack.py b/distributed/protocol/msgpack.py new file mode 100644 index 00000000000..804f3a581bd --- /dev/null +++ b/distributed/protocol/msgpack.py @@ -0,0 +1,45 @@ +from __future__ import print_function, division, absolute_import + +from functools import partial + +import msgpack + +_MSGPACK_EXT_TUPLE = 0 +_MSGPACK_EXT_SET = 1 +_MSGPACK_EXT_FROZENSET = 2 + + +def _msgpack_default(o): + """ Default handler to allow encoding some other collection types correctly + + """ + if isinstance(o, (tuple, set, frozenset)): + payload = msgpack.packb( + list(o), strict_types=True, use_bin_type=True, default=_msgpack_default) + if isinstance(o, tuple): + ext_type = _MSGPACK_EXT_TUPLE + elif isinstance(o, frozenset): + ext_type = _MSGPACK_EXT_FROZENSET + elif isinstance(o, set): + ext_type = _MSGPACK_EXT_SET + else: + raise TypeError("Unknown type %s" % type(o)) + return msgpack.ExtType(ext_type, payload) + else: + raise TypeError("Unknown type %s for %s" % (repr(o), type(o))) + + +def _msgpack_ext_hook(code, payload): + if code in {_MSGPACK_EXT_TUPLE, _MSGPACK_EXT_SET, _MSGPACK_EXT_FROZENSET}: + l = msgpack.unpackb(payload, encoding='utf-8', ext_hook=_msgpack_ext_hook) + if code == _MSGPACK_EXT_TUPLE: + return tuple(l) + elif code == _MSGPACK_EXT_SET: + return set(l) + elif code == _MSGPACK_EXT_FROZENSET: + return frozenset(l) + raise ValueError("Unknown Ext code %s, payload: %s" % (code, payload)) + + +dumps = partial(msgpack.dumps, default=_msgpack_default) +loads = partial(msgpack.loads, ext_hook=_msgpack_ext_hook) diff --git a/distributed/protocol/serialize.py b/distributed/protocol/serialize.py index 30bbf8f0d03..9ce3207ddc1 100644 --- a/distributed/protocol/serialize.py +++ b/distributed/protocol/serialize.py @@ -3,14 +3,13 @@ import traceback from dask.base import normalize_token + try: from cytoolz import valmap, get_in except ImportError: from toolz import valmap, get_in -import msgpack - -from . import core +from . import msgpack from . import pickle from ..compatibility import PY2 from .compression import maybe_compress, decompress @@ -60,11 +59,11 @@ def pickle_loads(header, frames): def msgpack_dumps(x): - return {'serializer': 'msgpack'}, [msgpack.dumps(x, use_bin_type=True, default=core.msgpack_default)] + return {'serializer': 'msgpack'}, [msgpack.dumps(x, use_bin_type=True)] def msgpack_loads(header, frames): - return msgpack.loads(b''.join(frames), encoding='utf8', ext_hook=core.msgpack_ext_hook) + return msgpack.loads(b''.join(frames), encoding='utf8') def serialization_error_loads(header, frames): @@ -426,7 +425,7 @@ def serialize_bytelist(x, **kwargs): header['compression'] = compression header['count'] = len(frames) - header = msgpack.dumps(header, use_bin_type=True, default=core.msgpack_default) + header = msgpack.dumps(header, use_bin_type=True) frames2 = [header] + list(frames) return [pack_frames_prelude(frames2)] + frames2 @@ -442,7 +441,7 @@ def deserialize_bytes(b): frames = unpack_frames(b) header, frames = frames[0], frames[1:] if header: - header = msgpack.loads(header, encoding='utf8', ext_hook=core.msgpack_ext_hook) + header = msgpack.loads(header, encoding='utf8') else: header = {} frames = decompress(header, frames) From 934832e239d408e2ba51c5de98529b152d1a181c Mon Sep 17 00:00:00 2001 From: Marius van Niekerk Date: Wed, 23 May 2018 23:47:11 -0400 Subject: [PATCH 4/4] Actually test msgpack --- distributed/protocol/msgpack.py | 3 ++- distributed/protocol/tests/test_serialize.py | 7 +++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/distributed/protocol/msgpack.py b/distributed/protocol/msgpack.py index 804f3a581bd..b5ec47b3645 100644 --- a/distributed/protocol/msgpack.py +++ b/distributed/protocol/msgpack.py @@ -41,5 +41,6 @@ def _msgpack_ext_hook(code, payload): raise ValueError("Unknown Ext code %s, payload: %s" % (code, payload)) -dumps = partial(msgpack.dumps, default=_msgpack_default) +# msgpack's non-strict mode automatically treats tuple as a list which bypasses our default +dumps = partial(msgpack.dumps, default=_msgpack_default, strict_types=True) loads = partial(msgpack.loads, ext_hook=_msgpack_ext_hook) diff --git a/distributed/protocol/tests/test_serialize.py b/distributed/protocol/tests/test_serialize.py index a6293494bd1..f1f186f2137 100644 --- a/distributed/protocol/tests/test_serialize.py +++ b/distributed/protocol/tests/test_serialize.py @@ -179,10 +179,9 @@ def test_serialize_bytes(): @pytest.mark.parametrize('val', [tuple([1, 2]), set([1,2]), frozenset([1,2])]) -def test_serialize_tuple(val): - from distributed.protocol import loads, dumps - reslist = loads(dumps([to_serialize(val)])) - res = reslist[0] +def test_serialize_msgpack(val): + from distributed.protocol.msgpack import loads, dumps + res = loads(dumps(val)) assert type(res) == type(val) assert res == val