From 7b6d3f406b135427d7a1b519095a37bbdeecb61e Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Fri, 28 Apr 2023 12:58:56 +0200 Subject: [PATCH 1/3] Allow disabling torch 2_0 attention --- src/diffusers/models/attention.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 8e537c6f3680..55ed921e3c35 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -71,6 +71,7 @@ def __init__( self.proj_attn = nn.Linear(channels, channels, bias=True) self._use_memory_efficient_attention_xformers = False + self._use_2_0_attn = True self._attention_op = None def reshape_heads_to_batch_dim(self, tensor, merge_head_and_batch=True): @@ -142,8 +143,9 @@ def forward(self, hidden_states): scale = 1 / math.sqrt(self.channels / self.num_heads) + _use_2_0_attn = self._use_2_0_attn and not self._use_memory_efficient_attention_xformers and self._use_2_0_attn use_torch_2_0_attn = ( - hasattr(F, "scaled_dot_product_attention") and not self._use_memory_efficient_attention_xformers + hasattr(F, "scaled_dot_product_attention") and _use_2_0_attn ) query_proj = self.reshape_heads_to_batch_dim(query_proj, merge_head_and_batch=not use_torch_2_0_attn) From 8564d251c72f2850acfafdce709d18a2ca2c7b06 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Fri, 28 Apr 2023 13:03:51 +0200 Subject: [PATCH 2/3] make style --- src/diffusers/models/attention.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 55ed921e3c35..356af7ee6d5d 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -144,9 +144,7 @@ def forward(self, hidden_states): scale = 1 / math.sqrt(self.channels / self.num_heads) _use_2_0_attn = self._use_2_0_attn and not self._use_memory_efficient_attention_xformers and self._use_2_0_attn - use_torch_2_0_attn = ( - hasattr(F, "scaled_dot_product_attention") and _use_2_0_attn - ) + use_torch_2_0_attn = hasattr(F, "scaled_dot_product_attention") and _use_2_0_attn query_proj = self.reshape_heads_to_batch_dim(query_proj, merge_head_and_batch=not use_torch_2_0_attn) key_proj = self.reshape_heads_to_batch_dim(key_proj, merge_head_and_batch=not use_torch_2_0_attn) From 917211f313ea57e3d5141a4e6445777aaaccf020 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Fri, 28 Apr 2023 12:17:36 +0100 Subject: [PATCH 3/3] Update src/diffusers/models/attention.py --- src/diffusers/models/attention.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 356af7ee6d5d..fb5f6f48b324 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -143,7 +143,7 @@ def forward(self, hidden_states): scale = 1 / math.sqrt(self.channels / self.num_heads) - _use_2_0_attn = self._use_2_0_attn and not self._use_memory_efficient_attention_xformers and self._use_2_0_attn + _use_2_0_attn = self._use_2_0_attn and not self._use_memory_efficient_attention_xformers use_torch_2_0_attn = hasattr(F, "scaled_dot_product_attention") and _use_2_0_attn query_proj = self.reshape_heads_to_batch_dim(query_proj, merge_head_and_batch=not use_torch_2_0_attn)