diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 6701122fc13b..b727c76e2137 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import warnings from typing import Callable, Optional, Union import torch @@ -72,7 +73,8 @@ def __init__( self.upcast_attention = upcast_attention self.upcast_softmax = upcast_softmax - self.scale = dim_head**-0.5 if scale_qk else 1.0 + self.scale_qk = scale_qk + self.scale = dim_head**-0.5 if self.scale_qk else 1.0 self.heads = heads # for slice_size > 0 the attention score computation @@ -140,7 +142,7 @@ def __init__( # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1 if processor is None: processor = ( - AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and scale_qk else AttnProcessor() + AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor() ) self.set_processor(processor) @@ -176,6 +178,11 @@ def set_use_memory_efficient_attention_xformers( "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is" " only available for GPU " ) + elif hasattr(F, "scaled_dot_product_attention") and self.scale_qk: + warnings.warn( + "You have specified using flash attention using xFormers but you have PyTorch 2.0 already installed. " + "We will default to PyTorch's native efficient flash attention implementation provided by PyTorch 2.0." + ) else: try: # Make sure we can run the memory efficient attention @@ -229,7 +236,15 @@ def set_use_memory_efficient_attention_xformers( if hasattr(self.processor, "to_k_custom_diffusion"): processor.to(self.processor.to_k_custom_diffusion.weight.device) else: - processor = AttnProcessor() + # set attention processor + # We use the AttnProcessor2_0 by default when torch 2.x is used which uses + # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention + # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1 + processor = ( + AttnProcessor2_0() + if hasattr(F, "scaled_dot_product_attention") and self.scale_qk + else AttnProcessor() + ) self.set_processor(processor) @@ -244,7 +259,13 @@ def set_attention_slice(self, slice_size): elif self.added_kv_proj_dim is not None: processor = AttnAddedKVProcessor() else: - processor = AttnProcessor() + # set attention processor + # We use the AttnProcessor2_0 by default when torch 2.x is used which uses + # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention + # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1 + processor = ( + AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor() + ) self.set_processor(processor)