From 3dbcdd04517d5a1bf6f17b8d20e14e8d230056ee Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Wed, 7 May 2025 17:40:21 +0800 Subject: [PATCH 1/3] upgrade mistral --- colossalai/shardformer/modeling/mistral.py | 235 ++++++++---------- colossalai/shardformer/policies/mistral.py | 19 +- .../test_model/test_shard_mistral.py | 3 +- 3 files changed, 109 insertions(+), 148 deletions(-) diff --git a/colossalai/shardformer/modeling/mistral.py b/colossalai/shardformer/modeling/mistral.py index 7fc6a1062037..5014281e8960 100644 --- a/colossalai/shardformer/modeling/mistral.py +++ b/colossalai/shardformer/modeling/mistral.py @@ -36,7 +36,7 @@ def mistral_model_forward( use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, stage_manager: Optional[PipelineStageManager] = None, hidden_states: Optional[torch.FloatTensor] = None, stage_index: Optional[List[int]] = None, @@ -50,8 +50,6 @@ def mistral_model_forward( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - # retrieve input_ids and inputs_embeds if stage_manager.is_first_stage(): if input_ids is not None and inputs_embeds is not None: @@ -71,16 +69,19 @@ def mistral_model_forward( past_key_values_length = 0 + if use_cache and past_key_values is None: + past_key_values = DynamicCache() + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + hidden_states.shape[1], device=hidden_states.device + ) + if position_ids is None: - device = input_ids.device if input_ids is not None else inputs_embeds.device - position_ids = torch.arange( - past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device - ) - position_ids = position_ids.unsqueeze(0).view(-1, seq_length) - else: - position_ids = position_ids.view(-1, seq_length).long() + position_ids = cache_position.unsqueeze(0) - if attention_mask is not None and self._attn_implementation == "flash_attention_2" and use_cache: + if attention_mask is not None and self.config._attn_implementation == "flash_attention_2" and use_cache: is_padding_right = attention_mask[:, -1].sum().item() != batch_size if is_padding_right: raise ValueError( @@ -100,26 +101,8 @@ def mistral_model_forward( is_causal=True, ) else: - if self._attn_implementation == "flash_attention_2": - # 2d mask is passed through the layers - attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None - elif self._attn_implementation == "sdpa" and not output_attentions: - # output_attentions=True can not be supported when using SDPA, and we fall back on - # the manual implementation that requires a 4D causal mask in all cases. - attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( - attention_mask, - (batch_size, seq_length), - inputs_embeds, - past_key_values_length, - ) - else: - # 4d mask is passed through the layers - attention_mask = _prepare_4d_causal_attention_mask( - attention_mask, - (batch_size, seq_length), - hidden_states, - past_key_values_length, - sliding_window=self.config.sliding_window, + attention_mask = self._update_causal_mask( + attention_mask, hidden_states, cache_position, past_key_values, output_attentions ) if self.gradient_checkpointing and self.training: @@ -133,6 +116,8 @@ def mistral_model_forward( all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None + position_embeddings = self.rotary_emb(hidden_states, position_ids) + start_idx, end_idx = stage_index[0], stage_index[1] num_ckpt_layers = 0 if self.gradient_checkpointing and self.training: @@ -155,21 +140,25 @@ def mistral_model_forward( if idx - start_idx < num_ckpt_layers: layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, - hidden_states, - attention_mask, - position_ids, - past_key_values, - output_attentions, - use_cache, - ) - else: - layer_outputs = decoder_layer( hidden_states, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_values, output_attentions=output_attentions, use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, ) hidden_states = layer_outputs[0] @@ -189,8 +178,6 @@ def mistral_model_forward( next_cache = None if stage_manager.is_last_stage(): - if not return_dict: - return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=next_cache, @@ -212,7 +199,8 @@ def mistral_for_causal_lm_forward( use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, stage_manager: Optional[PipelineStageManager] = None, hidden_states: Optional[torch.FloatTensor] = None, stage_index: Optional[List[int]] = None, @@ -248,7 +236,6 @@ def mistral_for_causal_lm_forward( output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) outputs = MistralForwards.mistral_model_forward( @@ -261,7 +248,7 @@ def mistral_for_causal_lm_forward( use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, + cache_position=cache_position, stage_manager=stage_manager, hidden_states=hidden_states, stage_index=stage_index, @@ -278,9 +265,6 @@ def mistral_for_causal_lm_forward( if labels is not None: loss = dist_cross_entropy(labels, logits, shard_config, self.lm_head.out_features, self.model.dtype) - if not return_dict: - output = (logits,) + outputs[1:] - return (loss,) + output if loss is not None else output return CausalLMOutputWithPast( loss=loss, @@ -305,7 +289,6 @@ def mistral_for_sequence_classification_forward( use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, stage_manager: Optional[PipelineStageManager] = None, hidden_states: Optional[torch.FloatTensor] = None, stage_index: Optional[List[int]] = None, @@ -317,7 +300,6 @@ def mistral_for_sequence_classification_forward( config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If `config.num_labels > 1` a classification loss is computed (Cross-Entropy). """ - return_dict = return_dict if return_dict is not None else self.config.use_return_dict transformer_outputs = MistralForwards.mistral_model_forward( self.model, @@ -329,7 +311,6 @@ def mistral_for_sequence_classification_forward( use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, stage_manager=stage_manager, hidden_states=hidden_states, stage_index=stage_index, @@ -383,9 +364,6 @@ def mistral_for_sequence_classification_forward( elif self.config.problem_type == "multi_label_classification": loss_fct = BCEWithLogitsLoss() loss = loss_fct(pooled_logits, labels) - if not return_dict: - output = (pooled_logits,) + transformer_outputs[1:] - return ((loss,) + output) if loss is not None else output else: hidden_states = transformer_outputs.get("hidden_states") return {"hidden_states": hidden_states} @@ -413,7 +391,8 @@ def forward( use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **flash_attn_kwargs, ) -> Union[Tuple, BaseModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -421,8 +400,6 @@ def forward( ) use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - # retrieve input_ids and inputs_embeds if input_ids is not None and inputs_embeds is not None: raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") @@ -433,27 +410,22 @@ def forward( else: raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") - past_key_values_length = 0 + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) - if use_cache: - use_legacy_cache = not isinstance(past_key_values, Cache) - if use_legacy_cache: - past_key_values = DynamicCache.from_legacy_cache(past_key_values) - past_key_values_length = past_key_values.get_usable_length(seq_length) + if use_cache and past_key_values is None: + past_key_values = DynamicCache() - if position_ids is None: - device = input_ids.device if input_ids is not None else inputs_embeds.device - position_ids = torch.arange( - past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) - position_ids = position_ids.unsqueeze(0).view(-1, seq_length) - else: - position_ids = position_ids.view(-1, seq_length).long() - if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) + if position_ids is None: + position_ids = cache_position.unsqueeze(0) - if attention_mask is not None and self._attn_implementation == "flash_attention_2" and use_cache: + if attention_mask is not None and self.config._attn_implementation == "flash_attention_2" and use_cache: is_padding_right = attention_mask[:, -1].sum().item() != batch_size if is_padding_right: raise ValueError( @@ -471,31 +443,11 @@ def forward( q_padding_mask=attention_mask, is_causal=True, ) - else: - if self._attn_implementation == "flash_attention_2": - # 2d mask is passed through the layers - attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None - elif self._attn_implementation == "sdpa" and not output_attentions: - # output_attentions=True can not be supported when using SDPA, and we fall back on - # the manual implementation that requires a 4D causal mask in all cases. - attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( - attention_mask, - (batch_size, seq_length), - inputs_embeds, - past_key_values_length, - ) - else: - # 4d mask is passed through the layers - attention_mask = _prepare_4d_causal_attention_mask( - attention_mask, - (batch_size, seq_length), - inputs_embeds, - past_key_values_length, - sliding_window=self.config.sliding_window, - ) hidden_states = inputs_embeds + position_embeddings = self.rotary_emb(hidden_states, position_ids) + if self.gradient_checkpointing and self.training: if use_cache: logger.warning_once( @@ -506,37 +458,49 @@ def forward( # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None - next_decoder_cache = None - for decoder_layer in self.layers: + # for decoder_layer in self.layers: + # if output_hidden_states: + # all_hidden_states += (hidden_states,) + + # if self.gradient_checkpointing and self.training: + # layer_outputs = self._gradient_checkpointing_func( + # decoder_layer.__call__, + # hidden_states, + # attention_mask, + # position_ids, + # past_key_values, + # output_attentions, + # use_cache, + # ) + # else: + # layer_outputs = decoder_layer( + # hidden_states, + # attention_mask=attention_mask, + # position_ids=position_ids, + # past_key_value=past_key_values, + # output_attentions=output_attentions, + # use_cache=use_cache, + # ) + + for decoder_layer in self.layers[: self.config.num_hidden_layers]: if output_hidden_states: all_hidden_states += (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, - hidden_states, - attention_mask, - position_ids, - past_key_values, - output_attentions, - use_cache, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_values, - output_attentions=output_attentions, - use_cache=use_cache, - ) + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **flash_attn_kwargs, + ) hidden_states = layer_outputs[0] - if use_cache: - next_decoder_cache = layer_outputs[2 if output_attentions else 1] - if output_attentions: all_self_attns += (layer_outputs[1],) @@ -546,15 +510,12 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - next_cache = None - if use_cache: - next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache - - if not return_dict: - return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) return BaseModelOutputWithPast( last_hidden_state=hidden_states, - past_key_values=next_cache, + past_key_values=past_key_values if use_cache else None, hidden_states=all_hidden_states, attentions=all_self_attns, ) @@ -568,11 +529,10 @@ def get_mistral_flash_attention_forward(shard_config: ShardConfig): def forward( self: MistralAttention, hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: if "padding_mask" in kwargs: @@ -585,9 +545,9 @@ def forward( key_states = self.k_proj(hidden_states) value_states = self.v_proj(hidden_states) - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) kv_seq_len = key_states.shape[-2] if past_key_value is not None: @@ -598,11 +558,12 @@ def forward( "with a layer index." ) kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + # cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: - cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) # repeat k/v heads if n_kv_heads < n_heads @@ -613,11 +574,11 @@ def forward( attn_output = ColoAttention.attention(query_states, key_states, value_states, **attention_mask) attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + attn_output = attn_output.reshape(bsz, q_len, -1) attn_output = self.o_proj(attn_output) - return attn_output, None, past_key_value + return attn_output, None return forward diff --git a/colossalai/shardformer/policies/mistral.py b/colossalai/shardformer/policies/mistral.py index f9c9a9404e72..63c87f31354a 100644 --- a/colossalai/shardformer/policies/mistral.py +++ b/colossalai/shardformer/policies/mistral.py @@ -41,20 +41,20 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: from transformers.models.mistral.modeling_mistral import ( MistralAttention, MistralDecoderLayer, - MistralFlashAttention2, + # MistralFlashAttention2, MistralModel, - MistralSdpaAttention, + # MistralSdpaAttention, ) - ATTN_IMPLEMENTATION = { - "eager": MistralAttention, - "flash_attention_2": MistralFlashAttention2, - "sdpa": MistralSdpaAttention, - } + # ATTN_IMPLEMENTATION = { + # "eager": MistralAttention, + # "flash_attention_2": MistralFlashAttention2, + # "sdpa": MistralSdpaAttention, + # } policy = {} - attn_cls = ATTN_IMPLEMENTATION[self.model.config._attn_implementation] + # attn_cls = ATTN_IMPLEMENTATION[self.model.config._attn_implementation] embedding_cls = None if self.shard_config.enable_tensor_parallelism: @@ -258,7 +258,7 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: "forward": get_mistral_flash_attention_forward(self.shard_config), }, policy=policy, - target_key=attn_cls, + target_key=MistralAttention, ) if self.pipeline_stage_manager is None: # replace llama model forward method @@ -316,6 +316,7 @@ def get_held_layers(self) -> List[Module]: stage_manager = self.pipeline_stage_manager held_layers = [] + held_layers.append(module.rotary_emb) if stage_manager.is_interleave: assert stage_manager.num_model_chunks is not None layers_per_stage = stage_manager.distribute_layers(len(module.layers)) diff --git a/tests/test_shardformer/test_model/test_shard_mistral.py b/tests/test_shardformer/test_model/test_shard_mistral.py index deced9d56507..3c3b26777532 100644 --- a/tests/test_shardformer/test_model/test_shard_mistral.py +++ b/tests/test_shardformer/test_model/test_shard_mistral.py @@ -22,7 +22,7 @@ os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true" - +@clear_cache_before_run() def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config): org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = build_model_from_hybrid_plugin( model_fn, loss_fn, test_config @@ -176,7 +176,6 @@ def check_mistral(rank, world_size, port): @pytest.mark.dist @rerun_if_address_is_in_use() -@clear_cache_before_run() def test_mistral(): spawn(check_mistral, 4) From 7c36b5bd05a95d7ce070cc4e33166e870c4b3b4b Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Wed, 7 May 2025 17:42:25 +0800 Subject: [PATCH 2/3] fix --- colossalai/shardformer/modeling/mistral.py | 23 ---------------------- colossalai/shardformer/policies/mistral.py | 10 ---------- 2 files changed, 33 deletions(-) diff --git a/colossalai/shardformer/modeling/mistral.py b/colossalai/shardformer/modeling/mistral.py index 5014281e8960..68468d5e0b64 100644 --- a/colossalai/shardformer/modeling/mistral.py +++ b/colossalai/shardformer/modeling/mistral.py @@ -459,29 +459,6 @@ def forward( all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None - # for decoder_layer in self.layers: - # if output_hidden_states: - # all_hidden_states += (hidden_states,) - - # if self.gradient_checkpointing and self.training: - # layer_outputs = self._gradient_checkpointing_func( - # decoder_layer.__call__, - # hidden_states, - # attention_mask, - # position_ids, - # past_key_values, - # output_attentions, - # use_cache, - # ) - # else: - # layer_outputs = decoder_layer( - # hidden_states, - # attention_mask=attention_mask, - # position_ids=position_ids, - # past_key_value=past_key_values, - # output_attentions=output_attentions, - # use_cache=use_cache, - # ) for decoder_layer in self.layers[: self.config.num_hidden_layers]: if output_hidden_states: diff --git a/colossalai/shardformer/policies/mistral.py b/colossalai/shardformer/policies/mistral.py index 63c87f31354a..ad49b1ad966d 100644 --- a/colossalai/shardformer/policies/mistral.py +++ b/colossalai/shardformer/policies/mistral.py @@ -41,21 +41,11 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: from transformers.models.mistral.modeling_mistral import ( MistralAttention, MistralDecoderLayer, - # MistralFlashAttention2, MistralModel, - # MistralSdpaAttention, ) - # ATTN_IMPLEMENTATION = { - # "eager": MistralAttention, - # "flash_attention_2": MistralFlashAttention2, - # "sdpa": MistralSdpaAttention, - # } - policy = {} - # attn_cls = ATTN_IMPLEMENTATION[self.model.config._attn_implementation] - embedding_cls = None if self.shard_config.enable_tensor_parallelism: embedding_cls = VocabParallelEmbedding1D From 7a4d3685a055c1b0e2b7b2c11112331ae1b6b2db Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 7 May 2025 09:49:04 +0000 Subject: [PATCH 3/3] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- colossalai/shardformer/modeling/mistral.py | 38 ++++++++----------- colossalai/shardformer/policies/mistral.py | 6 +-- .../test_model/test_shard_mistral.py | 1 + 3 files changed, 18 insertions(+), 27 deletions(-) diff --git a/colossalai/shardformer/modeling/mistral.py b/colossalai/shardformer/modeling/mistral.py index 68468d5e0b64..510cea066050 100644 --- a/colossalai/shardformer/modeling/mistral.py +++ b/colossalai/shardformer/modeling/mistral.py @@ -4,10 +4,6 @@ import torch from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from transformers.cache_utils import Cache, DynamicCache -from transformers.modeling_attn_mask_utils import ( - _prepare_4d_causal_attention_mask, - _prepare_4d_causal_attention_mask_for_sdpa, -) from transformers.modeling_outputs import ( BaseModelOutputWithPast, CausalLMOutputWithPast, @@ -65,18 +61,18 @@ def mistral_model_forward( else: input_shape = hidden_states.shape[:-1] batch_size, seq_length = input_shape - device = hidden_states.device + hidden_states.device past_key_values_length = 0 if use_cache and past_key_values is None: - past_key_values = DynamicCache() + past_key_values = DynamicCache() if cache_position is None: - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 - cache_position = torch.arange( - past_seen_tokens, past_seen_tokens + hidden_states.shape[1], device=hidden_states.device - ) + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + hidden_states.shape[1], device=hidden_states.device + ) if position_ids is None: position_ids = cache_position.unsqueeze(0) @@ -103,7 +99,7 @@ def mistral_model_forward( else: attention_mask = self._update_causal_mask( attention_mask, hidden_states, cache_position, past_key_values, output_attentions - ) + ) if self.gradient_checkpointing and self.training: if use_cache: @@ -151,14 +147,14 @@ def mistral_model_forward( ) else: layer_outputs = decoder_layer( - hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_values, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - position_embeddings=position_embeddings, + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, ) hidden_states = layer_outputs[0] @@ -265,7 +261,6 @@ def mistral_for_causal_lm_forward( if labels is not None: loss = dist_cross_entropy(labels, logits, shard_config, self.lm_head.out_features, self.model.dtype) - return CausalLMOutputWithPast( loss=loss, logits=logits, @@ -459,7 +454,6 @@ def forward( all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None - for decoder_layer in self.layers[: self.config.num_hidden_layers]: if output_hidden_states: all_hidden_states += (hidden_states,) @@ -487,7 +481,7 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - # add hidden states from the last decoder layer + # add hidden states from the last decoder layer if output_hidden_states: all_hidden_states += (hidden_states,) return BaseModelOutputWithPast( diff --git a/colossalai/shardformer/policies/mistral.py b/colossalai/shardformer/policies/mistral.py index ad49b1ad966d..b154723a2788 100644 --- a/colossalai/shardformer/policies/mistral.py +++ b/colossalai/shardformer/policies/mistral.py @@ -38,11 +38,7 @@ def preprocess(self): return self.model def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: - from transformers.models.mistral.modeling_mistral import ( - MistralAttention, - MistralDecoderLayer, - MistralModel, - ) + from transformers.models.mistral.modeling_mistral import MistralAttention, MistralDecoderLayer, MistralModel policy = {} diff --git a/tests/test_shardformer/test_model/test_shard_mistral.py b/tests/test_shardformer/test_model/test_shard_mistral.py index 3c3b26777532..11b61721cc14 100644 --- a/tests/test_shardformer/test_model/test_shard_mistral.py +++ b/tests/test_shardformer/test_model/test_shard_mistral.py @@ -22,6 +22,7 @@ os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true" + @clear_cache_before_run() def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config): org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = build_model_from_hybrid_plugin(