diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index 5edcf427f69e..966fdc616457 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -85,7 +85,49 @@ def __init__(self, regular_linear_layer, lora_scale=1, network_alpha=None, rank= self.lora_scale = lora_scale + def _fuse_lora(self): + if self.lora_linear_layer is None: + return + + dtype, device = self.regular_linear_layer.weight.data.dtype, self.regular_linear_layer.weight.data.device + logger.info(f"Fusing LoRA weights for {self.__class__}") + + w_orig = self.regular_linear_layer.weight.data.float() + w_up = self.lora_linear_layer.up.weight.data.float() + w_down = self.lora_linear_layer.down.weight.data.float() + + if self.lora_linear_layer.network_alpha is not None: + w_up = w_up * self.lora_linear_layer.network_alpha / self.lora_linear_layer.rank + + fused_weight = w_orig + torch.bmm(w_up[None, :], w_down[None, :])[0] + self.regular_linear_layer.weight.data = fused_weight.to(device=device, dtype=dtype) + + # we can drop the lora layer now + self.lora_linear_layer = None + + # offload the up and down matrices to CPU to not blow the memory + self.w_up = w_up.cpu() + self.w_down = w_down.cpu() + + def _unfuse_lora(self): + if not (hasattr(self, "w_up") and hasattr(self, "w_down")): + return + logger.info(f"Unfusing LoRA weights for {self.__class__}") + + fused_weight = self.regular_linear_layer.weight.data + dtype, device = fused_weight.dtype, fused_weight.device + + self.w_up = self.w_up.to(device=device, dtype=dtype) + self.w_down = self.w_down.to(device, dtype=dtype) + unfused_weight = fused_weight - torch.bmm(self.w_up[None, :], self.w_down[None, :])[0] + self.regular_linear_layer.weight.data = unfused_weight.to(device=device, dtype=dtype) + + self.w_up = None + self.w_down = None + def forward(self, input): + if self.lora_linear_layer is None: + return self.regular_linear_layer(input) return self.regular_linear_layer(input) + self.lora_scale * self.lora_linear_layer(input) @@ -525,6 +567,20 @@ def save_function(weights, filename): save_function(state_dict, os.path.join(save_directory, weight_name)) logger.info(f"Model weights saved in {os.path.join(save_directory, weight_name)}") + def fuse_lora(self): + self.apply(self._fuse_lora_apply) + + def _fuse_lora_apply(self, module): + if hasattr(module, "_fuse_lora"): + module._fuse_lora() + + def unfuse_lora(self): + self.apply(self._unfuse_lora_apply) + + def _unfuse_lora_apply(self, module): + if hasattr(module, "_unfuse_lora"): + module._unfuse_lora() + class TextualInversionLoaderMixin: r""" @@ -1712,6 +1768,83 @@ def unload_lora_weights(self): # Safe to call the following regardless of LoRA. self._remove_text_encoder_monkey_patch() + def fuse_lora(self, fuse_unet: bool = True, fuse_text_encoder: bool = True): + r""" + Fuses the LoRA parameters into the original parameters of the corresponding blocks. + + + + This is an experimental API. + + + + Args: + fuse_unet (`bool`, defaults to `True`): Whether to fuse the UNet LoRA parameters. + fuse_text_encoder (`bool`, defaults to `True`): + Whether to fuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the + LoRA parameters then it won't have any effect. + """ + if fuse_unet: + self.unet.fuse_lora() + + def fuse_text_encoder_lora(text_encoder): + for _, attn_module in text_encoder_attn_modules(text_encoder): + if isinstance(attn_module.q_proj, PatchedLoraProjection): + attn_module.q_proj._fuse_lora() + attn_module.k_proj._fuse_lora() + attn_module.v_proj._fuse_lora() + attn_module.out_proj._fuse_lora() + + for _, mlp_module in text_encoder_mlp_modules(text_encoder): + if isinstance(mlp_module.fc1, PatchedLoraProjection): + mlp_module.fc1._fuse_lora() + mlp_module.fc2._fuse_lora() + + if fuse_text_encoder: + if hasattr(self, "text_encoder"): + fuse_text_encoder_lora(self.text_encoder) + if hasattr(self, "text_encoder_2"): + fuse_text_encoder_lora(self.text_encoder_2) + + def unfuse_lora(self, unfuse_unet: bool = True, unfuse_text_encoder: bool = True): + r""" + Reverses the effect of + [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraLoaderMixin.fuse_lora). + + + + This is an experimental API. + + + + Args: + unfuse_unet (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters. + unfuse_text_encoder (`bool`, defaults to `True`): + Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the + LoRA parameters then it won't have any effect. + """ + if unfuse_unet: + self.unet.unfuse_lora() + + def unfuse_text_encoder_lora(text_encoder): + for _, attn_module in text_encoder_attn_modules(text_encoder): + if isinstance(attn_module.q_proj, PatchedLoraProjection): + attn_module.q_proj._unfuse_lora() + attn_module.k_proj._unfuse_lora() + attn_module.v_proj._unfuse_lora() + attn_module.out_proj._unfuse_lora() + + for _, mlp_module in text_encoder_mlp_modules(text_encoder): + if isinstance(mlp_module.fc1, PatchedLoraProjection): + mlp_module.fc1._unfuse_lora() + mlp_module.fc2._unfuse_lora() + + if unfuse_text_encoder: + if hasattr(self, "text_encoder"): + unfuse_text_encoder_lora(self.text_encoder) + if hasattr(self, "text_encoder_2"): + unfuse_text_encoder_lora(self.text_encoder_2) + class FromSingleFileMixin: """ diff --git a/src/diffusers/models/lora.py b/src/diffusers/models/lora.py index a7cba5dc8765..cb0b3b45eb69 100644 --- a/src/diffusers/models/lora.py +++ b/src/diffusers/models/lora.py @@ -14,9 +14,15 @@ from typing import Optional +import torch import torch.nn.functional as F from torch import nn +from ..utils import logging + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + class LoRALinearLayer(nn.Module): def __init__(self, in_features, out_features, rank=4, network_alpha=None, device=None, dtype=None): @@ -91,6 +97,51 @@ def __init__(self, *args, lora_layer: Optional[LoRAConv2dLayer] = None, **kwargs def set_lora_layer(self, lora_layer: Optional[LoRAConv2dLayer]): self.lora_layer = lora_layer + def _fuse_lora(self): + if self.lora_layer is None: + return + + dtype, device = self.weight.data.dtype, self.weight.data.device + logger.info(f"Fusing LoRA weights for {self.__class__}") + + w_orig = self.weight.data.float() + w_up = self.lora_layer.up.weight.data.float() + w_down = self.lora_layer.down.weight.data.float() + + if self.lora_layer.network_alpha is not None: + w_up = w_up * self.lora_layer.network_alpha / self.lora_layer.rank + + fusion = torch.mm(w_up.flatten(start_dim=1), w_down.flatten(start_dim=1)) + fusion = fusion.reshape((w_orig.shape)) + fused_weight = w_orig + fusion + self.weight.data = fused_weight.to(device=device, dtype=dtype) + + # we can drop the lora layer now + self.lora_layer = None + + # offload the up and down matrices to CPU to not blow the memory + self.w_up = w_up.cpu() + self.w_down = w_down.cpu() + + def _unfuse_lora(self): + if not (hasattr(self, "w_up") and hasattr(self, "w_down")): + return + logger.info(f"Unfusing LoRA weights for {self.__class__}") + + fused_weight = self.weight.data + dtype, device = fused_weight.data.dtype, fused_weight.data.device + + self.w_up = self.w_up.to(device=device, dtype=dtype) + self.w_down = self.w_down.to(device, dtype=dtype) + + fusion = torch.mm(self.w_up.flatten(start_dim=1), self.w_down.flatten(start_dim=1)) + fusion = fusion.reshape((fused_weight.shape)) + unfused_weight = fused_weight - fusion + self.weight.data = unfused_weight.to(device=device, dtype=dtype) + + self.w_up = None + self.w_down = None + def forward(self, x): if self.lora_layer is None: # make sure to the functional Conv2D function as otherwise torch.compile's graph will break @@ -109,9 +160,49 @@ def __init__(self, *args, lora_layer: Optional[LoRALinearLayer] = None, **kwargs super().__init__(*args, **kwargs) self.lora_layer = lora_layer - def set_lora_layer(self, lora_layer: Optional[LoRAConv2dLayer]): + def set_lora_layer(self, lora_layer: Optional[LoRALinearLayer]): self.lora_layer = lora_layer + def _fuse_lora(self): + if self.lora_layer is None: + return + + dtype, device = self.weight.data.dtype, self.weight.data.device + logger.info(f"Fusing LoRA weights for {self.__class__}") + + w_orig = self.weight.data.float() + w_up = self.lora_layer.up.weight.data.float() + w_down = self.lora_layer.down.weight.data.float() + + if self.lora_layer.network_alpha is not None: + w_up = w_up * self.lora_layer.network_alpha / self.lora_layer.rank + + fused_weight = w_orig + torch.bmm(w_up[None, :], w_down[None, :])[0] + self.weight.data = fused_weight.to(device=device, dtype=dtype) + + # we can drop the lora layer now + self.lora_layer = None + + # offload the up and down matrices to CPU to not blow the memory + self.w_up = w_up.cpu() + self.w_down = w_down.cpu() + + def _unfuse_lora(self): + if not (hasattr(self, "w_up") and hasattr(self, "w_down")): + return + logger.info(f"Unfusing LoRA weights for {self.__class__}") + + fused_weight = self.weight.data + dtype, device = fused_weight.dtype, fused_weight.device + + self.w_up = self.w_up.to(device=device, dtype=dtype) + self.w_down = self.w_down.to(device, dtype=dtype) + unfused_weight = fused_weight - torch.bmm(self.w_up[None, :], self.w_down[None, :])[0] + self.weight.data = unfused_weight.to(device=device, dtype=dtype) + + self.w_up = None + self.w_down = None + def forward(self, hidden_states, lora_scale: int = 1): if self.lora_layer is None: return super().forward(hidden_states) diff --git a/tests/models/test_lora_layers.py b/tests/models/test_lora_layers.py index e18e05c949ad..f17529a1680e 100644 --- a/tests/models/test_lora_layers.py +++ b/tests/models/test_lora_layers.py @@ -14,6 +14,7 @@ # limitations under the License. import os import tempfile +import time import unittest import numpy as np @@ -692,6 +693,124 @@ def test_load_lora_locally_safetensors(self): sd_pipe.unload_lora_weights() + def test_lora_fusion(self): + pipeline_components, lora_components = self.get_dummy_components() + sd_pipe = StableDiffusionXLPipeline(**pipeline_components) + sd_pipe = sd_pipe.to(torch_device) + sd_pipe.set_progress_bar_config(disable=None) + + _, _, pipeline_inputs = self.get_dummy_inputs(with_generator=False) + + original_images = sd_pipe(**pipeline_inputs, generator=torch.manual_seed(0)).images + orig_image_slice = original_images[0, -3:, -3:, -1] + + # Emulate training. + set_lora_weights(lora_components["unet_lora_layers"].parameters(), randn_weight=True) + set_lora_weights(lora_components["text_encoder_one_lora_layers"].parameters(), randn_weight=True) + set_lora_weights(lora_components["text_encoder_two_lora_layers"].parameters(), randn_weight=True) + + with tempfile.TemporaryDirectory() as tmpdirname: + StableDiffusionXLPipeline.save_lora_weights( + save_directory=tmpdirname, + unet_lora_layers=lora_components["unet_lora_layers"], + text_encoder_lora_layers=lora_components["text_encoder_one_lora_layers"], + text_encoder_2_lora_layers=lora_components["text_encoder_two_lora_layers"], + safe_serialization=True, + ) + self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))) + sd_pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) + + sd_pipe.fuse_lora() + lora_images = sd_pipe(**pipeline_inputs, generator=torch.manual_seed(0)).images + lora_image_slice = lora_images[0, -3:, -3:, -1] + + self.assertFalse(np.allclose(orig_image_slice, lora_image_slice, atol=1e-3)) + + def test_unfuse_lora(self): + pipeline_components, lora_components = self.get_dummy_components() + sd_pipe = StableDiffusionXLPipeline(**pipeline_components) + sd_pipe = sd_pipe.to(torch_device) + sd_pipe.set_progress_bar_config(disable=None) + + _, _, pipeline_inputs = self.get_dummy_inputs(with_generator=False) + + original_images = sd_pipe(**pipeline_inputs, generator=torch.manual_seed(0)).images + orig_image_slice = original_images[0, -3:, -3:, -1] + + # Emulate training. + set_lora_weights(lora_components["unet_lora_layers"].parameters(), randn_weight=True) + set_lora_weights(lora_components["text_encoder_one_lora_layers"].parameters(), randn_weight=True) + set_lora_weights(lora_components["text_encoder_two_lora_layers"].parameters(), randn_weight=True) + + with tempfile.TemporaryDirectory() as tmpdirname: + StableDiffusionXLPipeline.save_lora_weights( + save_directory=tmpdirname, + unet_lora_layers=lora_components["unet_lora_layers"], + text_encoder_lora_layers=lora_components["text_encoder_one_lora_layers"], + text_encoder_2_lora_layers=lora_components["text_encoder_two_lora_layers"], + safe_serialization=True, + ) + self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))) + sd_pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) + + sd_pipe.fuse_lora() + lora_images = sd_pipe(**pipeline_inputs, generator=torch.manual_seed(0)).images + lora_image_slice = lora_images[0, -3:, -3:, -1] + + # Reverse LoRA fusion. + sd_pipe.unfuse_lora() + original_images = sd_pipe(**pipeline_inputs, generator=torch.manual_seed(0)).images + orig_image_slice_two = original_images[0, -3:, -3:, -1] + + assert not np.allclose( + orig_image_slice, lora_image_slice + ), "Fusion of LoRAs should lead to a different image slice." + assert not np.allclose( + orig_image_slice_two, lora_image_slice + ), "Fusion of LoRAs should lead to a different image slice." + assert np.allclose( + orig_image_slice, orig_image_slice_two, atol=1e-3 + ), "Reversing LoRA fusion should lead to results similar to what was obtained with the pipeline without any LoRA parameters." + + def test_lora_fusion_is_not_affected_by_unloading(self): + pipeline_components, lora_components = self.get_dummy_components() + sd_pipe = StableDiffusionXLPipeline(**pipeline_components) + sd_pipe = sd_pipe.to(torch_device) + sd_pipe.set_progress_bar_config(disable=None) + + _, _, pipeline_inputs = self.get_dummy_inputs(with_generator=False) + + _ = sd_pipe(**pipeline_inputs, generator=torch.manual_seed(0)).images + + # Emulate training. + set_lora_weights(lora_components["unet_lora_layers"].parameters(), randn_weight=True) + set_lora_weights(lora_components["text_encoder_one_lora_layers"].parameters(), randn_weight=True) + set_lora_weights(lora_components["text_encoder_two_lora_layers"].parameters(), randn_weight=True) + + with tempfile.TemporaryDirectory() as tmpdirname: + StableDiffusionXLPipeline.save_lora_weights( + save_directory=tmpdirname, + unet_lora_layers=lora_components["unet_lora_layers"], + text_encoder_lora_layers=lora_components["text_encoder_one_lora_layers"], + text_encoder_2_lora_layers=lora_components["text_encoder_two_lora_layers"], + safe_serialization=True, + ) + self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))) + sd_pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) + + sd_pipe.fuse_lora() + lora_images = sd_pipe(**pipeline_inputs, generator=torch.manual_seed(0)).images + lora_image_slice = lora_images[0, -3:, -3:, -1] + + # Unload LoRA parameters. + sd_pipe.unload_lora_weights() + images_with_unloaded_lora = sd_pipe(**pipeline_inputs, generator=torch.manual_seed(0)).images + images_with_unloaded_lora_slice = images_with_unloaded_lora[0, -3:, -3:, -1] + + assert np.allclose( + lora_image_slice, images_with_unloaded_lora_slice + ), "`unload_lora_weights()` should have not effect on the semantics of the results as the LoRA parameters were fused." + @slow @require_torch_gpu @@ -877,10 +996,10 @@ def test_sdxl_0_9_lora_one(self): generator = torch.Generator().manual_seed(0) pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-0.9") - pipe.enable_model_cpu_offload() lora_model_id = "hf-internal-testing/sdxl-0.9-daiton-lora" lora_filename = "daiton-xl-lora-test.safetensors" pipe.load_lora_weights(lora_model_id, weight_name=lora_filename) + pipe.enable_model_cpu_offload() images = pipe( "masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=2 @@ -895,10 +1014,10 @@ def test_sdxl_0_9_lora_two(self): generator = torch.Generator().manual_seed(0) pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-0.9") - pipe.enable_model_cpu_offload() lora_model_id = "hf-internal-testing/sdxl-0.9-costumes-lora" lora_filename = "saijo.safetensors" pipe.load_lora_weights(lora_model_id, weight_name=lora_filename) + pipe.enable_model_cpu_offload() images = pipe( "masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=2 @@ -913,10 +1032,10 @@ def test_sdxl_0_9_lora_three(self): generator = torch.Generator().manual_seed(0) pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-0.9") - pipe.enable_model_cpu_offload() lora_model_id = "hf-internal-testing/sdxl-0.9-kamepan-lora" lora_filename = "kame_sdxl_v2-000020-16rank.safetensors" pipe.load_lora_weights(lora_model_id, weight_name=lora_filename) + pipe.enable_model_cpu_offload() images = pipe( "masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=2 @@ -931,19 +1050,127 @@ def test_sdxl_1_0_lora(self): generator = torch.Generator().manual_seed(0) pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0") + lora_model_id = "hf-internal-testing/sdxl-1.0-lora" + lora_filename = "sd_xl_offset_example-lora_1.0.safetensors" + pipe.load_lora_weights(lora_model_id, weight_name=lora_filename) pipe.enable_model_cpu_offload() + + images = pipe( + "masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=2 + ).images + + images = images[0, -3:, -3:, -1].flatten() + expected = np.array([0.4468, 0.4087, 0.4134, 0.366, 0.3202, 0.3505, 0.3786, 0.387, 0.3535]) + + self.assertTrue(np.allclose(images, expected, atol=1e-4)) + + def test_sdxl_1_0_lora_fusion(self): + generator = torch.Generator().manual_seed(0) + + pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0") lora_model_id = "hf-internal-testing/sdxl-1.0-lora" lora_filename = "sd_xl_offset_example-lora_1.0.safetensors" pipe.load_lora_weights(lora_model_id, weight_name=lora_filename) + pipe.fuse_lora() + pipe.enable_model_cpu_offload() images = pipe( "masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=2 ).images images = images[0, -3:, -3:, -1].flatten() + # This way we also test equivalence between LoRA fusion and the non-fusion behaviour. expected = np.array([0.4468, 0.4087, 0.4134, 0.366, 0.3202, 0.3505, 0.3786, 0.387, 0.3535]) - self.assertTrue(np.allclose(images, expected, atol=1e-3)) + self.assertTrue(np.allclose(images, expected, atol=1e-4)) + + def test_sdxl_1_0_lora_unfusion(self): + generator = torch.Generator().manual_seed(0) + + pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0") + lora_model_id = "hf-internal-testing/sdxl-1.0-lora" + lora_filename = "sd_xl_offset_example-lora_1.0.safetensors" + pipe.load_lora_weights(lora_model_id, weight_name=lora_filename) + pipe.fuse_lora() + pipe.enable_model_cpu_offload() + + images = pipe( + "masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=2 + ).images + images_with_fusion = images[0, -3:, -3:, -1].flatten() + + pipe.unfuse_lora() + generator = torch.Generator().manual_seed(0) + images = pipe( + "masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=2 + ).images + images_without_fusion = images[0, -3:, -3:, -1].flatten() + + self.assertFalse(np.allclose(images_with_fusion, images_without_fusion, atol=1e-3)) + + def test_sdxl_1_0_lora_unfusion_effectivity(self): + pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0") + pipe.enable_model_cpu_offload() + + generator = torch.Generator().manual_seed(0) + images = pipe( + "masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=2 + ).images + original_image_slice = images[0, -3:, -3:, -1].flatten() + + lora_model_id = "hf-internal-testing/sdxl-1.0-lora" + lora_filename = "sd_xl_offset_example-lora_1.0.safetensors" + pipe.load_lora_weights(lora_model_id, weight_name=lora_filename) + pipe.fuse_lora() + + generator = torch.Generator().manual_seed(0) + _ = pipe( + "masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=2 + ).images + + pipe.unfuse_lora() + generator = torch.Generator().manual_seed(0) + images = pipe( + "masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=2 + ).images + images_without_fusion_slice = images[0, -3:, -3:, -1].flatten() + + self.assertTrue(np.allclose(original_image_slice, images_without_fusion_slice, atol=1e-3)) + + def test_sdxl_1_0_lora_fusion_efficiency(self): + generator = torch.Generator().manual_seed(0) + lora_model_id = "hf-internal-testing/sdxl-1.0-lora" + lora_filename = "sd_xl_offset_example-lora_1.0.safetensors" + + pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0") + pipe.load_lora_weights(lora_model_id, weight_name=lora_filename) + pipe.enable_model_cpu_offload() + + start_time = time.time() + for _ in range(3): + pipe( + "masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=2 + ).images + end_time = time.time() + elapsed_time_non_fusion = end_time - start_time + + del pipe + + pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0") + pipe.load_lora_weights(lora_model_id, weight_name=lora_filename) + pipe.fuse_lora() + pipe.enable_model_cpu_offload() + + start_time = time.time() + generator = torch.Generator().manual_seed(0) + for _ in range(3): + pipe( + "masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=2 + ).images + end_time = time.time() + elapsed_time_fusion = end_time - start_time + + self.assertTrue(elapsed_time_fusion < elapsed_time_non_fusion) def test_sdxl_1_0_last_ben(self): generator = torch.Generator().manual_seed(0)