Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion src/diffusers/models/cross_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,6 @@ def __init__(self, in_features, out_features, rank=4):

self.down = nn.Linear(in_features, rank, bias=False)
self.up = nn.Linear(rank, out_features, bias=False)
self.scale = 1.0

nn.init.normal_(self.down.weight, std=1 / rank)
nn.init.zeros_(self.up.weight)
Expand Down
4 changes: 4 additions & 0 deletions src/diffusers/models/unet_2d_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,6 +410,10 @@ def forward(
encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
cross_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under
`self.processor` in
[diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).

Returns:
[`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
Expand Down
14 changes: 12 additions & 2 deletions src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

import inspect
from typing import Callable, List, Optional, Union
from typing import Any, Callable, Dict, List, Optional, Union

import torch

Expand Down Expand Up @@ -477,6 +477,7 @@ def __call__(
return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: Optional[int] = 1,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
):
r"""
Function invoked when calling the pipeline for generation.
Expand Down Expand Up @@ -533,6 +534,10 @@ def __call__(
callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function will be called. If not specified, the callback will be
called at every step.
cross_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under
`self.processor` in
[diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).

Examples:

Expand Down Expand Up @@ -606,7 +611,12 @@ def __call__(
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)

# predict the noise residual
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=prompt_embeds).sample
noise_pred = self.unet(
latent_model_input,
t,
encoder_hidden_states=prompt_embeds,
cross_attention_kwargs=cross_attention_kwargs,
).sample

# perform guidance
if do_classifier_free_guidance:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

import inspect
from typing import Callable, List, Optional, Union
from typing import Any, Callable, Dict, List, Optional, Union

import torch

Expand Down Expand Up @@ -474,6 +474,7 @@ def __call__(
return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: Optional[int] = 1,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
):
r"""
Function invoked when calling the pipeline for generation.
Expand Down Expand Up @@ -530,6 +531,10 @@ def __call__(
callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function will be called. If not specified, the callback will be
called at every step.
cross_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under
`self.processor` in
[diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).

Examples:

Expand Down Expand Up @@ -603,7 +608,12 @@ def __call__(
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)

# predict the noise residual
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=prompt_embeds).sample
noise_pred = self.unet(
latent_model_input,
t,
encoder_hidden_states=prompt_embeds,
cross_attention_kwargs=cross_attention_kwargs,
).sample

# perform guidance
if do_classifier_free_guidance:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -496,6 +496,10 @@ def forward(
encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
cross_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under
`self.processor` in
[diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).

Returns:
[`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
Expand Down
35 changes: 35 additions & 0 deletions tests/pipelines/stable_diffusion/test_stable_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from diffusers.utils.testing_utils import CaptureLogger, require_torch_gpu
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer

from ...models.test_models_unet_2d_condition import create_lora_layers
from ...test_pipelines_common import PipelineTesterMixin


Expand Down Expand Up @@ -134,6 +135,40 @@ def test_stable_diffusion_ddim(self):

assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2

def test_stable_diffusion_lora(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator

components = self.get_dummy_components()
sd_pipe = StableDiffusionPipeline(**components)
sd_pipe = sd_pipe.to(torch_device)
sd_pipe.set_progress_bar_config(disable=None)

# forward 1
inputs = self.get_dummy_inputs(device)
output = sd_pipe(**inputs)
image = output.images
image_slice = image[0, -3:, -3:, -1]

# set lora layers
lora_attn_procs = create_lora_layers(sd_pipe.unet)
sd_pipe.unet.set_attn_processor(lora_attn_procs)
sd_pipe = sd_pipe.to(torch_device)

# forward 2
inputs = self.get_dummy_inputs(device)
output = sd_pipe(**inputs, cross_attention_kwargs={"scale": 0.0})
image = output.images
image_slice_1 = image[0, -3:, -3:, -1]

# forward 3
inputs = self.get_dummy_inputs(device)
output = sd_pipe(**inputs, cross_attention_kwargs={"scale": 0.5})
image = output.images
image_slice_2 = image[0, -3:, -3:, -1]

assert np.abs(image_slice - image_slice_1).max() < 1e-2
assert np.abs(image_slice - image_slice_2).max() > 1e-2

def test_stable_diffusion_prompt_embeds(self):
components = self.get_dummy_components()
sd_pipe = StableDiffusionPipeline(**components)
Expand Down