diff --git a/distributed/protocol/tests/test_pickle.py b/distributed/protocol/tests/test_pickle.py index a25d499ea0f..a4ab8435646 100644 --- a/distributed/protocol/tests/test_pickle.py +++ b/distributed/protocol/tests/test_pickle.py @@ -8,7 +8,6 @@ from distributed.protocol import deserialize, serialize from distributed.protocol.pickle import HIGHEST_PROTOCOL, dumps, loads -from distributed.protocol.serialize import pickle_dumps if sys.version_info < (3, 8): try: @@ -19,6 +18,17 @@ import pickle +class MemoryviewHolder: + def __init__(self, mv): + self.mv = memoryview(mv) + + def __reduce_ex__(self, protocol): + if protocol >= 5: + return MemoryviewHolder, (pickle.PickleBuffer(self.mv),) + else: + return MemoryviewHolder, (self.mv.tobytes(),) + + def test_pickle_data(): data = [1, b"123", "123", [123], {}, set()] for d in data: @@ -27,16 +37,6 @@ def test_pickle_data(): def test_pickle_out_of_band(): - class MemoryviewHolder: - def __init__(self, mv): - self.mv = memoryview(mv) - - def __reduce_ex__(self, protocol): - if protocol >= 5: - return MemoryviewHolder, (pickle.PickleBuffer(self.mv),) - else: - return MemoryviewHolder, (self.mv.tobytes(),) - mv = memoryview(b"123") mvh = MemoryviewHolder(mv) @@ -73,13 +73,29 @@ def __reduce_ex__(self, protocol): def test_pickle_empty(): - np = pytest.importorskip("numpy") - x = np.arange(2)[0:0] # Empty view - header, frames = pickle_dumps(x) - header["writeable"] = [False] * len(frames) + x = MemoryviewHolder(bytearray()) # Empty view + header, frames = serialize(x, serializers=("pickle",)) + + assert header["serializer"] == "pickle" + assert len(frames) >= 1 + assert isinstance(frames[0], bytes) + + if HIGHEST_PROTOCOL >= 5: + assert len(frames) == 2 + assert len(header["writeable"]) == 1 + + header["writeable"] = (False,) * len(frames) + else: + assert len(frames) == 1 + assert len(header["writeable"]) == 0 + y = deserialize(header, frames) - assert memoryview(y).nbytes == 0 - assert memoryview(y).readonly + + assert isinstance(y, MemoryviewHolder) + assert isinstance(y.mv, memoryview) + assert y.mv == x.mv + assert y.mv.nbytes == 0 + assert y.mv.readonly def test_pickle_numpy():