From 5abb60ba1e1f63e1480691478f87e10767ff4072 Mon Sep 17 00:00:00 2001 From: Masahiro Tanaka Date: Tue, 17 Sep 2024 16:29:18 +0000 Subject: [PATCH] use msgpack for p2p comm --- deepspeed/runtime/pipe/p2p.py | 8 ++++---- requirements/requirements.txt | 1 + 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/deepspeed/runtime/pipe/p2p.py b/deepspeed/runtime/pipe/p2p.py index 2b12a9573c4b..ed6d80b8d4fb 100644 --- a/deepspeed/runtime/pipe/p2p.py +++ b/deepspeed/runtime/pipe/p2p.py @@ -3,7 +3,7 @@ # DeepSpeed Team -import pickle +import msgpack import typing import torch @@ -96,7 +96,7 @@ def wait(): def send_obj(msg: typing.Any, dest: int): """Send an arbitrary python object to ``dest``. - Note: ``msg`` must be pickleable. + Note: ``msg`` must be serializable by msgpack. WARN: This incurs a CPU -> GPU transfer and should be used sparingly for performance reasons. @@ -106,7 +106,7 @@ def send_obj(msg: typing.Any, dest: int): dest (int): Destination rank. """ # serialize the message - msg = pickle.dumps(msg) + msg = msgpack.packb(msg) # construct a tensor to send msg = torch.ByteTensor(torch.ByteStorage.from_buffer(msg)).to(get_accelerator().device_name()) @@ -133,7 +133,7 @@ def recv_obj(sender: int) -> typing.Any: msg = torch.empty(length.item(), dtype=torch.uint8).to(get_accelerator().device_name()) dist.recv(msg, src=sender) - msg = pickle.loads(msg.cpu().numpy().tobytes()) + msg = msgpack.unpackb(msg.cpu().numpy().tobytes()) def _to(x): """Recursively move to the current device.""" diff --git a/requirements/requirements.txt b/requirements/requirements.txt index 70c94a745435..296398f680cc 100755 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -1,4 +1,5 @@ hjson +msgpack ninja numpy packaging>=20.0