diff --git a/colossalai/shardformer/layer/_operation.py b/colossalai/shardformer/layer/_operation.py index fb7f4ab06760..dbb50caf401c 100644 --- a/colossalai/shardformer/layer/_operation.py +++ b/colossalai/shardformer/layer/_operation.py @@ -149,16 +149,18 @@ class _LinearWithGatherForwardReduceScatterBackward(torch.autograd.Function): Args: input_ (`torch.Tensor`): The input tensor from sequence parallel region. process_group (`torch.distributed.ProcessGroup`): The process group used for collective communication. + overlap (`bool`): Whther to overlap the all_gather op and gradient calculate in backward. """ @staticmethod - def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim): + def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap): ctx.save_for_backward(input_, weight) ctx.use_bias = bias is not None ctx.process_group = process_group ctx.async_grad_reduce_scatter = async_grad_reduce_scatter ctx.dim = dim + ctx.overlap = overlap input_parallel = _gather(input_, dim, process_group) @@ -175,42 +177,80 @@ def backward(ctx, grad_output): use_bias = ctx.use_bias dim = ctx.dim process_group = ctx.process_group + overlap = ctx.overlap + + if not overlap: + # TODO: overlap SP input with gradient computation + input_parallel = _gather(input_, dim, process_group) + + total_input = input_parallel + grad_input = grad_output.matmul(weight) + grad_output = grad_output.contiguous() + # Convert the tensor shapes to 2D for execution compatibility + if len(grad_output.shape) > 2: + grad_output = grad_output.view(-1, grad_output.shape[-1]) + total_input = total_input.view(-1, total_input.shape[-1]) + + # TODO: overlap SP input with gradient computation + if ctx.async_grad_reduce_scatter: + # Asynchronous reduce-scatter + input_list = [ + item.contiguous() for item in torch.chunk(grad_input, dist.get_world_size(process_group), dim=dim) + ] + output = torch.empty(input_.shape, dtype=input_parallel.dtype, + device=input_parallel.device).contiguous() + handle = dist.reduce_scatter(output, input_list, group=process_group, async_op=True) + # Delay the start of weight gradient computation shortly (3us) to have + # reduce-scatter scheduled first and have GPU resources allocated + _ = torch.empty(1, device=grad_output.device) + 1 + + grad_weight = grad_output.t().matmul(total_input) + grad_bias = grad_output.sum(dim=0) if use_bias else None + + if ctx.async_grad_reduce_scatter: + handle.wait() - # TODO: overlap SP input with gradient computation - input_parallel = _gather(input_, dim, process_group) - - total_input = input_parallel - grad_input = grad_output.matmul(weight) - grad_output = grad_output.contiguous() - # Convert the tensor shapes to 2D for execution compatibility - if len(grad_output.shape) > 2: - grad_output = grad_output.view(-1, grad_output.shape[-1]) - total_input = total_input.view(-1, total_input.shape[-1]) - - # TODO: overlap SP input with gradient computation - if ctx.async_grad_reduce_scatter: - # Asynchronous reduce-scatter - new_shape = list(input_parallel.shape) - assert new_shape[dim] % dist.get_world_size(process_group) == 0, \ - f'The dimension to split ({new_shape[dim]}) is not a multiple of tensor parallel size ({dist.get_world_size(process_group)}). ' - new_shape[dim] = new_shape[dim] // dist.get_world_size(process_group) - input_list = [ - item.contiguous() for item in torch.chunk(grad_input, dist.get_world_size(process_group), dim=dim) - ] - output = torch.empty(new_shape, dtype=input_parallel.dtype, device=input_parallel.device).contiguous() - handle = dist.reduce_scatter(output, input_list, group=process_group, async_op=True) - # Delay the start of weight gradient computation shortly (3us) to have - # reduce-scatter scheduled first and have GPU resources allocated - _ = torch.empty(1, device=grad_output.device) + 1 - - grad_weight = grad_output.t().matmul(total_input) - grad_bias = grad_output.sum(dim=0) if use_bias else None - - if ctx.async_grad_reduce_scatter: - handle.wait() - - # print(output, output.shape) - return output, grad_weight, grad_bias, None, None, None + else: + # create new stream for calculate the gradient + calculate_stream = torch.cuda.Stream() + + # do all gather in default stream + input_ = input_.contiguous() + world_size = dist.get_world_size(process_group) + rank = dist.get_rank(process_group) + tensor_list = [torch.empty_like(input_) for _ in range(world_size)] + tensor_list[rank] = input_ + gather_handle = dist.all_gather(tensor_list, input_, group=process_group, async_op=True) + + # calculate gradient in calculate_stream + with torch.cuda.stream(calculate_stream): + # calculate + grad_input = grad_output.matmul(weight) + grad_output = grad_output.contiguous() + # Convert the tensor shapes to 2D for execution compatibility + if len(grad_output.shape) > 2: + grad_output = grad_output.view(-1, grad_output.shape[-1]) + grad_bias = grad_output.sum(dim=0) if use_bias else None + + # prepare data + input_list = [ + item.contiguous() for item in torch.chunk(grad_input, dist.get_world_size(process_group), dim=dim) + ] + output = torch.empty(input_.shape, dtype=input_.dtype, device=input_.device).contiguous() + + torch.cuda.synchronize() + + reducescatter_handle = dist.reduce_scatter(output, input_list, group=process_group, async_op=True) + with torch.cuda.stream(calculate_stream): + input_parallel = torch.cat(tensor_list, dim=dim).contiguous() + if len(input_parallel.shape) > 2: + input_parallel = input_parallel.view(-1, input_parallel.shape[-1]) + print(grad_output.shape, input_parallel.shape) + grad_weight = grad_output.t().matmul(input_parallel) + + torch.cuda.synchronize() + + return output, grad_weight, grad_bias, None, None, None, None class _LinearWithReduceScatterForwardGatherBackward(torch.autograd.Function): @@ -294,14 +334,10 @@ def backward(ctx, grad_output): # TODO: overlap SP input with gradient computation if ctx.async_grad_reduce_scatter: # Asynchronous reduce-scatter - new_shape = list(input_parallel.shape) - assert new_shape[dim] % dist.get_world_size(process_group) == 0, \ - f'The dimension to split ({new_shape[dim]}) is not a multiple of tensor parallel size ({dist.get_world_size(process_group)}). ' - new_shape[dim] = new_shape[dim] // dist.get_world_size(process_group) input_list = [ item.contiguous() for item in torch.chunk(grad_input, dist.get_world_size(process_group), dim=dim) ] - output = torch.empty(new_shape, dtype=input_parallel.dtype, device=input_parallel.device).contiguous() + output = torch.empty(input_.shape, dtype=input_parallel.dtype, device=input_parallel.device).contiguous() handle = dist.reduce_scatter(output, input_list, group=process_group, async_op=True) # Delay the start of weight gradient computation shortly (3us) to have # reduce-scatter scheduled first and have GPU resources allocated @@ -442,7 +478,7 @@ def _gather(input_, dim=-1, process_group=None): return output -def _reduce_scatter(intput_, dim=1, process_group=None): +def _reduce_scatter(input_, dim=1, process_group=None): """ Do reduce-scatter operation. Args: @@ -452,15 +488,15 @@ def _reduce_scatter(intput_, dim=1, process_group=None): """ world_size = dist.get_world_size(process_group) if world_size == 1: - return intput_ + return input_ # reduce-scatter - new_shape = list(intput_.shape) + new_shape = list(input_.shape) assert new_shape[dim] % dist.get_world_size(process_group) == 0, \ f'The dimension to split ({new_shape[dim]}) is not a multiple of tensor parallel size ({dist.get_world_size(process_group)}). ' new_shape[dim] = new_shape[dim] // world_size - output = torch.empty(new_shape, dtype=intput_.dtype, device=intput_.device) - dist.reduce_scatter(output, intput_, group=process_group) + output = torch.empty(new_shape, dtype=input_.dtype, device=input_.device) + dist.reduce_scatter(output, input_, group=process_group) return output @@ -473,9 +509,10 @@ def linear_with_async_comm(input_, weight, bias, process_group, async_grad_allre return LinearWithAsyncCommunication.apply(input_, weight, bias, process_group, async_grad_allreduce) -def linear_gather_forward_reducescatter_backward(input_, weight, bias, process_group, async_grad_reduce_scatter, dim): +def linear_gather_forward_reducescatter_backward(input_, weight, bias, process_group, async_grad_reduce_scatter, dim, + overlap): return _LinearWithGatherForwardReduceScatterBackward.apply(input_, weight, bias, process_group, - async_grad_reduce_scatter, dim) + async_grad_reduce_scatter, dim, overlap) def linear_reducescatter_forward_gather_backward(input_, process_group, dim): diff --git a/colossalai/shardformer/layer/linear.py b/colossalai/shardformer/layer/linear.py index 0305529addb9..6b3b55987c6c 100644 --- a/colossalai/shardformer/layer/linear.py +++ b/colossalai/shardformer/layer/linear.py @@ -53,6 +53,7 @@ class Linear1D_Col(ParallelModule): to all GPUs, otherwise, every GPU will have its output which is :math:`Y_i = XA_i`, defaults to False seq_parallel (`bool`): If set to ``True``, it will use sequence parallel, defaults to False. + overlap (`bool`): If set to ``True``, it will overlap input all-gather with gradient computation during backward, defaults to False. skip_bias_add (bool): If set to ``True``, it will skip bias add for linear layer, which is preserved for kernel fusion, defaults to False weight_initializer (`typing.Callable`): @@ -73,6 +74,7 @@ def __init__(self, process_group: ProcessGroup = None, gather_output: bool = False, seq_parallel: bool = False, + overlap: bool = False, skip_bias_add: bool = False, weight: Optional[Parameter] = None, bias_: Optional[Parameter] = None, @@ -85,6 +87,7 @@ def __init__(self, self.out_features = out_features self.gather_output = gather_output self.seq_parallel = seq_parallel + self.overlap = overlap self.skip_bias_add = skip_bias_add self.device = device self.process_group = process_group @@ -179,7 +182,7 @@ def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]: bias = self.bias if not self.skip_bias_add else None if self.seq_parallel: output_parallel = linear_gather_forward_reducescatter_backward(input_parallel, self.weight, bias, - self.process_group, True, 1) + self.process_group, True, 1, self.overlap) else: output_parallel = linear_with_async_comm(input_parallel, self.weight, bias, self.process_group, True) diff --git a/tests/test_shardformer/test_layer/test_linear_1d.py b/tests/test_shardformer/test_layer/test_linear_1d.py index 182c5eb36392..3ad8f14b99e6 100644 --- a/tests/test_shardformer/test_layer/test_linear_1d.py +++ b/tests/test_shardformer/test_layer/test_linear_1d.py @@ -12,7 +12,7 @@ from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn -def check_linear_1d_col(lazy_init: bool, seq_parallel: bool): +def check_linear_1d_col(lazy_init: bool, seq_parallel: bool, overlap: bool): ctx = LazyInitContext() if lazy_init else nullcontext() linear = nn.Linear(32, 128).cuda() with ctx: @@ -20,7 +20,8 @@ def check_linear_1d_col(lazy_init: bool, seq_parallel: bool): linear_col = Linear1D_Col.from_native_module(linear_copy, process_group=None, gather_output=True, - seq_parallel=seq_parallel) + seq_parallel=seq_parallel, + overlap=overlap) # ensure that the parameters are distributed assert is_distributed_tensor(linear_col.weight) @@ -111,7 +112,7 @@ def check_linear_1d_row(lazy_init: bool, seq_parallel: bool): assert_close(x_for_unshard.grad, x_for_shard.grad) -def check_linear_col_plus_row(lazy_init: bool, seq_parallel: bool): +def check_linear_col_plus_row(lazy_init: bool, seq_parallel: bool, overlap: bool): ctx = LazyInitContext() if lazy_init else nullcontext() linear_1 = nn.Linear(32, 128).cuda() @@ -123,7 +124,8 @@ def check_linear_col_plus_row(lazy_init: bool, seq_parallel: bool): linear_col = Linear1D_Col.from_native_module(linear_1_copy, process_group=None, gather_output=False, - seq_parallel=seq_parallel) + seq_parallel=seq_parallel, + overlap=overlap) linear_row = Linear1D_Row.from_native_module(linear_2_copy, process_group=None, parallel_input=True, @@ -166,10 +168,11 @@ def check_linear_col_plus_row(lazy_init: bool, seq_parallel: bool): @parameterize('lazy_init', [False, True]) @parameterize('seq_parallel', [False, True]) -def run_dist_linear_test(lazy_init, seq_parallel): - check_linear_1d_col(lazy_init, seq_parallel) +@parameterize('overlap', [False, True]) +def run_dist_linear_test(lazy_init, seq_parallel, overlap): + check_linear_1d_col(lazy_init, seq_parallel, overlap) check_linear_1d_row(lazy_init, seq_parallel) - check_linear_col_plus_row(lazy_init, seq_parallel) + check_linear_col_plus_row(lazy_init, seq_parallel, overlap) def check_dist_linear(rank, world_size, port):