From ccc0c831f58c6f61f965f132a7300d152ae4ae56 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Mon, 8 May 2023 15:56:39 +0530 Subject: [PATCH 1/2] add: a warning message when using xformers in a PT 2.0 env. --- src/diffusers/models/attention_processor.py | 29 ++++++++++++++++++--- 1 file changed, 25 insertions(+), 4 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 7ac88b17999a..78688ebe97df 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import warnings from typing import Callable, Optional, Union import torch @@ -72,7 +73,8 @@ def __init__( self.upcast_attention = upcast_attention self.upcast_softmax = upcast_softmax - self.scale = dim_head**-0.5 if scale_qk else 1.0 + self.scale_qk = scale_qk + self.scale = dim_head**-0.5 if self.scale_qk else 1.0 self.heads = heads # for slice_size > 0 the attention score computation @@ -140,7 +142,7 @@ def __init__( # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1 if processor is None: processor = ( - AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and scale_qk else AttnProcessor() + AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor() ) self.set_processor(processor) @@ -176,6 +178,11 @@ def set_use_memory_efficient_attention_xformers( "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is" " only available for GPU " ) + elif hasattr(F, "scaled_dot_product_attention") and self.scale_qk: + warnings.warn( + "You have specified using flash attention using xFormers but you have PyTorch 2.0 already installed. " + "So, we will default to native efficient flash attention implementation provided by PyTorch 2.0." + ) else: try: # Make sure we can run the memory efficient attention @@ -229,7 +236,15 @@ def set_use_memory_efficient_attention_xformers( if hasattr(self.processor, "to_k_custom_diffusion"): processor.to(self.processor.to_k_custom_diffusion.weight.device) else: - processor = AttnProcessor() + # set attention processor + # We use the AttnProcessor2_0 by default when torch 2.x is used which uses + # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention + # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1 + processor = ( + AttnProcessor2_0() + if hasattr(F, "scaled_dot_product_attention") and self.scale_qk + else AttnProcessor() + ) self.set_processor(processor) @@ -244,7 +259,13 @@ def set_attention_slice(self, slice_size): elif self.added_kv_proj_dim is not None: processor = AttnAddedKVProcessor() else: - processor = AttnProcessor() + # set attention processor + # We use the AttnProcessor2_0 by default when torch 2.x is used which uses + # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention + # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1 + processor = ( + AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor() + ) self.set_processor(processor) From 3e4846f8515ce7fde2ecfe2df88cca82212de352 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 9 May 2023 18:05:22 +0530 Subject: [PATCH 2/2] Apply suggestions from code review Co-authored-by: Patrick von Platen --- src/diffusers/models/attention_processor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 78688ebe97df..5d242f314184 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -181,7 +181,7 @@ def set_use_memory_efficient_attention_xformers( elif hasattr(F, "scaled_dot_product_attention") and self.scale_qk: warnings.warn( "You have specified using flash attention using xFormers but you have PyTorch 2.0 already installed. " - "So, we will default to native efficient flash attention implementation provided by PyTorch 2.0." + "We will default to PyTorch's native efficient flash attention implementation provided by PyTorch 2.0." ) else: try: