diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py
index 5edcf427f69e..966fdc616457 100644
--- a/src/diffusers/loaders.py
+++ b/src/diffusers/loaders.py
@@ -85,7 +85,49 @@ def __init__(self, regular_linear_layer, lora_scale=1, network_alpha=None, rank=
self.lora_scale = lora_scale
+ def _fuse_lora(self):
+ if self.lora_linear_layer is None:
+ return
+
+ dtype, device = self.regular_linear_layer.weight.data.dtype, self.regular_linear_layer.weight.data.device
+ logger.info(f"Fusing LoRA weights for {self.__class__}")
+
+ w_orig = self.regular_linear_layer.weight.data.float()
+ w_up = self.lora_linear_layer.up.weight.data.float()
+ w_down = self.lora_linear_layer.down.weight.data.float()
+
+ if self.lora_linear_layer.network_alpha is not None:
+ w_up = w_up * self.lora_linear_layer.network_alpha / self.lora_linear_layer.rank
+
+ fused_weight = w_orig + torch.bmm(w_up[None, :], w_down[None, :])[0]
+ self.regular_linear_layer.weight.data = fused_weight.to(device=device, dtype=dtype)
+
+ # we can drop the lora layer now
+ self.lora_linear_layer = None
+
+ # offload the up and down matrices to CPU to not blow the memory
+ self.w_up = w_up.cpu()
+ self.w_down = w_down.cpu()
+
+ def _unfuse_lora(self):
+ if not (hasattr(self, "w_up") and hasattr(self, "w_down")):
+ return
+ logger.info(f"Unfusing LoRA weights for {self.__class__}")
+
+ fused_weight = self.regular_linear_layer.weight.data
+ dtype, device = fused_weight.dtype, fused_weight.device
+
+ self.w_up = self.w_up.to(device=device, dtype=dtype)
+ self.w_down = self.w_down.to(device, dtype=dtype)
+ unfused_weight = fused_weight - torch.bmm(self.w_up[None, :], self.w_down[None, :])[0]
+ self.regular_linear_layer.weight.data = unfused_weight.to(device=device, dtype=dtype)
+
+ self.w_up = None
+ self.w_down = None
+
def forward(self, input):
+ if self.lora_linear_layer is None:
+ return self.regular_linear_layer(input)
return self.regular_linear_layer(input) + self.lora_scale * self.lora_linear_layer(input)
@@ -525,6 +567,20 @@ def save_function(weights, filename):
save_function(state_dict, os.path.join(save_directory, weight_name))
logger.info(f"Model weights saved in {os.path.join(save_directory, weight_name)}")
+ def fuse_lora(self):
+ self.apply(self._fuse_lora_apply)
+
+ def _fuse_lora_apply(self, module):
+ if hasattr(module, "_fuse_lora"):
+ module._fuse_lora()
+
+ def unfuse_lora(self):
+ self.apply(self._unfuse_lora_apply)
+
+ def _unfuse_lora_apply(self, module):
+ if hasattr(module, "_unfuse_lora"):
+ module._unfuse_lora()
+
class TextualInversionLoaderMixin:
r"""
@@ -1712,6 +1768,83 @@ def unload_lora_weights(self):
# Safe to call the following regardless of LoRA.
self._remove_text_encoder_monkey_patch()
+ def fuse_lora(self, fuse_unet: bool = True, fuse_text_encoder: bool = True):
+ r"""
+ Fuses the LoRA parameters into the original parameters of the corresponding blocks.
+
+
+
+ This is an experimental API.
+
+
+
+ Args:
+ fuse_unet (`bool`, defaults to `True`): Whether to fuse the UNet LoRA parameters.
+ fuse_text_encoder (`bool`, defaults to `True`):
+ Whether to fuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
+ LoRA parameters then it won't have any effect.
+ """
+ if fuse_unet:
+ self.unet.fuse_lora()
+
+ def fuse_text_encoder_lora(text_encoder):
+ for _, attn_module in text_encoder_attn_modules(text_encoder):
+ if isinstance(attn_module.q_proj, PatchedLoraProjection):
+ attn_module.q_proj._fuse_lora()
+ attn_module.k_proj._fuse_lora()
+ attn_module.v_proj._fuse_lora()
+ attn_module.out_proj._fuse_lora()
+
+ for _, mlp_module in text_encoder_mlp_modules(text_encoder):
+ if isinstance(mlp_module.fc1, PatchedLoraProjection):
+ mlp_module.fc1._fuse_lora()
+ mlp_module.fc2._fuse_lora()
+
+ if fuse_text_encoder:
+ if hasattr(self, "text_encoder"):
+ fuse_text_encoder_lora(self.text_encoder)
+ if hasattr(self, "text_encoder_2"):
+ fuse_text_encoder_lora(self.text_encoder_2)
+
+ def unfuse_lora(self, unfuse_unet: bool = True, unfuse_text_encoder: bool = True):
+ r"""
+ Reverses the effect of
+ [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraLoaderMixin.fuse_lora).
+
+
+
+ This is an experimental API.
+
+
+
+ Args:
+ unfuse_unet (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
+ unfuse_text_encoder (`bool`, defaults to `True`):
+ Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
+ LoRA parameters then it won't have any effect.
+ """
+ if unfuse_unet:
+ self.unet.unfuse_lora()
+
+ def unfuse_text_encoder_lora(text_encoder):
+ for _, attn_module in text_encoder_attn_modules(text_encoder):
+ if isinstance(attn_module.q_proj, PatchedLoraProjection):
+ attn_module.q_proj._unfuse_lora()
+ attn_module.k_proj._unfuse_lora()
+ attn_module.v_proj._unfuse_lora()
+ attn_module.out_proj._unfuse_lora()
+
+ for _, mlp_module in text_encoder_mlp_modules(text_encoder):
+ if isinstance(mlp_module.fc1, PatchedLoraProjection):
+ mlp_module.fc1._unfuse_lora()
+ mlp_module.fc2._unfuse_lora()
+
+ if unfuse_text_encoder:
+ if hasattr(self, "text_encoder"):
+ unfuse_text_encoder_lora(self.text_encoder)
+ if hasattr(self, "text_encoder_2"):
+ unfuse_text_encoder_lora(self.text_encoder_2)
+
class FromSingleFileMixin:
"""
diff --git a/src/diffusers/models/lora.py b/src/diffusers/models/lora.py
index a7cba5dc8765..cb0b3b45eb69 100644
--- a/src/diffusers/models/lora.py
+++ b/src/diffusers/models/lora.py
@@ -14,9 +14,15 @@
from typing import Optional
+import torch
import torch.nn.functional as F
from torch import nn
+from ..utils import logging
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
class LoRALinearLayer(nn.Module):
def __init__(self, in_features, out_features, rank=4, network_alpha=None, device=None, dtype=None):
@@ -91,6 +97,51 @@ def __init__(self, *args, lora_layer: Optional[LoRAConv2dLayer] = None, **kwargs
def set_lora_layer(self, lora_layer: Optional[LoRAConv2dLayer]):
self.lora_layer = lora_layer
+ def _fuse_lora(self):
+ if self.lora_layer is None:
+ return
+
+ dtype, device = self.weight.data.dtype, self.weight.data.device
+ logger.info(f"Fusing LoRA weights for {self.__class__}")
+
+ w_orig = self.weight.data.float()
+ w_up = self.lora_layer.up.weight.data.float()
+ w_down = self.lora_layer.down.weight.data.float()
+
+ if self.lora_layer.network_alpha is not None:
+ w_up = w_up * self.lora_layer.network_alpha / self.lora_layer.rank
+
+ fusion = torch.mm(w_up.flatten(start_dim=1), w_down.flatten(start_dim=1))
+ fusion = fusion.reshape((w_orig.shape))
+ fused_weight = w_orig + fusion
+ self.weight.data = fused_weight.to(device=device, dtype=dtype)
+
+ # we can drop the lora layer now
+ self.lora_layer = None
+
+ # offload the up and down matrices to CPU to not blow the memory
+ self.w_up = w_up.cpu()
+ self.w_down = w_down.cpu()
+
+ def _unfuse_lora(self):
+ if not (hasattr(self, "w_up") and hasattr(self, "w_down")):
+ return
+ logger.info(f"Unfusing LoRA weights for {self.__class__}")
+
+ fused_weight = self.weight.data
+ dtype, device = fused_weight.data.dtype, fused_weight.data.device
+
+ self.w_up = self.w_up.to(device=device, dtype=dtype)
+ self.w_down = self.w_down.to(device, dtype=dtype)
+
+ fusion = torch.mm(self.w_up.flatten(start_dim=1), self.w_down.flatten(start_dim=1))
+ fusion = fusion.reshape((fused_weight.shape))
+ unfused_weight = fused_weight - fusion
+ self.weight.data = unfused_weight.to(device=device, dtype=dtype)
+
+ self.w_up = None
+ self.w_down = None
+
def forward(self, x):
if self.lora_layer is None:
# make sure to the functional Conv2D function as otherwise torch.compile's graph will break
@@ -109,9 +160,49 @@ def __init__(self, *args, lora_layer: Optional[LoRALinearLayer] = None, **kwargs
super().__init__(*args, **kwargs)
self.lora_layer = lora_layer
- def set_lora_layer(self, lora_layer: Optional[LoRAConv2dLayer]):
+ def set_lora_layer(self, lora_layer: Optional[LoRALinearLayer]):
self.lora_layer = lora_layer
+ def _fuse_lora(self):
+ if self.lora_layer is None:
+ return
+
+ dtype, device = self.weight.data.dtype, self.weight.data.device
+ logger.info(f"Fusing LoRA weights for {self.__class__}")
+
+ w_orig = self.weight.data.float()
+ w_up = self.lora_layer.up.weight.data.float()
+ w_down = self.lora_layer.down.weight.data.float()
+
+ if self.lora_layer.network_alpha is not None:
+ w_up = w_up * self.lora_layer.network_alpha / self.lora_layer.rank
+
+ fused_weight = w_orig + torch.bmm(w_up[None, :], w_down[None, :])[0]
+ self.weight.data = fused_weight.to(device=device, dtype=dtype)
+
+ # we can drop the lora layer now
+ self.lora_layer = None
+
+ # offload the up and down matrices to CPU to not blow the memory
+ self.w_up = w_up.cpu()
+ self.w_down = w_down.cpu()
+
+ def _unfuse_lora(self):
+ if not (hasattr(self, "w_up") and hasattr(self, "w_down")):
+ return
+ logger.info(f"Unfusing LoRA weights for {self.__class__}")
+
+ fused_weight = self.weight.data
+ dtype, device = fused_weight.dtype, fused_weight.device
+
+ self.w_up = self.w_up.to(device=device, dtype=dtype)
+ self.w_down = self.w_down.to(device, dtype=dtype)
+ unfused_weight = fused_weight - torch.bmm(self.w_up[None, :], self.w_down[None, :])[0]
+ self.weight.data = unfused_weight.to(device=device, dtype=dtype)
+
+ self.w_up = None
+ self.w_down = None
+
def forward(self, hidden_states, lora_scale: int = 1):
if self.lora_layer is None:
return super().forward(hidden_states)
diff --git a/tests/models/test_lora_layers.py b/tests/models/test_lora_layers.py
index e18e05c949ad..f17529a1680e 100644
--- a/tests/models/test_lora_layers.py
+++ b/tests/models/test_lora_layers.py
@@ -14,6 +14,7 @@
# limitations under the License.
import os
import tempfile
+import time
import unittest
import numpy as np
@@ -692,6 +693,124 @@ def test_load_lora_locally_safetensors(self):
sd_pipe.unload_lora_weights()
+ def test_lora_fusion(self):
+ pipeline_components, lora_components = self.get_dummy_components()
+ sd_pipe = StableDiffusionXLPipeline(**pipeline_components)
+ sd_pipe = sd_pipe.to(torch_device)
+ sd_pipe.set_progress_bar_config(disable=None)
+
+ _, _, pipeline_inputs = self.get_dummy_inputs(with_generator=False)
+
+ original_images = sd_pipe(**pipeline_inputs, generator=torch.manual_seed(0)).images
+ orig_image_slice = original_images[0, -3:, -3:, -1]
+
+ # Emulate training.
+ set_lora_weights(lora_components["unet_lora_layers"].parameters(), randn_weight=True)
+ set_lora_weights(lora_components["text_encoder_one_lora_layers"].parameters(), randn_weight=True)
+ set_lora_weights(lora_components["text_encoder_two_lora_layers"].parameters(), randn_weight=True)
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ StableDiffusionXLPipeline.save_lora_weights(
+ save_directory=tmpdirname,
+ unet_lora_layers=lora_components["unet_lora_layers"],
+ text_encoder_lora_layers=lora_components["text_encoder_one_lora_layers"],
+ text_encoder_2_lora_layers=lora_components["text_encoder_two_lora_layers"],
+ safe_serialization=True,
+ )
+ self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")))
+ sd_pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
+
+ sd_pipe.fuse_lora()
+ lora_images = sd_pipe(**pipeline_inputs, generator=torch.manual_seed(0)).images
+ lora_image_slice = lora_images[0, -3:, -3:, -1]
+
+ self.assertFalse(np.allclose(orig_image_slice, lora_image_slice, atol=1e-3))
+
+ def test_unfuse_lora(self):
+ pipeline_components, lora_components = self.get_dummy_components()
+ sd_pipe = StableDiffusionXLPipeline(**pipeline_components)
+ sd_pipe = sd_pipe.to(torch_device)
+ sd_pipe.set_progress_bar_config(disable=None)
+
+ _, _, pipeline_inputs = self.get_dummy_inputs(with_generator=False)
+
+ original_images = sd_pipe(**pipeline_inputs, generator=torch.manual_seed(0)).images
+ orig_image_slice = original_images[0, -3:, -3:, -1]
+
+ # Emulate training.
+ set_lora_weights(lora_components["unet_lora_layers"].parameters(), randn_weight=True)
+ set_lora_weights(lora_components["text_encoder_one_lora_layers"].parameters(), randn_weight=True)
+ set_lora_weights(lora_components["text_encoder_two_lora_layers"].parameters(), randn_weight=True)
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ StableDiffusionXLPipeline.save_lora_weights(
+ save_directory=tmpdirname,
+ unet_lora_layers=lora_components["unet_lora_layers"],
+ text_encoder_lora_layers=lora_components["text_encoder_one_lora_layers"],
+ text_encoder_2_lora_layers=lora_components["text_encoder_two_lora_layers"],
+ safe_serialization=True,
+ )
+ self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")))
+ sd_pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
+
+ sd_pipe.fuse_lora()
+ lora_images = sd_pipe(**pipeline_inputs, generator=torch.manual_seed(0)).images
+ lora_image_slice = lora_images[0, -3:, -3:, -1]
+
+ # Reverse LoRA fusion.
+ sd_pipe.unfuse_lora()
+ original_images = sd_pipe(**pipeline_inputs, generator=torch.manual_seed(0)).images
+ orig_image_slice_two = original_images[0, -3:, -3:, -1]
+
+ assert not np.allclose(
+ orig_image_slice, lora_image_slice
+ ), "Fusion of LoRAs should lead to a different image slice."
+ assert not np.allclose(
+ orig_image_slice_two, lora_image_slice
+ ), "Fusion of LoRAs should lead to a different image slice."
+ assert np.allclose(
+ orig_image_slice, orig_image_slice_two, atol=1e-3
+ ), "Reversing LoRA fusion should lead to results similar to what was obtained with the pipeline without any LoRA parameters."
+
+ def test_lora_fusion_is_not_affected_by_unloading(self):
+ pipeline_components, lora_components = self.get_dummy_components()
+ sd_pipe = StableDiffusionXLPipeline(**pipeline_components)
+ sd_pipe = sd_pipe.to(torch_device)
+ sd_pipe.set_progress_bar_config(disable=None)
+
+ _, _, pipeline_inputs = self.get_dummy_inputs(with_generator=False)
+
+ _ = sd_pipe(**pipeline_inputs, generator=torch.manual_seed(0)).images
+
+ # Emulate training.
+ set_lora_weights(lora_components["unet_lora_layers"].parameters(), randn_weight=True)
+ set_lora_weights(lora_components["text_encoder_one_lora_layers"].parameters(), randn_weight=True)
+ set_lora_weights(lora_components["text_encoder_two_lora_layers"].parameters(), randn_weight=True)
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ StableDiffusionXLPipeline.save_lora_weights(
+ save_directory=tmpdirname,
+ unet_lora_layers=lora_components["unet_lora_layers"],
+ text_encoder_lora_layers=lora_components["text_encoder_one_lora_layers"],
+ text_encoder_2_lora_layers=lora_components["text_encoder_two_lora_layers"],
+ safe_serialization=True,
+ )
+ self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")))
+ sd_pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
+
+ sd_pipe.fuse_lora()
+ lora_images = sd_pipe(**pipeline_inputs, generator=torch.manual_seed(0)).images
+ lora_image_slice = lora_images[0, -3:, -3:, -1]
+
+ # Unload LoRA parameters.
+ sd_pipe.unload_lora_weights()
+ images_with_unloaded_lora = sd_pipe(**pipeline_inputs, generator=torch.manual_seed(0)).images
+ images_with_unloaded_lora_slice = images_with_unloaded_lora[0, -3:, -3:, -1]
+
+ assert np.allclose(
+ lora_image_slice, images_with_unloaded_lora_slice
+ ), "`unload_lora_weights()` should have not effect on the semantics of the results as the LoRA parameters were fused."
+
@slow
@require_torch_gpu
@@ -877,10 +996,10 @@ def test_sdxl_0_9_lora_one(self):
generator = torch.Generator().manual_seed(0)
pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-0.9")
- pipe.enable_model_cpu_offload()
lora_model_id = "hf-internal-testing/sdxl-0.9-daiton-lora"
lora_filename = "daiton-xl-lora-test.safetensors"
pipe.load_lora_weights(lora_model_id, weight_name=lora_filename)
+ pipe.enable_model_cpu_offload()
images = pipe(
"masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=2
@@ -895,10 +1014,10 @@ def test_sdxl_0_9_lora_two(self):
generator = torch.Generator().manual_seed(0)
pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-0.9")
- pipe.enable_model_cpu_offload()
lora_model_id = "hf-internal-testing/sdxl-0.9-costumes-lora"
lora_filename = "saijo.safetensors"
pipe.load_lora_weights(lora_model_id, weight_name=lora_filename)
+ pipe.enable_model_cpu_offload()
images = pipe(
"masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=2
@@ -913,10 +1032,10 @@ def test_sdxl_0_9_lora_three(self):
generator = torch.Generator().manual_seed(0)
pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-0.9")
- pipe.enable_model_cpu_offload()
lora_model_id = "hf-internal-testing/sdxl-0.9-kamepan-lora"
lora_filename = "kame_sdxl_v2-000020-16rank.safetensors"
pipe.load_lora_weights(lora_model_id, weight_name=lora_filename)
+ pipe.enable_model_cpu_offload()
images = pipe(
"masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=2
@@ -931,19 +1050,127 @@ def test_sdxl_1_0_lora(self):
generator = torch.Generator().manual_seed(0)
pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0")
+ lora_model_id = "hf-internal-testing/sdxl-1.0-lora"
+ lora_filename = "sd_xl_offset_example-lora_1.0.safetensors"
+ pipe.load_lora_weights(lora_model_id, weight_name=lora_filename)
pipe.enable_model_cpu_offload()
+
+ images = pipe(
+ "masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=2
+ ).images
+
+ images = images[0, -3:, -3:, -1].flatten()
+ expected = np.array([0.4468, 0.4087, 0.4134, 0.366, 0.3202, 0.3505, 0.3786, 0.387, 0.3535])
+
+ self.assertTrue(np.allclose(images, expected, atol=1e-4))
+
+ def test_sdxl_1_0_lora_fusion(self):
+ generator = torch.Generator().manual_seed(0)
+
+ pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0")
lora_model_id = "hf-internal-testing/sdxl-1.0-lora"
lora_filename = "sd_xl_offset_example-lora_1.0.safetensors"
pipe.load_lora_weights(lora_model_id, weight_name=lora_filename)
+ pipe.fuse_lora()
+ pipe.enable_model_cpu_offload()
images = pipe(
"masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=2
).images
images = images[0, -3:, -3:, -1].flatten()
+ # This way we also test equivalence between LoRA fusion and the non-fusion behaviour.
expected = np.array([0.4468, 0.4087, 0.4134, 0.366, 0.3202, 0.3505, 0.3786, 0.387, 0.3535])
- self.assertTrue(np.allclose(images, expected, atol=1e-3))
+ self.assertTrue(np.allclose(images, expected, atol=1e-4))
+
+ def test_sdxl_1_0_lora_unfusion(self):
+ generator = torch.Generator().manual_seed(0)
+
+ pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0")
+ lora_model_id = "hf-internal-testing/sdxl-1.0-lora"
+ lora_filename = "sd_xl_offset_example-lora_1.0.safetensors"
+ pipe.load_lora_weights(lora_model_id, weight_name=lora_filename)
+ pipe.fuse_lora()
+ pipe.enable_model_cpu_offload()
+
+ images = pipe(
+ "masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=2
+ ).images
+ images_with_fusion = images[0, -3:, -3:, -1].flatten()
+
+ pipe.unfuse_lora()
+ generator = torch.Generator().manual_seed(0)
+ images = pipe(
+ "masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=2
+ ).images
+ images_without_fusion = images[0, -3:, -3:, -1].flatten()
+
+ self.assertFalse(np.allclose(images_with_fusion, images_without_fusion, atol=1e-3))
+
+ def test_sdxl_1_0_lora_unfusion_effectivity(self):
+ pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0")
+ pipe.enable_model_cpu_offload()
+
+ generator = torch.Generator().manual_seed(0)
+ images = pipe(
+ "masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=2
+ ).images
+ original_image_slice = images[0, -3:, -3:, -1].flatten()
+
+ lora_model_id = "hf-internal-testing/sdxl-1.0-lora"
+ lora_filename = "sd_xl_offset_example-lora_1.0.safetensors"
+ pipe.load_lora_weights(lora_model_id, weight_name=lora_filename)
+ pipe.fuse_lora()
+
+ generator = torch.Generator().manual_seed(0)
+ _ = pipe(
+ "masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=2
+ ).images
+
+ pipe.unfuse_lora()
+ generator = torch.Generator().manual_seed(0)
+ images = pipe(
+ "masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=2
+ ).images
+ images_without_fusion_slice = images[0, -3:, -3:, -1].flatten()
+
+ self.assertTrue(np.allclose(original_image_slice, images_without_fusion_slice, atol=1e-3))
+
+ def test_sdxl_1_0_lora_fusion_efficiency(self):
+ generator = torch.Generator().manual_seed(0)
+ lora_model_id = "hf-internal-testing/sdxl-1.0-lora"
+ lora_filename = "sd_xl_offset_example-lora_1.0.safetensors"
+
+ pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0")
+ pipe.load_lora_weights(lora_model_id, weight_name=lora_filename)
+ pipe.enable_model_cpu_offload()
+
+ start_time = time.time()
+ for _ in range(3):
+ pipe(
+ "masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=2
+ ).images
+ end_time = time.time()
+ elapsed_time_non_fusion = end_time - start_time
+
+ del pipe
+
+ pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0")
+ pipe.load_lora_weights(lora_model_id, weight_name=lora_filename)
+ pipe.fuse_lora()
+ pipe.enable_model_cpu_offload()
+
+ start_time = time.time()
+ generator = torch.Generator().manual_seed(0)
+ for _ in range(3):
+ pipe(
+ "masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=2
+ ).images
+ end_time = time.time()
+ elapsed_time_fusion = end_time - start_time
+
+ self.assertTrue(elapsed_time_fusion < elapsed_time_non_fusion)
def test_sdxl_1_0_last_ben(self):
generator = torch.Generator().manual_seed(0)