Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 10 additions & 9 deletions distributed/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -960,8 +960,8 @@ def _handle_report(self):
continue
else:
break
if not isinstance(msgs, list):
msgs = [msgs]
if not isinstance(msgs, (list, tuple)):
msgs = (msgs,)

breakout = False
for msg in msgs:
Expand Down Expand Up @@ -2665,7 +2665,7 @@ def ncores(self, workers=None, **kwargs):
if (isinstance(workers, tuple)
and all(isinstance(i, (str, tuple)) for i in workers)):
workers = list(workers)
if workers is not None and not isinstance(workers, (list, set)):
if workers is not None and not isinstance(workers, (tuple, list, set)):
workers = [workers]
return self.sync(self.scheduler.ncores, workers=workers, **kwargs)

Expand Down Expand Up @@ -2731,7 +2731,7 @@ def has_what(self, workers=None, **kwargs):
if (isinstance(workers, tuple)
and all(isinstance(i, (str, tuple)) for i in workers)):
workers = list(workers)
if workers is not None and not isinstance(workers, (list, set)):
if workers is not None and not isinstance(workers, (tuple, list, set)):
workers = [workers]
return self.sync(self.scheduler.has_what, workers=workers, **kwargs)

Expand Down Expand Up @@ -2760,7 +2760,7 @@ def processing(self, workers=None):
if (isinstance(workers, tuple)
and all(isinstance(i, (str, tuple)) for i in workers)):
workers = list(workers)
if workers is not None and not isinstance(workers, (list, set)):
if workers is not None and not isinstance(workers, (tuple, list, set)):
workers = [workers]
return self.sync(self.scheduler.processing, workers=workers)

Expand Down Expand Up @@ -2910,8 +2910,8 @@ def get_metadata(self, keys, default=no_default):
--------
Client.set_metadata
"""
if not isinstance(keys, list):
keys = [keys]
if not isinstance(keys, (list, tuple)):
keys = (keys,)
return self.sync(self.scheduler.get_metadata, keys=keys,
default=default)

Expand Down Expand Up @@ -3012,7 +3012,7 @@ def set_metadata(self, key, value):
get_metadata
"""
if not isinstance(key, list):
key = [key]
key = (key,)
return self.sync(self.scheduler.set_metadata, keys=key, value=value)

def get_versions(self, check=False):
Expand Down Expand Up @@ -3040,7 +3040,8 @@ def get_versions(self, check=False):
if check:
# we care about the required & optional packages matching
def to_packages(d):
return dict(sum(d['packages'].values(), []))
L = list(d['packages'].values())
return dict(sum(L, type(L[0])()))
client_versions = to_packages(result['client'])
versions = [('scheduler', to_packages(result['scheduler']))]
versions.extend((w, to_packages(d))
Expand Down
4 changes: 2 additions & 2 deletions distributed/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,8 +359,8 @@ def handle_stream(self, comm, extra=None, every_cycle=[]):
try:
while not closed:
msgs = yield comm.read()
if not isinstance(msgs, list):
msgs = [msgs]
if not isinstance(msgs, (tuple, list)):
msgs = (msgs,)

if not comm.closed():
for msg in msgs:
Expand Down
2 changes: 1 addition & 1 deletion distributed/diagnostics/tests/test_eventstream.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def test_eventstream_remote(c, s, a, b):
total = []
while len(total) < 10:
msgs = yield comm.read()
assert isinstance(msgs, list)
assert isinstance(msgs, tuple)
total.extend(msgs)
assert time() < start + 5

Expand Down
23 changes: 18 additions & 5 deletions distributed/protocol/core.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
from __future__ import print_function, division, absolute_import

import logging
import operator

import msgpack

try:
from cytoolz import get_in
from cytoolz import reduce
except ImportError:
from toolz import get_in
from toolz import reduce

from .compression import compressions, maybe_compress, decompress
from .serialize import (serialize, deserialize, Serialize, Serialized,
Expand Down Expand Up @@ -122,7 +123,19 @@ def loads(frames, deserialize=True, deserializers=None):
else:
value = Serialized(head, fs)

get_in(key[:-1], msg)[key[-1]] = value
def put_in(keys, coll, val):
"""Inverse of get_in, but does type promotion in the case of lists"""
if keys:
holder = reduce(operator.getitem, keys[:-1], coll)
if isinstance(holder, tuple):
holder = list(holder)
coll = put_in(keys[:-1], coll, holder)
holder[keys[-1]] = val
else:
coll = val
return coll

msg = put_in(key, msg, value)

return msg
except Exception:
Expand Down Expand Up @@ -160,7 +173,7 @@ def loads_msgpack(header, payload):
dumps_msgpack
"""
if header:
header = msgpack.loads(header, encoding='utf8')
header = msgpack.loads(header, encoding='utf8', use_list=False)
else:
header = {}

Expand All @@ -172,4 +185,4 @@ def loads_msgpack(header, payload):
raise ValueError("Data is compressed as %s but we don't have this"
" installed" % str(header['compression']))

return msgpack.loads(payload, encoding='utf8')
return msgpack.loads(payload, encoding='utf8', use_list=False)
4 changes: 2 additions & 2 deletions distributed/protocol/serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def msgpack_dumps(x):


def msgpack_loads(header, frames):
return msgpack.loads(b''.join(frames), encoding='utf8')
return msgpack.loads(b''.join(frames), encoding='utf8', use_list=False)


def serialization_error_loads(header, frames):
Expand Down Expand Up @@ -441,7 +441,7 @@ def deserialize_bytes(b):
frames = unpack_frames(b)
header, frames = frames[0], frames[1:]
if header:
header = msgpack.loads(header, encoding='utf8')
header = msgpack.loads(header, encoding='utf8', use_list=False)
else:
header = {}
frames = decompress(header, frames)
Expand Down
2 changes: 1 addition & 1 deletion distributed/protocol/tests/test_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def test_small():


def test_small_and_big():
d = {'x': [1, 2, 3], 'y': b'0' * 10000000}
d = {'x': (1, 2, 3), 'y': b'0' * 10000000}
L = dumps(d)
assert loads(L) == d
# assert loads([small_header, small]) == {'x': [1, 2, 3]}
Expand Down
7 changes: 7 additions & 0 deletions distributed/protocol/tests/test_serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,13 @@ def test_empty_loads():
assert isinstance(e2[0], Empty)


def test_empty_loads_deep():
from distributed.protocol import loads, dumps
e = Empty()
e2 = loads(dumps([[[to_serialize(e)]]]))
assert isinstance(e2[0][0][0], Empty)


def test_serialize_bytes():
for x in [1, 'abc', np.arange(5)]:
b = serialize_bytes(x)
Expand Down
10 changes: 5 additions & 5 deletions distributed/tests/test_batched.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,9 @@ def test_BatchedSend():
b.send('HELLO')

result = yield comm.read()
assert result == ['hello', 'hello', 'world']
assert result == ('hello', 'hello', 'world')
result = yield comm.read()
assert result == ['HELLO', 'HELLO']
assert result == ('HELLO', 'HELLO')

assert b.byte_count > 1

Expand All @@ -88,7 +88,7 @@ def test_send_before_start():

b.start(comm)
result = yield comm.read()
assert result == ['hello', 'world']
assert result == ('hello', 'world')


@gen_test()
Expand All @@ -104,7 +104,7 @@ def test_send_after_stream_start():
result = yield comm.read()
if len(result) < 2:
result += yield comm.read()
assert result == ['hello', 'world']
assert result == ('hello', 'world')


@gen_test()
Expand Down Expand Up @@ -295,7 +295,7 @@ def test_serializers():
assert 'function' in value

msg = yield comm.read()
assert msg == [{'x': 123}, {'x': 'hello'}]
assert list(msg) == [{'x': 123}, {'x': 'hello'}]

with pytest.raises(gen.TimeoutError):
msg = yield gen.with_timeout(timedelta(milliseconds=100), comm.read())
28 changes: 18 additions & 10 deletions distributed/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3340,7 +3340,7 @@ def test_default_get():
@gen_cluster(client=True)
def test_get_processing(c, s, a, b):
processing = yield c.processing()
assert processing == valmap(list, s.processing)
assert processing == valmap(tuple, s.processing)

futures = c.map(slowinc, range(10), delay=0.1, workers=[a.address],
allow_other_workers=True)
Expand All @@ -3351,7 +3351,7 @@ def test_get_processing(c, s, a, b):
assert set(x) == {a.address, b.address}

x = yield c.processing(workers=[a.address])
assert isinstance(x[a.address], list)
assert isinstance(x[a.address], (list, tuple))


@gen_cluster(client=True)
Expand Down Expand Up @@ -3384,6 +3384,14 @@ def test_get_foo(c, s, a, b):
assert valmap(sorted, x) == {futures[0].key: sorted(s.who_has[futures[0].key])}


def assert_dict_key_equal(expected, actual):
assert set(expected.keys()) == set(actual.keys())
for k in actual.keys():
ev = expected[k]
av = actual[k]
assert list(ev) == list(av)


@gen_cluster(client=True, ncores=[('127.0.0.1', 1)] * 3)
def test_get_foo_lost_keys(c, s, u, v, w):
x = c.submit(inc, 1, workers=[u.address])
Expand All @@ -3393,27 +3401,27 @@ def test_get_foo_lost_keys(c, s, u, v, w):
ua, va, wa = u.address, v.address, w.address

d = yield c.scheduler.has_what()
assert d == {ua: [x.key], va: [y.key], wa: []}
assert_dict_key_equal(d, {ua: [x.key], va: [y.key], wa: []})
d = yield c.scheduler.has_what(workers=[ua, va])
assert d == {ua: [x.key], va: [y.key]}
assert_dict_key_equal(d, {ua: [x.key], va: [y.key]})
d = yield c.scheduler.who_has()
assert d == {x.key: [ua], y.key: [va]}
assert_dict_key_equal(d, {x.key: [ua], y.key: [va]})
d = yield c.scheduler.who_has(keys=[x.key, y.key])
assert d == {x.key: [ua], y.key: [va]}
assert_dict_key_equal(d, {x.key: [ua], y.key: [va]})

yield u._close()
yield v._close()

d = yield c.scheduler.has_what()
assert d == {wa: []}
assert_dict_key_equal(d, {wa: []})
d = yield c.scheduler.has_what(workers=[ua, va])
assert d == {ua: [], va: []}
assert_dict_key_equal(d, {ua: [], va: []})
# The scattered key cannot be recomputed so it is forgotten
d = yield c.scheduler.who_has()
assert d == {x.key: []}
assert_dict_key_equal(d, {x.key: []})
# ... but when passed explicitly, it is included in the result
d = yield c.scheduler.who_has(keys=[x.key, y.key])
assert d == {x.key: [], y.key: []}
assert_dict_key_equal(d, {x.key: [], y.key: []})


@slow
Expand Down
6 changes: 3 additions & 3 deletions distributed/tests/test_publish.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,10 @@ def test_publish_simple(s, a, b):
assert "data" in str(exc_info.value)

result = yield c.scheduler.publish_list()
assert result == ['data']
assert result == ('data',)

result = yield f.scheduler.publish_list()
assert result == ['data']
assert result == ('data',)

yield c.close()
yield f.close()
Expand Down Expand Up @@ -221,7 +221,7 @@ def test_pickle_safe(c, s, a, b):
try:
yield c2.publish_dataset(x=[1, 2, 3])
result = yield c2.get_dataset('x')
assert result == [1, 2, 3]
assert result == (1, 2, 3)

with pytest.raises(TypeError):
yield c2.publish_dataset(y=lambda x: x)
Expand Down
4 changes: 2 additions & 2 deletions distributed/tests/test_queues.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,10 @@ def test_queue_with_data(c, s, a, b):
xx = yield Queue('x')
assert x.client is c

yield x.put([1, 'hello'])
yield x.put((1, 'hello'))
data = yield xx.get()

assert data == [1, 'hello']
assert data == (1, 'hello')

with pytest.raises(gen.TimeoutError):
yield x.get(timeout=0.1)
Expand Down
4 changes: 2 additions & 2 deletions distributed/tests/test_variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,10 @@ def test_queue_with_data(c, s, a, b):
xx = Variable('x')
assert x.client is c

yield x.set([1, 'hello'])
yield x.set((1, 'hello'))
data = yield xx.get()

assert data == [1, 'hello']
assert data == (1, 'hello')


def test_sync(loop):
Expand Down