diff --git a/setup.py b/setup.py index ca8928b3223c..e2619772aea4 100644 --- a/setup.py +++ b/setup.py @@ -91,6 +91,7 @@ _deps = [ "Pillow", # keep the PIL.Image.Resampling deprecation away "accelerate>=0.11.0", + "peft>=0.5.0", "compel==0.1.8", "black~=23.1", "datasets", @@ -200,7 +201,7 @@ def run(self): extras = {} extras["quality"] = deps_list("urllib3", "black", "isort", "ruff", "hf-doc-builder") extras["docs"] = deps_list("hf-doc-builder") -extras["training"] = deps_list("accelerate", "datasets", "protobuf", "tensorboard", "Jinja2") +extras["training"] = deps_list("accelerate", "datasets", "protobuf", "tensorboard", "Jinja2", "peft") extras["test"] = deps_list( "compel", "datasets", @@ -220,7 +221,7 @@ def run(self): "torchvision", "transformers", ) -extras["torch"] = deps_list("torch", "accelerate") +extras["torch"] = deps_list("torch", "accelerate", "peft") if os.name == "nt": # windows extras["flax"] = [] # jax is not supported on windows diff --git a/src/diffusers/dependency_versions_table.py b/src/diffusers/dependency_versions_table.py index d4b94ba6d4ed..4d3cea199352 100644 --- a/src/diffusers/dependency_versions_table.py +++ b/src/diffusers/dependency_versions_table.py @@ -4,6 +4,7 @@ deps = { "Pillow": "Pillow", "accelerate": "accelerate>=0.11.0", + "peft": "peft>=0.5.0", "compel": "compel==0.1.8", "black": "black~=23.1", "datasets": "datasets", diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index 45c866c1aa16..5a68cc0a4a3f 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -31,10 +31,14 @@ DIFFUSERS_CACHE, HF_HUB_OFFLINE, _get_model_file, + convert_diffusers_state_dict_to_peft, + convert_old_state_dict_to_peft, + convert_unet_state_dict_to_peft, deprecate, is_accelerate_available, is_accelerate_version, is_omegaconf_available, + is_peft_available, is_transformers_available, logging, ) @@ -48,6 +52,9 @@ from accelerate import init_empty_weights from accelerate.hooks import AlignDevicesHook, CpuOffload, remove_hook_from_module +if is_peft_available(): + from peft import LoraConfig + logger = logging.get_logger(__name__) TEXT_ENCODER_NAME = "text_encoder" @@ -1385,7 +1392,30 @@ def load_lora_into_unet(cls, state_dict, network_alphas, unet, low_cpu_mem_usage warn_message = "You have saved the LoRA weights using the old format. To convert the old LoRA weights to the new format, you can first load them in a dictionary and then create a new dictionary like the following: `new_state_dict = {f'unet.{module_name}': params for module_name, params in old_state_dict.items()}`." warnings.warn(warn_message) - unet.load_attn_procs(state_dict, network_alphas=network_alphas, low_cpu_mem_usage=low_cpu_mem_usage) + # load loras into unet + # TODO: @younesbelkada deal with network_alphas + from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict + + state_dict, target_modules = convert_unet_state_dict_to_peft(state_dict) + + lora_config = LoraConfig( + r=4, + lora_alpha=4, + target_modules=target_modules, + ) + + inject_adapter_in_model(lora_config, unet) + + incompatible_keys = set_peft_model_state_dict(unet, state_dict) + unet._is_peft_loaded = True + + if incompatible_keys is not None: + # check only for unexpected keys + if hasattr(incompatible_keys, "unexpected_keys") and len(incompatible_keys.unexpected_keys) > 0: + logger.warning( + f"Loading adapter weights from state_dict led to unexpected keys not found in the model: " + f" {incompatible_keys.unexpected_keys}. " + ) @classmethod def load_lora_into_text_encoder( @@ -1414,7 +1444,6 @@ def load_lora_into_text_encoder( argument to `True` will raise an error. """ low_cpu_mem_usage = low_cpu_mem_usage if low_cpu_mem_usage is not None else _LOW_CPU_MEM_USAGE_DEFAULT - # If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918), # then the `state_dict` keys should have `self.unet_name` and/or `self.text_encoder_name` as # their prefixes. @@ -1433,55 +1462,33 @@ def load_lora_into_text_encoder( logger.info(f"Loading {prefix}.") rank = {} + # Old diffusers to PEFT if any("to_out_lora" in k for k in text_encoder_lora_state_dict.keys()): - # Convert from the old naming convention to the new naming convention. - # - # Previously, the old LoRA layers were stored on the state dict at the - # same level as the attention block i.e. - # `text_model.encoder.layers.11.self_attn.to_out_lora.up.weight`. - # - # This is no actual module at that point, they were monkey patched on to the - # existing module. We want to be able to load them via their actual state dict. - # They're in `PatchedLoraProjection.lora_linear_layer` now. - for name, _ in text_encoder_attn_modules(text_encoder): - text_encoder_lora_state_dict[ - f"{name}.q_proj.lora_linear_layer.up.weight" - ] = text_encoder_lora_state_dict.pop(f"{name}.to_q_lora.up.weight") - text_encoder_lora_state_dict[ - f"{name}.k_proj.lora_linear_layer.up.weight" - ] = text_encoder_lora_state_dict.pop(f"{name}.to_k_lora.up.weight") - text_encoder_lora_state_dict[ - f"{name}.v_proj.lora_linear_layer.up.weight" - ] = text_encoder_lora_state_dict.pop(f"{name}.to_v_lora.up.weight") - text_encoder_lora_state_dict[ - f"{name}.out_proj.lora_linear_layer.up.weight" - ] = text_encoder_lora_state_dict.pop(f"{name}.to_out_lora.up.weight") - - text_encoder_lora_state_dict[ - f"{name}.q_proj.lora_linear_layer.down.weight" - ] = text_encoder_lora_state_dict.pop(f"{name}.to_q_lora.down.weight") - text_encoder_lora_state_dict[ - f"{name}.k_proj.lora_linear_layer.down.weight" - ] = text_encoder_lora_state_dict.pop(f"{name}.to_k_lora.down.weight") - text_encoder_lora_state_dict[ - f"{name}.v_proj.lora_linear_layer.down.weight" - ] = text_encoder_lora_state_dict.pop(f"{name}.to_v_lora.down.weight") - text_encoder_lora_state_dict[ - f"{name}.out_proj.lora_linear_layer.down.weight" - ] = text_encoder_lora_state_dict.pop(f"{name}.to_out_lora.down.weight") + attention_modules = text_encoder_attn_modules(text_encoder) + text_encoder_lora_state_dict = convert_old_state_dict_to_peft( + attention_modules, text_encoder_lora_state_dict + ) + # New diffusers format to PEFT + elif any("lora_linear_layer" in k for k in text_encoder_lora_state_dict.keys()): + attention_modules = text_encoder_attn_modules(text_encoder) + text_encoder_lora_state_dict = convert_diffusers_state_dict_to_peft( + attention_modules, text_encoder_lora_state_dict + ) for name, _ in text_encoder_attn_modules(text_encoder): - rank_key = f"{name}.out_proj.lora_linear_layer.up.weight" + rank_key = f"{name}.out_proj.lora_B.weight" rank.update({rank_key: text_encoder_lora_state_dict[rank_key].shape[1]}) patch_mlp = any(".mlp." in key for key in text_encoder_lora_state_dict.keys()) if patch_mlp: for name, _ in text_encoder_mlp_modules(text_encoder): - rank_key_fc1 = f"{name}.fc1.lora_linear_layer.up.weight" - rank_key_fc2 = f"{name}.fc2.lora_linear_layer.up.weight" + rank_key_fc1 = f"{name}.fc1.lora_B.weight" + rank_key_fc2 = f"{name}.fc2.lora_B.weight" rank.update({rank_key_fc1: text_encoder_lora_state_dict[rank_key_fc1].shape[1]}) rank.update({rank_key_fc2: text_encoder_lora_state_dict[rank_key_fc2].shape[1]}) + # for diffusers format you always get the same rank everywhere + # is it possible to load with PEFT if network_alphas is not None: alpha_keys = [ k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix @@ -1490,34 +1497,16 @@ def load_lora_into_text_encoder( k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys } - cls._modify_text_encoder( - text_encoder, - lora_scale, - network_alphas, - rank=rank, - patch_mlp=patch_mlp, - low_cpu_mem_usage=low_cpu_mem_usage, - ) + lora_rank = list(rank.values())[0] + alpha = lora_scale * lora_rank - # set correct dtype & device - text_encoder_lora_state_dict = { - k: v.to(device=text_encoder.device, dtype=text_encoder.dtype) - for k, v in text_encoder_lora_state_dict.items() - } - if low_cpu_mem_usage: - device = next(iter(text_encoder_lora_state_dict.values())).device - dtype = next(iter(text_encoder_lora_state_dict.values())).dtype - unexpected_keys = load_model_dict_into_meta( - text_encoder, text_encoder_lora_state_dict, device=device, dtype=dtype - ) - else: - load_state_dict_results = text_encoder.load_state_dict(text_encoder_lora_state_dict, strict=False) - unexpected_keys = load_state_dict_results.unexpected_keys + target_modules = ["q_proj", "k_proj", "v_proj", "out_proj"] + if patch_mlp: + target_modules += ["fc1", "fc2"] - if len(unexpected_keys) != 0: - raise ValueError( - f"failed to load text encoder state dict, unexpected keys: {load_state_dict_results.unexpected_keys}" - ) + lora_config = LoraConfig(r=lora_rank, target_modules=target_modules, lora_alpha=alpha) + + text_encoder.load_adapter(text_encoder_lora_state_dict, peft_config=lora_config) text_encoder.to(device=text_encoder.device, dtype=text_encoder.dtype) @@ -1544,83 +1533,6 @@ def _remove_text_encoder_monkey_patch_classmethod(cls, text_encoder): mlp_module.fc1.lora_linear_layer = None mlp_module.fc2.lora_linear_layer = None - @classmethod - def _modify_text_encoder( - cls, - text_encoder, - lora_scale=1, - network_alphas=None, - rank: Union[Dict[str, int], int] = 4, - dtype=None, - patch_mlp=False, - low_cpu_mem_usage=False, - ): - r""" - Monkey-patches the forward passes of attention modules of the text encoder. - """ - - def create_patched_linear_lora(model, network_alpha, rank, dtype, lora_parameters): - linear_layer = model.regular_linear_layer if isinstance(model, PatchedLoraProjection) else model - ctx = init_empty_weights if low_cpu_mem_usage else nullcontext - with ctx(): - model = PatchedLoraProjection(linear_layer, lora_scale, network_alpha, rank, dtype=dtype) - - lora_parameters.extend(model.lora_linear_layer.parameters()) - return model - - # First, remove any monkey-patch that might have been applied before - cls._remove_text_encoder_monkey_patch_classmethod(text_encoder) - - lora_parameters = [] - network_alphas = {} if network_alphas is None else network_alphas - is_network_alphas_populated = len(network_alphas) > 0 - - for name, attn_module in text_encoder_attn_modules(text_encoder): - query_alpha = network_alphas.pop(name + ".to_q_lora.down.weight.alpha", None) - key_alpha = network_alphas.pop(name + ".to_k_lora.down.weight.alpha", None) - value_alpha = network_alphas.pop(name + ".to_v_lora.down.weight.alpha", None) - out_alpha = network_alphas.pop(name + ".to_out_lora.down.weight.alpha", None) - - if isinstance(rank, dict): - current_rank = rank.pop(f"{name}.out_proj.lora_linear_layer.up.weight") - else: - current_rank = rank - - attn_module.q_proj = create_patched_linear_lora( - attn_module.q_proj, query_alpha, current_rank, dtype, lora_parameters - ) - attn_module.k_proj = create_patched_linear_lora( - attn_module.k_proj, key_alpha, current_rank, dtype, lora_parameters - ) - attn_module.v_proj = create_patched_linear_lora( - attn_module.v_proj, value_alpha, current_rank, dtype, lora_parameters - ) - attn_module.out_proj = create_patched_linear_lora( - attn_module.out_proj, out_alpha, current_rank, dtype, lora_parameters - ) - - if patch_mlp: - for name, mlp_module in text_encoder_mlp_modules(text_encoder): - fc1_alpha = network_alphas.pop(name + ".fc1.lora_linear_layer.down.weight.alpha", None) - fc2_alpha = network_alphas.pop(name + ".fc2.lora_linear_layer.down.weight.alpha", None) - - current_rank_fc1 = rank.pop(f"{name}.fc1.lora_linear_layer.up.weight") - current_rank_fc2 = rank.pop(f"{name}.fc2.lora_linear_layer.up.weight") - - mlp_module.fc1 = create_patched_linear_lora( - mlp_module.fc1, fc1_alpha, current_rank_fc1, dtype, lora_parameters - ) - mlp_module.fc2 = create_patched_linear_lora( - mlp_module.fc2, fc2_alpha, current_rank_fc2, dtype, lora_parameters - ) - - if is_network_alphas_populated and len(network_alphas) > 0: - raise ValueError( - f"The `network_alphas` has to be empty at this point but has the following keys \n\n {', '.join(network_alphas.keys())}" - ) - - return lora_parameters - @classmethod def save_lora_weights( self, diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 892d44a03137..d5fc2850b86c 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -21,7 +21,6 @@ from .activations import get_activation from .attention_processor import Attention from .embeddings import CombinedTimestepLabelEmbeddings -from .lora import LoRACompatibleLinear @maybe_allow_in_graph @@ -296,17 +295,14 @@ def __init__( # project dropout self.net.append(nn.Dropout(dropout)) # project out - self.net.append(LoRACompatibleLinear(inner_dim, dim_out)) + self.net.append(nn.Linear(inner_dim, dim_out)) # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout if final_dropout: self.net.append(nn.Dropout(dropout)) def forward(self, hidden_states, scale: float = 1.0): for module in self.net: - if isinstance(module, (LoRACompatibleLinear, GEGLU)): - hidden_states = module(hidden_states, scale) - else: - hidden_states = module(hidden_states) + hidden_states = module(hidden_states) return hidden_states @@ -343,7 +339,7 @@ class GEGLU(nn.Module): def __init__(self, dim_in: int, dim_out: int): super().__init__() - self.proj = LoRACompatibleLinear(dim_in, dim_out * 2) + self.proj = nn.Linear(dim_in, dim_out * 2) def gelu(self, gate): if gate.device.type != "mps": @@ -351,8 +347,8 @@ def gelu(self, gate): # mps: gelu is not implemented for float16 return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype) - def forward(self, hidden_states, scale: float = 1.0): - hidden_states, gate = self.proj(hidden_states, scale).chunk(2, dim=-1) + def forward(self, hidden_states): + hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1) return hidden_states * self.gelu(gate) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 36851085c483..e16c90b61a8a 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -21,7 +21,7 @@ from ..utils import deprecate, logging from ..utils.import_utils import is_xformers_available from ..utils.torch_utils import maybe_allow_in_graph -from .lora import LoRACompatibleLinear, LoRALinearLayer +from .lora import LoRALinearLayer logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -137,22 +137,22 @@ def __init__( f"unknown cross_attention_norm: {cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'" ) - self.to_q = LoRACompatibleLinear(query_dim, self.inner_dim, bias=bias) + self.to_q = nn.Linear(query_dim, self.inner_dim, bias=bias) if not self.only_cross_attention: # only relevant for the `AddedKVProcessor` classes - self.to_k = LoRACompatibleLinear(self.cross_attention_dim, self.inner_dim, bias=bias) - self.to_v = LoRACompatibleLinear(self.cross_attention_dim, self.inner_dim, bias=bias) + self.to_k = nn.Linear(self.cross_attention_dim, self.inner_dim, bias=bias) + self.to_v = nn.Linear(self.cross_attention_dim, self.inner_dim, bias=bias) else: self.to_k = None self.to_v = None if self.added_kv_proj_dim is not None: - self.add_k_proj = LoRACompatibleLinear(added_kv_proj_dim, self.inner_dim) - self.add_v_proj = LoRACompatibleLinear(added_kv_proj_dim, self.inner_dim) + self.add_k_proj = nn.Linear(added_kv_proj_dim, self.inner_dim) + self.add_v_proj = nn.Linear(added_kv_proj_dim, self.inner_dim) self.to_out = nn.ModuleList([]) - self.to_out.append(LoRACompatibleLinear(self.inner_dim, query_dim, bias=out_bias)) + self.to_out.append(nn.Linear(self.inner_dim, query_dim, bias=out_bias)) self.to_out.append(nn.Dropout(dropout)) # set attention processor @@ -413,11 +413,33 @@ def get_processor(self, return_deprecated_lora: bool = False) -> "AttentionProce return lora_processor + def scale_peft_lora_layers(self, scale: float = 1.0): + from peft.tuners.lora import LoraLayer + + total_modules_to_scale = list(self.modules()) + + for module in total_modules_to_scale: + if isinstance(module, LoraLayer): + module.scale_layer(scale) + + def unscale_peft_lora_layers(self, scale: float = 1.0): + from peft.tuners.lora import LoraLayer + + total_modules_to_unscale = list(self.modules()) + + for module in total_modules_to_unscale: + if isinstance(module, LoraLayer): + module.unscale_layer(scale) + def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, **cross_attention_kwargs): + # retrieve the scale of LoRA layers and optionnaly scale / unscale them + scale = cross_attention_kwargs.get("scale", 1.0) + self.scale_peft_lora_layers(scale) + # The `Attention` class can call different attention processors / attention functions # here we simply pass along all tensors to the selected processor class # For standard processors that are defined here, `**cross_attention_kwargs` is empty - return self.processor( + output = self.processor( self, hidden_states, encoder_hidden_states=encoder_hidden_states, @@ -425,6 +447,10 @@ def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None **cross_attention_kwargs, ) + # unscale operation in case + self.unscale_peft_lora_layers(scale) + return output + def batch_to_head_dim(self, tensor): head_size = self.heads batch_size, seq_len, dim = tensor.shape @@ -559,15 +585,15 @@ def __call__( if attn.group_norm is not None: hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) - query = attn.to_q(hidden_states, scale=scale) + query = attn.to_q(hidden_states) if encoder_hidden_states is None: encoder_hidden_states = hidden_states elif attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) - key = attn.to_k(encoder_hidden_states, scale=scale) - value = attn.to_v(encoder_hidden_states, scale=scale) + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) query = attn.head_to_batch_dim(query) key = attn.head_to_batch_dim(key) @@ -578,7 +604,7 @@ def __call__( hidden_states = attn.batch_to_head_dim(hidden_states) # linear proj - hidden_states = attn.to_out[0](hidden_states, scale=scale) + hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) @@ -711,17 +737,17 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) - query = attn.to_q(hidden_states, scale=scale) + query = attn.to_q(hidden_states) query = attn.head_to_batch_dim(query) - encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states, scale=scale) - encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states, scale=scale) + encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) + encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj) encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj) if not attn.only_cross_attention: - key = attn.to_k(hidden_states, scale=scale) - value = attn.to_v(hidden_states, scale=scale) + key = attn.to_k(hidden_states) + value = attn.to_v(hidden_states) key = attn.head_to_batch_dim(key) value = attn.head_to_batch_dim(value) key = torch.cat([encoder_hidden_states_key_proj, key], dim=1) @@ -735,7 +761,7 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a hidden_states = attn.batch_to_head_dim(hidden_states) # linear proj - hidden_states = attn.to_out[0](hidden_states, scale=scale) + hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) @@ -771,7 +797,7 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) - query = attn.to_q(hidden_states, scale=scale) + query = attn.to_q(hidden_states) query = attn.head_to_batch_dim(query, out_dim=4) encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) @@ -780,8 +806,8 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj, out_dim=4) if not attn.only_cross_attention: - key = attn.to_k(hidden_states, scale=scale) - value = attn.to_v(hidden_states, scale=scale) + key = attn.to_k(hidden_states) + value = attn.to_v(hidden_states) key = attn.head_to_batch_dim(key, out_dim=4) value = attn.head_to_batch_dim(value, out_dim=4) key = torch.cat([encoder_hidden_states_key_proj, key], dim=2) @@ -798,7 +824,7 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, residual.shape[1]) # linear proj - hidden_states = attn.to_out[0](hidden_states, scale=scale) + hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) @@ -926,15 +952,15 @@ def __call__( if attn.group_norm is not None: hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) - query = attn.to_q(hidden_states, scale=scale) + query = attn.to_q(hidden_states) if encoder_hidden_states is None: encoder_hidden_states = hidden_states elif attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) - key = attn.to_k(encoder_hidden_states, scale=scale) - value = attn.to_v(encoder_hidden_states, scale=scale) + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) query = attn.head_to_batch_dim(query).contiguous() key = attn.head_to_batch_dim(key).contiguous() @@ -947,7 +973,7 @@ def __call__( hidden_states = attn.batch_to_head_dim(hidden_states) # linear proj - hidden_states = attn.to_out[0](hidden_states, scale=scale) + hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) @@ -1004,15 +1030,15 @@ def __call__( if attn.group_norm is not None: hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) - query = attn.to_q(hidden_states, scale=scale) + query = attn.to_q(hidden_states) if encoder_hidden_states is None: encoder_hidden_states = hidden_states elif attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) - key = attn.to_k(encoder_hidden_states, scale=scale) - value = attn.to_v(encoder_hidden_states, scale=scale) + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) inner_dim = key.shape[-1] head_dim = inner_dim // attn.heads @@ -1032,7 +1058,7 @@ def __call__( hidden_states = hidden_states.to(query.dtype) # linear proj - hidden_states = attn.to_out[0](hidden_states, scale=scale) + hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py index ac66e2271c61..3278987703d9 100644 --- a/src/diffusers/models/resnet.py +++ b/src/diffusers/models/resnet.py @@ -23,7 +23,6 @@ from .activations import get_activation from .attention import AdaGroupNorm from .attention_processor import SpatialNorm -from .lora import LoRACompatibleConv, LoRACompatibleLinear class Upsample1D(nn.Module): @@ -127,7 +126,7 @@ def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_chann if use_conv_transpose: conv = nn.ConvTranspose2d(channels, self.out_channels, 4, 2, 1) elif use_conv: - conv = LoRACompatibleConv(self.channels, self.out_channels, 3, padding=1) + conv = nn.Conv2d(self.channels, self.out_channels, 3, padding=1) # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed if name == "conv": @@ -166,15 +165,9 @@ def forward(self, hidden_states, output_size=None, scale: float = 1.0): # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed if self.use_conv: if self.name == "conv": - if isinstance(self.conv, LoRACompatibleConv): - hidden_states = self.conv(hidden_states, scale) - else: - hidden_states = self.conv(hidden_states) + hidden_states = self.conv(hidden_states) else: - if isinstance(self.Conv2d_0, LoRACompatibleConv): - hidden_states = self.Conv2d_0(hidden_states, scale) - else: - hidden_states = self.Conv2d_0(hidden_states) + hidden_states = self.Conv2d_0(hidden_states) return hidden_states @@ -203,7 +196,7 @@ def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name= self.name = name if use_conv: - conv = LoRACompatibleConv(self.channels, self.out_channels, 3, stride=stride, padding=padding) + conv = nn.Conv2d(self.channels, self.out_channels, 3, stride=stride, padding=padding) else: assert self.channels == self.out_channels conv = nn.AvgPool2d(kernel_size=stride, stride=stride) @@ -219,16 +212,13 @@ def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name= def forward(self, hidden_states, scale: float = 1.0): assert hidden_states.shape[1] == self.channels + if self.use_conv and self.padding == 0: pad = (0, 1, 0, 1) hidden_states = F.pad(hidden_states, pad, mode="constant", value=0) assert hidden_states.shape[1] == self.channels - if isinstance(self.conv, LoRACompatibleConv): - hidden_states = self.conv(hidden_states, scale) - else: - hidden_states = self.conv(hidden_states) - + hidden_states = self.conv(hidden_states) return hidden_states @@ -544,13 +534,13 @@ def __init__( else: self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True) - self.conv1 = LoRACompatibleConv(in_channels, out_channels, kernel_size=3, stride=1, padding=1) + self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) if temb_channels is not None: if self.time_embedding_norm == "default": - self.time_emb_proj = LoRACompatibleLinear(temb_channels, out_channels) + self.time_emb_proj = nn.Linear(temb_channels, out_channels) elif self.time_embedding_norm == "scale_shift": - self.time_emb_proj = LoRACompatibleLinear(temb_channels, 2 * out_channels) + self.time_emb_proj = nn.Linear(temb_channels, 2 * out_channels) elif self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial": self.time_emb_proj = None else: @@ -567,7 +557,7 @@ def __init__( self.dropout = torch.nn.Dropout(dropout) conv_2d_out_channels = conv_2d_out_channels or out_channels - self.conv2 = LoRACompatibleConv(out_channels, conv_2d_out_channels, kernel_size=3, stride=1, padding=1) + self.conv2 = nn.Conv2d(out_channels, conv_2d_out_channels, kernel_size=3, stride=1, padding=1) self.nonlinearity = get_activation(non_linearity) @@ -593,7 +583,7 @@ def __init__( self.conv_shortcut = None if self.use_in_shortcut: - self.conv_shortcut = LoRACompatibleConv( + self.conv_shortcut = nn.Conv2d( in_channels, conv_2d_out_channels, kernel_size=1, stride=1, padding=0, bias=conv_shortcut_bias ) @@ -634,12 +624,12 @@ def forward(self, input_tensor, temb, scale: float = 1.0): else self.downsample(hidden_states) ) - hidden_states = self.conv1(hidden_states, scale) + hidden_states = self.conv1(hidden_states) if self.time_emb_proj is not None: if not self.skip_time_act: temb = self.nonlinearity(temb) - temb = self.time_emb_proj(temb, scale)[:, :, None, None] + temb = self.time_emb_proj(temb)[:, :, None, None] if temb is not None and self.time_embedding_norm == "default": hidden_states = hidden_states + temb @@ -656,10 +646,10 @@ def forward(self, input_tensor, temb, scale: float = 1.0): hidden_states = self.nonlinearity(hidden_states) hidden_states = self.dropout(hidden_states) - hidden_states = self.conv2(hidden_states, scale) + hidden_states = self.conv2(hidden_states) if self.conv_shortcut is not None: - input_tensor = self.conv_shortcut(input_tensor, scale) + input_tensor = self.conv_shortcut(input_tensor) output_tensor = (input_tensor + hidden_states) / self.output_scale_factor diff --git a/src/diffusers/models/transformer_2d.py b/src/diffusers/models/transformer_2d.py index c96aef65f339..ae8ed8adb5f3 100644 --- a/src/diffusers/models/transformer_2d.py +++ b/src/diffusers/models/transformer_2d.py @@ -23,7 +23,6 @@ from ..utils import BaseOutput, deprecate from .attention import BasicTransformerBlock from .embeddings import PatchEmbed -from .lora import LoRACompatibleConv, LoRACompatibleLinear from .modeling_utils import ModelMixin @@ -139,9 +138,9 @@ def __init__( self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True) if use_linear_projection: - self.proj_in = LoRACompatibleLinear(in_channels, inner_dim) + self.proj_in = nn.Linear(in_channels, inner_dim) else: - self.proj_in = LoRACompatibleConv(in_channels, inner_dim, kernel_size=1, stride=1, padding=0) + self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0) elif self.is_input_vectorized: assert sample_size is not None, "Transformer2DModel over discrete input must provide sample_size" assert num_vector_embeds is not None, "Transformer2DModel over discrete input must provide num_embed" @@ -197,9 +196,9 @@ def __init__( if self.is_input_continuous: # TODO: should use out_channels for continuous projections if use_linear_projection: - self.proj_out = LoRACompatibleLinear(inner_dim, in_channels) + self.proj_out = nn.Linear(inner_dim, in_channels) else: - self.proj_out = LoRACompatibleConv(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) + self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) elif self.is_input_vectorized: self.norm_out = nn.LayerNorm(inner_dim) self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1) @@ -275,7 +274,7 @@ def forward( encoder_attention_mask = encoder_attention_mask.unsqueeze(1) # Retrieve lora scale. - lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 + cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 # 1. Input if self.is_input_continuous: @@ -284,13 +283,13 @@ def forward( hidden_states = self.norm(hidden_states) if not self.use_linear_projection: - hidden_states = self.proj_in(hidden_states, scale=lora_scale) + hidden_states = self.proj_in(hidden_states) inner_dim = hidden_states.shape[1] hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim) else: inner_dim = hidden_states.shape[1] hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim) - hidden_states = self.proj_in(hidden_states, scale=lora_scale) + hidden_states = self.proj_in(hidden_states) elif self.is_input_vectorized: hidden_states = self.latent_image_embedding(hidden_states) @@ -326,9 +325,9 @@ def forward( if self.is_input_continuous: if not self.use_linear_projection: hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() - hidden_states = self.proj_out(hidden_states, scale=lora_scale) + hidden_states = self.proj_out(hidden_states) else: - hidden_states = self.proj_out(hidden_states, scale=lora_scale) + hidden_states = self.proj_out(hidden_states) hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() output = hidden_states + residual diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index d695d182fa37..5d05f26357e1 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -731,6 +731,7 @@ def _set_gradient_checkpointing(self, module, value=False): if hasattr(module, "gradient_checkpointing"): module.gradient_checkpointing = value + def forward( self, sample: torch.FloatTensor, diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py index 7390a2f69d23..1a0500b1b0bd 100644 --- a/src/diffusers/utils/__init__.py +++ b/src/diffusers/utils/__init__.py @@ -67,6 +67,7 @@ is_note_seq_available, is_omegaconf_available, is_onnx_available, + is_peft_available, is_scipy_available, is_tensorboard_available, is_torch_available, @@ -82,6 +83,12 @@ from .loading_utils import load_image from .logging import get_logger from .outputs import BaseOutput +from .peft_utils import ( + convert_diffusers_state_dict_to_peft, + convert_old_state_dict_to_peft, + convert_peft_state_dict_to_diffusers, + convert_unet_state_dict_to_peft, +) from .pil_utils import PIL_INTERPOLATION, make_image_grid, numpy_to_pil, pt_to_pil diff --git a/src/diffusers/utils/import_utils.py b/src/diffusers/utils/import_utils.py index 587949ab0c52..a60ffd582472 100644 --- a/src/diffusers/utils/import_utils.py +++ b/src/diffusers/utils/import_utils.py @@ -180,6 +180,13 @@ except importlib_metadata.PackageNotFoundError: _accelerate_available = False +_peft_available = importlib.util.find_spec("peft") is not None +try: + _accelerate_version = importlib_metadata.version("peft") + logger.debug(f"Successfully imported accelerate version {_accelerate_version}") +except importlib_metadata.PackageNotFoundError: + _peft_available = False + _xformers_available = importlib.util.find_spec("xformers") is not None try: _xformers_version = importlib_metadata.version("xformers") @@ -319,6 +326,10 @@ def is_note_seq_available(): return _note_seq_available +def is_peft_available(): + return _peft_available + + def is_wandb_available(): return _wandb_available diff --git a/src/diffusers/utils/peft_utils.py b/src/diffusers/utils/peft_utils.py new file mode 100644 index 000000000000..d9a8d508f631 --- /dev/null +++ b/src/diffusers/utils/peft_utils.py @@ -0,0 +1,153 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +PEFT utilities: Utilities related to peft library +""" + + +class PeftAdapterMixin: + r""" + Mixin class that contains the useful methods for leveraging PEFT library to load and use adapters + """ + _is_peft_adapter_loaded = False + + def add_adapter( + self, + peft_config, + ): + if not getattr(self, "_is_peft_adapter_loaded", False): + pass + + +def convert_old_state_dict_to_peft(attention_modules, state_dict): + # Convert from the old naming convention to the new naming convention. + # + # Previously, the old LoRA layers were stored on the state dict at the + # same level as the attention block i.e. + # `text_model.encoder.layers.11.self_attn.to_out_lora.lora_A.weight`. + # + # This is no actual module at that point, they were monkey patched on to the + # existing module. We want to be able to load them via their actual state dict. + # They're in `PatchedLoraProjection.lora_linear_layer` now. + converted_state_dict = {} + + for name, _ in attention_modules: + converted_state_dict[f"{name}.q_proj.lora_B.weight"] = state_dict.pop(f"{name}.to_q_lora.up.weight") + converted_state_dict[f"{name}.k_proj.lora_B.weight"] = state_dict.pop(f"{name}.to_k_lora.up.weight") + converted_state_dict[f"{name}.v_proj.lora_B.weight"] = state_dict.pop(f"{name}.to_v_lora.up.weight") + converted_state_dict[f"{name}.out_proj.lora_B.weight"] = state_dict.pop(f"{name}.to_out_lora.up.weight") + + converted_state_dict[f"{name}.q_proj.lora_A.weight"] = state_dict.pop(f"{name}.to_q_lora.down.weight") + converted_state_dict[f"{name}.k_proj.lora_A.weight"] = state_dict.pop(f"{name}.to_k_lora.down.weight") + converted_state_dict[f"{name}.v_proj.lora_A.weight"] = state_dict.pop(f"{name}.to_v_lora.down.weight") + converted_state_dict[f"{name}.out_proj.lora_A.weight"] = state_dict.pop(f"{name}.to_out_lora.down.weight") + + return converted_state_dict + + +def convert_peft_state_dict_to_diffusers(attention_modules, state_dict, adapter_name): + # Convert from the new naming convention to the diffusers naming convention. + converted_state_dict = {} + + for name, _ in attention_modules: + converted_state_dict[f"{name}.q_proj.lora_linear_layer.up.weight"] = state_dict.pop( + f"{name}.q_proj.lora_B.{adapter_name}.weight" + ) + converted_state_dict[f"{name}.k_proj.lora_linear_layer.up.weight"] = state_dict.pop( + f"{name}.k_proj.lora_B.{adapter_name}.weight" + ) + converted_state_dict[f"{name}.v_proj.lora_linear_layer.up.weight"] = state_dict.pop( + f"{name}.v_proj.lora_B.{adapter_name}.weight" + ) + converted_state_dict[f"{name}.out_proj.lora_linear_layer.up.weight"] = state_dict.pop( + f"{name}.out_proj.lora_B.{adapter_name}.weight" + ) + + converted_state_dict[f"{name}.q_proj.lora_linear_layer.down.weight"] = state_dict.pop( + f"{name}.q_proj.lora_A.{adapter_name}.weight" + ) + converted_state_dict[f"{name}.k_proj.lora_linear_layer.down.weight"] = state_dict.pop( + f"{name}.k_proj.lora_A.{adapter_name}.weight" + ) + converted_state_dict[f"{name}.v_proj.lora_linear_layer.down.weight"] = state_dict.pop( + f"{name}.v_proj.lora_A.{adapter_name}.weight" + ) + converted_state_dict[f"{name}.out_proj.lora_linear_layer.down.weight"] = state_dict.pop( + f"{name}.out_proj.lora_A.{adapter_name}.weight" + ) + + return converted_state_dict + + +def convert_diffusers_state_dict_to_peft(attention_modules, state_dict): + # Convert from the diffusers naming convention to the new naming convention. + converted_state_dict = {} + + for name, _ in attention_modules: + converted_state_dict[f"{name}.q_proj.lora_B.weight"] = state_dict.pop( + f"{name}.q_proj.lora_linear_layer.up.weight" + ) + converted_state_dict[f"{name}.k_proj.lora_B.weight"] = state_dict.pop( + f"{name}.k_proj.lora_linear_layer.up.weight" + ) + converted_state_dict[f"{name}.v_proj.lora_B.weight"] = state_dict.pop( + f"{name}.v_proj.lora_linear_layer.up.weight" + ) + converted_state_dict[f"{name}.out_proj.lora_B.weight"] = state_dict.pop( + f"{name}.out_proj.lora_linear_layer.up.weight" + ) + + converted_state_dict[f"{name}.q_proj.lora_A.weight"] = state_dict.pop( + f"{name}.q_proj.lora_linear_layer.down.weight" + ) + converted_state_dict[f"{name}.k_proj.lora_A.weight"] = state_dict.pop( + f"{name}.k_proj.lora_linear_layer.down.weight" + ) + converted_state_dict[f"{name}.v_proj.lora_A.weight"] = state_dict.pop( + f"{name}.v_proj.lora_linear_layer.down.weight" + ) + converted_state_dict[f"{name}.out_proj.lora_A.weight"] = state_dict.pop( + f"{name}.out_proj.lora_linear_layer.down.weight" + ) + + return converted_state_dict + + +def convert_unet_state_dict_to_peft(state_dict): + converted_state_dict = {} + target_modules = [] + + patterns = { + ".to_out_lora": ".to_out.0", + ".down": ".lora_A", + ".up": ".lora_B", + ".to_q_lora": ".to_q", + ".to_k_lora": ".to_k", + ".to_v_lora": ".to_v", + ".processor.": ".", + } + + for k, v in state_dict.items(): + pattern_found = False + + if any(pattern in k for pattern in patterns.keys()): + for old, new in patterns.items(): + k = k.replace(old, new) + pattern_found = True + + converted_state_dict[k] = v + if pattern_found: + target_modules.append(".".join(k.split(".")[:-2])) + + return converted_state_dict, target_modules