Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
224 changes: 224 additions & 0 deletions colossalai/pipeline/p2p.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,224 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-

import io
import pickle
from typing import Any, List, Optional, Union

import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup
from torch.distributed import distributed_c10d as c10d

from .stage_manager import PipelineStageManager

_unpickler = pickle.Unpickler


def _cuda_safe_tensor_to_object(tensor: torch.Tensor, tensor_size: torch.Size) -> object:
"""transform tensor to object with unpickle.
Info of the device in bytes stream will be modified into current device before unpickling

Args:
tensor (:class:`torch.tensor`): tensor to be unpickled
tensor_size (:class:`torch.Size`): Size of the real info in bytes

Returns:
Any: object after unpickled
"""
buf = tensor.numpy().tobytes()[:tensor_size]
if b'cuda' in buf:
buf_array = bytearray(buf)
device_index = torch.cuda.current_device()
buf_array[buf_array.find(b'cuda') + 5] = 48 + device_index
buf = bytes(buf_array)

io_bytes = io.BytesIO(buf)
byte_pickler = _unpickler(io_bytes)
unpickle = byte_pickler.load()

return unpickle


def _broadcast_object_list(object_list: List[Any],
src: int,
group: ProcessGroup,
device: Optional[Union[torch.device, str, int]] = None):
"""This is a modified version of the broadcast_object_list in torch.distribution
The only difference is that object will be move to correct device after unpickled.
If local_rank = src, then object list will be sent to rank src. Otherwise, object list will
be updated with data sent from rank src.

Args:
object_list (List[Any]): list of object to broadcast
src (int): source rank to broadcast
dst (int): dst rank to broadcast
device (:class:`torch.device`): device to do broadcast. current device in default

"""

if c10d._rank_not_in_group(group):
c10d._warn_not_in_group("broadcast_object_list")
return

my_rank = dist.get_rank()
# Serialize object_list elements to tensors on src rank.
if my_rank == src:
tensor_list, size_list = zip(*[c10d._object_to_tensor(obj) for obj in object_list])
object_sizes_tensor = torch.cat(size_list)
else:
object_sizes_tensor = torch.empty(len(object_list), dtype=torch.long)

is_nccl_backend = c10d._check_for_nccl_backend(group)
current_device = None

if device is not None:
if is_nccl_backend and device.type != "cuda":
raise ValueError("device type must be cuda for nccl backend")
current_device = device
else:
current_device = torch.device("cpu")
if is_nccl_backend:
current_device = torch.device("cuda", torch.cuda.current_device())
if is_nccl_backend:
object_sizes_tensor = object_sizes_tensor.to(current_device)

# Broadcast object sizes
c10d.broadcast(object_sizes_tensor, src=src, group=group, async_op=False)

# Concatenate and broadcast serialized object tensors
if my_rank == src:
object_tensor = torch.cat(tensor_list)
else:
object_tensor = torch.empty( # type: ignore[call-overload]
torch.sum(object_sizes_tensor).item(), # type: ignore[arg-type]
dtype=torch.uint8,
)

if is_nccl_backend:
object_tensor = object_tensor.to(current_device)

c10d.broadcast(object_tensor, src=src, group=group, async_op=False)

# Deserialize objects using their stored sizes.
offset = 0

if my_rank != src:
for i, obj_size in enumerate(object_sizes_tensor):
obj_view = object_tensor[offset:offset + obj_size]
obj_view = obj_view.type(torch.uint8)
if obj_view.device != torch.device("cpu"):
obj_view = obj_view.cpu()
offset += obj_size
# unpickle
unpickle_object = _cuda_safe_tensor_to_object(obj_view, obj_size)

# unconsistence in device
if isinstance(unpickle_object,
torch.Tensor) and unpickle_object.device.index != torch.cuda.current_device():
unpickle_object = unpickle_object.cuda()

object_list[i] = unpickle_object


def _send_object(object: Any, src: int, dst: int, group: ProcessGroup) -> None:
"""send anything to dst rank

Args:
object (Any): object needed to be sent
dst (int): rank of the destination

Returns:
None
"""
# then broadcast safely
_broadcast_object_list([object], src, group)


def _recv_object(src: int, dst: int, group: ProcessGroup) -> Any:
"""recv anything from src

Args:
src (int): source rank of data. local rank will receive data from src rank.

Returns:
Any: Object received from src.
"""
object_list = [None]
_broadcast_object_list(object_list, src, group)

return object_list[0]


class PipelineP2PCommunication:

def __init__(self, stage_manager: PipelineStageManager) -> None:
self.stage_manager = stage_manager

def recv_forward(self, prev_rank: int = None) -> Any:
"""Copy the forward output from the previous stage in pipeline as the input tensor of this stage.

Args:
prev_rank (int, optional): The rank of the source of the tensor.

Returns:
Any: The input tensor or input tensor list.
"""
if self.stage_manager.is_first_stage():
input_tensor = None
else:
if prev_rank is None:
prev_rank = self.stage_manager.get_prev_rank()
cur_rank = self.stage_manager.get_rank()
input_tensor = _recv_object(prev_rank, cur_rank,
self.stage_manager.get_p2p_process_group(prev_rank, cur_rank))

return input_tensor

def recv_backward(self, next_rank: int = None) -> Any:
"""Copy the gradient tensor from the next stage in pipeline as the input gradient of this stage.

Args:
next_rank (int, optional): The rank of the source of the tensor.

Returns:
Any: The input gradient tensor or gradient tensor list.
"""
if self.stage_manager.is_last_stage():
output_tensor_grad = None
else:
if next_rank is None:
next_rank = self.stage_manager.get_next_rank()
cur_rank = self.stage_manager.get_rank()
output_tensor_grad = _recv_object(next_rank, cur_rank,
self.stage_manager.get_p2p_process_group(next_rank, cur_rank))

return output_tensor_grad

def send_forward(self, output_object: Any, next_rank: int = None) -> None:
"""Sends the input tensor to the next stage in pipeline.

Args:
output_object (Any): Object to be sent.
next_rank (int, optional): The rank of the recipient of the tensor.
"""
if not self.stage_manager.is_last_stage():
if next_rank is None:
next_rank = self.stage_manager.get_next_rank()
cur_rank = self.stage_manager.get_rank()
_send_object(output_object, cur_rank, next_rank,
self.stage_manager.get_p2p_process_group(cur_rank, next_rank))

def send_backward(self, input_object: Any, prev_rank: int = None) -> None:
"""Sends the gradient tensor to the previous stage in pipeline.

Args:
input_object (Any): Object to be sent.
prev_rank (int, optional): The rank of the recipient of the tensor
"""
if not self.stage_manager.is_first_stage():
if prev_rank is None:
prev_rank = self.stage_manager.get_prev_rank()
cur_rank = self.stage_manager.get_rank()
_send_object(input_object, cur_rank, prev_rank,
self.stage_manager.get_p2p_process_group(cur_rank, prev_rank))
59 changes: 59 additions & 0 deletions tests/test_pipeline/test_p2p_communication.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import pytest
import torch
import torch.distributed as dist

import colossalai
from colossalai.cluster import ProcessGroupMesh
from colossalai.pipeline.p2p import PipelineP2PCommunication
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.testing import rerun_if_address_is_in_use, spawn
from colossalai.utils import get_current_device


def check_p2p_communication():
pg_mesh = ProcessGroupMesh(2)
stage_manager = PipelineStageManager(pg_mesh, 0)
p2p = PipelineP2PCommunication(stage_manager)

rank = dist.get_rank()

tensor = torch.ones(1, device=get_current_device())

if rank == 0:
p2p.send_forward(tensor)
p2p.send_forward([tensor])
p2p.send_forward({'tensor': tensor})
else:
obj = p2p.recv_forward()
assert torch.equal(obj, tensor)
obj = p2p.recv_forward()
assert type(obj) == list and len(obj) == 1 and torch.equal(obj[0], tensor)
obj = p2p.recv_forward()
assert type(obj) == dict and 'tensor' in obj and torch.equal(obj['tensor'], tensor)

if rank == 1:
p2p.send_backward(tensor)
p2p.send_backward([tensor])
p2p.send_backward({'tensor': tensor})
else:
obj = p2p.recv_backward()
assert torch.equal(obj, tensor)
obj = p2p.recv_backward()
assert type(obj) == list and len(obj) == 1 and torch.equal(obj[0], tensor)
obj = p2p.recv_backward()
assert type(obj) == dict and 'tensor' in obj and torch.equal(obj['tensor'], tensor)


def run_dist(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host='localhost')
check_p2p_communication()


@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_pipeline_p2p():
spawn(run_dist, 2)


if __name__ == '__main__':
test_pipeline_p2p()
7 changes: 4 additions & 3 deletions tests/test_pipeline/test_stage_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import colossalai
from colossalai.cluster import ProcessGroupMesh
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.testing import spawn
from colossalai.testing import rerun_if_address_is_in_use, spawn


def check_stage_manager():
Expand Down Expand Up @@ -78,9 +78,10 @@ def run_dist(rank, world_size, port):


@pytest.mark.dist
def test_process_group_mesh():
@rerun_if_address_is_in_use()
def test_pipeline_stage_manager():
Comment thread
ver217 marked this conversation as resolved.
spawn(run_dist, 4)


if __name__ == '__main__':
test_process_group_mesh()
test_pipeline_stage_manager()