Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 65 additions & 41 deletions src/diffusers/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.
import os
import re
from collections import defaultdict
from collections import defaultdict, OrderedDict
from contextlib import nullcontext
from io import BytesIO
from pathlib import Path
Expand Down Expand Up @@ -62,7 +62,7 @@


class PatchedLoraProjection(nn.Module):
def __init__(self, regular_linear_layer, lora_scale=1, network_alpha=None, rank=4, dtype=None):
def __init__(self, regular_linear_layer, lora_scale=1, network_alpha=None, rank=4, dtype=None, fused_params=None):
super().__init__()
from .models.lora import LoRALinearLayer

Expand All @@ -83,6 +83,11 @@ def __init__(self, regular_linear_layer, lora_scale=1, network_alpha=None, rank=
)

self.lora_scale = lora_scale
if fused_params is not None:
self._fused_params = fused_params
else:
self._fused_params = OrderedDict()


# overwrite PyTorch's `state_dict` to be sure that only the 'regular_linear_layer' weights are saved
# when saving the whole text encoder model and when LoRA is unloaded or fused
Expand All @@ -94,10 +99,19 @@ def state_dict(self, *args, destination=None, prefix="", keep_vars=False):

return super().state_dict(*args, destination=destination, prefix=prefix, keep_vars=keep_vars)

def _fuse_lora(self, lora_scale=1.0):
def _fuse_lora(self, lora_scale=1.0, lora_name: str = None):
if self.lora_linear_layer is None:
return

if lora_name is None:
lora_name = "unspecified"
unspecified_num = 0
while lora_name in self._fused_params:
lora_name = f"{lora_name}{unspecified_num+1}"
unspecified_num += 1
if lora_name in self._fused_params:
raise ValueError(f"LoRA with name {lora_name} already fused")

dtype, device = self.regular_linear_layer.weight.data.dtype, self.regular_linear_layer.weight.data.device

w_orig = self.regular_linear_layer.weight.data.float()
Expand All @@ -114,26 +128,33 @@ def _fuse_lora(self, lora_scale=1.0):
self.lora_linear_layer = None

# offload the up and down matrices to CPU to not blow the memory
self.w_up = w_up.cpu()
self.w_down = w_down.cpu()
self.lora_scale = lora_scale
self._fused_params[lora_name] = {
"w_up": w_up.cpu(),
"w_down": w_down.cpu(),
"lora_scale": lora_scale
}

def _unfuse_lora(self):
if not (getattr(self, "w_up", None) is not None and getattr(self, "w_down", None) is not None):
def _unfuse_lora(self, lora_name: str = None):
if len(self._fused_params) == 0:
return

if lora_name is None:
unfuse_lora = self._fused_params.popitem(last=True)[1]
else:
if lora_name not in self._fused_params:
raise ValueError(f"LoRA with name {lora_name} not found")
unfuse_lora = self._fused_params.pop(lora_name)

fused_weight = self.regular_linear_layer.weight.data
dtype, device = fused_weight.dtype, fused_weight.device

w_up = self.w_up.to(device=device).float()
w_down = self.w_down.to(device).float()
w_up = unfuse_lora["w_up"].to(device=device).float()
w_down = unfuse_lora["w_down"].to(device).float()
lora_scale = unfuse_lora["lora_scale"]

unfused_weight = fused_weight.float() - (self.lora_scale * torch.bmm(w_up[None, :], w_down[None, :])[0])
unfused_weight = fused_weight.float() - (lora_scale * torch.bmm(w_up[None, :], w_down[None, :])[0])
self.regular_linear_layer.weight.data = unfused_weight.to(device=device, dtype=dtype)

self.w_up = None
self.w_down = None

def forward(self, input):
if self.lora_scale is None:
self.lora_scale = 1.0
Expand Down Expand Up @@ -645,20 +666,22 @@ def save_function(weights, filename):
save_function(state_dict, os.path.join(save_directory, weight_name))
logger.info(f"Model weights saved in {os.path.join(save_directory, weight_name)}")

def fuse_lora(self, lora_scale=1.0):
def fuse_lora(self, lora_scale=1.0, lora_name: str = None):
self.lora_scale = lora_scale
self.lora_name = lora_name
self.apply(self._fuse_lora_apply)

def _fuse_lora_apply(self, module):
if hasattr(module, "_fuse_lora"):
module._fuse_lora(self.lora_scale)
module._fuse_lora(self.lora_scale, self.lora_name)

def unfuse_lora(self):
def unfuse_lora(self, lora_name: str = None):
self.lora_name = lora_name
self.apply(self._unfuse_lora_apply)

def _unfuse_lora_apply(self, module):
if hasattr(module, "_unfuse_lora"):
module._unfuse_lora()
module._unfuse_lora(self.lora_name)


def load_textual_inversion_state_dicts(pretrained_model_name_or_paths, **kwargs):
Expand Down Expand Up @@ -1678,9 +1701,10 @@ def _modify_text_encoder(

def create_patched_linear_lora(model, network_alpha, rank, dtype, lora_parameters):
linear_layer = model.regular_linear_layer if isinstance(model, PatchedLoraProjection) else model
fused_params = model._fused_params if isinstance(model, PatchedLoraProjection) else None
ctx = init_empty_weights if low_cpu_mem_usage else nullcontext
with ctx():
model = PatchedLoraProjection(linear_layer, lora_scale, network_alpha, rank, dtype=dtype)
model = PatchedLoraProjection(linear_layer, lora_scale, network_alpha, rank, dtype=dtype, fused_params=fused_params)

lora_parameters.extend(model.lora_linear_layer.parameters())
return model
Expand Down Expand Up @@ -2021,7 +2045,7 @@ def unload_lora_weights(self):
# Safe to call the following regardless of LoRA.
self._remove_text_encoder_monkey_patch()

def fuse_lora(self, fuse_unet: bool = True, fuse_text_encoder: bool = True, lora_scale: float = 1.0):
def fuse_lora(self, fuse_unet: bool = True, fuse_text_encoder: bool = True, lora_scale: float = 1.0, lora_name: str = None):
r"""
Fuses the LoRA parameters into the original parameters of the corresponding blocks.

Expand All @@ -2047,28 +2071,28 @@ def fuse_lora(self, fuse_unet: bool = True, fuse_text_encoder: bool = True, lora
)

if fuse_unet:
self.unet.fuse_lora(lora_scale)
self.unet.fuse_lora(lora_scale, lora_name)

def fuse_text_encoder_lora(text_encoder):
def fuse_text_encoder_lora(text_encoder, lora_name):
for _, attn_module in text_encoder_attn_modules(text_encoder):
if isinstance(attn_module.q_proj, PatchedLoraProjection):
attn_module.q_proj._fuse_lora(lora_scale)
attn_module.k_proj._fuse_lora(lora_scale)
attn_module.v_proj._fuse_lora(lora_scale)
attn_module.out_proj._fuse_lora(lora_scale)
attn_module.q_proj._fuse_lora(lora_scale, lora_name)
attn_module.k_proj._fuse_lora(lora_scale, lora_name)
attn_module.v_proj._fuse_lora(lora_scale, lora_name)
attn_module.out_proj._fuse_lora(lora_scale, lora_name)

for _, mlp_module in text_encoder_mlp_modules(text_encoder):
if isinstance(mlp_module.fc1, PatchedLoraProjection):
mlp_module.fc1._fuse_lora(lora_scale)
mlp_module.fc2._fuse_lora(lora_scale)
mlp_module.fc1._fuse_lora(lora_scale, lora_name)
mlp_module.fc2._fuse_lora(lora_scale, lora_name)

if fuse_text_encoder:
if hasattr(self, "text_encoder"):
fuse_text_encoder_lora(self.text_encoder)
fuse_text_encoder_lora(self.text_encoder, lora_name)
if hasattr(self, "text_encoder_2"):
fuse_text_encoder_lora(self.text_encoder_2)
fuse_text_encoder_lora(self.text_encoder_2, lora_name)

def unfuse_lora(self, unfuse_unet: bool = True, unfuse_text_encoder: bool = True):
def unfuse_lora(self, unfuse_unet: bool = True, unfuse_text_encoder: bool = True, lora_name: str = None):
r"""
Reverses the effect of
[`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraLoaderMixin.fuse_lora).
Expand All @@ -2086,26 +2110,26 @@ def unfuse_lora(self, unfuse_unet: bool = True, unfuse_text_encoder: bool = True
LoRA parameters then it won't have any effect.
"""
if unfuse_unet:
self.unet.unfuse_lora()
self.unet.unfuse_lora(lora_name)

def unfuse_text_encoder_lora(text_encoder):
def unfuse_text_encoder_lora(text_encoder, lora_name):
for _, attn_module in text_encoder_attn_modules(text_encoder):
if isinstance(attn_module.q_proj, PatchedLoraProjection):
attn_module.q_proj._unfuse_lora()
attn_module.k_proj._unfuse_lora()
attn_module.v_proj._unfuse_lora()
attn_module.out_proj._unfuse_lora()
attn_module.q_proj._unfuse_lora(lora_name)
attn_module.k_proj._unfuse_lora(lora_name)
attn_module.v_proj._unfuse_lora(lora_name)
attn_module.out_proj._unfuse_lora(lora_name)

for _, mlp_module in text_encoder_mlp_modules(text_encoder):
if isinstance(mlp_module.fc1, PatchedLoraProjection):
mlp_module.fc1._unfuse_lora()
mlp_module.fc2._unfuse_lora()
mlp_module.fc1._unfuse_lora(lora_name)
mlp_module.fc2._unfuse_lora(lora_name)

if unfuse_text_encoder:
if hasattr(self, "text_encoder"):
unfuse_text_encoder_lora(self.text_encoder)
unfuse_text_encoder_lora(self.text_encoder, lora_name)
if hasattr(self, "text_encoder_2"):
unfuse_text_encoder_lora(self.text_encoder_2)
unfuse_text_encoder_lora(self.text_encoder_2, lora_name)

self.num_fused_loras -= 1

Expand Down
90 changes: 65 additions & 25 deletions src/diffusers/models/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

from typing import Optional
from collections import OrderedDict

import torch
import torch.nn.functional as F
Expand Down Expand Up @@ -108,14 +109,25 @@ class LoRACompatibleConv(nn.Conv2d):
def __init__(self, *args, lora_layer: Optional[LoRAConv2dLayer] = None, **kwargs):
super().__init__(*args, **kwargs)
self.lora_layer = lora_layer
self._fused_params = OrderedDict()

def set_lora_layer(self, lora_layer: Optional[LoRAConv2dLayer]):
self.lora_layer = lora_layer

def _fuse_lora(self, lora_scale=1.0):
def _fuse_lora(self, lora_scale=1.0, lora_name: str = None):
if self.lora_layer is None:
return

# make sure lora have a unique name
if lora_name is None:
lora_name = "unspecified"
unspecified_num = 0
while lora_name in self._fused_params:
lora_name = f"{lora_name}{unspecified_num+1}"
unspecified_num += 1
if lora_name in self._fused_params:
raise ValueError(f"LoRA with name {lora_name} already fused")

dtype, device = self.weight.data.dtype, self.weight.data.device

w_orig = self.weight.data.float()
Expand All @@ -134,27 +146,35 @@ def _fuse_lora(self, lora_scale=1.0):
self.lora_layer = None

# offload the up and down matrices to CPU to not blow the memory
self.w_up = w_up.cpu()
self.w_down = w_down.cpu()
self._lora_scale = lora_scale

def _unfuse_lora(self):
if not (getattr(self, "w_up", None) is not None and getattr(self, "w_down", None) is not None):
self._fused_params[lora_name] = {
"w_up": w_up.cpu(),
"w_down": w_down.cpu(),
"lora_scale": lora_scale
}

def _unfuse_lora(self, lora_name: str = None):
if len(self._fused_params) == 0:
return

if lora_name is None:
unfuse_lora = self._fused_params.popitem(last=True)[1]
else:
if lora_name not in self._fused_params:
raise ValueError(f"LoRA with name {lora_name} not found")
unfuse_lora = self._fused_params.pop(lora_name)

fused_weight = self.weight.data
dtype, device = fused_weight.data.dtype, fused_weight.data.device

self.w_up = self.w_up.to(device=device).float()
self.w_down = self.w_down.to(device).float()
w_up = unfuse_lora["w_up"].to(device=device).float()
w_down = unfuse_lora["w_down"].to(device).float()
lora_scale = unfuse_lora["lora_scale"]

fusion = torch.mm(self.w_up.flatten(start_dim=1), self.w_down.flatten(start_dim=1))
fusion = torch.mm(w_up.flatten(start_dim=1), w_down.flatten(start_dim=1))
fusion = fusion.reshape((fused_weight.shape))
unfused_weight = fused_weight.float() - (self._lora_scale * fusion)
unfused_weight = fused_weight.float() - (lora_scale * fusion)
self.weight.data = unfused_weight.to(device=device, dtype=dtype)

self.w_up = None
self.w_down = None

def forward(self, hidden_states, scale: float = 1.0):
if self.lora_layer is None:
Expand All @@ -175,14 +195,25 @@ class LoRACompatibleLinear(nn.Linear):
def __init__(self, *args, lora_layer: Optional[LoRALinearLayer] = None, **kwargs):
super().__init__(*args, **kwargs)
self.lora_layer = lora_layer
self._fused_params = OrderedDict()


def set_lora_layer(self, lora_layer: Optional[LoRALinearLayer]):
self.lora_layer = lora_layer

def _fuse_lora(self, lora_scale=1.0):
def _fuse_lora(self, lora_scale=1.0, lora_name: str = None):
if self.lora_layer is None:
return

if lora_name is None:
lora_name = "unspecified"
unspecified_num = 0
while lora_name in self._fused_params:
lora_name = f"{lora_name}{unspecified_num+1}"
unspecified_num += 1
if lora_name in self._fused_params:
raise ValueError(f"LoRA with name {lora_name} already fused")

dtype, device = self.weight.data.dtype, self.weight.data.device

w_orig = self.weight.data.float()
Expand All @@ -199,25 +230,34 @@ def _fuse_lora(self, lora_scale=1.0):
self.lora_layer = None

# offload the up and down matrices to CPU to not blow the memory
self.w_up = w_up.cpu()
self.w_down = w_down.cpu()
self._lora_scale = lora_scale

def _unfuse_lora(self):
if not (getattr(self, "w_up", None) is not None and getattr(self, "w_down", None) is not None):
self._fused_params[lora_name] = {
"w_up": w_up.cpu(),
"w_down": w_down.cpu(),
"lora_scale": lora_scale
}

def _unfuse_lora(self, lora_name: str = None):
if len(self._fused_params) == 0:
return

if lora_name is None:
unfuse_lora = self._fused_params.popitem(last=True)[1]

else:
if lora_name not in self._fused_params:
raise ValueError(f"LoRA with name {lora_name} not found")
unfuse_lora = self._fused_params.pop(lora_name)

fused_weight = self.weight.data
dtype, device = fused_weight.dtype, fused_weight.device

w_up = self.w_up.to(device=device).float()
w_down = self.w_down.to(device).float()
w_up = unfuse_lora["w_up"].to(device=device).float()
w_down = unfuse_lora["w_down"].to(device).float()
lora_scale = unfuse_lora["lora_scale"]

unfused_weight = fused_weight.float() - (self._lora_scale * torch.bmm(w_up[None, :], w_down[None, :])[0])
unfused_weight = fused_weight.float() - (lora_scale * torch.bmm(w_up[None, :], w_down[None, :])[0])
self.weight.data = unfused_weight.to(device=device, dtype=dtype)

self.w_up = None
self.w_down = None

def forward(self, hidden_states, scale: float = 1.0):
if self.lora_layer is None:
Expand Down