From c23bb45bcd5b625eafc74cf7971f1c4251608ce6 Mon Sep 17 00:00:00 2001 From: FoolPlayer <45593998+FoolPlayer@users.noreply.github.com> Date: Thu, 27 Jul 2023 10:47:28 +0800 Subject: [PATCH 01/10] [sequence parallel] add sequence parallel linear col/row support (#4336) * add sequence parallel linear col/row support * add annotation * add annotation --- colossalai/shardformer/layer/_operation.py | 156 ++++++++++++++++-- colossalai/shardformer/layer/linear.py | 17 +- .../test_layer/test_linear_1d.py | 72 +++++--- 3 files changed, 206 insertions(+), 39 deletions(-) diff --git a/colossalai/shardformer/layer/_operation.py b/colossalai/shardformer/layer/_operation.py index 7e97bee01b33..fcd43bd857a4 100644 --- a/colossalai/shardformer/layer/_operation.py +++ b/colossalai/shardformer/layer/_operation.py @@ -141,6 +141,105 @@ def backward(ctx, grad_output): return grad_input, grad_weight, grad_bias, None, None, None +class _LinearWithGatherForwardReduceScatterBackward(torch.autograd.Function): + """Gather input from sequence parallel in forward and reduce-scatter gradient in backward + + Args: + input_ (`torch.Tensor`): The input tensor from sequence parallel region. + process_group (`torch.distributed.ProcessGroup`): The process group used for collective communication. + + """ + + @staticmethod + def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim): + ctx.save_for_backward(input_, weight) + ctx.use_bias = bias is not None + ctx.process_group = process_group + ctx.async_grad_reduce_scatter = async_grad_reduce_scatter + ctx.dim = dim + + input_parallel = _gather(input_, dim, process_group) + + if bias is not None: + output = F.linear(input_parallel, weight, bias) + else: + output = F.linear(input_parallel, weight) + + return output + + @staticmethod + def backward(ctx, grad_output): + input_, weight = ctx.saved_tensors + use_bias = ctx.use_bias + dim = ctx.dim + process_group = ctx.process_group + + # TODO: overlap SP input with gradient computation + input_parallel = _gather(input_, dim, process_group) + + total_input = input_parallel + grad_input = grad_output.matmul(weight) + grad_output = grad_output.contiguous() + # Convert the tensor shapes to 2D for execution compatibility + if len(grad_output.shape) > 2: + grad_output = grad_output.view(-1, grad_output.shape[-1]) + total_input = total_input.view(-1, total_input.shape[-1]) + + # TODO: overlap SP input with gradient computation + if ctx.async_grad_reduce_scatter: + # Asynchronous reduce-scatter + new_shape = list(input_parallel.shape) + new_shape[dim] = new_shape[dim] // dist.get_world_size(process_group) + input_list = [ + item.contiguous() for item in torch.chunk(grad_input, dist.get_world_size(process_group), dim=dim) + ] + output = torch.empty(new_shape, dtype=input_parallel.dtype, device=input_parallel.device).contiguous() + handle = dist.reduce_scatter(output, input_list, group=process_group, async_op=True) + # Delay the start of weight gradient computation shortly (3us) to have + # reduce-scatter scheduled first and have GPU resources allocated + _ = torch.empty(1, device=grad_output.device) + 1 + + grad_weight = grad_output.t().matmul(total_input) + grad_bias = grad_output.sum(dim=0) if use_bias else None + + if ctx.async_grad_reduce_scatter: + handle.wait() + + # print(output, output.shape) + return output, grad_weight, grad_bias, None, None, None + + +class _LinearWithReduceScatterForwardGatherBackward(torch.autograd.Function): + """Gather input from sequence parallel in forward and reduce-scatter gradient in backward + + Args: + input_ (`torch.Tensor`): The input tensor from sequence parallel region. + process_group (`torch.distributed.ProcessGroup`): The process group used for collective communication. + + """ + + @staticmethod + def forward(ctx, input_, process_group, dim): + ctx.dim = dim + ctx.process_group = process_group + + # do reduce-scatter + new_shape = list(input_.shape) + new_shape[dim] = new_shape[dim] // dist.get_world_size(process_group) + input_list = [item.contiguous() for item in torch.chunk(input_, dist.get_world_size(process_group), dim=dim)] + output = torch.empty(new_shape, dtype=input_.dtype, device=input_.device) + dist.reduce_scatter(output, input_list, group=process_group) + + return output + + @staticmethod + def backward(ctx, grad_output): + dim = ctx.dim + process_group = ctx.process_group + + return _gather(grad_output, dim, process_group), None, None + + class _SplitForwardGatherBackward(torch.autograd.Function): """ Split the input and keep only the corresponding chuck to the rank. @@ -200,6 +299,26 @@ def backward(ctx, grad_output): return _reduce(grad_output, ctx.process_group), None +class _GatherForwardSplitBackward(torch.autograd.Function): + """Gather the input from model parallel region and concatenate. + + Args: + input_: input matrix. + parallel_mode: parallel mode. + dim: dimension + """ + + @staticmethod + def forward(ctx, input_, dim, process_group): + ctx.process_group = process_group + ctx.dim = dim + return _gather(input_, dim, process_group) + + @staticmethod + def backward(ctx, grad_output): + return _split(grad_output, ctx.dim, ctx.process_group), None, None + + def _reduce(input_, process_group): # skip if only one rank involved if dist.get_world_size(process_group) == 1: @@ -235,6 +354,7 @@ def _gather(input_, dim=-1, process_group=None): return input_ # all gather + input_ = input_.contiguous() rank = dist.get_rank(process_group) tensor_list = [torch.empty_like(input_) for _ in range(world_size)] tensor_list[rank] = input_ @@ -246,24 +366,25 @@ def _gather(input_, dim=-1, process_group=None): return output -class _GatherForwardSplitBackward(torch.autograd.Function): - """Gather the input from model parallel region and concatenate. +def _reduce_scatter(intput_, dim=1, process_group=None): + """ Do reduce-scatter operation. Args: - input_: input matrix. - parallel_mode: parallel mode. - dim: dimension + input_ (`torch.Tensor`): The input tensor from sequence parallel region. + dim (int): The dimension to perform reduce-scatter. + process_group (`torch.distributed.ProcessGroup`): The process group used for collective communication. """ + world_size = dist.get_world_size(process_group) + if world_size == 1: + return intput_ - @staticmethod - def forward(ctx, input_, dim, process_group): - ctx.process_group = process_group - ctx.dim = dim - return _gather(input_, dim, process_group) + # reduce-scatter + new_shape = list(intput_.shape) + new_shape[dim] = new_shape[dim] // world_size + output = torch.empty(new_shape, dtype=intput_.dtype, device=intput_.device) + dist.reduce_scatter(output, intput_, group=process_group) - @staticmethod - def backward(ctx, grad_output): - return _split(grad_output, ctx.dim, ctx.process_group), None, None + return output def matmul_with_async_comm(input_, weight, bias, process_group, async_grad_allreduce): @@ -274,6 +395,15 @@ def linear_with_async_comm(input_, weight, bias, process_group, async_grad_allre return LinearWithAsyncCommunication.apply(input_, weight, bias, process_group, async_grad_allreduce) +def linear_gather_forward_reducescatter_backward(input_, weight, bias, process_group, async_grad_allreduce, dim): + return _LinearWithGatherForwardReduceScatterBackward.apply(input_, weight, bias, process_group, + async_grad_allreduce, dim) + + +def linear_reducescatter_forward_gather_backward(input_, process_group, dim): + return _LinearWithReduceScatterForwardGatherBackward.apply(input_, process_group, dim) + + def gather_forward_split_backward(input_, dim, process_group): return _GatherForwardSplitBackward.apply(input_, dim, process_group) diff --git a/colossalai/shardformer/layer/linear.py b/colossalai/shardformer/layer/linear.py index bb36854bd772..fc758cf8f2cf 100644 --- a/colossalai/shardformer/layer/linear.py +++ b/colossalai/shardformer/layer/linear.py @@ -24,6 +24,8 @@ from ._operation import ( gather_forward_split_backward, + linear_gather_forward_reducescatter_backward, + linear_reducescatter_forward_gather_backward, linear_with_async_comm, reduce_forward, split_forward_gather_backward, @@ -69,6 +71,7 @@ def __init__(self, device: torch.device = None, process_group: ProcessGroup = None, gather_output: bool = False, + seq_parallel: bool = False, skip_bias_add: bool = False, weight: Optional[Parameter] = None, bias_: Optional[Parameter] = None, @@ -80,6 +83,7 @@ def __init__(self, self.in_features = in_features self.out_features = out_features self.gather_output = gather_output + self.seq_parallel = seq_parallel self.skip_bias_add = skip_bias_add self.device = device self.process_group = process_group @@ -172,7 +176,11 @@ def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]: # Matrix multiply. bias = self.bias if not self.skip_bias_add else None - output_parallel = linear_with_async_comm(input_parallel, self.weight, bias, self.process_group, True) + if self.seq_parallel: + output_parallel = linear_gather_forward_reducescatter_backward(input_parallel, self.weight, bias, + self.process_group, True, 1) + else: + output_parallel = linear_with_async_comm(input_parallel, self.weight, bias, self.process_group, True) if self.gather_output: # All-gather across the partitions. @@ -213,6 +221,7 @@ def __init__(self, dtype: torch.dtype = None, device: torch.device = None, process_group: ProcessGroup = None, + seq_parallel: bool = False, parallel_input: bool = True, skip_bias_add: bool = False, weight: Optional[Parameter] = None, @@ -230,6 +239,7 @@ def __init__(self, self.parallel_input = parallel_input self.skip_bias_add = skip_bias_add self.process_group = process_group + self.seq_parallel = seq_parallel self.num_partitions = dist.get_world_size(self.process_group) if skip_bias_add and not bias: @@ -357,7 +367,10 @@ def forward(self, input_: Tensor) -> Tensor: output = torch.cat(output_parallel_list, dim=-1) else: output_parallel = F.linear(input_, self.weight) - output = reduce_forward(output_parallel, self.process_group) + if self.seq_parallel: + output = linear_reducescatter_forward_gather_backward(output_parallel, self.process_group, 1) + else: + output = reduce_forward(output_parallel, self.process_group) if not self.skip_bias_add: if self.bias is not None: diff --git a/tests/test_shardformer/test_layer/test_linear_1d.py b/tests/test_shardformer/test_layer/test_linear_1d.py index aa75879e0313..182c5eb36392 100644 --- a/tests/test_shardformer/test_layer/test_linear_1d.py +++ b/tests/test_shardformer/test_layer/test_linear_1d.py @@ -12,13 +12,15 @@ from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn -@parameterize('lazy_init', [False, True]) -def check_linear_1d_col(lazy_init: bool): +def check_linear_1d_col(lazy_init: bool, seq_parallel: bool): ctx = LazyInitContext() if lazy_init else nullcontext() linear = nn.Linear(32, 128).cuda() with ctx: linear_copy = nn.Linear(32, 128).cuda() - linear_col = Linear1D_Col.from_native_module(linear_copy, process_group=None, gather_output=True) + linear_col = Linear1D_Col.from_native_module(linear_copy, + process_group=None, + gather_output=True, + seq_parallel=seq_parallel) # ensure that the parameters are distributed assert is_distributed_tensor(linear_col.weight) @@ -35,10 +37,11 @@ def check_linear_1d_col(lazy_init: bool): linear_col.load_state_dict(linear.state_dict()) # check computation correctness - x = torch.rand(4, 32).cuda() + # [batch_size, seq_len, hidden_size] + x = torch.rand(2, 4, 32).cuda() x_for_unshard = x.expand_as(x.clone()) x_for_unshard.requires_grad_(True) - x_for_shard = x.expand_as(x.clone()) + x_for_shard = x.expand_as(x.clone()) if seq_parallel is False else torch.chunk(x.clone(), 2, dim=1)[dist.get_rank()] x_for_shard.requires_grad_(True) out = linear(x_for_unshard) @@ -56,17 +59,21 @@ def check_linear_1d_col(lazy_init: bool): # check the input gradients assert x_for_shard.grad is not None assert x_for_unshard.grad is not None - assert_close(x_for_unshard.grad, x_for_shard.grad) + target_unshard_gard = x_for_unshard.grad if seq_parallel is False else torch.chunk( + x_for_unshard.grad.clone(), 2, dim=1)[dist.get_rank()] + assert_close(target_unshard_gard, x_for_shard.grad) -@parameterize('lazy_init', [False, True]) -def check_linear_1d_row(lazy_init: bool): +def check_linear_1d_row(lazy_init: bool, seq_parallel: bool): ctx = LazyInitContext() if lazy_init else nullcontext() linear = nn.Linear(32, 128).cuda() with ctx: linear_copy = nn.Linear(32, 128).cuda() - linear_row = Linear1D_Row.from_native_module(linear_copy, process_group=None, parallel_input=False) + linear_row = Linear1D_Row.from_native_module(linear_copy, + process_group=None, + parallel_input=False, + seq_parallel=seq_parallel) assert linear_row.weight.shape == torch.Size([128, 16]) assert linear_row.bias.shape == torch.Size([128]) @@ -77,7 +84,8 @@ def check_linear_1d_row(lazy_init: bool): linear_row.load_state_dict(linear.state_dict()) # check computation correctness - x = torch.rand(4, 32).cuda() + # [batch_size, seq_len, hidden_size] + x = torch.rand(2, 4, 32).cuda() x_for_unshard = x.expand_as(x.clone()) x_for_unshard.requires_grad_(True) x_for_shard = x.expand_as(x.clone()) @@ -86,7 +94,8 @@ def check_linear_1d_row(lazy_init: bool): # run forward out = linear(x_for_unshard) gather_out = linear_row(x_for_shard) - assert_close(out, gather_out) + target_out = out if seq_parallel is False else torch.chunk(out.clone(), 2, dim=1)[dist.get_rank()] + assert_close(target_out, gather_out) # check backward correctness out.sum().backward() @@ -102,8 +111,7 @@ def check_linear_1d_row(lazy_init: bool): assert_close(x_for_unshard.grad, x_for_shard.grad) -@parameterize('lazy_init', [False, True]) -def check_linear_col_plus_row(lazy_init: bool): +def check_linear_col_plus_row(lazy_init: bool, seq_parallel: bool): ctx = LazyInitContext() if lazy_init else nullcontext() linear_1 = nn.Linear(32, 128).cuda() @@ -112,8 +120,14 @@ def check_linear_col_plus_row(lazy_init: bool): with ctx: linear_1_copy = nn.Linear(32, 128).cuda() linear_2_copy = nn.Linear(128, 32).cuda() - linear_col = Linear1D_Col.from_native_module(linear_1_copy, process_group=None, gather_output=False) - linear_row = Linear1D_Row.from_native_module(linear_2_copy, process_group=None, parallel_input=True) + linear_col = Linear1D_Col.from_native_module(linear_1_copy, + process_group=None, + gather_output=False, + seq_parallel=seq_parallel) + linear_row = Linear1D_Row.from_native_module(linear_2_copy, + process_group=None, + parallel_input=True, + seq_parallel=seq_parallel) linear_1.load_state_dict(linear_col.state_dict()) linear_col.load_state_dict(linear_1.state_dict()) @@ -121,16 +135,18 @@ def check_linear_col_plus_row(lazy_init: bool): linear_row.load_state_dict(linear_2.state_dict()) # check computation correctness - x = torch.rand(4, 32).cuda() + # [batch_size, seq_len, hidden_size] + x = torch.rand(2, 4, 32).cuda() x_for_unshard = x.expand_as(x.clone()) x_for_unshard.requires_grad_(True) - x_for_shard = x.expand_as(x.clone()) + x_for_shard = x.expand_as(x.clone()) if seq_parallel is False else torch.chunk(x.clone(), 2, dim=1)[dist.get_rank()] x_for_shard.requires_grad_(True) # run forward unshard_out = linear_2(linear_1(x_for_unshard)) shard_out = linear_row(linear_col(x_for_shard)) - assert_close(unshard_out, shard_out) + target_out = unshard_out if seq_parallel is False else torch.chunk(unshard_out.clone(), 2, dim=1)[dist.get_rank()] + assert_close(target_out, shard_out) # check backward correctness unshard_out.sum().backward() @@ -143,19 +159,27 @@ def check_linear_col_plus_row(lazy_init: bool): # check the input gradients assert x_for_shard.grad is not None assert x_for_unshard.grad is not None - assert_close(x_for_unshard.grad, x_for_shard.grad) + target_unshard_gard = x_for_unshard.grad if seq_parallel is False else torch.chunk( + x_for_unshard.grad.clone(), 2, dim=1)[dist.get_rank()] + assert_close(target_unshard_gard, x_for_shard.grad) + + +@parameterize('lazy_init', [False, True]) +@parameterize('seq_parallel', [False, True]) +def run_dist_linear_test(lazy_init, seq_parallel): + check_linear_1d_col(lazy_init, seq_parallel) + check_linear_1d_row(lazy_init, seq_parallel) + check_linear_col_plus_row(lazy_init, seq_parallel) -def run_dist(rank, world_size, port): +def check_dist_linear(rank, world_size, port): colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - check_linear_1d_col() - check_linear_1d_row() - check_linear_col_plus_row() + run_dist_linear_test() @rerun_if_address_is_in_use() def test_linear(): - spawn(run_dist, nprocs=2) + spawn(check_dist_linear, nprocs=2) if __name__ == '__main__': From bc4adde6460233b3c9bfba77549847fec9fd2bbd Mon Sep 17 00:00:00 2001 From: FoolPlayer <498107402@qq.com> Date: Fri, 4 Aug 2023 10:32:40 +0800 Subject: [PATCH 02/10] add support for gpt2 fused qkv linear layer --- colossalai/shardformer/layer/_operation.py | 79 ++++++++++++++++++- .../shardformer/layer/qkv_fused_linear.py | 26 ++++-- tests/kit/model_zoo/__init__.py | 2 +- .../test_gpt2_qkv_fused_linear_1d.py | 34 +++++--- 4 files changed, 120 insertions(+), 21 deletions(-) diff --git a/colossalai/shardformer/layer/_operation.py b/colossalai/shardformer/layer/_operation.py index fcd43bd857a4..cfebdafd6e29 100644 --- a/colossalai/shardformer/layer/_operation.py +++ b/colossalai/shardformer/layer/_operation.py @@ -1,3 +1,5 @@ +from typing import Any + import torch import torch.distributed as dist import torch.nn.functional as F @@ -240,6 +242,74 @@ def backward(ctx, grad_output): return _gather(grad_output, dim, process_group), None, None +class _MatmulWithGatherForwardReduceScatterBackward(torch.autograd.Function): + """ + This class is designed for matmul operation with gather forward and reduce-scatter backward. + + Args: + input_ (`torch.Tensor`): input matrix. + dim (int): the dimension to perform split and gather + process_group (`torch.distributed.ProcessGroup`): the process group used for collective communication + + """ + + @staticmethod + def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim): + ctx.save_for_backward(input_, weight) + ctx.use_bias = bias is not None + ctx.process_group = process_group + ctx.async_grad_reduce_scatter = async_grad_reduce_scatter + ctx.dim = dim + + input_parallel = _gather(input_, dim, process_group) + + output = torch.matmul(input_parallel, weight) + + if bias is not None: + output = output + bias + return output + + @staticmethod + def backward(ctx, grad_output): + input_, weight = ctx.saved_tensors + use_bias = ctx.use_bias + dim = ctx.dim + process_group = ctx.process_group + + # TODO: overlap SP input with gradient computation + input_parallel = _gather(input_, dim, process_group) + + total_input = input_parallel + grad_input = grad_output.matmul(weight.T) + grad_output = grad_output.contiguous() + # Convert the tensor shapes to 2D for execution compatibility + if len(grad_output.shape) > 2: + grad_output = grad_output.view(-1, grad_output.shape[-1]) + total_input = total_input.view(-1, total_input.shape[-1]) + + # TODO: overlap SP input with gradient computation + if ctx.async_grad_reduce_scatter: + # Asynchronous reduce-scatter + new_shape = list(input_parallel.shape) + new_shape[dim] = new_shape[dim] // dist.get_world_size(process_group) + input_list = [ + item.contiguous() for item in torch.chunk(grad_input, dist.get_world_size(process_group), dim=dim) + ] + output = torch.empty(new_shape, dtype=input_parallel.dtype, device=input_parallel.device).contiguous() + handle = dist.reduce_scatter(output, input_list, group=process_group, async_op=True) + # Delay the start of weight gradient computation shortly (3us) to have + # reduce-scatter scheduled first and have GPU resources allocated + _ = torch.empty(1, device=grad_output.device) + 1 + + grad_weight = total_input.t().matmul(grad_output) + grad_bias = grad_output.sum(dim=0) if use_bias else None + + if ctx.async_grad_reduce_scatter: + handle.wait() + + return grad_input, grad_weight, grad_bias, None, None, None + + class _SplitForwardGatherBackward(torch.autograd.Function): """ Split the input and keep only the corresponding chuck to the rank. @@ -395,15 +465,20 @@ def linear_with_async_comm(input_, weight, bias, process_group, async_grad_allre return LinearWithAsyncCommunication.apply(input_, weight, bias, process_group, async_grad_allreduce) -def linear_gather_forward_reducescatter_backward(input_, weight, bias, process_group, async_grad_allreduce, dim): +def linear_gather_forward_reducescatter_backward(input_, weight, bias, process_group, async_grad_reduce_scatter, dim): return _LinearWithGatherForwardReduceScatterBackward.apply(input_, weight, bias, process_group, - async_grad_allreduce, dim) + async_grad_reduce_scatter, dim) def linear_reducescatter_forward_gather_backward(input_, process_group, dim): return _LinearWithReduceScatterForwardGatherBackward.apply(input_, process_group, dim) +def matmul_gather_forward_reducescatter_backward(input_, weight, bias, process_group, async_grad_reduce_scatter, dim): + return _MatmulWithGatherForwardReduceScatterBackward.apply(input_, weight, bias, process_group, + async_grad_reduce_scatter, dim) + + def gather_forward_split_backward(input_, dim, process_group): return _GatherForwardSplitBackward.apply(input_, dim, process_group) diff --git a/colossalai/shardformer/layer/qkv_fused_linear.py b/colossalai/shardformer/layer/qkv_fused_linear.py index 42417f8bcc43..91402bcdd91d 100644 --- a/colossalai/shardformer/layer/qkv_fused_linear.py +++ b/colossalai/shardformer/layer/qkv_fused_linear.py @@ -25,7 +25,9 @@ from ._operation import ( gather_forward_split_backward, + linear_reducescatter_forward_gather_backward, linear_with_async_comm, + matmul_gather_forward_reducescatter_backward, matmul_with_async_comm, reduce_backward, reduce_forward, @@ -173,6 +175,7 @@ def __init__(self, process_group: ProcessGroup = None, async_communication: bool = False, gather_output: bool = False, + seq_parallel: bool = False, skip_bias_add: bool = False, n_fused: int = 3, weight: Optional[Parameter] = None, @@ -185,6 +188,7 @@ def __init__(self, self.in_features = in_features self.out_features = out_features self.gather_output = gather_output + self.seq_parallel = seq_parallel self.skip_bias_add = skip_bias_add self.device = device self.n_fused = n_fused @@ -288,15 +292,20 @@ def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]: assert input_.shape[-1] == self.weight.shape[0], \ 'Invalid shapes in Linear1D_Col forward: input={}, weight={}. Expected last dim of input {}.'.format( input_.shape, self.weight.shape, self.weight.shape[-1]) - # Set up backprop all-reduce. - input_parallel = reduce_backward(input_, self.process_group) - # input_parallel = input_ # Matrix multiply. bias = self.bias if not self.skip_bias_add else None - output_parallel = matmul_with_async_comm(input_parallel, self.weight, bias, self.process_group, - self.async_communication) + if self.seq_parallel: + input_parallel = input_ + output_parallel = matmul_gather_forward_reducescatter_backward(input_parallel, self.weight, bias, + self.process_group, self.async_communication, + 1) + else: + # Set up backprop all-reduce. + input_parallel = reduce_backward(input_, self.process_group) + output_parallel = matmul_with_async_comm(input_parallel, self.weight, bias, self.process_group, + self.async_communication) if self.gather_output: # All-gather across the partitions. @@ -338,6 +347,7 @@ def __init__(self, dtype: torch.dtype = None, device: torch.device = None, process_group: ProcessGroup = None, + seq_parallel: bool = True, parallel_input: bool = True, skip_bias_add: bool = False, weight: Optional[Parameter] = None, @@ -355,6 +365,7 @@ def __init__(self, self.parallel_input = parallel_input self.skip_bias_add = skip_bias_add self.process_group = process_group + self.seq_parallel = seq_parallel self.num_partitions = dist.get_world_size(self.process_group) if skip_bias_add and not bias: @@ -483,7 +494,10 @@ def forward(self, input_: Tensor) -> Tensor: output = torch.cat(output_parallel_list, dim=-1) else: output_parallel = torch.matmul(input_, self.weight) - output = reduce_forward(output_parallel, self.process_group) + if self.seq_parallel: + output = linear_reducescatter_forward_gather_backward(output_parallel, self.process_group, 1) + else: + output = reduce_forward(output_parallel, self.process_group) if not self.skip_bias_add: if self.bias is not None: diff --git a/tests/kit/model_zoo/__init__.py b/tests/kit/model_zoo/__init__.py index 466a2a558829..82a61626b6c2 100644 --- a/tests/kit/model_zoo/__init__.py +++ b/tests/kit/model_zoo/__init__.py @@ -1,4 +1,4 @@ -from . import diffusers, timm, torchaudio, torchrec, torchvision, transformers +from . import diffusers, timm, torchaudio, torchvision, transformers from .registry import model_zoo __all__ = ['model_zoo'] diff --git a/tests/test_shardformer/test_layer/test_gpt2_qkv_fused_linear_1d.py b/tests/test_shardformer/test_layer/test_gpt2_qkv_fused_linear_1d.py index b45cd172c3ca..ae6a1dc90dc5 100644 --- a/tests/test_shardformer/test_layer/test_gpt2_qkv_fused_linear_1d.py +++ b/tests/test_shardformer/test_layer/test_gpt2_qkv_fused_linear_1d.py @@ -53,8 +53,7 @@ def rearrange(tensor: torch.Tensor, dim: int): return rearanged_tensor -@parameterize('lazy_init', [False, True]) -def check_linear_conv_1d_col(lazy_init: bool): +def check_linear_conv_1d_col(lazy_init: bool, seq_parallel: bool): ctx = LazyInitContext() if lazy_init else nullcontext() linear = Conv1D(192, 48).cuda() with ctx: @@ -62,6 +61,7 @@ def check_linear_conv_1d_col(lazy_init: bool): linear_conv_col = GPT2FusedLinearConv1D_Col.from_native_module(linear_copy, process_group=None, gather_output=True, + seq_parallel=seq_parallel, n_fused=3) assert linear.weight.shape == torch.Size([48, 192]) @@ -76,10 +76,11 @@ def check_linear_conv_1d_col(lazy_init: bool): linear.load_state_dict(linear_conv_col.state_dict()) # check computation correctness - x = torch.rand(4, 48).cuda() + x = torch.rand(1, 4, 48).cuda() out = linear(x) - gather_out = linear_conv_col(x) - assert_close(rearrange(out, 1), gather_out) + x_for_shard = x.expand_as(x.clone()) if seq_parallel is False else torch.chunk(x.clone(), 2, dim=1)[dist.get_rank()] + gather_out = linear_conv_col(x_for_shard) + assert_close(rearrange(out, -1), gather_out) # check backward correctness out.sum().backward() @@ -89,14 +90,16 @@ def check_linear_conv_1d_col(lazy_init: bool): assert_close(target_grad, linear_conv_col.weight.grad) -@parameterize('lazy_init', [False, True]) -def check_linear_conv_1d_row(lazy_init: bool): +def check_linear_conv_1d_row(lazy_init: bool, seq_parallel: bool): ctx = LazyInitContext() if lazy_init else nullcontext() linear = Conv1D(192, 48).cuda() with ctx: linear_copy = Conv1D(192, 48).cuda() - linear_row = GPT2FusedLinearConv1D_Row.from_native_module(linear_copy, process_group=None, parallel_input=False) + linear_row = GPT2FusedLinearConv1D_Row.from_native_module(linear_copy, + process_group=None, + parallel_input=False, + seq_parallel=seq_parallel) assert linear.weight.shape == torch.Size([48, 192]) assert linear_row.weight.shape == torch.Size([24, 192]) @@ -109,10 +112,11 @@ def check_linear_conv_1d_row(lazy_init: bool): linear.load_state_dict(linear_row.state_dict()) # check computation correctness - x = torch.rand(4, 48).cuda() + x = torch.rand(1, 4, 48).cuda() out = linear(x) gather_out = linear_row(x) - assert_close(out, gather_out) + target_out = out if seq_parallel is False else torch.chunk(out.clone(), 2, dim=1)[dist.get_rank()] + assert_close(target_out, gather_out) # check backward correctness out.sum().backward() @@ -123,12 +127,18 @@ def check_linear_conv_1d_row(lazy_init: bool): assert_close(target_grad, linear_row.weight.grad) +@parameterize('lazy_init', [False, True]) +@parameterize('seq_parallel', [False, True]) +def check_gpt2_qkv_fused_linear_1d(lazy_init: bool, seq_parallel: bool): + check_linear_conv_1d_col(lazy_init, seq_parallel) + check_linear_conv_1d_row(lazy_init, seq_parallel) + + def run_dist(rank, world_size, port): colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') # test for linear conv - check_linear_conv_1d_col() - check_linear_conv_1d_row() + check_gpt2_qkv_fused_linear_1d() @rerun_if_address_is_in_use() From e0b5a329ac88bc647e802ab5599a7aeff8b43f03 Mon Sep 17 00:00:00 2001 From: FoolPlayer <498107402@qq.com> Date: Fri, 4 Aug 2023 17:46:19 +0800 Subject: [PATCH 03/10] support sequence parallel in GPT2 --- .../booster/plugin/hybrid_parallel_plugin.py | 5 +- colossalai/shardformer/layer/_operation.py | 2 +- .../shardformer/layer/qkv_fused_linear.py | 5 +- colossalai/shardformer/modeling/gpt2_seq.py | 221 ++++++++++++++++++ .../shardformer/policies/base_policy.py | 26 ++- colossalai/shardformer/policies/gpt2.py | 96 ++++---- colossalai/shardformer/shard/shard_config.py | 1 + tests/kit/model_zoo/transformers/gpt.py | 2 +- tests/test_shardformer/test_model/_utils.py | 20 +- .../test_model/test_shard_gpt2.py | 8 +- 10 files changed, 331 insertions(+), 55 deletions(-) create mode 100644 colossalai/shardformer/modeling/gpt2_seq.py diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index a22bdb7199bb..0a6d3ba0021e 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -147,6 +147,7 @@ def __init__( zero_stage: int = 0, cpu_offload: bool = False, enable_fused_normalization: bool = False, + enable_sequence_parallelism: bool = False, num_microbatches: Optional[int] = None, initial_scale: float = 2**16, min_scale: float = 1, @@ -170,6 +171,7 @@ def __init__( self.zero_stage = zero_stage self.cpu_offload = cpu_offload self.enable_fused_normalization = enable_fused_normalization + self.enable_sequence_parallelism = enable_sequence_parallelism self.pg_mesh = ProcessGroupMesh(self.dp_size, self.pp_size, self.tp_size) self.stage_manager = None self.schedule = None @@ -184,7 +186,8 @@ def __init__( self.shard_config = ShardConfig(tensor_parallel_process_group=self.tp_group, pipeline_stage_manager=self.stage_manager, enable_tensor_parallelism=self.tp_size > 1, - enable_fused_normalization=self.enable_fused_normalization) + enable_fused_normalization=self.enable_fused_normalization, + enable_sequence_parallelism=enable_sequence_parallelism) self.amp_config = dict( initial_scale=initial_scale, growth_factor=growth_factor, diff --git a/colossalai/shardformer/layer/_operation.py b/colossalai/shardformer/layer/_operation.py index cfebdafd6e29..b6fa698a3131 100644 --- a/colossalai/shardformer/layer/_operation.py +++ b/colossalai/shardformer/layer/_operation.py @@ -307,7 +307,7 @@ def backward(ctx, grad_output): if ctx.async_grad_reduce_scatter: handle.wait() - return grad_input, grad_weight, grad_bias, None, None, None + return output, grad_weight, grad_bias, None, None, None class _SplitForwardGatherBackward(torch.autograd.Function): diff --git a/colossalai/shardformer/layer/qkv_fused_linear.py b/colossalai/shardformer/layer/qkv_fused_linear.py index 91402bcdd91d..09662efeacd8 100644 --- a/colossalai/shardformer/layer/qkv_fused_linear.py +++ b/colossalai/shardformer/layer/qkv_fused_linear.py @@ -299,8 +299,7 @@ def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]: if self.seq_parallel: input_parallel = input_ output_parallel = matmul_gather_forward_reducescatter_backward(input_parallel, self.weight, bias, - self.process_group, self.async_communication, - 1) + self.process_group, True, 1) else: # Set up backprop all-reduce. input_parallel = reduce_backward(input_, self.process_group) @@ -347,7 +346,7 @@ def __init__(self, dtype: torch.dtype = None, device: torch.device = None, process_group: ProcessGroup = None, - seq_parallel: bool = True, + seq_parallel: bool = False, parallel_input: bool = True, skip_bias_add: bool = False, weight: Optional[Parameter] = None, diff --git a/colossalai/shardformer/modeling/gpt2_seq.py b/colossalai/shardformer/modeling/gpt2_seq.py new file mode 100644 index 000000000000..2538412a1098 --- /dev/null +++ b/colossalai/shardformer/modeling/gpt2_seq.py @@ -0,0 +1,221 @@ +# this code is modified from transformers.models.gpt2.modeling_gpt2 +# https://github.com/huggingface/transformers/blob/v4.31.0/src/transformers/models/gpt2/modeling_gpt2.py#L670 + +from typing import Optional, Tuple, Union + +import torch +import torch.distributed as dist +from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions +from transformers.utils import logging + +from colossalai.shardformer.layer._operation import gather_forward_split_backward, split_forward_gather_backward +from colossalai.shardformer.shard import ShardConfig + +logger = logging.get_logger(__name__) + + +def seq_forward_fn(shard_config: ShardConfig): + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = (output_hidden_states + if output_hidden_states is not None else self.config.output_hidden_states) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + batch_size = input_ids.shape[0] + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + batch_size = inputs_embeds.shape[0] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + device = input_ids.device if input_ids is not None else inputs_embeds.device + + if token_type_ids is not None: + token_type_ids = token_type_ids.view(-1, input_shape[-1]) + if position_ids is not None: + position_ids = position_ids.view(-1, input_shape[-1]) + + if past_key_values is None: + past_length = 0 + past_key_values = tuple([None] * len(self.h)) + else: + past_length = past_key_values[0][0].size(-2) + if position_ids is None: + position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device) + position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1]) + + # GPT2Attention mask. + if attention_mask is not None: + if batch_size <= 0: + raise ValueError("batch_size has to be defined and > 0") + attention_mask = attention_mask.view(batch_size, -1) + # We create a 3D attention mask from a 2D tensor mask. + # Sizes are [batch_size, 1, 1, to_seq_length] + # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] + # this attention mask is more simple than the triangular masking of causal attention + # used in OpenAI GPT, we just need to prepare the broadcast dimension here. + attention_mask = attention_mask[:, None, None, :] + + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and the dtype's smallest value for masked positions. + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility + attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min + + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if self.config.add_cross_attention and encoder_hidden_states is not None: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() + encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) + if encoder_attention_mask is None: + encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) + encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask) + else: + encoder_attention_mask = None + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # head_mask has shape n_layer x batch x n_heads x N x N + head_mask = self.get_head_mask(head_mask, self.config.n_layer) + + if inputs_embeds is None: + inputs_embeds = self.wte(input_ids) + position_embeds = self.wpe(position_ids) + hidden_states = inputs_embeds + position_embeds + + if token_type_ids is not None: + token_type_embeds = self.wte(token_type_ids) + hidden_states = hidden_states + token_type_embeds + + hidden_states = self.drop(hidden_states) + + output_shape = input_shape + (hidden_states.size(-1),) + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...") + use_cache = False + + presents = () if use_cache else None + all_self_attentions = () if output_attentions else None + all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None + all_hidden_states = () if output_hidden_states else None + + # split the input tensor along sequence dimension + # [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size] + hidden_states = split_forward_gather_backward(hidden_states, + dim=1, + process_group=shard_config.tensor_parallel_process_group) + + for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): + # Model parallel + if self.model_parallel: + torch.cuda.set_device(hidden_states.device) + # Ensure layer_past is on same device as hidden_states (might not be correct) + if layer_past is not None: + layer_past = tuple(past_state.to(hidden_states.device) for past_state in layer_past) + # Ensure that attention_mask is always on the same device as hidden_states + if attention_mask is not None: + attention_mask = attention_mask.to(hidden_states.device) + if isinstance(head_mask, torch.Tensor): + head_mask = head_mask.to(hidden_states.device) + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if self.gradient_checkpointing and self.training: + + def create_custom_forward(module): + + def custom_forward(*inputs): + # None for past_key_value + return module(*inputs, use_cache, output_attentions) + + return custom_forward + + outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + None, + attention_mask, + head_mask[i], + encoder_hidden_states, + encoder_attention_mask, + ) + else: + outputs = block( + hidden_states, + layer_past=layer_past, + attention_mask=attention_mask, + head_mask=head_mask[i], + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + use_cache=use_cache, + output_attentions=output_attentions, + ) + + hidden_states = outputs[0] + if use_cache is True: + presents = presents + (outputs[1],) + + if output_attentions: + all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) + if self.config.add_cross_attention: + all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],) + + # Model Parallel: If it's the last layer for that device, put things on the next device + if self.model_parallel: + for k, v in self.device_map.items(): + if i == v[-1] and "cuda:" + str(k) != self.last_device: + hidden_states = hidden_states.to("cuda:" + str(k + 1)) + + # When sequence parallelism done, gather the output tensor in forward and split it in backward + hidden_states = gather_forward_split_backward(hidden_states, + dim=1, + process_group=shard_config.tensor_parallel_process_group) + + hidden_states = self.ln_f(hidden_states) + hidden_states = hidden_states.view(output_shape) + # Add last hidden state + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple( + v for v in [hidden_states, presents, all_hidden_states, all_self_attentions, all_cross_attentions] + if v is not None) + + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=presents, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + cross_attentions=all_cross_attentions, + ) + + return forward diff --git a/colossalai/shardformer/policies/base_policy.py b/colossalai/shardformer/policies/base_policy.py index 69493bfb6007..7022a1cfd7a2 100644 --- a/colossalai/shardformer/policies/base_policy.py +++ b/colossalai/shardformer/policies/base_policy.py @@ -11,17 +11,12 @@ from colossalai.pipeline.stage_manager import PipelineStageManager +from ..layer.parallel_module import ParallelModule from ..shard.shard_config import ShardConfig __all__ = ["ParallelModule", "SubModuleReplacementDescription", "ModulePolicyDescription", "Policy"] -class ParallelModule(): - - def __init__(self): - pass - - @dataclass class SubModuleReplacementDescription: r""" @@ -231,3 +226,22 @@ def get_stage_index(layers_per_stage: List[int], stage: int) -> List[int]: end_idx = num_layers_per_stage_accumulated[stage + 1] return [start_idx, end_idx] + + def append_seq_parallel_to_policy( + self, + suffix_list: List[str], + module_policy_description: ModulePolicyDescription, + ): + r""" + Append the sequence parallel policy to the policy for the given key. + + Args: + suffix_list (List[str]): the suffix list of the module to be parallelized + policy (Dict[Union[str, nn.Module], ModulePolicyDescription]): the policy to be updated + """ + + for sub_description in module_policy_description.sub_module_replacement: + if (sub_description.suffix in suffix_list): + if sub_description.kwargs is None: + sub_description.kwargs = {} + sub_description.kwargs["seq_parallel"] = True diff --git a/colossalai/shardformer/policies/gpt2.py b/colossalai/shardformer/policies/gpt2.py index 6d734b063036..72d7143a2c74 100644 --- a/colossalai/shardformer/policies/gpt2.py +++ b/colossalai/shardformer/policies/gpt2.py @@ -6,6 +6,7 @@ import colossalai.shardformer.layer as col_nn from ..modeling.gpt2 import GPT2PipelineForwards +from ..modeling.gpt2_seq import seq_forward_fn from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription __all__ = [ @@ -48,47 +49,57 @@ def module_policy(self): target_module=col_nn.DropoutForParallelInput, ), ]) + if self.shard_config.enable_sequence_parallelism: + policy[GPT2Model] = ModulePolicyDescription( + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="wte", + target_module=col_nn.VocabParallelEmbedding1D, + ), + ], + method_replacement={"forward": seq_forward_fn(self.shard_config)}) + policy[GPT2Block] = ModulePolicyDescription(attribute_replacement={ "attn.embed_dim": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, "attn.split_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, "attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size, }, - sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="attn.c_attn", - target_module=col_nn.GPT2FusedLinearConv1D_Col, - kwargs={ - "n_fused": 3, - }, - ), - SubModuleReplacementDescription( - suffix="attn.c_proj", - target_module=col_nn.GPT2FusedLinearConv1D_Row, - ), - SubModuleReplacementDescription( - suffix="mlp.c_fc", - target_module=col_nn.GPT2FusedLinearConv1D_Col, - kwargs={ - "n_fused": 1, - }, - ), - SubModuleReplacementDescription( - suffix="mlp.c_proj", - target_module=col_nn.GPT2FusedLinearConv1D_Row, - ), - SubModuleReplacementDescription( - suffix="attn.attn_dropout", - target_module=col_nn.DropoutForParallelInput, - ), - SubModuleReplacementDescription( - suffix="attn.resid_dropout", - target_module=col_nn.DropoutForParallelInput, - ), - SubModuleReplacementDescription( - suffix="mlp.dropout", - target_module=col_nn.DropoutForParallelInput, - ), - ]) + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="attn.c_attn", + target_module=col_nn.GPT2FusedLinearConv1D_Col, + kwargs={ + "n_fused": 3, + }, + ), + SubModuleReplacementDescription( + suffix="attn.c_proj", + target_module=col_nn.GPT2FusedLinearConv1D_Row, + ), + SubModuleReplacementDescription( + suffix="mlp.c_fc", + target_module=col_nn.GPT2FusedLinearConv1D_Col, + kwargs={ + "n_fused": 1, + }, + ), + SubModuleReplacementDescription( + suffix="mlp.c_proj", + target_module=col_nn.GPT2FusedLinearConv1D_Row, + ), + SubModuleReplacementDescription( + suffix="attn.attn_dropout", + target_module=col_nn.DropoutForParallelInput, + ), + SubModuleReplacementDescription( + suffix="attn.resid_dropout", + target_module=col_nn.DropoutForParallelInput, + ), + SubModuleReplacementDescription( + suffix="mlp.dropout", + target_module=col_nn.DropoutForParallelInput, + ), + ]) # optimization configuration if self.shard_config.enable_fused_normalization: @@ -96,8 +107,8 @@ def module_policy(self): suffix="ln_f", target_module=col_nn.FusedLayerNorm, ), - policy=policy, - target_key=GPT2Model) + policy=policy, + target_key=GPT2Model) self.append_or_create_submodule_replacement(description=[ SubModuleReplacementDescription( @@ -112,8 +123,13 @@ def module_policy(self): target_module=col_nn.FusedLayerNorm, ignore_if_not_exist=True) ], - policy=policy, - target_key=GPT2Block) + policy=policy, + target_key=GPT2Block) + + if self.shard_config.enable_sequence_parallelism: + suffix_list = ["attn.c_attn", "attn.c_proj", "mlp.c_fc", "mlp.c_proj"] + self.append_seq_parallel_to_policy(suffix_list=suffix_list, module_policy_description=policy[GPT2Block]) + return policy def postprocess(self): diff --git a/colossalai/shardformer/shard/shard_config.py b/colossalai/shardformer/shard/shard_config.py index 75fad4eb7431..eb36946a8f1e 100644 --- a/colossalai/shardformer/shard/shard_config.py +++ b/colossalai/shardformer/shard/shard_config.py @@ -26,6 +26,7 @@ class ShardConfig: enable_tensor_parallelism: bool = True enable_fused_normalization: bool = False enable_all_optimization: bool = False + enable_sequence_parallelism: bool = False # TODO: add support for tensor parallel # pipeline_parallel_size: int diff --git a/tests/kit/model_zoo/transformers/gpt.py b/tests/kit/model_zoo/transformers/gpt.py index 73c210221e61..fbe45b1f3b44 100644 --- a/tests/kit/model_zoo/transformers/gpt.py +++ b/tests/kit/model_zoo/transformers/gpt.py @@ -67,7 +67,7 @@ def data_gen_for_sequence_classification(): config = transformers.GPT2Config(n_layer=2, n_head=4, - vocab_size=50258, + vocab_size=50260, attn_pdrop=0, embd_pdrop=0, resid_pdrop=0, diff --git a/tests/test_shardformer/test_model/_utils.py b/tests/test_shardformer/test_model/_utils.py index 46b262d0a8cd..592538cc5702 100644 --- a/tests/test_shardformer/test_model/_utils.py +++ b/tests/test_shardformer/test_model/_utils.py @@ -1,4 +1,5 @@ import copy +import math from contextlib import nullcontext from typing import Any, Callable, Dict, List, Optional @@ -19,7 +20,11 @@ from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor -def build_model(model_fn, enable_fused_normalization=True, enable_tensor_parallelism=True, use_lazy_init: bool = False): +def build_model(model_fn, + enable_fused_normalization=True, + enable_tensor_parallelism=True, + enable_sequence_parallelism=False, + use_lazy_init: bool = False): ctx = LazyInitContext() if use_lazy_init else nullcontext() with ctx: # create new model @@ -29,7 +34,8 @@ def build_model(model_fn, enable_fused_normalization=True, enable_tensor_paralle ctx.materialize(org_model) # shard model shard_config = ShardConfig(enable_fused_normalization=enable_fused_normalization, - enable_tensor_parallelism=enable_tensor_parallelism) + enable_tensor_parallelism=enable_tensor_parallelism, + enable_sequence_parallelism=enable_sequence_parallelism) shard_former = ShardFormer(shard_config=shard_config) sharded_model, shared_params = shard_former.optimize(model_copy) return org_model.cuda(), sharded_model.cuda() @@ -127,6 +133,16 @@ def _criterion(outputs, inputs): return loss data = data_gen_fn() + + if booster.plugin.enable_sequence_parallelism and booster.plugin.tp_size != 0: + seq_len = data['input_ids'].shape[1] + lcm = booster.plugin.tp_size * seq_len // math.gcd(booster.plugin.tp_size, seq_len) + times = lcm // seq_len + input_shape = data['input_ids'].shape + for k, v in data.items(): + if v.shape == input_shape: + data[k] = v.repeat(1, times) + sharded_model.train() if booster.plugin.stage_manager is not None: data = { diff --git a/tests/test_shardformer/test_model/test_shard_gpt2.py b/tests/test_shardformer/test_model/test_shard_gpt2.py index cebb40bd16fe..9e5dc3802be6 100644 --- a/tests/test_shardformer/test_model/test_shard_gpt2.py +++ b/tests/test_shardformer/test_model/test_shard_gpt2.py @@ -53,7 +53,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, sharded_gpt2 = sharded_model.unwrap().transformer col_layer_for_check = ['h[0].mlp.c_fc'] - row_layer_for_check = ['wte', 'h[0].mlp.c_proj'] + row_layer_for_check = ['h[0].mlp.c_proj', 'wte'] # check grad if stage_manager is None or stage_manager.is_first_stage(): @@ -85,6 +85,12 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, 'pp_size': 1, 'enable_fused_normalization': True, 'use_lazy_init': False +}, { + 'tp_size': 4, + 'pp_size': 1, + 'enable_fused_normalization': True, + 'use_lazy_init': True, + 'enable_sequence_parallelism': True }]) @clear_cache_before_run() def run_gpt2_test(test_config): From 314513f3ac69194a2a79e153bb6428059dfe19dd Mon Sep 17 00:00:00 2001 From: FoolPlayer <498107402@qq.com> Date: Mon, 7 Aug 2023 10:35:49 +0800 Subject: [PATCH 04/10] add docstring and note --- colossalai/shardformer/layer/linear.py | 3 +++ colossalai/shardformer/layer/qkv_fused_linear.py | 2 ++ tests/kit/model_zoo/__init__.py | 2 +- 3 files changed, 6 insertions(+), 1 deletion(-) diff --git a/colossalai/shardformer/layer/linear.py b/colossalai/shardformer/layer/linear.py index fc758cf8f2cf..0305529addb9 100644 --- a/colossalai/shardformer/layer/linear.py +++ b/colossalai/shardformer/layer/linear.py @@ -52,6 +52,7 @@ class Linear1D_Col(ParallelModule): gather_output (bool, optional): If true, call all-gather on output and make Y available to all GPUs, otherwise, every GPU will have its output which is :math:`Y_i = XA_i`, defaults to False + seq_parallel (`bool`): If set to ``True``, it will use sequence parallel, defaults to False. skip_bias_add (bool): If set to ``True``, it will skip bias add for linear layer, which is preserved for kernel fusion, defaults to False weight_initializer (`typing.Callable`): @@ -203,6 +204,8 @@ class Linear1D_Row(ParallelModule): bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``. dtype (`torch.dtype`): The dtype of parameters, defaults to None. parallel_input (bool): If set to ``True``, it's assumed that the input is split, defaults to False. + process_group (`torch.distributed.ProcessGroup`): The process group to be used for weight sharding and communication, defaults to None. + seq_parallel (`bool`): If set to ``True``, it will use sequence parallel, defaults to False. skip_bias_add (bool): If set to ``True``, it will skip bias add for linear layer, which is preserved for kernel fusion, defaults to False weight_initializer (:class:`typing.Callable`, optional): diff --git a/colossalai/shardformer/layer/qkv_fused_linear.py b/colossalai/shardformer/layer/qkv_fused_linear.py index 09662efeacd8..54131dc7cd20 100644 --- a/colossalai/shardformer/layer/qkv_fused_linear.py +++ b/colossalai/shardformer/layer/qkv_fused_linear.py @@ -152,6 +152,7 @@ class GPT2FusedLinearConv1D_Col(ParallelModule): device (`torch.device`): The device of parameters, defaults to None. n_fused (int): The number items fused, defaults to 3 (QKV). process_group (`torch.distributed.ProcessGroup`): The process group to be used for weight sharding and communication, defaults to None. + seq_parallel (`bool`): If set to ``True``, it will use sequence parallel, defaults to False. gather_output (bool, optional): If true, call all-gather on output and make Y available to all GPUs, otherwise, every GPU will have its output which is :math:`Y_i = XA_i`, defaults to False @@ -329,6 +330,7 @@ class GPT2FusedLinearConv1D_Row(ParallelModule): dtype (`torch.dtype`): The dtype of parameters, defaults to None. parallel_input (bool): If set to ``True``, it's assumed that the input is split, defaults to False. skip_bias_add (bool): If set to ``True``, it will skip bias add for linear layer, + seq_parallel (`bool`): If set to ``True``, it will use sequence parallel, defaults to False. which is preserved for kernel fusion, defaults to False weight_initializer (:class:`typing.Callable`, optional): The initializer of weight, defaults to kaiming uniform initializer. diff --git a/tests/kit/model_zoo/__init__.py b/tests/kit/model_zoo/__init__.py index 82a61626b6c2..466a2a558829 100644 --- a/tests/kit/model_zoo/__init__.py +++ b/tests/kit/model_zoo/__init__.py @@ -1,4 +1,4 @@ -from . import diffusers, timm, torchaudio, torchvision, transformers +from . import diffusers, timm, torchaudio, torchrec, torchvision, transformers from .registry import model_zoo __all__ = ['model_zoo'] From 4e902b36e8875d0a1ec787b609683f99192d174c Mon Sep 17 00:00:00 2001 From: FoolPlayer <498107402@qq.com> Date: Mon, 7 Aug 2023 10:53:28 +0800 Subject: [PATCH 05/10] add requirments --- requirements/requirements-test.txt | 2 ++ 1 file changed, 2 insertions(+) diff --git a/requirements/requirements-test.txt b/requirements/requirements-test.txt index 6f8a72e3962f..89e998f9812a 100644 --- a/requirements/requirements-test.txt +++ b/requirements/requirements-test.txt @@ -17,3 +17,5 @@ git+https://github.com/HazyResearch/flash-attention.git@c422fee3776eb3ea24e011ef requests==2.27.1 # downgrade to avoid huggingface error https://github.com/huggingface/transformers/issues/17611 SentencePiece datasets +ninja +flash-attn From 91e5389c0e6370734a34db01ba598d0ec03fefe4 Mon Sep 17 00:00:00 2001 From: FoolPlayer <498107402@qq.com> Date: Mon, 7 Aug 2023 11:13:48 +0800 Subject: [PATCH 06/10] remove unused flash-attb --- requirements/requirements-test.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements/requirements-test.txt b/requirements/requirements-test.txt index 89e998f9812a..fa797f26a4ca 100644 --- a/requirements/requirements-test.txt +++ b/requirements/requirements-test.txt @@ -13,7 +13,7 @@ torchrec==0.2.0 contexttimer einops triton==2.0.0.dev20221202 -git+https://github.com/HazyResearch/flash-attention.git@c422fee3776eb3ea24e011ef641fd5fbeb212623#egg=flash_attn +# git+https://github.com/HazyResearch/flash-attention.git@c422fee3776eb3ea24e011ef641fd5fbeb212623#egg=flash_attn requests==2.27.1 # downgrade to avoid huggingface error https://github.com/huggingface/transformers/issues/17611 SentencePiece datasets From 9616af927df139dd527b2768d1c3c3c33c982596 Mon Sep 17 00:00:00 2001 From: FoolPlayer <498107402@qq.com> Date: Mon, 7 Aug 2023 12:26:02 +0800 Subject: [PATCH 07/10] modify flash attn test --- tests/test_utils/test_flash_attention.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/test_utils/test_flash_attention.py b/tests/test_utils/test_flash_attention.py index 7a28b0157384..925cc7f92b30 100644 --- a/tests/test_utils/test_flash_attention.py +++ b/tests/test_utils/test_flash_attention.py @@ -24,7 +24,7 @@ def baseline_attention(Z, N_CTX, H, q, k, v, sm_scale): @pytest.mark.skipif(HAS_MEM_EFF_ATTN == False, reason="xformers is not available") @clear_cache_before_run() -@parameterize('B, S, H, D_HEAD', [(6, 8, 4, 16)]) +@parameterize('B, S, H, D_HEAD', [(1, 128, 4, 16)]) def test_attention_gpt(B, S, H, D_HEAD, dtype=torch.float16): D = H * D_HEAD @@ -45,7 +45,7 @@ def test_attention_gpt(B, S, H, D_HEAD, dtype=torch.float16): @pytest.mark.skipif(HAS_MEM_EFF_ATTN == False, reason="xformers is not available") @clear_cache_before_run() -@parameterize('B, S, H, D_HEAD', [(6, 8, 4, 16)]) +@parameterize('B, S, H, D_HEAD', [(1, 128, 4, 16)]) def test_attention_bert(B, S, H, D_HEAD, dtype=torch.float16): D = H * D_HEAD @@ -69,7 +69,7 @@ def test_attention_bert(B, S, H, D_HEAD, dtype=torch.float16): @pytest.mark.skipif(HAS_MEM_EFF_ATTN == False, reason="xformers is not available") @clear_cache_before_run() -@parameterize('B, S, H, D_HEAD', [(6, 8, 4, 16)]) +@parameterize('B, S, H, D_HEAD', [(1, 128, 4, 16)]) def test_attention_no_mask(B, S, H, D_HEAD, dtype=torch.float16): D = H * D_HEAD From 6660516005191c0e4f034577ef91a00d01fe955c Mon Sep 17 00:00:00 2001 From: FoolPlayer <498107402@qq.com> Date: Mon, 7 Aug 2023 12:26:48 +0800 Subject: [PATCH 08/10] modify flash attn setting --- tests/test_utils/test_flash_attention.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_utils/test_flash_attention.py b/tests/test_utils/test_flash_attention.py index 925cc7f92b30..3e21171c8f1f 100644 --- a/tests/test_utils/test_flash_attention.py +++ b/tests/test_utils/test_flash_attention.py @@ -89,7 +89,7 @@ def test_attention_no_mask(B, S, H, D_HEAD, dtype=torch.float16): @pytest.mark.skipif(HAS_MEM_EFF_ATTN == False, reason="xformers is not available") @clear_cache_before_run() -@parameterize('B, S, T, H, D_HEAD', [(6, 24, 8, 4, 16)]) +@parameterize('B, S, T, H, D_HEAD', [(1, 128, 8, 4, 16)]) def test_cross_attention(B, S, T, H, D_HEAD, dtype=torch.float16): D = H * D_HEAD From a6094486007f3cb3a744eeaf17b2e8f16b4d2b1f Mon Sep 17 00:00:00 2001 From: FoolPlayer <498107402@qq.com> Date: Mon, 7 Aug 2023 14:34:56 +0800 Subject: [PATCH 09/10] modify flash attn code --- .../kernel/cuda_native/flash_attention.py | 26 ++++++++++++++----- .../kernel/cuda_native/scaled_softmax.py | 5 ++-- tests/test_utils/test_flash_attention.py | 26 ++++++++++++------- 3 files changed, 39 insertions(+), 18 deletions(-) diff --git a/colossalai/kernel/cuda_native/flash_attention.py b/colossalai/kernel/cuda_native/flash_attention.py index 3db7374509a0..91bef0908dbb 100644 --- a/colossalai/kernel/cuda_native/flash_attention.py +++ b/colossalai/kernel/cuda_native/flash_attention.py @@ -6,6 +6,7 @@ import math import os import subprocess +import warnings import torch @@ -14,7 +15,7 @@ HAS_MEM_EFF_ATTN = True except ImportError: HAS_MEM_EFF_ATTN = False - print('please install xformers from https://github.com/facebookresearch/xformers') + warnings.warn(f'please install xformers from https://github.com/facebookresearch/xformers') if HAS_MEM_EFF_ATTN: @@ -22,7 +23,12 @@ from einops import rearrange from xformers.ops.fmha import MemoryEfficientAttentionCutlassOp - from xformers.ops.fmha.attn_bias import BlockDiagonalMask, LowerTriangularMask, LowerTriangularMaskWithTensorBias + from xformers.ops.fmha.attn_bias import ( + BlockDiagonalCausalMask, + BlockDiagonalMask, + LowerTriangularMask, + LowerTriangularMaskWithTensorBias, + ) from .scaled_softmax import AttnMaskType @@ -86,11 +92,14 @@ def backward(ctx, grad_output): class ColoAttention(torch.nn.Module): - def __init__(self, embed_dim: int, num_heads: int, dropout: float = 0.0): + def __init__(self, embed_dim: int, num_heads: int, dropout: float = 0.0, scale=None): super().__init__() assert embed_dim % num_heads == 0, \ f"the embed dim ({embed_dim}) is not divisible by the number of attention heads ({num_heads})." - self.scale = 1 / math.sqrt(embed_dim // num_heads) + if scale is not None: + self.scale = scale + else: + self.scale = 1 / math.sqrt(embed_dim // num_heads) self.dropout = dropout @staticmethod @@ -116,7 +125,7 @@ def forward(self, bias: Optional[torch.Tensor] = None): batch_size, tgt_len, src_len = query.shape[0], query.shape[1], key.shape[1] attn_bias = None - if attn_mask_type == AttnMaskType.padding: # bert style + if attn_mask_type and attn_mask_type.value % 2 == 1: # bert style assert attn_mask is not None, \ f"attention mask {attn_mask} is not valid for attention mask type {attn_mask_type}." assert attn_mask.dim() == 2, \ @@ -134,7 +143,10 @@ def forward(self, if batch_size > 1: query = rearrange(query, "b s ... -> c (b s) ...", c=1) key, value = self.unpad(torch.stack([query, key, value], dim=2), kv_indices).unbind(dim=2) - attn_bias = BlockDiagonalMask.from_seqlens(q_seqlen, kv_seqlen) + if attn_mask_type == AttnMaskType.padding: + attn_bias = BlockDiagonalMask.from_seqlens(q_seqlen, kv_seqlen) + elif attn_mask_type == AttnMaskType.paddedcausal: + attn_bias = BlockDiagonalCausalMask.from_seqlens(q_seqlen, kv_seqlen) elif attn_mask_type == AttnMaskType.causal: # gpt style attn_bias = LowerTriangularMask() @@ -146,7 +158,7 @@ def forward(self, out = memory_efficient_attention(query, key, value, attn_bias=attn_bias, p=self.dropout, scale=self.scale) - if attn_mask_type == AttnMaskType.padding and batch_size > 1: + if attn_mask_type and attn_mask_type.value % 2 == 1 and batch_size > 1: out = self.repad(out, q_indices, batch_size, tgt_len) out = rearrange(out, 'b s h d -> b s (h d)') diff --git a/colossalai/kernel/cuda_native/scaled_softmax.py b/colossalai/kernel/cuda_native/scaled_softmax.py index 24e458bb3ea5..41cd4b20faa1 100644 --- a/colossalai/kernel/cuda_native/scaled_softmax.py +++ b/colossalai/kernel/cuda_native/scaled_softmax.py @@ -19,6 +19,7 @@ class AttnMaskType(enum.Enum): padding = 1 causal = 2 + paddedcausal = 3 class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function): @@ -139,7 +140,7 @@ def is_kernel_available(self, mask, b, np, sq, sk): if 0 <= sk <= 2048: batch_per_block = self.get_batch_per_block(sq, sk, b, np) - if self.attn_mask_type == AttnMaskType.causal: + if self.attn_mask_type.value > 1: if attn_batches % batch_per_block == 0: return True else: @@ -151,7 +152,7 @@ def forward_fused_softmax(self, input, mask): b, np, sq, sk = input.size() scale = self.scale if self.scale is not None else 1.0 - if self.attn_mask_type == AttnMaskType.causal: + if self.attn_mask_type.value > 1: assert sq == sk, "causal mask is only for self attention" # input is 3D tensor (attn_batches, sq, sk) diff --git a/tests/test_utils/test_flash_attention.py b/tests/test_utils/test_flash_attention.py index 3e21171c8f1f..938f85b410e1 100644 --- a/tests/test_utils/test_flash_attention.py +++ b/tests/test_utils/test_flash_attention.py @@ -24,8 +24,9 @@ def baseline_attention(Z, N_CTX, H, q, k, v, sm_scale): @pytest.mark.skipif(HAS_MEM_EFF_ATTN == False, reason="xformers is not available") @clear_cache_before_run() -@parameterize('B, S, H, D_HEAD', [(1, 128, 4, 16)]) -def test_attention_gpt(B, S, H, D_HEAD, dtype=torch.float16): +@parameterize('proj_shape', [(1, 128, 4, 16)]) +def test_attention_gpt(proj_shape, dtype=torch.float16): + (B, S, H, D_HEAD) = proj_shape D = H * D_HEAD c_attn = torch.nn.Linear(D, 3 * D, dtype=dtype, device="cuda") @@ -35,7 +36,11 @@ def test_attention_gpt(B, S, H, D_HEAD, dtype=torch.float16): qkv = c_attn(x) q, k, v = rearrange(qkv, 'b s (n h d) -> n b s h d', n=3, h=H) - y = attn(q, k, v, attn_mask_type=AttnMaskType.causal) + + mask = [torch.ones(S - i, dtype=dtype, device="cuda") for i in range(B)] + mask = torch.nn.utils.rnn.pad_sequence(mask, batch_first=True) + + y = attn(q, k, v, attn_mask=mask, attn_mask_type=AttnMaskType.paddedcausal) assert list(y.shape) == [B, S, D] @@ -45,8 +50,9 @@ def test_attention_gpt(B, S, H, D_HEAD, dtype=torch.float16): @pytest.mark.skipif(HAS_MEM_EFF_ATTN == False, reason="xformers is not available") @clear_cache_before_run() -@parameterize('B, S, H, D_HEAD', [(1, 128, 4, 16)]) -def test_attention_bert(B, S, H, D_HEAD, dtype=torch.float16): +@parameterize('proj_shape', [(1, 128, 4, 16)]) +def test_attention_bert(proj_shape, dtype=torch.float16): + (B, S, H, D_HEAD) = proj_shape D = H * D_HEAD c_attn = torch.nn.Linear(D, 3 * D, dtype=dtype, device="cuda") @@ -69,8 +75,9 @@ def test_attention_bert(B, S, H, D_HEAD, dtype=torch.float16): @pytest.mark.skipif(HAS_MEM_EFF_ATTN == False, reason="xformers is not available") @clear_cache_before_run() -@parameterize('B, S, H, D_HEAD', [(1, 128, 4, 16)]) -def test_attention_no_mask(B, S, H, D_HEAD, dtype=torch.float16): +@parameterize('proj_shape', [(6, 128, 4, 16)]) +def test_attention_no_mask(proj_shape, dtype=torch.float16): + (B, S, H, D_HEAD) = proj_shape D = H * D_HEAD c_attn = torch.nn.Linear(D, 3 * D, dtype=dtype, device="cuda") @@ -89,8 +96,9 @@ def test_attention_no_mask(B, S, H, D_HEAD, dtype=torch.float16): @pytest.mark.skipif(HAS_MEM_EFF_ATTN == False, reason="xformers is not available") @clear_cache_before_run() -@parameterize('B, S, T, H, D_HEAD', [(1, 128, 8, 4, 16)]) -def test_cross_attention(B, S, T, H, D_HEAD, dtype=torch.float16): +@parameterize('proj_shape', [(6, 128, 256, 4, 16)]) +def test_cross_attention(proj_shape, dtype=torch.float16): + (B, S, T, H, D_HEAD) = proj_shape D = H * D_HEAD q_attn = torch.nn.Linear(D, D, dtype=dtype, device="cuda") From ce13b2e100af3d4129d97d5857c812e3d4da6ef7 Mon Sep 17 00:00:00 2001 From: FoolPlayer <498107402@qq.com> Date: Mon, 7 Aug 2023 17:03:27 +0800 Subject: [PATCH 10/10] add assert before divide, rename forward function --- colossalai/shardformer/layer/_operation.py | 8 ++++++++ colossalai/shardformer/modeling/gpt2_seq.py | 2 +- colossalai/shardformer/policies/gpt2.py | 4 ++-- tests/test_shardformer/test_model/test_shard_gpt2.py | 2 +- 4 files changed, 12 insertions(+), 4 deletions(-) diff --git a/colossalai/shardformer/layer/_operation.py b/colossalai/shardformer/layer/_operation.py index b6fa698a3131..fb7f4ab06760 100644 --- a/colossalai/shardformer/layer/_operation.py +++ b/colossalai/shardformer/layer/_operation.py @@ -191,6 +191,8 @@ def backward(ctx, grad_output): if ctx.async_grad_reduce_scatter: # Asynchronous reduce-scatter new_shape = list(input_parallel.shape) + assert new_shape[dim] % dist.get_world_size(process_group) == 0, \ + f'The dimension to split ({new_shape[dim]}) is not a multiple of tensor parallel size ({dist.get_world_size(process_group)}). ' new_shape[dim] = new_shape[dim] // dist.get_world_size(process_group) input_list = [ item.contiguous() for item in torch.chunk(grad_input, dist.get_world_size(process_group), dim=dim) @@ -227,6 +229,8 @@ def forward(ctx, input_, process_group, dim): # do reduce-scatter new_shape = list(input_.shape) + assert new_shape[dim] % dist.get_world_size(process_group) == 0, \ + f'The dimension to split ({new_shape[dim]}) is not a multiple of tensor parallel size ({dist.get_world_size(process_group)}). ' new_shape[dim] = new_shape[dim] // dist.get_world_size(process_group) input_list = [item.contiguous() for item in torch.chunk(input_, dist.get_world_size(process_group), dim=dim)] output = torch.empty(new_shape, dtype=input_.dtype, device=input_.device) @@ -291,6 +295,8 @@ def backward(ctx, grad_output): if ctx.async_grad_reduce_scatter: # Asynchronous reduce-scatter new_shape = list(input_parallel.shape) + assert new_shape[dim] % dist.get_world_size(process_group) == 0, \ + f'The dimension to split ({new_shape[dim]}) is not a multiple of tensor parallel size ({dist.get_world_size(process_group)}). ' new_shape[dim] = new_shape[dim] // dist.get_world_size(process_group) input_list = [ item.contiguous() for item in torch.chunk(grad_input, dist.get_world_size(process_group), dim=dim) @@ -450,6 +456,8 @@ def _reduce_scatter(intput_, dim=1, process_group=None): # reduce-scatter new_shape = list(intput_.shape) + assert new_shape[dim] % dist.get_world_size(process_group) == 0, \ + f'The dimension to split ({new_shape[dim]}) is not a multiple of tensor parallel size ({dist.get_world_size(process_group)}). ' new_shape[dim] = new_shape[dim] // world_size output = torch.empty(new_shape, dtype=intput_.dtype, device=intput_.device) dist.reduce_scatter(output, intput_, group=process_group) diff --git a/colossalai/shardformer/modeling/gpt2_seq.py b/colossalai/shardformer/modeling/gpt2_seq.py index 2538412a1098..07e2b84b6577 100644 --- a/colossalai/shardformer/modeling/gpt2_seq.py +++ b/colossalai/shardformer/modeling/gpt2_seq.py @@ -14,7 +14,7 @@ logger = logging.get_logger(__name__) -def seq_forward_fn(shard_config: ShardConfig): +def gpt2_sequence_parallel_forward_fn(shard_config: ShardConfig): def forward( self, diff --git a/colossalai/shardformer/policies/gpt2.py b/colossalai/shardformer/policies/gpt2.py index 72d7143a2c74..2055b75c88ec 100644 --- a/colossalai/shardformer/policies/gpt2.py +++ b/colossalai/shardformer/policies/gpt2.py @@ -6,7 +6,7 @@ import colossalai.shardformer.layer as col_nn from ..modeling.gpt2 import GPT2PipelineForwards -from ..modeling.gpt2_seq import seq_forward_fn +from ..modeling.gpt2_seq import gpt2_sequence_parallel_forward_fn from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription __all__ = [ @@ -57,7 +57,7 @@ def module_policy(self): target_module=col_nn.VocabParallelEmbedding1D, ), ], - method_replacement={"forward": seq_forward_fn(self.shard_config)}) + method_replacement={"forward": gpt2_sequence_parallel_forward_fn(self.shard_config)}) policy[GPT2Block] = ModulePolicyDescription(attribute_replacement={ "attn.embed_dim": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, diff --git a/tests/test_shardformer/test_model/test_shard_gpt2.py b/tests/test_shardformer/test_model/test_shard_gpt2.py index 9e5dc3802be6..c77729076dda 100644 --- a/tests/test_shardformer/test_model/test_shard_gpt2.py +++ b/tests/test_shardformer/test_model/test_shard_gpt2.py @@ -53,7 +53,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, sharded_gpt2 = sharded_model.unwrap().transformer col_layer_for_check = ['h[0].mlp.c_fc'] - row_layer_for_check = ['h[0].mlp.c_proj', 'wte'] + row_layer_for_check = ['wte', 'h[0].mlp.c_proj'] # check grad if stage_manager is None or stage_manager.is_first_stage():