From 55e246db2c9f17682be79e854aabe754a78524da Mon Sep 17 00:00:00 2001 From: AnyISalIn Date: Thu, 21 Sep 2023 13:45:07 +0800 Subject: [PATCH] support fuse/unfuse multiple lora Signed-off-by: AnyISalIn --- src/diffusers/loaders.py | 106 +++++++++++++++++++++-------------- src/diffusers/models/lora.py | 90 ++++++++++++++++++++--------- 2 files changed, 130 insertions(+), 66 deletions(-) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index bea6e21aa64a..30bf5579cd6a 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -13,7 +13,7 @@ # limitations under the License. import os import re -from collections import defaultdict +from collections import defaultdict, OrderedDict from contextlib import nullcontext from io import BytesIO from pathlib import Path @@ -62,7 +62,7 @@ class PatchedLoraProjection(nn.Module): - def __init__(self, regular_linear_layer, lora_scale=1, network_alpha=None, rank=4, dtype=None): + def __init__(self, regular_linear_layer, lora_scale=1, network_alpha=None, rank=4, dtype=None, fused_params=None): super().__init__() from .models.lora import LoRALinearLayer @@ -83,6 +83,11 @@ def __init__(self, regular_linear_layer, lora_scale=1, network_alpha=None, rank= ) self.lora_scale = lora_scale + if fused_params is not None: + self._fused_params = fused_params + else: + self._fused_params = OrderedDict() + # overwrite PyTorch's `state_dict` to be sure that only the 'regular_linear_layer' weights are saved # when saving the whole text encoder model and when LoRA is unloaded or fused @@ -94,10 +99,19 @@ def state_dict(self, *args, destination=None, prefix="", keep_vars=False): return super().state_dict(*args, destination=destination, prefix=prefix, keep_vars=keep_vars) - def _fuse_lora(self, lora_scale=1.0): + def _fuse_lora(self, lora_scale=1.0, lora_name: str = None): if self.lora_linear_layer is None: return + if lora_name is None: + lora_name = "unspecified" + unspecified_num = 0 + while lora_name in self._fused_params: + lora_name = f"{lora_name}{unspecified_num+1}" + unspecified_num += 1 + if lora_name in self._fused_params: + raise ValueError(f"LoRA with name {lora_name} already fused") + dtype, device = self.regular_linear_layer.weight.data.dtype, self.regular_linear_layer.weight.data.device w_orig = self.regular_linear_layer.weight.data.float() @@ -114,26 +128,33 @@ def _fuse_lora(self, lora_scale=1.0): self.lora_linear_layer = None # offload the up and down matrices to CPU to not blow the memory - self.w_up = w_up.cpu() - self.w_down = w_down.cpu() - self.lora_scale = lora_scale + self._fused_params[lora_name] = { + "w_up": w_up.cpu(), + "w_down": w_down.cpu(), + "lora_scale": lora_scale + } - def _unfuse_lora(self): - if not (getattr(self, "w_up", None) is not None and getattr(self, "w_down", None) is not None): + def _unfuse_lora(self, lora_name: str = None): + if len(self._fused_params) == 0: return + if lora_name is None: + unfuse_lora = self._fused_params.popitem(last=True)[1] + else: + if lora_name not in self._fused_params: + raise ValueError(f"LoRA with name {lora_name} not found") + unfuse_lora = self._fused_params.pop(lora_name) + fused_weight = self.regular_linear_layer.weight.data dtype, device = fused_weight.dtype, fused_weight.device - w_up = self.w_up.to(device=device).float() - w_down = self.w_down.to(device).float() + w_up = unfuse_lora["w_up"].to(device=device).float() + w_down = unfuse_lora["w_down"].to(device).float() + lora_scale = unfuse_lora["lora_scale"] - unfused_weight = fused_weight.float() - (self.lora_scale * torch.bmm(w_up[None, :], w_down[None, :])[0]) + unfused_weight = fused_weight.float() - (lora_scale * torch.bmm(w_up[None, :], w_down[None, :])[0]) self.regular_linear_layer.weight.data = unfused_weight.to(device=device, dtype=dtype) - self.w_up = None - self.w_down = None - def forward(self, input): if self.lora_scale is None: self.lora_scale = 1.0 @@ -645,20 +666,22 @@ def save_function(weights, filename): save_function(state_dict, os.path.join(save_directory, weight_name)) logger.info(f"Model weights saved in {os.path.join(save_directory, weight_name)}") - def fuse_lora(self, lora_scale=1.0): + def fuse_lora(self, lora_scale=1.0, lora_name: str = None): self.lora_scale = lora_scale + self.lora_name = lora_name self.apply(self._fuse_lora_apply) def _fuse_lora_apply(self, module): if hasattr(module, "_fuse_lora"): - module._fuse_lora(self.lora_scale) + module._fuse_lora(self.lora_scale, self.lora_name) - def unfuse_lora(self): + def unfuse_lora(self, lora_name: str = None): + self.lora_name = lora_name self.apply(self._unfuse_lora_apply) def _unfuse_lora_apply(self, module): if hasattr(module, "_unfuse_lora"): - module._unfuse_lora() + module._unfuse_lora(self.lora_name) def load_textual_inversion_state_dicts(pretrained_model_name_or_paths, **kwargs): @@ -1678,9 +1701,10 @@ def _modify_text_encoder( def create_patched_linear_lora(model, network_alpha, rank, dtype, lora_parameters): linear_layer = model.regular_linear_layer if isinstance(model, PatchedLoraProjection) else model + fused_params = model._fused_params if isinstance(model, PatchedLoraProjection) else None ctx = init_empty_weights if low_cpu_mem_usage else nullcontext with ctx(): - model = PatchedLoraProjection(linear_layer, lora_scale, network_alpha, rank, dtype=dtype) + model = PatchedLoraProjection(linear_layer, lora_scale, network_alpha, rank, dtype=dtype, fused_params=fused_params) lora_parameters.extend(model.lora_linear_layer.parameters()) return model @@ -2021,7 +2045,7 @@ def unload_lora_weights(self): # Safe to call the following regardless of LoRA. self._remove_text_encoder_monkey_patch() - def fuse_lora(self, fuse_unet: bool = True, fuse_text_encoder: bool = True, lora_scale: float = 1.0): + def fuse_lora(self, fuse_unet: bool = True, fuse_text_encoder: bool = True, lora_scale: float = 1.0, lora_name: str = None): r""" Fuses the LoRA parameters into the original parameters of the corresponding blocks. @@ -2047,28 +2071,28 @@ def fuse_lora(self, fuse_unet: bool = True, fuse_text_encoder: bool = True, lora ) if fuse_unet: - self.unet.fuse_lora(lora_scale) + self.unet.fuse_lora(lora_scale, lora_name) - def fuse_text_encoder_lora(text_encoder): + def fuse_text_encoder_lora(text_encoder, lora_name): for _, attn_module in text_encoder_attn_modules(text_encoder): if isinstance(attn_module.q_proj, PatchedLoraProjection): - attn_module.q_proj._fuse_lora(lora_scale) - attn_module.k_proj._fuse_lora(lora_scale) - attn_module.v_proj._fuse_lora(lora_scale) - attn_module.out_proj._fuse_lora(lora_scale) + attn_module.q_proj._fuse_lora(lora_scale, lora_name) + attn_module.k_proj._fuse_lora(lora_scale, lora_name) + attn_module.v_proj._fuse_lora(lora_scale, lora_name) + attn_module.out_proj._fuse_lora(lora_scale, lora_name) for _, mlp_module in text_encoder_mlp_modules(text_encoder): if isinstance(mlp_module.fc1, PatchedLoraProjection): - mlp_module.fc1._fuse_lora(lora_scale) - mlp_module.fc2._fuse_lora(lora_scale) + mlp_module.fc1._fuse_lora(lora_scale, lora_name) + mlp_module.fc2._fuse_lora(lora_scale, lora_name) if fuse_text_encoder: if hasattr(self, "text_encoder"): - fuse_text_encoder_lora(self.text_encoder) + fuse_text_encoder_lora(self.text_encoder, lora_name) if hasattr(self, "text_encoder_2"): - fuse_text_encoder_lora(self.text_encoder_2) + fuse_text_encoder_lora(self.text_encoder_2, lora_name) - def unfuse_lora(self, unfuse_unet: bool = True, unfuse_text_encoder: bool = True): + def unfuse_lora(self, unfuse_unet: bool = True, unfuse_text_encoder: bool = True, lora_name: str = None): r""" Reverses the effect of [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraLoaderMixin.fuse_lora). @@ -2086,26 +2110,26 @@ def unfuse_lora(self, unfuse_unet: bool = True, unfuse_text_encoder: bool = True LoRA parameters then it won't have any effect. """ if unfuse_unet: - self.unet.unfuse_lora() + self.unet.unfuse_lora(lora_name) - def unfuse_text_encoder_lora(text_encoder): + def unfuse_text_encoder_lora(text_encoder, lora_name): for _, attn_module in text_encoder_attn_modules(text_encoder): if isinstance(attn_module.q_proj, PatchedLoraProjection): - attn_module.q_proj._unfuse_lora() - attn_module.k_proj._unfuse_lora() - attn_module.v_proj._unfuse_lora() - attn_module.out_proj._unfuse_lora() + attn_module.q_proj._unfuse_lora(lora_name) + attn_module.k_proj._unfuse_lora(lora_name) + attn_module.v_proj._unfuse_lora(lora_name) + attn_module.out_proj._unfuse_lora(lora_name) for _, mlp_module in text_encoder_mlp_modules(text_encoder): if isinstance(mlp_module.fc1, PatchedLoraProjection): - mlp_module.fc1._unfuse_lora() - mlp_module.fc2._unfuse_lora() + mlp_module.fc1._unfuse_lora(lora_name) + mlp_module.fc2._unfuse_lora(lora_name) if unfuse_text_encoder: if hasattr(self, "text_encoder"): - unfuse_text_encoder_lora(self.text_encoder) + unfuse_text_encoder_lora(self.text_encoder, lora_name) if hasattr(self, "text_encoder_2"): - unfuse_text_encoder_lora(self.text_encoder_2) + unfuse_text_encoder_lora(self.text_encoder_2, lora_name) self.num_fused_loras -= 1 diff --git a/src/diffusers/models/lora.py b/src/diffusers/models/lora.py index cc8e3e231e2b..b7787ed5fc80 100644 --- a/src/diffusers/models/lora.py +++ b/src/diffusers/models/lora.py @@ -13,6 +13,7 @@ # limitations under the License. from typing import Optional +from collections import OrderedDict import torch import torch.nn.functional as F @@ -108,14 +109,25 @@ class LoRACompatibleConv(nn.Conv2d): def __init__(self, *args, lora_layer: Optional[LoRAConv2dLayer] = None, **kwargs): super().__init__(*args, **kwargs) self.lora_layer = lora_layer + self._fused_params = OrderedDict() def set_lora_layer(self, lora_layer: Optional[LoRAConv2dLayer]): self.lora_layer = lora_layer - def _fuse_lora(self, lora_scale=1.0): + def _fuse_lora(self, lora_scale=1.0, lora_name: str = None): if self.lora_layer is None: return + # make sure lora have a unique name + if lora_name is None: + lora_name = "unspecified" + unspecified_num = 0 + while lora_name in self._fused_params: + lora_name = f"{lora_name}{unspecified_num+1}" + unspecified_num += 1 + if lora_name in self._fused_params: + raise ValueError(f"LoRA with name {lora_name} already fused") + dtype, device = self.weight.data.dtype, self.weight.data.device w_orig = self.weight.data.float() @@ -134,27 +146,35 @@ def _fuse_lora(self, lora_scale=1.0): self.lora_layer = None # offload the up and down matrices to CPU to not blow the memory - self.w_up = w_up.cpu() - self.w_down = w_down.cpu() - self._lora_scale = lora_scale - - def _unfuse_lora(self): - if not (getattr(self, "w_up", None) is not None and getattr(self, "w_down", None) is not None): + self._fused_params[lora_name] = { + "w_up": w_up.cpu(), + "w_down": w_down.cpu(), + "lora_scale": lora_scale + } + + def _unfuse_lora(self, lora_name: str = None): + if len(self._fused_params) == 0: return + if lora_name is None: + unfuse_lora = self._fused_params.popitem(last=True)[1] + else: + if lora_name not in self._fused_params: + raise ValueError(f"LoRA with name {lora_name} not found") + unfuse_lora = self._fused_params.pop(lora_name) + fused_weight = self.weight.data dtype, device = fused_weight.data.dtype, fused_weight.data.device - self.w_up = self.w_up.to(device=device).float() - self.w_down = self.w_down.to(device).float() + w_up = unfuse_lora["w_up"].to(device=device).float() + w_down = unfuse_lora["w_down"].to(device).float() + lora_scale = unfuse_lora["lora_scale"] - fusion = torch.mm(self.w_up.flatten(start_dim=1), self.w_down.flatten(start_dim=1)) + fusion = torch.mm(w_up.flatten(start_dim=1), w_down.flatten(start_dim=1)) fusion = fusion.reshape((fused_weight.shape)) - unfused_weight = fused_weight.float() - (self._lora_scale * fusion) + unfused_weight = fused_weight.float() - (lora_scale * fusion) self.weight.data = unfused_weight.to(device=device, dtype=dtype) - self.w_up = None - self.w_down = None def forward(self, hidden_states, scale: float = 1.0): if self.lora_layer is None: @@ -175,14 +195,25 @@ class LoRACompatibleLinear(nn.Linear): def __init__(self, *args, lora_layer: Optional[LoRALinearLayer] = None, **kwargs): super().__init__(*args, **kwargs) self.lora_layer = lora_layer + self._fused_params = OrderedDict() + def set_lora_layer(self, lora_layer: Optional[LoRALinearLayer]): self.lora_layer = lora_layer - def _fuse_lora(self, lora_scale=1.0): + def _fuse_lora(self, lora_scale=1.0, lora_name: str = None): if self.lora_layer is None: return + if lora_name is None: + lora_name = "unspecified" + unspecified_num = 0 + while lora_name in self._fused_params: + lora_name = f"{lora_name}{unspecified_num+1}" + unspecified_num += 1 + if lora_name in self._fused_params: + raise ValueError(f"LoRA with name {lora_name} already fused") + dtype, device = self.weight.data.dtype, self.weight.data.device w_orig = self.weight.data.float() @@ -199,25 +230,34 @@ def _fuse_lora(self, lora_scale=1.0): self.lora_layer = None # offload the up and down matrices to CPU to not blow the memory - self.w_up = w_up.cpu() - self.w_down = w_down.cpu() - self._lora_scale = lora_scale - - def _unfuse_lora(self): - if not (getattr(self, "w_up", None) is not None and getattr(self, "w_down", None) is not None): + self._fused_params[lora_name] = { + "w_up": w_up.cpu(), + "w_down": w_down.cpu(), + "lora_scale": lora_scale + } + + def _unfuse_lora(self, lora_name: str = None): + if len(self._fused_params) == 0: return + if lora_name is None: + unfuse_lora = self._fused_params.popitem(last=True)[1] + + else: + if lora_name not in self._fused_params: + raise ValueError(f"LoRA with name {lora_name} not found") + unfuse_lora = self._fused_params.pop(lora_name) + fused_weight = self.weight.data dtype, device = fused_weight.dtype, fused_weight.device - w_up = self.w_up.to(device=device).float() - w_down = self.w_down.to(device).float() + w_up = unfuse_lora["w_up"].to(device=device).float() + w_down = unfuse_lora["w_down"].to(device).float() + lora_scale = unfuse_lora["lora_scale"] - unfused_weight = fused_weight.float() - (self._lora_scale * torch.bmm(w_up[None, :], w_down[None, :])[0]) + unfused_weight = fused_weight.float() - (lora_scale * torch.bmm(w_up[None, :], w_down[None, :])[0]) self.weight.data = unfused_weight.to(device=device, dtype=dtype) - self.w_up = None - self.w_down = None def forward(self, hidden_states, scale: float = 1.0): if self.lora_layer is None: