-
Notifications
You must be signed in to change notification settings - Fork 6.7k
[LoRA] quality of life improvements in the loading semantics and docs #3180
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
9ebee45
4d8b1eb
0675e42
4bcc0db
ad837b0
0676ae1
3db4f69
fa03c3d
14f1aa9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -115,7 +115,7 @@ Load the LoRA weights from your finetuned model *on top of the base model weight | |
| </Tip> | ||
|
|
||
| ```py | ||
| >>> pipe.unet.load_attn_procs(model_path) | ||
| >>> pipe.unet.load_attn_procs(lora_model_path) | ||
| >>> pipe.to("cuda") | ||
| # use half the weights from the LoRA finetuned model and half the weights from the base model | ||
|
|
||
|
|
@@ -128,6 +128,25 @@ Load the LoRA weights from your finetuned model *on top of the base model weight | |
| >>> image.save("blue_pokemon.png") | ||
| ``` | ||
|
|
||
| <Tip> | ||
|
|
||
| If you are loading the LoRA parameters from the Hub and if the Hub repository has | ||
| a `base_model` tag (such as [this](https://huggingface.co/sayakpaul/sd-model-finetuned-lora-t4/blob/main/README.md?code=true#L4)), then | ||
| you can do: | ||
|
|
||
| ```py | ||
| from huggingface_hub.repocard import RepoCard | ||
|
|
||
| lora_model_id = "sayakpaul/sd-model-finetuned-lora-t4" | ||
| card = RepoCard.load(lora_model_id) | ||
| base_model_id = card.data.to_dict()["base_model"] | ||
|
|
||
| pipe = StableDiffusionPipeline.from_pretrained(base_model_id, torch_dtype=torch.float16) | ||
| ... | ||
| ``` | ||
|
|
||
| </Tip> | ||
|
|
||
| ## DreamBooth | ||
|
|
||
| [DreamBooth](https://arxiv.org/abs/2208.12242) is a finetuning technique for personalizing a text-to-image model like Stable Diffusion to generate photorealistic images of a subject in different contexts, given a few images of the subject. However, DreamBooth is very sensitive to hyperparameters and it is easy to overfit. Some important hyperparameters to consider include those that affect the training time (learning rate, number of training steps), and inference time (number of steps, scheduler type). | ||
|
|
@@ -208,7 +227,7 @@ Load the LoRA weights from your finetuned DreamBooth model *on top of the base m | |
| </Tip> | ||
|
|
||
| ```py | ||
| >>> pipe.unet.load_attn_procs(model_path) | ||
| >>> pipe.unet.load_attn_procs(lora_model_path) | ||
| >>> pipe.to("cuda") | ||
| # use half the weights from the LoRA finetuned model and half the weights from the base model | ||
|
|
||
|
|
@@ -222,4 +241,15 @@ Load the LoRA weights from your finetuned DreamBooth model *on top of the base m | |
|
|
||
| >>> image = pipe("A picture of a sks dog in a bucket.", num_inference_steps=25, guidance_scale=7.5).images[0] | ||
| >>> image.save("bucket-dog.png") | ||
| ``` | ||
| ``` | ||
|
|
||
| Note that the use of [`LoraLoaderMixin.load_lora_weights`] is preferred to [`UNet2DConditionLoadersMixin.load_attn_procs`] for loading LoRA parameters. This is because | ||
| [`LoraLoaderMixin.load_lora_weights`] can handle the following situations: | ||
|
Comment on lines
+246
to
+247
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since we decided not to deprecate things, I changed the wording here. @patrickvonplaten as an FYI. |
||
|
|
||
| * LoRA parameters that don't have separate identifiers for the UNet and the text encoder (such as [`"patrickvonplaten/lora_dreambooth_dog_example"`](https://huggingface.co/patrickvonplaten/lora_dreambooth_dog_example)). So, you can just do: | ||
|
|
||
| ```py | ||
| pipe.load_lora_weights(lora_model_path) | ||
| ``` | ||
|
|
||
| * LoRA parameters that have separate identifiers for the UNet and the text encoder such as: [`"sayakpaul/dreambooth"`](https://huggingface.co/sayakpaul/dreambooth). | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1045,7 +1045,7 @@ def main(args): | |
| pipeline = pipeline.to(accelerator.device) | ||
|
|
||
| # load attention processors | ||
| pipeline.load_attn_procs(args.output_dir) | ||
| pipeline.load_lora_weights(args.output_dir) | ||
|
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should be using the more general method here.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Agree! |
||
|
|
||
| # run inference | ||
| if args.validation_prompt and args.num_validation_images > 0: | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -12,6 +12,7 @@ | |
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| import os | ||
| import warnings | ||
| from collections import defaultdict | ||
| from pathlib import Path | ||
| from typing import Callable, Dict, List, Optional, Union | ||
|
|
@@ -45,6 +46,8 @@ | |
|
|
||
| logger = logging.get_logger(__name__) | ||
|
|
||
| TEXT_ENCODER_NAME = "text_encoder" | ||
| UNET_NAME = "unet" | ||
|
|
||
| LORA_WEIGHT_NAME = "pytorch_lora_weights.bin" | ||
| LORA_WEIGHT_NAME_SAFE = "pytorch_lora_weights.safetensors" | ||
|
|
@@ -87,6 +90,9 @@ def map_from(module, state_dict, *args, **kwargs): | |
|
|
||
|
|
||
| class UNet2DConditionLoadersMixin: | ||
| text_encoder_name = TEXT_ENCODER_NAME | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Think this is the best option we have right now! Just FYI: Feel slightly uncomfortable with the unet having to know about the text encoder name here as we now super lightly entangle those concepts now, but I also think that this is the best design at the moment! Thank for adding it!
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do you want me to add the names to the utils constants? |
||
| unet_name = UNET_NAME | ||
|
|
||
| def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs): | ||
| r""" | ||
| Load pretrained attention processor layers into `UNet2DConditionModel`. Attention processor layers have to be | ||
|
|
@@ -225,6 +231,18 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict | |
| is_custom_diffusion = any("custom_diffusion" in k for k in state_dict.keys()) | ||
|
|
||
| if is_lora: | ||
| is_new_lora_format = all( | ||
| key.startswith(self.unet_name) or key.startswith(self.text_encoder_name) for key in state_dict.keys() | ||
| ) | ||
| if is_new_lora_format: | ||
| # Strip the `"unet"` prefix. | ||
| is_text_encoder_present = any(key.startswith(self.text_encoder_name) for key in state_dict.keys()) | ||
| if is_text_encoder_present: | ||
| warn_message = "The state_dict contains LoRA params corresponding to the text encoder which are not being used here. To use both UNet and text encoder related LoRA params, use [`pipe.load_lora_weights()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraLoaderMixin.load_lora_weights)." | ||
| warnings.warn(warn_message) | ||
| unet_keys = [k for k in state_dict.keys() if k.startswith(self.unet_name)] | ||
| state_dict = {k.replace(f"{self.unet_name}.", ""): v for k, v in state_dict.items() if k in unet_keys} | ||
|
|
||
| lora_grouped_dict = defaultdict(dict) | ||
| for key, value in state_dict.items(): | ||
| attn_processor_key, sub_key = ".".join(key.split(".")[:-3]), ".".join(key.split(".")[-3:]) | ||
|
|
@@ -672,8 +690,8 @@ class LoraLoaderMixin: | |
|
|
||
| </Tip> | ||
| """ | ||
| text_encoder_name = "text_encoder" | ||
| unet_name = "unet" | ||
| text_encoder_name = TEXT_ENCODER_NAME | ||
| unet_name = UNET_NAME | ||
|
|
||
| def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs): | ||
| r""" | ||
|
|
@@ -810,33 +828,33 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di | |
| # then the `state_dict` keys should have `self.unet_name` and/or `self.text_encoder_name` as | ||
| # their prefixes. | ||
| keys = list(state_dict.keys()) | ||
|
|
||
| # Load the layers corresponding to UNet. | ||
| if all(key.startswith(self.unet_name) for key in keys): | ||
|
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I had messed up with the filtering 😟 now, it should be fixed. |
||
| if all(key.startswith(self.unet_name) or key.startswith(self.text_encoder_name) for key in keys): | ||
| # Load the layers corresponding to UNet. | ||
| unet_keys = [k for k in keys if k.startswith(self.unet_name)] | ||
| logger.info(f"Loading {self.unet_name}.") | ||
| unet_lora_state_dict = {k: v for k, v in state_dict.items() if k.startswith(self.unet_name)} | ||
| unet_lora_state_dict = { | ||
| k.replace(f"{self.unet_name}.", ""): v for k, v in state_dict.items() if k in unet_keys | ||
| } | ||
| self.unet.load_attn_procs(unet_lora_state_dict) | ||
|
|
||
| # Load the layers corresponding to text encoder and make necessary adjustments. | ||
| elif all(key.startswith(self.text_encoder_name) for key in keys): | ||
| # Load the layers corresponding to text encoder and make necessary adjustments. | ||
| text_encoder_keys = [k for k in keys if k.startswith(self.text_encoder_name)] | ||
| logger.info(f"Loading {self.text_encoder_name}.") | ||
| text_encoder_lora_state_dict = { | ||
| k: v for k, v in state_dict.items() if k.startswith(self.text_encoder_name) | ||
| k.replace(f"{self.text_encoder_name}.", ""): v for k, v in state_dict.items() if k in text_encoder_keys | ||
| } | ||
| attn_procs_text_encoder = self.load_attn_procs(text_encoder_lora_state_dict) | ||
| self._modify_text_encoder(attn_procs_text_encoder) | ||
| if len(text_encoder_lora_state_dict) > 0: | ||
|
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In case someone doesn't train the text encoder but still, the parameters are serialized in the new format. So, this check is needed. |
||
| attn_procs_text_encoder = self._load_text_encoder_attn_procs(text_encoder_lora_state_dict) | ||
| self._modify_text_encoder(attn_procs_text_encoder) | ||
|
|
||
| # Otherwise, we're dealing with the old format. This means the `state_dict` should only | ||
| # contain the module names of the `unet` as its keys WITHOUT any prefix. | ||
| elif not all( | ||
| key.startswith(self.unet_name) or key.startswith(self.text_encoder_name) for key in state_dict.keys() | ||
| ): | ||
| self.unet.load_attn_procs(state_dict) | ||
| deprecation_message = "You have saved the LoRA weights using the old format. This will be" | ||
| " deprecated soon. To convert the old LoRA weights to the new format, you can first load them" | ||
| " in a dictionary and then create a new dictionary like the following:" | ||
| " `new_state_dict = {f'unet'.{module_name}: params for module_name, params in old_state_dict.items()}`." | ||
| deprecate("legacy LoRA weights", "1.0.0", deprecation_message, standard_warn=False) | ||
| warn_message = "You have saved the LoRA weights using the old format. To convert the old LoRA weights to the new format, you can first load them in a dictionary and then create a new dictionary like the following: `new_state_dict = {f'unet'.{module_name}: params for module_name, params in old_state_dict.items()}`." | ||
| warnings.warn(warn_message) | ||
|
|
||
| def _modify_text_encoder(self, attn_processors: Dict[str, LoRAAttnProcessor]): | ||
| r""" | ||
|
|
@@ -872,7 +890,9 @@ def _get_lora_layer_attribute(self, name: str) -> str: | |
| else: | ||
| return "to_out_lora" | ||
|
|
||
| def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs): | ||
| def _load_text_encoder_attn_procs( | ||
| self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs | ||
| ): | ||
| r""" | ||
| Load pretrained attention processor layers for | ||
| [`CLIPTextModel`](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel). | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To better distinguish between the base model path and the LoRA model path.