diff --git a/colossalai/shardformer/layer/_operation.py b/colossalai/shardformer/layer/_operation.py index f02dea8cdabd..7da6d86d95d7 100644 --- a/colossalai/shardformer/layer/_operation.py +++ b/colossalai/shardformer/layer/_operation.py @@ -169,47 +169,57 @@ def backward(ctx, grad_output): return grad_input, grad_weight, grad_bias, None, None, None -def _AllgatherLinear(input_, weight, process_group): +def _ring_as_gather(func, input_to_gather=None, input_local=None, process_group=None, gather_dim=1, keep_item=False): + # currently only support one single tensor as output group_size = dist.get_world_size(process_group) cur_rank = dist.get_rank(process_group) - input_shape = input_.shape - weight_shape = weight.shape - - output_tensors = [torch.empty((input_shape[0], input_shape[1], weight_shape[0])) for _ in range(group_size)] + #output_tensors = [torch.empty((input_shape[0], input_shape[1], weight_shape[0])) for _ in range(group_size)] # initialization of ring communication - input_shape[1] recv_rank = cur_rank + 1 if cur_rank + 1 < group_size else 0 send_rank = cur_rank - 1 if cur_rank > 0 else group_size - 1 - recv_tensor = input_.clone() - send_tensor = input_.clone() - - recv_op = dist.P2POp(dist.irecv, recv_tensor, recv_rank, group=process_group) - send_op = dist.P2POp(dist.isend, send_tensor, send_rank, group=process_group) - handles = dist.batch_isend_irecv([send_op, recv_op]) + recv_tensors = {} + send_tensors = {} + for k, v in input_to_gather.items(): + recv_tensors[k] = v.clone() + send_tensors[k] = v.clone() + + def communicate_step(): + comm_ops = [] + for k in recv_tensors: + comm_ops.append(dist.P2POp(dist.irecv, recv_tensors[k], recv_rank, group=process_group)) + comm_ops.append(dist.P2POp(dist.isend, send_tensors[k], send_rank, group=process_group)) + return dist.batch_isend_irecv(comm_ops) + + def switch_step(): + for k in recv_tensors: + tmp_tensor = send_tensors[k] + send_tensors[k] = recv_tensors[k] + recv_tensors[k] = tmp_tensor + + output_tensors = [] + + handles = communicate_step() # first round: special case, retrive from local tensor - output_tensors[0] = F.linear(input_, weight) + output_tensors.append(func(**input_to_gather, **input_local)) for i in range(group_size - 2): for handle in handles: handle.wait() - tmp_tensor = send_tensor - send_tensor = recv_tensor - recv_tensor = tmp_tensor + switch_step() - recv_op = dist.P2POp(dist.irecv, recv_tensor, recv_rank, group=process_group) - send_op = dist.P2POp(dist.isend, send_tensor, send_rank, group=process_group) - handles = dist.batch_isend_irecv([recv_op, send_op]) + handles = communicate_step() # actual computation - output_tensors[i + 1] = F.linear(send_tensor, weight) + output_tensors.append(func(**send_tensors, **input_local)) # final round: special case, no need to send/recv again for handle in handles: handle.wait() - output_tensors[group_size - 1] = F.linear(recv_tensor, weight) - return torch.cat(output_tensors[group_size - cur_rank :] + output_tensors[: group_size - cur_rank], dim=1) + output_tensors.append(func(**recv_tensors, **input_local)) + + return torch.cat(output_tensors[group_size - cur_rank :] + output_tensors[: group_size - cur_rank], dim=gather_dim) class _GatherForwardReduceScatterBackward(torch.autograd.Function): @@ -249,6 +259,41 @@ def backward(ctx, grad_output): return output, None, None +class _GatherForwardReduceScatterBackward(torch.autograd.Function): + """Gather input from sequence parallel in forward and reduce-scatter gradient in backward + + Args: + input_ (`torch.Tensor`): The input tensor from sequence parallel region. + process_group (`torch.distributed.ProcessGroup`): The process group used for collective communication. + overlap (`bool`): Whther to overlap the all_gather op and gradient calculate in backward. + + """ + + @staticmethod + def forward(ctx, input_, process_group, dim): + ctx.process_group = process_group + ctx.dim = dim + + return _gather(input_, dim, process_group) + + @staticmethod + def backward(ctx, grad_output): + dim = ctx.dim + process_group = ctx.process_group + + # do reduce-scatter + new_shape = list(grad_output.shape) + assert ( + new_shape[dim] % dist.get_world_size(process_group) == 0 + ), f"The dimension to split ({new_shape[dim]}) is not a multiple of tensor parallel size ({dist.get_world_size(process_group)}). " + new_shape[dim] = new_shape[dim] // dist.get_world_size(process_group) + grad_list = [item.contiguous() for item in torch.chunk(grad_output, dist.get_world_size(process_group), dim=dim)] + output = torch.empty(new_shape, dtype=grad_output.dtype, device=grad_output.device) + dist.reduce_scatter(output, grad_list, group=process_group) + + return output, None, None + + class _LinearWithGatherForwardReduceScatterBackward(torch.autograd.Function): """Gather input from sequence parallel in forward and reduce-scatter gradient in backward @@ -260,7 +305,7 @@ class _LinearWithGatherForwardReduceScatterBackward(torch.autograd.Function): """ @staticmethod - def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap=True): + def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap=True, ring=False): ctx.save_for_backward(input_, weight, bias) ctx.use_bias = bias is not None ctx.process_group = process_group @@ -268,11 +313,27 @@ def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, ctx.dim = dim ctx.overlap = overlap - if bias is not None: - input_parallel = _gather(input_, dim, process_group) - output = F.linear(input_parallel, weight, bias) + if ring is True: + input_to_gather = {} + input_local = {} + input_to_gather['input'] = input_ + input_local['weight'] = weight + + output = _ring_as_gather( + F.linear, + input_to_gather=input_to_gather, + input_local=input_local, + process_group=process_group, + ) + + if bias is not None: + output += bias else: - output = _AllgatherLinear(input_, weight, process_group) + input_parallel = _gather(input_, dim, process_group) + if bias is not None: + output = F.linear(input_parallel, weight, bias) + else: + output = F.linear(input_parallel, weight) return output @@ -376,34 +437,43 @@ def backward(ctx, grad_output): # wait until reduce-scatter finished reducescatter_handle.wait() - return output, grad_weight, grad_bias, None, None, None, None + return output, grad_weight, grad_bias, None, None, None, None, None -def _ReduceScatterLinear(input_, weight, process_group): +def _ring_as_reducescatter(func, input_to_reducescatter=None, input_local=None, process_group=None, reducescatter_dim=1): + # currently only support one single tensor as output group_size = dist.get_world_size(process_group) cur_rank = dist.get_rank(process_group) - input_shape = input_.shape - # initialization of ring communication - # communicate(e.g.): 0->1->2->3 - # compute(e.g.): 3->2->1->0 - input_tensors = list(torch.split(input_, int(input_shape[1] / group_size), dim=1)) - input_tensors = input_tensors[cur_rank:] + input_tensors[:cur_rank] - input_tensors.reverse() recv_rank = cur_rank - 1 if cur_rank > 0 else group_size - 1 send_rank = cur_rank + 1 if cur_rank + 1 < group_size else 0 + input_tensors = [] + for _ in range(group_size): + input_tensors.append({}) + for k, v in input_to_reducescatter.items(): + input_shape = v.shape + assert input_shape[reducescatter_dim] % group_size == 0 + _input_tensors = list(torch.split(v, input_shape[reducescatter_dim] // group_size, dim=reducescatter_dim)) + for i in range(group_size): + input_tensors[i][k] = _input_tensors[i] + input_tensors = input_tensors[cur_rank:] + input_tensors[:cur_rank] + input_tensors.reverse() - # first round: special case, no reduce operation - output_tensor = F.linear(input_tensors[0], weight) + output_tensor = func(**input_tensors[0], **input_local) recv_tensor = output_tensor.clone() send_tensor = output_tensor.clone() - recv_op = dist.P2POp(dist.irecv, recv_tensor, recv_rank, group=process_group) - send_op = dist.P2POp(dist.isend, send_tensor, send_rank, group=process_group) - handles = dist.batch_isend_irecv([recv_op, send_op]) + + def communicate_step(): + recv_op = dist.P2POp(dist.irecv, recv_tensor, recv_rank, group=process_group) + send_op = dist.P2POp(dist.isend, send_tensor, send_rank, group=process_group) + return dist.batch_isend_irecv([recv_op, send_op]) + + handles = communicate_step() + # first round: special case, retrive from local tensor for i in range(group_size - 2): # actual computation - output_tensor = F.linear(input_tensors[i + 1], weight) + output_tensor = func(**input_tensors[i + 1], **input_local) for handle in handles: handle.wait() @@ -413,12 +483,10 @@ def _ReduceScatterLinear(input_, weight, process_group): send_tensor = output_tensor output_tensor = tmp_tensor - recv_op = dist.P2POp(dist.irecv, recv_tensor, recv_rank, group=process_group) - send_op = dist.P2POp(dist.isend, send_tensor, send_rank, group=process_group) - handles = dist.batch_isend_irecv([recv_op, send_op]) + handles = communicate_step() # final round: special case, no need to send/recv again - output_tensor = F.linear(input_tensors[group_size - 1], weight) + output_tensor = func(**input_tensors[-1], **input_local) for handle in handles: handle.wait() output_tensor += recv_tensor @@ -436,27 +504,44 @@ class _LinearWithReduceScatterForwardGatherBackward(torch.autograd.Function): """ @staticmethod - def forward(ctx, input_, weight, bias, process_group, dim): + def forward(ctx, input_, weight, bias, process_group, dim, ring): ctx.save_for_backward(input_, weight, bias) ctx.use_bias = bias is not None ctx.process_group = process_group ctx.dim = dim - if bias is not None: - partial_output = F.linear(input_, weight, bias) + + if ring is True: + input_to_reducescatter = {} + input_local = {} + input_to_reducescatter['input'] = input_ + input_local['weight'] = weight + + if bias is not None: + input_to_reducescatter['bias'] = bias + + output = _ring_as_reducescatter( + F.linear, + input_to_reducescatter=input_to_reducescatter, + input_local=input_local, + process_group=process_group, + ) else: - return _ReduceScatterLinear(input_, weight, process_group) + if bias is not None: + partial_output = F.linear(input_, weight, bias) + else: + partial_output = F.linear(input_, weight) - output_shape = list(partial_output.shape) - assert ( - output_shape[dim] % dist.get_world_size(process_group) == 0 - ), f"The dimension to split ({output_shape[dim]}) is not a multiple of tensor parallel size ({dist.get_world_size(process_group)}). " - output_shape[dim] = output_shape[dim] // dist.get_world_size(process_group) + output_shape = list(partial_output.shape) + assert ( + output_shape[dim] % dist.get_world_size(process_group) == 0 + ), f"The dimension to split ({output_shape[dim]}) is not a multiple of tensor parallel size ({dist.get_world_size(process_group)}). " + output_shape[dim] = output_shape[dim] // dist.get_world_size(process_group) - output_list = [ - item.contiguous() for item in torch.chunk(partial_output, dist.get_world_size(process_group), dim=dim) - ] - output = torch.empty(output_shape, dtype=partial_output.dtype, device=partial_output.device).contiguous() - dist.reduce_scatter(output, output_list, group=process_group) + output_list = [ + item.contiguous() for item in torch.chunk(partial_output, dist.get_world_size(process_group), dim=dim) + ] + output = torch.empty(output_shape, dtype=partial_output.dtype, device=partial_output.device).contiguous() + dist.reduce_scatter(output, output_list, group=process_group) return output @@ -484,7 +569,7 @@ def backward(ctx, grad_output): grad_weight = grad_output.t().matmul(total_input) grad_bias = grad_output.sum(dim=0) if use_bias else None - return grad_input, grad_weight, grad_bias, None, None + return grad_input, grad_weight, grad_bias, None, None, None class _ReduceScatterForwardGatherBackward(torch.autograd.Function): @@ -533,7 +618,7 @@ class _MatmulWithGatherForwardReduceScatterBackward(torch.autograd.Function): """ @staticmethod - def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap): + def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap, ring): ctx.save_for_backward(input_, weight, bias) ctx.use_bias = bias is not None ctx.process_group = process_group @@ -541,9 +626,24 @@ def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, ctx.dim = dim ctx.overlap = overlap - input_parallel = _gather(input_, dim, process_group) + if ring is True: + input_to_gather = {} + input_local = {} + input_to_gather['input'] = input_ + input_local['other'] = weight + + output = _ring_as_gather( + torch.matmul, + input_to_gather=input_to_gather, + input_local=input_local, + process_group=process_group, + gather_dim=dim + ) + + else: + input_parallel = _gather(input_, dim, process_group) - output = torch.matmul(input_parallel, weight) + output = torch.matmul(input_parallel, weight) if bias is not None: output = output + bias @@ -624,7 +724,7 @@ def backward(ctx, grad_output): # wait until reduce-scatter finished reducescatter_handle.wait() - return output, grad_weight, grad_bias, None, None, None, None + return output, grad_weight, grad_bias, None, None, None, None, None class _SplitForwardGatherBackward(torch.autograd.Function): @@ -877,10 +977,10 @@ def linear_with_async_comm(input_, weight, bias, process_group, async_grad_allre def linear_gather_forward_reducescatter_backward( - input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap + input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap, ring=False ): return _LinearWithGatherForwardReduceScatterBackward.apply( - input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap + input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap, ring ) @@ -892,15 +992,15 @@ def reducescatter_forward_gather_backward(input_, process_group, dim): return _ReduceScatterForwardGatherBackward.apply(input_, process_group, dim) -def linear_reducescatter_forward_gather_backward(input_, weight, bias=None, process_group=None, dim=1): - return _LinearWithReduceScatterForwardGatherBackward.apply(input_, weight, bias, process_group, dim) +def linear_reducescatter_forward_gather_backward(input_, weight, bias=None, process_group=None, dim=1, ring=False): + return _LinearWithReduceScatterForwardGatherBackward.apply(input_, weight, bias, process_group, dim, ring) def matmul_gather_forward_reducescatter_backward( - input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap + input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap, ring=False ): return _MatmulWithGatherForwardReduceScatterBackward.apply( - input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap + input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap, ring ) diff --git a/colossalai/shardformer/layer/linear.py b/colossalai/shardformer/layer/linear.py index 20a9f0328cfc..a773783b9f19 100644 --- a/colossalai/shardformer/layer/linear.py +++ b/colossalai/shardformer/layer/linear.py @@ -207,7 +207,7 @@ def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]: output_parallel = linear_with_async_comm(input_parallel, self.weight, bias, self.process_group, False) elif self.seq_parallel_mode == "2": output_parallel = linear_gather_forward_reducescatter_backward( - input_parallel, self.weight, bias, self.process_group, True, self.seq_parallel_dim, self.overlap + input_parallel, self.weight, bias, self.process_group, True, self.seq_parallel_dim, self.overlap, True ) if self.gather_output: @@ -429,7 +429,9 @@ def forward(self, input_: Tensor) -> Tensor: output = linear_reducescatter_forward_gather_backward( input_, self.weight, + process_group=self.process_group, dim=self.seq_parallel_dim, + ring=True, ) if not self.skip_bias_add: diff --git a/colossalai/shardformer/layer/qkv_fused_linear.py b/colossalai/shardformer/layer/qkv_fused_linear.py index 6c5fb41494f0..a5d75db8a740 100644 --- a/colossalai/shardformer/layer/qkv_fused_linear.py +++ b/colossalai/shardformer/layer/qkv_fused_linear.py @@ -323,6 +323,11 @@ def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]: output_parallel = matmul_gather_forward_reducescatter_backward( input_parallel, self.weight, bias, self.process_group, True, 1, self.overlap ) + elif self.seq_parallel_mode == "2": + input_parallel = input_ + output_parallel = matmul_gather_forward_reducescatter_backward( + input_parallel, self.weight, bias, self.process_group, True, 1, self.overlap, True + ) if self.gather_output: # All-gather across the partitions. @@ -528,10 +533,14 @@ def forward(self, input_: Tensor) -> Tensor: handle.wait() output = torch.cat(output_parallel_list, dim=-1) else: - output_parallel = torch.matmul(input_, self.weight) if self.seq_parallel_mode is None: + output_parallel = torch.matmul(input_, self.weight) output = reduce_forward(output_parallel, self.process_group) elif self.seq_parallel_mode == "1": + output_parallel = torch.matmul(input_, self.weight) + output = reducescatter_forward_gather_backward(output_parallel, self.process_group, 1) + elif self.seq_parallel_mode == "2": + output_parallel = torch.matmul(input_, self.weight) output = reducescatter_forward_gather_backward(output_parallel, self.process_group, 1) if not self.skip_bias_add: @@ -702,7 +711,6 @@ def from_native_module( # process_group=process_group, # is_transposed=False) # linear_1d.bias.data.copy_(sharded_bias.data) - print(linear_1d.weight.shape) return linear_1d def reset_parameters(self, weight_initializer, bias_initializer) -> None: diff --git a/colossalai/shardformer/modeling/gpt2.py b/colossalai/shardformer/modeling/gpt2.py index 13ac4aa9fa1e..e9fed06d7295 100644 --- a/colossalai/shardformer/modeling/gpt2.py +++ b/colossalai/shardformer/modeling/gpt2.py @@ -922,7 +922,12 @@ def forward( head_mask = self.get_head_mask(head_mask, self.config.n_layer) if inputs_embeds is None: - inputs_embeds = self.wte(input_ids) + if sp_mode in ["2"]: + input_ids = _gather(input_ids, 1, sp_group) + inputs_embeds = self.wte(input_ids) + inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group) + else: + inputs_embeds = self.wte(input_ids) position_embeds = self.wpe(position_ids) hidden_states = inputs_embeds + position_embeds diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index 2439de1befa0..a33bf80295c4 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -24,6 +24,8 @@ gather_forward_split_backward, reducescatter_forward_gather_backward, split_forward_gather_backward, + gather_forward_reducescatter_backward, + reducescatter_forward_gather_backward, ) from colossalai.shardformer.shard import ShardConfig @@ -433,7 +435,6 @@ def llama_for_sequence_classification_forward( hidden_states = transformer_outputs.get("hidden_states") return {"hidden_states": hidden_states} - def get_llama_flash_attention_forward(shard_config): from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb @@ -498,7 +499,7 @@ def forward( if llama_version == 2: key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) - + me_input_shape = (bsz, q_len, self.num_heads, self.head_dim) query_states = query_states.transpose(1, 2).contiguous().view(*me_input_shape) key_states = key_states.transpose(1, 2).contiguous().view(*me_input_shape) @@ -519,8 +520,8 @@ def forward( ) flash_attention_mask = ~(attention_mask[:, :, -1].squeeze(1).to(torch.bool)).contiguous() attn_mask_type = AttnMaskType.paddedcausal - hidden_size = self.hidden_size // sp_size if sp_mode == "3" else self.hidden_size - + hidden_size = self.hidden_size // sp_size if sp_mode == '3' else self.hidden_size + attention = ColoAttention(embed_dim=hidden_size, num_heads=self.num_heads) attn_output = attention( query_states, key_states, value_states, attn_mask=flash_attention_mask, attn_mask_type=attn_mask_type @@ -630,7 +631,7 @@ def forward( # TODO (linshengjie) Block attention with ring #### block_wise = False - seq_len = query_states[2] + seq_len = query_states.shape[2] seq_block = 1024 if block_wise and seq_len > seq_block: assert query_states.shape[2] % seq_block == 0 @@ -780,6 +781,7 @@ def _prepare_decoder_attention_mask_partial( return combined_attention_mask + def forward( self, input_ids: torch.LongTensor = None, @@ -906,6 +908,7 @@ def custom_forward(*inputs): position_ids, ) else: + layer_outputs = decoder_layer( hidden_states, attention_mask=attention_mask, diff --git a/colossalai/shardformer/policies/gpt2.py b/colossalai/shardformer/policies/gpt2.py index 0926b0ccf27a..023e7e63c950 100644 --- a/colossalai/shardformer/policies/gpt2.py +++ b/colossalai/shardformer/policies/gpt2.py @@ -49,7 +49,7 @@ def module_policy(self): sp_size = self.shard_config.sequence_parallel_size sp_group = self.shard_config.sequence_parallel_process_group overlap = self.shard_config.enable_sequence_overlap - sp_partial_derived = sp_mode in ["1"] + sp_partial_derived = sp_mode in ["1", "2"] if sp_mode == "2": pass