From 1571285dc7bf82172858f75fbb16ec218694bb65 Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Wed, 12 Jun 2024 08:23:19 +0000 Subject: [PATCH 1/3] upgrade transformers for mistral --- colossalai/shardformer/modeling/mistral.py | 33 ++++++++++++++++++---- colossalai/shardformer/policies/mistral.py | 2 ++ 2 files changed, 29 insertions(+), 6 deletions(-) diff --git a/colossalai/shardformer/modeling/mistral.py b/colossalai/shardformer/modeling/mistral.py index 5f96ebe3d5cd..ba92381da866 100644 --- a/colossalai/shardformer/modeling/mistral.py +++ b/colossalai/shardformer/modeling/mistral.py @@ -4,7 +4,10 @@ import torch from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from transformers.cache_utils import Cache, DynamicCache -from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask +from transformers.modeling_attn_mask_utils import ( + _prepare_4d_causal_attention_mask, + _prepare_4d_causal_attention_mask_for_sdpa, +) from transformers.modeling_outputs import ( BaseModelOutputWithPast, CausalLMOutputWithPast, @@ -77,7 +80,7 @@ def mistral_model_forward( else: position_ids = position_ids.view(-1, seq_length).long() - if attention_mask is not None and self._use_flash_attention_2 and use_cache: + if attention_mask is not None and self._attn_implementation == "flash_attention_2" and use_cache: is_padding_right = attention_mask[:, -1].sum().item() != batch_size if is_padding_right: raise ValueError( @@ -97,15 +100,24 @@ def mistral_model_forward( is_causal=True, ) else: - if self._use_flash_attention_2: + if self._attn_implementation == "flash_attention_2": # 2d mask is passed through the layers attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None + elif self._attn_implementation == "sdpa" and not output_attentions: + # output_attentions=True can not be supported when using SDPA, and we fall back on + # the manual implementation that requires a 4D causal mask in all cases. + attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( + attention_mask, + (batch_size, seq_length), + inputs_embeds, + past_key_values_length, + ) else: # 4d mask is passed through the layers attention_mask = _prepare_4d_causal_attention_mask( attention_mask, (batch_size, seq_length), - hidden_states, + inputs_embeds, past_key_values_length, sliding_window=self.config.sliding_window, ) @@ -462,7 +474,7 @@ def forward( if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) - if attention_mask is not None and self._use_flash_attention_2 and use_cache: + if attention_mask is not None and self._attn_implementation == "flash_attention_2" and use_cache: is_padding_right = attention_mask[:, -1].sum().item() != batch_size if is_padding_right: raise ValueError( @@ -481,9 +493,18 @@ def forward( is_causal=True, ) else: - if self._use_flash_attention_2: + if self._attn_implementation == "flash_attention_2": # 2d mask is passed through the layers attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None + elif self._attn_implementation == "sdpa" and not output_attentions: + # output_attentions=True can not be supported when using SDPA, and we fall back on + # the manual implementation that requires a 4D causal mask in all cases. + attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( + attention_mask, + (batch_size, seq_length), + inputs_embeds, + past_key_values_length, + ) else: # 4d mask is passed through the layers attention_mask = _prepare_4d_causal_attention_mask( diff --git a/colossalai/shardformer/policies/mistral.py b/colossalai/shardformer/policies/mistral.py index 621982f29058..c5a0277a5783 100644 --- a/colossalai/shardformer/policies/mistral.py +++ b/colossalai/shardformer/policies/mistral.py @@ -42,11 +42,13 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: MistralDecoderLayer, MistralFlashAttention2, MistralModel, + MistralSdpaAttention, ) ATTN_IMPLEMENTATION = { "eager": MistralAttention, "flash_attention_2": MistralFlashAttention2, + "sdpa": MistralSdpaAttention, } policy = {} From c863428717c6a7daa481537ed6036836c2480e7e Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Wed, 12 Jun 2024 08:24:34 +0000 Subject: [PATCH 2/3] fix --- requirements/requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements/requirements.txt b/requirements/requirements.txt index d30b26dbc787..5d8eef26f6ec 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -16,7 +16,7 @@ ray sentencepiece google protobuf -transformers==4.36.2 +transformers==4.39.3 peft>=0.7.1 bitsandbytes>=0.39.0 rpyc==6.0.0 From 5e85ac87285f99544a23b2433c1ac1a8ac56268e Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Wed, 12 Jun 2024 08:30:25 +0000 Subject: [PATCH 3/3] fix --- colossalai/shardformer/modeling/mistral.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/colossalai/shardformer/modeling/mistral.py b/colossalai/shardformer/modeling/mistral.py index ba92381da866..310c2d8e233a 100644 --- a/colossalai/shardformer/modeling/mistral.py +++ b/colossalai/shardformer/modeling/mistral.py @@ -117,7 +117,7 @@ def mistral_model_forward( attention_mask = _prepare_4d_causal_attention_mask( attention_mask, (batch_size, seq_length), - inputs_embeds, + hidden_states, past_key_values_length, sliding_window=self.config.sliding_window, )