Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
156 changes: 143 additions & 13 deletions colossalai/shardformer/layer/_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,105 @@ def backward(ctx, grad_output):
return grad_input, grad_weight, grad_bias, None, None, None


class _LinearWithGatherForwardReduceScatterBackward(torch.autograd.Function):
"""Gather input from sequence parallel in forward and reduce-scatter gradient in backward

Args:
input_ (`torch.Tensor`): The input tensor from sequence parallel region.
process_group (`torch.distributed.ProcessGroup`): The process group used for collective communication.

"""

@staticmethod
def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim):
ctx.save_for_backward(input_, weight)
ctx.use_bias = bias is not None
ctx.process_group = process_group
ctx.async_grad_reduce_scatter = async_grad_reduce_scatter
ctx.dim = dim

input_parallel = _gather(input_, dim, process_group)

if bias is not None:
output = F.linear(input_parallel, weight, bias)
else:
output = F.linear(input_parallel, weight)

return output

@staticmethod
def backward(ctx, grad_output):
input_, weight = ctx.saved_tensors
use_bias = ctx.use_bias
dim = ctx.dim
process_group = ctx.process_group

# TODO: overlap SP input with gradient computation
input_parallel = _gather(input_, dim, process_group)

total_input = input_parallel
grad_input = grad_output.matmul(weight)
grad_output = grad_output.contiguous()
# Convert the tensor shapes to 2D for execution compatibility
if len(grad_output.shape) > 2:
grad_output = grad_output.view(-1, grad_output.shape[-1])
total_input = total_input.view(-1, total_input.shape[-1])

# TODO: overlap SP input with gradient computation
if ctx.async_grad_reduce_scatter:
# Asynchronous reduce-scatter
new_shape = list(input_parallel.shape)
new_shape[dim] = new_shape[dim] // dist.get_world_size(process_group)
input_list = [
item.contiguous() for item in torch.chunk(grad_input, dist.get_world_size(process_group), dim=dim)
]
output = torch.empty(new_shape, dtype=input_parallel.dtype, device=input_parallel.device).contiguous()
handle = dist.reduce_scatter(output, input_list, group=process_group, async_op=True)
# Delay the start of weight gradient computation shortly (3us) to have
# reduce-scatter scheduled first and have GPU resources allocated
_ = torch.empty(1, device=grad_output.device) + 1

grad_weight = grad_output.t().matmul(total_input)
grad_bias = grad_output.sum(dim=0) if use_bias else None

if ctx.async_grad_reduce_scatter:
handle.wait()

# print(output, output.shape)
return output, grad_weight, grad_bias, None, None, None


class _LinearWithReduceScatterForwardGatherBackward(torch.autograd.Function):
"""Gather input from sequence parallel in forward and reduce-scatter gradient in backward

Args:
input_ (`torch.Tensor`): The input tensor from sequence parallel region.
process_group (`torch.distributed.ProcessGroup`): The process group used for collective communication.

"""

@staticmethod
def forward(ctx, input_, process_group, dim):
ctx.dim = dim
ctx.process_group = process_group

# do reduce-scatter
new_shape = list(input_.shape)
new_shape[dim] = new_shape[dim] // dist.get_world_size(process_group)
input_list = [item.contiguous() for item in torch.chunk(input_, dist.get_world_size(process_group), dim=dim)]
output = torch.empty(new_shape, dtype=input_.dtype, device=input_.device)
dist.reduce_scatter(output, input_list, group=process_group)

return output

@staticmethod
def backward(ctx, grad_output):
dim = ctx.dim
process_group = ctx.process_group

return _gather(grad_output, dim, process_group), None, None


class _SplitForwardGatherBackward(torch.autograd.Function):
"""
Split the input and keep only the corresponding chuck to the rank.
Expand Down Expand Up @@ -200,6 +299,26 @@ def backward(ctx, grad_output):
return _reduce(grad_output, ctx.process_group), None


class _GatherForwardSplitBackward(torch.autograd.Function):
"""Gather the input from model parallel region and concatenate.

Args:
input_: input matrix.
parallel_mode: parallel mode.
dim: dimension
"""

@staticmethod
def forward(ctx, input_, dim, process_group):
ctx.process_group = process_group
ctx.dim = dim
return _gather(input_, dim, process_group)

@staticmethod
def backward(ctx, grad_output):
return _split(grad_output, ctx.dim, ctx.process_group), None, None


def _reduce(input_, process_group):
# skip if only one rank involved
if dist.get_world_size(process_group) == 1:
Expand Down Expand Up @@ -235,6 +354,7 @@ def _gather(input_, dim=-1, process_group=None):
return input_

# all gather
input_ = input_.contiguous()
rank = dist.get_rank(process_group)
tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
tensor_list[rank] = input_
Expand All @@ -246,24 +366,25 @@ def _gather(input_, dim=-1, process_group=None):
return output


class _GatherForwardSplitBackward(torch.autograd.Function):
"""Gather the input from model parallel region and concatenate.
def _reduce_scatter(intput_, dim=1, process_group=None):
""" Do reduce-scatter operation.

Args:
input_: input matrix.
parallel_mode: parallel mode.
dim: dimension
input_ (`torch.Tensor`): The input tensor from sequence parallel region.
dim (int): The dimension to perform reduce-scatter.
process_group (`torch.distributed.ProcessGroup`): The process group used for collective communication.
"""
world_size = dist.get_world_size(process_group)
if world_size == 1:
return intput_

@staticmethod
def forward(ctx, input_, dim, process_group):
ctx.process_group = process_group
ctx.dim = dim
return _gather(input_, dim, process_group)
# reduce-scatter
new_shape = list(intput_.shape)
new_shape[dim] = new_shape[dim] // world_size
output = torch.empty(new_shape, dtype=intput_.dtype, device=intput_.device)
dist.reduce_scatter(output, intput_, group=process_group)

@staticmethod
def backward(ctx, grad_output):
return _split(grad_output, ctx.dim, ctx.process_group), None, None
return output


def matmul_with_async_comm(input_, weight, bias, process_group, async_grad_allreduce):
Expand All @@ -274,6 +395,15 @@ def linear_with_async_comm(input_, weight, bias, process_group, async_grad_allre
return LinearWithAsyncCommunication.apply(input_, weight, bias, process_group, async_grad_allreduce)


def linear_gather_forward_reducescatter_backward(input_, weight, bias, process_group, async_grad_allreduce, dim):
return _LinearWithGatherForwardReduceScatterBackward.apply(input_, weight, bias, process_group,
async_grad_allreduce, dim)


def linear_reducescatter_forward_gather_backward(input_, process_group, dim):
return _LinearWithReduceScatterForwardGatherBackward.apply(input_, process_group, dim)


def gather_forward_split_backward(input_, dim, process_group):
return _GatherForwardSplitBackward.apply(input_, dim, process_group)

Expand Down
17 changes: 15 additions & 2 deletions colossalai/shardformer/layer/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@

from ._operation import (
gather_forward_split_backward,
linear_gather_forward_reducescatter_backward,
linear_reducescatter_forward_gather_backward,
linear_with_async_comm,
reduce_forward,
split_forward_gather_backward,
Expand Down Expand Up @@ -63,6 +65,7 @@ def __init__(self,
device: torch.device = None,
process_group: ProcessGroup = None,
gather_output: bool = False,
seq_parallel: bool = False,
skip_bias_add: bool = False,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)):
Expand All @@ -72,6 +75,7 @@ def __init__(self,
self.in_features = in_features
self.out_features = out_features
self.gather_output = gather_output
self.seq_parallel = seq_parallel
self.skip_bias_add = skip_bias_add
self.device = device
self.process_group = process_group
Expand Down Expand Up @@ -153,7 +157,11 @@ def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]:

# Matrix multiply.
bias = self.bias if not self.skip_bias_add else None
output_parallel = linear_with_async_comm(input_parallel, self.weight, bias, self.process_group, True)
if self.seq_parallel:
output_parallel = linear_gather_forward_reducescatter_backward(input_parallel, self.weight, bias,
self.process_group, True, 1)
else:
output_parallel = linear_with_async_comm(input_parallel, self.weight, bias, self.process_group, True)

if self.gather_output:
# All-gather across the partitions.
Expand Down Expand Up @@ -194,6 +202,7 @@ def __init__(self,
dtype: torch.dtype = None,
device: torch.device = None,
process_group: ProcessGroup = None,
seq_parallel: bool = False,
parallel_input: bool = True,
skip_bias_add: bool = False,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
Expand All @@ -209,6 +218,7 @@ def __init__(self,
self.parallel_input = parallel_input
self.skip_bias_add = skip_bias_add
self.process_group = process_group
self.seq_parallel = seq_parallel
self.num_partitions = dist.get_world_size(self.process_group)

if skip_bias_add and not bias:
Expand Down Expand Up @@ -326,7 +336,10 @@ def forward(self, input_: Tensor) -> Tensor:
output = torch.cat(output_parallel_list, dim=-1)
else:
output_parallel = F.linear(input_, self.weight)
output = reduce_forward(output_parallel, self.process_group)
if self.seq_parallel:
output = linear_reducescatter_forward_gather_backward(output_parallel, self.process_group, 1)
else:
output = reduce_forward(output_parallel, self.process_group)

if not self.skip_bias_add:
if self.bias is not None:
Expand Down
Loading