From 719b9713ba0cf30aee3b93fb16209ecbc47105bb Mon Sep 17 00:00:00 2001 From: caiqi Date: Sat, 25 Feb 2023 00:24:50 +0800 Subject: [PATCH] adapt attention.py to torch 2.0 --- src/diffusers/models/attention.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 3cdc7177a411..132ff5334cba 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -144,6 +144,9 @@ def forward(self, hidden_states): query_proj, key_proj, value_proj, attn_bias=None, op=self._attention_op ) hidden_states = hidden_states.to(query_proj.dtype) + elif hasattr(F, "scaled_dot_product_attention"): + # torch.nn.functional.scaled_dot_product_attention when torch2.x is used + hidden_states = F.scaled_dot_product_attention(query_proj, key_proj, value_proj) else: attention_scores = torch.baddbmm( torch.empty(