From 1f8d83a2bde3925df6d558c2831d0586219aa5a9 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Tue, 19 Dec 2023 08:20:41 +0100 Subject: [PATCH 01/15] some nits --- .../models/mixtral/modeling_mixtral.py | 135 +++++++++++++++++- 1 file changed, 129 insertions(+), 6 deletions(-) diff --git a/src/transformers/models/mixtral/modeling_mixtral.py b/src/transformers/models/mixtral/modeling_mixtral.py index c07346c6de19..8599cd203608 100644 --- a/src/transformers/models/mixtral/modeling_mixtral.py +++ b/src/transformers/models/mixtral/modeling_mixtral.py @@ -33,6 +33,7 @@ from ...cache_utils import Cache, DynamicCache from ...modeling_attn_mask_utils import ( _prepare_4d_causal_attention_mask, + _prepare_4d_causal_attention_mask_for_sdpa, ) from ...modeling_outputs import ( MoeCausalLMOutputWithPast, @@ -661,6 +662,125 @@ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query ) +class MixtralSdpaAttention(MixtralAttention): + """ + Llama attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from + `LlamaAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to + SDPA API. + """ + + # Adapted from LlamaAttention.forward + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if output_attentions: + # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. + logger.warning_once( + "LlamaModel is using LlamaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " + 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + if past_key_value is not None: + # Activate slicing cache only if the config has a value `sliding_windows` attribute + cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0 + if ( + getattr(self.config, "sliding_window", None) is not None + and kv_seq_len > self.config.sliding_window + and cache_has_contents + ): + slicing_tokens = 1 - self.config.sliding_window + + past_key = past_key_value[self.layer_idx][0] + past_value = past_key_value[self.layer_idx][1] + + past_key = past_key[:, :, slicing_tokens:, :].contiguous() + past_value = past_value[:, :, slicing_tokens:, :].contiguous() + + if past_key.shape[-2] != self.config.sliding_window - 1: + raise ValueError( + f"past key must have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got" + f" {past_key.shape}" + ) + + if attention_mask is not None: + attention_mask = attention_mask[:, slicing_tokens:] + attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1) + + cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" + ) + + # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, + # Reference: https://github.com/pytorch/pytorch/issues/112577. + if query_states.device.type == "cuda" and attention_mask is not None: + query_states = query_states.contiguous() + key_states = key_states.contiguous() + value_states = value_states.contiguous() + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=attention_mask, + dropout_p=self.attention_dropout if self.training else 0.0, + # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1. + is_causal=self.is_causal and attention_mask is None and q_len > 1, + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + + attn_output = self.o_proj(attn_output) + + return attn_output, None, past_key_value + + +MISTRAL_ATTENTION_CLASSES = { + "eager": MixtralAttention, + "flash_attention_2": MixtralFlashAttention2, + "sdpa": MixtralSdpaAttention, +} + + class MixtralBLockSparseTop2MLP(nn.Module): def __init__(self, config: MixtralConfig): super().__init__() @@ -679,12 +799,6 @@ def forward(self, hidden_states): return current_hidden_states -MISTRAL_ATTENTION_CLASSES = { - "eager": MixtralAttention, - "flash_attention_2": MixtralFlashAttention2, -} - - class MixtralSparseMoeBlock(nn.Module): """ This implementation is @@ -1053,6 +1167,15 @@ def forward( if self._use_flash_attention_2: # 2d mask is passed through the layers attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None + elif self._use_sdpa and not output_attentions: + # output_attentions=True can not be supported when using SDPA, and we fall back on + # the manual implementation that requires a 4D causal mask in all cases. + attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( + attention_mask, + (batch_size, seq_length), + inputs_embeds, + past_key_values_length, + ) else: # 4d mask is passed through the layers attention_mask = _prepare_4d_causal_attention_mask( From 8f8d0249fe81bc86c114d119058ab4bd18e41d95 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Tue, 19 Dec 2023 08:25:22 +0100 Subject: [PATCH 02/15] update test --- tests/models/mixtral/test_modeling_mixtral.py | 59 +++++++++++++++++-- 1 file changed, 54 insertions(+), 5 deletions(-) diff --git a/tests/models/mixtral/test_modeling_mixtral.py b/tests/models/mixtral/test_modeling_mixtral.py index a2d5af00237b..98bea7bfa1a4 100644 --- a/tests/models/mixtral/test_modeling_mixtral.py +++ b/tests/models/mixtral/test_modeling_mixtral.py @@ -25,6 +25,7 @@ require_flash_attn, require_torch, require_torch_gpu, + require_torch_sdpa, slow, torch_device, ) @@ -38,11 +39,7 @@ if is_torch_available(): import torch - from transformers import ( - MixtralForCausalLM, - MixtralForSequenceClassification, - MixtralModel, - ) + from transformers import AutoTokenizer, MixtralForCausalLM, MixtralForSequenceClassification, MixtralModel class MixtralModelTester: @@ -461,6 +458,58 @@ def test_flash_attn_2_generate_use_cache(self): def test_flash_attn_2_inference_padding_right(self): self.skipTest("Mixtral flash attention does not support right padding") + @require_torch_sdpa + @slow + def test_eager_matches_sdpa_generate(self): + """ + Overwritting the common test as the test is flaky on tiny models + """ + max_new_tokens = 30 + + tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf") + + model_sdpa = MixtralForCausalLM.from_pretrained( + "hf-internal-testing/Mixtral-tiny", + torch_dtype=torch.float16, + low_cpu_mem_usage=True, + ).to(torch_device) + + self.assertTrue(model_sdpa.config._attn_implementation == "sdpa") + + model_eager = MixtralForCausalLM.from_pretrained( + "hf-internal-testing/Mixtral-tiny", + torch_dtype=torch.float16, + low_cpu_mem_usage=True, + attn_implementation="eager", + ).to(torch_device) + + self.assertTrue(model_eager.config._attn_implementation == "eager") + + for name, submodule in model_eager.named_modules(): + if "SdpaAttention" in submodule.__class__.__name__: + raise ValueError("The eager model should not have SDPA attention layers") + + has_sdpa = False + for name, submodule in model_sdpa.named_modules(): + if "SdpaAttention" in submodule.__class__.__name__: + has_sdpa = True + break + if not has_sdpa: + raise ValueError("The SDPA model should have SDPA attention layers") + + texts = ["hi", "Hello this is a very long sentence my friend", "Today I am in Paris and"] + + for padding_side in ["left", "right"]: + tokenizer.padding_side = padding_side + tokenizer.pad_token = tokenizer.eos_token + + inputs = tokenizer(texts, return_tensors="pt", padding=True).to(torch_device) + + res_eager = model_eager.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=False) + + res_sdpa = model_sdpa.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=False) + self.assertTrue(torch.allclose(res_eager, res_sdpa)) + # Ignore copy def test_load_balancing_loss(self): r""" From c3905b421cf86495b98782a64a262559843fb7a2 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Tue, 19 Dec 2023 08:28:42 +0100 Subject: [PATCH 03/15] add support d\sd[a --- src/transformers/models/mixtral/modeling_mixtral.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/transformers/models/mixtral/modeling_mixtral.py b/src/transformers/models/mixtral/modeling_mixtral.py index 8599cd203608..98fd484886d1 100644 --- a/src/transformers/models/mixtral/modeling_mixtral.py +++ b/src/transformers/models/mixtral/modeling_mixtral.py @@ -976,6 +976,7 @@ class MixtralPreTrainedModel(PreTrainedModel): _no_split_modules = ["MixtralDecoderLayer"] _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = True + _supports_sdpa = True _supports_cache_class = True def _init_weights(self, module): From 69f6f9d29f55821d8996e51aa6e12804f53988e6 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Tue, 19 Dec 2023 08:30:55 +0100 Subject: [PATCH 04/15] remove some dummy inputs --- src/transformers/models/mixtral/modeling_mixtral.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/transformers/models/mixtral/modeling_mixtral.py b/src/transformers/models/mixtral/modeling_mixtral.py index 98fd484886d1..7a1d99bbc9de 100644 --- a/src/transformers/models/mixtral/modeling_mixtral.py +++ b/src/transformers/models/mixtral/modeling_mixtral.py @@ -1080,7 +1080,7 @@ def __init__(self, config: MixtralConfig): self.layers = nn.ModuleList( [MixtralDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] ) - self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" + self._attn_implementation = "flash_attention_2" self.norm = MixtralRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.gradient_checkpointing = False @@ -1156,7 +1156,7 @@ def forward( if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) - if attention_mask is not None and self._use_flash_attention_2 and use_cache: + if attention_mask is not None and self._attn_implementation == "flash_attention_2"and use_cache: is_padding_right = attention_mask[:, -1].sum().item() != batch_size if is_padding_right: raise ValueError( @@ -1165,10 +1165,10 @@ def forward( " call `tokenizer.padding_side = 'left'` before tokenizing the input. " ) - if self._use_flash_attention_2: + if self._attn_implementation == "flash_attention_2": # 2d mask is passed through the layers attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None - elif self._use_sdpa and not output_attentions: + elif self._attn_implementation == "sdpa" and not output_attentions: # output_attentions=True can not be supported when using SDPA, and we fall back on # the manual implementation that requires a 4D causal mask in all cases. attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( From a4c67b2d4630418a2dd471822eb348edeb7f45ab Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Tue, 19 Dec 2023 08:32:10 +0100 Subject: [PATCH 05/15] all good --- src/transformers/models/mixtral/modeling_mixtral.py | 2 +- tests/models/mixtral/test_modeling_mixtral.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/transformers/models/mixtral/modeling_mixtral.py b/src/transformers/models/mixtral/modeling_mixtral.py index 7a1d99bbc9de..e27228d4f927 100644 --- a/src/transformers/models/mixtral/modeling_mixtral.py +++ b/src/transformers/models/mixtral/modeling_mixtral.py @@ -1080,7 +1080,7 @@ def __init__(self, config: MixtralConfig): self.layers = nn.ModuleList( [MixtralDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] ) - self._attn_implementation = "flash_attention_2" + self._attn_implementation = config._attn_implementation self.norm = MixtralRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.gradient_checkpointing = False diff --git a/tests/models/mixtral/test_modeling_mixtral.py b/tests/models/mixtral/test_modeling_mixtral.py index 98bea7bfa1a4..6ec8bf1c4d34 100644 --- a/tests/models/mixtral/test_modeling_mixtral.py +++ b/tests/models/mixtral/test_modeling_mixtral.py @@ -470,7 +470,7 @@ def test_eager_matches_sdpa_generate(self): model_sdpa = MixtralForCausalLM.from_pretrained( "hf-internal-testing/Mixtral-tiny", - torch_dtype=torch.float16, + # torch_dtype=torch.float16, low_cpu_mem_usage=True, ).to(torch_device) @@ -478,7 +478,7 @@ def test_eager_matches_sdpa_generate(self): model_eager = MixtralForCausalLM.from_pretrained( "hf-internal-testing/Mixtral-tiny", - torch_dtype=torch.float16, + # torch_dtype=torch.float16, low_cpu_mem_usage=True, attn_implementation="eager", ).to(torch_device) From 2e4fc182ab3fc219b7abda0e4dc41166cc098d5a Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Tue, 19 Dec 2023 08:32:21 +0100 Subject: [PATCH 06/15] style --- src/transformers/models/mixtral/modeling_mixtral.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/mixtral/modeling_mixtral.py b/src/transformers/models/mixtral/modeling_mixtral.py index e27228d4f927..8cd52c5e0cc4 100644 --- a/src/transformers/models/mixtral/modeling_mixtral.py +++ b/src/transformers/models/mixtral/modeling_mixtral.py @@ -1156,7 +1156,7 @@ def forward( if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) - if attention_mask is not None and self._attn_implementation == "flash_attention_2"and use_cache: + if attention_mask is not None and self._attn_implementation == "flash_attention_2" and use_cache: is_padding_right = attention_mask[:, -1].sum().item() != batch_size if is_padding_right: raise ValueError( @@ -1165,7 +1165,7 @@ def forward( " call `tokenizer.padding_side = 'left'` before tokenizing the input. " ) - if self._attn_implementation == "flash_attention_2": + if self._attn_implementation == "flash_attention_2": # 2d mask is passed through the layers attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None elif self._attn_implementation == "sdpa" and not output_attentions: From c22839ec7209dee368cac0cca444fc4074b8de8c Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Wed, 20 Dec 2023 11:52:08 +0100 Subject: [PATCH 07/15] nits --- .../models/mistral/modeling_mistral.py | 133 +++++++++++++++++- .../models/mixtral/modeling_mixtral.py | 5 +- tests/models/mixtral/test_modeling_mixtral.py | 55 +------- 3 files changed, 132 insertions(+), 61 deletions(-) diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index ee51bcea794e..a8a5492e0293 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -31,7 +31,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache -from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask +from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast from ...modeling_utils import PreTrainedModel from ...utils import ( @@ -610,9 +610,122 @@ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query ) +class MixtralSdpaAttention(MistralAttention): + """ + Llama attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from + `LlamaAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to + SDPA API. + """ + + # Adapted from LlamaAttention.forward + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if output_attentions: + # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. + logger.warning_once( + "LlamaModel is using LlamaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " + 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + if past_key_value is not None: + # Activate slicing cache only if the config has a value `sliding_windows` attribute + cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0 + if ( + getattr(self.config, "sliding_window", None) is not None + and kv_seq_len > self.config.sliding_window + and cache_has_contents + ): + slicing_tokens = 1 - self.config.sliding_window + + past_key = past_key_value[self.layer_idx][0] + past_value = past_key_value[self.layer_idx][1] + + past_key = past_key[:, :, slicing_tokens:, :].contiguous() + past_value = past_value[:, :, slicing_tokens:, :].contiguous() + + if past_key.shape[-2] != self.config.sliding_window - 1: + raise ValueError( + f"past key must have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got" + f" {past_key.shape}" + ) + + if attention_mask is not None: + attention_mask = attention_mask[:, slicing_tokens:] + attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1) + + cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" + ) + + # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, + # Reference: https://github.com/pytorch/pytorch/issues/112577. + if query_states.device.type == "cuda" and attention_mask is not None: + query_states = query_states.contiguous() + key_states = key_states.contiguous() + value_states = value_states.contiguous() + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=attention_mask, + dropout_p=self.attention_dropout if self.training else 0.0, + # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1. + is_causal=self.is_causal and attention_mask is None and q_len > 1, + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + + attn_output = self.o_proj(attn_output) + + return attn_output, None, past_key_value + + MISTRAL_ATTENTION_CLASSES = { "eager": MistralAttention, "flash_attention_2": MistralFlashAttention2, + "sdpa": MixtralSdpaAttention, } @@ -715,6 +828,7 @@ class MistralPreTrainedModel(PreTrainedModel): _no_split_modules = ["MistralDecoderLayer"] _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = True + _supports_sdpa = True _supports_cache_class = True def _init_weights(self, module): @@ -820,7 +934,7 @@ def __init__(self, config: MistralConfig): self.layers = nn.ModuleList( [MistralDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] ) - self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" + self._attn_implementation = config._attn_implementation self.norm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.gradient_checkpointing = False @@ -891,18 +1005,27 @@ def forward( if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) - if attention_mask is not None and self._use_flash_attention_2 and use_cache: + if attention_mask is not None and self._attn_implementation == "flash_attention_2" and use_cache: is_padding_right = attention_mask[:, -1].sum().item() != batch_size if is_padding_right: raise ValueError( "You are attempting to perform batched generation with padding_side='right'" - " this may lead to unexpected behaviour for Flash Attention version of Mistral. Make sure to " + " this may lead to unexpected behaviour for Flash Attention version of Mixtral. Make sure to " " call `tokenizer.padding_side = 'left'` before tokenizing the input. " ) - if self._use_flash_attention_2: + if self._attn_implementation == "flash_attention_2": # 2d mask is passed through the layers attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None + elif self._attn_implementation == "sdpa" and not output_attentions: + # output_attentions=True can not be supported when using SDPA, and we fall back on + # the manual implementation that requires a 4D causal mask in all cases. + attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( + attention_mask, + (batch_size, seq_length), + inputs_embeds, + past_key_values_length, + ) else: # 4d mask is passed through the layers attention_mask = _prepare_4d_causal_attention_mask( diff --git a/src/transformers/models/mixtral/modeling_mixtral.py b/src/transformers/models/mixtral/modeling_mixtral.py index 8cd52c5e0cc4..5d05797dc69a 100644 --- a/src/transformers/models/mixtral/modeling_mixtral.py +++ b/src/transformers/models/mixtral/modeling_mixtral.py @@ -662,7 +662,8 @@ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query ) -class MixtralSdpaAttention(MixtralAttention): +# Copied from transformers.models.mistral.modeling_mistral.MistralSdpaAttention2 with Mistral->Mixtral +class MixtralSdpaAttention2(MixtralAttention): """ Llama attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from `LlamaAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to @@ -874,7 +875,7 @@ def __init__(self, config: MixtralConfig, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size - self.self_attn = MISTRAL_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx) + self.self_attn = MIXTRAL_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx) self.block_sparse_moe = MixtralSparseMoeBlock(config) self.input_layernorm = MixtralRMSNorm(config.hidden_size, eps=config.rms_norm_eps) diff --git a/tests/models/mixtral/test_modeling_mixtral.py b/tests/models/mixtral/test_modeling_mixtral.py index 6ec8bf1c4d34..efc2321e6998 100644 --- a/tests/models/mixtral/test_modeling_mixtral.py +++ b/tests/models/mixtral/test_modeling_mixtral.py @@ -25,7 +25,6 @@ require_flash_attn, require_torch, require_torch_gpu, - require_torch_sdpa, slow, torch_device, ) @@ -39,7 +38,7 @@ if is_torch_available(): import torch - from transformers import AutoTokenizer, MixtralForCausalLM, MixtralForSequenceClassification, MixtralModel + from transformers import MixtralForCausalLM, MixtralForSequenceClassification, MixtralModel class MixtralModelTester: @@ -458,58 +457,6 @@ def test_flash_attn_2_generate_use_cache(self): def test_flash_attn_2_inference_padding_right(self): self.skipTest("Mixtral flash attention does not support right padding") - @require_torch_sdpa - @slow - def test_eager_matches_sdpa_generate(self): - """ - Overwritting the common test as the test is flaky on tiny models - """ - max_new_tokens = 30 - - tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf") - - model_sdpa = MixtralForCausalLM.from_pretrained( - "hf-internal-testing/Mixtral-tiny", - # torch_dtype=torch.float16, - low_cpu_mem_usage=True, - ).to(torch_device) - - self.assertTrue(model_sdpa.config._attn_implementation == "sdpa") - - model_eager = MixtralForCausalLM.from_pretrained( - "hf-internal-testing/Mixtral-tiny", - # torch_dtype=torch.float16, - low_cpu_mem_usage=True, - attn_implementation="eager", - ).to(torch_device) - - self.assertTrue(model_eager.config._attn_implementation == "eager") - - for name, submodule in model_eager.named_modules(): - if "SdpaAttention" in submodule.__class__.__name__: - raise ValueError("The eager model should not have SDPA attention layers") - - has_sdpa = False - for name, submodule in model_sdpa.named_modules(): - if "SdpaAttention" in submodule.__class__.__name__: - has_sdpa = True - break - if not has_sdpa: - raise ValueError("The SDPA model should have SDPA attention layers") - - texts = ["hi", "Hello this is a very long sentence my friend", "Today I am in Paris and"] - - for padding_side in ["left", "right"]: - tokenizer.padding_side = padding_side - tokenizer.pad_token = tokenizer.eos_token - - inputs = tokenizer(texts, return_tensors="pt", padding=True).to(torch_device) - - res_eager = model_eager.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=False) - - res_sdpa = model_sdpa.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=False) - self.assertTrue(torch.allclose(res_eager, res_sdpa)) - # Ignore copy def test_load_balancing_loss(self): r""" From cc724db0fe37a23b20323c09eeeaf66de4908496 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Wed, 20 Dec 2023 11:52:52 +0100 Subject: [PATCH 08/15] fixes --- src/transformers/models/mistral/modeling_mistral.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index a8a5492e0293..f0fe03abccc7 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -610,7 +610,7 @@ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query ) -class MixtralSdpaAttention(MistralAttention): +class MistralSdpaAttention(MistralAttention): """ Llama attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from `LlamaAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to @@ -725,7 +725,7 @@ def forward( MISTRAL_ATTENTION_CLASSES = { "eager": MistralAttention, "flash_attention_2": MistralFlashAttention2, - "sdpa": MixtralSdpaAttention, + "sdpa": MistralSdpaAttention, } From b48e0645672e5bcc870cd7ac4f46529c24f966a9 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Wed, 20 Dec 2023 11:55:04 +0100 Subject: [PATCH 09/15] fix more copies --- .../models/mistral/modeling_mistral.py | 2 +- .../models/mixtral/modeling_mixtral.py | 37 +++---------------- 2 files changed, 7 insertions(+), 32 deletions(-) diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index f0fe03abccc7..a5c4cb3bfe34 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -609,7 +609,7 @@ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query (max_seqlen_in_batch_q, max_seqlen_in_batch_k), ) - +# Copied from transformers.models.llama.modeling_llama.LlamaSdpaAttention with Llama->Mistral class MistralSdpaAttention(MistralAttention): """ Llama attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from diff --git a/src/transformers/models/mixtral/modeling_mixtral.py b/src/transformers/models/mixtral/modeling_mixtral.py index 5d05797dc69a..87327043bd5d 100644 --- a/src/transformers/models/mixtral/modeling_mixtral.py +++ b/src/transformers/models/mixtral/modeling_mixtral.py @@ -662,15 +662,15 @@ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query ) -# Copied from transformers.models.mistral.modeling_mistral.MistralSdpaAttention2 with Mistral->Mixtral -class MixtralSdpaAttention2(MixtralAttention): +# Copied from transformers.models.llama.modeling_llama.LlamaSdpaAttention with Llama->Mixtral +class MixtralSdpaAttention(MixtralAttention): """ - Llama attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from - `LlamaAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to + Mixtral attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from + `MixtralAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to SDPA API. """ - # Adapted from LlamaAttention.forward + # Adapted from MixtralAttention.forward def forward( self, hidden_states: torch.Tensor, @@ -683,7 +683,7 @@ def forward( if output_attentions: # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. logger.warning_once( - "LlamaModel is using LlamaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " + "MixtralModel is using MixtralSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' ) return super().forward( @@ -713,31 +713,6 @@ def forward( query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: - # Activate slicing cache only if the config has a value `sliding_windows` attribute - cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0 - if ( - getattr(self.config, "sliding_window", None) is not None - and kv_seq_len > self.config.sliding_window - and cache_has_contents - ): - slicing_tokens = 1 - self.config.sliding_window - - past_key = past_key_value[self.layer_idx][0] - past_value = past_key_value[self.layer_idx][1] - - past_key = past_key[:, :, slicing_tokens:, :].contiguous() - past_value = past_value[:, :, slicing_tokens:, :].contiguous() - - if past_key.shape[-2] != self.config.sliding_window - 1: - raise ValueError( - f"past key must have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got" - f" {past_key.shape}" - ) - - if attention_mask is not None: - attention_mask = attention_mask[:, slicing_tokens:] - attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1) - cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) From 373cf1650ee015de96f7f07df1907c7604091959 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Wed, 20 Dec 2023 11:58:12 +0100 Subject: [PATCH 10/15] nits --- .../models/mistral/modeling_mistral.py | 33 +++---------------- 1 file changed, 4 insertions(+), 29 deletions(-) diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index a5c4cb3bfe34..386a40b90aed 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -612,12 +612,12 @@ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query # Copied from transformers.models.llama.modeling_llama.LlamaSdpaAttention with Llama->Mistral class MistralSdpaAttention(MistralAttention): """ - Llama attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from - `LlamaAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to + Mistral attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from + `MistralAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to SDPA API. """ - # Adapted from LlamaAttention.forward + # Adapted from MistralAttention.forward def forward( self, hidden_states: torch.Tensor, @@ -630,7 +630,7 @@ def forward( if output_attentions: # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. logger.warning_once( - "LlamaModel is using LlamaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " + "MistralModel is using MistralSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' ) return super().forward( @@ -660,31 +660,6 @@ def forward( query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: - # Activate slicing cache only if the config has a value `sliding_windows` attribute - cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0 - if ( - getattr(self.config, "sliding_window", None) is not None - and kv_seq_len > self.config.sliding_window - and cache_has_contents - ): - slicing_tokens = 1 - self.config.sliding_window - - past_key = past_key_value[self.layer_idx][0] - past_value = past_key_value[self.layer_idx][1] - - past_key = past_key[:, :, slicing_tokens:, :].contiguous() - past_value = past_value[:, :, slicing_tokens:, :].contiguous() - - if past_key.shape[-2] != self.config.sliding_window - 1: - raise ValueError( - f"past key must have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got" - f" {past_key.shape}" - ) - - if attention_mask is not None: - attention_mask = attention_mask[:, slicing_tokens:] - attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1) - cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) From d49cec1257db187f97c91249630644b63133e2c0 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Wed, 20 Dec 2023 12:01:05 +0100 Subject: [PATCH 11/15] styling --- src/transformers/models/mistral/modeling_mistral.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index 386a40b90aed..6a84d94169a8 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -609,6 +609,7 @@ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query (max_seqlen_in_batch_q, max_seqlen_in_batch_k), ) + # Copied from transformers.models.llama.modeling_llama.LlamaSdpaAttention with Llama->Mistral class MistralSdpaAttention(MistralAttention): """ From b6e69291201a364fafc2d9f09a88e5ae4ba324f8 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Wed, 20 Dec 2023 12:01:36 +0100 Subject: [PATCH 12/15] fix --- src/transformers/models/mixtral/modeling_mixtral.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/mixtral/modeling_mixtral.py b/src/transformers/models/mixtral/modeling_mixtral.py index 87327043bd5d..3e32a43d94dd 100644 --- a/src/transformers/models/mixtral/modeling_mixtral.py +++ b/src/transformers/models/mixtral/modeling_mixtral.py @@ -750,7 +750,7 @@ def forward( return attn_output, None, past_key_value -MISTRAL_ATTENTION_CLASSES = { +MIXTRAL_ATTENTION_CLASSES = { "eager": MixtralAttention, "flash_attention_2": MixtralFlashAttention2, "sdpa": MixtralSdpaAttention, From eeb456b8bd4bf05d94aaa0577c36ddc87bba3895 Mon Sep 17 00:00:00 2001 From: Arthur <48595927+ArthurZucker@users.noreply.github.com> Date: Wed, 20 Dec 2023 19:33:04 +0100 Subject: [PATCH 13/15] Update src/transformers/models/mistral/modeling_mistral.py Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> --- src/transformers/models/mistral/modeling_mistral.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index 6a84d94169a8..5127a4df32f7 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -986,7 +986,7 @@ def forward( if is_padding_right: raise ValueError( "You are attempting to perform batched generation with padding_side='right'" - " this may lead to unexpected behaviour for Flash Attention version of Mixtral. Make sure to " + " this may lead to unexpected behaviour for Flash Attention version of Mistral. Make sure to " " call `tokenizer.padding_side = 'left'` before tokenizing the input. " ) From 8475dc1d6d8bed068866d8c588924cfcd8b78ee4 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Wed, 20 Dec 2023 19:48:13 +0100 Subject: [PATCH 14/15] add a slow test just to be sure --- tests/models/mistral/test_modeling_mistral.py | 43 +++++++++++++++++++ 1 file changed, 43 insertions(+) diff --git a/tests/models/mistral/test_modeling_mistral.py b/tests/models/mistral/test_modeling_mistral.py index 35a2341b4e69..41ee08948471 100644 --- a/tests/models/mistral/test_modeling_mistral.py +++ b/tests/models/mistral/test_modeling_mistral.py @@ -28,6 +28,7 @@ require_flash_attn, require_torch, require_torch_gpu, + require_torch_sdpa, slow, torch_device, ) @@ -527,3 +528,45 @@ def test_model_7b_long_prompt(self): del model backend_empty_cache(torch_device) gc.collect() + + @slow + @require_torch_sdpa + def test_model_7b_long_prompt_sdpa(self): + EXPECTED_OUTPUT_TOKEN_IDS = [306, 338] + # An input with 4097 tokens that is above the size of the sliding window + input_ids = [1] + [306, 338] * 2048 + model = MistralForCausalLM.from_pretrained( + "mistralai/Mistral-7B-v0.1", + device_map="auto", + attn_implementation="sdpa", + ) + input_ids = torch.tensor([input_ids]).to(model.model.embed_tokens.weight.device) + generated_ids = model.generate(input_ids, max_new_tokens=4, temperature=0) + self.assertEqual(EXPECTED_OUTPUT_TOKEN_IDS, generated_ids[0][-2:].tolist()) + + # Assisted generation + assistant_model = model + assistant_model.generation_config.num_assistant_tokens = 2 + assistant_model.generation_config.num_assistant_tokens_schedule = "constant" + generated_ids = model.generate(input_ids, max_new_tokens=4, temperature=0) + self.assertEqual(EXPECTED_OUTPUT_TOKEN_IDS, generated_ids[0][-2:].tolist()) + + del assistant_model + + backend_empty_cache(torch_device) + gc.collect() + + EXPECTED_TEXT_COMPLETION = """My favourite condiment is 100% ketchup. I love it on everything. I’m not a big""" + prompt = "My favourite condiment is " + tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1", use_fast=False) + + input_ids = tokenizer.encode(prompt, return_tensors="pt").to(model.model.embed_tokens.weight.device) + + # greedy generation outputs + generated_ids = model.generate(input_ids, max_new_tokens=20, temperature=0) + text = tokenizer.decode(generated_ids[0], skip_special_tokens=True) + self.assertEqual(EXPECTED_TEXT_COMPLETION, text) + + del model + backend_empty_cache(torch_device) + gc.collect() From c932e141f3e51a3e5d2b610b1e3f908f6c028383 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Thu, 21 Dec 2023 08:51:46 +0100 Subject: [PATCH 15/15] fixup --- tests/models/mistral/test_modeling_mistral.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/models/mistral/test_modeling_mistral.py b/tests/models/mistral/test_modeling_mistral.py index 1a5028944dee..5e91e70ecd5b 100644 --- a/tests/models/mistral/test_modeling_mistral.py +++ b/tests/models/mistral/test_modeling_mistral.py @@ -564,6 +564,8 @@ def test_model_7b_long_prompt_sdpa(self): # greedy generation outputs generated_ids = model.generate(input_ids, max_new_tokens=20, temperature=0) + text = tokenizer.decode(generated_ids[0], skip_special_tokens=True) + self.assertEqual(EXPECTED_TEXT_COMPLETION, text) @slow def test_speculative_generation(self):