diff --git a/distributed/protocol/tests/test_numpy.py b/distributed/protocol/tests/test_numpy.py index 0e299632902..d09b5e435ad 100644 --- a/distributed/protocol/tests/test_numpy.py +++ b/distributed/protocol/tests/test_numpy.py @@ -100,6 +100,14 @@ def test_dumps_serialize_numpy(x): np.testing.assert_equal(x, y) +def test_dumps_numpy_writable(): + a1 = np.arange(1000) + a1.flags.writeable = False + (a2,) = loads(dumps([to_serialize(a1)])) + assert (a1 == a2).all() + assert a2.flags.writeable + + @pytest.mark.parametrize( "x", [ diff --git a/distributed/protocol/utils.py b/distributed/protocol/utils.py index fa020dae909..c287949bc53 100644 --- a/distributed/protocol/utils.py +++ b/distributed/protocol/utils.py @@ -1,7 +1,7 @@ import struct import msgpack -from ..utils import ensure_bytes, nbytes +from ..utils import ensure_bytearray, nbytes BIG_BYTES_SHARD_SIZE = 2 ** 26 @@ -62,7 +62,7 @@ def merge_frames(header, frames): assert sum(lengths) == sum(map(nbytes, frames)) if all(len(f) == l for f, l in zip(frames, lengths)): - return frames + return list(map(ensure_bytearray, frames)) frames = frames[::-1] lengths = lengths[::-1] @@ -82,9 +82,9 @@ def merge_frames(header, frames): frames.append(mv[l:]) l = 0 if len(L) == 1: # no work necessary - out.extend(L) + out.append(ensure_bytearray(L[0])) else: - out.append(b"".join(map(ensure_bytes, L))) + out.append(bytearray().join(L)) return out diff --git a/distributed/utils.py b/distributed/utils.py index dec1b6b79d3..14249acea23 100644 --- a/distributed/utils.py +++ b/distributed/utils.py @@ -944,6 +944,48 @@ def ensure_bytes(s): ) from e +def ensure_bytearray(s): + """Attempt to turn `s` into `bytearray`. + + Parameters + ---------- + s : Any + The object to be converted. Will correctly handled + + * str + * bytes + * objects implementing the buffer protocol (memoryview, ndarray, etc.) + + Returns + ------- + b : bytes + + Raises + ------ + TypeError + When `s` cannot be converted + + Examples + -------- + + >>> ensure_bytearray('123') + bytearray(b'123') + >>> ensure_bytearray(b'123') + bytearray(b'123') + """ + if isinstance(s, bytearray): + return s + elif hasattr(s, "encode"): + return bytearray(s.encode()) + else: + try: + return bytearray(s) + except Exception as e: + raise TypeError( + "Object %s is neither a bytes object nor has an encode method" % s + ) from e + + def divide_n_among_bins(n, bins): """