diff --git a/colossalai/shardformer/layer/_operation.py b/colossalai/shardformer/layer/_operation.py index 7e97bee01b33..fcd43bd857a4 100644 --- a/colossalai/shardformer/layer/_operation.py +++ b/colossalai/shardformer/layer/_operation.py @@ -141,6 +141,105 @@ def backward(ctx, grad_output): return grad_input, grad_weight, grad_bias, None, None, None +class _LinearWithGatherForwardReduceScatterBackward(torch.autograd.Function): + """Gather input from sequence parallel in forward and reduce-scatter gradient in backward + + Args: + input_ (`torch.Tensor`): The input tensor from sequence parallel region. + process_group (`torch.distributed.ProcessGroup`): The process group used for collective communication. + + """ + + @staticmethod + def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim): + ctx.save_for_backward(input_, weight) + ctx.use_bias = bias is not None + ctx.process_group = process_group + ctx.async_grad_reduce_scatter = async_grad_reduce_scatter + ctx.dim = dim + + input_parallel = _gather(input_, dim, process_group) + + if bias is not None: + output = F.linear(input_parallel, weight, bias) + else: + output = F.linear(input_parallel, weight) + + return output + + @staticmethod + def backward(ctx, grad_output): + input_, weight = ctx.saved_tensors + use_bias = ctx.use_bias + dim = ctx.dim + process_group = ctx.process_group + + # TODO: overlap SP input with gradient computation + input_parallel = _gather(input_, dim, process_group) + + total_input = input_parallel + grad_input = grad_output.matmul(weight) + grad_output = grad_output.contiguous() + # Convert the tensor shapes to 2D for execution compatibility + if len(grad_output.shape) > 2: + grad_output = grad_output.view(-1, grad_output.shape[-1]) + total_input = total_input.view(-1, total_input.shape[-1]) + + # TODO: overlap SP input with gradient computation + if ctx.async_grad_reduce_scatter: + # Asynchronous reduce-scatter + new_shape = list(input_parallel.shape) + new_shape[dim] = new_shape[dim] // dist.get_world_size(process_group) + input_list = [ + item.contiguous() for item in torch.chunk(grad_input, dist.get_world_size(process_group), dim=dim) + ] + output = torch.empty(new_shape, dtype=input_parallel.dtype, device=input_parallel.device).contiguous() + handle = dist.reduce_scatter(output, input_list, group=process_group, async_op=True) + # Delay the start of weight gradient computation shortly (3us) to have + # reduce-scatter scheduled first and have GPU resources allocated + _ = torch.empty(1, device=grad_output.device) + 1 + + grad_weight = grad_output.t().matmul(total_input) + grad_bias = grad_output.sum(dim=0) if use_bias else None + + if ctx.async_grad_reduce_scatter: + handle.wait() + + # print(output, output.shape) + return output, grad_weight, grad_bias, None, None, None + + +class _LinearWithReduceScatterForwardGatherBackward(torch.autograd.Function): + """Gather input from sequence parallel in forward and reduce-scatter gradient in backward + + Args: + input_ (`torch.Tensor`): The input tensor from sequence parallel region. + process_group (`torch.distributed.ProcessGroup`): The process group used for collective communication. + + """ + + @staticmethod + def forward(ctx, input_, process_group, dim): + ctx.dim = dim + ctx.process_group = process_group + + # do reduce-scatter + new_shape = list(input_.shape) + new_shape[dim] = new_shape[dim] // dist.get_world_size(process_group) + input_list = [item.contiguous() for item in torch.chunk(input_, dist.get_world_size(process_group), dim=dim)] + output = torch.empty(new_shape, dtype=input_.dtype, device=input_.device) + dist.reduce_scatter(output, input_list, group=process_group) + + return output + + @staticmethod + def backward(ctx, grad_output): + dim = ctx.dim + process_group = ctx.process_group + + return _gather(grad_output, dim, process_group), None, None + + class _SplitForwardGatherBackward(torch.autograd.Function): """ Split the input and keep only the corresponding chuck to the rank. @@ -200,6 +299,26 @@ def backward(ctx, grad_output): return _reduce(grad_output, ctx.process_group), None +class _GatherForwardSplitBackward(torch.autograd.Function): + """Gather the input from model parallel region and concatenate. + + Args: + input_: input matrix. + parallel_mode: parallel mode. + dim: dimension + """ + + @staticmethod + def forward(ctx, input_, dim, process_group): + ctx.process_group = process_group + ctx.dim = dim + return _gather(input_, dim, process_group) + + @staticmethod + def backward(ctx, grad_output): + return _split(grad_output, ctx.dim, ctx.process_group), None, None + + def _reduce(input_, process_group): # skip if only one rank involved if dist.get_world_size(process_group) == 1: @@ -235,6 +354,7 @@ def _gather(input_, dim=-1, process_group=None): return input_ # all gather + input_ = input_.contiguous() rank = dist.get_rank(process_group) tensor_list = [torch.empty_like(input_) for _ in range(world_size)] tensor_list[rank] = input_ @@ -246,24 +366,25 @@ def _gather(input_, dim=-1, process_group=None): return output -class _GatherForwardSplitBackward(torch.autograd.Function): - """Gather the input from model parallel region and concatenate. +def _reduce_scatter(intput_, dim=1, process_group=None): + """ Do reduce-scatter operation. Args: - input_: input matrix. - parallel_mode: parallel mode. - dim: dimension + input_ (`torch.Tensor`): The input tensor from sequence parallel region. + dim (int): The dimension to perform reduce-scatter. + process_group (`torch.distributed.ProcessGroup`): The process group used for collective communication. """ + world_size = dist.get_world_size(process_group) + if world_size == 1: + return intput_ - @staticmethod - def forward(ctx, input_, dim, process_group): - ctx.process_group = process_group - ctx.dim = dim - return _gather(input_, dim, process_group) + # reduce-scatter + new_shape = list(intput_.shape) + new_shape[dim] = new_shape[dim] // world_size + output = torch.empty(new_shape, dtype=intput_.dtype, device=intput_.device) + dist.reduce_scatter(output, intput_, group=process_group) - @staticmethod - def backward(ctx, grad_output): - return _split(grad_output, ctx.dim, ctx.process_group), None, None + return output def matmul_with_async_comm(input_, weight, bias, process_group, async_grad_allreduce): @@ -274,6 +395,15 @@ def linear_with_async_comm(input_, weight, bias, process_group, async_grad_allre return LinearWithAsyncCommunication.apply(input_, weight, bias, process_group, async_grad_allreduce) +def linear_gather_forward_reducescatter_backward(input_, weight, bias, process_group, async_grad_allreduce, dim): + return _LinearWithGatherForwardReduceScatterBackward.apply(input_, weight, bias, process_group, + async_grad_allreduce, dim) + + +def linear_reducescatter_forward_gather_backward(input_, process_group, dim): + return _LinearWithReduceScatterForwardGatherBackward.apply(input_, process_group, dim) + + def gather_forward_split_backward(input_, dim, process_group): return _GatherForwardSplitBackward.apply(input_, dim, process_group) diff --git a/colossalai/shardformer/layer/linear.py b/colossalai/shardformer/layer/linear.py index 26ba5883c64f..14511b7451d6 100644 --- a/colossalai/shardformer/layer/linear.py +++ b/colossalai/shardformer/layer/linear.py @@ -18,6 +18,8 @@ from ._operation import ( gather_forward_split_backward, + linear_gather_forward_reducescatter_backward, + linear_reducescatter_forward_gather_backward, linear_with_async_comm, reduce_forward, split_forward_gather_backward, @@ -63,6 +65,7 @@ def __init__(self, device: torch.device = None, process_group: ProcessGroup = None, gather_output: bool = False, + seq_parallel: bool = False, skip_bias_add: bool = False, weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)): @@ -72,6 +75,7 @@ def __init__(self, self.in_features = in_features self.out_features = out_features self.gather_output = gather_output + self.seq_parallel = seq_parallel self.skip_bias_add = skip_bias_add self.device = device self.process_group = process_group @@ -153,7 +157,11 @@ def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]: # Matrix multiply. bias = self.bias if not self.skip_bias_add else None - output_parallel = linear_with_async_comm(input_parallel, self.weight, bias, self.process_group, True) + if self.seq_parallel: + output_parallel = linear_gather_forward_reducescatter_backward(input_parallel, self.weight, bias, + self.process_group, True, 1) + else: + output_parallel = linear_with_async_comm(input_parallel, self.weight, bias, self.process_group, True) if self.gather_output: # All-gather across the partitions. @@ -194,6 +202,7 @@ def __init__(self, dtype: torch.dtype = None, device: torch.device = None, process_group: ProcessGroup = None, + seq_parallel: bool = False, parallel_input: bool = True, skip_bias_add: bool = False, weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), @@ -209,6 +218,7 @@ def __init__(self, self.parallel_input = parallel_input self.skip_bias_add = skip_bias_add self.process_group = process_group + self.seq_parallel = seq_parallel self.num_partitions = dist.get_world_size(self.process_group) if skip_bias_add and not bias: @@ -326,7 +336,10 @@ def forward(self, input_: Tensor) -> Tensor: output = torch.cat(output_parallel_list, dim=-1) else: output_parallel = F.linear(input_, self.weight) - output = reduce_forward(output_parallel, self.process_group) + if self.seq_parallel: + output = linear_reducescatter_forward_gather_backward(output_parallel, self.process_group, 1) + else: + output = reduce_forward(output_parallel, self.process_group) if not self.skip_bias_add: if self.bias is not None: diff --git a/tests/test_shardformer/test_layer/test_linear_1d.py b/tests/test_shardformer/test_layer/test_linear_1d.py index da3bdc1d78d3..ea5ea6966249 100644 --- a/tests/test_shardformer/test_layer/test_linear_1d.py +++ b/tests/test_shardformer/test_layer/test_linear_1d.py @@ -6,12 +6,15 @@ import colossalai from colossalai.shardformer.layer import Linear1D_Col, Linear1D_Row from colossalai.tensor.d_tensor import is_distributed_tensor -from colossalai.testing import rerun_if_address_is_in_use, spawn +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn -def check_linear_1d_col(): +def check_linear_1d_col(seq_parallel): linear = nn.Linear(32, 128).cuda() - linear_col = Linear1D_Col.from_native_module(linear, process_group=None, gather_output=True) + linear_col = Linear1D_Col.from_native_module(linear, + process_group=None, + gather_output=True, + seq_parallel=seq_parallel) # ensure that the parameters are distributed assert is_distributed_tensor(linear_col.weight) @@ -26,10 +29,11 @@ def check_linear_1d_col(): linear_col.load_state_dict(linear.state_dict()) # check computation correctness - x = torch.rand(4, 32).cuda() + # [batch_size, seq_len, hidden_size] + x = torch.rand(2, 4, 32).cuda() x_for_unshard = x.expand_as(x.clone()) x_for_unshard.requires_grad_(True) - x_for_shard = x.expand_as(x.clone()) + x_for_shard = x.expand_as(x.clone()) if seq_parallel is False else torch.chunk(x.clone(), 2, dim=1)[dist.get_rank()] x_for_shard.requires_grad_(True) out = linear(x_for_unshard) @@ -47,18 +51,24 @@ def check_linear_1d_col(): # check the input gradients assert x_for_shard.grad is not None assert x_for_unshard.grad is not None - assert_close(x_for_unshard.grad, x_for_shard.grad) + target_unshard_gard = x_for_unshard.grad if seq_parallel is False else torch.chunk( + x_for_unshard.grad.clone(), 2, dim=1)[dist.get_rank()] + assert_close(target_unshard_gard, x_for_shard.grad) -def check_linear_1d_row(): +def check_linear_1d_row(seq_parallel): linear = nn.Linear(32, 128).cuda() - linear_row = Linear1D_Row.from_native_module(linear, process_group=None, parallel_input=False) + linear_row = Linear1D_Row.from_native_module(linear, + process_group=None, + parallel_input=False, + seq_parallel=seq_parallel) assert linear_row.weight.shape == torch.Size([128, 16]) assert linear_row.bias.shape == torch.Size([128]) # check computation correctness - x = torch.rand(4, 32).cuda() + # [batch_size, seq_len, hidden_size] + x = torch.rand(2, 4, 32).cuda() x_for_unshard = x.expand_as(x.clone()) x_for_unshard.requires_grad_(True) x_for_shard = x.expand_as(x.clone()) @@ -67,7 +77,8 @@ def check_linear_1d_row(): # run forward out = linear(x_for_unshard) gather_out = linear_row(x_for_shard) - assert_close(out, gather_out) + target_out = out if seq_parallel is False else torch.chunk(out.clone(), 2, dim=1)[dist.get_rank()] + assert_close(target_out, gather_out) # check backward correctness out.sum().backward() @@ -83,23 +94,31 @@ def check_linear_1d_row(): assert_close(x_for_unshard.grad, x_for_shard.grad) -def check_linear_col_plus_row(): +def check_linear_col_plus_row(seq_parallel): linear_1 = nn.Linear(32, 128).cuda() linear_2 = nn.Linear(128, 32).cuda() - linear_col = Linear1D_Col.from_native_module(linear_1, process_group=None, gather_output=False) - linear_row = Linear1D_Row.from_native_module(linear_2, process_group=None, parallel_input=True) + linear_col = Linear1D_Col.from_native_module(linear_1, + process_group=None, + gather_output=False, + seq_parallel=seq_parallel) + linear_row = Linear1D_Row.from_native_module(linear_2, + process_group=None, + parallel_input=True, + seq_parallel=seq_parallel) # check computation correctness - x = torch.rand(4, 32).cuda() + # [batch_size, seq_len, hidden_size] + x = torch.rand(2, 4, 32).cuda() x_for_unshard = x.expand_as(x.clone()) x_for_unshard.requires_grad_(True) - x_for_shard = x.expand_as(x.clone()) + x_for_shard = x.expand_as(x.clone()) if seq_parallel is False else torch.chunk(x.clone(), 2, dim=1)[dist.get_rank()] x_for_shard.requires_grad_(True) # run forward unshard_out = linear_2(linear_1(x_for_unshard)) shard_out = linear_row(linear_col(x_for_shard)) - assert_close(unshard_out, shard_out) + target_out = unshard_out if seq_parallel is False else torch.chunk(unshard_out.clone(), 2, dim=1)[dist.get_rank()] + assert_close(target_out, shard_out) # check backward correctness unshard_out.sum().backward() @@ -112,19 +131,26 @@ def check_linear_col_plus_row(): # check the input gradients assert x_for_shard.grad is not None assert x_for_unshard.grad is not None - assert_close(x_for_unshard.grad, x_for_shard.grad) + target_unshard_gard = x_for_unshard.grad if seq_parallel is False else torch.chunk( + x_for_unshard.grad.clone(), 2, dim=1)[dist.get_rank()] + assert_close(target_unshard_gard, x_for_shard.grad) + + +@parameterize('seq_parallel', [True, False]) +def run_dist_linear_test(seq_parallel): + check_linear_1d_col(seq_parallel) + check_linear_1d_row(seq_parallel) + check_linear_col_plus_row(seq_parallel) -def run_dist(rank, world_size, port): +def check_dist_linear(rank, world_size, port): colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - check_linear_1d_col() - check_linear_1d_row() - check_linear_col_plus_row() + run_dist_linear_test() @rerun_if_address_is_in_use() def test_linear(): - spawn(run_dist, nprocs=2) + spawn(check_dist_linear, nprocs=2) if __name__ == '__main__':