diff --git a/distributed/client.py b/distributed/client.py index 8be4f945fb9..87255dd3499 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -960,8 +960,8 @@ def _handle_report(self): continue else: break - if not isinstance(msgs, list): - msgs = [msgs] + if not isinstance(msgs, (list, tuple)): + msgs = (msgs,) breakout = False for msg in msgs: @@ -2665,7 +2665,7 @@ def ncores(self, workers=None, **kwargs): if (isinstance(workers, tuple) and all(isinstance(i, (str, tuple)) for i in workers)): workers = list(workers) - if workers is not None and not isinstance(workers, (list, set)): + if workers is not None and not isinstance(workers, (tuple, list, set)): workers = [workers] return self.sync(self.scheduler.ncores, workers=workers, **kwargs) @@ -2731,7 +2731,7 @@ def has_what(self, workers=None, **kwargs): if (isinstance(workers, tuple) and all(isinstance(i, (str, tuple)) for i in workers)): workers = list(workers) - if workers is not None and not isinstance(workers, (list, set)): + if workers is not None and not isinstance(workers, (tuple, list, set)): workers = [workers] return self.sync(self.scheduler.has_what, workers=workers, **kwargs) @@ -2760,7 +2760,7 @@ def processing(self, workers=None): if (isinstance(workers, tuple) and all(isinstance(i, (str, tuple)) for i in workers)): workers = list(workers) - if workers is not None and not isinstance(workers, (list, set)): + if workers is not None and not isinstance(workers, (tuple, list, set)): workers = [workers] return self.sync(self.scheduler.processing, workers=workers) @@ -2910,8 +2910,8 @@ def get_metadata(self, keys, default=no_default): -------- Client.set_metadata """ - if not isinstance(keys, list): - keys = [keys] + if not isinstance(keys, (list, tuple)): + keys = (keys,) return self.sync(self.scheduler.get_metadata, keys=keys, default=default) @@ -3012,7 +3012,7 @@ def set_metadata(self, key, value): get_metadata """ if not isinstance(key, list): - key = [key] + key = (key,) return self.sync(self.scheduler.set_metadata, keys=key, value=value) def get_versions(self, check=False): @@ -3040,7 +3040,8 @@ def get_versions(self, check=False): if check: # we care about the required & optional packages matching def to_packages(d): - return dict(sum(d['packages'].values(), [])) + L = list(d['packages'].values()) + return dict(sum(L, type(L[0])())) client_versions = to_packages(result['client']) versions = [('scheduler', to_packages(result['scheduler']))] versions.extend((w, to_packages(d)) diff --git a/distributed/core.py b/distributed/core.py index c00ecf088b9..b152a58cf6d 100644 --- a/distributed/core.py +++ b/distributed/core.py @@ -359,8 +359,8 @@ def handle_stream(self, comm, extra=None, every_cycle=[]): try: while not closed: msgs = yield comm.read() - if not isinstance(msgs, list): - msgs = [msgs] + if not isinstance(msgs, (tuple, list)): + msgs = (msgs,) if not comm.closed(): for msg in msgs: diff --git a/distributed/diagnostics/tests/test_eventstream.py b/distributed/diagnostics/tests/test_eventstream.py index ee7f343d421..7504a59846b 100644 --- a/distributed/diagnostics/tests/test_eventstream.py +++ b/distributed/diagnostics/tests/test_eventstream.py @@ -57,7 +57,7 @@ def test_eventstream_remote(c, s, a, b): total = [] while len(total) < 10: msgs = yield comm.read() - assert isinstance(msgs, list) + assert isinstance(msgs, tuple) total.extend(msgs) assert time() < start + 5 diff --git a/distributed/protocol/core.py b/distributed/protocol/core.py index 24e980351d0..9209aa06184 100644 --- a/distributed/protocol/core.py +++ b/distributed/protocol/core.py @@ -1,13 +1,14 @@ from __future__ import print_function, division, absolute_import import logging +import operator import msgpack try: - from cytoolz import get_in + from cytoolz import reduce except ImportError: - from toolz import get_in + from toolz import reduce from .compression import compressions, maybe_compress, decompress from .serialize import (serialize, deserialize, Serialize, Serialized, @@ -122,7 +123,19 @@ def loads(frames, deserialize=True, deserializers=None): else: value = Serialized(head, fs) - get_in(key[:-1], msg)[key[-1]] = value + def put_in(keys, coll, val): + """Inverse of get_in, but does type promotion in the case of lists""" + if keys: + holder = reduce(operator.getitem, keys[:-1], coll) + if isinstance(holder, tuple): + holder = list(holder) + coll = put_in(keys[:-1], coll, holder) + holder[keys[-1]] = val + else: + coll = val + return coll + + msg = put_in(key, msg, value) return msg except Exception: @@ -160,7 +173,7 @@ def loads_msgpack(header, payload): dumps_msgpack """ if header: - header = msgpack.loads(header, encoding='utf8') + header = msgpack.loads(header, encoding='utf8', use_list=False) else: header = {} @@ -172,4 +185,4 @@ def loads_msgpack(header, payload): raise ValueError("Data is compressed as %s but we don't have this" " installed" % str(header['compression'])) - return msgpack.loads(payload, encoding='utf8') + return msgpack.loads(payload, encoding='utf8', use_list=False) diff --git a/distributed/protocol/serialize.py b/distributed/protocol/serialize.py index 4e397abee1a..27cb33abe28 100644 --- a/distributed/protocol/serialize.py +++ b/distributed/protocol/serialize.py @@ -63,7 +63,7 @@ def msgpack_dumps(x): def msgpack_loads(header, frames): - return msgpack.loads(b''.join(frames), encoding='utf8') + return msgpack.loads(b''.join(frames), encoding='utf8', use_list=False) def serialization_error_loads(header, frames): @@ -441,7 +441,7 @@ def deserialize_bytes(b): frames = unpack_frames(b) header, frames = frames[0], frames[1:] if header: - header = msgpack.loads(header, encoding='utf8') + header = msgpack.loads(header, encoding='utf8', use_list=False) else: header = {} frames = decompress(header, frames) diff --git a/distributed/protocol/tests/test_protocol.py b/distributed/protocol/tests/test_protocol.py index 0906d7687d3..509640e47fa 100644 --- a/distributed/protocol/tests/test_protocol.py +++ b/distributed/protocol/tests/test_protocol.py @@ -56,7 +56,7 @@ def test_small(): def test_small_and_big(): - d = {'x': [1, 2, 3], 'y': b'0' * 10000000} + d = {'x': (1, 2, 3), 'y': b'0' * 10000000} L = dumps(d) assert loads(L) == d # assert loads([small_header, small]) == {'x': [1, 2, 3]} diff --git a/distributed/protocol/tests/test_serialize.py b/distributed/protocol/tests/test_serialize.py index bc75082c350..a1cedf5f2f7 100644 --- a/distributed/protocol/tests/test_serialize.py +++ b/distributed/protocol/tests/test_serialize.py @@ -170,6 +170,13 @@ def test_empty_loads(): assert isinstance(e2[0], Empty) +def test_empty_loads_deep(): + from distributed.protocol import loads, dumps + e = Empty() + e2 = loads(dumps([[[to_serialize(e)]]])) + assert isinstance(e2[0][0][0], Empty) + + def test_serialize_bytes(): for x in [1, 'abc', np.arange(5)]: b = serialize_bytes(x) diff --git a/distributed/tests/test_batched.py b/distributed/tests/test_batched.py index eac2bbfbeff..1fcf026b68b 100644 --- a/distributed/tests/test_batched.py +++ b/distributed/tests/test_batched.py @@ -69,9 +69,9 @@ def test_BatchedSend(): b.send('HELLO') result = yield comm.read() - assert result == ['hello', 'hello', 'world'] + assert result == ('hello', 'hello', 'world') result = yield comm.read() - assert result == ['HELLO', 'HELLO'] + assert result == ('HELLO', 'HELLO') assert b.byte_count > 1 @@ -88,7 +88,7 @@ def test_send_before_start(): b.start(comm) result = yield comm.read() - assert result == ['hello', 'world'] + assert result == ('hello', 'world') @gen_test() @@ -104,7 +104,7 @@ def test_send_after_stream_start(): result = yield comm.read() if len(result) < 2: result += yield comm.read() - assert result == ['hello', 'world'] + assert result == ('hello', 'world') @gen_test() @@ -295,7 +295,7 @@ def test_serializers(): assert 'function' in value msg = yield comm.read() - assert msg == [{'x': 123}, {'x': 'hello'}] + assert list(msg) == [{'x': 123}, {'x': 'hello'}] with pytest.raises(gen.TimeoutError): msg = yield gen.with_timeout(timedelta(milliseconds=100), comm.read()) diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index cc6f992e911..7cc9be216d1 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -3340,7 +3340,7 @@ def test_default_get(): @gen_cluster(client=True) def test_get_processing(c, s, a, b): processing = yield c.processing() - assert processing == valmap(list, s.processing) + assert processing == valmap(tuple, s.processing) futures = c.map(slowinc, range(10), delay=0.1, workers=[a.address], allow_other_workers=True) @@ -3351,7 +3351,7 @@ def test_get_processing(c, s, a, b): assert set(x) == {a.address, b.address} x = yield c.processing(workers=[a.address]) - assert isinstance(x[a.address], list) + assert isinstance(x[a.address], (list, tuple)) @gen_cluster(client=True) @@ -3384,6 +3384,14 @@ def test_get_foo(c, s, a, b): assert valmap(sorted, x) == {futures[0].key: sorted(s.who_has[futures[0].key])} +def assert_dict_key_equal(expected, actual): + assert set(expected.keys()) == set(actual.keys()) + for k in actual.keys(): + ev = expected[k] + av = actual[k] + assert list(ev) == list(av) + + @gen_cluster(client=True, ncores=[('127.0.0.1', 1)] * 3) def test_get_foo_lost_keys(c, s, u, v, w): x = c.submit(inc, 1, workers=[u.address]) @@ -3393,27 +3401,27 @@ def test_get_foo_lost_keys(c, s, u, v, w): ua, va, wa = u.address, v.address, w.address d = yield c.scheduler.has_what() - assert d == {ua: [x.key], va: [y.key], wa: []} + assert_dict_key_equal(d, {ua: [x.key], va: [y.key], wa: []}) d = yield c.scheduler.has_what(workers=[ua, va]) - assert d == {ua: [x.key], va: [y.key]} + assert_dict_key_equal(d, {ua: [x.key], va: [y.key]}) d = yield c.scheduler.who_has() - assert d == {x.key: [ua], y.key: [va]} + assert_dict_key_equal(d, {x.key: [ua], y.key: [va]}) d = yield c.scheduler.who_has(keys=[x.key, y.key]) - assert d == {x.key: [ua], y.key: [va]} + assert_dict_key_equal(d, {x.key: [ua], y.key: [va]}) yield u._close() yield v._close() d = yield c.scheduler.has_what() - assert d == {wa: []} + assert_dict_key_equal(d, {wa: []}) d = yield c.scheduler.has_what(workers=[ua, va]) - assert d == {ua: [], va: []} + assert_dict_key_equal(d, {ua: [], va: []}) # The scattered key cannot be recomputed so it is forgotten d = yield c.scheduler.who_has() - assert d == {x.key: []} + assert_dict_key_equal(d, {x.key: []}) # ... but when passed explicitly, it is included in the result d = yield c.scheduler.who_has(keys=[x.key, y.key]) - assert d == {x.key: [], y.key: []} + assert_dict_key_equal(d, {x.key: [], y.key: []}) @slow diff --git a/distributed/tests/test_publish.py b/distributed/tests/test_publish.py index bf6b2097d5b..a331b7b957a 100644 --- a/distributed/tests/test_publish.py +++ b/distributed/tests/test_publish.py @@ -28,10 +28,10 @@ def test_publish_simple(s, a, b): assert "data" in str(exc_info.value) result = yield c.scheduler.publish_list() - assert result == ['data'] + assert result == ('data',) result = yield f.scheduler.publish_list() - assert result == ['data'] + assert result == ('data',) yield c.close() yield f.close() @@ -221,7 +221,7 @@ def test_pickle_safe(c, s, a, b): try: yield c2.publish_dataset(x=[1, 2, 3]) result = yield c2.get_dataset('x') - assert result == [1, 2, 3] + assert result == (1, 2, 3) with pytest.raises(TypeError): yield c2.publish_dataset(y=lambda x: x) diff --git a/distributed/tests/test_queues.py b/distributed/tests/test_queues.py index 715aa99d967..a5a8d63d8fa 100644 --- a/distributed/tests/test_queues.py +++ b/distributed/tests/test_queues.py @@ -48,10 +48,10 @@ def test_queue_with_data(c, s, a, b): xx = yield Queue('x') assert x.client is c - yield x.put([1, 'hello']) + yield x.put((1, 'hello')) data = yield xx.get() - assert data == [1, 'hello'] + assert data == (1, 'hello') with pytest.raises(gen.TimeoutError): yield x.get(timeout=0.1) diff --git a/distributed/tests/test_variable.py b/distributed/tests/test_variable.py index db529f12afe..6debaa80719 100644 --- a/distributed/tests/test_variable.py +++ b/distributed/tests/test_variable.py @@ -44,10 +44,10 @@ def test_queue_with_data(c, s, a, b): xx = Variable('x') assert x.client is c - yield x.set([1, 'hello']) + yield x.set((1, 'hello')) data = yield xx.get() - assert data == [1, 'hello'] + assert data == (1, 'hello') def test_sync(loop):