From b213549122165889a0171c3f6489f059910f26c1 Mon Sep 17 00:00:00 2001 From: linsj20 Date: Fri, 12 Jan 2024 15:51:39 +0800 Subject: [PATCH 1/8] [shardformer] add megatron sp to llama --- colossalai/shardformer/layer/_operation.py | 43 ++++- colossalai/shardformer/layer/linear.py | 14 +- colossalai/shardformer/modeling/llama.py | 182 ++++++++++++++++++++- colossalai/shardformer/policies/llama.py | 23 ++- 4 files changed, 244 insertions(+), 18 deletions(-) diff --git a/colossalai/shardformer/layer/_operation.py b/colossalai/shardformer/layer/_operation.py index 6a992c6f1acb..7f04d846b7c9 100644 --- a/colossalai/shardformer/layer/_operation.py +++ b/colossalai/shardformer/layer/_operation.py @@ -212,6 +212,41 @@ def _AllgatherLinear(input_, weight, process_group): return torch.cat(output_tensors[group_size - cur_rank :] + output_tensors[: group_size - cur_rank], dim=1) +class _GatherForwardReduceScatterBackward(torch.autograd.Function): + """Gather input from sequence parallel in forward and reduce-scatter gradient in backward + + Args: + input_ (`torch.Tensor`): The input tensor from sequence parallel region. + process_group (`torch.distributed.ProcessGroup`): The process group used for collective communication. + overlap (`bool`): Whther to overlap the all_gather op and gradient calculate in backward. + + """ + + @staticmethod + def forward(ctx, input_, process_group, dim): + ctx.process_group = process_group + ctx.dim = dim + + return _gather(input_, dim, process_group) + + @staticmethod + def backward(ctx, grad_output): + dim = ctx.dim + process_group = ctx.process_group + + # do reduce-scatter + new_shape = list(grad_output.shape) + assert ( + new_shape[dim] % dist.get_world_size(process_group) == 0 + ), f"The dimension to split ({new_shape[dim]}) is not a multiple of tensor parallel size ({dist.get_world_size(process_group)}). " + new_shape[dim] = new_shape[dim] // dist.get_world_size(process_group) + grad_list = [item.contiguous() for item in torch.chunk(grad_output, dist.get_world_size(process_group), dim=dim)] + output = torch.empty(new_shape, dtype=grad_output.dtype, device=grad_output.device) + dist.reduce_scatter(output, grad_list, group=process_group) + + return output, None, None + + class _LinearWithGatherForwardReduceScatterBackward(torch.autograd.Function): """Gather input from sequence parallel in forward and reduce-scatter gradient in backward @@ -389,7 +424,7 @@ def _ReduceScatterLinear(input_, weight, process_group): class _LinearWithReduceScatterForwardGatherBackward(torch.autograd.Function): - """Reduce-scatter input from sequence parallel in forward and gather gradient in backward + """Reduce-scatter input from sequence parallel in forward and gather gradient in backward with ring Args: input_ (`torch.Tensor`): The input tensor from sequence parallel region. @@ -451,7 +486,7 @@ def backward(ctx, grad_output): class _ReduceScatterForwardGatherBackward(torch.autograd.Function): - """Gather input from sequence parallel in forward and reduce-scatter gradient in backward + """Reduce-scatter input from sequence parallel in forward and gather gradient in backward Args: input_ (`torch.Tensor`): The input tensor from sequence parallel region. @@ -812,6 +847,10 @@ def linear_gather_forward_reducescatter_backward( ) +def gather_forward_reducescatter_backward(input_, process_group, dim): + return _GatherForwardReduceScatterBackward.apply(input_, process_group, dim) + + def reducescatter_forward_gather_backward(input_, process_group, dim): return _ReduceScatterForwardGatherBackward.apply(input_, process_group, dim) diff --git a/colossalai/shardformer/layer/linear.py b/colossalai/shardformer/layer/linear.py index a908a862da88..3f90a6cb0a39 100644 --- a/colossalai/shardformer/layer/linear.py +++ b/colossalai/shardformer/layer/linear.py @@ -24,6 +24,7 @@ from ._operation import ( gather_forward_split_backward, + reducescatter_forward_gather_backward, linear_gather_forward_reducescatter_backward, linear_reducescatter_forward_gather_backward, linear_with_async_comm, @@ -200,7 +201,9 @@ def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]: if self.seq_parallel_mode is None: output_parallel = linear_with_async_comm(input_parallel, self.weight, bias, self.process_group, True) - elif self.seq_parallel_mode in ["1", "2"]: + elif self.seq_parallel_mode == "1": + output_parallel = linear_with_async_comm(input_parallel, self.weight, bias, self.process_group, False) + elif self.seq_parallel_mode == "2": output_parallel = linear_gather_forward_reducescatter_backward( input_parallel, self.weight, bias, self.process_group, True, self.seq_parallel_dim, self.overlap ) @@ -415,15 +418,8 @@ def forward(self, input_: Tensor) -> Tensor: output_parallel = linear_with_async_comm(input_, self.weight, None, None, False) output = reduce_forward(output_parallel, self.process_group) elif self.seq_parallel_mode == "1": - output_parallel = linear_with_async_comm(input_, self.weight, None, None, False) - output = reducescatter_forward_gather_backward( - output_parallel, self.process_group, self.seq_parallel_dim - ) + output = linear_with_async_comm(input_, self.weight, None, None, False) elif self.seq_parallel_mode == "2": - # TODO how to maintain compatibility? - # output = reducescatter_forward_gather_backward( - # output_parallel, self.process_group, self.seq_parallel_dim - # ) output = linear_reducescatter_forward_gather_backward( input_, self.weight, diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index 433c541a9178..a759a497793f 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -22,6 +22,8 @@ all_to_all_comm, gather_forward_split_backward, split_forward_gather_backward, + gather_forward_reducescatter_backward, + reducescatter_forward_gather_backward, ) from colossalai.shardformer.shard import ShardConfig @@ -876,10 +878,10 @@ def forward( position_ids = position_ids.view(-1, seq_length).long() if inputs_embeds is None: - if sp_mode == "2": + if sp_mode in ["1", "2"]: input_ids = _gather(input_ids, 1, None) inputs_embeds = self.embed_tokens(input_ids) - input_ids = input_ids.chunk(4, dim=1)[torch.distributed.get_rank()] + input_ids = input_ids.chunk(sp_size, dim=1)[torch.distributed.get_rank()] inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, None) else: inputs_embeds = self.embed_tokens(input_ids) @@ -931,6 +933,7 @@ def custom_forward(*inputs): position_ids, ) else: + layer_outputs = decoder_layer( hidden_states, attention_mask=attention_mask, @@ -969,3 +972,178 @@ def custom_forward(*inputs): ) return forward + + +def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig): + from transformers import LlamaForCausalLM + + def forward( + self: LlamaForCausalLM, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, LlamaForCausalLM + + >>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) + >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + if self.config.pretraining_tp > 1: + lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0) + logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)] + logits = torch.cat(logits, dim=-1) + else: + logits = self.lm_head(hidden_states) + logits = logits.float() + + loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + if shard_config.enable_tensor_parallelism: + new_vocab_size = logits.shape[-1] + shift_logits = shift_logits.view(-1, new_vocab_size) + loss = cross_entropy_1d( + shift_logits, shift_labels, process_group=shard_config.tensor_parallel_process_group + ) + else: + shift_logits = shift_logits.view(-1, self.config.vocab_size) + loss = loss_fct(shift_logits, shift_labels) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + return forward + + +def get_llama_decoder_seq_parallel_model_forward(sp_mode, sp_size): + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + """ + + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + if sp_mode == "1": + hidden_states = gather_forward_reducescatter_backward(hidden_states, None, 1) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + if sp_mode == "1": + hidden_states = reducescatter_forward_gather_backward(hidden_states, None, 1) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + if sp_mode == "1": + hidden_states = gather_forward_reducescatter_backward(hidden_states, None, 1) + + hidden_states = self.mlp(hidden_states) + + if sp_mode == "1": + hidden_states = reducescatter_forward_gather_backward(hidden_states, None, 1) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + return forward diff --git a/colossalai/shardformer/policies/llama.py b/colossalai/shardformer/policies/llama.py index 1ff4666ce863..231ca70bcd8f 100644 --- a/colossalai/shardformer/policies/llama.py +++ b/colossalai/shardformer/policies/llama.py @@ -13,6 +13,7 @@ get_llama_flash_attention_forward, get_llama_seq_parallel_attention_forward, get_llama_seq_parallel_model_forward, + get_llama_decoder_seq_parallel_model_forward, get_lm_forward_with_dist_cross_entropy, ) from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription @@ -53,10 +54,13 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: # sp_partial_derived = sp_mode in ["1"] # todo: Support SP for LlaMa model if sp_mode == "1": - self.shard_config.enable_sequence_parallelism = False - self.shard_config.sequence_parallelism_mode = None - sp_mode = None - warnings.warn("Llama doesn't support sequence parallelism now, will ignore the sequence parallelism flag.") + self.append_or_create_method_replacement( + description={ + "forward": get_llama_seq_parallel_model_forward(sp_mode, sp_size), + }, + policy=policy, + target_key=LlamaModel, + ) elif sp_mode == "2": self.append_or_create_method_replacement( description={ @@ -176,6 +180,15 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: target_key=LlamaDecoderLayer, ) + if sp_mode == "1": + self.append_or_create_method_replacement( + description={ + "forward": get_llama_decoder_seq_parallel_model_forward(sp_mode, sp_size), + }, + policy=policy, + target_key=LlamaDecoderLayer, + ) + self.append_or_create_submodule_replacement( description=SubModuleReplacementDescription( suffix="norm", @@ -311,7 +324,7 @@ def module_policy(self): setattr(self.shard_config, "causal_lm", True) - if self.shard_config.enable_tensor_parallelism: + if self.shard_config.enable_tensor_parallelism and not self.shard_config.enable_sequence_parallelism: # add a new item for casual lm new_item = { LlamaForCausalLM: ModulePolicyDescription( From 8ae8f69206fb0736cbd3ab97377101588db6adb5 Mon Sep 17 00:00:00 2001 From: linsj20 Date: Fri, 19 Jan 2024 10:01:42 +0800 Subject: [PATCH 2/8] support llama7B 128k with distributed attention --- colossalai/shardformer/modeling/llama.py | 295 ++++++++++------------- 1 file changed, 133 insertions(+), 162 deletions(-) diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index a759a497793f..94e5a701c20d 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -433,7 +433,6 @@ def llama_for_sequence_classification_forward( hidden_states = transformer_outputs.get("hidden_states") return {"hidden_states": hidden_states} - def get_llama_flash_attention_forward(shard_config): from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb @@ -459,21 +458,20 @@ def forward( bsz, q_len, _ = hidden_states.size() sp_mode = shard_config.sequence_parallelism_mode sp_size = shard_config.sequence_parallel_size - sp_group = shard_config.sequence_parallel_process_group - + if sp_mode == "2": q_len *= shard_config.sequence_parallel_size # assert q_len % 4 == 0, "Flash Attention Error: The sequence length should be a multiple of 4." - + query_states = self.q_proj(hidden_states) key_states = self.k_proj(hidden_states) value_states = self.v_proj(hidden_states) - + # sp: all-to-all comminucation when introducing sequence parallel if sp_mode == "3": - query_states = all_to_all_comm(query_states, sp_group) - key_states = all_to_all_comm(key_states, sp_group) - value_states = all_to_all_comm(value_states, sp_group) + query_states = all_to_all_comm(query_states) + key_states = all_to_all_comm(key_states) + value_states = all_to_all_comm(value_states) bsz, q_len, _ = query_states.size() query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) @@ -497,7 +495,7 @@ def forward( if llama_version == 2: key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) - + me_input_shape = (bsz, q_len, self.num_heads, self.head_dim) query_states = query_states.transpose(1, 2).contiguous().view(*me_input_shape) key_states = key_states.transpose(1, 2).contiguous().view(*me_input_shape) @@ -506,14 +504,17 @@ def forward( flash_attention_mask = None attn_mask_type = AttnMaskType.causal if not getattr(shard_config, "causal_lm", False) and attention_mask != None: - if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): - raise ValueError( - f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" - ) - flash_attention_mask = ~(attention_mask[:, :, -1].squeeze(1).to(torch.bool)).contiguous() + if shard_config.enable_sequence_parallelism: + flash_attention_mask = attention_mask + else: + if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" + ) + flash_attention_mask = ~(attention_mask[:, :, -1].squeeze(1).to(torch.bool)).contiguous() attn_mask_type = AttnMaskType.paddedcausal - hidden_size = self.hidden_size // sp_size if sp_mode == "3" else self.hidden_size - + hidden_size = self.hidden_size // sp_size if sp_mode == '3' else self.hidden_size + attention = ColoAttention(embed_dim=hidden_size, num_heads=self.num_heads) attn_output = attention( query_states, key_states, value_states, attn_mask=flash_attention_mask, attn_mask_type=attn_mask_type @@ -521,120 +522,14 @@ def forward( # sp: all-to-all comminucation when introducing sequence parallel if sp_mode == "3": - attn_output = all_to_all_comm(attn_output, sp_group, scatter_dim=1, gather_dim=2) + attn_output = all_to_all_comm(attn_output, None, scatter_dim=1, gather_dim=2) attn_output = self.o_proj(attn_output) return attn_output, None, past_key_value return forward - -def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig): - from transformers import LlamaForCausalLM - - def forward( - self: LlamaForCausalLM, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, CausalLMOutputWithPast]: - r""" - Args: - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., - config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored - (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - - Returns: - - Example: - - ```python - >>> from transformers import AutoTokenizer, LlamaForCausalLM - - >>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) - >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) - - >>> prompt = "Hey, are you conscious? Can you talk to me?" - >>> inputs = tokenizer(prompt, return_tensors="pt") - - >>> # Generate - >>> generate_ids = model.generate(inputs.input_ids, max_length=30) - >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] - "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." - ```""" - - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) - outputs = self.model( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - hidden_states = outputs[0] - if self.config.pretraining_tp > 1: - lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0) - logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)] - logits = torch.cat(logits, dim=-1) - else: - logits = self.lm_head(hidden_states) - logits = logits.float() - - loss = None - if labels is not None: - # Shift so that tokens < n predict n - shift_logits = logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - # Flatten the tokens - loss_fct = CrossEntropyLoss() - shift_labels = shift_labels.view(-1) - # Enable model parallelism - shift_labels = shift_labels.to(shift_logits.device) - if shard_config.enable_tensor_parallelism: - new_vocab_size = logits.shape[-1] - shift_logits = shift_logits.view(-1, new_vocab_size) - loss = cross_entropy_1d( - shift_logits, shift_labels, process_group=shard_config.tensor_parallel_process_group - ) - else: - shift_logits = shift_logits.view(-1, self.config.vocab_size) - loss = loss_fct(shift_logits, shift_labels) - - if not return_dict: - output = (logits,) + outputs[1:] - return (loss,) + output if loss is not None else output - - return CausalLMOutputWithPast( - loss=loss, - logits=logits, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - return forward - - -def get_llama_seq_parallel_attention_forward(sp_mode, sp_size, sp_group): +def get_llama_seq_parallel_attention_forward(sp_mode, sp_size): def rotate_half(x): """Rotates half the hidden dims of the input.""" x1 = x[..., : x.shape[-1] // 2] @@ -725,24 +620,71 @@ def forward( key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) - attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + # TODO (linshengjie) Block attention with ring + #### + block_wise = False + if block_wise: + seq_block = query_states.shape[2] + assert query_states.shape[2] % seq_block == 0 + block_num = query_states.shape[2] // seq_block + #assert block_num == 1 + + query_states_chunks = query_states.chunk(block_num, dim=2) + if attention_mask is not None: + attention_mask_chunks = attention_mask.chunk(block_num, dim=2) + attn_output_chunks = [] + + + for i in range(block_num): + #attn_weights = torch.matmul(query_states_chunks[i], key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + #if attention_mask is not None: + # attn_weights = attn_weights + attention_mask_chunks[i] + #attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + #attn_output_chunks.append(torch.matmul(attn_weights, value_states)) + def custom_forward(): + def block_attn(query_states, attention_mask, key_states, value_states, head_dim): + attn_weights = torch.matmul(query_states_chunks[i], key_states.transpose(2, 3)) / math.sqrt(head_dim) + if attention_mask is not None: + attn_weights = attn_weights + attention_mask_chunks[i] + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + return torch.matmul(attn_weights, value_states) + + def block_attn_forward(*inputs): + return block_attn(*inputs) + + return block_attn_forward + key_states_ref = key_states + value_states_ref = value_states + attn_output_chunks.append(torch.utils.checkpoint.checkpoint( + custom_forward(), + query_states_chunks[i], + attention_mask_chunks[i] if attention_mask is not None else None, + key_states_ref, + value_states_ref, + self.head_dim + )) + attn_output = torch.cat(attn_output_chunks, dim=2) - if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): - raise ValueError( - f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" - f" {attn_weights.size()}" - ) + else: + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) - if attention_mask is not None: - if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): raise ValueError( - f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" + f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" + f" {attn_weights.size()}" ) - attn_weights = attn_weights + attention_mask - # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) - attn_output = torch.matmul(attn_weights, value_states) + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" + ) + attn_weights = attn_weights + attention_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_output = torch.matmul(attn_weights, value_states) + #### if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): raise ValueError( @@ -772,43 +714,62 @@ def forward( return forward +import torch.distributed as dist + def get_llama_seq_parallel_model_forward(sp_mode, sp_size): + + logger = logging.get_logger(__name__) + # Copied from transformers.models.bart.modeling_bart._make_causal_mask - def _make_causal_mask( + def _make_causal_mask_partial( input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0 ): """ Make causal mask used for bi-directional self-attention. """ bsz, tgt_len = input_ids_shape - mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device) - mask_cond = torch.arange(mask.size(-1), device=device) - mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) + world_size = dist.get_world_size() + tgt_len *= world_size + + mask = torch.full((tgt_len, tgt_len // world_size), torch.finfo(dtype).min, device=device) + mask_cond = torch.arange(mask.size(-1) * world_size, device=device) + + block_size = tgt_len // world_size + idx = dist.get_rank() + off = idx * block_size + + mask.masked_fill_(mask_cond[off:off+block_size] < (mask_cond + 1).view(mask.size(-1) * world_size, 1), 0) mask = mask.to(dtype) if past_key_values_length > 0: - mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1) - return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) + mask = torch.cat([torch.zeros(tgt_len // world_size, past_key_values_length, dtype=dtype, device=device), mask], dim=-1) + return mask[None, None, :, :].expand(bsz, 1, tgt_len, (tgt_len + past_key_values_length) // world_size) + # Copied from transformers.models.bart.modeling_bart._expand_mask - def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): + def _expand_mask_partial(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): """ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. """ bsz, src_len = mask.size() tgt_len = tgt_len if tgt_len is not None else src_len - expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) - # inverted_mask = 1.0 - expanded_mask - inverted_mask = expanded_mask.mul_(-1).add_(1.0) - return inverted_mask.masked_fill_(inverted_mask.to(torch.bool), torch.finfo(dtype).min) + + world_size = dist.get_world_size() + + expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len * world_size, src_len).to(dtype) + + inverted_mask = 1.0 - expanded_mask + + return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) + # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask - def _prepare_decoder_attention_mask(attention_mask, input_shape, inputs_embeds, past_key_values_length): + def _prepare_decoder_attention_mask_partial(attention_mask, input_shape, inputs_embeds, past_key_values_length): # create causal mask # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] combined_attention_mask = None if input_shape[-1] > 1: - combined_attention_mask = _make_causal_mask( + combined_attention_mask = _make_causal_mask_partial( input_shape, inputs_embeds.dtype, device=inputs_embeds.device, @@ -817,7 +778,7 @@ def _prepare_decoder_attention_mask(attention_mask, input_shape, inputs_embeds, if attention_mask is not None: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to( + expanded_attn_mask = _expand_mask_partial(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to( inputs_embeds.device ) combined_attention_mask = ( @@ -826,6 +787,7 @@ def _prepare_decoder_attention_mask(attention_mask, input_shape, inputs_embeds, return combined_attention_mask + def forward( self, input_ids: torch.LongTensor = None, @@ -887,20 +849,30 @@ def forward( inputs_embeds = self.embed_tokens(input_ids) # embed positions - if attention_mask is None: + if sp_mode is None: + if attention_mask is None: + attention_mask = torch.ones( + (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device + ) + attention_mask = _gather(attention_mask, 1, None) + attention_mask = self._prepare_decoder_attention_mask( + attention_mask, attention_mask.shape, inputs_embeds, past_key_values_length + ) + else: + world_size = dist.get_world_size() + assert seq_length_with_past % world_size == 0 attention_mask = torch.ones( - (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device + (batch_size, seq_length_with_past // world_size), dtype=torch.bool, device=inputs_embeds.device ) - - attention_mask = _gather(attention_mask, 1, None) - - attention_mask = _prepare_decoder_attention_mask( - attention_mask, attention_mask.shape, inputs_embeds, past_key_values_length - ) + attention_mask = _prepare_decoder_attention_mask_partial( + attention_mask, attention_mask.shape, inputs_embeds, past_key_values_length + ) + attention_mask = ~(attention_mask[:, :, -1].squeeze(1).to(torch.bool)).contiguous() + attention_mask = _gather(attention_mask, 1, None) hidden_states = inputs_embeds - if self.gradient_checkpointing and self.training: + if (self.gradient_checkpointing or sp_mode is not None) and self.training: if use_cache: logger.warning_once( "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." @@ -917,8 +889,7 @@ def forward( all_hidden_states += (hidden_states,) past_key_value = past_key_values[idx] if past_key_values is not None else None - if self.gradient_checkpointing and self.training: - + if (self.gradient_checkpointing or sp_mode is not None) and self.training: def create_custom_forward(module): def custom_forward(*inputs): # None for past_key_value From 579df22498c673cfa1d823f57fb66b88f540422a Mon Sep 17 00:00:00 2001 From: linsj20 Date: Mon, 22 Jan 2024 09:49:14 +0800 Subject: [PATCH 3/8] [shardformer] robustness enhancement --- colossalai/shardformer/modeling/llama.py | 57 +++++++++++--------- colossalai/shardformer/policies/llama.py | 8 +-- colossalai/shardformer/shard/shard_config.py | 11 ++-- tests/test_shardformer/test_model/_utils.py | 12 ++--- 4 files changed, 46 insertions(+), 42 deletions(-) diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index 94e5a701c20d..59545983aaad 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -461,7 +461,7 @@ def forward( if sp_mode == "2": q_len *= shard_config.sequence_parallel_size - # assert q_len % 4 == 0, "Flash Attention Error: The sequence length should be a multiple of 4." + assert q_len % 4 == 0, "Flash Attention Error: The sequence length should be a multiple of 4." query_states = self.q_proj(hidden_states) key_states = self.k_proj(hidden_states) @@ -503,8 +503,11 @@ def forward( flash_attention_mask = None attn_mask_type = AttnMaskType.causal + + # TODO Internal function + use_distributed_mask = False if not getattr(shard_config, "causal_lm", False) and attention_mask != None: - if shard_config.enable_sequence_parallelism: + if use_distributed_mask is True: flash_attention_mask = attention_mask else: if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): @@ -529,7 +532,7 @@ def forward( return forward -def get_llama_seq_parallel_attention_forward(sp_mode, sp_size): +def get_llama_seq_parallel_attention_forward(sp_mode, sp_size, sp_group): def rotate_half(x): """Rotates half the hidden dims of the input.""" x1 = x[..., : x.shape[-1] // 2] @@ -716,13 +719,13 @@ def block_attn_forward(*inputs): import torch.distributed as dist -def get_llama_seq_parallel_model_forward(sp_mode, sp_size): +def get_llama_seq_parallel_model_forward(sp_mode, sp_size, sp_group): logger = logging.get_logger(__name__) # Copied from transformers.models.bart.modeling_bart._make_causal_mask def _make_causal_mask_partial( - input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0 + input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0, sp_group = None ): """ Make causal mask used for bi-directional self-attention. @@ -747,7 +750,7 @@ def _make_causal_mask_partial( # Copied from transformers.models.bart.modeling_bart._expand_mask - def _expand_mask_partial(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): + def _expand_mask_partial(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None, sp_group = None): """ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. """ @@ -764,7 +767,7 @@ def _expand_mask_partial(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Option # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask - def _prepare_decoder_attention_mask_partial(attention_mask, input_shape, inputs_embeds, past_key_values_length): + def _prepare_decoder_attention_mask_partial(attention_mask, input_shape, inputs_embeds, past_key_values_length, sp_group = None): # create causal mask # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] combined_attention_mask = None @@ -774,11 +777,12 @@ def _prepare_decoder_attention_mask_partial(attention_mask, input_shape, inputs_ inputs_embeds.dtype, device=inputs_embeds.device, past_key_values_length=past_key_values_length, + sp_group=sp_group ) if attention_mask is not None: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - expanded_attn_mask = _expand_mask_partial(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to( + expanded_attn_mask = _expand_mask_partial(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1], sp_group=sp_group).to( inputs_embeds.device ) combined_attention_mask = ( @@ -841,34 +845,37 @@ def forward( if inputs_embeds is None: if sp_mode in ["1", "2"]: - input_ids = _gather(input_ids, 1, None) + input_ids = _gather(input_ids, 1, sp_group) inputs_embeds = self.embed_tokens(input_ids) - input_ids = input_ids.chunk(sp_size, dim=1)[torch.distributed.get_rank()] - inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, None) + input_ids = input_ids.chunk(sp_size, dim=1)[torch.distributed.get_rank(sp_group)] + inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group) else: inputs_embeds = self.embed_tokens(input_ids) + # TODO Internal function + use_distributed_mask = False + # embed positions - if sp_mode is None: + if sp_mode is None or use_distributed_mask is False: if attention_mask is None: attention_mask = torch.ones( (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device ) - attention_mask = _gather(attention_mask, 1, None) - attention_mask = self._prepare_decoder_attention_mask( - attention_mask, attention_mask.shape, inputs_embeds, past_key_values_length - ) + attention_mask = _gather(attention_mask, 1, sp_group) + attention_mask = self._prepare_decoder_attention_mask( + attention_mask, attention_mask.shape, inputs_embeds, past_key_values_length + ) else: - world_size = dist.get_world_size() + world_size = dist.get_world_size(sp_group) assert seq_length_with_past % world_size == 0 attention_mask = torch.ones( (batch_size, seq_length_with_past // world_size), dtype=torch.bool, device=inputs_embeds.device ) attention_mask = _prepare_decoder_attention_mask_partial( - attention_mask, attention_mask.shape, inputs_embeds, past_key_values_length + attention_mask, attention_mask.shape, inputs_embeds, past_key_values_length, sp_group ) attention_mask = ~(attention_mask[:, :, -1].squeeze(1).to(torch.bool)).contiguous() - attention_mask = _gather(attention_mask, 1, None) + attention_mask = _gather(attention_mask, 1, sp_group) hidden_states = inputs_embeds @@ -925,7 +932,7 @@ def custom_forward(*inputs): hidden_states = self.norm(hidden_states) # Todo: Maybe this line can be optimized - hidden_states = gather_forward_split_backward(hidden_states, 1, None) + hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group) # add hidden states from the last decoder layer if output_hidden_states: @@ -1050,7 +1057,7 @@ def forward( return forward -def get_llama_decoder_seq_parallel_model_forward(sp_mode, sp_size): +def get_llama_decoder_seq_parallel_model_forward(sp_mode, sp_size, sp_group): def forward( self, @@ -1080,7 +1087,7 @@ def forward( hidden_states = self.input_layernorm(hidden_states) if sp_mode == "1": - hidden_states = gather_forward_reducescatter_backward(hidden_states, None, 1) + hidden_states = gather_forward_reducescatter_backward(hidden_states, sp_group, 1) # Self Attention hidden_states, self_attn_weights, present_key_value = self.self_attn( @@ -1093,19 +1100,19 @@ def forward( ) if sp_mode == "1": - hidden_states = reducescatter_forward_gather_backward(hidden_states, None, 1) + hidden_states = reducescatter_forward_gather_backward(hidden_states, sp_group, 1) hidden_states = residual + hidden_states # Fully Connected residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) if sp_mode == "1": - hidden_states = gather_forward_reducescatter_backward(hidden_states, None, 1) + hidden_states = gather_forward_reducescatter_backward(hidden_states, sp_group, 1) hidden_states = self.mlp(hidden_states) if sp_mode == "1": - hidden_states = reducescatter_forward_gather_backward(hidden_states, None, 1) + hidden_states = reducescatter_forward_gather_backward(hidden_states, sp_group, 1) hidden_states = residual + hidden_states outputs = (hidden_states,) diff --git a/colossalai/shardformer/policies/llama.py b/colossalai/shardformer/policies/llama.py index 231ca70bcd8f..4b8047c15cf8 100644 --- a/colossalai/shardformer/policies/llama.py +++ b/colossalai/shardformer/policies/llama.py @@ -56,7 +56,7 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: if sp_mode == "1": self.append_or_create_method_replacement( description={ - "forward": get_llama_seq_parallel_model_forward(sp_mode, sp_size), + "forward": get_llama_seq_parallel_model_forward(sp_mode, sp_size, sp_group), }, policy=policy, target_key=LlamaModel, @@ -71,7 +71,7 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: ) self.append_or_create_method_replacement( description={ - "forward": get_llama_seq_parallel_model_forward(sp_mode, sp_size), + "forward": get_llama_seq_parallel_model_forward(sp_mode, sp_size, sp_group), }, policy=policy, target_key=LlamaModel, @@ -98,7 +98,7 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: ) self.append_or_create_method_replacement( description={ - "forward": get_llama_seq_parallel_model_forward(sp_mode, sp_size), + "forward": get_llama_seq_parallel_model_forward(sp_mode, sp_size, sp_group), }, policy=policy, target_key=LlamaModel, @@ -183,7 +183,7 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: if sp_mode == "1": self.append_or_create_method_replacement( description={ - "forward": get_llama_decoder_seq_parallel_model_forward(sp_mode, sp_size), + "forward": get_llama_decoder_seq_parallel_model_forward(sp_mode, sp_size, sp_group), }, policy=policy, target_key=LlamaDecoderLayer, diff --git a/colossalai/shardformer/shard/shard_config.py b/colossalai/shardformer/shard/shard_config.py index 9363b6c64169..2e41074e1504 100644 --- a/colossalai/shardformer/shard/shard_config.py +++ b/colossalai/shardformer/shard/shard_config.py @@ -52,6 +52,10 @@ def sequence_parallel_size(self): return self._sequence_parallel_size def __post_init__(self): + # turn on all optimization if all_optimization is set to True + if self.enable_all_optimization: + self._turn_on_all_optimization() + if self.enable_sequence_parallelism: self.sequence_parallelism_mode = ( "1" if self.sequence_parallelism_mode is None else self.sequence_parallelism_mode @@ -94,10 +98,6 @@ def __post_init__(self): else: self._sequence_parallel_size = dist.get_world_size(self.sequence_parallel_process_group) - # turn on all optimization if all_optimization is set to True - if self.enable_all_optimization: - self._turn_on_all_optimization() - def _turn_on_all_optimization(self): """ Turn on all optimization. @@ -108,8 +108,9 @@ def _turn_on_all_optimization(self): self.enable_jit_fused = True self.enable_sequence_parallelism = True self.enable_sequence_overlap = True - # todo modify default sequence parallelism mode + # todo modify default sequence parallelism mode and process group self.sequence_parallelism_mode = "1" + self.sequence_parallel_process_group = self.tensor_parallel_process_group def _infer(self): """ diff --git a/tests/test_shardformer/test_model/_utils.py b/tests/test_shardformer/test_model/_utils.py index d31f62392baf..226a8e4731c3 100644 --- a/tests/test_shardformer/test_model/_utils.py +++ b/tests/test_shardformer/test_model/_utils.py @@ -157,11 +157,6 @@ def _criterion(outputs, inputs): return loss data = data_gen_fn() - # for k, v in data.items(): - # size = list(v.shape) - # tg_size = [1] * len(size) - # tg_size[1] = 64 * 2 - # data[k] = v.repeat(tg_size) if ( booster.plugin.shard_config.enable_sequence_parallelism @@ -182,9 +177,11 @@ def _criterion(outputs, inputs): shard_test_data[k] = data[k].clone() else: shard_test_data[k] = ( - torch.chunk(data[k].clone(), booster.plugin.shard_config.sequence_parallel_size, dim=1)[dist.get_rank()] + torch.chunk(data[k].clone(), booster.plugin.shard_config.sequence_parallel_size, dim=1)[ + dist.get_rank(booster.plugin.shard_config.sequence_parallel_process_group) + ] if booster.plugin.shard_config.enable_sequence_parallelism - and booster.plugin.shard_config.sequence_parallelism_mode in ["2", "3"] + and booster.plugin.shard_config.sequence_parallelism_mode in ["1", "2", "3"] else data[k].clone() ) unshard_test_data = {} @@ -224,7 +221,6 @@ def _criterion(outputs, inputs): org_loss.backward() return org_loss, org_output, sharded_loss, sharded_output - # return sharded_loss, sharded_output, sharded_loss, sharded_output def check_output_hidden_state( From 333cfde06589576baba17d6e80388a4486e59312 Mon Sep 17 00:00:00 2001 From: linsj20 Date: Mon, 22 Jan 2024 10:16:51 +0800 Subject: [PATCH 4/8] add block attn --- colossalai/shardformer/modeling/llama.py | 38 +++++------------------- 1 file changed, 8 insertions(+), 30 deletions(-) diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index 59545983aaad..9dc8bb5b6cca 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -626,11 +626,11 @@ def forward( # TODO (linshengjie) Block attention with ring #### block_wise = False - if block_wise: - seq_block = query_states.shape[2] + seq_len = query_states[2] + seq_block = 1024 + if block_wise and seq_len > seq_block: assert query_states.shape[2] % seq_block == 0 block_num = query_states.shape[2] // seq_block - #assert block_num == 1 query_states_chunks = query_states.chunk(block_num, dim=2) if attention_mask is not None: @@ -639,33 +639,11 @@ def forward( for i in range(block_num): - #attn_weights = torch.matmul(query_states_chunks[i], key_states.transpose(2, 3)) / math.sqrt(self.head_dim) - #if attention_mask is not None: - # attn_weights = attn_weights + attention_mask_chunks[i] - #attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) - #attn_output_chunks.append(torch.matmul(attn_weights, value_states)) - def custom_forward(): - def block_attn(query_states, attention_mask, key_states, value_states, head_dim): - attn_weights = torch.matmul(query_states_chunks[i], key_states.transpose(2, 3)) / math.sqrt(head_dim) - if attention_mask is not None: - attn_weights = attn_weights + attention_mask_chunks[i] - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) - return torch.matmul(attn_weights, value_states) - - def block_attn_forward(*inputs): - return block_attn(*inputs) - - return block_attn_forward - key_states_ref = key_states - value_states_ref = value_states - attn_output_chunks.append(torch.utils.checkpoint.checkpoint( - custom_forward(), - query_states_chunks[i], - attention_mask_chunks[i] if attention_mask is not None else None, - key_states_ref, - value_states_ref, - self.head_dim - )) + attn_weights = torch.matmul(query_states_chunks[i], key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + if attention_mask is not None: + attn_weights = attn_weights + attention_mask_chunks[i] + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_output_chunks.append(torch.matmul(attn_weights, value_states)) attn_output = torch.cat(attn_output_chunks, dim=2) else: From 7aa4a2d1580219c50e0a93cd4d40f6763f0e8ced Mon Sep 17 00:00:00 2001 From: linsj20 Date: Mon, 22 Jan 2024 10:37:57 +0800 Subject: [PATCH 5/8] sp mode 1: keep input as a complete sequence --- colossalai/shardformer/modeling/llama.py | 12 ++++++++++++ tests/test_shardformer/test_model/_utils.py | 2 +- 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index 9dc8bb5b6cca..478126188da6 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -790,6 +790,18 @@ def forward( return_dict = return_dict if return_dict is not None else self.config.use_return_dict + if sp_mode == "1": + if input_ids is not None: + input_ids = split_forward_gather_backward(input_ids, dim=1, process_group=sp_group) + if attention_mask is not None: + attention_mask = split_forward_gather_backward(attention_mask, dim=1, process_group=sp_group) + if position_ids is not None: + position_ids = split_forward_gather_backward(position_ids, dim=1, process_group=sp_group) + if past_key_values is not None: + past_key_values = split_forward_gather_backward(past_key_values, dim=1, process_group=sp_group) + if inputs_embeds is not None: + inputs_embeds = split_forward_gather_backward(inputs_embeds, dim=1, process_group=sp_group) + # retrieve input_ids and inputs_embeds if input_ids is not None and inputs_embeds is not None: raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") diff --git a/tests/test_shardformer/test_model/_utils.py b/tests/test_shardformer/test_model/_utils.py index 226a8e4731c3..eca4764a8c49 100644 --- a/tests/test_shardformer/test_model/_utils.py +++ b/tests/test_shardformer/test_model/_utils.py @@ -181,7 +181,7 @@ def _criterion(outputs, inputs): dist.get_rank(booster.plugin.shard_config.sequence_parallel_process_group) ] if booster.plugin.shard_config.enable_sequence_parallelism - and booster.plugin.shard_config.sequence_parallelism_mode in ["1", "2", "3"] + and booster.plugin.shard_config.sequence_parallelism_mode in ["2", "3"] else data[k].clone() ) unshard_test_data = {} From be6aed26ebee73e2d6e8ca228e6dfc7b9d3f8b5e Mon Sep 17 00:00:00 2001 From: linsj20 Date: Tue, 23 Jan 2024 11:14:48 +0800 Subject: [PATCH 6/8] fix sp compatability --- colossalai/shardformer/layer/linear.py | 10 ++++-- colossalai/shardformer/modeling/llama.py | 36 +++++++++----------- colossalai/shardformer/policies/llama.py | 11 +++++- colossalai/shardformer/shard/shard_config.py | 11 +++--- 4 files changed, 40 insertions(+), 28 deletions(-) diff --git a/colossalai/shardformer/layer/linear.py b/colossalai/shardformer/layer/linear.py index 3f90a6cb0a39..20a9f0328cfc 100644 --- a/colossalai/shardformer/layer/linear.py +++ b/colossalai/shardformer/layer/linear.py @@ -24,6 +24,7 @@ from ._operation import ( gather_forward_split_backward, + gather_forward_reducescatter_backward, reducescatter_forward_gather_backward, linear_gather_forward_reducescatter_backward, linear_reducescatter_forward_gather_backward, @@ -202,6 +203,7 @@ def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]: if self.seq_parallel_mode is None: output_parallel = linear_with_async_comm(input_parallel, self.weight, bias, self.process_group, True) elif self.seq_parallel_mode == "1": + input_parallel = gather_forward_reducescatter_backward(input_parallel, self.process_group, self.seq_parallel_dim) output_parallel = linear_with_async_comm(input_parallel, self.weight, bias, self.process_group, False) elif self.seq_parallel_mode == "2": output_parallel = linear_gather_forward_reducescatter_backward( @@ -415,10 +417,14 @@ def forward(self, input_: Tensor) -> Tensor: output = torch.cat(output_parallel_list, dim=-1) else: if self.seq_parallel_mode is None: - output_parallel = linear_with_async_comm(input_, self.weight, None, None, False) + output_parallel = linear_with_async_comm(input_, self.weight, None, self.process_group, False) output = reduce_forward(output_parallel, self.process_group) elif self.seq_parallel_mode == "1": - output = linear_with_async_comm(input_, self.weight, None, None, False) + #output = linear_with_async_comm(input_, self.weight, None, None, False) + output_parallel = linear_with_async_comm(input_, self.weight, None, self.process_group, False) + output = reducescatter_forward_gather_backward( + output_parallel, self.process_group, self.seq_parallel_dim + ) elif self.seq_parallel_mode == "2": output = linear_reducescatter_forward_gather_backward( input_, diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index 478126188da6..1421e5c337ed 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -459,14 +459,14 @@ def forward( sp_mode = shard_config.sequence_parallelism_mode sp_size = shard_config.sequence_parallel_size - if sp_mode == "2": + if sp_mode in["1", "2"]: q_len *= shard_config.sequence_parallel_size assert q_len % 4 == 0, "Flash Attention Error: The sequence length should be a multiple of 4." query_states = self.q_proj(hidden_states) key_states = self.k_proj(hidden_states) value_states = self.v_proj(hidden_states) - + # sp: all-to-all comminucation when introducing sequence parallel if sp_mode == "3": query_states = all_to_all_comm(query_states) @@ -474,6 +474,8 @@ def forward( value_states = all_to_all_comm(value_states) bsz, q_len, _ = query_states.size() + if shard_config.sequence_parallel_size < 4: + print(query_states.shape) query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) @@ -571,7 +573,7 @@ def forward( ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() # sp: modify sp_len when sequence parallel mode is 2 - if sp_mode == "2": + if sp_mode in["1", "2"]: q_len *= sp_size if self.config.pretraining_tp > 1: key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp @@ -790,18 +792,6 @@ def forward( return_dict = return_dict if return_dict is not None else self.config.use_return_dict - if sp_mode == "1": - if input_ids is not None: - input_ids = split_forward_gather_backward(input_ids, dim=1, process_group=sp_group) - if attention_mask is not None: - attention_mask = split_forward_gather_backward(attention_mask, dim=1, process_group=sp_group) - if position_ids is not None: - position_ids = split_forward_gather_backward(position_ids, dim=1, process_group=sp_group) - if past_key_values is not None: - past_key_values = split_forward_gather_backward(past_key_values, dim=1, process_group=sp_group) - if inputs_embeds is not None: - inputs_embeds = split_forward_gather_backward(inputs_embeds, dim=1, process_group=sp_group) - # retrieve input_ids and inputs_embeds if input_ids is not None and inputs_embeds is not None: raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") @@ -813,7 +803,8 @@ def forward( raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") # sp: modify seq_length when using sequence parallel - seq_length *= sp_size + if sp_mode in ["2", "3"]: + seq_length *= sp_size seq_length_with_past = seq_length past_key_values_length = 0 @@ -834,7 +825,7 @@ def forward( position_ids = position_ids.view(-1, seq_length).long() if inputs_embeds is None: - if sp_mode in ["1", "2"]: + if sp_mode == "2": input_ids = _gather(input_ids, 1, sp_group) inputs_embeds = self.embed_tokens(input_ids) input_ids = input_ids.chunk(sp_size, dim=1)[torch.distributed.get_rank(sp_group)] @@ -851,7 +842,10 @@ def forward( attention_mask = torch.ones( (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device ) - attention_mask = _gather(attention_mask, 1, sp_group) + + if sp_mode in ["2", "3"]: + attention_mask = _gather(attention_mask, 1, sp_group) + attention_mask = self._prepare_decoder_attention_mask( attention_mask, attention_mask.shape, inputs_embeds, past_key_values_length ) @@ -868,8 +862,10 @@ def forward( attention_mask = _gather(attention_mask, 1, sp_group) hidden_states = inputs_embeds + if sp_mode == "1": + hidden_states = split_forward_gather_backward(hidden_states, 1, sp_group) - if (self.gradient_checkpointing or sp_mode is not None) and self.training: + if (self.gradient_checkpointing or sp_mode in ["2", "3"]) and self.training: if use_cache: logger.warning_once( "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." @@ -886,7 +882,7 @@ def forward( all_hidden_states += (hidden_states,) past_key_value = past_key_values[idx] if past_key_values is not None else None - if (self.gradient_checkpointing or sp_mode is not None) and self.training: + if (self.gradient_checkpointing or sp_mode in ["2", "3"]) and self.training: def create_custom_forward(module): def custom_forward(*inputs): # None for past_key_value diff --git a/colossalai/shardformer/policies/llama.py b/colossalai/shardformer/policies/llama.py index 4b8047c15cf8..d6ba81847910 100644 --- a/colossalai/shardformer/policies/llama.py +++ b/colossalai/shardformer/policies/llama.py @@ -61,6 +61,13 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: policy=policy, target_key=LlamaModel, ) + self.append_or_create_method_replacement( + description={ + "forward": get_llama_seq_parallel_attention_forward(sp_mode, sp_size, sp_group), + }, + policy=policy, + target_key=LlamaAttention, + ) elif sp_mode == "2": self.append_or_create_method_replacement( description={ @@ -180,7 +187,8 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: target_key=LlamaDecoderLayer, ) - if sp_mode == "1": + ''' + if sp_mode == "1" and False: self.append_or_create_method_replacement( description={ "forward": get_llama_decoder_seq_parallel_model_forward(sp_mode, sp_size, sp_group), @@ -188,6 +196,7 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: policy=policy, target_key=LlamaDecoderLayer, ) + ''' self.append_or_create_submodule_replacement( description=SubModuleReplacementDescription( diff --git a/colossalai/shardformer/shard/shard_config.py b/colossalai/shardformer/shard/shard_config.py index 2e41074e1504..f8f02b960ccb 100644 --- a/colossalai/shardformer/shard/shard_config.py +++ b/colossalai/shardformer/shard/shard_config.py @@ -106,11 +106,12 @@ def _turn_on_all_optimization(self): self.enable_fused_normalization = True self.enable_flash_attention = True self.enable_jit_fused = True - self.enable_sequence_parallelism = True - self.enable_sequence_overlap = True - # todo modify default sequence parallelism mode and process group - self.sequence_parallelism_mode = "1" - self.sequence_parallel_process_group = self.tensor_parallel_process_group + if self.enable_tensor_parallelism: + self.enable_sequence_parallelism = True + self.enable_sequence_overlap = True + # todo modify default sequence parallelism mode and process group + self.sequence_parallelism_mode = "1" + self.sequence_parallel_process_group = self.tensor_parallel_process_group def _infer(self): """ From 7d219c1557e6e36d7e5340a6be8e32654e2d12e7 Mon Sep 17 00:00:00 2001 From: linsj20 Date: Thu, 25 Jan 2024 16:30:23 +0800 Subject: [PATCH 7/8] refactor ring implementation --- colossalai/shardformer/layer/_operation.py | 167 ++++++++++++++------- colossalai/shardformer/modeling/llama.py | 2 +- 2 files changed, 111 insertions(+), 58 deletions(-) diff --git a/colossalai/shardformer/layer/_operation.py b/colossalai/shardformer/layer/_operation.py index f02dea8cdabd..7af480432a00 100644 --- a/colossalai/shardformer/layer/_operation.py +++ b/colossalai/shardformer/layer/_operation.py @@ -169,47 +169,57 @@ def backward(ctx, grad_output): return grad_input, grad_weight, grad_bias, None, None, None -def _AllgatherLinear(input_, weight, process_group): +def _ring_as_gather(func, input_to_gather=None, input_local=None, process_group=None, gather_dim=1, keep_item=False): + # currently only support one single tensor as output group_size = dist.get_world_size(process_group) cur_rank = dist.get_rank(process_group) - input_shape = input_.shape - weight_shape = weight.shape - - output_tensors = [torch.empty((input_shape[0], input_shape[1], weight_shape[0])) for _ in range(group_size)] + #output_tensors = [torch.empty((input_shape[0], input_shape[1], weight_shape[0])) for _ in range(group_size)] # initialization of ring communication - input_shape[1] recv_rank = cur_rank + 1 if cur_rank + 1 < group_size else 0 send_rank = cur_rank - 1 if cur_rank > 0 else group_size - 1 - recv_tensor = input_.clone() - send_tensor = input_.clone() - - recv_op = dist.P2POp(dist.irecv, recv_tensor, recv_rank, group=process_group) - send_op = dist.P2POp(dist.isend, send_tensor, send_rank, group=process_group) - handles = dist.batch_isend_irecv([send_op, recv_op]) + recv_tensors = {} + send_tensors = {} + for k, v in input_to_gather.items(): + recv_tensors[k] = v.clone() + send_tensors[k] = v.clone() + + def communicate_step(): + comm_ops = [] + for k in recv_tensors: + comm_ops.append(dist.P2POp(dist.irecv, recv_tensors[k], recv_rank, group=process_group)) + comm_ops.append(dist.P2POp(dist.isend, send_tensors[k], send_rank, group=process_group)) + return dist.batch_isend_irecv(comm_ops) + + def switch_step(): + for k in recv_tensors: + tmp_tensor = send_tensors[k] + send_tensors[k] = recv_tensors[k] + recv_tensors[k] = tmp_tensor + + output_tensors = [] + + handles = communicate_step() # first round: special case, retrive from local tensor - output_tensors[0] = F.linear(input_, weight) + output_tensors.append(func(**input_to_gather, **input_local)) for i in range(group_size - 2): for handle in handles: handle.wait() - tmp_tensor = send_tensor - send_tensor = recv_tensor - recv_tensor = tmp_tensor + switch_step() - recv_op = dist.P2POp(dist.irecv, recv_tensor, recv_rank, group=process_group) - send_op = dist.P2POp(dist.isend, send_tensor, send_rank, group=process_group) - handles = dist.batch_isend_irecv([recv_op, send_op]) + handles = communicate_step() # actual computation - output_tensors[i + 1] = F.linear(send_tensor, weight) + output_tensors.append(func(**send_tensors, **input_local)) # final round: special case, no need to send/recv again for handle in handles: handle.wait() - output_tensors[group_size - 1] = F.linear(recv_tensor, weight) - return torch.cat(output_tensors[group_size - cur_rank :] + output_tensors[: group_size - cur_rank], dim=1) + output_tensors.append(func(**recv_tensors, **input_local)) + + return torch.cat(output_tensors[group_size - cur_rank :] + output_tensors[: group_size - cur_rank], dim=gather_dim) class _GatherForwardReduceScatterBackward(torch.autograd.Function): @@ -268,11 +278,28 @@ def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, ctx.dim = dim ctx.overlap = overlap - if bias is not None: - input_parallel = _gather(input_, dim, process_group) - output = F.linear(input_parallel, weight, bias) + ring = True + if ring is True: + input_to_gather = {} + input_local = {} + input_to_gather['input'] = input_ + input_local['weight'] = weight + + if bias is not None: + input_local['bias'] = bias + + output = _ring_as_gather( + F.linear, + input_to_gather=input_to_gather, + input_local=input_local, + process_group=process_group, + ) else: - output = _AllgatherLinear(input_, weight, process_group) + input_parallel = _gather(input_, dim, process_group) + if bias is not None: + output = F.linear(input_parallel, weight, bias) + else: + output = F.linear(input_parallel, weight) return output @@ -379,31 +406,40 @@ def backward(ctx, grad_output): return output, grad_weight, grad_bias, None, None, None, None -def _ReduceScatterLinear(input_, weight, process_group): +def _ring_as_reducescatter(func, input_to_reducescatter=None, input_local=None, process_group=None, reducescatter_dim=1): + # currently only support one single tensor as output group_size = dist.get_world_size(process_group) cur_rank = dist.get_rank(process_group) - input_shape = input_.shape - # initialization of ring communication - # communicate(e.g.): 0->1->2->3 - # compute(e.g.): 3->2->1->0 - input_tensors = list(torch.split(input_, int(input_shape[1] / group_size), dim=1)) - input_tensors = input_tensors[cur_rank:] + input_tensors[:cur_rank] - input_tensors.reverse() recv_rank = cur_rank - 1 if cur_rank > 0 else group_size - 1 send_rank = cur_rank + 1 if cur_rank + 1 < group_size else 0 + input_tensors = [] + for _ in range(group_size): + input_tensors.append({}) + for k, v in input_to_reducescatter.items(): + input_shape = v.shape + assert input_shape[reducescatter_dim] % group_size == 0 + _input_tensors = list(torch.split(v, input_shape[reducescatter_dim] // group_size, dim=reducescatter_dim)) + for i in range(group_size): + input_tensors[i][k] = _input_tensors[i] + input_tensors = input_tensors[cur_rank:] + input_tensors[:cur_rank] + input_tensors.reverse() - # first round: special case, no reduce operation - output_tensor = F.linear(input_tensors[0], weight) + output_tensor = func(**input_tensors[0], **input_local) recv_tensor = output_tensor.clone() send_tensor = output_tensor.clone() - recv_op = dist.P2POp(dist.irecv, recv_tensor, recv_rank, group=process_group) - send_op = dist.P2POp(dist.isend, send_tensor, send_rank, group=process_group) - handles = dist.batch_isend_irecv([recv_op, send_op]) + + def communicate_step(): + recv_op = dist.P2POp(dist.irecv, recv_tensor, recv_rank, group=process_group) + send_op = dist.P2POp(dist.isend, send_tensor, send_rank, group=process_group) + return dist.batch_isend_irecv([recv_op, send_op]) + + handles = communicate_step() + # first round: special case, retrive from local tensor for i in range(group_size - 2): # actual computation - output_tensor = F.linear(input_tensors[i + 1], weight) + output_tensor = func(**input_tensors[i + 1], **input_local) for handle in handles: handle.wait() @@ -413,12 +449,10 @@ def _ReduceScatterLinear(input_, weight, process_group): send_tensor = output_tensor output_tensor = tmp_tensor - recv_op = dist.P2POp(dist.irecv, recv_tensor, recv_rank, group=process_group) - send_op = dist.P2POp(dist.isend, send_tensor, send_rank, group=process_group) - handles = dist.batch_isend_irecv([recv_op, send_op]) + handles = communicate_step() # final round: special case, no need to send/recv again - output_tensor = F.linear(input_tensors[group_size - 1], weight) + output_tensor = func(**input_tensors[-1], **input_local) for handle in handles: handle.wait() output_tensor += recv_tensor @@ -441,22 +475,41 @@ def forward(ctx, input_, weight, bias, process_group, dim): ctx.use_bias = bias is not None ctx.process_group = process_group ctx.dim = dim - if bias is not None: - partial_output = F.linear(input_, weight, bias) + + ring = True + + if ring is True: + input_to_reducescatter = {} + input_local = {} + input_to_reducescatter['input'] = input_ + input_local['weight'] = weight + + if bias is not None: + input_to_reducescatter['bias'] = bias + + output = _ring_as_reducescatter( + F.linear, + input_to_reducescatter=input_to_reducescatter, + input_local=input_local, + process_group=process_group, + ) else: - return _ReduceScatterLinear(input_, weight, process_group) + if bias is not None: + partial_output = F.linear(input_, weight, bias) + else: + partial_output = F.linear(input_, weight) - output_shape = list(partial_output.shape) - assert ( - output_shape[dim] % dist.get_world_size(process_group) == 0 - ), f"The dimension to split ({output_shape[dim]}) is not a multiple of tensor parallel size ({dist.get_world_size(process_group)}). " - output_shape[dim] = output_shape[dim] // dist.get_world_size(process_group) + output_shape = list(partial_output.shape) + assert ( + output_shape[dim] % dist.get_world_size(process_group) == 0 + ), f"The dimension to split ({output_shape[dim]}) is not a multiple of tensor parallel size ({dist.get_world_size(process_group)}). " + output_shape[dim] = output_shape[dim] // dist.get_world_size(process_group) - output_list = [ - item.contiguous() for item in torch.chunk(partial_output, dist.get_world_size(process_group), dim=dim) - ] - output = torch.empty(output_shape, dtype=partial_output.dtype, device=partial_output.device).contiguous() - dist.reduce_scatter(output, output_list, group=process_group) + output_list = [ + item.contiguous() for item in torch.chunk(partial_output, dist.get_world_size(process_group), dim=dim) + ] + output = torch.empty(output_shape, dtype=partial_output.dtype, device=partial_output.device).contiguous() + dist.reduce_scatter(output, output_list, group=process_group) return output diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index 2439de1befa0..8523f6eca541 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -630,7 +630,7 @@ def forward( # TODO (linshengjie) Block attention with ring #### block_wise = False - seq_len = query_states[2] + seq_len = query_states.shape[2] seq_block = 1024 if block_wise and seq_len > seq_block: assert query_states.shape[2] % seq_block == 0 From 5f806d0b42a9fd7b679874b2b5ff34e586fe45ae Mon Sep 17 00:00:00 2001 From: linsj20 Date: Fri, 26 Jan 2024 15:54:53 +0800 Subject: [PATCH 8/8] support mode 2 sp in gpt2 --- colossalai/shardformer/layer/_operation.py | 52 ++++++++++++------- colossalai/shardformer/layer/linear.py | 4 +- .../shardformer/layer/qkv_fused_linear.py | 12 ++++- colossalai/shardformer/modeling/gpt2.py | 7 ++- colossalai/shardformer/policies/gpt2.py | 2 +- 5 files changed, 52 insertions(+), 25 deletions(-) diff --git a/colossalai/shardformer/layer/_operation.py b/colossalai/shardformer/layer/_operation.py index 7af480432a00..8560f6463de4 100644 --- a/colossalai/shardformer/layer/_operation.py +++ b/colossalai/shardformer/layer/_operation.py @@ -270,7 +270,7 @@ class _LinearWithGatherForwardReduceScatterBackward(torch.autograd.Function): """ @staticmethod - def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap=True): + def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap=True, ring=False): ctx.save_for_backward(input_, weight, bias) ctx.use_bias = bias is not None ctx.process_group = process_group @@ -278,22 +278,21 @@ def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, ctx.dim = dim ctx.overlap = overlap - ring = True if ring is True: input_to_gather = {} input_local = {} input_to_gather['input'] = input_ input_local['weight'] = weight - if bias is not None: - input_local['bias'] = bias - output = _ring_as_gather( F.linear, input_to_gather=input_to_gather, input_local=input_local, process_group=process_group, ) + + if bias is not None: + output += bias else: input_parallel = _gather(input_, dim, process_group) if bias is not None: @@ -403,7 +402,7 @@ def backward(ctx, grad_output): # wait until reduce-scatter finished reducescatter_handle.wait() - return output, grad_weight, grad_bias, None, None, None, None + return output, grad_weight, grad_bias, None, None, None, None, None def _ring_as_reducescatter(func, input_to_reducescatter=None, input_local=None, process_group=None, reducescatter_dim=1): @@ -470,14 +469,12 @@ class _LinearWithReduceScatterForwardGatherBackward(torch.autograd.Function): """ @staticmethod - def forward(ctx, input_, weight, bias, process_group, dim): + def forward(ctx, input_, weight, bias, process_group, dim, ring): ctx.save_for_backward(input_, weight, bias) ctx.use_bias = bias is not None ctx.process_group = process_group ctx.dim = dim - ring = True - if ring is True: input_to_reducescatter = {} input_local = {} @@ -537,7 +534,7 @@ def backward(ctx, grad_output): grad_weight = grad_output.t().matmul(total_input) grad_bias = grad_output.sum(dim=0) if use_bias else None - return grad_input, grad_weight, grad_bias, None, None + return grad_input, grad_weight, grad_bias, None, None, None class _ReduceScatterForwardGatherBackward(torch.autograd.Function): @@ -586,7 +583,7 @@ class _MatmulWithGatherForwardReduceScatterBackward(torch.autograd.Function): """ @staticmethod - def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap): + def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap, ring): ctx.save_for_backward(input_, weight, bias) ctx.use_bias = bias is not None ctx.process_group = process_group @@ -594,9 +591,24 @@ def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, ctx.dim = dim ctx.overlap = overlap - input_parallel = _gather(input_, dim, process_group) + if ring is True: + input_to_gather = {} + input_local = {} + input_to_gather['input'] = input_ + input_local['other'] = weight + + output = _ring_as_gather( + torch.matmul, + input_to_gather=input_to_gather, + input_local=input_local, + process_group=process_group, + gather_dim=dim + ) + + else: + input_parallel = _gather(input_, dim, process_group) - output = torch.matmul(input_parallel, weight) + output = torch.matmul(input_parallel, weight) if bias is not None: output = output + bias @@ -677,7 +689,7 @@ def backward(ctx, grad_output): # wait until reduce-scatter finished reducescatter_handle.wait() - return output, grad_weight, grad_bias, None, None, None, None + return output, grad_weight, grad_bias, None, None, None, None, None class _SplitForwardGatherBackward(torch.autograd.Function): @@ -930,10 +942,10 @@ def linear_with_async_comm(input_, weight, bias, process_group, async_grad_allre def linear_gather_forward_reducescatter_backward( - input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap + input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap, ring=False ): return _LinearWithGatherForwardReduceScatterBackward.apply( - input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap + input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap, ring ) @@ -945,15 +957,15 @@ def reducescatter_forward_gather_backward(input_, process_group, dim): return _ReduceScatterForwardGatherBackward.apply(input_, process_group, dim) -def linear_reducescatter_forward_gather_backward(input_, weight, bias=None, process_group=None, dim=1): - return _LinearWithReduceScatterForwardGatherBackward.apply(input_, weight, bias, process_group, dim) +def linear_reducescatter_forward_gather_backward(input_, weight, bias=None, process_group=None, dim=1, ring=False): + return _LinearWithReduceScatterForwardGatherBackward.apply(input_, weight, bias, process_group, dim, ring) def matmul_gather_forward_reducescatter_backward( - input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap + input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap, ring=False ): return _MatmulWithGatherForwardReduceScatterBackward.apply( - input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap + input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap, ring ) diff --git a/colossalai/shardformer/layer/linear.py b/colossalai/shardformer/layer/linear.py index 20a9f0328cfc..a773783b9f19 100644 --- a/colossalai/shardformer/layer/linear.py +++ b/colossalai/shardformer/layer/linear.py @@ -207,7 +207,7 @@ def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]: output_parallel = linear_with_async_comm(input_parallel, self.weight, bias, self.process_group, False) elif self.seq_parallel_mode == "2": output_parallel = linear_gather_forward_reducescatter_backward( - input_parallel, self.weight, bias, self.process_group, True, self.seq_parallel_dim, self.overlap + input_parallel, self.weight, bias, self.process_group, True, self.seq_parallel_dim, self.overlap, True ) if self.gather_output: @@ -429,7 +429,9 @@ def forward(self, input_: Tensor) -> Tensor: output = linear_reducescatter_forward_gather_backward( input_, self.weight, + process_group=self.process_group, dim=self.seq_parallel_dim, + ring=True, ) if not self.skip_bias_add: diff --git a/colossalai/shardformer/layer/qkv_fused_linear.py b/colossalai/shardformer/layer/qkv_fused_linear.py index 6c5fb41494f0..a5d75db8a740 100644 --- a/colossalai/shardformer/layer/qkv_fused_linear.py +++ b/colossalai/shardformer/layer/qkv_fused_linear.py @@ -323,6 +323,11 @@ def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]: output_parallel = matmul_gather_forward_reducescatter_backward( input_parallel, self.weight, bias, self.process_group, True, 1, self.overlap ) + elif self.seq_parallel_mode == "2": + input_parallel = input_ + output_parallel = matmul_gather_forward_reducescatter_backward( + input_parallel, self.weight, bias, self.process_group, True, 1, self.overlap, True + ) if self.gather_output: # All-gather across the partitions. @@ -528,10 +533,14 @@ def forward(self, input_: Tensor) -> Tensor: handle.wait() output = torch.cat(output_parallel_list, dim=-1) else: - output_parallel = torch.matmul(input_, self.weight) if self.seq_parallel_mode is None: + output_parallel = torch.matmul(input_, self.weight) output = reduce_forward(output_parallel, self.process_group) elif self.seq_parallel_mode == "1": + output_parallel = torch.matmul(input_, self.weight) + output = reducescatter_forward_gather_backward(output_parallel, self.process_group, 1) + elif self.seq_parallel_mode == "2": + output_parallel = torch.matmul(input_, self.weight) output = reducescatter_forward_gather_backward(output_parallel, self.process_group, 1) if not self.skip_bias_add: @@ -702,7 +711,6 @@ def from_native_module( # process_group=process_group, # is_transposed=False) # linear_1d.bias.data.copy_(sharded_bias.data) - print(linear_1d.weight.shape) return linear_1d def reset_parameters(self, weight_initializer, bias_initializer) -> None: diff --git a/colossalai/shardformer/modeling/gpt2.py b/colossalai/shardformer/modeling/gpt2.py index 13ac4aa9fa1e..e9fed06d7295 100644 --- a/colossalai/shardformer/modeling/gpt2.py +++ b/colossalai/shardformer/modeling/gpt2.py @@ -922,7 +922,12 @@ def forward( head_mask = self.get_head_mask(head_mask, self.config.n_layer) if inputs_embeds is None: - inputs_embeds = self.wte(input_ids) + if sp_mode in ["2"]: + input_ids = _gather(input_ids, 1, sp_group) + inputs_embeds = self.wte(input_ids) + inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group) + else: + inputs_embeds = self.wte(input_ids) position_embeds = self.wpe(position_ids) hidden_states = inputs_embeds + position_embeds diff --git a/colossalai/shardformer/policies/gpt2.py b/colossalai/shardformer/policies/gpt2.py index 0926b0ccf27a..023e7e63c950 100644 --- a/colossalai/shardformer/policies/gpt2.py +++ b/colossalai/shardformer/policies/gpt2.py @@ -49,7 +49,7 @@ def module_policy(self): sp_size = self.shard_config.sequence_parallel_size sp_group = self.shard_config.sequence_parallel_process_group overlap = self.shard_config.enable_sequence_overlap - sp_partial_derived = sp_mode in ["1"] + sp_partial_derived = sp_mode in ["1", "2"] if sp_mode == "2": pass