Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 35 additions & 83 deletions colossalai/communication/p2p.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@ def _communicate(tensor_send_next=None,
recv_next_shape=None,
prev_rank=None,
next_rank=None,
up_group=None,
down_group=None,
dtype=None):
"""
Adapted from megatron.p2p_communication.
Expand Down Expand Up @@ -59,60 +57,44 @@ def _communicate(tensor_send_next=None,
if prev_rank is None:
prev_rank = gpc.get_prev_global_rank(
ParallelMode.PIPELINE)
if up_group is None:
up_group = gpc.get_group(ParallelMode.PIPELINE_PREV)

if tensor_send_next is not None or recv_next:
if next_rank is None:
next_rank = gpc.get_next_global_rank(
ParallelMode.PIPELINE)
if down_group is None:
down_group = gpc.get_group(ParallelMode.PIPELINE_NEXT)

# rank = dist.get_rank()
rank = gpc.get_global_rank()

ops = []
if tensor_send_prev is not None:
send_prev_op = dist.broadcast(tensor_send_prev,
src=rank,
group=up_group,
async_op=True)
send_prev_op = dist.P2POp(dist.isend, tensor_send_prev, prev_rank)
ops.append(send_prev_op)
if tensor_recv_prev is not None:
recv_prev_op = dist.broadcast(tensor_recv_prev,
src=prev_rank,
group=up_group,
async_op=True)
recv_prev_op = dist.P2POp(dist.irecv, tensor_recv_prev, prev_rank)
ops.append(recv_prev_op)
if tensor_recv_next is not None:
recv_next_op = dist.broadcast(tensor_recv_next,
src=next_rank,
group=down_group,
async_op=True)
recv_next_op = dist.P2POp(dist.irecv, tensor_recv_next, next_rank)
ops.append(recv_next_op)
if tensor_send_next is not None:
send_next_op = dist.broadcast(tensor_send_next,
src=rank,
group=down_group,
async_op=True)
send_next_op = dist.P2POp(dist.isend, tensor_send_next, next_rank)
ops.append(send_next_op)
for req in ops:
req.wait()
if len(ops) > 0:
reqs = dist.batch_isend_irecv(ops)
for req in reqs:
req.wait()
# To protect against race condition when using batch_isend_irecv().
torch.cuda.synchronize()
return tensor_recv_prev, tensor_recv_next


def recv_forward(input_tensor_shape, prev_rank=None, up_group=None):
def recv_forward(input_tensor_shape, prev_rank=None):
"""Receives the input tensor from the previous member in pipeline.

:param input_tensor_shape: The shape of the tensor to be recieved
:param prev_rank: The rank of the source of the tensor
:param up_group: Communication group including the previous member in pipeline parallel group
:type input_tensor_shape: torch.Size
:type prev_rank: int, optional
:type up_group: ProcessGroup, optional
:return: The input tensor in forward step
:rtype: Tensor
"""
Expand All @@ -121,20 +103,17 @@ def recv_forward(input_tensor_shape, prev_rank=None, up_group=None):
else:
input_tensor, _ = _communicate(recv_prev=True,
recv_prev_shape=input_tensor_shape,
prev_rank=prev_rank,
up_group=up_group)
prev_rank=prev_rank)
return input_tensor


def recv_backward(output_grad_shape, next_rank=None, down_group=None):
def recv_backward(output_grad_shape, next_rank=None):
"""Receives the grad tensor from the next member in pipeline.

:param output_grad_shape: The shape of the tensor to be recieved
:param next_rank: The rank of the source of the tensor
:param down_group: Communication group including the next member in pipeline parallel group
:type output_grad_shape: torch.Size
:type next_rank: int, optional
:type down_group: ProcessGroup, optional
:return: The grad of output tensor in forward step
:rtype: Tensor
"""
Expand All @@ -143,56 +122,44 @@ def recv_backward(output_grad_shape, next_rank=None, down_group=None):
else:
_, output_tensor_grad = _communicate(recv_next=True,
recv_next_shape=output_grad_shape,
next_rank=next_rank,
down_group=down_group)
next_rank=next_rank)
return output_tensor_grad


def send_forward(output_tensor,
next_rank=None,
down_group=None):
def send_forward(output_tensor, next_rank=None):
"""Sends the input tensor to the next member in pipeline.

:param output_tensor: Tensor to be sent
:param next_rank: The rank of the recipient of the tensor
:param down_group: Communication group including the next member in pipeline parallel group
:type output_tensor: Tensor
:type next_rank: int, optional
:type down_group: ProcessGroup, optional
"""
if not gpc.is_last_rank(ParallelMode.PIPELINE):
_communicate(tensor_send_next=output_tensor,
next_rank=next_rank,
down_group=down_group)
next_rank=next_rank)


def send_backward(input_tensor_grad,
prev_rank=None,
up_group=None):
def send_backward(input_tensor_grad, prev_rank=None):
"""Sends the grad tensor to the previous member in pipeline.

:param input_tensor_grad: Tensor to be sent
:param prev_rank: The rank of the recipient of the tensor
:param up_group: Communication group including the previous member in pipeline parallel group
:type input_tensor_grad: Tensor
:type prev_rank: int, optional
:type up_group: ProcessGroup, optional
"""
if not gpc.is_first_rank(ParallelMode.PIPELINE):
_communicate(tensor_send_prev=input_tensor_grad,
prev_rank=prev_rank,
up_group=up_group)
prev_rank=prev_rank)


def send_forward_recv_backward(output_tensor,
output_grad_shape,
recv_next=True,
next_rank=None,
down_group=None):
next_rank=None):
"""Batched communication operation. Sends the input tensor to the
next member in pipeline, while recieves the grad tensor from the
next member in pipeline.

:param output_tensor: Tensor to be sent
:param output_grad_shape: The shape of the tensor to be recieved
:type output_tensor: Tensor
Expand All @@ -206,20 +173,18 @@ def send_forward_recv_backward(output_tensor,
_, output_tensor_grad = _communicate(tensor_send_next=output_tensor,
recv_next=recv_next,
recv_next_shape=output_grad_shape,
next_rank=next_rank,
down_group=down_group)
next_rank=next_rank)
return output_tensor_grad


def send_backward_recv_forward(input_tensor_grad,
input_tensor_shape,
recv_prev=True,
prev_rank=None,
up_group=None):
prev_rank=None):
"""Batched communication operation. Sends the grad tensor to the
previous member in pipeline, while recieves the input tensor from the
previous member in pipeline.

:param input_tensor_grad: Tensor to be sent
:param input_tensor_shape: The shape of the tensor to be recieved
:type input_tensor_grad: Tensor
Expand All @@ -233,22 +198,19 @@ def send_backward_recv_forward(input_tensor_grad,
input_tensor, _ = _communicate(tensor_send_prev=input_tensor_grad,
recv_prev=recv_prev,
recv_prev_shape=input_tensor_shape,
prev_rank=prev_rank,
up_group=up_group)
prev_rank=prev_rank)
return input_tensor


def send_forward_recv_forward(output_tensor,
input_tensor_shape,
recv_prev=True,
prev_rank=None,
next_rank=None,
up_group=None,
down_group=None):
next_rank=None):
"""Batched communication operation. Sends the input tensor to the
next member in pipeline, while recieves the input tensor from the
previous member in pipeline.

:param output_tensor: Tensor to be sent
:param input_tensor_shape: The shape of the tensor to be recieved
:type output_tensor: Tensor
Expand All @@ -260,23 +222,19 @@ def send_forward_recv_forward(output_tensor,
recv_prev=recv_prev,
recv_prev_shape=input_tensor_shape,
prev_rank=prev_rank,
next_rank=next_rank,
up_group=up_group,
down_group=down_group)
next_rank=next_rank)
return input_tensor


def send_backward_recv_backward(input_tensor_grad,
output_grad_shape,
recv_next=True,
prev_rank=None,
next_rank=None,
up_group=None,
down_group=None):
next_rank=None):
"""Batched communication operation. Sends the grad tensor to the
previous member in pipeline, while recieves the grad tensor from the
next member in pipeline.

:param input_tensor_grad: Tensor to be sent
:param output_grad_shape: The shape of the tensor to be recieved
:type input_tensor_grad: Tensor
Expand All @@ -288,9 +246,7 @@ def send_backward_recv_backward(input_tensor_grad,
recv_next=recv_next,
recv_next_shape=output_grad_shape,
prev_rank=prev_rank,
next_rank=next_rank,
up_group=up_group,
down_group=down_group)
next_rank=next_rank)
return output_tensor_grad


Expand All @@ -301,13 +257,11 @@ def send_forward_backward_recv_forward_backward(output_tensor,
recv_prev=True,
recv_next=True,
prev_rank=None,
next_rank=None,
up_group=None,
down_group=None):
next_rank=None):
"""Batched communication operation. Sends the input tensor to the next and
the grad tensor to the previous, while recieves the grad tensor from the
next and the input tensor from the previous.

:param output_tensor: Tensor sent to the next
:param input_tensor_grad: Tensor sent to the previous
:param input_tensor_shape: The shape of the tensor recieved from the previous
Expand All @@ -327,7 +281,5 @@ def send_forward_backward_recv_forward_backward(output_tensor,
recv_prev_shape=input_tensor_shape,
recv_next_shape=output_grad_shape,
prev_rank=prev_rank,
next_rank=next_rank,
up_group=up_group,
down_group=down_group)
next_rank=next_rank)
return input_tensor, output_tensor_grad
37 changes: 17 additions & 20 deletions colossalai/communication/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,67 +6,64 @@
from colossalai.utils import get_current_device


def send_tensor_meta(tensor, need_meta=True, down_group=None):
def send_tensor_meta(tensor, need_meta=True, next_rank=None):
"""Sends tensor meta information before sending a specific tensor.
Since the recipient must know the shape of the tensor in p2p communications,
meta information of the tensor should be sent before communications. This function
synchronizes with :func:`recv_tensor_meta`.

:param tensor: Tensor to be sent
:param need_meta: If False, meta information won't be sent
:param down_group: Communication group including the next member in pipeline parallel group
:param next_rank: The rank of the next member in pipeline parallel group
:type tensor: Tensor
:type need_meta: bool, optional
:type down_group: ProcessGroup, optional
:type next_rank: int
:return: False
:rtype: bool
"""
if need_meta:
rank = gpc.get_global_rank()

if down_group is None:
down_group = gpc.get_group(ParallelMode.PIPELINE_NEXT)
if next_rank is None:
next_rank = gpc.get_next_global_rank(ParallelMode.PIPELINE)

tensor_kwargs = {'dtype': torch.long, 'device': get_current_device()}

send_shape = torch.tensor(tensor.size(), **tensor_kwargs)
send_ndims = torch.tensor(len(tensor.size()), **tensor_kwargs)

dist.broadcast(send_ndims, src=rank, group=down_group)
dist.broadcast(send_shape, src=rank, group=down_group)
ops = [
dist.P2POp(dist.isend, send_ndims, next_rank),
dist.P2POp(dist.isend, send_shape, next_rank)
]
reqs = dist.batch_isend_irecv(ops)
for req in reqs:
req.wait()
torch.cuda.synchronize()

return False


def recv_tensor_meta(tensor_shape, prev_rank=None, up_group=None):
def recv_tensor_meta(tensor_shape, prev_rank=None):
"""Recieves tensor meta information before recieving a specific tensor.
Since the recipient must know the shape of the tensor in p2p communications,
meta information of the tensor should be recieved before communications. This function
synchronizes with :func:`send_tensor_meta`.

:param tensor_shape: The shape of the tensor to be recieved
:param prev_rank: The rank of the source of the tensor
:param up_group: Communication group including the previous member in pipeline parallel group
:type tensor_shape: torch.Size
:type prev_rank: int, optional
:type up_group: ProcessGroup, optional
:return: The shape of the tensor to be recieved
:rtype: torch.Size
"""
if tensor_shape is None:
if prev_rank is None:
prev_rank = gpc.get_prev_global_rank(
ParallelMode.PIPELINE)
if up_group is None:
up_group = gpc.get_group(ParallelMode.PIPELINE_PREV)
prev_rank = gpc.get_prev_global_rank(ParallelMode.PIPELINE)

tensor_kwargs = {'dtype': torch.long, 'device': get_current_device()}

recv_ndims = torch.empty((), **tensor_kwargs)
dist.broadcast(recv_ndims, src=prev_rank, group=up_group)

dist.recv(recv_ndims, prev_rank)
recv_shape = torch.empty(recv_ndims, **tensor_kwargs)
dist.broadcast(recv_shape, src=prev_rank, group=up_group)
dist.recv(recv_shape, prev_rank)

tensor_shape = torch.Size(recv_shape)

Expand Down
1 change: 1 addition & 0 deletions colossalai/engine/schedule/_no_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from typing import Iterable

import torch

import torch.nn as nn
from torch.optim import Optimizer

Expand Down
Loading