Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions docs/source/en/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@
- local: api/pipelines/semantic_stable_diffusion
title: Semantic Guidance
- local: api/pipelines/spectrogram_diffusion
title: "Spectrogram Diffusion"
title: Spectrogram Diffusion
- sections:
- local: api/pipelines/stable_diffusion/overview
title: Overview
Expand Down Expand Up @@ -238,6 +238,8 @@
title: DPM Discrete Scheduler
- local: api/schedulers/dpm_discrete_ancestral
title: DPM Discrete Scheduler with ancestral sampling
- local: api/schedulers/dpm_sde
title: DPMSolverSDEScheduler
- local: api/schedulers/euler_ancestral
title: Euler Ancestral Scheduler
- local: api/schedulers/euler
Expand Down Expand Up @@ -266,8 +268,6 @@
title: VP-SDE
- local: api/schedulers/vq_diffusion
title: VQDiffusionScheduler
- local: api/schedulers/dpm_sde
title: DPMSolverSDEScheduler
title: Schedulers
- sections:
- local: api/experimental/rl
Expand Down
36 changes: 33 additions & 3 deletions docs/source/en/training/lora.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ Load the LoRA weights from your finetuned model *on top of the base model weight
</Tip>

```py
>>> pipe.unet.load_attn_procs(model_path)
>>> pipe.unet.load_attn_procs(lora_model_path)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To better distinguish between the base model path and the LoRA model path.

>>> pipe.to("cuda")
# use half the weights from the LoRA finetuned model and half the weights from the base model

Expand All @@ -128,6 +128,25 @@ Load the LoRA weights from your finetuned model *on top of the base model weight
>>> image.save("blue_pokemon.png")
```

<Tip>

If you are loading the LoRA parameters from the Hub and if the Hub repository has
a `base_model` tag (such as [this](https://huggingface.co/sayakpaul/sd-model-finetuned-lora-t4/blob/main/README.md?code=true#L4)), then
you can do:

```py
from huggingface_hub.repocard import RepoCard

lora_model_id = "sayakpaul/sd-model-finetuned-lora-t4"
card = RepoCard.load(lora_model_id)
base_model_id = card.data.to_dict()["base_model"]

pipe = StableDiffusionPipeline.from_pretrained(base_model_id, torch_dtype=torch.float16)
...
```

</Tip>

## DreamBooth

[DreamBooth](https://arxiv.org/abs/2208.12242) is a finetuning technique for personalizing a text-to-image model like Stable Diffusion to generate photorealistic images of a subject in different contexts, given a few images of the subject. However, DreamBooth is very sensitive to hyperparameters and it is easy to overfit. Some important hyperparameters to consider include those that affect the training time (learning rate, number of training steps), and inference time (number of steps, scheduler type).
Expand Down Expand Up @@ -208,7 +227,7 @@ Load the LoRA weights from your finetuned DreamBooth model *on top of the base m
</Tip>

```py
>>> pipe.unet.load_attn_procs(model_path)
>>> pipe.unet.load_attn_procs(lora_model_path)
>>> pipe.to("cuda")
# use half the weights from the LoRA finetuned model and half the weights from the base model

Expand All @@ -222,4 +241,15 @@ Load the LoRA weights from your finetuned DreamBooth model *on top of the base m

>>> image = pipe("A picture of a sks dog in a bucket.", num_inference_steps=25, guidance_scale=7.5).images[0]
>>> image.save("bucket-dog.png")
```
```

Note that the use of [`LoraLoaderMixin.load_lora_weights`] is preferred to [`UNet2DConditionLoadersMixin.load_attn_procs`] for loading LoRA parameters. This is because
[`LoraLoaderMixin.load_lora_weights`] can handle the following situations:
Comment on lines +246 to +247
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we decided not to deprecate things, I changed the wording here. @patrickvonplaten as an FYI.


* LoRA parameters that don't have separate identifiers for the UNet and the text encoder (such as [`"patrickvonplaten/lora_dreambooth_dog_example"`](https://huggingface.co/patrickvonplaten/lora_dreambooth_dog_example)). So, you can just do:

```py
pipe.load_lora_weights(lora_model_path)
```

* LoRA parameters that have separate identifiers for the UNet and the text encoder such as: [`"sayakpaul/dreambooth"`](https://huggingface.co/sayakpaul/dreambooth).
29 changes: 28 additions & 1 deletion examples/dreambooth/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -355,7 +355,7 @@ The final LoRA embedding weights have been uploaded to [patrickvonplaten/lora_dr
The training results are summarized [here](https://api.wandb.ai/report/patrickvonplaten/xm6cd5q5).
You can use the `Step` slider to see how the model learned the features of our subject while the model trained.

Optionally, we can also train additional LoRA layers for the text encoder. Specify the `train_text_encoder` argument above for that. If you're interested to know more about how we
Optionally, we can also train additional LoRA layers for the text encoder. Specify the `--train_text_encoder` argument above for that. If you're interested to know more about how we
enable this support, check out this [PR](https://github.com/huggingface/diffusers/pull/2918).

With the default hyperparameters from the above, the training seems to go in a positive direction. Check out [this panel](https://wandb.ai/sayakpaul/dreambooth-lora/reports/test-23-04-17-17-00-13---Vmlldzo0MDkwNjMy). The trained LoRA layers are available [here](https://huggingface.co/sayakpaul/dreambooth).
Expand Down Expand Up @@ -387,6 +387,33 @@ Finally, we can run the model in inference.
image = pipe("A picture of a sks dog in a bucket", num_inference_steps=25).images[0]
```

If you are loading the LoRA parameters from the Hub and if the Hub repository has
a `base_model` tag (such as [this](https://huggingface.co/patrickvonplaten/lora_dreambooth_dog_example/blob/main/README.md?code=true#L4)), then
you can do:

```py
from huggingface_hub.repocard import RepoCard

lora_model_id = "patrickvonplaten/lora_dreambooth_dog_example"
card = RepoCard.load(lora_model_id)
base_model_id = card.data.to_dict()["base_model"]

pipe = StableDiffusionPipeline.from_pretrained(base_model_id, torch_dtype=torch.float16)
...
```

**Note** that we will gradually be depcrecating the use of [`UNet2DConditionLoadersMixin.load_attn_procs`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.UNet2DConditionLoadersMixin.load_attn_procs) since we now have a more general
method to load the LoRA parameters -- [`LoraLoaderMixin.load_lora_weights`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraLoaderMixin.load_lora_weights). This is because
[`LoraLoaderMixin.load_lora_weights`] can handle the following situations:

* LoRA parameters that don't have separate identifiers for the UNet and the text encoder (such as [`"patrickvonplaten/lora_dreambooth_dog_example"`](https://huggingface.co/patrickvonplaten/lora_dreambooth_dog_example)). So, you can just do:

```py
pipe.load_lora_weights(lora_model_path)
```

* LoRA parameters that have separate identifiers for the UNet and the text encoder such as: [`"sayakpaul/dreambooth"`](https://huggingface.co/sayakpaul/dreambooth).

## Training with Flax/JAX

For faster training on TPUs and GPUs you can leverage the flax training example. Follow the instructions above to get the model and dataset before running the script.
Expand Down
2 changes: 1 addition & 1 deletion examples/dreambooth/train_dreambooth_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -1045,7 +1045,7 @@ def main(args):
pipeline = pipeline.to(accelerator.device)

# load attention processors
pipeline.load_attn_procs(args.output_dir)
pipeline.load_lora_weights(args.output_dir)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should be using the more general method here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree!


# run inference
if args.validation_prompt and args.num_validation_images > 0:
Expand Down
8 changes: 6 additions & 2 deletions examples/test_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,10 +281,14 @@ def test_dreambooth_lora_with_text_encoder(self):
# save_pretrained smoke test
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.bin")))

# the names of the keys of the state dict should either start with `unet`
# or `text_encoder`.
# check `text_encoder` is present at all.
lora_state_dict = torch.load(os.path.join(tmpdir, "pytorch_lora_weights.bin"))
keys = lora_state_dict.keys()
is_text_encoder_present = any(k.startswith("text_encoder") for k in keys)
self.assertTrue(is_text_encoder_present)

# the names of the keys of the state dict should either start with `unet`
# or `text_encoder`.
is_correct_naming = all(k.startswith("unet") or k.startswith("text_encoder") for k in keys)
self.assertTrue(is_correct_naming)

Expand Down
15 changes: 15 additions & 0 deletions examples/text_to_image/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,21 @@ image = pipe(prompt, num_inference_steps=30, guidance_scale=7.5).images[0]
image.save("pokemon.png")
```

If you are loading the LoRA parameters from the Hub and if the Hub repository has
a `base_model` tag (such as [this](https://huggingface.co/sayakpaul/sd-model-finetuned-lora-t4/blob/main/README.md?code=true#L4)), then
you can do:

```py
from huggingface_hub.repocard import RepoCard

lora_model_id = "sayakpaul/sd-model-finetuned-lora-t4"
card = RepoCard.load(lora_model_id)
base_model_id = card.data.to_dict()["base_model"]

pipe = StableDiffusionPipeline.from_pretrained(base_model_id, torch_dtype=torch.float16)
...
```

## Training with Flax/JAX

For faster training on TPUs and GPUs you can leverage the flax training example. Follow the instructions above to get the model and dataset before running the script.
Expand Down
54 changes: 37 additions & 17 deletions src/diffusers/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import warnings
from collections import defaultdict
from pathlib import Path
from typing import Callable, Dict, List, Optional, Union
Expand Down Expand Up @@ -45,6 +46,8 @@

logger = logging.get_logger(__name__)

TEXT_ENCODER_NAME = "text_encoder"
UNET_NAME = "unet"

LORA_WEIGHT_NAME = "pytorch_lora_weights.bin"
LORA_WEIGHT_NAME_SAFE = "pytorch_lora_weights.safetensors"
Expand Down Expand Up @@ -87,6 +90,9 @@ def map_from(module, state_dict, *args, **kwargs):


class UNet2DConditionLoadersMixin:
text_encoder_name = TEXT_ENCODER_NAME
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Think this is the best option we have right now!

Just FYI: Feel slightly uncomfortable with the unet having to know about the text encoder name here as we now super lightly entangle those concepts now, but I also think that this is the best design at the moment! Thank for adding it!

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you want me to add the names to the utils constants?

unet_name = UNET_NAME

def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs):
r"""
Load pretrained attention processor layers into `UNet2DConditionModel`. Attention processor layers have to be
Expand Down Expand Up @@ -225,6 +231,18 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict
is_custom_diffusion = any("custom_diffusion" in k for k in state_dict.keys())

if is_lora:
is_new_lora_format = all(
key.startswith(self.unet_name) or key.startswith(self.text_encoder_name) for key in state_dict.keys()
)
if is_new_lora_format:
# Strip the `"unet"` prefix.
is_text_encoder_present = any(key.startswith(self.text_encoder_name) for key in state_dict.keys())
if is_text_encoder_present:
warn_message = "The state_dict contains LoRA params corresponding to the text encoder which are not being used here. To use both UNet and text encoder related LoRA params, use [`pipe.load_lora_weights()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraLoaderMixin.load_lora_weights)."
warnings.warn(warn_message)
unet_keys = [k for k in state_dict.keys() if k.startswith(self.unet_name)]
state_dict = {k.replace(f"{self.unet_name}.", ""): v for k, v in state_dict.items() if k in unet_keys}

lora_grouped_dict = defaultdict(dict)
for key, value in state_dict.items():
attn_processor_key, sub_key = ".".join(key.split(".")[:-3]), ".".join(key.split(".")[-3:])
Expand Down Expand Up @@ -672,8 +690,8 @@ class LoraLoaderMixin:

</Tip>
"""
text_encoder_name = "text_encoder"
unet_name = "unet"
text_encoder_name = TEXT_ENCODER_NAME
unet_name = UNET_NAME

def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs):
r"""
Expand Down Expand Up @@ -810,33 +828,33 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di
# then the `state_dict` keys should have `self.unet_name` and/or `self.text_encoder_name` as
# their prefixes.
keys = list(state_dict.keys())

# Load the layers corresponding to UNet.
if all(key.startswith(self.unet_name) for key in keys):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had messed up with the filtering 😟 now, it should be fixed.

if all(key.startswith(self.unet_name) or key.startswith(self.text_encoder_name) for key in keys):
# Load the layers corresponding to UNet.
unet_keys = [k for k in keys if k.startswith(self.unet_name)]
logger.info(f"Loading {self.unet_name}.")
unet_lora_state_dict = {k: v for k, v in state_dict.items() if k.startswith(self.unet_name)}
unet_lora_state_dict = {
k.replace(f"{self.unet_name}.", ""): v for k, v in state_dict.items() if k in unet_keys
}
self.unet.load_attn_procs(unet_lora_state_dict)

# Load the layers corresponding to text encoder and make necessary adjustments.
elif all(key.startswith(self.text_encoder_name) for key in keys):
# Load the layers corresponding to text encoder and make necessary adjustments.
text_encoder_keys = [k for k in keys if k.startswith(self.text_encoder_name)]
logger.info(f"Loading {self.text_encoder_name}.")
text_encoder_lora_state_dict = {
k: v for k, v in state_dict.items() if k.startswith(self.text_encoder_name)
k.replace(f"{self.text_encoder_name}.", ""): v for k, v in state_dict.items() if k in text_encoder_keys
}
attn_procs_text_encoder = self.load_attn_procs(text_encoder_lora_state_dict)
self._modify_text_encoder(attn_procs_text_encoder)
if len(text_encoder_lora_state_dict) > 0:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In case someone doesn't train the text encoder but still, the parameters are serialized in the new format. So, this check is needed.

attn_procs_text_encoder = self._load_text_encoder_attn_procs(text_encoder_lora_state_dict)
self._modify_text_encoder(attn_procs_text_encoder)

# Otherwise, we're dealing with the old format. This means the `state_dict` should only
# contain the module names of the `unet` as its keys WITHOUT any prefix.
elif not all(
key.startswith(self.unet_name) or key.startswith(self.text_encoder_name) for key in state_dict.keys()
):
self.unet.load_attn_procs(state_dict)
deprecation_message = "You have saved the LoRA weights using the old format. This will be"
" deprecated soon. To convert the old LoRA weights to the new format, you can first load them"
" in a dictionary and then create a new dictionary like the following:"
" `new_state_dict = {f'unet'.{module_name}: params for module_name, params in old_state_dict.items()}`."
deprecate("legacy LoRA weights", "1.0.0", deprecation_message, standard_warn=False)
warn_message = "You have saved the LoRA weights using the old format. To convert the old LoRA weights to the new format, you can first load them in a dictionary and then create a new dictionary like the following: `new_state_dict = {f'unet'.{module_name}: params for module_name, params in old_state_dict.items()}`."
warnings.warn(warn_message)

def _modify_text_encoder(self, attn_processors: Dict[str, LoRAAttnProcessor]):
r"""
Expand Down Expand Up @@ -872,7 +890,9 @@ def _get_lora_layer_attribute(self, name: str) -> str:
else:
return "to_out_lora"

def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs):
def _load_text_encoder_attn_procs(
self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs
):
r"""
Load pretrained attention processor layers for
[`CLIPTextModel`](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel).
Expand Down