From 3f2aa433cc63e83f75bb5443aa8d94d4698d1563 Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Tue, 13 May 2025 15:30:56 +0800 Subject: [PATCH 1/3] upgrade qwen2 --- colossalai/shardformer/modeling/qwen2.py | 113 ++++++++++++------ colossalai/shardformer/policies/qwen2.py | 19 +-- .../test_model/test_shard_qwen2.py | 2 +- 3 files changed, 88 insertions(+), 46 deletions(-) diff --git a/colossalai/shardformer/modeling/qwen2.py b/colossalai/shardformer/modeling/qwen2.py index 569fc4a459c5..3886bfaf8101 100644 --- a/colossalai/shardformer/modeling/qwen2.py +++ b/colossalai/shardformer/modeling/qwen2.py @@ -9,6 +9,7 @@ CausalLMOutputWithPast, SequenceClassifierOutputWithPast, ) +from transformers.cache_utils import DynamicCache try: from transformers.modeling_attn_mask_utils import ( @@ -57,6 +58,7 @@ def qwen2_model_forward( use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, return_dict: Optional[bool] = None, stage_manager: Optional[PipelineStageManager] = None, hidden_states: Optional[torch.FloatTensor] = None, @@ -131,14 +133,6 @@ def qwen2_model_forward( else: position_ids = position_ids.view(-1, seq_length).long() - if attention_mask is not None and self._attn_implementation == "flash_attention_2" and use_cache: - is_padding_right = attention_mask[:, -1].sum().item() != batch_size - if is_padding_right: - raise ValueError( - "You are attempting to perform batched generation with padding_side='right'" - " this may lead to unexpected behaviour for Flash Attention version of Qwen2. Make sure to " - " call `tokenizer.padding_side = 'left'` before tokenizing the input. " - ) # embed positions, for the first stage, hidden_states is the input embeddings, # for the other stages, hidden_states is the output of the previous stage if shard_config.enable_flash_attention: @@ -152,16 +146,16 @@ def qwen2_model_forward( is_causal=True, ) else: - if self._attn_implementation == "flash_attention_2": + if self.config._attn_implementation == "flash_attention_2": # 2d mask is passed through the layers attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None - elif self._attn_implementation == "sdpa" and not output_attentions: + elif self.config._attn_implementation == "sdpa" and not output_attentions: # output_attentions=True can not be supported when using SDPA, and we fall back on # the manual implementation that requires a 4D causal mask in all cases. attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( attention_mask, (batch_size, seq_length), - inputs_embeds, + hidden_states, past_key_values_length, ) else: @@ -195,6 +189,8 @@ def qwen2_model_forward( all_self_attns = () if output_attentions else None next_decoder_cache = None + position_embeddings = self.rotary_emb(hidden_states, position_ids) + start_idx, end_idx = stage_index[0], stage_index[1] num_ckpt_layers = 0 if self.gradient_checkpointing and self.training: @@ -216,24 +212,41 @@ def qwen2_model_forward( past_key_value = past_key_values[idx] if past_key_values is not None else None + if idx - start_idx < num_ckpt_layers: layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, + # hidden_states, + # attention_mask, + # position_ids, + # past_key_values, + # output_attentions, + # use_cache, hidden_states, attention_mask, position_ids, past_key_values, output_attentions, use_cache, + cache_position, + position_embeddings, ) else: layer_outputs = decoder_layer( + # hidden_states, + # attention_mask=attention_mask, + # position_ids=position_ids, + # past_key_value=past_key_value, + # output_attentions=output_attentions, + # use_cache=use_cache, hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, + attention_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + cache_position, + position_embeddings, ) hidden_states = layer_outputs[0] @@ -491,11 +504,10 @@ def get_qwen2_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, s def forward( self: Qwen2Attention, hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, - output_attentions: bool = False, - use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: if sp_mode is not None: @@ -519,9 +531,11 @@ def forward( value_states = all_to_all_comm(value_states, sp_group, fp8_communication=shard_config.fp8_communication) bsz, q_len, _ = query_states.size() - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + + print("value_states, value_states", value_states.shape) kv_seq_len = key_states.shape[-2] if past_key_value is not None: @@ -533,9 +547,10 @@ def forward( ) kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) # Because the input can be padded, the absolute sequence length depends on the max position id. - rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item()) + 1 - cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + # rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item()) + 1 + # cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len) + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: # Activate slicing cache only if the config has a value `sliding_windows` attribute @@ -563,7 +578,7 @@ def forward( attention_mask = attention_mask[:, slicing_tokens:] attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1) - cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) # repeat k/v heads if n_kv_heads < n_heads @@ -605,11 +620,11 @@ def forward( attn_output, sp_group, scatter_dim=1, gather_dim=2, fp8_communication=shard_config.fp8_communication ) else: - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + attn_output = attn_output.reshape(bsz, q_len, -1) attn_output = self.o_proj(attn_output) - return attn_output, None, past_key_value + return attn_output, None return forward @@ -627,6 +642,7 @@ def forward( use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, return_dict: Optional[bool] = None, force_sp_output_gather: bool = True, ) -> Union[Tuple, BaseModelOutputWithPast]: @@ -648,6 +664,10 @@ def forward( else: raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + seq_length_with_past = seq_length past_key_values_length = 0 @@ -664,12 +684,12 @@ def forward( else: position_ids = position_ids.view(-1, seq_length).long() - if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) # embed positions hidden_states = inputs_embeds + print("replace replace") + if shard_config.enable_flash_attention: # in this case, attention_mask is a dict rather than a tensor mask_shape = (batch_size, 1, seq_length, seq_length_with_past) @@ -681,6 +701,9 @@ def forward( is_causal=True, ) else: + # attention_mask = self._update_causal_mask( + # attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + # ) attention_mask = _prepare_4d_causal_attention_mask( attention_mask, (batch_size, seq_length), @@ -689,6 +712,7 @@ def forward( sliding_window=self.config.sliding_window, ) + if (self.gradient_checkpointing or sp_mode in ["ring", "all_to_all"]) and self.training: if use_cache: logger.warning_once( @@ -700,6 +724,7 @@ def forward( all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None next_decoder_cache = None + position_embeddings = self.rotary_emb(hidden_states, position_ids) if sp_mode in ["ring", "split_gather"]: hidden_states = split_forward_gather_backward( @@ -717,27 +742,41 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, + # hidden_states, + # attention_mask, + # position_ids, + # past_key_values, + # output_attentions, + # use_cache, hidden_states, attention_mask, position_ids, past_key_values, output_attentions, use_cache, + cache_position, + position_embeddings, ) else: layer_outputs = decoder_layer( + # hidden_states, + # attention_mask=attention_mask, + # position_ids=position_ids, + # past_key_value=past_key_values, + # output_attentions=output_attentions, + # use_cache=use_cache, hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_values, - output_attentions=output_attentions, - use_cache=use_cache, + attention_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + cache_position, + position_embeddings, ) hidden_states = layer_outputs[0] - if use_cache: - next_decoder_cache = layer_outputs[2 if output_attentions else 1] if output_attentions: all_self_attns += (layer_outputs[1],) diff --git a/colossalai/shardformer/policies/qwen2.py b/colossalai/shardformer/policies/qwen2.py index 84d2b2fdbd99..8a797fc59a09 100644 --- a/colossalai/shardformer/policies/qwen2.py +++ b/colossalai/shardformer/policies/qwen2.py @@ -65,15 +65,15 @@ def preprocess(self): return self.model def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: - ATTN_IMPLEMENTATION = { - "eager": Qwen2Attention, - "flash_attention_2": Qwen2FlashAttention2, - "sdpa": Qwen2SdpaAttention, - } + # ATTN_IMPLEMENTATION = { + # "eager": Qwen2Attention, + # "flash_attention_2": Qwen2FlashAttention2, + # "sdpa": Qwen2SdpaAttention, + # } policy = {} - attn_cls = ATTN_IMPLEMENTATION[self.origin_attn_implement] + # attn_cls = ATTN_IMPLEMENTATION[self.origin_attn_implement] embedding_cls = None if self.shard_config.enable_tensor_parallelism: embedding_cls = VocabParallelEmbedding1D @@ -93,7 +93,7 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: if getattr(self.model.config, "num_key_value_heads", False): decoder_attribute_replacement["num_key_value_heads"] = self.model.config.num_key_value_heads // sp_size - policy[attn_cls] = ModulePolicyDescription( + policy[Qwen2Attention] = ModulePolicyDescription( attribute_replacement=decoder_attribute_replacement, ) @@ -301,12 +301,13 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: ) if self.shard_config.enable_flash_attention or self.shard_config.enable_sequence_parallelism: + print("self.shard_config.enable_flash_attention", self.shard_config.enable_flash_attention) self.append_or_create_method_replacement( description={ "forward": get_qwen2_flash_attention_forward(self.shard_config, sp_mode, sp_size, sp_group), }, policy=policy, - target_key=attn_cls, + target_key=Qwen2Attention, ) if self.pipeline_stage_manager is None: # replace qwen2 model forward method @@ -319,6 +320,7 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: policy=policy, target_key=Qwen2Model, ) + print("policy", policy) return policy @@ -370,6 +372,7 @@ def get_held_layers(self) -> List[Module]: stage_manager = self.pipeline_stage_manager held_layers = [] + held_layers.append(module.rotary_emb) if stage_manager.is_interleave: assert stage_manager.num_model_chunks is not None layers_per_stage = stage_manager.distribute_layers(len(module.layers)) diff --git a/tests/test_shardformer/test_model/test_shard_qwen2.py b/tests/test_shardformer/test_model/test_shard_qwen2.py index 865563adc625..2bf98afe7ec7 100644 --- a/tests/test_shardformer/test_model/test_shard_qwen2.py +++ b/tests/test_shardformer/test_model/test_shard_qwen2.py @@ -193,7 +193,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "enable_sequence_parallelism": True, "sequence_parallelism_mode": "split_gather", "enable_flash_attention": False, - "use_lazy_init": True, + "use_lazy_init": False, "precision": "fp16", "initial_scale": 1, }, From 74f038f412cc058a3d6d4fc48135e36e800d2886 Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Tue, 13 May 2025 15:43:54 +0800 Subject: [PATCH 2/3] fix --- colossalai/shardformer/modeling/qwen2.py | 31 ------------------- colossalai/shardformer/policies/qwen2.py | 7 ----- .../test_model/test_shard_qwen2.py | 2 +- 3 files changed, 1 insertion(+), 39 deletions(-) diff --git a/colossalai/shardformer/modeling/qwen2.py b/colossalai/shardformer/modeling/qwen2.py index 3886bfaf8101..957b6bb3ee4d 100644 --- a/colossalai/shardformer/modeling/qwen2.py +++ b/colossalai/shardformer/modeling/qwen2.py @@ -216,12 +216,6 @@ def qwen2_model_forward( if idx - start_idx < num_ckpt_layers: layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, - # hidden_states, - # attention_mask, - # position_ids, - # past_key_values, - # output_attentions, - # use_cache, hidden_states, attention_mask, position_ids, @@ -233,12 +227,6 @@ def qwen2_model_forward( ) else: layer_outputs = decoder_layer( - # hidden_states, - # attention_mask=attention_mask, - # position_ids=position_ids, - # past_key_value=past_key_value, - # output_attentions=output_attentions, - # use_cache=use_cache, hidden_states, attention_mask, position_ids, @@ -535,7 +523,6 @@ def forward( key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - print("value_states, value_states", value_states.shape) kv_seq_len = key_states.shape[-2] if past_key_value is not None: @@ -547,8 +534,6 @@ def forward( ) kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) # Because the input can be padded, the absolute sequence length depends on the max position id. - # rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item()) + 1 - # cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len) cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) @@ -688,7 +673,6 @@ def forward( # embed positions hidden_states = inputs_embeds - print("replace replace") if shard_config.enable_flash_attention: # in this case, attention_mask is a dict rather than a tensor @@ -701,9 +685,6 @@ def forward( is_causal=True, ) else: - # attention_mask = self._update_causal_mask( - # attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions - # ) attention_mask = _prepare_4d_causal_attention_mask( attention_mask, (batch_size, seq_length), @@ -742,12 +723,6 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, - # hidden_states, - # attention_mask, - # position_ids, - # past_key_values, - # output_attentions, - # use_cache, hidden_states, attention_mask, position_ids, @@ -759,12 +734,6 @@ def forward( ) else: layer_outputs = decoder_layer( - # hidden_states, - # attention_mask=attention_mask, - # position_ids=position_ids, - # past_key_value=past_key_values, - # output_attentions=output_attentions, - # use_cache=use_cache, hidden_states, attention_mask, position_ids, diff --git a/colossalai/shardformer/policies/qwen2.py b/colossalai/shardformer/policies/qwen2.py index 8a797fc59a09..7f8a35e46bbe 100644 --- a/colossalai/shardformer/policies/qwen2.py +++ b/colossalai/shardformer/policies/qwen2.py @@ -65,15 +65,9 @@ def preprocess(self): return self.model def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: - # ATTN_IMPLEMENTATION = { - # "eager": Qwen2Attention, - # "flash_attention_2": Qwen2FlashAttention2, - # "sdpa": Qwen2SdpaAttention, - # } policy = {} - # attn_cls = ATTN_IMPLEMENTATION[self.origin_attn_implement] embedding_cls = None if self.shard_config.enable_tensor_parallelism: embedding_cls = VocabParallelEmbedding1D @@ -320,7 +314,6 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: policy=policy, target_key=Qwen2Model, ) - print("policy", policy) return policy diff --git a/tests/test_shardformer/test_model/test_shard_qwen2.py b/tests/test_shardformer/test_model/test_shard_qwen2.py index 2bf98afe7ec7..865563adc625 100644 --- a/tests/test_shardformer/test_model/test_shard_qwen2.py +++ b/tests/test_shardformer/test_model/test_shard_qwen2.py @@ -193,7 +193,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "enable_sequence_parallelism": True, "sequence_parallelism_mode": "split_gather", "enable_flash_attention": False, - "use_lazy_init": False, + "use_lazy_init": True, "precision": "fp16", "initial_scale": 1, }, From de47d22fbc6e129c81c82c9110f5f1aa0b13c8bc Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 13 May 2025 07:45:24 +0000 Subject: [PATCH 3/3] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- colossalai/shardformer/modeling/qwen2.py | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/colossalai/shardformer/modeling/qwen2.py b/colossalai/shardformer/modeling/qwen2.py index 957b6bb3ee4d..7bdf1e65f527 100644 --- a/colossalai/shardformer/modeling/qwen2.py +++ b/colossalai/shardformer/modeling/qwen2.py @@ -9,7 +9,6 @@ CausalLMOutputWithPast, SequenceClassifierOutputWithPast, ) -from transformers.cache_utils import DynamicCache try: from transformers.modeling_attn_mask_utils import ( @@ -210,8 +209,7 @@ def qwen2_model_forward( if output_hidden_states: all_hidden_states += (hidden_states,) - past_key_value = past_key_values[idx] if past_key_values is not None else None - + past_key_values[idx] if past_key_values is not None else None if idx - start_idx < num_ckpt_layers: layer_outputs = self._gradient_checkpointing_func( @@ -523,7 +521,6 @@ def forward( key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - kv_seq_len = key_states.shape[-2] if past_key_value is not None: if self.layer_idx is None: @@ -649,7 +646,6 @@ def forward( else: raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") - if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) @@ -669,11 +665,9 @@ def forward( else: position_ids = position_ids.view(-1, seq_length).long() - # embed positions hidden_states = inputs_embeds - if shard_config.enable_flash_attention: # in this case, attention_mask is a dict rather than a tensor mask_shape = (batch_size, 1, seq_length, seq_length_with_past) @@ -693,7 +687,6 @@ def forward( sliding_window=self.config.sliding_window, ) - if (self.gradient_checkpointing or sp_mode in ["ring", "all_to_all"]) and self.training: if use_cache: logger.warning_once( @@ -746,7 +739,6 @@ def forward( hidden_states = layer_outputs[0] - if output_attentions: all_self_attns += (layer_outputs[1],)