diff --git a/src/diffusers/models/cross_attention.py b/src/diffusers/models/cross_attention.py index d834432ec4db..a1d77f66ef56 100644 --- a/src/diffusers/models/cross_attention.py +++ b/src/diffusers/models/cross_attention.py @@ -278,7 +278,6 @@ def __init__(self, in_features, out_features, rank=4): self.down = nn.Linear(in_features, rank, bias=False) self.up = nn.Linear(rank, out_features, bias=False) - self.scale = 1.0 nn.init.normal_(self.down.weight, std=1 / rank) nn.init.zeros_(self.up.weight) diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index bfdb50b7b94a..c524dbf2bed3 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -410,6 +410,10 @@ def forward( encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under + `self.processor` in + [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). Returns: [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`: diff --git a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py index a198136e78a7..0344f09a0c9c 100644 --- a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py +++ b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py @@ -13,7 +13,7 @@ # limitations under the License. import inspect -from typing import Callable, List, Optional, Union +from typing import Any, Callable, Dict, List, Optional, Union import torch @@ -477,6 +477,7 @@ def __call__( return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback_steps: Optional[int] = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, ): r""" Function invoked when calling the pipeline for generation. @@ -533,6 +534,10 @@ def __call__( callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under + `self.processor` in + [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). Examples: @@ -606,7 +611,12 @@ def __call__( latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # predict the noise residual - noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=prompt_embeds).sample + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + ).sample # perform guidance if do_classifier_free_guidance: diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index 1be36cc21bf2..fce7d3bb6361 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -13,7 +13,7 @@ # limitations under the License. import inspect -from typing import Callable, List, Optional, Union +from typing import Any, Callable, Dict, List, Optional, Union import torch @@ -474,6 +474,7 @@ def __call__( return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback_steps: Optional[int] = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, ): r""" Function invoked when calling the pipeline for generation. @@ -530,6 +531,10 @@ def __call__( callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under + `self.processor` in + [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). Examples: @@ -603,7 +608,12 @@ def __call__( latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # predict the noise residual - noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=prompt_embeds).sample + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + ).sample # perform guidance if do_classifier_free_guidance: diff --git a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py index dd9988be520d..3511cbfdd730 100644 --- a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py +++ b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py @@ -496,6 +496,10 @@ def forward( encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under + `self.processor` in + [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). Returns: [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`: diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion.py b/tests/pipelines/stable_diffusion/test_stable_diffusion.py index febf58fd1094..e78ae1c12518 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion.py @@ -38,6 +38,7 @@ from diffusers.utils.testing_utils import CaptureLogger, require_torch_gpu from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer +from ...models.test_models_unet_2d_condition import create_lora_layers from ...test_pipelines_common import PipelineTesterMixin @@ -134,6 +135,40 @@ def test_stable_diffusion_ddim(self): assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + def test_stable_diffusion_lora(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + + components = self.get_dummy_components() + sd_pipe = StableDiffusionPipeline(**components) + sd_pipe = sd_pipe.to(torch_device) + sd_pipe.set_progress_bar_config(disable=None) + + # forward 1 + inputs = self.get_dummy_inputs(device) + output = sd_pipe(**inputs) + image = output.images + image_slice = image[0, -3:, -3:, -1] + + # set lora layers + lora_attn_procs = create_lora_layers(sd_pipe.unet) + sd_pipe.unet.set_attn_processor(lora_attn_procs) + sd_pipe = sd_pipe.to(torch_device) + + # forward 2 + inputs = self.get_dummy_inputs(device) + output = sd_pipe(**inputs, cross_attention_kwargs={"scale": 0.0}) + image = output.images + image_slice_1 = image[0, -3:, -3:, -1] + + # forward 3 + inputs = self.get_dummy_inputs(device) + output = sd_pipe(**inputs, cross_attention_kwargs={"scale": 0.5}) + image = output.images + image_slice_2 = image[0, -3:, -3:, -1] + + assert np.abs(image_slice - image_slice_1).max() < 1e-2 + assert np.abs(image_slice - image_slice_2).max() > 1e-2 + def test_stable_diffusion_prompt_embeds(self): components = self.get_dummy_components() sd_pipe = StableDiffusionPipeline(**components)