From bfcd0ad13372f911e7decb621b16ac6e896d3b59 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Mon, 29 May 2023 09:50:25 +0530 Subject: [PATCH 01/19] feat: add lora attention processor for pt 2.0. --- src/diffusers/models/attention_processor.py | 114 +++++++++++++++++--- 1 file changed, 98 insertions(+), 16 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 4b65d164bda1..21a348d6d76a 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -1137,6 +1137,87 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a return hidden_states +class LoRAAttnProcessor2_0(nn.Module): + r""" + Processor for implementing the LoRA attention mechanism using PyTorch 2.0's memory-efficient scaled dot-product + attention. + + Args: + hidden_size (`int`, *optional*): + The hidden size of the attention layer. + cross_attention_dim (`int`, *optional*): + The number of channels in the `encoder_hidden_states`. + rank (`int`, defaults to 4): + The dimension of the LoRA update matrices. + """ + + def __init__(self, hidden_size, cross_attention_dim=None, rank=4): + super().__init__() + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") + + self.hidden_size = hidden_size + self.cross_attention_dim = cross_attention_dim + self.rank = rank + + self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank) + self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank) + self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank) + self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank) + + def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0): + residual = hidden_states + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states) + query = attn.head_to_batch_dim(query) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + scale * self.to_k_lora(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + scale * self.to_v_lora(encoder_hidden_states) + + key = attn.head_to_batch_dim(key) + value = attn.head_to_batch_dim(value) + + # TODO: add support for attn.scale when we move to Torch 2.1 + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + hidden_states = attn.batch_to_head_dim(hidden_states) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + scale * self.to_out_lora(hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + class CustomDiffusionXFormersAttnProcessor(nn.Module): r""" Processor for implementing memory efficient attention using xFormers for the Custom Diffusion method. @@ -1406,22 +1487,6 @@ def __call__(self, attn: "Attention", hidden_states, encoder_hidden_states=None, return hidden_states -AttentionProcessor = Union[ - AttnProcessor, - AttnProcessor2_0, - XFormersAttnProcessor, - SlicedAttnProcessor, - AttnAddedKVProcessor, - SlicedAttnAddedKVProcessor, - AttnAddedKVProcessor2_0, - LoRAAttnProcessor, - LoRAXFormersAttnProcessor, - LoRAAttnAddedKVProcessor, - CustomDiffusionAttnProcessor, - CustomDiffusionXFormersAttnProcessor, -] - - class SpatialNorm(nn.Module): """ Spatially conditioned normalization as defined in https://arxiv.org/abs/2209.09002 @@ -1443,3 +1508,20 @@ def forward(self, f, zq): norm_f = self.norm_layer(f) new_f = norm_f * self.conv_y(zq) + self.conv_b(zq) return new_f + + +AttentionProcessor = Union[ + AttnProcessor, + AttnProcessor2_0, + XFormersAttnProcessor, + SlicedAttnProcessor, + AttnAddedKVProcessor, + SlicedAttnAddedKVProcessor, + AttnAddedKVProcessor2_0, + LoRAAttnProcessor, + LoRAXFormersAttnProcessor, + LoRAAttnProcessor2_0, + LoRAAttnAddedKVProcessor, + CustomDiffusionAttnProcessor, + CustomDiffusionXFormersAttnProcessor, +] From 53c51997fc8ca28e8be83983e4e33d7c295de1db Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Wed, 31 May 2023 09:06:37 +0530 Subject: [PATCH 02/19] explicit context manager for SDPA. --- src/diffusers/models/attention_processor.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 21a348d6d76a..d3d1a0fafbf6 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -17,6 +17,7 @@ import torch import torch.nn.functional as F from torch import nn +from torch.backends.cuda import SDPBackend, sdp_kernel from ..utils import deprecate, logging, maybe_allow_in_graph from ..utils.import_utils import is_xformers_available @@ -24,6 +25,12 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name +BACKEND_MAP = { + SDPBackend.MATH: {"enable_math": True, "enable_flash": False, "enable_mem_efficient": False}, + SDPBackend.FLASH_ATTENTION: {"enable_math": False, "enable_flash": True, "enable_mem_efficient": False}, + SDPBackend.EFFICIENT_ATTENTION: {"enable_math": False, "enable_flash": False, "enable_mem_efficient": True}, +} + if is_xformers_available(): import xformers @@ -1197,9 +1204,10 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a value = attn.head_to_batch_dim(value) # TODO: add support for attn.scale when we move to Torch 2.1 - hidden_states = F.scaled_dot_product_attention( - query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False - ) + with sdp_kernel(**BACKEND_MAP[SDPBackend.EFFICIENT_ATTENTION]): + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) hidden_states = attn.batch_to_head_dim(hidden_states) # linear proj From 00df0a41391b7b57cf40cc143f517863b56b9d5f Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Wed, 31 May 2023 09:13:04 +0530 Subject: [PATCH 03/19] switch to flash attention --- src/diffusers/models/attention_processor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index d3d1a0fafbf6..fc4acfaf0b76 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -1204,7 +1204,7 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a value = attn.head_to_batch_dim(value) # TODO: add support for attn.scale when we move to Torch 2.1 - with sdp_kernel(**BACKEND_MAP[SDPBackend.EFFICIENT_ATTENTION]): + with sdp_kernel(**BACKEND_MAP[SDPBackend.FLASH_ATTENTION]): hidden_states = F.scaled_dot_product_attention( query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False ) From 71b8ad27c54183090523745c8356ecb74877c4e4 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Wed, 31 May 2023 09:53:59 +0530 Subject: [PATCH 04/19] make shapes compatible to work optimally with SDPA. --- src/diffusers/models/attention_processor.py | 27 +++++++++++++-------- 1 file changed, 17 insertions(+), 10 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index fc4acfaf0b76..fa6cf3be23a4 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -17,7 +17,7 @@ import torch import torch.nn.functional as F from torch import nn -from torch.backends.cuda import SDPBackend, sdp_kernel +from torch.backends.cuda import SDPBackend from ..utils import deprecate, logging, maybe_allow_in_graph from ..utils.import_utils import is_xformers_available @@ -1184,13 +1184,18 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a batch_size, sequence_length, _ = ( hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape ) - attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + inner_dim = hidden_states.shape[-1] + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + # scaled_dot_product_attention expects attention_mask shape to be + # (batch, heads, source_length, target_length) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) if attn.group_norm is not None: hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states) - query = attn.head_to_batch_dim(query) if encoder_hidden_states is None: encoder_hidden_states = hidden_states @@ -1200,15 +1205,17 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a key = attn.to_k(encoder_hidden_states) + scale * self.to_k_lora(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) + scale * self.to_v_lora(encoder_hidden_states) - key = attn.head_to_batch_dim(key) - value = attn.head_to_batch_dim(value) + head_dim = inner_dim // attn.heads + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) # TODO: add support for attn.scale when we move to Torch 2.1 - with sdp_kernel(**BACKEND_MAP[SDPBackend.FLASH_ATTENTION]): - hidden_states = F.scaled_dot_product_attention( - query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False - ) - hidden_states = attn.batch_to_head_dim(hidden_states) + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) # linear proj hidden_states = attn.to_out[0](hidden_states) + scale * self.to_out_lora(hidden_states) From bf60598cd9d1f26c6fab48b7c95708b419b3f66c Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Wed, 31 May 2023 10:16:05 +0530 Subject: [PATCH 05/19] fix: circular import problem. --- src/diffusers/models/attention_processor.py | 34 ++++++++++----------- 1 file changed, 17 insertions(+), 17 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index fd9a7ac55e11..d26e0cdb14e4 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -1596,6 +1596,23 @@ def __call__(self, attn: "Attention", hidden_states, encoder_hidden_states=None, return hidden_states +AttentionProcessor = Union[ + AttnProcessor, + AttnProcessor2_0, + XFormersAttnProcessor, + SlicedAttnProcessor, + AttnAddedKVProcessor, + SlicedAttnAddedKVProcessor, + AttnAddedKVProcessor2_0, + LoRAAttnProcessor, + LoRAXFormersAttnProcessor, + LoRAAttnProcessor2_0, + LoRAAttnAddedKVProcessor, + CustomDiffusionAttnProcessor, + CustomDiffusionXFormersAttnProcessor, +] + + class SpatialNorm(nn.Module): """ Spatially conditioned normalization as defined in https://arxiv.org/abs/2209.09002 @@ -1617,20 +1634,3 @@ def forward(self, f, zq): norm_f = self.norm_layer(f) new_f = norm_f * self.conv_y(zq) + self.conv_b(zq) return new_f - - -AttentionProcessor = Union[ - AttnProcessor, - AttnProcessor2_0, - XFormersAttnProcessor, - SlicedAttnProcessor, - AttnAddedKVProcessor, - SlicedAttnAddedKVProcessor, - AttnAddedKVProcessor2_0, - LoRAAttnProcessor, - LoRAXFormersAttnProcessor, - LoRAAttnProcessor2_0, - LoRAAttnAddedKVProcessor, - CustomDiffusionAttnProcessor, - CustomDiffusionXFormersAttnProcessor, -] From e5fad840283120fc0d553cb12a5e447a37385377 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Wed, 31 May 2023 10:47:05 +0530 Subject: [PATCH 06/19] explicitly specify the flash attention kernel in sdpa --- src/diffusers/models/attention_processor.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index d26e0cdb14e4..b1fb5883240e 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -17,7 +17,7 @@ import torch import torch.nn.functional as F from torch import nn -from torch.backends.cuda import SDPBackend +from torch.backends.cuda import SDPBackend, sdp_kernel from ..utils import deprecate, logging, maybe_allow_in_graph from ..utils.import_utils import is_xformers_available @@ -1301,9 +1301,10 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) # TODO: add support for attn.scale when we move to Torch 2.1 - hidden_states = F.scaled_dot_product_attention( - query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False - ) + with sdp_kernel(**BACKEND_MAP[SDPBackend.FLASH_ATTENTION]): + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states = hidden_states.to(query.dtype) From 151907523e830ec0067d11129b4b9da05ec38e9f Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Wed, 31 May 2023 10:55:47 +0530 Subject: [PATCH 07/19] fall back to efficient attention context manager. --- src/diffusers/models/attention_processor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index b1fb5883240e..55d174da70db 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -1301,7 +1301,7 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) # TODO: add support for attn.scale when we move to Torch 2.1 - with sdp_kernel(**BACKEND_MAP[SDPBackend.FLASH_ATTENTION]): + with sdp_kernel(**BACKEND_MAP[SDPBackend.EFFICIENT_ATTENTION]): hidden_states = F.scaled_dot_product_attention( query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False ) From a193d2659fdb74c13a5f95678405873051b67e0e Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Wed, 31 May 2023 11:14:39 +0530 Subject: [PATCH 08/19] remove explicit dispatch. --- src/diffusers/models/attention_processor.py | 14 +++----------- 1 file changed, 3 insertions(+), 11 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 55d174da70db..14355dc26ac7 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -17,7 +17,6 @@ import torch import torch.nn.functional as F from torch import nn -from torch.backends.cuda import SDPBackend, sdp_kernel from ..utils import deprecate, logging, maybe_allow_in_graph from ..utils.import_utils import is_xformers_available @@ -25,12 +24,6 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name -BACKEND_MAP = { - SDPBackend.MATH: {"enable_math": True, "enable_flash": False, "enable_mem_efficient": False}, - SDPBackend.FLASH_ATTENTION: {"enable_math": False, "enable_flash": True, "enable_mem_efficient": False}, - SDPBackend.EFFICIENT_ATTENTION: {"enable_math": False, "enable_flash": False, "enable_mem_efficient": True}, -} - if is_xformers_available(): import xformers @@ -1301,10 +1294,9 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) # TODO: add support for attn.scale when we move to Torch 2.1 - with sdp_kernel(**BACKEND_MAP[SDPBackend.EFFICIENT_ATTENTION]): - hidden_states = F.scaled_dot_product_attention( - query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False - ) + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states = hidden_states.to(query.dtype) From 7898c11254a458fec6cb1c3bb0e1381d27176633 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Fri, 2 Jun 2023 09:59:13 +0530 Subject: [PATCH 09/19] fix: removed processor. --- src/diffusers/models/attention_processor.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 14355dc26ac7..12a7685b1688 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -1597,6 +1597,7 @@ def __call__(self, attn: "Attention", hidden_states, encoder_hidden_states=None, AttnAddedKVProcessor, SlicedAttnAddedKVProcessor, AttnAddedKVProcessor2_0, + XFormersAttnAddedKVProcessor, LoRAAttnProcessor, LoRAXFormersAttnProcessor, LoRAAttnProcessor2_0, From 686742745e9f3d1c57c002b71e1eb8af583fda68 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Fri, 2 Jun 2023 10:00:19 +0530 Subject: [PATCH 10/19] fix: remove optional from type annotation. --- src/diffusers/models/attention_processor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 12a7685b1688..adc8af06581d 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -1233,7 +1233,7 @@ class LoRAAttnProcessor2_0(nn.Module): attention. Args: - hidden_size (`int`, *optional*): + hidden_size (`int`): The hidden size of the attention layer. cross_attention_dim (`int`, *optional*): The number of channels in the `encoder_hidden_states`. From 4d3afd2d59660a738c460d3ce59fa43b901531d2 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Fri, 2 Jun 2023 10:40:39 +0530 Subject: [PATCH 11/19] feat: make changes regarding LoRAAttnProcessor2_0. --- examples/dreambooth/train_dreambooth_lora.py | 5 ++++- src/diffusers/loaders.py | 6 +++++- src/diffusers/models/attention_processor.py | 18 ++++++++++------- tests/models/test_lora_layers.py | 21 +++++++++++++++++--- 4 files changed, 38 insertions(+), 12 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora.py b/examples/dreambooth/train_dreambooth_lora.py index 4ff759dcd6d4..bd977fbd95e4 100644 --- a/examples/dreambooth/train_dreambooth_lora.py +++ b/examples/dreambooth/train_dreambooth_lora.py @@ -55,6 +55,7 @@ AttnAddedKVProcessor2_0, LoRAAttnAddedKVProcessor, LoRAAttnProcessor, + LoRAAttnProcessor2_0, SlicedAttnAddedKVProcessor, ) from diffusers.optimization import get_scheduler @@ -831,7 +832,9 @@ def main(args): if isinstance(attn_processor, (AttnAddedKVProcessor, SlicedAttnAddedKVProcessor, AttnAddedKVProcessor2_0)): lora_attn_processor_class = LoRAAttnAddedKVProcessor else: - lora_attn_processor_class = LoRAAttnProcessor + lora_attn_processor_class = ( + LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor + ) unet_lora_attn_procs[name] = lora_attn_processor_class( hidden_size=hidden_size, cross_attention_dim=cross_attention_dim diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index 84e6b4e61f0f..7208b9a5bc9a 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -18,6 +18,7 @@ from typing import Callable, Dict, List, Optional, Union import torch +import torch.nn.functional as F from huggingface_hub import hf_hub_download from .models.attention_processor import ( @@ -27,6 +28,7 @@ CustomDiffusionXFormersAttnProcessor, LoRAAttnAddedKVProcessor, LoRAAttnProcessor, + LoRAAttnProcessor2_0, LoRAXFormersAttnProcessor, SlicedAttnAddedKVProcessor, XFormersAttnProcessor, @@ -284,7 +286,9 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict if isinstance(attn_processor, (XFormersAttnProcessor, LoRAXFormersAttnProcessor)): attn_processor_class = LoRAXFormersAttnProcessor else: - attn_processor_class = LoRAAttnProcessor + attn_processor_class = ( + LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor + ) attn_processors[key] = attn_processor_class( hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, rank=rank diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index adc8af06581d..185f8daf9aea 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -166,7 +166,8 @@ def set_use_memory_efficient_attention_xformers( self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None ): is_lora = hasattr(self, "processor") and isinstance( - self.processor, (LoRAAttnProcessor, LoRAXFormersAttnProcessor, LoRAAttnAddedKVProcessor) + self.processor, + (LoRAAttnProcessor, LoRAAttnProcessor2_0, LoRAXFormersAttnProcessor, LoRAAttnAddedKVProcessor), ) is_custom_diffusion = hasattr(self, "processor") and isinstance( self.processor, (CustomDiffusionAttnProcessor, CustomDiffusionXFormersAttnProcessor) @@ -202,11 +203,9 @@ def set_use_memory_efficient_attention_xformers( ) elif hasattr(F, "scaled_dot_product_attention") and self.scale_qk: warnings.warn( - "You have specified using flash attention using xFormers but you have PyTorch 2.0 already installed. " - "We will default to PyTorch's native efficient flash attention implementation (`F.scaled_dot_product_attention`) " - "introduced in PyTorch 2.0. In case you are using LoRA or Custom Diffusion, we will fall " - "back to their respective attention processors i.e., we will NOT use the PyTorch 2.0 " - "native efficient flash attention." + "You have specified using efficient attention using xFormers but you have PyTorch 2.0 already installed. " + "We will default to PyTorch's native efficient attention implementation (`F.scaled_dot_product_attention`) " + "introduced in PyTorch 2.0." ) else: try: @@ -220,6 +219,8 @@ def set_use_memory_efficient_attention_xformers( raise e if is_lora: + # TODO (sayakpaul): should we throw a warning if someone wants to use the xformers + # variant when using PT 2.0 now that we have LoRAAttnProcessor2_0? processor = LoRAXFormersAttnProcessor( hidden_size=self.processor.hidden_size, cross_attention_dim=self.processor.cross_attention_dim, @@ -252,7 +253,10 @@ def set_use_memory_efficient_attention_xformers( processor = XFormersAttnProcessor(attention_op=attention_op) else: if is_lora: - processor = LoRAAttnProcessor( + attn_processor_class = ( + LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor + ) + processor = attn_processor_class( hidden_size=self.processor.hidden_size, cross_attention_dim=self.processor.cross_attention_dim, rank=self.processor.rank, diff --git a/tests/models/test_lora_layers.py b/tests/models/test_lora_layers.py index 64e30ba4057d..849a4d5dc697 100644 --- a/tests/models/test_lora_layers.py +++ b/tests/models/test_lora_layers.py @@ -18,6 +18,7 @@ import torch import torch.nn as nn +import torch.nn.functional as F from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline, UNet2DConditionModel @@ -27,6 +28,7 @@ AttnProcessor, AttnProcessor2_0, LoRAAttnProcessor, + LoRAAttnProcessor2_0, LoRAXFormersAttnProcessor, XFormersAttnProcessor, ) @@ -45,16 +47,26 @@ def create_unet_lora_layers(unet: nn.Module): elif name.startswith("down_blocks"): block_id = int(name[len("down_blocks.")]) hidden_size = unet.config.block_out_channels[block_id] - lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim) + lora_attn_processor_class = ( + LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor + ) + lora_attn_procs[name] = lora_attn_processor_class( + hidden_size=hidden_size, cross_attention_dim=cross_attention_dim + ) unet_lora_layers = AttnProcsLayers(lora_attn_procs) return lora_attn_procs, unet_lora_layers def create_text_encoder_lora_layers(text_encoder: nn.Module): text_lora_attn_procs = {} + lora_attn_processor_class = ( + LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor + ) for name, module in text_encoder.named_modules(): if any(x in name for x in TEXT_ENCODER_TARGET_MODULES): - text_lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=module.out_features, cross_attention_dim=None) + text_lora_attn_procs[name] = lora_attn_processor_class( + hidden_size=module.out_features, cross_attention_dim=None + ) text_encoder_lora_layers = AttnProcsLayers(text_lora_attn_procs) return text_encoder_lora_layers @@ -249,7 +261,10 @@ def test_lora_unet_attn_processors(self): # check if lora attention processors are used for _, module in sd_pipe.unet.named_modules(): if isinstance(module, Attention): - self.assertIsInstance(module.processor, LoRAAttnProcessor) + attn_proc_class = ( + LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor + ) + self.assertIsInstance(module.processor, attn_proc_class) @unittest.skipIf(torch_device != "cuda", "This test is supposed to run on GPU") def test_lora_unet_attn_processors_with_xformers(self): From b694e3f9edab7827535c945263fb4e8d09558980 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Fri, 2 Jun 2023 11:08:23 +0530 Subject: [PATCH 12/19] remove confusing warning. --- src/diffusers/models/attention_processor.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 185f8daf9aea..501dc98aafa7 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -201,12 +201,6 @@ def set_use_memory_efficient_attention_xformers( "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is" " only available for GPU " ) - elif hasattr(F, "scaled_dot_product_attention") and self.scale_qk: - warnings.warn( - "You have specified using efficient attention using xFormers but you have PyTorch 2.0 already installed. " - "We will default to PyTorch's native efficient attention implementation (`F.scaled_dot_product_attention`) " - "introduced in PyTorch 2.0." - ) else: try: # Make sure we can run the memory efficient attention From ffb136d81fc09415ec1264c04168e33f15a0fd2a Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Fri, 2 Jun 2023 11:10:10 +0530 Subject: [PATCH 13/19] formatting. --- src/diffusers/models/attention_processor.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 501dc98aafa7..0c418c92c1a9 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import warnings from typing import Callable, Optional, Union import torch From 8c304bca97781a21aa0ead74d3bec1cf6f02b279 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Fri, 2 Jun 2023 11:31:12 +0530 Subject: [PATCH 14/19] relax tolerance for PT 2.0 --- tests/models/test_models_unet_3d_condition.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/models/test_models_unet_3d_condition.py b/tests/models/test_models_unet_3d_condition.py index 928f6bcbe960..a32098883177 100644 --- a/tests/models/test_models_unet_3d_condition.py +++ b/tests/models/test_models_unet_3d_condition.py @@ -261,7 +261,7 @@ def test_lora_save_load(self): with torch.no_grad(): new_sample = new_model(**inputs_dict, cross_attention_kwargs={"scale": 0.5}).sample - assert (sample - new_sample).abs().max() < 1e-4 + assert (sample - new_sample).abs().max() < 5e-4 # LoRA and no LoRA should NOT be the same assert (sample - old_sample).abs().max() > 1e-4 @@ -295,7 +295,7 @@ def test_lora_save_load_safetensors(self): with torch.no_grad(): new_sample = new_model(**inputs_dict, cross_attention_kwargs={"scale": 0.5}).sample - assert (sample - new_sample).abs().max() < 1e-4 + assert (sample - new_sample).abs().max() < 2e-4 # LoRA and no LoRA should NOT be the same assert (sample - old_sample).abs().max() > 1e-4 From 9d12c34a3fc376c8e32b1ca849bcb6675f6c5b9d Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Fri, 2 Jun 2023 11:45:04 +0530 Subject: [PATCH 15/19] fix: loading message. --- examples/dreambooth/train_dreambooth_lora.py | 2 +- src/diffusers/loaders.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora.py b/examples/dreambooth/train_dreambooth_lora.py index bd977fbd95e4..c2b41b47935b 100644 --- a/examples/dreambooth/train_dreambooth_lora.py +++ b/examples/dreambooth/train_dreambooth_lora.py @@ -835,7 +835,7 @@ def main(args): lora_attn_processor_class = ( LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor ) - + logger.info(f"Using {lora_attn_processor_class} as the LoRA attention processor class.") unet_lora_attn_procs[name] = lora_attn_processor_class( hidden_size=hidden_size, cross_attention_dim=cross_attention_dim ) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index 7208b9a5bc9a..15981702dfb0 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -917,11 +917,11 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di # Load the layers corresponding to text encoder and make necessary adjustments. text_encoder_keys = [k for k in keys if k.startswith(self.text_encoder_name)] - logger.info(f"Loading {self.text_encoder_name}.") text_encoder_lora_state_dict = { k.replace(f"{self.text_encoder_name}.", ""): v for k, v in state_dict.items() if k in text_encoder_keys } if len(text_encoder_lora_state_dict) > 0: + logger.info(f"Loading {self.text_encoder_name}.") attn_procs_text_encoder = self._load_text_encoder_attn_procs(text_encoder_lora_state_dict) self._modify_text_encoder(attn_procs_text_encoder) From 3c3c2f7b47bcd436ec8980d9689d93ae57c62d29 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Fri, 2 Jun 2023 11:46:13 +0530 Subject: [PATCH 16/19] remove unnecessary logging. --- examples/dreambooth/train_dreambooth_lora.py | 1 - 1 file changed, 1 deletion(-) diff --git a/examples/dreambooth/train_dreambooth_lora.py b/examples/dreambooth/train_dreambooth_lora.py index c2b41b47935b..1c7049182f17 100644 --- a/examples/dreambooth/train_dreambooth_lora.py +++ b/examples/dreambooth/train_dreambooth_lora.py @@ -835,7 +835,6 @@ def main(args): lora_attn_processor_class = ( LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor ) - logger.info(f"Using {lora_attn_processor_class} as the LoRA attention processor class.") unet_lora_attn_procs[name] = lora_attn_processor_class( hidden_size=hidden_size, cross_attention_dim=cross_attention_dim ) From ba3f7ad2f871bb0e381ad02a68b4de68ab12babe Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Fri, 2 Jun 2023 11:51:11 +0530 Subject: [PATCH 17/19] add: entry to the docs. --- docs/source/en/api/attnprocessor.mdx | 3 +++ 1 file changed, 3 insertions(+) diff --git a/docs/source/en/api/attnprocessor.mdx b/docs/source/en/api/attnprocessor.mdx index ead639feffe0..7a4812e0961e 100644 --- a/docs/source/en/api/attnprocessor.mdx +++ b/docs/source/en/api/attnprocessor.mdx @@ -11,6 +11,9 @@ An attention processor is a class for applying different types of attention mech ## LoRAAttnProcessor [[autodoc]] models.attention_processor.LoRAAttnProcessor +## LoRAAttnProcessor2_0 +[[autodoc]] models.attention_processor.LoRAAttnProcessor2_0 + ## CustomDiffusionAttnProcessor [[autodoc]] models.attention_processor.CustomDiffusionAttnProcessor From 0c764515a5bc524a4697f7f41417f744f2bf076b Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 6 Jun 2023 14:29:20 +0530 Subject: [PATCH 18/19] add: network_alpha argument. --- src/diffusers/models/attention_processor.py | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 4714481dcff0..e0404a83cc9a 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -545,6 +545,8 @@ class LoRAAttnProcessor(nn.Module): The number of channels in the `encoder_hidden_states`. rank (`int`, defaults to 4): The dimension of the LoRA update matrices. + network_alpha (`int`, *optional*): + Equivalent to `alpha` but it's usage is specific to Kohya (A1111) style LoRAs. """ def __init__(self, hidden_size, cross_attention_dim=None, rank=4, network_alpha=None): @@ -840,6 +842,7 @@ class LoRAAttnAddedKVProcessor(nn.Module): The number of channels in the `encoder_hidden_states`. rank (`int`, defaults to 4): The dimension of the LoRA update matrices. + """ def __init__(self, hidden_size, cross_attention_dim=None, rank=4, network_alpha=None): @@ -1159,6 +1162,9 @@ class LoRAXFormersAttnProcessor(nn.Module): [operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to use as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best operator. + network_alpha (`int`, *optional*): + Equivalent to `alpha` but it's usage is specific to Kohya (A1111) style LoRAs. + """ def __init__( @@ -1245,9 +1251,11 @@ class LoRAAttnProcessor2_0(nn.Module): The number of channels in the `encoder_hidden_states`. rank (`int`, defaults to 4): The dimension of the LoRA update matrices. + network_alpha (`int`, *optional*): + Equivalent to `alpha` but it's usage is specific to Kohya (A1111) style LoRAs. """ - def __init__(self, hidden_size, cross_attention_dim=None, rank=4): + def __init__(self, hidden_size, cross_attention_dim=None, rank=4, network_alpha=None): super().__init__() if not hasattr(F, "scaled_dot_product_attention"): raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") @@ -1256,10 +1264,10 @@ def __init__(self, hidden_size, cross_attention_dim=None, rank=4): self.cross_attention_dim = cross_attention_dim self.rank = rank - self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank) - self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank) - self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank) - self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank) + self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha) + self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha) + self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha) + self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha) def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0): residual = hidden_states From b13c5df909bd1fe8dd996c8f224ceac953c4bb59 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 6 Jun 2023 14:38:03 +0530 Subject: [PATCH 19/19] relax tolerance. --- tests/models/test_models_unet_3d_condition.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/models/test_models_unet_3d_condition.py b/tests/models/test_models_unet_3d_condition.py index a32098883177..762c4975da51 100644 --- a/tests/models/test_models_unet_3d_condition.py +++ b/tests/models/test_models_unet_3d_condition.py @@ -295,7 +295,7 @@ def test_lora_save_load_safetensors(self): with torch.no_grad(): new_sample = new_model(**inputs_dict, cross_attention_kwargs={"scale": 0.5}).sample - assert (sample - new_sample).abs().max() < 2e-4 + assert (sample - new_sample).abs().max() < 3e-4 # LoRA and no LoRA should NOT be the same assert (sample - old_sample).abs().max() > 1e-4