Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 81 additions & 5 deletions colossalai/shardformer/layer/_operation.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import torch
import torch.distributed as dist
import torch.nn.functional as F

try:
import fused_mix_prec_layer_norm_cuda
Expand Down Expand Up @@ -46,7 +47,7 @@ def backward(ctx, grad_output):
return grad_input, grad_weight, grad_bias, None, None


class LinearWithAsyncCommunication(torch.autograd.Function):
class MatmulWithAsyncCommunication(torch.autograd.Function):
"""
Linear layer execution with asynchronous communication in backprop.
"""
Expand All @@ -58,11 +59,59 @@ def forward(ctx, input_, weight, bias, process_group, async_grad_allreduce):
ctx.process_group = process_group
ctx.async_grad_allreduce = async_grad_allreduce

output = torch.matmul(input_, weight.t())
output = torch.matmul(input_, weight)

if bias is not None:
output = output + bias
return output

@staticmethod
def backward(ctx, grad_output):
input, weight = ctx.saved_tensors
use_bias = ctx.use_bias

total_input = input
grad_input = grad_output.matmul(weight.T)
grad_output = grad_output.contiguous()
# Convert the tensor shapes to 2D for execution compatibility
if len(grad_output.shape) > 2:
grad_output = grad_output.view(-1, grad_output.shape[-1])
total_input = total_input.view(-1, total_input.shape[-1])

if ctx.async_grad_allreduce:
# Asynchronous all-reduce
handle = dist.all_reduce(grad_input, group=ctx.process_group, async_op=True)
# Delay the start of weight gradient computation shortly (3us) to have
# all-reduce scheduled first and have GPU resources allocated
_ = torch.empty(1, device=grad_output.device) + 1

grad_weight = total_input.t().matmul(grad_output)
grad_bias = grad_output.sum(dim=0) if use_bias else None

if ctx.async_grad_allreduce:
handle.wait()

return grad_input, grad_weight, grad_bias, None, None, None


class LinearWithAsyncCommunication(torch.autograd.Function):
"""
Linear layer execution with asynchronous communication in backprop.
"""

@staticmethod
def forward(ctx, input_, weight, bias, process_group, async_grad_allreduce):
ctx.save_for_backward(input_, weight)
ctx.use_bias = bias is not None
ctx.process_group = process_group
ctx.async_grad_allreduce = async_grad_allreduce

if bias is not None:
output = F.linear(input_, weight, bias)
else:
output = F.linear(input_, weight)
return output

@staticmethod
def backward(ctx, grad_output):
input, weight = ctx.saved_tensors
Expand Down Expand Up @@ -114,7 +163,7 @@ def backward(ctx, grad_output):
return _gather(grad_output, ctx.dim, ctx.process_group), None, None


class _ReduceInput(torch.autograd.Function):
class _ReduceForward(torch.autograd.Function):
"""
All-reduce the input from the model parallel region.

Expand All @@ -132,6 +181,25 @@ def backward(ctx, grad_output):
return grad_output, None


class _ReduceBackward(torch.autograd.Function):
"""
All-reduce the input from the model parallel region.

Args:
input_: input matrix.
parallel_mode: parallel mode.
"""

@staticmethod
def forward(ctx, input_, process_group):
ctx.process_group = process_group
return input_

@staticmethod
def backward(ctx, grad_output):
return _reduce(grad_output, ctx.process_group), None


def _reduce(input_, process_group):
# skip if only one rank involved
if dist.get_world_size(process_group) == 1:
Expand Down Expand Up @@ -198,6 +266,10 @@ def backward(ctx, grad_output):
return _split(grad_output, ctx.dim, ctx.process_group), None, None


def matmul_with_async_comm(input_, weight, bias, process_group, async_grad_allreduce):
return MatmulWithAsyncCommunication.apply(input_, weight, bias, process_group, async_grad_allreduce)


def linear_with_async_comm(input_, weight, bias, process_group, async_grad_allreduce):
return LinearWithAsyncCommunication.apply(input_, weight, bias, process_group, async_grad_allreduce)

Expand All @@ -210,5 +282,9 @@ def split_forward_gather_backward(input_, dim, process_group):
return _SplitForwardGatherBackward.apply(input_, dim, process_group)


def reduce_input(input_, process_group):
return _ReduceInput.apply(input_, process_group)
def reduce_forward(input_, process_group):
return _ReduceForward.apply(input_, process_group)


def reduce_backward(input_, process_group):
return _ReduceBackward.apply(input_, process_group)
4 changes: 2 additions & 2 deletions colossalai/shardformer/layer/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from colossalai.nn.layer.utils import divide
from colossalai.tensor.d_tensor.api import shard_colwise, shard_rowwise, sharded_tensor_to_param

from ._operation import gather_forward_split_backward, reduce_input
from ._operation import gather_forward_split_backward, reduce_forward
from .parallel_module import ParallelModule
from .utils import create_randomizer_with_offset

Expand Down Expand Up @@ -276,5 +276,5 @@ def forward(self, input_: Tensor) -> Tensor:
# Mask the output embedding.
output_parallel[input_mask, :] = 0.
# Reduce across all the model parallel GPUs.
output = reduce_input(output_parallel, self.process_group)
output = reduce_forward(output_parallel, self.process_group)
return output
16 changes: 5 additions & 11 deletions colossalai/shardformer/layer/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,11 @@
from colossalai.nn import init as init
from colossalai.nn.layer.utils import divide
from colossalai.tensor.d_tensor import shard_colwise, shard_rowwise, sharded_tensor_to_param
from colossalai.utils.cuda import get_current_device

from ._operation import (
gather_forward_split_backward,
linear_with_async_comm,
reduce_input,
reduce_forward,
split_forward_gather_backward,
)
from .parallel_module import ParallelModule
Expand Down Expand Up @@ -148,9 +147,10 @@ def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]:
assert input_.shape[-1] == self.weight.shape[-1], \
'Invalid shapes in Linear1D_Col forward: input={}, weight={}. Expected last dim of input {}.'.format(
input_.shape, self.weight.shape, self.weight.shape[-1])

# Set up backprop all-reduce.
# input_parallel = reduce_grad(input_, ParallelMode.PARALLEL_1D)
input_parallel = input_

# Matrix multiply.
bias = self.bias if not self.skip_bias_add else None
output_parallel = linear_with_async_comm(input_parallel, self.weight, bias, self.process_group, True)
Expand Down Expand Up @@ -209,17 +209,14 @@ def __init__(self,
self.parallel_input = parallel_input
self.skip_bias_add = skip_bias_add
self.process_group = process_group
self.num_partitions = dist.get_world_size(self.process_group)

if skip_bias_add and not bias:
raise ValueError('cannot skip bias addition if bias is None')

# Parameters.
# Initialize weight.
if device is None:
device = get_current_device()

factory_kwargs = {'device': device, 'dtype': dtype}

weight = torch.empty(self.out_features, self.in_features, **factory_kwargs)
sharded_weight = shard_colwise(weight, self.process_group)
self.weight = sharded_tensor_to_param(sharded_weight)
Expand Down Expand Up @@ -327,14 +324,11 @@ def forward(self, input_: Tensor) -> Tensor:
output = torch.cat(output_parallel_list, dim=-1)
else:
output_parallel = F.linear(input_, self.weight)
# output_parallel = linear_with_async_comm(input_, self.weight, None, ParallelMode.PARALLEL_1D, False)
output = reduce_input(output_parallel, self.process_group)
output = reduce_forward(output_parallel, self.process_group)

if not self.skip_bias_add:
if self.bias is not None:
output = output + self.bias
return output
else:
return output, self.bias
return output, self.bias
return output, self.bias
Loading