From 9ebee45e3167b468c423d19c9c91d7202f136341 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Fri, 21 Apr 2023 09:18:00 +0530 Subject: [PATCH 1/7] =?UTF-8?q?=F0=9F=91=BD=20qol=20improvements=20for=20L?= =?UTF-8?q?oRA.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- docs/source/en/training/lora.mdx | 37 +++++++++++++++++++++++++++++--- examples/dreambooth/README.md | 29 ++++++++++++++++++++++++- examples/text_to_image/README.md | 15 +++++++++++++ src/diffusers/loaders.py | 26 ++++++++++++++++++---- 4 files changed, 99 insertions(+), 8 deletions(-) diff --git a/docs/source/en/training/lora.mdx b/docs/source/en/training/lora.mdx index ac2311df9f1e..474734904071 100644 --- a/docs/source/en/training/lora.mdx +++ b/docs/source/en/training/lora.mdx @@ -113,7 +113,7 @@ Load the LoRA weights from your finetuned model *on top of the base model weight ```py ->>> pipe.unet.load_attn_procs(model_path) +>>> pipe.unet.load_attn_procs(lora_model_path) >>> pipe.to("cuda") # use half the weights from the LoRA finetuned model and half the weights from the base model @@ -126,6 +126,25 @@ Load the LoRA weights from your finetuned model *on top of the base model weight >>> image.save("blue_pokemon.png") ``` + + +If you are loading the LoRA parameters from the Hub and if the Hub repository has +a `base_model` tag (such as [this](https://huggingface.co/sayakpaul/sd-model-finetuned-lora-t4/blob/main/README.md?code=true#L4)), then +you can do: + +```py +from huggingface_hub.repocard import RepoCard + +lora_model_id = "sayakpaul/sd-model-finetuned-lora-t4" +card = RepoCard.load(lora_model_id) +base_model_id = card.data.to_dict()["base_model"] + +pipe = StableDiffusionPipeline.from_pretrained(base_model_id, torch_dtype=torch.float16) +... +``` + + + ## DreamBooth [DreamBooth](https://arxiv.org/abs/2208.12242) is a finetuning technique for personalizing a text-to-image model like Stable Diffusion to generate photorealistic images of a subject in different contexts, given a few images of the subject. However, DreamBooth is very sensitive to hyperparameters and it is easy to overfit. Some important hyperparameters to consider include those that affect the training time (learning rate, number of training steps), and inference time (number of steps, scheduler type). @@ -204,7 +223,7 @@ Load the LoRA weights from your finetuned DreamBooth model *on top of the base m ```py ->>> pipe.unet.load_attn_procs(model_path) +>>> pipe.unet.load_attn_procs(lora_model_path) >>> pipe.to("cuda") # use half the weights from the LoRA finetuned model and half the weights from the base model @@ -218,4 +237,16 @@ Load the LoRA weights from your finetuned DreamBooth model *on top of the base m >>> image = pipe("A picture of a sks dog in a bucket.", num_inference_steps=25, guidance_scale=7.5).images[0] >>> image.save("bucket-dog.png") -``` \ No newline at end of file +``` + +Note that we will gradually be depcrecating the use of [`UNet2DConditionLoadersMixin.load_attn_procs`] since we now have a more general +method to load the LoRA parameters -- [`LoraLoaderMixin.load_lora_weights`]. This is because +[`LoraLoaderMixin.load_lora_weights`] can handle the following situations: + +* LoRA parameters that don't have separate identifiers for the UNet and the text encoder (such as [`"patrickvonplaten/lora_dreambooth_dog_example"`](https://huggingface.co/patrickvonplaten/lora_dreambooth_dog_example)). So, you can just do: + + ```py + pipe.load_lora_weights(lora_model_path) + ``` + +* LoRA parameters that have separate identifiers for the UNet and the text encoder such as: [`"sayakpaul/dreambooth"`](https://huggingface.co/sayakpaul/dreambooth). \ No newline at end of file diff --git a/examples/dreambooth/README.md b/examples/dreambooth/README.md index 8447c7560720..e1eb8a06b0ff 100644 --- a/examples/dreambooth/README.md +++ b/examples/dreambooth/README.md @@ -355,7 +355,7 @@ The final LoRA embedding weights have been uploaded to [patrickvonplaten/lora_dr The training results are summarized [here](https://api.wandb.ai/report/patrickvonplaten/xm6cd5q5). You can use the `Step` slider to see how the model learned the features of our subject while the model trained. -Optionally, we can also train additional LoRA layers for the text encoder. Specify the `train_text_encoder` argument above for that. If you're interested to know more about how we +Optionally, we can also train additional LoRA layers for the text encoder. Specify the `--train_text_encoder` argument above for that. If you're interested to know more about how we enable this support, check out this [PR](https://github.com/huggingface/diffusers/pull/2918). With the default hyperparameters from the above, the training seems to go in a positive direction. Check out [this panel](https://wandb.ai/sayakpaul/dreambooth-lora/reports/test-23-04-17-17-00-13---Vmlldzo0MDkwNjMy). The trained LoRA layers are available [here](https://huggingface.co/sayakpaul/dreambooth). @@ -387,6 +387,33 @@ Finally, we can run the model in inference. image = pipe("A picture of a sks dog in a bucket", num_inference_steps=25).images[0] ``` +If you are loading the LoRA parameters from the Hub and if the Hub repository has +a `base_model` tag (such as [this](https://huggingface.co/patrickvonplaten/lora_dreambooth_dog_example/blob/main/README.md?code=true#L4)), then +you can do: + +```py +from huggingface_hub.repocard import RepoCard + +lora_model_id = "patrickvonplaten/lora_dreambooth_dog_example" +card = RepoCard.load(lora_model_id) +base_model_id = card.data.to_dict()["base_model"] + +pipe = StableDiffusionPipeline.from_pretrained(base_model_id, torch_dtype=torch.float16) +... +``` + +**Note** that we will gradually be depcrecating the use of [`UNet2DConditionLoadersMixin.load_attn_procs`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.UNet2DConditionLoadersMixin.load_attn_procs) since we now have a more general +method to load the LoRA parameters -- [`LoraLoaderMixin.load_lora_weights`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraLoaderMixin.load_lora_weights). This is because +[`LoraLoaderMixin.load_lora_weights`] can handle the following situations: + +* LoRA parameters that don't have separate identifiers for the UNet and the text encoder (such as [`"patrickvonplaten/lora_dreambooth_dog_example"`](https://huggingface.co/patrickvonplaten/lora_dreambooth_dog_example)). So, you can just do: + + ```py + pipe.load_lora_weights(lora_model_path) + ``` + +* LoRA parameters that have separate identifiers for the UNet and the text encoder such as: [`"sayakpaul/dreambooth"`](https://huggingface.co/sayakpaul/dreambooth). + ## Training with Flax/JAX For faster training on TPUs and GPUs you can leverage the flax training example. Follow the instructions above to get the model and dataset before running the script. diff --git a/examples/text_to_image/README.md b/examples/text_to_image/README.md index 406a64b3759f..160e73fa02bb 100644 --- a/examples/text_to_image/README.md +++ b/examples/text_to_image/README.md @@ -229,6 +229,21 @@ image = pipe(prompt, num_inference_steps=30, guidance_scale=7.5).images[0] image.save("pokemon.png") ``` +If you are loading the LoRA parameters from the Hub and if the Hub repository has +a `base_model` tag (such as [this](https://huggingface.co/sayakpaul/sd-model-finetuned-lora-t4/blob/main/README.md?code=true#L4)), then +you can do: + +```py +from huggingface_hub.repocard import RepoCard + +lora_model_id = "sayakpaul/sd-model-finetuned-lora-t4" +card = RepoCard.load(lora_model_id) +base_model_id = card.data.to_dict()["base_model"] + +pipe = StableDiffusionPipeline.from_pretrained(base_model_id, torch_dtype=torch.float16) +... +``` + ## Training with Flax/JAX For faster training on TPUs and GPUs you can leverage the flax training example. Follow the instructions above to get the model and dataset before running the script. diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index b4c443fd303b..c1ee9afecb2c 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -45,6 +45,8 @@ logger = logging.get_logger(__name__) +TEXT_ENCODER_NAME = "text_encoder" +UNET_NAME = "unet" LORA_WEIGHT_NAME = "pytorch_lora_weights.bin" LORA_WEIGHT_NAME_SAFE = "pytorch_lora_weights.safetensors" @@ -87,6 +89,9 @@ def map_from(module, state_dict, *args, **kwargs): class UNet2DConditionLoadersMixin: + text_encoder_name = TEXT_ENCODER_NAME + unet_name = UNET_NAME + def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs): r""" Load pretrained attention processor layers into `UNet2DConditionModel`. Attention processor layers have to be @@ -225,6 +230,19 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict is_custom_diffusion = any("custom_diffusion" in k for k in state_dict.keys()) if is_lora: + is_lora_legacy = all( + key.startswith(self.unet_name) or key.startswith(self.text_encoder_name) for key in state_dict.keys() + ) + if not is_lora_legacy: + deprecation_message = ( + "Using `pipe.unet.load_attn_procs()` is deprecated. Please change to: `pipe.load_lora_weights()`." + ) + deprecate("legacy LoRA weights", "1.0.0", deprecation_message, standard_warn=False) + else: + raise ValueError( + "You are using the new LoRA serialization format introduced in https://github.com/huggingface/diffusers/pull/2918. Please use `pipe.load_lora_weights(...)`" + ) + lora_grouped_dict = defaultdict(dict) for key, value in state_dict.items(): attn_processor_key, sub_key = ".".join(key.split(".")[:-3]), ".".join(key.split(".")[-3:]) @@ -672,8 +690,8 @@ class LoraLoaderMixin: """ - text_encoder_name = "text_encoder" - unet_name = "unet" + text_encoder_name = TEXT_ENCODER_NAME + unet_name = UNET_NAME def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs): r""" @@ -823,7 +841,7 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di text_encoder_lora_state_dict = { k: v for k, v in state_dict.items() if k.startswith(self.text_encoder_name) } - attn_procs_text_encoder = self.load_attn_procs(text_encoder_lora_state_dict) + attn_procs_text_encoder = self._load_attn_procs(text_encoder_lora_state_dict) self._modify_text_encoder(attn_procs_text_encoder) # Otherwise, we're dealing with the old format. This means the `state_dict` should only @@ -872,7 +890,7 @@ def _get_lora_layer_attribute(self, name: str) -> str: else: return "to_out_lora" - def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs): + def _load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs): r""" Load pretrained attention processor layers for [`CLIPTextModel`](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel). From 4d8b1ebff5ca8e60a6064c6a2715600ddee55c69 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Fri, 21 Apr 2023 09:30:03 +0530 Subject: [PATCH 2/7] better function name? --- src/diffusers/loaders.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index c1ee9afecb2c..8eb2c620ccdc 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -841,7 +841,7 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di text_encoder_lora_state_dict = { k: v for k, v in state_dict.items() if k.startswith(self.text_encoder_name) } - attn_procs_text_encoder = self._load_attn_procs(text_encoder_lora_state_dict) + attn_procs_text_encoder = self.__load_text_encoder_attn_procs(text_encoder_lora_state_dict) self._modify_text_encoder(attn_procs_text_encoder) # Otherwise, we're dealing with the old format. This means the `state_dict` should only @@ -890,7 +890,9 @@ def _get_lora_layer_attribute(self, name: str) -> str: else: return "to_out_lora" - def _load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs): + def __load_text_encoder_attn_procs( + self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs + ): r""" Load pretrained attention processor layers for [`CLIPTextModel`](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel). From 0675e42ce7acaaeccadd23776f67b37807618f74 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Fri, 21 Apr 2023 10:31:03 +0530 Subject: [PATCH 3/7] fix: LoRA weight loading with the new format. --- examples/dreambooth/train_dreambooth_lora.py | 2 +- src/diffusers/loaders.py | 33 +++++++++++--------- 2 files changed, 19 insertions(+), 16 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora.py b/examples/dreambooth/train_dreambooth_lora.py index 1b75402c3550..1772208dc07b 100644 --- a/examples/dreambooth/train_dreambooth_lora.py +++ b/examples/dreambooth/train_dreambooth_lora.py @@ -1045,7 +1045,7 @@ def main(args): pipeline = pipeline.to(accelerator.device) # load attention processors - pipeline.load_attn_procs(args.output_dir) + pipeline.load_lora_weights(args.output_dir) # run inference if args.validation_prompt and args.num_validation_images > 0: diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index 8eb2c620ccdc..6c7b8ccdf7ac 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -230,18 +230,18 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict is_custom_diffusion = any("custom_diffusion" in k for k in state_dict.keys()) if is_lora: - is_lora_legacy = all( + is_new_lora_format = all( key.startswith(self.unet_name) or key.startswith(self.text_encoder_name) for key in state_dict.keys() ) - if not is_lora_legacy: + if is_new_lora_format: + raise ValueError( + "You are using the new LoRA serialization format introduced in https://github.com/huggingface/diffusers/pull/2918. Please use `pipe.load_lora_weights(...)`" + ) + else: deprecation_message = ( "Using `pipe.unet.load_attn_procs()` is deprecated. Please change to: `pipe.load_lora_weights()`." ) deprecate("legacy LoRA weights", "1.0.0", deprecation_message, standard_warn=False) - else: - raise ValueError( - "You are using the new LoRA serialization format introduced in https://github.com/huggingface/diffusers/pull/2918. Please use `pipe.load_lora_weights(...)`" - ) lora_grouped_dict = defaultdict(dict) for key, value in state_dict.items(): @@ -828,21 +828,24 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di # then the `state_dict` keys should have `self.unet_name` and/or `self.text_encoder_name` as # their prefixes. keys = list(state_dict.keys()) - - # Load the layers corresponding to UNet. - if all(key.startswith(self.unet_name) for key in keys): + if all(key.startswith(self.unet_name) or key.startswith(self.text_encoder_name) for key in keys): + # Load the layers corresponding to UNet. + unet_keys = [k for k in keys if k.startswith(self.unet_name)] logger.info(f"Loading {self.unet_name}.") - unet_lora_state_dict = {k: v for k, v in state_dict.items() if k.startswith(self.unet_name)} + unet_lora_state_dict = { + k.replace(f"{self.unet_name}.", ""): v for k, v in state_dict.items() if k in unet_keys + } self.unet.load_attn_procs(unet_lora_state_dict) - # Load the layers corresponding to text encoder and make necessary adjustments. - elif all(key.startswith(self.text_encoder_name) for key in keys): + # Load the layers corresponding to text encoder and make necessary adjustments. + text_encoder_keys = [k for k in keys if k.startswith(self.text_encoder_name)] logger.info(f"Loading {self.text_encoder_name}.") text_encoder_lora_state_dict = { - k: v for k, v in state_dict.items() if k.startswith(self.text_encoder_name) + k.replace(f"{self.text_encoder_name}.", ""): v for k, v in state_dict.items() if k in text_encoder_keys } - attn_procs_text_encoder = self.__load_text_encoder_attn_procs(text_encoder_lora_state_dict) - self._modify_text_encoder(attn_procs_text_encoder) + if len(text_encoder_lora_state_dict) > 0: + attn_procs_text_encoder = self.__load_text_encoder_attn_procs(text_encoder_lora_state_dict) + self._modify_text_encoder(attn_procs_text_encoder) # Otherwise, we're dealing with the old format. This means the `state_dict` should only # contain the module names of the `unet` as its keys WITHOUT any prefix. From ad837b0891775ef4a649e4dd710f5f81e21a213a Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Mon, 24 Apr 2023 10:23:22 +0530 Subject: [PATCH 4/7] address Patrick's comments. --- examples/test_examples.py | 8 ++++++-- src/diffusers/loaders.py | 23 ++++++++++------------- 2 files changed, 16 insertions(+), 15 deletions(-) diff --git a/examples/test_examples.py b/examples/test_examples.py index d4a5ef5046f0..648c2cb8a1b7 100644 --- a/examples/test_examples.py +++ b/examples/test_examples.py @@ -281,10 +281,14 @@ def test_dreambooth_lora_with_text_encoder(self): # save_pretrained smoke test self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.bin"))) - # the names of the keys of the state dict should either start with `unet` - # or `text_encoder`. + # check `text_encoder` is present at all. lora_state_dict = torch.load(os.path.join(tmpdir, "pytorch_lora_weights.bin")) keys = lora_state_dict.keys() + is_text_encoder_present = any(k.startswith("text_encoder") for k in keys) + self.assertTrue(is_text_encoder_present) + + # the names of the keys of the state dict should either start with `unet` + # or `text_encoder`. is_correct_naming = all(k.startswith("unet") or k.startswith("text_encoder") for k in keys) self.assertTrue(is_correct_naming) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index 6c7b8ccdf7ac..c406975ce4bf 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import os +import warnings from collections import defaultdict from pathlib import Path from typing import Callable, Dict, List, Optional, Union @@ -234,14 +235,13 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict key.startswith(self.unet_name) or key.startswith(self.text_encoder_name) for key in state_dict.keys() ) if is_new_lora_format: - raise ValueError( - "You are using the new LoRA serialization format introduced in https://github.com/huggingface/diffusers/pull/2918. Please use `pipe.load_lora_weights(...)`" - ) - else: - deprecation_message = ( - "Using `pipe.unet.load_attn_procs()` is deprecated. Please change to: `pipe.load_lora_weights()`." - ) - deprecate("legacy LoRA weights", "1.0.0", deprecation_message, standard_warn=False) + # Strip the `"unet"` prefix. + is_text_encoder_present = any(key.startswith(self.text_encoder_name) for key in state_dict.keys()) + if is_text_encoder_present: + warn_message = "The state_dict contains LoRA params corresponding to the text encoder which are not being used here. To use both UNet and text encoder related LoRA params, use [`pipe.load_lora_weights()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraLoaderMixin.load_lora_weights)." + warnings.warn(warn_message) + unet_keys = [k for k in state_dict.keys() if k.startswith(self.unet_name)] + state_dict = {k.replace(f"{self.unet_name}.", ""): v for k, v in state_dict.items() if k in unet_keys} lora_grouped_dict = defaultdict(dict) for key, value in state_dict.items(): @@ -853,11 +853,8 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di key.startswith(self.unet_name) or key.startswith(self.text_encoder_name) for key in state_dict.keys() ): self.unet.load_attn_procs(state_dict) - deprecation_message = "You have saved the LoRA weights using the old format. This will be" - " deprecated soon. To convert the old LoRA weights to the new format, you can first load them" - " in a dictionary and then create a new dictionary like the following:" - " `new_state_dict = {f'unet'.{module_name}: params for module_name, params in old_state_dict.items()}`." - deprecate("legacy LoRA weights", "1.0.0", deprecation_message, standard_warn=False) + warn_message = "You have saved the LoRA weights using the old format. To convert the old LoRA weights to the new format, you can first load them in a dictionary and then create a new dictionary like the following: `new_state_dict = {f'unet'.{module_name}: params for module_name, params in old_state_dict.items()}`." + warnings.warn(warn_message) def _modify_text_encoder(self, attn_processors: Dict[str, LoRAAttnProcessor]): r""" From 3db4f694407633625e44236e4fa218e072cb9b74 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Fri, 28 Apr 2023 10:41:05 +0530 Subject: [PATCH 5/7] Apply suggestions from code review Co-authored-by: Patrick von Platen --- src/diffusers/loaders.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index 37876ee022a7..d80834ad0a4b 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -890,7 +890,7 @@ def _get_lora_layer_attribute(self, name: str) -> str: else: return "to_out_lora" - def __load_text_encoder_attn_procs( + def _load_text_encoder_attn_procs( self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs ): r""" From fa03c3ddf1eecff22db0de294629f269893e6134 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Fri, 28 Apr 2023 10:44:22 +0530 Subject: [PATCH 6/7] change wording around encouraging the use of load_lora_weights(). --- docs/source/en/_toctree.yml | 6 +++--- docs/source/en/training/lora.mdx | 3 +-- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 35c5fd78a1f6..26d3dbcf4e83 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -171,7 +171,7 @@ - local: api/pipelines/semantic_stable_diffusion title: Semantic Guidance - local: api/pipelines/spectrogram_diffusion - title: "Spectrogram Diffusion" + title: Spectrogram Diffusion - sections: - local: api/pipelines/stable_diffusion/overview title: Overview @@ -238,6 +238,8 @@ title: DPM Discrete Scheduler - local: api/schedulers/dpm_discrete_ancestral title: DPM Discrete Scheduler with ancestral sampling + - local: api/schedulers/dpm_sde + title: DPMSolverSDEScheduler - local: api/schedulers/euler_ancestral title: Euler Ancestral Scheduler - local: api/schedulers/euler @@ -266,8 +268,6 @@ title: VP-SDE - local: api/schedulers/vq_diffusion title: VQDiffusionScheduler - - local: api/schedulers/dpm_sde - title: DPMSolverSDEScheduler title: Schedulers - sections: - local: api/experimental/rl diff --git a/docs/source/en/training/lora.mdx b/docs/source/en/training/lora.mdx index 7264b36deed1..3c7cc7ebfeec 100644 --- a/docs/source/en/training/lora.mdx +++ b/docs/source/en/training/lora.mdx @@ -243,8 +243,7 @@ Load the LoRA weights from your finetuned DreamBooth model *on top of the base m >>> image.save("bucket-dog.png") ``` -Note that we will gradually be depcrecating the use of [`UNet2DConditionLoadersMixin.load_attn_procs`] since we now have a more general -method to load the LoRA parameters -- [`LoraLoaderMixin.load_lora_weights`]. This is because +Note that the use of [`LoraLoaderMixin.load_lora_weights`] is preferred to [`UNet2DConditionLoadersMixin.load_attn_procs`] for loading LoRA parameters. This is because [`LoraLoaderMixin.load_lora_weights`] can handle the following situations: * LoRA parameters that don't have separate identifiers for the UNet and the text encoder (such as [`"patrickvonplaten/lora_dreambooth_dog_example"`](https://huggingface.co/patrickvonplaten/lora_dreambooth_dog_example)). So, you can just do: From 14f1aa994158cf761f1c0759bbdc8a2fc18f1c60 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Fri, 28 Apr 2023 10:48:53 +0530 Subject: [PATCH 7/7] fix: function name. --- src/diffusers/loaders.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index d80834ad0a4b..b4b0f4bb3bd6 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -844,7 +844,7 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di k.replace(f"{self.text_encoder_name}.", ""): v for k, v in state_dict.items() if k in text_encoder_keys } if len(text_encoder_lora_state_dict) > 0: - attn_procs_text_encoder = self.__load_text_encoder_attn_procs(text_encoder_lora_state_dict) + attn_procs_text_encoder = self._load_text_encoder_attn_procs(text_encoder_lora_state_dict) self._modify_text_encoder(attn_procs_text_encoder) # Otherwise, we're dealing with the old format. This means the `state_dict` should only