From b42e787a2e129fe4eb49f1825e6ef6b5c112cd65 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 15 Mar 2023 18:23:27 +0100 Subject: [PATCH 1/7] rename file --- .../models/{cross_attention.py => attention_processor.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename src/diffusers/models/{cross_attention.py => attention_processor.py} (100%) diff --git a/src/diffusers/models/cross_attention.py b/src/diffusers/models/attention_processor.py similarity index 100% rename from src/diffusers/models/cross_attention.py rename to src/diffusers/models/attention_processor.py From 2c8e2ed5cb086582ae07712f6df81ebbec4ec0b8 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 15 Mar 2023 18:27:19 +0100 Subject: [PATCH 2/7] rename attention --- docs/source/en/optimization/torch2.0.mdx | 2 +- examples/dreambooth/train_dreambooth_lora.py | 2 +- .../train_dreambooth_inpaint_lora.py | 2 +- .../lora/train_text_to_image_lora.py | 2 +- .../text_to_image/train_text_to_image_lora.py | 2 +- src/diffusers/loaders.py | 2 +- src/diffusers/models/attention.py | 6 ++--- src/diffusers/models/attention_flax.py | 6 ++--- src/diffusers/models/attention_processor.py | 24 ++++++++----------- src/diffusers/models/controlnet.py | 4 ++-- src/diffusers/models/dual_transformer_2d.py | 2 +- src/diffusers/models/unet_2d_blocks.py | 12 +++++----- src/diffusers/models/unet_2d_condition.py | 4 ++-- ...line_stable_diffusion_attend_and_excite.py | 4 ++-- .../pipeline_stable_diffusion_pix2pix_zero.py | 4 ++-- .../versatile_diffusion/modeling_text_unet.py | 8 +++---- tests/models/test_models_unet_2d_condition.py | 2 +- 17 files changed, 42 insertions(+), 46 deletions(-) diff --git a/docs/source/en/optimization/torch2.0.mdx b/docs/source/en/optimization/torch2.0.mdx index a55ac3634522..bf00c1dd408c 100644 --- a/docs/source/en/optimization/torch2.0.mdx +++ b/docs/source/en/optimization/torch2.0.mdx @@ -50,7 +50,7 @@ pip install --pre torch torchvision --index-url https://download.pytorch.org/whl ```Python import torch from diffusers import StableDiffusionPipeline - from diffusers.models.cross_attention import AttnProcessor2_0 + from diffusers.models.attention_processor import AttnProcessor2_0 pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16).to("cuda") pipe.unet.set_attn_processor(AttnProcessor2_0()) diff --git a/examples/dreambooth/train_dreambooth_lora.py b/examples/dreambooth/train_dreambooth_lora.py index 92d08b64b638..8de512462ad5 100644 --- a/examples/dreambooth/train_dreambooth_lora.py +++ b/examples/dreambooth/train_dreambooth_lora.py @@ -47,7 +47,7 @@ UNet2DConditionModel, ) from diffusers.loaders import AttnProcsLayers -from diffusers.models.cross_attention import LoRACrossAttnProcessor +from diffusers.models.attention_processor import LoRACrossAttnProcessor from diffusers.optimization import get_scheduler from diffusers.utils import check_min_version, is_wandb_available from diffusers.utils.import_utils import is_xformers_available diff --git a/examples/research_projects/dreambooth_inpaint/train_dreambooth_inpaint_lora.py b/examples/research_projects/dreambooth_inpaint/train_dreambooth_inpaint_lora.py index 4acc6e501b32..20c96dc7fcf2 100644 --- a/examples/research_projects/dreambooth_inpaint/train_dreambooth_inpaint_lora.py +++ b/examples/research_projects/dreambooth_inpaint/train_dreambooth_inpaint_lora.py @@ -22,7 +22,7 @@ from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionInpaintPipeline, UNet2DConditionModel from diffusers.loaders import AttnProcsLayers -from diffusers.models.cross_attention import LoRACrossAttnProcessor +from diffusers.models.attention_processor import LoRACrossAttnProcessor from diffusers.optimization import get_scheduler from diffusers.utils import check_min_version from diffusers.utils.import_utils import is_xformers_available diff --git a/examples/research_projects/lora/train_text_to_image_lora.py b/examples/research_projects/lora/train_text_to_image_lora.py index 2d0f807bdff3..00dbc59c8603 100644 --- a/examples/research_projects/lora/train_text_to_image_lora.py +++ b/examples/research_projects/lora/train_text_to_image_lora.py @@ -43,7 +43,7 @@ import diffusers from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel from diffusers.loaders import AttnProcsLayers -from diffusers.models.cross_attention import LoRACrossAttnProcessor +from diffusers.models.attention_processor import LoRACrossAttnProcessor from diffusers.optimization import get_scheduler from diffusers.utils import check_min_version, is_wandb_available from diffusers.utils.import_utils import is_xformers_available diff --git a/examples/text_to_image/train_text_to_image_lora.py b/examples/text_to_image/train_text_to_image_lora.py index ba4450f7d82c..e48a6e16e0de 100644 --- a/examples/text_to_image/train_text_to_image_lora.py +++ b/examples/text_to_image/train_text_to_image_lora.py @@ -41,7 +41,7 @@ import diffusers from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel from diffusers.loaders import AttnProcsLayers -from diffusers.models.cross_attention import LoRACrossAttnProcessor +from diffusers.models.attention_processor import LoRACrossAttnProcessor from diffusers.optimization import get_scheduler from diffusers.utils import check_min_version, is_wandb_available from diffusers.utils.import_utils import is_xformers_available diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index eaf6e594a278..251a28be58ce 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -17,7 +17,7 @@ import torch -from .models.cross_attention import LoRACrossAttnProcessor +from .models.attention_processor import LoRACrossAttnProcessor from .models.modeling_utils import _get_model_file from .utils import DIFFUSERS_CACHE, HF_HUB_OFFLINE, deprecate, is_safetensors_available, logging diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 6da318e65593..aa10bdd0e952 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -19,7 +19,7 @@ from torch import nn from ..utils.import_utils import is_xformers_available -from .cross_attention import CrossAttention +from .attention_processor import Attention from .embeddings import CombinedTimestepLabelEmbeddings @@ -220,7 +220,7 @@ def __init__( ) # 1. Self-Attn - self.attn1 = CrossAttention( + self.attn1 = Attention( query_dim=dim, heads=num_attention_heads, dim_head=attention_head_dim, @@ -234,7 +234,7 @@ def __init__( # 2. Cross-Attn if cross_attention_dim is not None: - self.attn2 = CrossAttention( + self.attn2 = Attention( query_dim=dim, cross_attention_dim=cross_attention_dim, heads=num_attention_heads, diff --git a/src/diffusers/models/attention_flax.py b/src/diffusers/models/attention_flax.py index b0717356bec1..1a47d728c2f9 100644 --- a/src/diffusers/models/attention_flax.py +++ b/src/diffusers/models/attention_flax.py @@ -16,7 +16,7 @@ import jax.numpy as jnp -class FlaxCrossAttention(nn.Module): +class FlaxAttention(nn.Module): r""" A Flax multi-head attention module as described in: https://arxiv.org/abs/1706.03762 @@ -118,9 +118,9 @@ class FlaxBasicTransformerBlock(nn.Module): def setup(self): # self attention (or cross_attention if only_cross_attention is True) - self.attn1 = FlaxCrossAttention(self.dim, self.n_heads, self.d_head, self.dropout, dtype=self.dtype) + self.attn1 = FlaxAttention(self.dim, self.n_heads, self.d_head, self.dropout, dtype=self.dtype) # cross attention - self.attn2 = FlaxCrossAttention(self.dim, self.n_heads, self.d_head, self.dropout, dtype=self.dtype) + self.attn2 = FlaxAttention(self.dim, self.n_heads, self.d_head, self.dropout, dtype=self.dtype) self.ff = FlaxFeedForward(dim=self.dim, dropout=self.dropout, dtype=self.dtype) self.norm1 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype) self.norm2 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index a0ecfb0f406d..69eb46a02e25 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -31,7 +31,7 @@ xformers = None -class CrossAttention(nn.Module): +class Attention(nn.Module): r""" A cross attention layer. @@ -204,7 +204,7 @@ def set_processor(self, processor: "AttnProcessor"): self.processor = processor def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, **cross_attention_kwargs): - # The `CrossAttention` class can call different attention processors / attention functions + # The `Attention` class can call different attention processors / attention functions # here we simply pass along all tensors to the selected processor class # For standard processors that are defined here, `**cross_attention_kwargs` is empty return self.processor( @@ -295,7 +295,7 @@ def prepare_attention_mask(self, attention_mask, target_length, batch_size=None) class CrossAttnProcessor: def __call__( self, - attn: CrossAttention, + attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, @@ -366,9 +366,7 @@ def __init__(self, hidden_size, cross_attention_dim=None, rank=4): self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank) self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank) - def __call__( - self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0 - ): + def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0): batch_size, sequence_length, _ = ( hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape ) @@ -398,7 +396,7 @@ def __call__( class CrossAttnAddedKVProcessor: - def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None): + def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None): residual = hidden_states hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2) batch_size, sequence_length, _ = hidden_states.shape @@ -443,7 +441,7 @@ class XFormersCrossAttnProcessor: def __init__(self, attention_op: Optional[Callable] = None): self.attention_op = attention_op - def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None): + def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None): batch_size, sequence_length, _ = ( hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape ) @@ -482,7 +480,7 @@ def __init__(self): if not hasattr(F, "scaled_dot_product_attention"): raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") - def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None): + def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None): batch_size, sequence_length, _ = ( hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape ) @@ -539,9 +537,7 @@ def __init__(self, hidden_size, cross_attention_dim, rank=4, attention_op: Optio self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank) self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank) - def __call__( - self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0 - ): + def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0): batch_size, sequence_length, _ = ( hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape ) @@ -575,7 +571,7 @@ class SlicedAttnProcessor: def __init__(self, slice_size): self.slice_size = slice_size - def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None): + def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None): batch_size, sequence_length, _ = ( hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape ) @@ -628,7 +624,7 @@ class SlicedAttnAddedKVProcessor: def __init__(self, slice_size): self.slice_size = slice_size - def __call__(self, attn: "CrossAttention", hidden_states, encoder_hidden_states=None, attention_mask=None): + def __call__(self, attn: "Attention", hidden_states, encoder_hidden_states=None, attention_mask=None): residual = hidden_states hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2) encoder_hidden_states = encoder_hidden_states.transpose(1, 2) diff --git a/src/diffusers/models/controlnet.py b/src/diffusers/models/controlnet.py index 5895ae4de5b9..8923c238fcd4 100644 --- a/src/diffusers/models/controlnet.py +++ b/src/diffusers/models/controlnet.py @@ -20,7 +20,7 @@ from ..configuration_utils import ConfigMixin, register_to_config from ..utils import BaseOutput, logging -from .cross_attention import AttnProcessor +from .attention_processor import AttnProcessor from .embeddings import TimestepEmbedding, Timesteps from .modeling_utils import ModelMixin from .unet_2d_blocks import ( @@ -343,7 +343,7 @@ def set_attn_processor(self, processor: Union[AttnProcessor, Dict[str, AttnProce Parameters: `processor (`dict` of `AttnProcessor` or `AttnProcessor`): The instantiated processor class or a dictionary of processor classes that will be set as the processor - of **all** `CrossAttention` layers. + of **all** `Attention` layers. In case `processor` is a dict, the key needs to define the path to the corresponding cross attention processor. This is strongly recommended when setting trainablae attention processors.: """ diff --git a/src/diffusers/models/dual_transformer_2d.py b/src/diffusers/models/dual_transformer_2d.py index 8b805c98147c..3db7e73ca6af 100644 --- a/src/diffusers/models/dual_transformer_2d.py +++ b/src/diffusers/models/dual_transformer_2d.py @@ -114,7 +114,7 @@ def forward( timestep ( `torch.long`, *optional*): Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step. attention_mask (`torch.FloatTensor`, *optional*): - Optional attention mask to be applied in CrossAttention + Optional attention mask to be applied in Attention return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. diff --git a/src/diffusers/models/unet_2d_blocks.py b/src/diffusers/models/unet_2d_blocks.py index 8269b77f54d4..5453530b00bb 100644 --- a/src/diffusers/models/unet_2d_blocks.py +++ b/src/diffusers/models/unet_2d_blocks.py @@ -18,7 +18,7 @@ from torch import nn from .attention import AdaGroupNorm, AttentionBlock -from .cross_attention import CrossAttention, CrossAttnAddedKVProcessor +from .attention_processor import Attention, CrossAttnAddedKVProcessor from .dual_transformer_2d import DualTransformer2DModel from .resnet import Downsample2D, FirDownsample2D, FirUpsample2D, KDownsample2D, KUpsample2D, ResnetBlock2D, Upsample2D from .transformer_2d import Transformer2DModel @@ -591,7 +591,7 @@ def __init__( for _ in range(num_layers): attentions.append( - CrossAttention( + Attention( query_dim=in_channels, cross_attention_dim=in_channels, heads=self.num_heads, @@ -1365,7 +1365,7 @@ def __init__( ) ) attentions.append( - CrossAttention( + Attention( query_dim=out_channels, cross_attention_dim=out_channels, heads=self.num_heads, @@ -2358,7 +2358,7 @@ def __init__( ) ) attentions.append( - CrossAttention( + Attention( query_dim=out_channels, cross_attention_dim=out_channels, heads=self.num_heads, @@ -2677,7 +2677,7 @@ def __init__( # 1. Self-Attn if add_self_attention: self.norm1 = AdaGroupNorm(temb_channels, dim, max(1, dim // group_size)) - self.attn1 = CrossAttention( + self.attn1 = Attention( query_dim=dim, heads=num_attention_heads, dim_head=attention_head_dim, @@ -2689,7 +2689,7 @@ def __init__( # 2. Cross-Attn self.norm2 = AdaGroupNorm(temb_channels, dim, max(1, dim // group_size)) - self.attn2 = CrossAttention( + self.attn2 = Attention( query_dim=dim, cross_attention_dim=cross_attention_dim, heads=num_attention_heads, diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index 24e827a328f7..ac43a60c79ce 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -21,7 +21,7 @@ from ..configuration_utils import ConfigMixin, register_to_config from ..loaders import UNet2DConditionLoadersMixin from ..utils import BaseOutput, logging -from .cross_attention import AttnProcessor +from .attention_processor import AttnProcessor from .embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps from .modeling_utils import ModelMixin from .unet_2d_blocks import ( @@ -390,7 +390,7 @@ def set_attn_processor(self, processor: Union[AttnProcessor, Dict[str, AttnProce Parameters: `processor (`dict` of `AttnProcessor` or `AttnProcessor`): The instantiated processor class or a dictionary of processor classes that will be set as the processor - of **all** `CrossAttention` layers. + of **all** `Attention` layers. In case `processor` is a dict, the key needs to define the path to the corresponding cross attention processor. This is strongly recommended when setting trainablae attention processors.: """ diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_attend_and_excite.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_attend_and_excite.py index 028b7390e906..447855a5be2c 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_attend_and_excite.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_attend_and_excite.py @@ -22,7 +22,7 @@ from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer from ...models import AutoencoderKL, UNet2DConditionModel -from ...models.cross_attention import CrossAttention +from ...models.attention_processor import Attention from ...schedulers import KarrasDiffusionSchedulers from ...utils import is_accelerate_available, is_accelerate_version, logging, randn_tensor, replace_example_docstring from ..pipeline_utils import DiffusionPipeline @@ -127,7 +127,7 @@ def __init__(self, attnstore, place_in_unet): self.attnstore = attnstore self.place_in_unet = place_in_unet - def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None): + def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None): batch_size, sequence_length, _ = hidden_states.shape attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py index 0e58701d93a7..29899edc3117 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py @@ -29,7 +29,7 @@ ) from ...models import AutoencoderKL, UNet2DConditionModel -from ...models.cross_attention import CrossAttention +from ...models.attention_processor import Attention from ...schedulers import DDIMScheduler, DDPMScheduler, EulerAncestralDiscreteScheduler, LMSDiscreteScheduler from ...schedulers.scheduling_ddim_inverse import DDIMInverseScheduler from ...utils import ( @@ -229,7 +229,7 @@ def __init__(self, is_pix2pix_zero=False): def __call__( self, - attn: CrossAttention, + attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, diff --git a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py index 76bfdc4313ca..057455523f99 100644 --- a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py +++ b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py @@ -6,8 +6,8 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...models import ModelMixin -from ...models.attention import CrossAttention -from ...models.cross_attention import AttnProcessor, CrossAttnAddedKVProcessor +from ...models.attention import Attention +from ...models.attention_processor import AttnProcessor, CrossAttnAddedKVProcessor from ...models.dual_transformer_2d import DualTransformer2DModel from ...models.embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps from ...models.transformer_2d import Transformer2DModel @@ -480,7 +480,7 @@ def set_attn_processor(self, processor: Union[AttnProcessor, Dict[str, AttnProce Parameters: `processor (`dict` of `AttnProcessor` or `AttnProcessor`): The instantiated processor class or a dictionary of processor classes that will be set as the processor - of **all** `CrossAttention` layers. + of **all** `Attention` layers. In case `processor` is a dict, the key needs to define the path to the corresponding cross attention processor. This is strongly recommended when setting trainablae attention processors.: """ @@ -1425,7 +1425,7 @@ def __init__( for _ in range(num_layers): attentions.append( - CrossAttention( + Attention( query_dim=in_channels, cross_attention_dim=in_channels, heads=self.num_heads, diff --git a/tests/models/test_models_unet_2d_condition.py b/tests/models/test_models_unet_2d_condition.py index e313fcfb0b29..b79d97c82b28 100644 --- a/tests/models/test_models_unet_2d_condition.py +++ b/tests/models/test_models_unet_2d_condition.py @@ -22,7 +22,7 @@ from parameterized import parameterized from diffusers import UNet2DConditionModel -from diffusers.models.cross_attention import CrossAttnProcessor, LoRACrossAttnProcessor +from diffusers.models.attention_processor import CrossAttnProcessor, LoRACrossAttnProcessor from diffusers.utils import ( floats_tensor, load_hf_numpy, From c7a74fe89191a82e97c17a700875de80469c8e4b Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 15 Mar 2023 18:42:05 +0100 Subject: [PATCH 3/7] fix more --- .../stable_diffusion_controlnet_img2img.py | 2 +- .../stable_diffusion_controlnet_inpaint.py | 2 +- ...le_diffusion_controlnet_inpaint_img2img.py | 2 +- examples/dreambooth/train_dreambooth_lora.py | 6 ++-- .../train_dreambooth_inpaint_lora.py | 6 ++-- .../lora/train_text_to_image_lora.py | 6 ++-- .../text_to_image/train_text_to_image_lora.py | 6 ++-- src/diffusers/loaders.py | 4 +-- src/diffusers/models/attention_processor.py | 32 +++++++++---------- src/diffusers/models/controlnet.py | 10 +++--- src/diffusers/models/unet_2d_condition.py | 12 +++---- .../alt_diffusion/pipeline_alt_diffusion.py | 2 +- .../pipeline_stable_diffusion.py | 2 +- ...line_stable_diffusion_attend_and_excite.py | 8 ++--- .../pipeline_stable_diffusion_controlnet.py | 2 +- .../pipeline_stable_diffusion_panorama.py | 2 +- .../pipeline_stable_diffusion_pix2pix_zero.py | 6 ++-- .../pipeline_stable_diffusion_sag.py | 2 +- .../pipeline_stable_unclip.py | 2 +- .../pipeline_stable_unclip_img2img.py | 2 +- .../versatile_diffusion/modeling_text_unet.py | 12 +++---- tests/models/test_models_unet_2d_condition.py | 22 ++++--------- 22 files changed, 66 insertions(+), 84 deletions(-) diff --git a/examples/community/stable_diffusion_controlnet_img2img.py b/examples/community/stable_diffusion_controlnet_img2img.py index 2b3a45a9b787..60f16c820b9b 100644 --- a/examples/community/stable_diffusion_controlnet_img2img.py +++ b/examples/community/stable_diffusion_controlnet_img2img.py @@ -713,7 +713,7 @@ def __call__( The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. cross_attention_kwargs (`dict`, *optional*): - A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under + A kwargs dictionary that if specified is passed along to the `AttnProcessors` as defined under `self.processor` in [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). controlnet_conditioning_scale (`float`, *optional*, defaults to 1.0): diff --git a/examples/community/stable_diffusion_controlnet_inpaint.py b/examples/community/stable_diffusion_controlnet_inpaint.py index 98d73bb4ac02..cbe93fb6b995 100644 --- a/examples/community/stable_diffusion_controlnet_inpaint.py +++ b/examples/community/stable_diffusion_controlnet_inpaint.py @@ -868,7 +868,7 @@ def __call__( The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. cross_attention_kwargs (`dict`, *optional*): - A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under + A kwargs dictionary that if specified is passed along to the `AttnProcessors` as defined under `self.processor` in [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). controlnet_conditioning_scale (`float`, *optional*, defaults to 1.0): diff --git a/examples/community/stable_diffusion_controlnet_inpaint_img2img.py b/examples/community/stable_diffusion_controlnet_inpaint_img2img.py index acddb71e74de..88dcc26b61d9 100644 --- a/examples/community/stable_diffusion_controlnet_inpaint_img2img.py +++ b/examples/community/stable_diffusion_controlnet_inpaint_img2img.py @@ -911,7 +911,7 @@ def __call__( The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. cross_attention_kwargs (`dict`, *optional*): - A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under + A kwargs dictionary that if specified is passed along to the `AttnProcessors` as defined under `self.processor` in [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). controlnet_conditioning_scale (`float`, *optional*, defaults to 1.0): diff --git a/examples/dreambooth/train_dreambooth_lora.py b/examples/dreambooth/train_dreambooth_lora.py index 8de512462ad5..daef268ff8f3 100644 --- a/examples/dreambooth/train_dreambooth_lora.py +++ b/examples/dreambooth/train_dreambooth_lora.py @@ -47,7 +47,7 @@ UNet2DConditionModel, ) from diffusers.loaders import AttnProcsLayers -from diffusers.models.attention_processor import LoRACrossAttnProcessor +from diffusers.models.attention_processor import LoRAAttnProcessor from diffusers.optimization import get_scheduler from diffusers.utils import check_min_version, is_wandb_available from diffusers.utils.import_utils import is_xformers_available @@ -723,9 +723,7 @@ def main(args): block_id = int(name[len("down_blocks.")]) hidden_size = unet.config.block_out_channels[block_id] - lora_attn_procs[name] = LoRACrossAttnProcessor( - hidden_size=hidden_size, cross_attention_dim=cross_attention_dim - ) + lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim) unet.set_attn_processor(lora_attn_procs) lora_layers = AttnProcsLayers(unet.attn_processors) diff --git a/examples/research_projects/dreambooth_inpaint/train_dreambooth_inpaint_lora.py b/examples/research_projects/dreambooth_inpaint/train_dreambooth_inpaint_lora.py index 20c96dc7fcf2..e415e6965317 100644 --- a/examples/research_projects/dreambooth_inpaint/train_dreambooth_inpaint_lora.py +++ b/examples/research_projects/dreambooth_inpaint/train_dreambooth_inpaint_lora.py @@ -22,7 +22,7 @@ from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionInpaintPipeline, UNet2DConditionModel from diffusers.loaders import AttnProcsLayers -from diffusers.models.attention_processor import LoRACrossAttnProcessor +from diffusers.models.attention_processor import LoRAAttnProcessor from diffusers.optimization import get_scheduler from diffusers.utils import check_min_version from diffusers.utils.import_utils import is_xformers_available @@ -561,9 +561,7 @@ def main(): block_id = int(name[len("down_blocks.")]) hidden_size = unet.config.block_out_channels[block_id] - lora_attn_procs[name] = LoRACrossAttnProcessor( - hidden_size=hidden_size, cross_attention_dim=cross_attention_dim - ) + lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim) unet.set_attn_processor(lora_attn_procs) lora_layers = AttnProcsLayers(unet.attn_processors) diff --git a/examples/research_projects/lora/train_text_to_image_lora.py b/examples/research_projects/lora/train_text_to_image_lora.py index 00dbc59c8603..a53af7bcffd2 100644 --- a/examples/research_projects/lora/train_text_to_image_lora.py +++ b/examples/research_projects/lora/train_text_to_image_lora.py @@ -43,7 +43,7 @@ import diffusers from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel from diffusers.loaders import AttnProcsLayers -from diffusers.models.attention_processor import LoRACrossAttnProcessor +from diffusers.models.attention_processor import LoRAAttnProcessor from diffusers.optimization import get_scheduler from diffusers.utils import check_min_version, is_wandb_available from diffusers.utils.import_utils import is_xformers_available @@ -536,9 +536,7 @@ def main(): block_id = int(name[len("down_blocks.")]) hidden_size = unet.config.block_out_channels[block_id] - lora_attn_procs[name] = LoRACrossAttnProcessor( - hidden_size=hidden_size, cross_attention_dim=cross_attention_dim - ) + lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim) unet.set_attn_processor(lora_attn_procs) lora_layers = AttnProcsLayers(unet.attn_processors) diff --git a/examples/text_to_image/train_text_to_image_lora.py b/examples/text_to_image/train_text_to_image_lora.py index e48a6e16e0de..43bbd8ebf415 100644 --- a/examples/text_to_image/train_text_to_image_lora.py +++ b/examples/text_to_image/train_text_to_image_lora.py @@ -41,7 +41,7 @@ import diffusers from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel from diffusers.loaders import AttnProcsLayers -from diffusers.models.attention_processor import LoRACrossAttnProcessor +from diffusers.models.attention_processor import LoRAAttnProcessor from diffusers.optimization import get_scheduler from diffusers.utils import check_min_version, is_wandb_available from diffusers.utils.import_utils import is_xformers_available @@ -474,9 +474,7 @@ def main(): block_id = int(name[len("down_blocks.")]) hidden_size = unet.config.block_out_channels[block_id] - lora_attn_procs[name] = LoRACrossAttnProcessor( - hidden_size=hidden_size, cross_attention_dim=cross_attention_dim - ) + lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim) unet.set_attn_processor(lora_attn_procs) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index 251a28be58ce..9848ce7988c3 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -17,7 +17,7 @@ import torch -from .models.attention_processor import LoRACrossAttnProcessor +from .models.attention_processor import LoRAAttnProcessor from .models.modeling_utils import _get_model_file from .utils import DIFFUSERS_CACHE, HF_HUB_OFFLINE, deprecate, is_safetensors_available, logging @@ -207,7 +207,7 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict cross_attention_dim = value_dict["to_k_lora.down.weight"].shape[1] hidden_size = value_dict["to_k_lora.up.weight"].shape[0] - attn_processors[key] = LoRACrossAttnProcessor( + attn_processors[key] = LoRAAttnProcessor( hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, rank=rank ) attn_processors[key].load_state_dict(value_dict) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 69eb46a02e25..798664d2b5ed 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -106,7 +106,7 @@ def __init__( # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1 if processor is None: processor = ( - AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and scale_qk else CrossAttnProcessor() + AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and scale_qk else AttnProcessor() ) self.set_processor(processor) @@ -114,7 +114,7 @@ def set_use_memory_efficient_attention_xformers( self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None ): is_lora = hasattr(self, "processor") and isinstance( - self.processor, (LoRACrossAttnProcessor, LoRAXFormersCrossAttnProcessor) + self.processor, (LoRAAttnProcessor, LoRAXFormersAttnProcessor) ) if use_memory_efficient_attention_xformers: @@ -151,7 +151,7 @@ def set_use_memory_efficient_attention_xformers( raise e if is_lora: - processor = LoRAXFormersCrossAttnProcessor( + processor = LoRAXFormersAttnProcessor( hidden_size=self.processor.hidden_size, cross_attention_dim=self.processor.cross_attention_dim, rank=self.processor.rank, @@ -160,10 +160,10 @@ def set_use_memory_efficient_attention_xformers( processor.load_state_dict(self.processor.state_dict()) processor.to(self.processor.to_q_lora.up.weight.device) else: - processor = XFormersCrossAttnProcessor(attention_op=attention_op) + processor = XFormersAttnProcessor(attention_op=attention_op) else: if is_lora: - processor = LoRACrossAttnProcessor( + processor = LoRAAttnProcessor( hidden_size=self.processor.hidden_size, cross_attention_dim=self.processor.cross_attention_dim, rank=self.processor.rank, @@ -171,7 +171,7 @@ def set_use_memory_efficient_attention_xformers( processor.load_state_dict(self.processor.state_dict()) processor.to(self.processor.to_q_lora.up.weight.device) else: - processor = CrossAttnProcessor() + processor = AttnProcessor() self.set_processor(processor) @@ -186,7 +186,7 @@ def set_attention_slice(self, slice_size): elif self.added_kv_proj_dim is not None: processor = CrossAttnAddedKVProcessor() else: - processor = CrossAttnProcessor() + processor = AttnProcessor() self.set_processor(processor) @@ -292,7 +292,7 @@ def prepare_attention_mask(self, attention_mask, target_length, batch_size=None) return attention_mask -class CrossAttnProcessor: +class AttnProcessor: def __call__( self, attn: Attention, @@ -353,7 +353,7 @@ def forward(self, hidden_states): return up_hidden_states.to(orig_dtype) -class LoRACrossAttnProcessor(nn.Module): +class LoRAAttnProcessor(nn.Module): def __init__(self, hidden_size, cross_attention_dim=None, rank=4): super().__init__() @@ -437,7 +437,7 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a return hidden_states -class XFormersCrossAttnProcessor: +class XFormersAttnProcessor: def __init__(self, attention_op: Optional[Callable] = None): self.attention_op = attention_op @@ -523,7 +523,7 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a return hidden_states -class LoRAXFormersCrossAttnProcessor(nn.Module): +class LoRAXFormersAttnProcessor(nn.Module): def __init__(self, hidden_size, cross_attention_dim, rank=4, attention_op: Optional[Callable] = None): super().__init__() @@ -684,12 +684,12 @@ def __call__(self, attn: "Attention", hidden_states, encoder_hidden_states=None, return hidden_states -AttnProcessor = Union[ - CrossAttnProcessor, - XFormersCrossAttnProcessor, +AttnProcessors = Union[ + AttnProcessor, + XFormersAttnProcessor, SlicedAttnProcessor, CrossAttnAddedKVProcessor, SlicedAttnAddedKVProcessor, - LoRACrossAttnProcessor, - LoRAXFormersCrossAttnProcessor, + LoRAAttnProcessor, + LoRAXFormersAttnProcessor, ] diff --git a/src/diffusers/models/controlnet.py b/src/diffusers/models/controlnet.py index 8923c238fcd4..4fd99cbeb686 100644 --- a/src/diffusers/models/controlnet.py +++ b/src/diffusers/models/controlnet.py @@ -20,7 +20,7 @@ from ..configuration_utils import ConfigMixin, register_to_config from ..utils import BaseOutput, logging -from .attention_processor import AttnProcessor +from .attention_processor import AttnProcessors from .embeddings import TimestepEmbedding, Timesteps from .modeling_utils import ModelMixin from .unet_2d_blocks import ( @@ -314,7 +314,7 @@ def from_unet( @property # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors - def attn_processors(self) -> Dict[str, AttnProcessor]: + def attn_processors(self) -> Dict[str, AttnProcessors]: r""" Returns: `dict` of attention processors: A dictionary containing all attention processors used in the model with @@ -323,7 +323,7 @@ def attn_processors(self) -> Dict[str, AttnProcessor]: # set recursively processors = {} - def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttnProcessor]): + def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttnProcessors]): if hasattr(module, "set_processor"): processors[f"{name}.processor"] = module.processor @@ -338,10 +338,10 @@ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: return processors # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor - def set_attn_processor(self, processor: Union[AttnProcessor, Dict[str, AttnProcessor]]): + def set_attn_processor(self, processor: Union[AttnProcessors, Dict[str, AttnProcessors]]): r""" Parameters: - `processor (`dict` of `AttnProcessor` or `AttnProcessor`): + `processor (`dict` of `AttnProcessors` or `AttnProcessors`): The instantiated processor class or a dictionary of processor classes that will be set as the processor of **all** `Attention` layers. In case `processor` is a dict, the key needs to define the path to the corresponding cross attention processor. This is strongly recommended when setting trainablae attention processors.: diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index ac43a60c79ce..50de1ad406e8 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -21,7 +21,7 @@ from ..configuration_utils import ConfigMixin, register_to_config from ..loaders import UNet2DConditionLoadersMixin from ..utils import BaseOutput, logging -from .attention_processor import AttnProcessor +from .attention_processor import AttnProcessors from .embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps from .modeling_utils import ModelMixin from .unet_2d_blocks import ( @@ -362,7 +362,7 @@ def __init__( ) @property - def attn_processors(self) -> Dict[str, AttnProcessor]: + def attn_processors(self) -> Dict[str, AttnProcessors]: r""" Returns: `dict` of attention processors: A dictionary containing all attention processors used in the model with @@ -371,7 +371,7 @@ def attn_processors(self) -> Dict[str, AttnProcessor]: # set recursively processors = {} - def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttnProcessor]): + def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttnProcessors]): if hasattr(module, "set_processor"): processors[f"{name}.processor"] = module.processor @@ -385,10 +385,10 @@ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: return processors - def set_attn_processor(self, processor: Union[AttnProcessor, Dict[str, AttnProcessor]]): + def set_attn_processor(self, processor: Union[AttnProcessors, Dict[str, AttnProcessors]]): r""" Parameters: - `processor (`dict` of `AttnProcessor` or `AttnProcessor`): + `processor (`dict` of `AttnProcessors` or `AttnProcessors`): The instantiated processor class or a dictionary of processor classes that will be set as the processor of **all** `Attention` layers. In case `processor` is a dict, the key needs to define the path to the corresponding cross attention processor. This is strongly recommended when setting trainablae attention processors.: @@ -505,7 +505,7 @@ def forward( return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. cross_attention_kwargs (`dict`, *optional*): - A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under + A kwargs dictionary that if specified is passed along to the `AttnProcessors` as defined under `self.processor` in [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). diff --git a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py index 71e98480ed2d..175da9c4943c 100644 --- a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py +++ b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py @@ -585,7 +585,7 @@ def __call__( The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. cross_attention_kwargs (`dict`, *optional*): - A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under + A kwargs dictionary that if specified is passed along to the `AttnProcessors` as defined under `self.processor` in [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index 504479798617..ee5ba1f172a3 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -588,7 +588,7 @@ def __call__( The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. cross_attention_kwargs (`dict`, *optional*): - A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under + A kwargs dictionary that if specified is passed along to the `AttnProcessors` as defined under `self.processor` in [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_attend_and_excite.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_attend_and_excite.py index 447855a5be2c..5d8b473d05f1 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_attend_and_excite.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_attend_and_excite.py @@ -121,7 +121,7 @@ def __init__(self, attn_res=16): self.attn_res = attn_res -class AttendExciteCrossAttnProcessor: +class AttendExciteAttnProcessor: def __init__(self, attnstore, place_in_unet): super().__init__() self.attnstore = attnstore @@ -679,9 +679,7 @@ def register_attention_control(self): continue cross_att_count += 1 - attn_procs[name] = AttendExciteCrossAttnProcessor( - attnstore=self.attention_store, place_in_unet=place_in_unet - ) + attn_procs[name] = AttendExciteAttnProcessor(attnstore=self.attention_store, place_in_unet=place_in_unet) self.unet.set_attn_processor(attn_procs) self.attention_store.num_att_layers = cross_att_count @@ -777,7 +775,7 @@ def __call__( The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. cross_attention_kwargs (`dict`, *optional*): - A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under + A kwargs dictionary that if specified is passed along to the `AttnProcessors` as defined under `self.processor` in [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). max_iter_to_alter (`int`, *optional*, defaults to `25`): diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py index 08643c6b891a..042e02406e1c 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py @@ -789,7 +789,7 @@ def __call__( The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. cross_attention_kwargs (`dict`, *optional*): - A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under + A kwargs dictionary that if specified is passed along to the `AttnProcessors` as defined under `self.processor` in [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0): diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py index b1f29fbef12b..7c853219659f 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py @@ -525,7 +525,7 @@ def __call__( The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. cross_attention_kwargs (`dict`, *optional*): - A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under + A kwargs dictionary that if specified is passed along to the `AttnProcessors` as defined under `self.processor` in [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py index 29899edc3117..7de12bd291fb 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py @@ -200,10 +200,10 @@ def prepare_unet(unet: UNet2DConditionModel): module_name = name.replace(".processor", "") module = unet.get_submodule(module_name) if "attn2" in name: - pix2pix_zero_attn_procs[name] = Pix2PixZeroCrossAttnProcessor(is_pix2pix_zero=True) + pix2pix_zero_attn_procs[name] = Pix2PixZeroAttnProcessor(is_pix2pix_zero=True) module.requires_grad_(True) else: - pix2pix_zero_attn_procs[name] = Pix2PixZeroCrossAttnProcessor(is_pix2pix_zero=False) + pix2pix_zero_attn_procs[name] = Pix2PixZeroAttnProcessor(is_pix2pix_zero=False) module.requires_grad_(False) unet.set_attn_processor(pix2pix_zero_attn_procs) @@ -218,7 +218,7 @@ def compute_loss(self, predictions, targets): self.loss += ((predictions - targets) ** 2).sum((1, 2)).mean(0) -class Pix2PixZeroCrossAttnProcessor: +class Pix2PixZeroAttnProcessor: """An attention processor class to store the attention weights. In Pix2Pix Zero, it happens during computations in the cross-attention blocks.""" diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_sag.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_sag.py index afb473105512..493fbcd9ba13 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_sag.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_sag.py @@ -530,7 +530,7 @@ def __call__( The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. cross_attention_kwargs (`dict`, *optional*): - A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under + A kwargs dictionary that if specified is passed along to the `AttnProcessors` as defined under `self.processor` in [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py index 94780c9eb260..de0122251c43 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py @@ -684,7 +684,7 @@ def __call__( The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. cross_attention_kwargs (`dict`, *optional*): - A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under + A kwargs dictionary that if specified is passed along to the `AttnProcessors` as defined under `self.processor` in [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). noise_level (`int`, *optional*, defaults to `0`): diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py index e98dfd6f0d3a..f2d2ba479650 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py @@ -653,7 +653,7 @@ def __call__( The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. cross_attention_kwargs (`dict`, *optional*): - A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under + A kwargs dictionary that if specified is passed along to the `AttnProcessors` as defined under `self.processor` in [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). noise_level (`int`, *optional*, defaults to `0`): diff --git a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py index 057455523f99..b0ed57a81172 100644 --- a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py +++ b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py @@ -7,7 +7,7 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...models import ModelMixin from ...models.attention import Attention -from ...models.attention_processor import AttnProcessor, CrossAttnAddedKVProcessor +from ...models.attention_processor import AttnProcessors, CrossAttnAddedKVProcessor from ...models.dual_transformer_2d import DualTransformer2DModel from ...models.embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps from ...models.transformer_2d import Transformer2DModel @@ -452,7 +452,7 @@ def __init__( ) @property - def attn_processors(self) -> Dict[str, AttnProcessor]: + def attn_processors(self) -> Dict[str, AttnProcessors]: r""" Returns: `dict` of attention processors: A dictionary containing all attention processors used in the model with @@ -461,7 +461,7 @@ def attn_processors(self) -> Dict[str, AttnProcessor]: # set recursively processors = {} - def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttnProcessor]): + def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttnProcessors]): if hasattr(module, "set_processor"): processors[f"{name}.processor"] = module.processor @@ -475,10 +475,10 @@ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: return processors - def set_attn_processor(self, processor: Union[AttnProcessor, Dict[str, AttnProcessor]]): + def set_attn_processor(self, processor: Union[AttnProcessors, Dict[str, AttnProcessors]]): r""" Parameters: - `processor (`dict` of `AttnProcessor` or `AttnProcessor`): + `processor (`dict` of `AttnProcessors` or `AttnProcessors`): The instantiated processor class or a dictionary of processor classes that will be set as the processor of **all** `Attention` layers. In case `processor` is a dict, the key needs to define the path to the corresponding cross attention processor. This is strongly recommended when setting trainablae attention processors.: @@ -595,7 +595,7 @@ def forward( return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. cross_attention_kwargs (`dict`, *optional*): - A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under + A kwargs dictionary that if specified is passed along to the `AttnProcessors` as defined under `self.processor` in [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). diff --git a/tests/models/test_models_unet_2d_condition.py b/tests/models/test_models_unet_2d_condition.py index b79d97c82b28..24707df9d94d 100644 --- a/tests/models/test_models_unet_2d_condition.py +++ b/tests/models/test_models_unet_2d_condition.py @@ -22,7 +22,7 @@ from parameterized import parameterized from diffusers import UNet2DConditionModel -from diffusers.models.attention_processor import CrossAttnProcessor, LoRACrossAttnProcessor +from diffusers.models.attention_processor import AttnProcessor, LoRAAttnProcessor from diffusers.utils import ( floats_tensor, load_hf_numpy, @@ -54,9 +54,7 @@ def create_lora_layers(model): block_id = int(name[len("down_blocks.")]) hidden_size = model.config.block_out_channels[block_id] - lora_attn_procs[name] = LoRACrossAttnProcessor( - hidden_size=hidden_size, cross_attention_dim=cross_attention_dim - ) + lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim) lora_attn_procs[name] = lora_attn_procs[name].to(model.device) # add 1 to weights to mock trained weights @@ -119,7 +117,7 @@ def test_xformers_enable_works(self): assert ( model.mid_block.attentions[0].transformer_blocks[0].attn1.processor.__class__.__name__ - == "XFormersCrossAttnProcessor" + == "XFormersAttnProcessor" ), "xformers is not enabled" @unittest.skipIf(torch_device == "mps", "Gradient checkpointing skipped on MPS") @@ -324,9 +322,7 @@ def test_lora_processors(self): block_id = int(name[len("down_blocks.")]) hidden_size = model.config.block_out_channels[block_id] - lora_attn_procs[name] = LoRACrossAttnProcessor( - hidden_size=hidden_size, cross_attention_dim=cross_attention_dim - ) + lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim) # add 1 to weights to mock trained weights with torch.no_grad(): @@ -413,9 +409,7 @@ def test_lora_save_load_safetensors(self): block_id = int(name[len("down_blocks.")]) hidden_size = model.config.block_out_channels[block_id] - lora_attn_procs[name] = LoRACrossAttnProcessor( - hidden_size=hidden_size, cross_attention_dim=cross_attention_dim - ) + lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim) lora_attn_procs[name] = lora_attn_procs[name].to(model.device) # add 1 to weights to mock trained weights @@ -468,9 +462,7 @@ def test_lora_save_load_safetensors_load_torch(self): block_id = int(name[len("down_blocks.")]) hidden_size = model.config.block_out_channels[block_id] - lora_attn_procs[name] = LoRACrossAttnProcessor( - hidden_size=hidden_size, cross_attention_dim=cross_attention_dim - ) + lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim) lora_attn_procs[name] = lora_attn_procs[name].to(model.device) model.set_attn_processor(lora_attn_procs) @@ -502,7 +494,7 @@ def test_lora_on_off(self): with torch.no_grad(): sample = model(**inputs_dict, cross_attention_kwargs={"scale": 0.0}).sample - model.set_attn_processor(CrossAttnProcessor()) + model.set_attn_processor(AttnProcessor()) with torch.no_grad(): new_sample = model(**inputs_dict).sample From ea8842cd6c9b032f812f2d232f423e82cbfe6bfd Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 15 Mar 2023 18:43:50 +0100 Subject: [PATCH 4/7] rename more --- .../community/stable_diffusion_controlnet_img2img.py | 2 +- .../community/stable_diffusion_controlnet_inpaint.py | 2 +- .../stable_diffusion_controlnet_inpaint_img2img.py | 2 +- src/diffusers/models/attention_processor.py | 2 +- src/diffusers/models/controlnet.py | 10 +++++----- src/diffusers/models/unet_2d_condition.py | 12 ++++++------ .../alt_diffusion/pipeline_alt_diffusion.py | 2 +- .../stable_diffusion/pipeline_stable_diffusion.py | 2 +- .../pipeline_stable_diffusion_attend_and_excite.py | 2 +- .../pipeline_stable_diffusion_controlnet.py | 2 +- .../pipeline_stable_diffusion_panorama.py | 2 +- .../pipeline_stable_diffusion_sag.py | 2 +- .../stable_diffusion/pipeline_stable_unclip.py | 2 +- .../pipeline_stable_unclip_img2img.py | 2 +- .../versatile_diffusion/modeling_text_unet.py | 12 ++++++------ 15 files changed, 29 insertions(+), 29 deletions(-) diff --git a/examples/community/stable_diffusion_controlnet_img2img.py b/examples/community/stable_diffusion_controlnet_img2img.py index 60f16c820b9b..5aa5e47c6578 100644 --- a/examples/community/stable_diffusion_controlnet_img2img.py +++ b/examples/community/stable_diffusion_controlnet_img2img.py @@ -713,7 +713,7 @@ def __call__( The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. cross_attention_kwargs (`dict`, *optional*): - A kwargs dictionary that if specified is passed along to the `AttnProcessors` as defined under + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). controlnet_conditioning_scale (`float`, *optional*, defaults to 1.0): diff --git a/examples/community/stable_diffusion_controlnet_inpaint.py b/examples/community/stable_diffusion_controlnet_inpaint.py index cbe93fb6b995..02e71fb97ed1 100644 --- a/examples/community/stable_diffusion_controlnet_inpaint.py +++ b/examples/community/stable_diffusion_controlnet_inpaint.py @@ -868,7 +868,7 @@ def __call__( The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. cross_attention_kwargs (`dict`, *optional*): - A kwargs dictionary that if specified is passed along to the `AttnProcessors` as defined under + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). controlnet_conditioning_scale (`float`, *optional*, defaults to 1.0): diff --git a/examples/community/stable_diffusion_controlnet_inpaint_img2img.py b/examples/community/stable_diffusion_controlnet_inpaint_img2img.py index 88dcc26b61d9..a7afe26fa91c 100644 --- a/examples/community/stable_diffusion_controlnet_inpaint_img2img.py +++ b/examples/community/stable_diffusion_controlnet_inpaint_img2img.py @@ -911,7 +911,7 @@ def __call__( The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. cross_attention_kwargs (`dict`, *optional*): - A kwargs dictionary that if specified is passed along to the `AttnProcessors` as defined under + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). controlnet_conditioning_scale (`float`, *optional*, defaults to 1.0): diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 798664d2b5ed..8d1a3906db23 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -684,7 +684,7 @@ def __call__(self, attn: "Attention", hidden_states, encoder_hidden_states=None, return hidden_states -AttnProcessors = Union[ +AttentionProcessor = Union[ AttnProcessor, XFormersAttnProcessor, SlicedAttnProcessor, diff --git a/src/diffusers/models/controlnet.py b/src/diffusers/models/controlnet.py index 4fd99cbeb686..0d59605fe046 100644 --- a/src/diffusers/models/controlnet.py +++ b/src/diffusers/models/controlnet.py @@ -20,7 +20,7 @@ from ..configuration_utils import ConfigMixin, register_to_config from ..utils import BaseOutput, logging -from .attention_processor import AttnProcessors +from .attention_processor import AttentionProcessor from .embeddings import TimestepEmbedding, Timesteps from .modeling_utils import ModelMixin from .unet_2d_blocks import ( @@ -314,7 +314,7 @@ def from_unet( @property # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors - def attn_processors(self) -> Dict[str, AttnProcessors]: + def attn_processors(self) -> Dict[str, AttentionProcessor]: r""" Returns: `dict` of attention processors: A dictionary containing all attention processors used in the model with @@ -323,7 +323,7 @@ def attn_processors(self) -> Dict[str, AttnProcessors]: # set recursively processors = {} - def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttnProcessors]): + def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): if hasattr(module, "set_processor"): processors[f"{name}.processor"] = module.processor @@ -338,10 +338,10 @@ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: return processors # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor - def set_attn_processor(self, processor: Union[AttnProcessors, Dict[str, AttnProcessors]]): + def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): r""" Parameters: - `processor (`dict` of `AttnProcessors` or `AttnProcessors`): + `processor (`dict` of `AttentionProcessor` or `AttentionProcessor`): The instantiated processor class or a dictionary of processor classes that will be set as the processor of **all** `Attention` layers. In case `processor` is a dict, the key needs to define the path to the corresponding cross attention processor. This is strongly recommended when setting trainablae attention processors.: diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index 50de1ad406e8..8cd3dcf42307 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -21,7 +21,7 @@ from ..configuration_utils import ConfigMixin, register_to_config from ..loaders import UNet2DConditionLoadersMixin from ..utils import BaseOutput, logging -from .attention_processor import AttnProcessors +from .attention_processor import AttentionProcessor from .embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps from .modeling_utils import ModelMixin from .unet_2d_blocks import ( @@ -362,7 +362,7 @@ def __init__( ) @property - def attn_processors(self) -> Dict[str, AttnProcessors]: + def attn_processors(self) -> Dict[str, AttentionProcessor]: r""" Returns: `dict` of attention processors: A dictionary containing all attention processors used in the model with @@ -371,7 +371,7 @@ def attn_processors(self) -> Dict[str, AttnProcessors]: # set recursively processors = {} - def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttnProcessors]): + def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): if hasattr(module, "set_processor"): processors[f"{name}.processor"] = module.processor @@ -385,10 +385,10 @@ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: return processors - def set_attn_processor(self, processor: Union[AttnProcessors, Dict[str, AttnProcessors]]): + def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): r""" Parameters: - `processor (`dict` of `AttnProcessors` or `AttnProcessors`): + `processor (`dict` of `AttentionProcessor` or `AttentionProcessor`): The instantiated processor class or a dictionary of processor classes that will be set as the processor of **all** `Attention` layers. In case `processor` is a dict, the key needs to define the path to the corresponding cross attention processor. This is strongly recommended when setting trainablae attention processors.: @@ -505,7 +505,7 @@ def forward( return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. cross_attention_kwargs (`dict`, *optional*): - A kwargs dictionary that if specified is passed along to the `AttnProcessors` as defined under + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). diff --git a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py index 175da9c4943c..b94a2ec05649 100644 --- a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py +++ b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py @@ -585,7 +585,7 @@ def __call__( The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. cross_attention_kwargs (`dict`, *optional*): - A kwargs dictionary that if specified is passed along to the `AttnProcessors` as defined under + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index ee5ba1f172a3..5294fa4cfa06 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -588,7 +588,7 @@ def __call__( The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. cross_attention_kwargs (`dict`, *optional*): - A kwargs dictionary that if specified is passed along to the `AttnProcessors` as defined under + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_attend_and_excite.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_attend_and_excite.py index 5d8b473d05f1..2d32c0ba8b62 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_attend_and_excite.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_attend_and_excite.py @@ -775,7 +775,7 @@ def __call__( The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. cross_attention_kwargs (`dict`, *optional*): - A kwargs dictionary that if specified is passed along to the `AttnProcessors` as defined under + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). max_iter_to_alter (`int`, *optional*, defaults to `25`): diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py index 042e02406e1c..fd82281005ad 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py @@ -789,7 +789,7 @@ def __call__( The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. cross_attention_kwargs (`dict`, *optional*): - A kwargs dictionary that if specified is passed along to the `AttnProcessors` as defined under + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0): diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py index 7c853219659f..3fea4c2d83bb 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py @@ -525,7 +525,7 @@ def __call__( The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. cross_attention_kwargs (`dict`, *optional*): - A kwargs dictionary that if specified is passed along to the `AttnProcessors` as defined under + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_sag.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_sag.py index 493fbcd9ba13..b24354a8e568 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_sag.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_sag.py @@ -530,7 +530,7 @@ def __call__( The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. cross_attention_kwargs (`dict`, *optional*): - A kwargs dictionary that if specified is passed along to the `AttnProcessors` as defined under + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py index de0122251c43..a8ba0b504628 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py @@ -684,7 +684,7 @@ def __call__( The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. cross_attention_kwargs (`dict`, *optional*): - A kwargs dictionary that if specified is passed along to the `AttnProcessors` as defined under + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). noise_level (`int`, *optional*, defaults to `0`): diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py index f2d2ba479650..99caa8be65a5 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py @@ -653,7 +653,7 @@ def __call__( The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. cross_attention_kwargs (`dict`, *optional*): - A kwargs dictionary that if specified is passed along to the `AttnProcessors` as defined under + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). noise_level (`int`, *optional*, defaults to `0`): diff --git a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py index b0ed57a81172..28c9504ef31e 100644 --- a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py +++ b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py @@ -7,7 +7,7 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...models import ModelMixin from ...models.attention import Attention -from ...models.attention_processor import AttnProcessors, CrossAttnAddedKVProcessor +from ...models.attention_processor import AttentionProcessor, CrossAttnAddedKVProcessor from ...models.dual_transformer_2d import DualTransformer2DModel from ...models.embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps from ...models.transformer_2d import Transformer2DModel @@ -452,7 +452,7 @@ def __init__( ) @property - def attn_processors(self) -> Dict[str, AttnProcessors]: + def attn_processors(self) -> Dict[str, AttentionProcessor]: r""" Returns: `dict` of attention processors: A dictionary containing all attention processors used in the model with @@ -461,7 +461,7 @@ def attn_processors(self) -> Dict[str, AttnProcessors]: # set recursively processors = {} - def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttnProcessors]): + def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): if hasattr(module, "set_processor"): processors[f"{name}.processor"] = module.processor @@ -475,10 +475,10 @@ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: return processors - def set_attn_processor(self, processor: Union[AttnProcessors, Dict[str, AttnProcessors]]): + def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): r""" Parameters: - `processor (`dict` of `AttnProcessors` or `AttnProcessors`): + `processor (`dict` of `AttentionProcessor` or `AttentionProcessor`): The instantiated processor class or a dictionary of processor classes that will be set as the processor of **all** `Attention` layers. In case `processor` is a dict, the key needs to define the path to the corresponding cross attention processor. This is strongly recommended when setting trainablae attention processors.: @@ -595,7 +595,7 @@ def forward( return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. cross_attention_kwargs (`dict`, *optional*): - A kwargs dictionary that if specified is passed along to the `AttnProcessors` as defined under + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). From 9baad504da34e3a30d0e1b97a72c58612ab0d34f Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 15 Mar 2023 18:59:31 +0100 Subject: [PATCH 5/7] up --- src/diffusers/models/attention_processor.py | 6 +- src/diffusers/models/cross_attention.py | 83 +++++++++++++++++++ src/diffusers/models/unet_2d_blocks.py | 8 +- .../versatile_diffusion/modeling_text_unet.py | 4 +- 4 files changed, 92 insertions(+), 9 deletions(-) create mode 100644 src/diffusers/models/cross_attention.py diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 8d1a3906db23..30026cd89ff9 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -184,7 +184,7 @@ def set_attention_slice(self, slice_size): elif slice_size is not None: processor = SlicedAttnProcessor(slice_size) elif self.added_kv_proj_dim is not None: - processor = CrossAttnAddedKVProcessor() + processor = AttnAddedKVProcessor() else: processor = AttnProcessor() @@ -395,7 +395,7 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a return hidden_states -class CrossAttnAddedKVProcessor: +class AttnAddedKVProcessor: def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None): residual = hidden_states hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2) @@ -688,7 +688,7 @@ def __call__(self, attn: "Attention", hidden_states, encoder_hidden_states=None, AttnProcessor, XFormersAttnProcessor, SlicedAttnProcessor, - CrossAttnAddedKVProcessor, + AttnAddedKVProcessor, SlicedAttnAddedKVProcessor, LoRAAttnProcessor, LoRAXFormersAttnProcessor, diff --git a/src/diffusers/models/cross_attention.py b/src/diffusers/models/cross_attention.py new file mode 100644 index 000000000000..ed269212bc19 --- /dev/null +++ b/src/diffusers/models/cross_attention.py @@ -0,0 +1,83 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from attention_processor import ( + Attention, + AttnAddedKVProcessor, + LoRAAttnProcessor, + LoRAXFormersAttnProcessor, + SlicedAttnAddedKVProcessor, + SlicedAttnProcessor, + XFormersAttnProcessor, +) +from attention_processor import ( + AttnProcessor as AttnProcessorRename, +) + +from ..utils import deprecate + + +deprecate( + "cross_attention", + "0.18.0", + "Importing from cross_attention is deprecated. Please import from attention_processor instead.", + standard_warn=False, +) + + +class CrossAttention(Attention): + def __init__(self, args, **kwargs): + deprecation_message = f"{self.__class__.__name__} is deprecated and will be removed in `0.18.0`. Please use `from diffusers.models.attention_processor import {''.join(self.__class__.__name__.split('Cross'))} instead." + deprecate("cross_attention", "0.18.0", deprecation_message, standard_warn=False) + + +class CrossAttnProcessor(AttnProcessorRename): + def __init__(self, args, **kwargs): + deprecation_message = f"{self.__class__.__name__} is deprecated and will be removed in `0.18.0`. Please use `from diffusers.models.attention_processor import {''.join(self.__class__.__name__.split('Cross'))} instead." + deprecate("cross_attention", "0.18.0", deprecation_message, standard_warn=False) + + +class LoRACrossAttnProcessor(LoRAAttnProcessor): + def __init__(self, args, **kwargs): + deprecation_message = f"{self.__class__.__name__} is deprecated and will be removed in `0.18.0`. Please use `from diffusers.models.attention_processor import {''.join(self.__class__.__name__.split('Cross'))} instead." + deprecate("cross_attention", "0.18.0", deprecation_message, standard_warn=False) + + +class CrossAttnAddedKVProcessor(AttnAddedKVProcessor): + def __init__(self, args, **kwargs): + deprecation_message = f"{self.__class__.__name__} is deprecated and will be removed in `0.18.0`. Please use `from diffusers.models.attention_processor import {''.join(self.__class__.__name__.split('Cross'))} instead." + deprecate("cross_attention", "0.18.0", deprecation_message, standard_warn=False) + + +class XFormersCrossAttnProcessor(XFormersAttnProcessor): + def __init__(self, args, **kwargs): + deprecation_message = f"{self.__class__.__name__} is deprecated and will be removed in `0.18.0`. Please use `from diffusers.models.attention_processor import {''.join(self.__class__.__name__.split('Cross'))} instead." + deprecate("cross_attention", "0.18.0", deprecation_message, standard_warn=False) + + +class LoRAXFormersCrossAttnProcessor(LoRAXFormersAttnProcessor): + def __init__(self, args, **kwargs): + deprecation_message = f"{self.__class__.__name__} is deprecated and will be removed in `0.18.0`. Please use `from diffusers.models.attention_processor import {''.join(self.__class__.__name__.split('Cross'))} instead." + deprecate("cross_attention", "0.18.0", deprecation_message, standard_warn=False) + + +class SlicedCrossAttnProcessor(SlicedAttnProcessor): + def __init__(self, args, **kwargs): + deprecation_message = f"{self.__class__.__name__} is deprecated and will be removed in `0.18.0`. Please use `from diffusers.models.attention_processor import {''.join(self.__class__.__name__.split('Cross'))} instead." + deprecate("cross_attention", "0.18.0", deprecation_message, standard_warn=False) + + +class SlicedCrossAttnAddedKVProcessor(SlicedAttnAddedKVProcessor): + def __init__(self, args, **kwargs): + deprecation_message = f"{self.__class__.__name__} is deprecated and will be removed in `0.18.0`. Please use `from diffusers.models.attention_processor import {''.join(self.__class__.__name__.split('Cross'))} instead." + deprecate("cross_attention", "0.18.0", deprecation_message, standard_warn=False) diff --git a/src/diffusers/models/unet_2d_blocks.py b/src/diffusers/models/unet_2d_blocks.py index 5453530b00bb..f865b42eb9d5 100644 --- a/src/diffusers/models/unet_2d_blocks.py +++ b/src/diffusers/models/unet_2d_blocks.py @@ -18,7 +18,7 @@ from torch import nn from .attention import AdaGroupNorm, AttentionBlock -from .attention_processor import Attention, CrossAttnAddedKVProcessor +from .attention_processor import Attention, AttnAddedKVProcessor from .dual_transformer_2d import DualTransformer2DModel from .resnet import Downsample2D, FirDownsample2D, FirUpsample2D, KDownsample2D, KUpsample2D, ResnetBlock2D, Upsample2D from .transformer_2d import Transformer2DModel @@ -600,7 +600,7 @@ def __init__( norm_num_groups=resnet_groups, bias=True, upcast_softmax=True, - processor=CrossAttnAddedKVProcessor(), + processor=AttnAddedKVProcessor(), ) ) resnets.append( @@ -1374,7 +1374,7 @@ def __init__( norm_num_groups=resnet_groups, bias=True, upcast_softmax=True, - processor=CrossAttnAddedKVProcessor(), + processor=AttnAddedKVProcessor(), ) ) self.attentions = nn.ModuleList(attentions) @@ -2367,7 +2367,7 @@ def __init__( norm_num_groups=resnet_groups, bias=True, upcast_softmax=True, - processor=CrossAttnAddedKVProcessor(), + processor=AttnAddedKVProcessor(), ) ) self.attentions = nn.ModuleList(attentions) diff --git a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py index 28c9504ef31e..7b021c597d10 100644 --- a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py +++ b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py @@ -7,7 +7,7 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...models import ModelMixin from ...models.attention import Attention -from ...models.attention_processor import AttentionProcessor, CrossAttnAddedKVProcessor +from ...models.attention_processor import AttentionProcessor, AttnAddedKVProcessor from ...models.dual_transformer_2d import DualTransformer2DModel from ...models.embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps from ...models.transformer_2d import Transformer2DModel @@ -1434,7 +1434,7 @@ def __init__( norm_num_groups=resnet_groups, bias=True, upcast_softmax=True, - processor=CrossAttnAddedKVProcessor(), + processor=AttnAddedKVProcessor(), ) ) resnets.append( From 1af90af17c8d01655ed4fd4a171e81918f3cdac2 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 15 Mar 2023 23:44:16 +0100 Subject: [PATCH 6/7] more deprecation imports --- src/diffusers/models/cross_attention.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/cross_attention.py b/src/diffusers/models/cross_attention.py index ed269212bc19..098812f426d3 100644 --- a/src/diffusers/models/cross_attention.py +++ b/src/diffusers/models/cross_attention.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from attention_processor import ( +from attention_processor import ( # noqa: F401 Attention, AttnAddedKVProcessor, LoRAAttnProcessor, @@ -19,9 +19,12 @@ SlicedAttnAddedKVProcessor, SlicedAttnProcessor, XFormersAttnProcessor, + LoRALinearLayer, + AttnProcessor2_0, ) -from attention_processor import ( +from attention_processor import ( # noqa: F401 AttnProcessor as AttnProcessorRename, + AttnProcessors as AttnProcessor, ) from ..utils import deprecate From 985d83cae9c8dd419d1e0dddf76598f1ca320192 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 15 Mar 2023 23:58:06 +0100 Subject: [PATCH 7/7] fixes --- src/diffusers/models/cross_attention.py | 40 +++++++++++++++---------- 1 file changed, 25 insertions(+), 15 deletions(-) diff --git a/src/diffusers/models/cross_attention.py b/src/diffusers/models/cross_attention.py index 098812f426d3..1bb4ad2f4a67 100644 --- a/src/diffusers/models/cross_attention.py +++ b/src/diffusers/models/cross_attention.py @@ -11,24 +11,23 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from attention_processor import ( # noqa: F401 +from ..utils import deprecate +from .attention_processor import ( # noqa: F401 Attention, + AttentionProcessor, AttnAddedKVProcessor, + AttnProcessor2_0, LoRAAttnProcessor, + LoRALinearLayer, LoRAXFormersAttnProcessor, SlicedAttnAddedKVProcessor, SlicedAttnProcessor, XFormersAttnProcessor, - LoRALinearLayer, - AttnProcessor2_0, ) -from attention_processor import ( # noqa: F401 +from .attention_processor import ( # noqa: F401 AttnProcessor as AttnProcessorRename, - AttnProcessors as AttnProcessor, ) -from ..utils import deprecate - deprecate( "cross_attention", @@ -38,49 +37,60 @@ ) +AttnProcessor = AttentionProcessor + + class CrossAttention(Attention): - def __init__(self, args, **kwargs): + def __init__(self, *args, **kwargs): deprecation_message = f"{self.__class__.__name__} is deprecated and will be removed in `0.18.0`. Please use `from diffusers.models.attention_processor import {''.join(self.__class__.__name__.split('Cross'))} instead." deprecate("cross_attention", "0.18.0", deprecation_message, standard_warn=False) + super().__init__(*args, **kwargs) class CrossAttnProcessor(AttnProcessorRename): - def __init__(self, args, **kwargs): + def __init__(self, *args, **kwargs): deprecation_message = f"{self.__class__.__name__} is deprecated and will be removed in `0.18.0`. Please use `from diffusers.models.attention_processor import {''.join(self.__class__.__name__.split('Cross'))} instead." deprecate("cross_attention", "0.18.0", deprecation_message, standard_warn=False) + super().__init__(*args, **kwargs) class LoRACrossAttnProcessor(LoRAAttnProcessor): - def __init__(self, args, **kwargs): + def __init__(self, *args, **kwargs): deprecation_message = f"{self.__class__.__name__} is deprecated and will be removed in `0.18.0`. Please use `from diffusers.models.attention_processor import {''.join(self.__class__.__name__.split('Cross'))} instead." deprecate("cross_attention", "0.18.0", deprecation_message, standard_warn=False) + super().__init__(*args, **kwargs) class CrossAttnAddedKVProcessor(AttnAddedKVProcessor): - def __init__(self, args, **kwargs): + def __init__(self, *args, **kwargs): deprecation_message = f"{self.__class__.__name__} is deprecated and will be removed in `0.18.0`. Please use `from diffusers.models.attention_processor import {''.join(self.__class__.__name__.split('Cross'))} instead." deprecate("cross_attention", "0.18.0", deprecation_message, standard_warn=False) + super().__init__(*args, **kwargs) class XFormersCrossAttnProcessor(XFormersAttnProcessor): - def __init__(self, args, **kwargs): + def __init__(self, *args, **kwargs): deprecation_message = f"{self.__class__.__name__} is deprecated and will be removed in `0.18.0`. Please use `from diffusers.models.attention_processor import {''.join(self.__class__.__name__.split('Cross'))} instead." deprecate("cross_attention", "0.18.0", deprecation_message, standard_warn=False) + super().__init__(*args, **kwargs) class LoRAXFormersCrossAttnProcessor(LoRAXFormersAttnProcessor): - def __init__(self, args, **kwargs): + def __init__(self, *args, **kwargs): deprecation_message = f"{self.__class__.__name__} is deprecated and will be removed in `0.18.0`. Please use `from diffusers.models.attention_processor import {''.join(self.__class__.__name__.split('Cross'))} instead." deprecate("cross_attention", "0.18.0", deprecation_message, standard_warn=False) + super().__init__(*args, **kwargs) class SlicedCrossAttnProcessor(SlicedAttnProcessor): - def __init__(self, args, **kwargs): + def __init__(self, *args, **kwargs): deprecation_message = f"{self.__class__.__name__} is deprecated and will be removed in `0.18.0`. Please use `from diffusers.models.attention_processor import {''.join(self.__class__.__name__.split('Cross'))} instead." deprecate("cross_attention", "0.18.0", deprecation_message, standard_warn=False) + super().__init__(*args, **kwargs) class SlicedCrossAttnAddedKVProcessor(SlicedAttnAddedKVProcessor): - def __init__(self, args, **kwargs): + def __init__(self, *args, **kwargs): deprecation_message = f"{self.__class__.__name__} is deprecated and will be removed in `0.18.0`. Please use `from diffusers.models.attention_processor import {''.join(self.__class__.__name__.split('Cross'))} instead." deprecate("cross_attention", "0.18.0", deprecation_message, standard_warn=False) + super().__init__(*args, **kwargs)