diff --git a/src/transformers/models/bart/modeling_bart.py b/src/transformers/models/bart/modeling_bart.py index 33fea9a0183a..bcc9ac3e0746 100755 --- a/src/transformers/models/bart/modeling_bart.py +++ b/src/transformers/models/bart/modeling_bart.py @@ -592,6 +592,11 @@ def forward( query_states = self._shape(query_states, tgt_len, bsz) + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1. + is_causal = True if self.is_causal and attention_mask is None and tgt_len > 1 else False + # NOTE: SDPA with memory-efficient backend is currently (torch==2.1.2) bugged when using non-contiguous inputs and a custom attn_mask, # but we are fine here as `_shape` do call `.contiguous()`. Reference: https://github.com/pytorch/pytorch/issues/112577 attn_output = torch.nn.functional.scaled_dot_product_attention( @@ -600,8 +605,7 @@ def forward( value_states, attn_mask=attention_mask, dropout_p=self.dropout if self.training else 0.0, - # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1. - is_causal=self.is_causal and attention_mask is None and tgt_len > 1, + is_causal=is_causal, ) if attn_output.size() != (bsz, self.num_heads, tgt_len, self.head_dim): diff --git a/src/transformers/models/bert/modeling_bert.py b/src/transformers/models/bert/modeling_bert.py index b516a97187b8..033dc6ba6664 100755 --- a/src/transformers/models/bert/modeling_bert.py +++ b/src/transformers/models/bert/modeling_bert.py @@ -428,9 +428,11 @@ def forward( key_layer = key_layer.contiguous() value_layer = value_layer.contiguous() - # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal - # mask in case tgt_len == 1. - is_causal = self.is_decoder and attention_mask is None and tgt_len > 1 + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create + # a causal mask in case tgt_len == 1. + is_causal = True if self.is_decoder and attention_mask is None and tgt_len > 1 else False attn_output = torch.nn.functional.scaled_dot_product_attention( query_layer, diff --git a/src/transformers/models/cohere/modeling_cohere.py b/src/transformers/models/cohere/modeling_cohere.py index d96131d7705e..f85b3d32c1d5 100644 --- a/src/transformers/models/cohere/modeling_cohere.py +++ b/src/transformers/models/cohere/modeling_cohere.py @@ -588,8 +588,8 @@ def forward( key_states = key_states.contiguous() value_states = value_states.contiguous() - # We dispatch to SDPA's Flash Attention or Efficient kernels via this if statement instead of an - # inline conditional assignment to support both torch.compile's `dynamic=True` and `fullgraph=True` + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. is_causal = True if causal_mask is None and q_len > 1 else False attn_output = torch.nn.functional.scaled_dot_product_attention( diff --git a/src/transformers/models/data2vec/modeling_data2vec_audio.py b/src/transformers/models/data2vec/modeling_data2vec_audio.py index e77bc728ab36..8837c278c509 100755 --- a/src/transformers/models/data2vec/modeling_data2vec_audio.py +++ b/src/transformers/models/data2vec/modeling_data2vec_audio.py @@ -788,6 +788,11 @@ def forward( query_states = self._shape(query_states, tgt_len, bsz) + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1. + is_causal = True if self.is_causal and attention_mask is None and tgt_len > 1 else False + # NOTE: SDPA with memory-efficient backend is currently (torch==2.1.2) bugged when using non-contiguous inputs and a custom attn_mask, # but we are fine here as `_shape` do call `.contiguous()`. Reference: https://github.com/pytorch/pytorch/issues/112577 attn_output = torch.nn.functional.scaled_dot_product_attention( @@ -796,8 +801,7 @@ def forward( value_states, attn_mask=attention_mask, dropout_p=self.dropout if self.training else 0.0, - # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1. - is_causal=self.is_causal and attention_mask is None and tgt_len > 1, + is_causal=is_causal, ) if attn_output.size() != (bsz, self.num_heads, tgt_len, self.head_dim): diff --git a/src/transformers/models/falcon/modeling_falcon.py b/src/transformers/models/falcon/modeling_falcon.py index 76ca4110e818..6359e67bfb86 100644 --- a/src/transformers/models/falcon/modeling_falcon.py +++ b/src/transformers/models/falcon/modeling_falcon.py @@ -441,16 +441,19 @@ def forward( if alibi is None: if self._use_sdpa and not output_attentions: - attn_output = F.scaled_dot_product_attention( + # We dispatch to SDPA's Flash Attention or Efficient kernels via this if statement instead of an + # inline conditional assignment to support both torch.compile's `dynamic=True` and `fullgraph=True` + # The query_length > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not + # create a causal mask in case query_length == 1. + is_causal = True if self.is_causal and attention_mask is None and query_length > 1 else False + attn_output = torch.nn.functional.scaled_dot_product_attention( query_layer, key_layer, value_layer, - attention_mask, - 0.0, - # The query_length > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case query_length == 1. - is_causal=self.is_causal and attention_mask is None and query_length > 1, + attn_mask=attention_mask, + dropout_p=0.0, + is_causal=is_causal, ) - attention_scores = None else: attention_scores = query_layer @ key_layer.transpose(-1, -2) @@ -473,13 +476,16 @@ def forward( else: if self._use_sdpa and not output_attentions and head_mask is None: - attn_output = F.scaled_dot_product_attention( + # We dispatch to SDPA's Flash Attention or Efficient kernels via this if statement instead of an + # inline conditional assignment to support both torch.compile's `dynamic=True` and `fullgraph=True` + is_causal = True if self.is_causal and attention_mask is None and query_length > 1 else False + attn_output = torch.nn.functional.scaled_dot_product_attention( query_layer, key_layer, value_layer, attn_mask=attention_mask, dropout_p=self.attention_dropout.p if self.training else 0.0, - is_causal=self.is_causal and attention_mask is None and query_length > 1, + is_causal=is_causal, ) attn_output = attn_output.transpose(1, 2) attn_output = attn_output.reshape(batch_size, query_length, self.num_heads * self.head_dim) diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index 8f7893704780..2a32f7b6a16a 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -568,8 +568,8 @@ def forward( key_states = key_states.contiguous() value_states = value_states.contiguous() - # We dispatch to SDPA's Flash Attention or Efficient kernels via this if statement instead of an - # inline conditional assignment to support both torch.compile's `dynamic=True` and `fullgraph=True` + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. is_causal = True if causal_mask is None and q_len > 1 else False attn_output = torch.nn.functional.scaled_dot_product_attention( diff --git a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py index 37ed2aba6208..7b7bfaf1d423 100644 --- a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +++ b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py @@ -549,14 +549,19 @@ def _attn(self, query, key, value, attention_mask=None, head_mask=None): key = key.contiguous() value = value.contiguous() + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + # The query_length > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not + # create a causal mask in case query_length == 1. + is_causal = True if self.is_causal and attention_mask is None and query_length > 1 else False + sdpa_result = torch.nn.functional.scaled_dot_product_attention( query, key, value, attn_mask=attention_mask, dropout_p=self.attn_pdrop if self.training else 0.0, - # The query_length > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case query_length == 1. - is_causal=self.is_causal and attention_mask is None and query_length > 1, + is_causal=is_causal, scale=scale, ) diff --git a/src/transformers/models/hubert/modeling_hubert.py b/src/transformers/models/hubert/modeling_hubert.py index 3d1d0884c6ae..83d90f5ded6b 100755 --- a/src/transformers/models/hubert/modeling_hubert.py +++ b/src/transformers/models/hubert/modeling_hubert.py @@ -852,6 +852,11 @@ def forward( query_states = self._shape(query_states, tgt_len, bsz) + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1. + is_causal = True if self.is_causal and attention_mask is None and tgt_len > 1 else False + # NOTE: SDPA with memory-efficient backend is currently (torch==2.1.2) bugged when using non-contiguous inputs and a custom attn_mask, # but we are fine here as `_shape` do call `.contiguous()`. Reference: https://github.com/pytorch/pytorch/issues/112577 attn_output = torch.nn.functional.scaled_dot_product_attention( @@ -860,8 +865,7 @@ def forward( value_states, attn_mask=attention_mask, dropout_p=self.dropout if self.training else 0.0, - # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1. - is_causal=self.is_causal and attention_mask is None and tgt_len > 1, + is_causal=is_causal, ) if attn_output.size() != (bsz, self.num_heads, tgt_len, self.head_dim): diff --git a/src/transformers/models/idefics/modeling_idefics.py b/src/transformers/models/idefics/modeling_idefics.py index 622e336fe403..68ec6ab8a5f0 100644 --- a/src/transformers/models/idefics/modeling_idefics.py +++ b/src/transformers/models/idefics/modeling_idefics.py @@ -660,14 +660,18 @@ def forward( key_states = key_states.contiguous() value_states = value_states.contiguous() - attn_output = nn.functional.scaled_dot_product_attention( + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1. + is_causal = True if self.is_causal and attention_mask is None and q_len > 1 else False + + attn_output = torch.nn.functional.scaled_dot_product_attention( query_states, key_states, value_states, attn_mask=attention_mask, - dropout_p=self.dropout, - # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1. - is_causal=self.is_causal and attention_mask is None and q_len > 1, + dropout_p=self.dropout if self.training else 0.0, + is_causal=is_causal, ) if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index d840b03faf71..12c3d36af690 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -666,8 +666,8 @@ def forward( key_states = key_states.contiguous() value_states = value_states.contiguous() - # We dispatch to SDPA's Flash Attention or Efficient kernels via this if statement instead of an - # inline conditional assignment to support both torch.compile's `dynamic=True` and `fullgraph=True` + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. is_causal = True if causal_mask is None and q_len > 1 else False attn_output = torch.nn.functional.scaled_dot_product_attention( diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index 665e95a8fd78..ddfe544ebd75 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -685,14 +685,18 @@ def forward( key_states = key_states.contiguous() value_states = value_states.contiguous() + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1. + is_causal = True if self.is_causal and attention_mask is None and q_len > 1 else False + attn_output = torch.nn.functional.scaled_dot_product_attention( query_states, key_states, value_states, attn_mask=attention_mask, dropout_p=self.attention_dropout if self.training else 0.0, - # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1. - is_causal=self.is_causal and attention_mask is None and q_len > 1, + is_causal=is_causal, ) attn_output = attn_output.transpose(1, 2).contiguous() diff --git a/src/transformers/models/mixtral/modeling_mixtral.py b/src/transformers/models/mixtral/modeling_mixtral.py index e5a81c4c9083..2bd4527ac622 100644 --- a/src/transformers/models/mixtral/modeling_mixtral.py +++ b/src/transformers/models/mixtral/modeling_mixtral.py @@ -762,14 +762,18 @@ def forward( key_states = key_states.contiguous() value_states = value_states.contiguous() + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1. + is_causal = True if self.is_causal and attention_mask is None and q_len > 1 else False + attn_output = torch.nn.functional.scaled_dot_product_attention( query_states, key_states, value_states, attn_mask=attention_mask, dropout_p=self.attention_dropout if self.training else 0.0, - # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1. - is_causal=self.is_causal and attention_mask is None and q_len > 1, + is_causal=is_causal, ) attn_output = attn_output.transpose(1, 2).contiguous() diff --git a/src/transformers/models/musicgen/modeling_musicgen.py b/src/transformers/models/musicgen/modeling_musicgen.py index 9d1cf6e568f6..90ce7c1f06e8 100644 --- a/src/transformers/models/musicgen/modeling_musicgen.py +++ b/src/transformers/models/musicgen/modeling_musicgen.py @@ -618,6 +618,11 @@ def forward( query_states = self._shape(query_states, tgt_len, bsz) + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1. + is_causal = True if self.is_causal and attention_mask is None and tgt_len > 1 else False + # NOTE: SDPA with memory-efficient backend is currently (torch==2.1.2) bugged when using non-contiguous inputs and a custom attn_mask, # but we are fine here as `_shape` do call `.contiguous()`. Reference: https://github.com/pytorch/pytorch/issues/112577 attn_output = torch.nn.functional.scaled_dot_product_attention( @@ -626,8 +631,7 @@ def forward( value_states, attn_mask=attention_mask, dropout_p=self.dropout if self.training else 0.0, - # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1. - is_causal=self.is_causal and attention_mask is None and tgt_len > 1, + is_causal=is_causal, ) if attn_output.size() != (bsz, self.num_heads, tgt_len, self.head_dim): diff --git a/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py b/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py index 63fc638f1644..fdce73b6148a 100644 --- a/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py +++ b/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py @@ -634,6 +634,11 @@ def forward( query_states = self._shape(query_states, tgt_len, bsz) + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1. + is_causal = True if self.is_causal and attention_mask is None and tgt_len > 1 else False + # NOTE: SDPA with memory-efficient backend is currently (torch==2.1.2) bugged when using non-contiguous inputs and a custom attn_mask, # but we are fine here as `_shape` do call `.contiguous()`. Reference: https://github.com/pytorch/pytorch/issues/112577 attn_output = torch.nn.functional.scaled_dot_product_attention( @@ -642,8 +647,7 @@ def forward( value_states, attn_mask=attention_mask, dropout_p=self.dropout if self.training else 0.0, - # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1. - is_causal=self.is_causal and attention_mask is None and tgt_len > 1, + is_causal=is_causal, ) if attn_output.size() != (bsz, self.num_heads, tgt_len, self.head_dim): diff --git a/src/transformers/models/olmo/modeling_olmo.py b/src/transformers/models/olmo/modeling_olmo.py index 5009ac84be2e..ff4261fcd210 100644 --- a/src/transformers/models/olmo/modeling_olmo.py +++ b/src/transformers/models/olmo/modeling_olmo.py @@ -641,8 +641,8 @@ def forward( key_states = key_states.contiguous() value_states = value_states.contiguous() - # We dispatch to SDPA's Flash Attention or Efficient kernels via this if statement instead of an - # inline conditional assignment to support both torch.compile's `dynamic=True` and `fullgraph=True` + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. is_causal = True if causal_mask is None and q_len > 1 else False attn_output = torch.nn.functional.scaled_dot_product_attention( diff --git a/src/transformers/models/phi/modeling_phi.py b/src/transformers/models/phi/modeling_phi.py index 795ff18e5bcd..1f82b09a2570 100644 --- a/src/transformers/models/phi/modeling_phi.py +++ b/src/transformers/models/phi/modeling_phi.py @@ -709,13 +709,17 @@ def forward( key_states = key_states.contiguous() value_states = value_states.contiguous() + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + is_causal = True if self.is_causal and attention_mask is None and q_len > 1 else False + attn_output = torch.nn.functional.scaled_dot_product_attention( query_states, key_states, value_states, attn_mask=attention_mask, dropout_p=self.attention_dropout if self.training else 0.0, - is_causal=self.is_causal and attention_mask is None and q_len > 1, + is_causal=is_causal, ) attn_output = attn_output.transpose(1, 2).contiguous() diff --git a/src/transformers/models/phi3/modeling_phi3.py b/src/transformers/models/phi3/modeling_phi3.py index db88d607a365..2ba460234690 100644 --- a/src/transformers/models/phi3/modeling_phi3.py +++ b/src/transformers/models/phi3/modeling_phi3.py @@ -789,14 +789,18 @@ def forward( key_states = key_states.contiguous() value_states = value_states.contiguous() + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1. + is_causal = True if self.is_causal and attention_mask is None and q_len > 1 else False + attn_output = torch.nn.functional.scaled_dot_product_attention( query_states, key_states, value_states, attn_mask=attention_mask, dropout_p=self.attention_dropout if self.training else 0.0, - # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1. - is_causal=self.is_causal and attention_mask is None and q_len > 1, + is_causal=is_causal, ) attn_output = attn_output.transpose(1, 2).contiguous() diff --git a/src/transformers/models/qwen2/modeling_qwen2.py b/src/transformers/models/qwen2/modeling_qwen2.py index b5a1370ae1fc..660cf63eefe4 100644 --- a/src/transformers/models/qwen2/modeling_qwen2.py +++ b/src/transformers/models/qwen2/modeling_qwen2.py @@ -690,14 +690,18 @@ def forward( key_states = key_states.contiguous() value_states = value_states.contiguous() + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1. + is_causal = True if self.is_causal and attention_mask is None and q_len > 1 else False + attn_output = torch.nn.functional.scaled_dot_product_attention( query_states, key_states, value_states, attn_mask=attention_mask, dropout_p=self.attention_dropout if self.training else 0.0, - # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1. - is_causal=self.is_causal and attention_mask is None and q_len > 1, + is_causal=is_causal, ) attn_output = attn_output.transpose(1, 2).contiguous() diff --git a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py index 838425505b3b..df6db40e1691 100644 --- a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py +++ b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py @@ -768,14 +768,18 @@ def forward( key_states = key_states.contiguous() value_states = value_states.contiguous() + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1. + is_causal = True if self.is_causal and attention_mask is None and q_len > 1 else False + attn_output = torch.nn.functional.scaled_dot_product_attention( query_states, key_states, value_states, attn_mask=attention_mask, dropout_p=self.attention_dropout if self.training else 0.0, - # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1. - is_causal=self.is_causal and attention_mask is None and q_len > 1, + is_causal=is_causal, ) attn_output = attn_output.transpose(1, 2).contiguous() diff --git a/src/transformers/models/sew/modeling_sew.py b/src/transformers/models/sew/modeling_sew.py index 60cbb777c795..199d0e543a8b 100644 --- a/src/transformers/models/sew/modeling_sew.py +++ b/src/transformers/models/sew/modeling_sew.py @@ -852,6 +852,11 @@ def forward( query_states = self._shape(query_states, tgt_len, bsz) + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1. + is_causal = True if self.is_causal and attention_mask is None and tgt_len > 1 else False + # NOTE: SDPA with memory-efficient backend is currently (torch==2.1.2) bugged when using non-contiguous inputs and a custom attn_mask, # but we are fine here as `_shape` do call `.contiguous()`. Reference: https://github.com/pytorch/pytorch/issues/112577 attn_output = torch.nn.functional.scaled_dot_product_attention( @@ -860,8 +865,7 @@ def forward( value_states, attn_mask=attention_mask, dropout_p=self.dropout if self.training else 0.0, - # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1. - is_causal=self.is_causal and attention_mask is None and tgt_len > 1, + is_causal=is_causal, ) if attn_output.size() != (bsz, self.num_heads, tgt_len, self.head_dim): diff --git a/src/transformers/models/stablelm/modeling_stablelm.py b/src/transformers/models/stablelm/modeling_stablelm.py index bc133ffb3d73..a8869d1634d4 100755 --- a/src/transformers/models/stablelm/modeling_stablelm.py +++ b/src/transformers/models/stablelm/modeling_stablelm.py @@ -482,14 +482,18 @@ def forward( key_states = key_states.contiguous() value_states = value_states.contiguous() + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1. + is_causal = True if self.is_causal and attention_mask is None and q_len > 1 else False + attn_output = torch.nn.functional.scaled_dot_product_attention( query_states, key_states, value_states, attn_mask=attention_mask, dropout_p=self.attention_dropout.p if self.training else 0.0, - # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1. - is_causal=self.is_causal and attention_mask is None and q_len > 1, + is_causal=is_causal, ) attn_output = attn_output.transpose(1, 2).contiguous() diff --git a/src/transformers/models/starcoder2/modeling_starcoder2.py b/src/transformers/models/starcoder2/modeling_starcoder2.py index 61e8518d659c..ac1ed8ecc92f 100644 --- a/src/transformers/models/starcoder2/modeling_starcoder2.py +++ b/src/transformers/models/starcoder2/modeling_starcoder2.py @@ -669,14 +669,18 @@ def forward( key_states = key_states.contiguous() value_states = value_states.contiguous() + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + # # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1. + is_causal = True if self.is_causal and attention_mask is None and q_len > 1 else False + attn_output = torch.nn.functional.scaled_dot_product_attention( query_states, key_states, value_states, attn_mask=attention_mask, dropout_p=self.attention_dropout if self.training else 0.0, - # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1. - is_causal=self.is_causal and attention_mask is None and q_len > 1, + is_causal=is_causal, ) attn_output = attn_output.transpose(1, 2).contiguous() diff --git a/src/transformers/models/unispeech/modeling_unispeech.py b/src/transformers/models/unispeech/modeling_unispeech.py index 5c1557fb1f18..aa4bb7827eb7 100755 --- a/src/transformers/models/unispeech/modeling_unispeech.py +++ b/src/transformers/models/unispeech/modeling_unispeech.py @@ -888,6 +888,11 @@ def forward( query_states = self._shape(query_states, tgt_len, bsz) + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1. + is_causal = True if self.is_causal and attention_mask is None and tgt_len > 1 else False + # NOTE: SDPA with memory-efficient backend is currently (torch==2.1.2) bugged when using non-contiguous inputs and a custom attn_mask, # but we are fine here as `_shape` do call `.contiguous()`. Reference: https://github.com/pytorch/pytorch/issues/112577 attn_output = torch.nn.functional.scaled_dot_product_attention( @@ -896,8 +901,7 @@ def forward( value_states, attn_mask=attention_mask, dropout_p=self.dropout if self.training else 0.0, - # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1. - is_causal=self.is_causal and attention_mask is None and tgt_len > 1, + is_causal=is_causal, ) if attn_output.size() != (bsz, self.num_heads, tgt_len, self.head_dim): diff --git a/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py b/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py index 853c521e5e9c..4663fc05d8ba 100755 --- a/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py +++ b/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py @@ -905,6 +905,11 @@ def forward( query_states = self._shape(query_states, tgt_len, bsz) + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1. + is_causal = True if self.is_causal and attention_mask is None and tgt_len > 1 else False + # NOTE: SDPA with memory-efficient backend is currently (torch==2.1.2) bugged when using non-contiguous inputs and a custom attn_mask, # but we are fine here as `_shape` do call `.contiguous()`. Reference: https://github.com/pytorch/pytorch/issues/112577 attn_output = torch.nn.functional.scaled_dot_product_attention( @@ -913,8 +918,7 @@ def forward( value_states, attn_mask=attention_mask, dropout_p=self.dropout if self.training else 0.0, - # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1. - is_causal=self.is_causal and attention_mask is None and tgt_len > 1, + is_causal=is_causal, ) if attn_output.size() != (bsz, self.num_heads, tgt_len, self.head_dim): diff --git a/src/transformers/models/wav2vec2/modeling_wav2vec2.py b/src/transformers/models/wav2vec2/modeling_wav2vec2.py index ec928762b587..5fb64c1f2cf4 100755 --- a/src/transformers/models/wav2vec2/modeling_wav2vec2.py +++ b/src/transformers/models/wav2vec2/modeling_wav2vec2.py @@ -953,6 +953,11 @@ def forward( query_states = self._shape(query_states, tgt_len, bsz) + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1. + is_causal = True if self.is_causal and attention_mask is None and tgt_len > 1 else False + # NOTE: SDPA with memory-efficient backend is currently (torch==2.1.2) bugged when using non-contiguous inputs and a custom attn_mask, # but we are fine here as `_shape` do call `.contiguous()`. Reference: https://github.com/pytorch/pytorch/issues/112577 attn_output = torch.nn.functional.scaled_dot_product_attention( @@ -961,8 +966,7 @@ def forward( value_states, attn_mask=attention_mask, dropout_p=self.dropout if self.training else 0.0, - # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1. - is_causal=self.is_causal and attention_mask is None and tgt_len > 1, + is_causal=is_causal, ) if attn_output.size() != (bsz, self.num_heads, tgt_len, self.head_dim): diff --git a/src/transformers/models/whisper/modeling_whisper.py b/src/transformers/models/whisper/modeling_whisper.py index c0db404e5c88..e4fda437bfa4 100644 --- a/src/transformers/models/whisper/modeling_whisper.py +++ b/src/transformers/models/whisper/modeling_whisper.py @@ -686,6 +686,11 @@ def forward( query_states = self._shape(query_states, tgt_len, bsz) + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1. + is_causal = True if self.is_causal and attention_mask is None and tgt_len > 1 else False + # NOTE: SDPA with memory-efficient backend is currently (torch==2.1.2) bugged when using non-contiguous inputs and a custom attn_mask, # but we are fine here as `_shape` do call `.contiguous()`. Reference: https://github.com/pytorch/pytorch/issues/112577 attn_output = torch.nn.functional.scaled_dot_product_attention( @@ -694,8 +699,7 @@ def forward( value_states, attn_mask=attention_mask, dropout_p=self.dropout if self.training else 0.0, - # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1. - is_causal=self.is_causal and attention_mask is None and tgt_len > 1, + is_causal=is_causal, ) if attn_output.size() != (bsz, self.num_heads, tgt_len, self.head_dim): diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index df585f4afc65..d98c2e632b82 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -3963,6 +3963,47 @@ def test_sdpa_can_dispatch_on_flash(self): with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False): _ = model(**inputs_dict) + @require_torch_sdpa + @require_torch_gpu + @slow + def test_sdpa_can_compile_dynamic(self): + compute_capability = torch.cuda.get_device_capability() + major, _ = compute_capability + + if not torch.version.cuda or major < 8: + self.skipTest("This test requires an NVIDIA GPU with compute capability >= 8.0") + + for model_class in self.all_model_classes: + if not model_class._supports_sdpa: + self.skipTest(f"{model_class.__name__} does not support SDPA") + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + inputs_dict = self._prepare_for_class(inputs_dict, model_class) + if config.model_type in ["dbrx"]: + self.skipTest( + "DBRX (transformers==4.40) requires a modification to support dynamic shapes with compile." + ) + model = model_class(config) + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, attn_implementation="sdpa") + model.to(torch_device) + + # For PyTorch 2.1 - 2.3.0 set `dynamic=True`. In the future setting `dynamic=None` and using `torch._dynamo.mark_dynamic()` + # on input tensors will be required. `mark_dynamic` currently raises inconsistent shape errors. + model = torch.compile(model, dynamic=True) + + inputs_dict.pop("attention_mask", None) + inputs_dict.pop("decoder_attention_mask", None) + for name, inp in inputs_dict.items(): + if isinstance(inp, torch.Tensor) and inp.dtype in [torch.float32, torch.float16]: + inputs_dict[name] = inp.to(torch.float16) + + # use no_grad to save some memory + with torch.no_grad(): + _ = model(**inputs_dict) + @require_torch_sdpa @slow def test_eager_matches_sdpa_generate(self):