Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 85 additions & 48 deletions colossalai/shardformer/layer/_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,16 +149,18 @@ class _LinearWithGatherForwardReduceScatterBackward(torch.autograd.Function):
Args:
input_ (`torch.Tensor`): The input tensor from sequence parallel region.
process_group (`torch.distributed.ProcessGroup`): The process group used for collective communication.
overlap (`bool`): Whther to overlap the all_gather op and gradient calculate in backward.

"""

@staticmethod
def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim):
def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap):
ctx.save_for_backward(input_, weight)
ctx.use_bias = bias is not None
ctx.process_group = process_group
ctx.async_grad_reduce_scatter = async_grad_reduce_scatter
ctx.dim = dim
ctx.overlap = overlap

input_parallel = _gather(input_, dim, process_group)

Expand All @@ -175,42 +177,80 @@ def backward(ctx, grad_output):
use_bias = ctx.use_bias
dim = ctx.dim
process_group = ctx.process_group
overlap = ctx.overlap

if not overlap:
# TODO: overlap SP input with gradient computation
input_parallel = _gather(input_, dim, process_group)

total_input = input_parallel
grad_input = grad_output.matmul(weight)
grad_output = grad_output.contiguous()
# Convert the tensor shapes to 2D for execution compatibility
if len(grad_output.shape) > 2:
grad_output = grad_output.view(-1, grad_output.shape[-1])
total_input = total_input.view(-1, total_input.shape[-1])

# TODO: overlap SP input with gradient computation
if ctx.async_grad_reduce_scatter:
# Asynchronous reduce-scatter
input_list = [
item.contiguous() for item in torch.chunk(grad_input, dist.get_world_size(process_group), dim=dim)
]
output = torch.empty(input_.shape, dtype=input_parallel.dtype,
device=input_parallel.device).contiguous()
handle = dist.reduce_scatter(output, input_list, group=process_group, async_op=True)
# Delay the start of weight gradient computation shortly (3us) to have
# reduce-scatter scheduled first and have GPU resources allocated
_ = torch.empty(1, device=grad_output.device) + 1

grad_weight = grad_output.t().matmul(total_input)
grad_bias = grad_output.sum(dim=0) if use_bias else None

if ctx.async_grad_reduce_scatter:
handle.wait()

# TODO: overlap SP input with gradient computation
input_parallel = _gather(input_, dim, process_group)

total_input = input_parallel
grad_input = grad_output.matmul(weight)
grad_output = grad_output.contiguous()
# Convert the tensor shapes to 2D for execution compatibility
if len(grad_output.shape) > 2:
grad_output = grad_output.view(-1, grad_output.shape[-1])
total_input = total_input.view(-1, total_input.shape[-1])

# TODO: overlap SP input with gradient computation
if ctx.async_grad_reduce_scatter:
# Asynchronous reduce-scatter
new_shape = list(input_parallel.shape)
assert new_shape[dim] % dist.get_world_size(process_group) == 0, \
f'The dimension to split ({new_shape[dim]}) is not a multiple of tensor parallel size ({dist.get_world_size(process_group)}). '
new_shape[dim] = new_shape[dim] // dist.get_world_size(process_group)
input_list = [
item.contiguous() for item in torch.chunk(grad_input, dist.get_world_size(process_group), dim=dim)
]
output = torch.empty(new_shape, dtype=input_parallel.dtype, device=input_parallel.device).contiguous()
handle = dist.reduce_scatter(output, input_list, group=process_group, async_op=True)
# Delay the start of weight gradient computation shortly (3us) to have
# reduce-scatter scheduled first and have GPU resources allocated
_ = torch.empty(1, device=grad_output.device) + 1

grad_weight = grad_output.t().matmul(total_input)
grad_bias = grad_output.sum(dim=0) if use_bias else None

if ctx.async_grad_reduce_scatter:
handle.wait()

# print(output, output.shape)
return output, grad_weight, grad_bias, None, None, None
else:
# create new stream for calculate the gradient
calculate_stream = torch.cuda.Stream()

# do all gather in default stream
input_ = input_.contiguous()
world_size = dist.get_world_size(process_group)
rank = dist.get_rank(process_group)
tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
tensor_list[rank] = input_
gather_handle = dist.all_gather(tensor_list, input_, group=process_group, async_op=True)

# calculate gradient in calculate_stream
with torch.cuda.stream(calculate_stream):
# calculate
grad_input = grad_output.matmul(weight)
grad_output = grad_output.contiguous()
# Convert the tensor shapes to 2D for execution compatibility
if len(grad_output.shape) > 2:
grad_output = grad_output.view(-1, grad_output.shape[-1])
grad_bias = grad_output.sum(dim=0) if use_bias else None

# prepare data
input_list = [
item.contiguous() for item in torch.chunk(grad_input, dist.get_world_size(process_group), dim=dim)
]
output = torch.empty(input_.shape, dtype=input_.dtype, device=input_.device).contiguous()

torch.cuda.synchronize()

reducescatter_handle = dist.reduce_scatter(output, input_list, group=process_group, async_op=True)
with torch.cuda.stream(calculate_stream):
input_parallel = torch.cat(tensor_list, dim=dim).contiguous()
if len(input_parallel.shape) > 2:
input_parallel = input_parallel.view(-1, input_parallel.shape[-1])
print(grad_output.shape, input_parallel.shape)
grad_weight = grad_output.t().matmul(input_parallel)

torch.cuda.synchronize()

Comment thread
ver217 marked this conversation as resolved.
return output, grad_weight, grad_bias, None, None, None, None


class _LinearWithReduceScatterForwardGatherBackward(torch.autograd.Function):
Expand Down Expand Up @@ -294,14 +334,10 @@ def backward(ctx, grad_output):
# TODO: overlap SP input with gradient computation
if ctx.async_grad_reduce_scatter:
# Asynchronous reduce-scatter
new_shape = list(input_parallel.shape)
assert new_shape[dim] % dist.get_world_size(process_group) == 0, \
f'The dimension to split ({new_shape[dim]}) is not a multiple of tensor parallel size ({dist.get_world_size(process_group)}). '
new_shape[dim] = new_shape[dim] // dist.get_world_size(process_group)
input_list = [
item.contiguous() for item in torch.chunk(grad_input, dist.get_world_size(process_group), dim=dim)
]
output = torch.empty(new_shape, dtype=input_parallel.dtype, device=input_parallel.device).contiguous()
output = torch.empty(input_.shape, dtype=input_parallel.dtype, device=input_parallel.device).contiguous()
handle = dist.reduce_scatter(output, input_list, group=process_group, async_op=True)
# Delay the start of weight gradient computation shortly (3us) to have
# reduce-scatter scheduled first and have GPU resources allocated
Expand Down Expand Up @@ -442,7 +478,7 @@ def _gather(input_, dim=-1, process_group=None):
return output


def _reduce_scatter(intput_, dim=1, process_group=None):
def _reduce_scatter(input_, dim=1, process_group=None):
""" Do reduce-scatter operation.

Args:
Expand All @@ -452,15 +488,15 @@ def _reduce_scatter(intput_, dim=1, process_group=None):
"""
world_size = dist.get_world_size(process_group)
if world_size == 1:
return intput_
return input_

# reduce-scatter
new_shape = list(intput_.shape)
new_shape = list(input_.shape)
assert new_shape[dim] % dist.get_world_size(process_group) == 0, \
f'The dimension to split ({new_shape[dim]}) is not a multiple of tensor parallel size ({dist.get_world_size(process_group)}). '
new_shape[dim] = new_shape[dim] // world_size
output = torch.empty(new_shape, dtype=intput_.dtype, device=intput_.device)
dist.reduce_scatter(output, intput_, group=process_group)
output = torch.empty(new_shape, dtype=input_.dtype, device=input_.device)
dist.reduce_scatter(output, input_, group=process_group)

return output

Expand All @@ -473,9 +509,10 @@ def linear_with_async_comm(input_, weight, bias, process_group, async_grad_allre
return LinearWithAsyncCommunication.apply(input_, weight, bias, process_group, async_grad_allreduce)


def linear_gather_forward_reducescatter_backward(input_, weight, bias, process_group, async_grad_reduce_scatter, dim):
def linear_gather_forward_reducescatter_backward(input_, weight, bias, process_group, async_grad_reduce_scatter, dim,
overlap):
return _LinearWithGatherForwardReduceScatterBackward.apply(input_, weight, bias, process_group,
async_grad_reduce_scatter, dim)
async_grad_reduce_scatter, dim, overlap)


def linear_reducescatter_forward_gather_backward(input_, process_group, dim):
Expand Down
5 changes: 4 additions & 1 deletion colossalai/shardformer/layer/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ class Linear1D_Col(ParallelModule):
to all GPUs, otherwise, every GPU will have its output
which is :math:`Y_i = XA_i`, defaults to False
seq_parallel (`bool`): If set to ``True``, it will use sequence parallel, defaults to False.
overlap (`bool`): If set to ``True``, it will overlap input all-gather with gradient computation during backward, defaults to False.
skip_bias_add (bool): If set to ``True``, it will skip bias add for linear layer,
which is preserved for kernel fusion, defaults to False
weight_initializer (`typing.Callable`):
Expand All @@ -73,6 +74,7 @@ def __init__(self,
process_group: ProcessGroup = None,
gather_output: bool = False,
seq_parallel: bool = False,
overlap: bool = False,
skip_bias_add: bool = False,
weight: Optional[Parameter] = None,
bias_: Optional[Parameter] = None,
Expand All @@ -85,6 +87,7 @@ def __init__(self,
self.out_features = out_features
self.gather_output = gather_output
self.seq_parallel = seq_parallel
self.overlap = overlap
self.skip_bias_add = skip_bias_add
self.device = device
self.process_group = process_group
Expand Down Expand Up @@ -179,7 +182,7 @@ def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]:
bias = self.bias if not self.skip_bias_add else None
if self.seq_parallel:
output_parallel = linear_gather_forward_reducescatter_backward(input_parallel, self.weight, bias,
self.process_group, True, 1)
self.process_group, True, 1, self.overlap)
else:
output_parallel = linear_with_async_comm(input_parallel, self.weight, bias, self.process_group, True)

Expand Down
17 changes: 10 additions & 7 deletions tests/test_shardformer/test_layer/test_linear_1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,16 @@
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn


def check_linear_1d_col(lazy_init: bool, seq_parallel: bool):
def check_linear_1d_col(lazy_init: bool, seq_parallel: bool, overlap: bool):
ctx = LazyInitContext() if lazy_init else nullcontext()
linear = nn.Linear(32, 128).cuda()
with ctx:
linear_copy = nn.Linear(32, 128).cuda()
linear_col = Linear1D_Col.from_native_module(linear_copy,
process_group=None,
gather_output=True,
seq_parallel=seq_parallel)
seq_parallel=seq_parallel,
overlap=overlap)

# ensure that the parameters are distributed
assert is_distributed_tensor(linear_col.weight)
Expand Down Expand Up @@ -111,7 +112,7 @@ def check_linear_1d_row(lazy_init: bool, seq_parallel: bool):
assert_close(x_for_unshard.grad, x_for_shard.grad)


def check_linear_col_plus_row(lazy_init: bool, seq_parallel: bool):
def check_linear_col_plus_row(lazy_init: bool, seq_parallel: bool, overlap: bool):
ctx = LazyInitContext() if lazy_init else nullcontext()

linear_1 = nn.Linear(32, 128).cuda()
Expand All @@ -123,7 +124,8 @@ def check_linear_col_plus_row(lazy_init: bool, seq_parallel: bool):
linear_col = Linear1D_Col.from_native_module(linear_1_copy,
process_group=None,
gather_output=False,
seq_parallel=seq_parallel)
seq_parallel=seq_parallel,
overlap=overlap)
linear_row = Linear1D_Row.from_native_module(linear_2_copy,
process_group=None,
parallel_input=True,
Expand Down Expand Up @@ -166,10 +168,11 @@ def check_linear_col_plus_row(lazy_init: bool, seq_parallel: bool):

@parameterize('lazy_init', [False, True])
@parameterize('seq_parallel', [False, True])
def run_dist_linear_test(lazy_init, seq_parallel):
check_linear_1d_col(lazy_init, seq_parallel)
@parameterize('overlap', [False, True])
def run_dist_linear_test(lazy_init, seq_parallel, overlap):
check_linear_1d_col(lazy_init, seq_parallel, overlap)
check_linear_1d_row(lazy_init, seq_parallel)
check_linear_col_plus_row(lazy_init, seq_parallel)
check_linear_col_plus_row(lazy_init, seq_parallel, overlap)


def check_dist_linear(rank, world_size, port):
Expand Down