Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions deepspeed/runtime/pipe/p2p.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

# DeepSpeed Team

import pickle
import msgpack
import typing

import torch
Expand Down Expand Up @@ -96,7 +96,7 @@ def wait():
def send_obj(msg: typing.Any, dest: int):
"""Send an arbitrary python object to ``dest``.

Note: ``msg`` must be pickleable.
Note: ``msg`` must be serializable by msgpack.

WARN: This incurs a CPU -> GPU transfer and should be used sparingly
for performance reasons.
Expand All @@ -106,7 +106,7 @@ def send_obj(msg: typing.Any, dest: int):
dest (int): Destination rank.
"""
# serialize the message
msg = pickle.dumps(msg)
msg = msgpack.packb(msg)
# construct a tensor to send
msg = torch.ByteTensor(torch.ByteStorage.from_buffer(msg)).to(get_accelerator().device_name())

Expand All @@ -133,7 +133,7 @@ def recv_obj(sender: int) -> typing.Any:
msg = torch.empty(length.item(), dtype=torch.uint8).to(get_accelerator().device_name())
dist.recv(msg, src=sender)

msg = pickle.loads(msg.cpu().numpy().tobytes())
msg = msgpack.unpackb(msg.cpu().numpy().tobytes())

def _to(x):
"""Recursively move to the current device."""
Expand Down
1 change: 1 addition & 0 deletions requirements/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
hjson
msgpack
ninja
numpy
packaging>=20.0
Expand Down