diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 8ceee2d1d45a..b6da01b7ef3f 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -955,8 +955,7 @@ def __init__(self, config: LlamaConfig): self.layers = nn.ModuleList( [LlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] ) - self._use_sdpa = config._attn_implementation == "sdpa" - self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" + self._attn_implementation = config._attn_implementation self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.gradient_checkpointing = False @@ -1024,10 +1023,10 @@ def forward( if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) - if self._use_flash_attention_2: + if self._attn_implementation == "flash_attention_2": # 2d mask is passed through the layers attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None - elif self._use_sdpa and not output_attentions: + elif self._attn_implementation == "sdpa" and not output_attentions: # output_attentions=True can not be supported when using SDPA, and we fall back on # the manual implementation that requires a 4D causal mask in all cases. attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(