Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@
_deps = [
"Pillow", # keep the PIL.Image.Resampling deprecation away
"accelerate>=0.11.0",
"peft>=0.5.0",
"compel==0.1.8",
"black~=23.1",
"datasets",
Expand Down Expand Up @@ -200,7 +201,7 @@ def run(self):
extras = {}
extras["quality"] = deps_list("urllib3", "black", "isort", "ruff", "hf-doc-builder")
extras["docs"] = deps_list("hf-doc-builder")
extras["training"] = deps_list("accelerate", "datasets", "protobuf", "tensorboard", "Jinja2")
extras["training"] = deps_list("accelerate", "datasets", "protobuf", "tensorboard", "Jinja2", "peft")
extras["test"] = deps_list(
"compel",
"datasets",
Expand All @@ -220,7 +221,7 @@ def run(self):
"torchvision",
"transformers",
)
extras["torch"] = deps_list("torch", "accelerate")
extras["torch"] = deps_list("torch", "accelerate", "peft")

if os.name == "nt": # windows
extras["flax"] = [] # jax is not supported on windows
Expand Down
1 change: 1 addition & 0 deletions src/diffusers/dependency_versions_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
deps = {
"Pillow": "Pillow",
"accelerate": "accelerate>=0.11.0",
"peft": "peft>=0.5.0",
"compel": "compel==0.1.8",
"black": "black~=23.1",
"datasets": "datasets",
Expand Down
198 changes: 55 additions & 143 deletions src/diffusers/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,14 @@
DIFFUSERS_CACHE,
HF_HUB_OFFLINE,
_get_model_file,
convert_diffusers_state_dict_to_peft,
convert_old_state_dict_to_peft,
convert_unet_state_dict_to_peft,
deprecate,
is_accelerate_available,
is_accelerate_version,
is_omegaconf_available,
is_peft_available,
is_transformers_available,
logging,
)
Expand All @@ -48,6 +52,9 @@
from accelerate import init_empty_weights
from accelerate.hooks import AlignDevicesHook, CpuOffload, remove_hook_from_module

if is_peft_available():
from peft import LoraConfig

logger = logging.get_logger(__name__)

TEXT_ENCODER_NAME = "text_encoder"
Expand Down Expand Up @@ -1385,7 +1392,30 @@ def load_lora_into_unet(cls, state_dict, network_alphas, unet, low_cpu_mem_usage
warn_message = "You have saved the LoRA weights using the old format. To convert the old LoRA weights to the new format, you can first load them in a dictionary and then create a new dictionary like the following: `new_state_dict = {f'unet.{module_name}': params for module_name, params in old_state_dict.items()}`."
warnings.warn(warn_message)

unet.load_attn_procs(state_dict, network_alphas=network_alphas, low_cpu_mem_usage=low_cpu_mem_usage)
# load loras into unet
# TODO: @younesbelkada deal with network_alphas
from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict

state_dict, target_modules = convert_unet_state_dict_to_peft(state_dict)

lora_config = LoraConfig(
r=4,
lora_alpha=4,
target_modules=target_modules,
)

inject_adapter_in_model(lora_config, unet)

incompatible_keys = set_peft_model_state_dict(unet, state_dict)
unet._is_peft_loaded = True

if incompatible_keys is not None:
# check only for unexpected keys
if hasattr(incompatible_keys, "unexpected_keys") and len(incompatible_keys.unexpected_keys) > 0:
logger.warning(
f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
f" {incompatible_keys.unexpected_keys}. "
)

@classmethod
def load_lora_into_text_encoder(
Expand Down Expand Up @@ -1414,7 +1444,6 @@ def load_lora_into_text_encoder(
argument to `True` will raise an error.
"""
low_cpu_mem_usage = low_cpu_mem_usage if low_cpu_mem_usage is not None else _LOW_CPU_MEM_USAGE_DEFAULT

# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
# then the `state_dict` keys should have `self.unet_name` and/or `self.text_encoder_name` as
# their prefixes.
Expand All @@ -1433,55 +1462,33 @@ def load_lora_into_text_encoder(
logger.info(f"Loading {prefix}.")
rank = {}

# Old diffusers to PEFT
if any("to_out_lora" in k for k in text_encoder_lora_state_dict.keys()):
# Convert from the old naming convention to the new naming convention.
#
# Previously, the old LoRA layers were stored on the state dict at the
# same level as the attention block i.e.
# `text_model.encoder.layers.11.self_attn.to_out_lora.up.weight`.
#
# This is no actual module at that point, they were monkey patched on to the
# existing module. We want to be able to load them via their actual state dict.
# They're in `PatchedLoraProjection.lora_linear_layer` now.
for name, _ in text_encoder_attn_modules(text_encoder):
text_encoder_lora_state_dict[
f"{name}.q_proj.lora_linear_layer.up.weight"
] = text_encoder_lora_state_dict.pop(f"{name}.to_q_lora.up.weight")
text_encoder_lora_state_dict[
f"{name}.k_proj.lora_linear_layer.up.weight"
] = text_encoder_lora_state_dict.pop(f"{name}.to_k_lora.up.weight")
text_encoder_lora_state_dict[
f"{name}.v_proj.lora_linear_layer.up.weight"
] = text_encoder_lora_state_dict.pop(f"{name}.to_v_lora.up.weight")
text_encoder_lora_state_dict[
f"{name}.out_proj.lora_linear_layer.up.weight"
] = text_encoder_lora_state_dict.pop(f"{name}.to_out_lora.up.weight")

text_encoder_lora_state_dict[
f"{name}.q_proj.lora_linear_layer.down.weight"
] = text_encoder_lora_state_dict.pop(f"{name}.to_q_lora.down.weight")
text_encoder_lora_state_dict[
f"{name}.k_proj.lora_linear_layer.down.weight"
] = text_encoder_lora_state_dict.pop(f"{name}.to_k_lora.down.weight")
text_encoder_lora_state_dict[
f"{name}.v_proj.lora_linear_layer.down.weight"
] = text_encoder_lora_state_dict.pop(f"{name}.to_v_lora.down.weight")
text_encoder_lora_state_dict[
f"{name}.out_proj.lora_linear_layer.down.weight"
] = text_encoder_lora_state_dict.pop(f"{name}.to_out_lora.down.weight")
attention_modules = text_encoder_attn_modules(text_encoder)
text_encoder_lora_state_dict = convert_old_state_dict_to_peft(
attention_modules, text_encoder_lora_state_dict
)
# New diffusers format to PEFT
elif any("lora_linear_layer" in k for k in text_encoder_lora_state_dict.keys()):
attention_modules = text_encoder_attn_modules(text_encoder)
text_encoder_lora_state_dict = convert_diffusers_state_dict_to_peft(
attention_modules, text_encoder_lora_state_dict
)

for name, _ in text_encoder_attn_modules(text_encoder):
rank_key = f"{name}.out_proj.lora_linear_layer.up.weight"
rank_key = f"{name}.out_proj.lora_B.weight"
rank.update({rank_key: text_encoder_lora_state_dict[rank_key].shape[1]})

patch_mlp = any(".mlp." in key for key in text_encoder_lora_state_dict.keys())
if patch_mlp:
for name, _ in text_encoder_mlp_modules(text_encoder):
rank_key_fc1 = f"{name}.fc1.lora_linear_layer.up.weight"
rank_key_fc2 = f"{name}.fc2.lora_linear_layer.up.weight"
rank_key_fc1 = f"{name}.fc1.lora_B.weight"
rank_key_fc2 = f"{name}.fc2.lora_B.weight"
rank.update({rank_key_fc1: text_encoder_lora_state_dict[rank_key_fc1].shape[1]})
rank.update({rank_key_fc2: text_encoder_lora_state_dict[rank_key_fc2].shape[1]})

# for diffusers format you always get the same rank everywhere
# is it possible to load with PEFT
if network_alphas is not None:
alpha_keys = [
k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix
Expand All @@ -1490,34 +1497,16 @@ def load_lora_into_text_encoder(
k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys
}

cls._modify_text_encoder(
text_encoder,
lora_scale,
network_alphas,
rank=rank,
patch_mlp=patch_mlp,
low_cpu_mem_usage=low_cpu_mem_usage,
)
lora_rank = list(rank.values())[0]
alpha = lora_scale * lora_rank

# set correct dtype & device
text_encoder_lora_state_dict = {
k: v.to(device=text_encoder.device, dtype=text_encoder.dtype)
for k, v in text_encoder_lora_state_dict.items()
}
if low_cpu_mem_usage:
device = next(iter(text_encoder_lora_state_dict.values())).device
dtype = next(iter(text_encoder_lora_state_dict.values())).dtype
unexpected_keys = load_model_dict_into_meta(
text_encoder, text_encoder_lora_state_dict, device=device, dtype=dtype
)
else:
load_state_dict_results = text_encoder.load_state_dict(text_encoder_lora_state_dict, strict=False)
unexpected_keys = load_state_dict_results.unexpected_keys
target_modules = ["q_proj", "k_proj", "v_proj", "out_proj"]
if patch_mlp:
target_modules += ["fc1", "fc2"]

if len(unexpected_keys) != 0:
raise ValueError(
f"failed to load text encoder state dict, unexpected keys: {load_state_dict_results.unexpected_keys}"
)
lora_config = LoraConfig(r=lora_rank, target_modules=target_modules, lora_alpha=alpha)

text_encoder.load_adapter(text_encoder_lora_state_dict, peft_config=lora_config)

text_encoder.to(device=text_encoder.device, dtype=text_encoder.dtype)

Expand All @@ -1544,83 +1533,6 @@ def _remove_text_encoder_monkey_patch_classmethod(cls, text_encoder):
mlp_module.fc1.lora_linear_layer = None
mlp_module.fc2.lora_linear_layer = None

@classmethod
def _modify_text_encoder(
cls,
text_encoder,
lora_scale=1,
network_alphas=None,
rank: Union[Dict[str, int], int] = 4,
dtype=None,
patch_mlp=False,
low_cpu_mem_usage=False,
):
r"""
Monkey-patches the forward passes of attention modules of the text encoder.
"""

def create_patched_linear_lora(model, network_alpha, rank, dtype, lora_parameters):
linear_layer = model.regular_linear_layer if isinstance(model, PatchedLoraProjection) else model
ctx = init_empty_weights if low_cpu_mem_usage else nullcontext
with ctx():
model = PatchedLoraProjection(linear_layer, lora_scale, network_alpha, rank, dtype=dtype)

lora_parameters.extend(model.lora_linear_layer.parameters())
return model

# First, remove any monkey-patch that might have been applied before
cls._remove_text_encoder_monkey_patch_classmethod(text_encoder)

lora_parameters = []
network_alphas = {} if network_alphas is None else network_alphas
is_network_alphas_populated = len(network_alphas) > 0

for name, attn_module in text_encoder_attn_modules(text_encoder):
query_alpha = network_alphas.pop(name + ".to_q_lora.down.weight.alpha", None)
key_alpha = network_alphas.pop(name + ".to_k_lora.down.weight.alpha", None)
value_alpha = network_alphas.pop(name + ".to_v_lora.down.weight.alpha", None)
out_alpha = network_alphas.pop(name + ".to_out_lora.down.weight.alpha", None)

if isinstance(rank, dict):
current_rank = rank.pop(f"{name}.out_proj.lora_linear_layer.up.weight")
else:
current_rank = rank

attn_module.q_proj = create_patched_linear_lora(
attn_module.q_proj, query_alpha, current_rank, dtype, lora_parameters
)
attn_module.k_proj = create_patched_linear_lora(
attn_module.k_proj, key_alpha, current_rank, dtype, lora_parameters
)
attn_module.v_proj = create_patched_linear_lora(
attn_module.v_proj, value_alpha, current_rank, dtype, lora_parameters
)
attn_module.out_proj = create_patched_linear_lora(
attn_module.out_proj, out_alpha, current_rank, dtype, lora_parameters
)

if patch_mlp:
for name, mlp_module in text_encoder_mlp_modules(text_encoder):
fc1_alpha = network_alphas.pop(name + ".fc1.lora_linear_layer.down.weight.alpha", None)
fc2_alpha = network_alphas.pop(name + ".fc2.lora_linear_layer.down.weight.alpha", None)

current_rank_fc1 = rank.pop(f"{name}.fc1.lora_linear_layer.up.weight")
current_rank_fc2 = rank.pop(f"{name}.fc2.lora_linear_layer.up.weight")

mlp_module.fc1 = create_patched_linear_lora(
mlp_module.fc1, fc1_alpha, current_rank_fc1, dtype, lora_parameters
)
mlp_module.fc2 = create_patched_linear_lora(
mlp_module.fc2, fc2_alpha, current_rank_fc2, dtype, lora_parameters
)

if is_network_alphas_populated and len(network_alphas) > 0:
raise ValueError(
f"The `network_alphas` has to be empty at this point but has the following keys \n\n {', '.join(network_alphas.keys())}"
)

return lora_parameters

@classmethod
def save_lora_weights(
self,
Expand Down
14 changes: 5 additions & 9 deletions src/diffusers/models/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
from .activations import get_activation
from .attention_processor import Attention
from .embeddings import CombinedTimestepLabelEmbeddings
from .lora import LoRACompatibleLinear


@maybe_allow_in_graph
Expand Down Expand Up @@ -296,17 +295,14 @@ def __init__(
# project dropout
self.net.append(nn.Dropout(dropout))
# project out
self.net.append(LoRACompatibleLinear(inner_dim, dim_out))
self.net.append(nn.Linear(inner_dim, dim_out))
# FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
if final_dropout:
self.net.append(nn.Dropout(dropout))

def forward(self, hidden_states, scale: float = 1.0):
for module in self.net:
if isinstance(module, (LoRACompatibleLinear, GEGLU)):
hidden_states = module(hidden_states, scale)
else:
hidden_states = module(hidden_states)
hidden_states = module(hidden_states)
return hidden_states


Expand Down Expand Up @@ -343,16 +339,16 @@ class GEGLU(nn.Module):

def __init__(self, dim_in: int, dim_out: int):
super().__init__()
self.proj = LoRACompatibleLinear(dim_in, dim_out * 2)
self.proj = nn.Linear(dim_in, dim_out * 2)

def gelu(self, gate):
if gate.device.type != "mps":
return F.gelu(gate)
# mps: gelu is not implemented for float16
return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)

def forward(self, hidden_states, scale: float = 1.0):
hidden_states, gate = self.proj(hidden_states, scale).chunk(2, dim=-1)
def forward(self, hidden_states):
hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1)
return hidden_states * self.gelu(gate)


Expand Down
Loading