From ca11165ca8a187dcabaa7498e91e3974c1a483af Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Fri, 25 Aug 2023 12:48:48 +0000 Subject: [PATCH 1/9] v1 --- src/diffusers/loaders.py | 130 +++++++++++++++++------------- src/diffusers/utils/__init__.py | 1 + src/diffusers/utils/peft_utils.py | 57 +++++++++++++ 3 files changed, 133 insertions(+), 55 deletions(-) create mode 100644 src/diffusers/utils/peft_utils.py diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index ea657ccbdf63..37703ff7a68a 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -37,6 +37,7 @@ is_omegaconf_available, is_transformers_available, logging, + convert_state_dict_to_peft ) from .utils.import_utils import BACKENDS_MAPPING @@ -903,7 +904,7 @@ class LoraLoaderMixin: text_encoder_name = TEXT_ENCODER_NAME unet_name = UNET_NAME - def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs): + def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], use_peft=True, **kwargs): """ Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.unet` and `self.text_encoder`. @@ -931,6 +932,7 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di network_alphas=network_alphas, text_encoder=self.text_encoder, lora_scale=self.lora_scale, + use_peft=use_peft ) @classmethod @@ -1210,7 +1212,7 @@ def load_lora_into_unet(cls, state_dict, network_alphas, unet): unet.load_attn_procs(state_dict, network_alphas=network_alphas) @classmethod - def load_lora_into_text_encoder(cls, state_dict, network_alphas, text_encoder, prefix=None, lora_scale=1.0): + def load_lora_into_text_encoder(cls, state_dict, network_alphas, text_encoder, prefix=None, lora_scale=1., use_peft=True): """ This will load the LoRA layers specified in `state_dict` into `text_encoder` @@ -1228,7 +1230,7 @@ def load_lora_into_text_encoder(cls, state_dict, network_alphas, text_encoder, p How much to scale the output of the lora linear layer before it is added with the output of the regular lora layer. """ - + from peft import LoraConfig # If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918), # then the `state_dict` keys should have `self.unet_name` and/or `self.text_encoder_name` as # their prefixes. @@ -1248,44 +1250,48 @@ def load_lora_into_text_encoder(cls, state_dict, network_alphas, text_encoder, p rank = {} if any("to_out_lora" in k for k in text_encoder_lora_state_dict.keys()): - # Convert from the old naming convention to the new naming convention. - # - # Previously, the old LoRA layers were stored on the state dict at the - # same level as the attention block i.e. - # `text_model.encoder.layers.11.self_attn.to_out_lora.up.weight`. - # - # This is no actual module at that point, they were monkey patched on to the - # existing module. We want to be able to load them via their actual state dict. - # They're in `PatchedLoraProjection.lora_linear_layer` now. - for name, _ in text_encoder_attn_modules(text_encoder): - text_encoder_lora_state_dict[ - f"{name}.q_proj.lora_linear_layer.up.weight" - ] = text_encoder_lora_state_dict.pop(f"{name}.to_q_lora.up.weight") - text_encoder_lora_state_dict[ - f"{name}.k_proj.lora_linear_layer.up.weight" - ] = text_encoder_lora_state_dict.pop(f"{name}.to_k_lora.up.weight") - text_encoder_lora_state_dict[ - f"{name}.v_proj.lora_linear_layer.up.weight" - ] = text_encoder_lora_state_dict.pop(f"{name}.to_v_lora.up.weight") - text_encoder_lora_state_dict[ - f"{name}.out_proj.lora_linear_layer.up.weight" - ] = text_encoder_lora_state_dict.pop(f"{name}.to_out_lora.up.weight") - - text_encoder_lora_state_dict[ - f"{name}.q_proj.lora_linear_layer.down.weight" - ] = text_encoder_lora_state_dict.pop(f"{name}.to_q_lora.down.weight") - text_encoder_lora_state_dict[ - f"{name}.k_proj.lora_linear_layer.down.weight" - ] = text_encoder_lora_state_dict.pop(f"{name}.to_k_lora.down.weight") - text_encoder_lora_state_dict[ - f"{name}.v_proj.lora_linear_layer.down.weight" - ] = text_encoder_lora_state_dict.pop(f"{name}.to_v_lora.down.weight") - text_encoder_lora_state_dict[ - f"{name}.out_proj.lora_linear_layer.down.weight" - ] = text_encoder_lora_state_dict.pop(f"{name}.to_out_lora.down.weight") + if use_peft: + attention_modules = text_encoder_attn_modules(text_encoder) + text_encoder_lora_state_dict = convert_state_dict_to_peft(attention_modules, text_encoder_lora_state_dict) + # # Convert from the old naming convention to the new naming convention. + # # + # # Previously, the old LoRA layers were stored on the state dict at the + # # same level as the attention block i.e. + # # `text_model.encoder.layers.11.self_attn.to_out_lora.up.weight`. + # # + # # This is no actual module at that point, they were monkey patched on to the + # # existing module. We want to be able to load them via their actual state dict. + # # They're in `PatchedLoraProjection.lora_linear_layer` now. + else: + for name, _ in text_encoder_attn_modules(text_encoder): + text_encoder_lora_state_dict[ + f"{name}.q_proj.lora_linear_layer.up.weight" + ] = text_encoder_lora_state_dict.pop(f"{name}.to_q_lora.up.weight") + text_encoder_lora_state_dict[ + f"{name}.k_proj.lora_linear_layer.up.weight" + ] = text_encoder_lora_state_dict.pop(f"{name}.to_k_lora.up.weight") + text_encoder_lora_state_dict[ + f"{name}.v_proj.lora_linear_layer.up.weight" + ] = text_encoder_lora_state_dict.pop(f"{name}.to_v_lora.up.weight") + text_encoder_lora_state_dict[ + f"{name}.out_proj.lora_linear_layer.up.weight" + ] = text_encoder_lora_state_dict.pop(f"{name}.to_out_lora.up.weight") + + text_encoder_lora_state_dict[ + f"{name}.q_proj.lora_linear_layer.down.weight" + ] = text_encoder_lora_state_dict.pop(f"{name}.to_q_lora.down.weight") + text_encoder_lora_state_dict[ + f"{name}.k_proj.lora_linear_layer.down.weight" + ] = text_encoder_lora_state_dict.pop(f"{name}.to_k_lora.down.weight") + text_encoder_lora_state_dict[ + f"{name}.v_proj.lora_linear_layer.down.weight" + ] = text_encoder_lora_state_dict.pop(f"{name}.to_v_lora.down.weight") + text_encoder_lora_state_dict[ + f"{name}.out_proj.lora_linear_layer.down.weight" + ] = text_encoder_lora_state_dict.pop(f"{name}.to_out_lora.down.weight") for name, _ in text_encoder_attn_modules(text_encoder): - rank_key = f"{name}.out_proj.lora_linear_layer.up.weight" + rank_key = f"{name}.out_proj.lora_B.weight" if use_peft else f"{name}.out_proj.lora_linear_layer.up.weight" rank.update({rank_key: text_encoder_lora_state_dict[rank_key].shape[1]}) patch_mlp = any(".mlp." in key for key in text_encoder_lora_state_dict.keys()) @@ -1296,6 +1302,8 @@ def load_lora_into_text_encoder(cls, state_dict, network_alphas, text_encoder, p rank.update({rank_key_fc1: text_encoder_lora_state_dict[rank_key_fc1].shape[1]}) rank.update({rank_key_fc2: text_encoder_lora_state_dict[rank_key_fc2].shape[1]}) + # for diffusers format you always get the same rank everywhere + # is it possible to load with PEFT if network_alphas is not None: alpha_keys = [ k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix @@ -1304,25 +1312,37 @@ def load_lora_into_text_encoder(cls, state_dict, network_alphas, text_encoder, p k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys } - cls._modify_text_encoder( - text_encoder, - lora_scale, - network_alphas, - rank=rank, - patch_mlp=patch_mlp, - ) + if not use_peft: + cls._modify_text_encoder( + text_encoder, + lora_scale, + network_alphas, + rank=rank, + patch_mlp=patch_mlp, + ) - # set correct dtype & device - text_encoder_lora_state_dict = { - k: v.to(device=text_encoder.device, dtype=text_encoder.dtype) - for k, v in text_encoder_lora_state_dict.items() - } - load_state_dict_results = text_encoder.load_state_dict(text_encoder_lora_state_dict, strict=False) - if len(load_state_dict_results.unexpected_keys) != 0: - raise ValueError( - f"failed to load text encoder state dict, unexpected keys: {load_state_dict_results.unexpected_keys}" + # set correct dtype & device + text_encoder_lora_state_dict = { + k: v.to(device=text_encoder.device, dtype=text_encoder.dtype) + for k, v in text_encoder_lora_state_dict.items() + } + load_state_dict_results = text_encoder.load_state_dict(text_encoder_lora_state_dict, strict=False) + if len(load_state_dict_results.unexpected_keys) != 0: + raise ValueError( + f"failed to load text encoder state dict, unexpected keys: {load_state_dict_results.unexpected_keys}" + ) + else: + lora_rank = list(rank.values())[0] + alpha = lora_scale * lora_rank + + lora_config = LoraConfig( + r=lora_rank, + target_modules=["q_proj", "k_proj", "v_proj", "out_proj"], + lora_alpha=alpha ) + text_encoder.load_adapter(text_encoder_lora_state_dict, peft_config=lora_config) + @property def lora_scale(self) -> float: # property function that returns the lora scale which can be set at run time by the pipeline. diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py index 9b710d214d92..a11c3644cf3a 100644 --- a/src/diffusers/utils/__init__.py +++ b/src/diffusers/utils/__init__.py @@ -82,6 +82,7 @@ from .outputs import BaseOutput from .pil_utils import PIL_INTERPOLATION, make_image_grid, numpy_to_pil, pt_to_pil from .torch_utils import is_compiled_module, randn_tensor +from .peft_utils import convert_state_dict_to_peft if is_torch_available(): diff --git a/src/diffusers/utils/peft_utils.py b/src/diffusers/utils/peft_utils.py new file mode 100644 index 000000000000..f07ed661f19d --- /dev/null +++ b/src/diffusers/utils/peft_utils.py @@ -0,0 +1,57 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +PEFT utilities: Utilities related to peft library +""" + +def convert_state_dict_to_peft(attention_modules, state_dict): + # Convert from the old naming convention to the new naming convention. + # + # Previously, the old LoRA layers were stored on the state dict at the + # same level as the attention block i.e. + # `text_model.encoder.layers.11.self_attn.to_out_lora.lora_A.weight`. + # + # This is no actual module at that point, they were monkey patched on to the + # existing module. We want to be able to load them via their actual state dict. + # They're in `PatchedLoraProjection.lora_linear_layer` now. + converted_state_dict = {} + + for name, _ in attention_modules: + converted_state_dict[ + f"{name}.q_proj.lora_B.weight" + ] = state_dict.pop(f"{name}.to_q_lora.up.weight") + converted_state_dict[ + f"{name}.k_proj.lora_B.weight" + ] = state_dict.pop(f"{name}.to_k_lora.up.weight") + converted_state_dict[ + f"{name}.v_proj.lora_B.weight" + ] = state_dict.pop(f"{name}.to_v_lora.up.weight") + converted_state_dict[ + f"{name}.out_proj.lora_B.weight" + ] = state_dict.pop(f"{name}.to_out_lora.up.weight") + + converted_state_dict[ + f"{name}.q_proj.lora_A.weight" + ] = state_dict.pop(f"{name}.to_q_lora.down.weight") + converted_state_dict[ + f"{name}.k_proj.lora_A.weight" + ] = state_dict.pop(f"{name}.to_k_lora.down.weight") + converted_state_dict[ + f"{name}.v_proj.lora_A.weight" + ] = state_dict.pop(f"{name}.to_v_lora.down.weight") + converted_state_dict[ + f"{name}.out_proj.lora_A.weight" + ] = state_dict.pop(f"{name}.to_out_lora.down.weight") + + return converted_state_dict \ No newline at end of file From 3e742fb8fc5fcaf6f63bfaa3ad8712055463ae9e Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Fri, 25 Aug 2023 13:46:48 +0000 Subject: [PATCH 2/9] now saving works --- src/diffusers/loaders.py | 180 +++++----------------------- src/diffusers/utils/__init__.py | 3 +- src/diffusers/utils/import_utils.py | 9 ++ src/diffusers/utils/peft_utils.py | 70 ++++++++++- 4 files changed, 110 insertions(+), 152 deletions(-) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index 37703ff7a68a..d87ee75df6a6 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -36,8 +36,10 @@ is_accelerate_available, is_omegaconf_available, is_transformers_available, + is_peft_available, logging, - convert_state_dict_to_peft + convert_old_state_dict_to_peft, + convert_diffusers_state_dict_to_peft ) from .utils.import_utils import BACKENDS_MAPPING @@ -49,6 +51,9 @@ from accelerate import init_empty_weights from accelerate.utils import set_module_tensor_to_device +if is_peft_available(): + from peft import LoraConfig + logger = logging.get_logger(__name__) TEXT_ENCODER_NAME = "text_encoder" @@ -904,7 +909,7 @@ class LoraLoaderMixin: text_encoder_name = TEXT_ENCODER_NAME unet_name = UNET_NAME - def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], use_peft=True, **kwargs): + def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs): """ Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.unet` and `self.text_encoder`. @@ -932,7 +937,7 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di network_alphas=network_alphas, text_encoder=self.text_encoder, lora_scale=self.lora_scale, - use_peft=use_peft + adapter_name=adapter_name, ) @classmethod @@ -1212,7 +1217,7 @@ def load_lora_into_unet(cls, state_dict, network_alphas, unet): unet.load_attn_procs(state_dict, network_alphas=network_alphas) @classmethod - def load_lora_into_text_encoder(cls, state_dict, network_alphas, text_encoder, prefix=None, lora_scale=1., use_peft=True): + def load_lora_into_text_encoder(cls, state_dict, network_alphas, text_encoder, prefix=None, lora_scale=1.0): """ This will load the LoRA layers specified in `state_dict` into `text_encoder` @@ -1230,7 +1235,6 @@ def load_lora_into_text_encoder(cls, state_dict, network_alphas, text_encoder, p How much to scale the output of the lora linear layer before it is added with the output of the regular lora layer. """ - from peft import LoraConfig # If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918), # then the `state_dict` keys should have `self.unet_name` and/or `self.text_encoder_name` as # their prefixes. @@ -1249,56 +1253,24 @@ def load_lora_into_text_encoder(cls, state_dict, network_alphas, text_encoder, p logger.info(f"Loading {prefix}.") rank = {} + # Old diffusers to PEFT if any("to_out_lora" in k for k in text_encoder_lora_state_dict.keys()): - if use_peft: - attention_modules = text_encoder_attn_modules(text_encoder) - text_encoder_lora_state_dict = convert_state_dict_to_peft(attention_modules, text_encoder_lora_state_dict) - # # Convert from the old naming convention to the new naming convention. - # # - # # Previously, the old LoRA layers were stored on the state dict at the - # # same level as the attention block i.e. - # # `text_model.encoder.layers.11.self_attn.to_out_lora.up.weight`. - # # - # # This is no actual module at that point, they were monkey patched on to the - # # existing module. We want to be able to load them via their actual state dict. - # # They're in `PatchedLoraProjection.lora_linear_layer` now. - else: - for name, _ in text_encoder_attn_modules(text_encoder): - text_encoder_lora_state_dict[ - f"{name}.q_proj.lora_linear_layer.up.weight" - ] = text_encoder_lora_state_dict.pop(f"{name}.to_q_lora.up.weight") - text_encoder_lora_state_dict[ - f"{name}.k_proj.lora_linear_layer.up.weight" - ] = text_encoder_lora_state_dict.pop(f"{name}.to_k_lora.up.weight") - text_encoder_lora_state_dict[ - f"{name}.v_proj.lora_linear_layer.up.weight" - ] = text_encoder_lora_state_dict.pop(f"{name}.to_v_lora.up.weight") - text_encoder_lora_state_dict[ - f"{name}.out_proj.lora_linear_layer.up.weight" - ] = text_encoder_lora_state_dict.pop(f"{name}.to_out_lora.up.weight") - - text_encoder_lora_state_dict[ - f"{name}.q_proj.lora_linear_layer.down.weight" - ] = text_encoder_lora_state_dict.pop(f"{name}.to_q_lora.down.weight") - text_encoder_lora_state_dict[ - f"{name}.k_proj.lora_linear_layer.down.weight" - ] = text_encoder_lora_state_dict.pop(f"{name}.to_k_lora.down.weight") - text_encoder_lora_state_dict[ - f"{name}.v_proj.lora_linear_layer.down.weight" - ] = text_encoder_lora_state_dict.pop(f"{name}.to_v_lora.down.weight") - text_encoder_lora_state_dict[ - f"{name}.out_proj.lora_linear_layer.down.weight" - ] = text_encoder_lora_state_dict.pop(f"{name}.to_out_lora.down.weight") + attention_modules = text_encoder_attn_modules(text_encoder) + text_encoder_lora_state_dict = convert_old_state_dict_to_peft(attention_modules, text_encoder_lora_state_dict) + # New diffusers format to PEFT + elif any("lora_linear_layer" in k for k in text_encoder_lora_state_dict.keys()): + attention_modules = text_encoder_attn_modules(text_encoder) + text_encoder_lora_state_dict = convert_diffusers_state_dict_to_peft(attention_modules, text_encoder_lora_state_dict) for name, _ in text_encoder_attn_modules(text_encoder): - rank_key = f"{name}.out_proj.lora_B.weight" if use_peft else f"{name}.out_proj.lora_linear_layer.up.weight" + rank_key = f"{name}.out_proj.lora_B.weight" rank.update({rank_key: text_encoder_lora_state_dict[rank_key].shape[1]}) patch_mlp = any(".mlp." in key for key in text_encoder_lora_state_dict.keys()) if patch_mlp: for name, _ in text_encoder_mlp_modules(text_encoder): - rank_key_fc1 = f"{name}.fc1.lora_linear_layer.up.weight" - rank_key_fc2 = f"{name}.fc2.lora_linear_layer.up.weight" + rank_key_fc1 = f"{name}.fc1.lora_B.weight" + rank_key_fc2 = f"{name}.fc2.lora_B.weight" rank.update({rank_key_fc1: text_encoder_lora_state_dict[rank_key_fc1].shape[1]}) rank.update({rank_key_fc2: text_encoder_lora_state_dict[rank_key_fc2].shape[1]}) @@ -1312,36 +1284,20 @@ def load_lora_into_text_encoder(cls, state_dict, network_alphas, text_encoder, p k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys } - if not use_peft: - cls._modify_text_encoder( - text_encoder, - lora_scale, - network_alphas, - rank=rank, - patch_mlp=patch_mlp, - ) + lora_rank = list(rank.values())[0] + alpha = lora_scale * lora_rank - # set correct dtype & device - text_encoder_lora_state_dict = { - k: v.to(device=text_encoder.device, dtype=text_encoder.dtype) - for k, v in text_encoder_lora_state_dict.items() - } - load_state_dict_results = text_encoder.load_state_dict(text_encoder_lora_state_dict, strict=False) - if len(load_state_dict_results.unexpected_keys) != 0: - raise ValueError( - f"failed to load text encoder state dict, unexpected keys: {load_state_dict_results.unexpected_keys}" - ) - else: - lora_rank = list(rank.values())[0] - alpha = lora_scale * lora_rank + target_modules = ["q_proj", "k_proj", "v_proj", "out_proj"] + if patch_mlp: + target_modules += ["fc1", "fc2"] - lora_config = LoraConfig( - r=lora_rank, - target_modules=["q_proj", "k_proj", "v_proj", "out_proj"], - lora_alpha=alpha - ) + lora_config = LoraConfig( + r=lora_rank, + target_modules=target_modules, + lora_alpha=alpha + ) - text_encoder.load_adapter(text_encoder_lora_state_dict, peft_config=lora_config) + text_encoder.load_adapter(text_encoder_lora_state_dict, peft_config=lora_config) @property def lora_scale(self) -> float: @@ -1366,82 +1322,6 @@ def _remove_text_encoder_monkey_patch_classmethod(cls, text_encoder): mlp_module.fc1 = mlp_module.fc1.regular_linear_layer mlp_module.fc2 = mlp_module.fc2.regular_linear_layer - @classmethod - def _modify_text_encoder( - cls, - text_encoder, - lora_scale=1, - network_alphas=None, - rank: Union[Dict[str, int], int] = 4, - dtype=None, - patch_mlp=False, - ): - r""" - Monkey-patches the forward passes of attention modules of the text encoder. - """ - - # First, remove any monkey-patch that might have been applied before - cls._remove_text_encoder_monkey_patch_classmethod(text_encoder) - - lora_parameters = [] - network_alphas = {} if network_alphas is None else network_alphas - is_network_alphas_populated = len(network_alphas) > 0 - - for name, attn_module in text_encoder_attn_modules(text_encoder): - query_alpha = network_alphas.pop(name + ".to_q_lora.down.weight.alpha", None) - key_alpha = network_alphas.pop(name + ".to_k_lora.down.weight.alpha", None) - value_alpha = network_alphas.pop(name + ".to_v_lora.down.weight.alpha", None) - out_alpha = network_alphas.pop(name + ".to_out_lora.down.weight.alpha", None) - - if isinstance(rank, dict): - current_rank = rank.pop(f"{name}.out_proj.lora_linear_layer.up.weight") - else: - current_rank = rank - - attn_module.q_proj = PatchedLoraProjection( - attn_module.q_proj, lora_scale, network_alpha=query_alpha, rank=current_rank, dtype=dtype - ) - lora_parameters.extend(attn_module.q_proj.lora_linear_layer.parameters()) - - attn_module.k_proj = PatchedLoraProjection( - attn_module.k_proj, lora_scale, network_alpha=key_alpha, rank=current_rank, dtype=dtype - ) - lora_parameters.extend(attn_module.k_proj.lora_linear_layer.parameters()) - - attn_module.v_proj = PatchedLoraProjection( - attn_module.v_proj, lora_scale, network_alpha=value_alpha, rank=current_rank, dtype=dtype - ) - lora_parameters.extend(attn_module.v_proj.lora_linear_layer.parameters()) - - attn_module.out_proj = PatchedLoraProjection( - attn_module.out_proj, lora_scale, network_alpha=out_alpha, rank=current_rank, dtype=dtype - ) - lora_parameters.extend(attn_module.out_proj.lora_linear_layer.parameters()) - - if patch_mlp: - for name, mlp_module in text_encoder_mlp_modules(text_encoder): - fc1_alpha = network_alphas.pop(name + ".fc1.lora_linear_layer.down.weight.alpha", None) - fc2_alpha = network_alphas.pop(name + ".fc2.lora_linear_layer.down.weight.alpha", None) - - current_rank_fc1 = rank.pop(f"{name}.fc1.lora_linear_layer.up.weight") - current_rank_fc2 = rank.pop(f"{name}.fc2.lora_linear_layer.up.weight") - - mlp_module.fc1 = PatchedLoraProjection( - mlp_module.fc1, lora_scale, network_alpha=fc1_alpha, rank=current_rank_fc1, dtype=dtype - ) - lora_parameters.extend(mlp_module.fc1.lora_linear_layer.parameters()) - - mlp_module.fc2 = PatchedLoraProjection( - mlp_module.fc2, lora_scale, network_alpha=fc2_alpha, rank=current_rank_fc2, dtype=dtype - ) - lora_parameters.extend(mlp_module.fc2.lora_linear_layer.parameters()) - - if is_network_alphas_populated and len(network_alphas) > 0: - raise ValueError( - f"The `network_alphas` has to be empty at this point but has the following keys \n\n {', '.join(network_alphas.keys())}" - ) - - return lora_parameters @classmethod def save_lora_weights( diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py index a11c3644cf3a..78ac4d227de4 100644 --- a/src/diffusers/utils/__init__.py +++ b/src/diffusers/utils/__init__.py @@ -76,13 +76,14 @@ is_unidecode_available, is_wandb_available, is_xformers_available, + is_peft_available, requires_backends, ) from .logging import get_logger from .outputs import BaseOutput from .pil_utils import PIL_INTERPOLATION, make_image_grid, numpy_to_pil, pt_to_pil from .torch_utils import is_compiled_module, randn_tensor -from .peft_utils import convert_state_dict_to_peft +from .peft_utils import convert_old_state_dict_to_peft, convert_peft_state_dict_to_diffusers, convert_diffusers_state_dict_to_peft if is_torch_available(): diff --git a/src/diffusers/utils/import_utils.py b/src/diffusers/utils/import_utils.py index 7fe5eacb25b0..f31624708a27 100644 --- a/src/diffusers/utils/import_utils.py +++ b/src/diffusers/utils/import_utils.py @@ -215,6 +215,13 @@ except importlib_metadata.PackageNotFoundError: _accelerate_available = False +_peft_available = importlib.util.find_spec("peft") is not None +try: + _accelerate_version = importlib_metadata.version("peft") + logger.debug(f"Successfully imported accelerate version {_accelerate_version}") +except importlib_metadata.PackageNotFoundError: + _peft_available = False + _xformers_available = importlib.util.find_spec("xformers") is not None try: _xformers_version = importlib_metadata.version("xformers") @@ -357,6 +364,8 @@ def is_k_diffusion_available(): def is_note_seq_available(): return _note_seq_available +def is_peft_available(): + return _peft_available def is_wandb_available(): return _wandb_available diff --git a/src/diffusers/utils/peft_utils.py b/src/diffusers/utils/peft_utils.py index f07ed661f19d..acf22bbd3e84 100644 --- a/src/diffusers/utils/peft_utils.py +++ b/src/diffusers/utils/peft_utils.py @@ -15,7 +15,7 @@ PEFT utilities: Utilities related to peft library """ -def convert_state_dict_to_peft(attention_modules, state_dict): +def convert_old_state_dict_to_peft(attention_modules, state_dict): # Convert from the old naming convention to the new naming convention. # # Previously, the old LoRA layers were stored on the state dict at the @@ -54,4 +54,72 @@ def convert_state_dict_to_peft(attention_modules, state_dict): f"{name}.out_proj.lora_A.weight" ] = state_dict.pop(f"{name}.to_out_lora.down.weight") + return converted_state_dict + + +def convert_peft_state_dict_to_diffusers(attention_modules, state_dict, adapter_name): + # Convert from the new naming convention to the diffusers naming convention. + converted_state_dict = {} + + for name, _ in attention_modules: + converted_state_dict[ + f"{name}.q_proj.lora_linear_layer.up.weight" + ] = state_dict.pop(f"{name}.q_proj.lora_B.{adapter_name}.weight") + converted_state_dict[ + f"{name}.k_proj.lora_linear_layer.up.weight" + ] = state_dict.pop(f"{name}.k_proj.lora_B.{adapter_name}.weight") + converted_state_dict[ + f"{name}.v_proj.lora_linear_layer.up.weight" + ] = state_dict.pop(f"{name}.v_proj.lora_B.{adapter_name}.weight") + converted_state_dict[ + f"{name}.out_proj.lora_linear_layer.up.weight" + ] = state_dict.pop(f"{name}.out_proj.lora_B.{adapter_name}.weight") + + converted_state_dict[ + f"{name}.q_proj.lora_linear_layer.down.weight" + ] = state_dict.pop(f"{name}.q_proj.lora_A.{adapter_name}.weight") + converted_state_dict[ + f"{name}.k_proj.lora_linear_layer.down.weight" + ] = state_dict.pop(f"{name}.k_proj.lora_A.{adapter_name}.weight") + converted_state_dict[ + f"{name}.v_proj.lora_linear_layer.down.weight" + ] = state_dict.pop(f"{name}.v_proj.lora_A.{adapter_name}.weight") + converted_state_dict[ + f"{name}.out_proj.lora_linear_layer.down.weight" + ] = state_dict.pop(f"{name}.out_proj.lora_A.{adapter_name}.weight") + + return converted_state_dict + + +def convert_diffusers_state_dict_to_peft(attention_modules, state_dict): + # Convert from the diffusers naming convention to the new naming convention. + converted_state_dict = {} + + for name, _ in attention_modules: + converted_state_dict[ + f"{name}.q_proj.lora_B.weight" + ] = state_dict.pop(f"{name}.q_proj.lora_linear_layer.up.weight") + converted_state_dict[ + f"{name}.k_proj.lora_B.weight" + ] = state_dict.pop(f"{name}.k_proj.lora_linear_layer.up.weight") + converted_state_dict[ + f"{name}.v_proj.lora_B.weight" + ] = state_dict.pop(f"{name}.v_proj.lora_linear_layer.up.weight") + converted_state_dict[ + f"{name}.out_proj.lora_B.weight" + ] = state_dict.pop(f"{name}.out_proj.lora_linear_layer.up.weight") + + converted_state_dict[ + f"{name}.q_proj.lora_A.weight" + ] = state_dict.pop(f"{name}.q_proj.lora_linear_layer.down.weight") + converted_state_dict[ + f"{name}.k_proj.lora_A.weight" + ] = state_dict.pop(f"{name}.k_proj.lora_linear_layer.down.weight") + converted_state_dict[ + f"{name}.v_proj.lora_A.weight" + ] = state_dict.pop(f"{name}.v_proj.lora_linear_layer.down.weight") + converted_state_dict[ + f"{name}.out_proj.lora_A.weight" + ] = state_dict.pop(f"{name}.out_proj.lora_linear_layer.down.weight") + return converted_state_dict \ No newline at end of file From adad99d480feda04bbe18ebb9f6e87db63747ce8 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Fri, 25 Aug 2023 13:53:57 +0000 Subject: [PATCH 3/9] add peft as dep --- setup.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/setup.py b/setup.py index 4c7005329abe..d7741ea7048e 100644 --- a/setup.py +++ b/setup.py @@ -80,6 +80,7 @@ _deps = [ "Pillow", # keep the PIL.Image.Resampling deprecation away "accelerate>=0.11.0", + "peft>=0.5.0", "compel==0.1.8", "black~=23.1", "datasets", @@ -189,7 +190,7 @@ def run(self): extras = {} extras["quality"] = deps_list("urllib3", "black", "isort", "ruff", "hf-doc-builder") extras["docs"] = deps_list("hf-doc-builder") -extras["training"] = deps_list("accelerate", "datasets", "protobuf", "tensorboard", "Jinja2") +extras["training"] = deps_list("accelerate", "datasets", "protobuf", "tensorboard", "Jinja2", "peft") extras["test"] = deps_list( "compel", "datasets", @@ -209,7 +210,7 @@ def run(self): "torchvision", "transformers", ) -extras["torch"] = deps_list("torch", "accelerate") +extras["torch"] = deps_list("torch", "accelerate", "peft") if os.name == "nt": # windows extras["flax"] = [] # jax is not supported on windows From 595eb73c0d67da47c1b6d2e36d5e5f8275b451b1 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Fri, 25 Aug 2023 14:08:00 +0000 Subject: [PATCH 4/9] oops --- src/diffusers/loaders.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index d87ee75df6a6..496b126429cd 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -909,7 +909,7 @@ class LoraLoaderMixin: text_encoder_name = TEXT_ENCODER_NAME unet_name = UNET_NAME - def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs): + def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs): """ Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.unet` and `self.text_encoder`. From ad5480fc93ec4fa95cffba49c42554f2d59b257d Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Fri, 25 Aug 2023 14:15:05 +0000 Subject: [PATCH 5/9] oops --- src/diffusers/loaders.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index 496b126429cd..a2973121c21f 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -937,7 +937,6 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di network_alphas=network_alphas, text_encoder=self.text_encoder, lora_scale=self.lora_scale, - adapter_name=adapter_name, ) @classmethod From cbd8494083b0048d58328ac7804aa98f3095dc94 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Thu, 14 Sep 2023 11:32:10 +0000 Subject: [PATCH 6/9] more fixes --- src/diffusers/loaders.py | 16 ++++++++++++++-- src/diffusers/utils/__init__.py | 2 +- src/diffusers/utils/peft_utils.py | 22 ++++++++++++++++++++++ 3 files changed, 37 insertions(+), 3 deletions(-) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index a2973121c21f..e455b7737cf7 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -39,7 +39,8 @@ is_peft_available, logging, convert_old_state_dict_to_peft, - convert_diffusers_state_dict_to_peft + convert_diffusers_state_dict_to_peft, + convert_unet_state_dict_to_peft, ) from .utils.import_utils import BACKENDS_MAPPING @@ -1213,7 +1214,18 @@ def load_lora_into_unet(cls, state_dict, network_alphas, unet): warnings.warn(warn_message) # load loras into unet - unet.load_attn_procs(state_dict, network_alphas=network_alphas) + # unet.load_attn_procs(state_dict, network_alphas=network_alphas) + from peft import inject_adapter_in_model, LoraConfig, set_peft_model_state_dict + + lora_config = LoraConfig( + r=4, + target_modules=["to_q", "to_k", "to_o"], + ) + + inject_adapter_in_model(lora_config, unet) + state_dict = convert_unet_state_dict_to_peft(state_dict) + + load_results = set_peft_model_state_dict(unet, state_dict) @classmethod def load_lora_into_text_encoder(cls, state_dict, network_alphas, text_encoder, prefix=None, lora_scale=1.0): diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py index 78ac4d227de4..0e03d483e93e 100644 --- a/src/diffusers/utils/__init__.py +++ b/src/diffusers/utils/__init__.py @@ -83,7 +83,7 @@ from .outputs import BaseOutput from .pil_utils import PIL_INTERPOLATION, make_image_grid, numpy_to_pil, pt_to_pil from .torch_utils import is_compiled_module, randn_tensor -from .peft_utils import convert_old_state_dict_to_peft, convert_peft_state_dict_to_diffusers, convert_diffusers_state_dict_to_peft +from .peft_utils import convert_old_state_dict_to_peft, convert_peft_state_dict_to_diffusers, convert_diffusers_state_dict_to_peft, convert_unet_state_dict_to_peft if is_torch_available(): diff --git a/src/diffusers/utils/peft_utils.py b/src/diffusers/utils/peft_utils.py index acf22bbd3e84..995e4b9357ef 100644 --- a/src/diffusers/utils/peft_utils.py +++ b/src/diffusers/utils/peft_utils.py @@ -122,4 +122,26 @@ def convert_diffusers_state_dict_to_peft(attention_modules, state_dict): f"{name}.out_proj.lora_A.weight" ] = state_dict.pop(f"{name}.out_proj.lora_linear_layer.down.weight") + return converted_state_dict + + +def convert_unet_state_dict_to_peft(state_dict): + converted_state_dict = {} + + patterns = { + ".to_out_lora": ".to_o", + ".down": ".lora_A", + ".up": ".lora_B", + ".to_q_lora": ".to_q", + ".to_k_lora": ".to_k", + ".to_v_lora": ".to_v", + } + + for k, v in state_dict.items(): + if any(pattern in k for pattern in patterns.keys()): + for old, new in patterns.items(): + k = k.replace(old, new) + + converted_state_dict[k] = v + return converted_state_dict \ No newline at end of file From be0451f5955743976c15493dcbe0e87884fb24b9 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Fri, 15 Sep 2023 09:09:04 +0000 Subject: [PATCH 7/9] working v1 unet --- src/diffusers/dependency_versions_table.py | 1 + src/diffusers/loaders.py | 45 ++++-- src/diffusers/models/attention_processor.py | 64 ++++---- src/diffusers/models/resnet.py | 24 +-- src/diffusers/models/transformer_2d.py | 19 ++- src/diffusers/utils/__init__.py | 10 +- src/diffusers/utils/import_utils.py | 2 + src/diffusers/utils/peft_utils.py | 170 ++++++++++---------- 8 files changed, 185 insertions(+), 150 deletions(-) diff --git a/src/diffusers/dependency_versions_table.py b/src/diffusers/dependency_versions_table.py index d4b94ba6d4ed..4d3cea199352 100644 --- a/src/diffusers/dependency_versions_table.py +++ b/src/diffusers/dependency_versions_table.py @@ -4,6 +4,7 @@ deps = { "Pillow": "Pillow", "accelerate": "accelerate>=0.11.0", + "peft": "peft>=0.5.0", "compel": "compel==0.1.8", "black": "black~=23.1", "datasets": "datasets", diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index 5e66ddb37e2c..faba2f5433ba 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -31,16 +31,16 @@ DIFFUSERS_CACHE, HF_HUB_OFFLINE, _get_model_file, + convert_diffusers_state_dict_to_peft, + convert_old_state_dict_to_peft, + convert_unet_state_dict_to_peft, deprecate, is_accelerate_available, is_accelerate_version, is_omegaconf_available, - is_transformers_available, is_peft_available, + is_transformers_available, logging, - convert_old_state_dict_to_peft, - convert_diffusers_state_dict_to_peft, - convert_unet_state_dict_to_peft, ) from .utils.import_utils import BACKENDS_MAPPING @@ -1394,17 +1394,28 @@ def load_lora_into_unet(cls, state_dict, network_alphas, unet, low_cpu_mem_usage # load loras into unet # unet.load_attn_procs(state_dict, network_alphas=network_alphas) - from peft import inject_adapter_in_model, LoraConfig, set_peft_model_state_dict + # TODO: deal with network_alphas + from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict + + state_dict, target_modules = convert_unet_state_dict_to_peft(state_dict) lora_config = LoraConfig( r=4, - target_modules=["to_q", "to_k", "to_o"], + lora_alpha=4, + target_modules=target_modules, ) inject_adapter_in_model(lora_config, unet) - state_dict = convert_unet_state_dict_to_peft(state_dict) - load_results = set_peft_model_state_dict(unet, state_dict) + incompatible_keys = set_peft_model_state_dict(unet, state_dict) + + if incompatible_keys is not None: + # check only for unexpected keys + if hasattr(incompatible_keys, "unexpected_keys") and len(incompatible_keys.unexpected_keys) > 0: + logger.warning( + f"Loading adapter weights from state_dict led to unexpected keys not found in the model: " + f" {incompatible_keys.unexpected_keys}. " + ) @classmethod def load_lora_into_text_encoder( @@ -1454,11 +1465,15 @@ def load_lora_into_text_encoder( # Old diffusers to PEFT if any("to_out_lora" in k for k in text_encoder_lora_state_dict.keys()): attention_modules = text_encoder_attn_modules(text_encoder) - text_encoder_lora_state_dict = convert_old_state_dict_to_peft(attention_modules, text_encoder_lora_state_dict) + text_encoder_lora_state_dict = convert_old_state_dict_to_peft( + attention_modules, text_encoder_lora_state_dict + ) # New diffusers format to PEFT elif any("lora_linear_layer" in k for k in text_encoder_lora_state_dict.keys()): attention_modules = text_encoder_attn_modules(text_encoder) - text_encoder_lora_state_dict = convert_diffusers_state_dict_to_peft(attention_modules, text_encoder_lora_state_dict) + text_encoder_lora_state_dict = convert_diffusers_state_dict_to_peft( + attention_modules, text_encoder_lora_state_dict + ) for name, _ in text_encoder_attn_modules(text_encoder): rank_key = f"{name}.out_proj.lora_B.weight" @@ -1472,8 +1487,8 @@ def load_lora_into_text_encoder( rank.update({rank_key_fc1: text_encoder_lora_state_dict[rank_key_fc1].shape[1]}) rank.update({rank_key_fc2: text_encoder_lora_state_dict[rank_key_fc2].shape[1]}) - # for diffusers format you always get the same rank everywhere - # is it possible to load with PEFT + # for diffusers format you always get the same rank everywhere + # is it possible to load with PEFT if network_alphas is not None: alpha_keys = [ k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix @@ -1489,11 +1504,7 @@ def load_lora_into_text_encoder( if patch_mlp: target_modules += ["fc1", "fc2"] - lora_config = LoraConfig( - r=lora_rank, - target_modules=target_modules, - lora_alpha=alpha - ) + lora_config = LoraConfig(r=lora_rank, target_modules=target_modules, lora_alpha=alpha) text_encoder.load_adapter(text_encoder_lora_state_dict, peft_config=lora_config) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 36851085c483..fa4bc4732cc6 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -21,7 +21,7 @@ from ..utils import deprecate, logging from ..utils.import_utils import is_xformers_available from ..utils.torch_utils import maybe_allow_in_graph -from .lora import LoRACompatibleLinear, LoRALinearLayer +from .lora import LoRALinearLayer logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -137,22 +137,28 @@ def __init__( f"unknown cross_attention_norm: {cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'" ) - self.to_q = LoRACompatibleLinear(query_dim, self.inner_dim, bias=bias) + # self.to_q = LoRACompatibleLinear(query_dim, self.inner_dim, bias=bias) + self.to_q = nn.Linear(query_dim, self.inner_dim, bias=bias) if not self.only_cross_attention: # only relevant for the `AddedKVProcessor` classes - self.to_k = LoRACompatibleLinear(self.cross_attention_dim, self.inner_dim, bias=bias) - self.to_v = LoRACompatibleLinear(self.cross_attention_dim, self.inner_dim, bias=bias) + # self.to_k = LoRACompatibleLinear(self.cross_attention_dim, self.inner_dim, bias=bias) + # self.to_v = LoRACompatibleLinear(self.cross_attention_dim, self.inner_dim, bias=bias) + self.to_k = nn.Linear(self.cross_attention_dim, self.inner_dim, bias=bias) + self.to_v = nn.Linear(self.cross_attention_dim, self.inner_dim, bias=bias) else: self.to_k = None self.to_v = None if self.added_kv_proj_dim is not None: - self.add_k_proj = LoRACompatibleLinear(added_kv_proj_dim, self.inner_dim) - self.add_v_proj = LoRACompatibleLinear(added_kv_proj_dim, self.inner_dim) + # self.add_k_proj = LoRACompatibleLinear(added_kv_proj_dim, self.inner_dim) + # self.add_v_proj = LoRACompatibleLinear(added_kv_proj_dim, self.inner_dim) + self.add_k_proj = nn.Linear(added_kv_proj_dim, self.inner_dim) + self.add_v_proj = nn.Linear(added_kv_proj_dim, self.inner_dim) self.to_out = nn.ModuleList([]) - self.to_out.append(LoRACompatibleLinear(self.inner_dim, query_dim, bias=out_bias)) + self.to_out.append(nn.Linear(self.inner_dim, query_dim, bias=out_bias)) + # self.to_out.append(LoRACompatibleLinear(self.inner_dim, query_dim, bias=out_bias)) self.to_out.append(nn.Dropout(dropout)) # set attention processor @@ -559,15 +565,15 @@ def __call__( if attn.group_norm is not None: hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) - query = attn.to_q(hidden_states, scale=scale) + query = attn.to_q(hidden_states) if encoder_hidden_states is None: encoder_hidden_states = hidden_states elif attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) - key = attn.to_k(encoder_hidden_states, scale=scale) - value = attn.to_v(encoder_hidden_states, scale=scale) + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) query = attn.head_to_batch_dim(query) key = attn.head_to_batch_dim(key) @@ -578,7 +584,7 @@ def __call__( hidden_states = attn.batch_to_head_dim(hidden_states) # linear proj - hidden_states = attn.to_out[0](hidden_states, scale=scale) + hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) @@ -711,17 +717,17 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) - query = attn.to_q(hidden_states, scale=scale) + query = attn.to_q(hidden_states) query = attn.head_to_batch_dim(query) - encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states, scale=scale) - encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states, scale=scale) + encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) + encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj) encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj) if not attn.only_cross_attention: - key = attn.to_k(hidden_states, scale=scale) - value = attn.to_v(hidden_states, scale=scale) + key = attn.to_k(hidden_states) + value = attn.to_v(hidden_states) key = attn.head_to_batch_dim(key) value = attn.head_to_batch_dim(value) key = torch.cat([encoder_hidden_states_key_proj, key], dim=1) @@ -735,7 +741,7 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a hidden_states = attn.batch_to_head_dim(hidden_states) # linear proj - hidden_states = attn.to_out[0](hidden_states, scale=scale) + hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) @@ -771,7 +777,7 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) - query = attn.to_q(hidden_states, scale=scale) + query = attn.to_q(hidden_states) query = attn.head_to_batch_dim(query, out_dim=4) encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) @@ -780,8 +786,8 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj, out_dim=4) if not attn.only_cross_attention: - key = attn.to_k(hidden_states, scale=scale) - value = attn.to_v(hidden_states, scale=scale) + key = attn.to_k(hidden_states) + value = attn.to_v(hidden_states) key = attn.head_to_batch_dim(key, out_dim=4) value = attn.head_to_batch_dim(value, out_dim=4) key = torch.cat([encoder_hidden_states_key_proj, key], dim=2) @@ -798,7 +804,7 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, residual.shape[1]) # linear proj - hidden_states = attn.to_out[0](hidden_states, scale=scale) + hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) @@ -926,15 +932,15 @@ def __call__( if attn.group_norm is not None: hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) - query = attn.to_q(hidden_states, scale=scale) + query = attn.to_q(hidden_states) if encoder_hidden_states is None: encoder_hidden_states = hidden_states elif attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) - key = attn.to_k(encoder_hidden_states, scale=scale) - value = attn.to_v(encoder_hidden_states, scale=scale) + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) query = attn.head_to_batch_dim(query).contiguous() key = attn.head_to_batch_dim(key).contiguous() @@ -947,7 +953,7 @@ def __call__( hidden_states = attn.batch_to_head_dim(hidden_states) # linear proj - hidden_states = attn.to_out[0](hidden_states, scale=scale) + hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) @@ -1004,15 +1010,15 @@ def __call__( if attn.group_norm is not None: hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) - query = attn.to_q(hidden_states, scale=scale) + query = attn.to_q(hidden_states) if encoder_hidden_states is None: encoder_hidden_states = hidden_states elif attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) - key = attn.to_k(encoder_hidden_states, scale=scale) - value = attn.to_v(encoder_hidden_states, scale=scale) + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) inner_dim = key.shape[-1] head_dim = inner_dim // attn.heads @@ -1032,7 +1038,7 @@ def __call__( hidden_states = hidden_states.to(query.dtype) # linear proj - hidden_states = attn.to_out[0](hidden_states, scale=scale) + hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py index ac66e2271c61..7eb597481655 100644 --- a/src/diffusers/models/resnet.py +++ b/src/diffusers/models/resnet.py @@ -23,7 +23,7 @@ from .activations import get_activation from .attention import AdaGroupNorm from .attention_processor import SpatialNorm -from .lora import LoRACompatibleConv, LoRACompatibleLinear +from .lora import LoRACompatibleConv class Upsample1D(nn.Module): @@ -544,13 +544,16 @@ def __init__( else: self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True) - self.conv1 = LoRACompatibleConv(in_channels, out_channels, kernel_size=3, stride=1, padding=1) + # self.conv1 = LoRACompatibleConv(in_channels, out_channels, kernel_size=3, stride=1, padding=1) + self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) if temb_channels is not None: if self.time_embedding_norm == "default": - self.time_emb_proj = LoRACompatibleLinear(temb_channels, out_channels) + # self.time_emb_proj = LoRACompatibleLinear(temb_channels, out_channels) + self.time_emb_proj = nn.Linear(temb_channels, out_channels) elif self.time_embedding_norm == "scale_shift": - self.time_emb_proj = LoRACompatibleLinear(temb_channels, 2 * out_channels) + # self.time_emb_proj = LoRACompatibleLinear(temb_channels, out_channels) + self.time_emb_proj = nn.Linear(temb_channels, 2 * out_channels) elif self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial": self.time_emb_proj = None else: @@ -593,7 +596,10 @@ def __init__( self.conv_shortcut = None if self.use_in_shortcut: - self.conv_shortcut = LoRACompatibleConv( + # self.conv_shortcut = LoRACompatibleConv( + # in_channels, conv_2d_out_channels, kernel_size=1, stride=1, padding=0, bias=conv_shortcut_bias + # ) + self.conv_shortcut = nn.Conv2d( in_channels, conv_2d_out_channels, kernel_size=1, stride=1, padding=0, bias=conv_shortcut_bias ) @@ -634,12 +640,12 @@ def forward(self, input_tensor, temb, scale: float = 1.0): else self.downsample(hidden_states) ) - hidden_states = self.conv1(hidden_states, scale) + hidden_states = self.conv1(hidden_states) if self.time_emb_proj is not None: if not self.skip_time_act: temb = self.nonlinearity(temb) - temb = self.time_emb_proj(temb, scale)[:, :, None, None] + temb = self.time_emb_proj(temb)[:, :, None, None] if temb is not None and self.time_embedding_norm == "default": hidden_states = hidden_states + temb @@ -656,10 +662,10 @@ def forward(self, input_tensor, temb, scale: float = 1.0): hidden_states = self.nonlinearity(hidden_states) hidden_states = self.dropout(hidden_states) - hidden_states = self.conv2(hidden_states, scale) + hidden_states = self.conv2(hidden_states) if self.conv_shortcut is not None: - input_tensor = self.conv_shortcut(input_tensor, scale) + input_tensor = self.conv_shortcut(input_tensor) output_tensor = (input_tensor + hidden_states) / self.output_scale_factor diff --git a/src/diffusers/models/transformer_2d.py b/src/diffusers/models/transformer_2d.py index c96aef65f339..ae8ed8adb5f3 100644 --- a/src/diffusers/models/transformer_2d.py +++ b/src/diffusers/models/transformer_2d.py @@ -23,7 +23,6 @@ from ..utils import BaseOutput, deprecate from .attention import BasicTransformerBlock from .embeddings import PatchEmbed -from .lora import LoRACompatibleConv, LoRACompatibleLinear from .modeling_utils import ModelMixin @@ -139,9 +138,9 @@ def __init__( self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True) if use_linear_projection: - self.proj_in = LoRACompatibleLinear(in_channels, inner_dim) + self.proj_in = nn.Linear(in_channels, inner_dim) else: - self.proj_in = LoRACompatibleConv(in_channels, inner_dim, kernel_size=1, stride=1, padding=0) + self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0) elif self.is_input_vectorized: assert sample_size is not None, "Transformer2DModel over discrete input must provide sample_size" assert num_vector_embeds is not None, "Transformer2DModel over discrete input must provide num_embed" @@ -197,9 +196,9 @@ def __init__( if self.is_input_continuous: # TODO: should use out_channels for continuous projections if use_linear_projection: - self.proj_out = LoRACompatibleLinear(inner_dim, in_channels) + self.proj_out = nn.Linear(inner_dim, in_channels) else: - self.proj_out = LoRACompatibleConv(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) + self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) elif self.is_input_vectorized: self.norm_out = nn.LayerNorm(inner_dim) self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1) @@ -275,7 +274,7 @@ def forward( encoder_attention_mask = encoder_attention_mask.unsqueeze(1) # Retrieve lora scale. - lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 + cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 # 1. Input if self.is_input_continuous: @@ -284,13 +283,13 @@ def forward( hidden_states = self.norm(hidden_states) if not self.use_linear_projection: - hidden_states = self.proj_in(hidden_states, scale=lora_scale) + hidden_states = self.proj_in(hidden_states) inner_dim = hidden_states.shape[1] hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim) else: inner_dim = hidden_states.shape[1] hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim) - hidden_states = self.proj_in(hidden_states, scale=lora_scale) + hidden_states = self.proj_in(hidden_states) elif self.is_input_vectorized: hidden_states = self.latent_image_embedding(hidden_states) @@ -326,9 +325,9 @@ def forward( if self.is_input_continuous: if not self.use_linear_projection: hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() - hidden_states = self.proj_out(hidden_states, scale=lora_scale) + hidden_states = self.proj_out(hidden_states) else: - hidden_states = self.proj_out(hidden_states, scale=lora_scale) + hidden_states = self.proj_out(hidden_states) hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() output = hidden_states + residual diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py index 80a234445c0d..1a0500b1b0bd 100644 --- a/src/diffusers/utils/__init__.py +++ b/src/diffusers/utils/__init__.py @@ -67,6 +67,7 @@ is_note_seq_available, is_omegaconf_available, is_onnx_available, + is_peft_available, is_scipy_available, is_tensorboard_available, is_torch_available, @@ -77,15 +78,18 @@ is_unidecode_available, is_wandb_available, is_xformers_available, - is_peft_available, requires_backends, ) from .loading_utils import load_image from .logging import get_logger from .outputs import BaseOutput +from .peft_utils import ( + convert_diffusers_state_dict_to_peft, + convert_old_state_dict_to_peft, + convert_peft_state_dict_to_diffusers, + convert_unet_state_dict_to_peft, +) from .pil_utils import PIL_INTERPOLATION, make_image_grid, numpy_to_pil, pt_to_pil -from .peft_utils import convert_old_state_dict_to_peft, convert_peft_state_dict_to_diffusers, convert_diffusers_state_dict_to_peft, convert_unet_state_dict_to_peft - logger = get_logger(__name__) diff --git a/src/diffusers/utils/import_utils.py b/src/diffusers/utils/import_utils.py index 40e588b99a02..a60ffd582472 100644 --- a/src/diffusers/utils/import_utils.py +++ b/src/diffusers/utils/import_utils.py @@ -325,9 +325,11 @@ def is_k_diffusion_available(): def is_note_seq_available(): return _note_seq_available + def is_peft_available(): return _peft_available + def is_wandb_available(): return _wandb_available diff --git a/src/diffusers/utils/peft_utils.py b/src/diffusers/utils/peft_utils.py index 995e4b9357ef..d9a8d508f631 100644 --- a/src/diffusers/utils/peft_utils.py +++ b/src/diffusers/utils/peft_utils.py @@ -15,6 +15,21 @@ PEFT utilities: Utilities related to peft library """ + +class PeftAdapterMixin: + r""" + Mixin class that contains the useful methods for leveraging PEFT library to load and use adapters + """ + _is_peft_adapter_loaded = False + + def add_adapter( + self, + peft_config, + ): + if not getattr(self, "_is_peft_adapter_loaded", False): + pass + + def convert_old_state_dict_to_peft(attention_modules, state_dict): # Convert from the old naming convention to the new naming convention. # @@ -26,122 +41,113 @@ def convert_old_state_dict_to_peft(attention_modules, state_dict): # existing module. We want to be able to load them via their actual state dict. # They're in `PatchedLoraProjection.lora_linear_layer` now. converted_state_dict = {} - + for name, _ in attention_modules: - converted_state_dict[ - f"{name}.q_proj.lora_B.weight" - ] = state_dict.pop(f"{name}.to_q_lora.up.weight") - converted_state_dict[ - f"{name}.k_proj.lora_B.weight" - ] = state_dict.pop(f"{name}.to_k_lora.up.weight") - converted_state_dict[ - f"{name}.v_proj.lora_B.weight" - ] = state_dict.pop(f"{name}.to_v_lora.up.weight") - converted_state_dict[ - f"{name}.out_proj.lora_B.weight" - ] = state_dict.pop(f"{name}.to_out_lora.up.weight") - - converted_state_dict[ - f"{name}.q_proj.lora_A.weight" - ] = state_dict.pop(f"{name}.to_q_lora.down.weight") - converted_state_dict[ - f"{name}.k_proj.lora_A.weight" - ] = state_dict.pop(f"{name}.to_k_lora.down.weight") - converted_state_dict[ - f"{name}.v_proj.lora_A.weight" - ] = state_dict.pop(f"{name}.to_v_lora.down.weight") - converted_state_dict[ - f"{name}.out_proj.lora_A.weight" - ] = state_dict.pop(f"{name}.to_out_lora.down.weight") - + converted_state_dict[f"{name}.q_proj.lora_B.weight"] = state_dict.pop(f"{name}.to_q_lora.up.weight") + converted_state_dict[f"{name}.k_proj.lora_B.weight"] = state_dict.pop(f"{name}.to_k_lora.up.weight") + converted_state_dict[f"{name}.v_proj.lora_B.weight"] = state_dict.pop(f"{name}.to_v_lora.up.weight") + converted_state_dict[f"{name}.out_proj.lora_B.weight"] = state_dict.pop(f"{name}.to_out_lora.up.weight") + + converted_state_dict[f"{name}.q_proj.lora_A.weight"] = state_dict.pop(f"{name}.to_q_lora.down.weight") + converted_state_dict[f"{name}.k_proj.lora_A.weight"] = state_dict.pop(f"{name}.to_k_lora.down.weight") + converted_state_dict[f"{name}.v_proj.lora_A.weight"] = state_dict.pop(f"{name}.to_v_lora.down.weight") + converted_state_dict[f"{name}.out_proj.lora_A.weight"] = state_dict.pop(f"{name}.to_out_lora.down.weight") + return converted_state_dict def convert_peft_state_dict_to_diffusers(attention_modules, state_dict, adapter_name): # Convert from the new naming convention to the diffusers naming convention. converted_state_dict = {} - + for name, _ in attention_modules: - converted_state_dict[ + converted_state_dict[f"{name}.q_proj.lora_linear_layer.up.weight"] = state_dict.pop( + f"{name}.q_proj.lora_B.{adapter_name}.weight" + ) + converted_state_dict[f"{name}.k_proj.lora_linear_layer.up.weight"] = state_dict.pop( + f"{name}.k_proj.lora_B.{adapter_name}.weight" + ) + converted_state_dict[f"{name}.v_proj.lora_linear_layer.up.weight"] = state_dict.pop( + f"{name}.v_proj.lora_B.{adapter_name}.weight" + ) + converted_state_dict[f"{name}.out_proj.lora_linear_layer.up.weight"] = state_dict.pop( + f"{name}.out_proj.lora_B.{adapter_name}.weight" + ) + + converted_state_dict[f"{name}.q_proj.lora_linear_layer.down.weight"] = state_dict.pop( + f"{name}.q_proj.lora_A.{adapter_name}.weight" + ) + converted_state_dict[f"{name}.k_proj.lora_linear_layer.down.weight"] = state_dict.pop( + f"{name}.k_proj.lora_A.{adapter_name}.weight" + ) + converted_state_dict[f"{name}.v_proj.lora_linear_layer.down.weight"] = state_dict.pop( + f"{name}.v_proj.lora_A.{adapter_name}.weight" + ) + converted_state_dict[f"{name}.out_proj.lora_linear_layer.down.weight"] = state_dict.pop( + f"{name}.out_proj.lora_A.{adapter_name}.weight" + ) + + return converted_state_dict + + +def convert_diffusers_state_dict_to_peft(attention_modules, state_dict): + # Convert from the diffusers naming convention to the new naming convention. + converted_state_dict = {} + + for name, _ in attention_modules: + converted_state_dict[f"{name}.q_proj.lora_B.weight"] = state_dict.pop( f"{name}.q_proj.lora_linear_layer.up.weight" - ] = state_dict.pop(f"{name}.q_proj.lora_B.{adapter_name}.weight") - converted_state_dict[ + ) + converted_state_dict[f"{name}.k_proj.lora_B.weight"] = state_dict.pop( f"{name}.k_proj.lora_linear_layer.up.weight" - ] = state_dict.pop(f"{name}.k_proj.lora_B.{adapter_name}.weight") - converted_state_dict[ + ) + converted_state_dict[f"{name}.v_proj.lora_B.weight"] = state_dict.pop( f"{name}.v_proj.lora_linear_layer.up.weight" - ] = state_dict.pop(f"{name}.v_proj.lora_B.{adapter_name}.weight") - converted_state_dict[ + ) + converted_state_dict[f"{name}.out_proj.lora_B.weight"] = state_dict.pop( f"{name}.out_proj.lora_linear_layer.up.weight" - ] = state_dict.pop(f"{name}.out_proj.lora_B.{adapter_name}.weight") + ) - converted_state_dict[ + converted_state_dict[f"{name}.q_proj.lora_A.weight"] = state_dict.pop( f"{name}.q_proj.lora_linear_layer.down.weight" - ] = state_dict.pop(f"{name}.q_proj.lora_A.{adapter_name}.weight") - converted_state_dict[ + ) + converted_state_dict[f"{name}.k_proj.lora_A.weight"] = state_dict.pop( f"{name}.k_proj.lora_linear_layer.down.weight" - ] = state_dict.pop(f"{name}.k_proj.lora_A.{adapter_name}.weight") - converted_state_dict[ + ) + converted_state_dict[f"{name}.v_proj.lora_A.weight"] = state_dict.pop( f"{name}.v_proj.lora_linear_layer.down.weight" - ] = state_dict.pop(f"{name}.v_proj.lora_A.{adapter_name}.weight") - converted_state_dict[ + ) + converted_state_dict[f"{name}.out_proj.lora_A.weight"] = state_dict.pop( f"{name}.out_proj.lora_linear_layer.down.weight" - ] = state_dict.pop(f"{name}.out_proj.lora_A.{adapter_name}.weight") - - return converted_state_dict - + ) -def convert_diffusers_state_dict_to_peft(attention_modules, state_dict): - # Convert from the diffusers naming convention to the new naming convention. - converted_state_dict = {} - - for name, _ in attention_modules: - converted_state_dict[ - f"{name}.q_proj.lora_B.weight" - ] = state_dict.pop(f"{name}.q_proj.lora_linear_layer.up.weight") - converted_state_dict[ - f"{name}.k_proj.lora_B.weight" - ] = state_dict.pop(f"{name}.k_proj.lora_linear_layer.up.weight") - converted_state_dict[ - f"{name}.v_proj.lora_B.weight" - ] = state_dict.pop(f"{name}.v_proj.lora_linear_layer.up.weight") - converted_state_dict[ - f"{name}.out_proj.lora_B.weight" - ] = state_dict.pop(f"{name}.out_proj.lora_linear_layer.up.weight") - - converted_state_dict[ - f"{name}.q_proj.lora_A.weight" - ] = state_dict.pop(f"{name}.q_proj.lora_linear_layer.down.weight") - converted_state_dict[ - f"{name}.k_proj.lora_A.weight" - ] = state_dict.pop(f"{name}.k_proj.lora_linear_layer.down.weight") - converted_state_dict[ - f"{name}.v_proj.lora_A.weight" - ] = state_dict.pop(f"{name}.v_proj.lora_linear_layer.down.weight") - converted_state_dict[ - f"{name}.out_proj.lora_A.weight" - ] = state_dict.pop(f"{name}.out_proj.lora_linear_layer.down.weight") - return converted_state_dict def convert_unet_state_dict_to_peft(state_dict): converted_state_dict = {} + target_modules = [] patterns = { - ".to_out_lora": ".to_o", + ".to_out_lora": ".to_out.0", ".down": ".lora_A", ".up": ".lora_B", ".to_q_lora": ".to_q", ".to_k_lora": ".to_k", ".to_v_lora": ".to_v", + ".processor.": ".", } for k, v in state_dict.items(): + pattern_found = False + if any(pattern in k for pattern in patterns.keys()): for old, new in patterns.items(): k = k.replace(old, new) - + pattern_found = True + converted_state_dict[k] = v - - return converted_state_dict \ No newline at end of file + if pattern_found: + target_modules.append(".".join(k.split(".")[:-2])) + + return converted_state_dict, target_modules From 61e3983024de55f1bbd1161307680420df32553a Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Fri, 15 Sep 2023 09:10:21 +0000 Subject: [PATCH 8/9] add comment --- src/diffusers/loaders.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index faba2f5433ba..3c122cfc3542 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -1394,7 +1394,7 @@ def load_lora_into_unet(cls, state_dict, network_alphas, unet, low_cpu_mem_usage # load loras into unet # unet.load_attn_procs(state_dict, network_alphas=network_alphas) - # TODO: deal with network_alphas + # TODO: @younesbelkada deal with network_alphas from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict state_dict, target_modules = convert_unet_state_dict_to_peft(state_dict) From 8ccaafc44a329f563ec7be1724108bbe393c7987 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Fri, 15 Sep 2023 10:49:35 +0000 Subject: [PATCH 9/9] v1 scale - need to dig why I am not getting the same output --- src/diffusers/loaders.py | 2 +- src/diffusers/models/attention.py | 14 +++------ src/diffusers/models/attention_processor.py | 34 ++++++++++++++++----- src/diffusers/models/resnet.py | 30 +++++------------- src/diffusers/models/unet_2d_condition.py | 1 + 5 files changed, 41 insertions(+), 40 deletions(-) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index 3c122cfc3542..5a68cc0a4a3f 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -1393,7 +1393,6 @@ def load_lora_into_unet(cls, state_dict, network_alphas, unet, low_cpu_mem_usage warnings.warn(warn_message) # load loras into unet - # unet.load_attn_procs(state_dict, network_alphas=network_alphas) # TODO: @younesbelkada deal with network_alphas from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict @@ -1408,6 +1407,7 @@ def load_lora_into_unet(cls, state_dict, network_alphas, unet, low_cpu_mem_usage inject_adapter_in_model(lora_config, unet) incompatible_keys = set_peft_model_state_dict(unet, state_dict) + unet._is_peft_loaded = True if incompatible_keys is not None: # check only for unexpected keys diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 892d44a03137..d5fc2850b86c 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -21,7 +21,6 @@ from .activations import get_activation from .attention_processor import Attention from .embeddings import CombinedTimestepLabelEmbeddings -from .lora import LoRACompatibleLinear @maybe_allow_in_graph @@ -296,17 +295,14 @@ def __init__( # project dropout self.net.append(nn.Dropout(dropout)) # project out - self.net.append(LoRACompatibleLinear(inner_dim, dim_out)) + self.net.append(nn.Linear(inner_dim, dim_out)) # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout if final_dropout: self.net.append(nn.Dropout(dropout)) def forward(self, hidden_states, scale: float = 1.0): for module in self.net: - if isinstance(module, (LoRACompatibleLinear, GEGLU)): - hidden_states = module(hidden_states, scale) - else: - hidden_states = module(hidden_states) + hidden_states = module(hidden_states) return hidden_states @@ -343,7 +339,7 @@ class GEGLU(nn.Module): def __init__(self, dim_in: int, dim_out: int): super().__init__() - self.proj = LoRACompatibleLinear(dim_in, dim_out * 2) + self.proj = nn.Linear(dim_in, dim_out * 2) def gelu(self, gate): if gate.device.type != "mps": @@ -351,8 +347,8 @@ def gelu(self, gate): # mps: gelu is not implemented for float16 return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype) - def forward(self, hidden_states, scale: float = 1.0): - hidden_states, gate = self.proj(hidden_states, scale).chunk(2, dim=-1) + def forward(self, hidden_states): + hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1) return hidden_states * self.gelu(gate) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index fa4bc4732cc6..e16c90b61a8a 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -137,13 +137,10 @@ def __init__( f"unknown cross_attention_norm: {cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'" ) - # self.to_q = LoRACompatibleLinear(query_dim, self.inner_dim, bias=bias) self.to_q = nn.Linear(query_dim, self.inner_dim, bias=bias) if not self.only_cross_attention: # only relevant for the `AddedKVProcessor` classes - # self.to_k = LoRACompatibleLinear(self.cross_attention_dim, self.inner_dim, bias=bias) - # self.to_v = LoRACompatibleLinear(self.cross_attention_dim, self.inner_dim, bias=bias) self.to_k = nn.Linear(self.cross_attention_dim, self.inner_dim, bias=bias) self.to_v = nn.Linear(self.cross_attention_dim, self.inner_dim, bias=bias) else: @@ -151,14 +148,11 @@ def __init__( self.to_v = None if self.added_kv_proj_dim is not None: - # self.add_k_proj = LoRACompatibleLinear(added_kv_proj_dim, self.inner_dim) - # self.add_v_proj = LoRACompatibleLinear(added_kv_proj_dim, self.inner_dim) self.add_k_proj = nn.Linear(added_kv_proj_dim, self.inner_dim) self.add_v_proj = nn.Linear(added_kv_proj_dim, self.inner_dim) self.to_out = nn.ModuleList([]) self.to_out.append(nn.Linear(self.inner_dim, query_dim, bias=out_bias)) - # self.to_out.append(LoRACompatibleLinear(self.inner_dim, query_dim, bias=out_bias)) self.to_out.append(nn.Dropout(dropout)) # set attention processor @@ -419,11 +413,33 @@ def get_processor(self, return_deprecated_lora: bool = False) -> "AttentionProce return lora_processor + def scale_peft_lora_layers(self, scale: float = 1.0): + from peft.tuners.lora import LoraLayer + + total_modules_to_scale = list(self.modules()) + + for module in total_modules_to_scale: + if isinstance(module, LoraLayer): + module.scale_layer(scale) + + def unscale_peft_lora_layers(self, scale: float = 1.0): + from peft.tuners.lora import LoraLayer + + total_modules_to_unscale = list(self.modules()) + + for module in total_modules_to_unscale: + if isinstance(module, LoraLayer): + module.unscale_layer(scale) + def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, **cross_attention_kwargs): + # retrieve the scale of LoRA layers and optionnaly scale / unscale them + scale = cross_attention_kwargs.get("scale", 1.0) + self.scale_peft_lora_layers(scale) + # The `Attention` class can call different attention processors / attention functions # here we simply pass along all tensors to the selected processor class # For standard processors that are defined here, `**cross_attention_kwargs` is empty - return self.processor( + output = self.processor( self, hidden_states, encoder_hidden_states=encoder_hidden_states, @@ -431,6 +447,10 @@ def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None **cross_attention_kwargs, ) + # unscale operation in case + self.unscale_peft_lora_layers(scale) + return output + def batch_to_head_dim(self, tensor): head_size = self.heads batch_size, seq_len, dim = tensor.shape diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py index 7eb597481655..3278987703d9 100644 --- a/src/diffusers/models/resnet.py +++ b/src/diffusers/models/resnet.py @@ -23,7 +23,6 @@ from .activations import get_activation from .attention import AdaGroupNorm from .attention_processor import SpatialNorm -from .lora import LoRACompatibleConv class Upsample1D(nn.Module): @@ -127,7 +126,7 @@ def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_chann if use_conv_transpose: conv = nn.ConvTranspose2d(channels, self.out_channels, 4, 2, 1) elif use_conv: - conv = LoRACompatibleConv(self.channels, self.out_channels, 3, padding=1) + conv = nn.Conv2d(self.channels, self.out_channels, 3, padding=1) # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed if name == "conv": @@ -166,15 +165,9 @@ def forward(self, hidden_states, output_size=None, scale: float = 1.0): # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed if self.use_conv: if self.name == "conv": - if isinstance(self.conv, LoRACompatibleConv): - hidden_states = self.conv(hidden_states, scale) - else: - hidden_states = self.conv(hidden_states) + hidden_states = self.conv(hidden_states) else: - if isinstance(self.Conv2d_0, LoRACompatibleConv): - hidden_states = self.Conv2d_0(hidden_states, scale) - else: - hidden_states = self.Conv2d_0(hidden_states) + hidden_states = self.Conv2d_0(hidden_states) return hidden_states @@ -203,7 +196,7 @@ def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name= self.name = name if use_conv: - conv = LoRACompatibleConv(self.channels, self.out_channels, 3, stride=stride, padding=padding) + conv = nn.Conv2d(self.channels, self.out_channels, 3, stride=stride, padding=padding) else: assert self.channels == self.out_channels conv = nn.AvgPool2d(kernel_size=stride, stride=stride) @@ -219,16 +212,13 @@ def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name= def forward(self, hidden_states, scale: float = 1.0): assert hidden_states.shape[1] == self.channels + if self.use_conv and self.padding == 0: pad = (0, 1, 0, 1) hidden_states = F.pad(hidden_states, pad, mode="constant", value=0) assert hidden_states.shape[1] == self.channels - if isinstance(self.conv, LoRACompatibleConv): - hidden_states = self.conv(hidden_states, scale) - else: - hidden_states = self.conv(hidden_states) - + hidden_states = self.conv(hidden_states) return hidden_states @@ -544,15 +534,12 @@ def __init__( else: self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True) - # self.conv1 = LoRACompatibleConv(in_channels, out_channels, kernel_size=3, stride=1, padding=1) self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) if temb_channels is not None: if self.time_embedding_norm == "default": - # self.time_emb_proj = LoRACompatibleLinear(temb_channels, out_channels) self.time_emb_proj = nn.Linear(temb_channels, out_channels) elif self.time_embedding_norm == "scale_shift": - # self.time_emb_proj = LoRACompatibleLinear(temb_channels, out_channels) self.time_emb_proj = nn.Linear(temb_channels, 2 * out_channels) elif self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial": self.time_emb_proj = None @@ -570,7 +557,7 @@ def __init__( self.dropout = torch.nn.Dropout(dropout) conv_2d_out_channels = conv_2d_out_channels or out_channels - self.conv2 = LoRACompatibleConv(out_channels, conv_2d_out_channels, kernel_size=3, stride=1, padding=1) + self.conv2 = nn.Conv2d(out_channels, conv_2d_out_channels, kernel_size=3, stride=1, padding=1) self.nonlinearity = get_activation(non_linearity) @@ -596,9 +583,6 @@ def __init__( self.conv_shortcut = None if self.use_in_shortcut: - # self.conv_shortcut = LoRACompatibleConv( - # in_channels, conv_2d_out_channels, kernel_size=1, stride=1, padding=0, bias=conv_shortcut_bias - # ) self.conv_shortcut = nn.Conv2d( in_channels, conv_2d_out_channels, kernel_size=1, stride=1, padding=0, bias=conv_shortcut_bias ) diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index d695d182fa37..5d05f26357e1 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -731,6 +731,7 @@ def _set_gradient_checkpointing(self, module, value=False): if hasattr(module, "gradient_checkpointing"): module.gradient_checkpointing = value + def forward( self, sample: torch.FloatTensor,