diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 3cdc7177a411..132ff5334cba 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -144,6 +144,9 @@ def forward(self, hidden_states): query_proj, key_proj, value_proj, attn_bias=None, op=self._attention_op ) hidden_states = hidden_states.to(query_proj.dtype) + elif hasattr(F, "scaled_dot_product_attention"): + # torch.nn.functional.scaled_dot_product_attention when torch2.x is used + hidden_states = F.scaled_dot_product_attention(query_proj, key_proj, value_proj) else: attention_scores = torch.baddbmm( torch.empty(