Skip to content

Add implementation of LyCORIS LoHa (FedPara-like adapter) for SD&SDXL models#956

Merged
BenjaminBossan merged 24 commits intohuggingface:mainfrom
kovalexal:loha
Oct 2, 2023
Merged

Add implementation of LyCORIS LoHa (FedPara-like adapter) for SD&SDXL models#956
BenjaminBossan merged 24 commits intohuggingface:mainfrom
kovalexal:loha

Conversation

@kovalexal
Copy link
Contributor

@kovalexal kovalexal commented Sep 22, 2023

This PR focuses on increasing compatibility of SD&SDXL adapters in peft with other open-source instruments like LyCORIS. Feel free to learn more about LyCORIS adapters from resources like this.

This specific PR is currently aimed at adding proper compatibility of peft with LoHa adapters. The original paper is called FedPara. LoHa is just a FedPara under the hood, but without federated learning. LoHa allows to train adapters with increased output quality and with more details compared to basic LoRA. As far as I know, it is probably the second most popular adapter for SD&SDXL on civitai.com, so from my perspective ability to use it with Hugging Face 🤗 ecosystem will be beneficial for peft (and future integration with diffusers).

Currently I've implemented and tested all the core functionality required for it to work properly with SD&SDXL models.

However, there are some pieces missing:

  • Unit tests
  • Documentation & examples
  • Conversion script for SD&SDXL for kohya_ss / LyCORIS trained LoHAs
  • Sample training script for SD / SDXL / SD&SDXL

Also, there are some open questions:

  • Conv1d implementation for LoHa layer?
  • Adding compatibility of LoHa with LLMs?
  • Adding ability to merge multiple LoHas together like LoRAs?
  • Adding the remaining LyCORIS adapters like LoKr / DyLORA (maybe in separate PR)?
  • Adding ability to merge LoRAs / LoHas / LoKrs / DyLORAs together (maybe in separate PR after we have all the adapters implemented)?

@pacman100 @BenjaminBossan @younesbelkada may I kindly ask you for your comments while I am still working on it?

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint.

@BenjaminBossan
Copy link
Member

Thanks for providing this feature! Could you please run make style so that CI can do its thing?

Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wow great work @kovalexal !
thanks a lot for this great contribution ! Let us know when this PR is ready for a first review

@kovalexal
Copy link
Contributor Author

Could you please run make style so that CI can do its thing?

@BenjaminBossan hi! Sorry, didn't know that hf-doc-builder also refactores docs a bit. Now it should be fixed.

@BenjaminBossan
Copy link
Member

This looks very promising, thanks for working on this. In general, we can add a new fine-tuning method even without tests or docs, but it would be good to have at least one working example to verify that the method works. Do you have something we can use to test the new feature?

@kovalexal
Copy link
Contributor Author

@BenjaminBossan I've currently tested inference with some civitai LoHa model with StableDiffusion (and get identical results compared to automatic1111 output).

Probably, for a light training sample, I can modify peft example script for dreambooth training next week to incorporate LoHa into it.

@kovalexal
Copy link
Contributor Author

kovalexal commented Sep 25, 2023

@BenjaminBossan I've added a simple training sample for LoHa based on dreambooth script.

I've tested it with my personal photos and used the following settings for training:

python train_dreambooth_loha.py \
--pretrained_model_name_or_path=... \
--instance_data_dir=... \
--instance_prompt="AlexanderKovalchuk" \
--output_dir=./output_loha \
--seed=42 \
--resolution=512 \
--train_text_encoder \
--use_loha \
--r=32 --alpha=32 \
--loha_text_encoder_r=32 --loha_text_encoder_alpha=32 \
--train_batch_size=2 \
--max_train_steps=3000 \
--learning_rate=1e-4 \
--num_validation_images=4 \
--validation_steps=50 \
--validation_prompt="AlexanderKovalchuk" \
--logging_dir=./output_logs \
--report_to=tensorboard \
--lr_warmup_steps=300 \
--lr_scheduler=constant_with_warmup \
--use_effective_conv2d

and after some time tb shows the following results:

Screenshot 2023-09-25 at 17 18 42

Should anything else be added to demonstrate that the method can train?

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey, thank you so much for adding this feature to PEFT. You did a really good job here, ensuring that the tuner is well integrated and even making some improvements over existing code where you saw fit. Well done.

I did an initial review purely for the code, I have checked neither the examples, nor whether the code corresponds to what the paper states. I assume that this is already well tested, especially since it's based on the kohya code.

I left a couple of comments, please check them out. I also created a PR on top of your PR to add the first couple of tests for LoHa.

One thing that confused me a bit is the naming between hada vs loha vs lycoris. I wonder if we can unify this?

@kovalexal
Copy link
Contributor Author

One thing that confused me a bit is the naming between hada vs loha vs lycoris. I wonder if we can unify this?

@BenjaminBossan I was also confused a bit at the beginning. Basically, LyCORIS is a series of adapters that are utilized to modify Stable Diffusion checkpoints:

  • LoCon / Lora-C3Lier - this adapters are already supported by PEFT (it's just a LoRA)
  • LoHa (aka FedPara) - added in the current PR
  • LoKr - similar to LoHa, but uses Kronecker product instead of Hadamard product in LoHa. I was also going to add this adapter in the current PR if it is appropriate, or I can implement it in a separate PR
  • DyLoRA - this is a LoRA with a simple but effective addition - weights dropout which allows to change lora rank after it was trained

So, only after all these adapters are supported, we can say that LyCORIS is fully supported in PEFT.

Also, AFAIK there are some successful applications of IA3 to Stable Diffusion - but I am not sure whether it's currently supported by PEFT.

That is why I've decided not to come up with adapter key names similar to LoRA and just stick with names from existing checkpoints. If you would like me to do it - I can modify key names, but probably it will make conversion logic more complex.

@BenjaminBossan
Copy link
Member

Thanks for explaining the terms better.

Also, AFAIK there are some successful applications of IA3 to Stable Diffusion - but I am not sure whether it's currently supported by PEFT.

I'm not sure if anyone has successfully tried it, but I see no obstacle why this should not work in principle. One caveat is that IA³ currently only has a Linear layer in PEFT, no Conv2d, though I don't see a technical reason why that couldn't be added.

That is why I've decided not to come up with adapter key names similar to LoRA and just stick with names from existing checkpoints.

Compatibility with existing checkpoints is a perfectly good reason, in this case I'm fine. Otherwise, I would have suggested to prefix the params with loha instead of hada, as I think the name would be more fitting (hadamard is only a part of the whole concept), but this is not worth having to add extra conversion logic.

@kovalexal
Copy link
Contributor Author

@BenjaminBossan the authors of LyCORIS have recently published a paper with more details and experiments for these adapters and Stable Diffusion.

@kovalexal kovalexal marked this pull request as ready for review September 27, 2023 11:23
@kovalexal
Copy link
Contributor Author

@BenjaminBossan I suppose that we can move forward and perform a review on this PR.

If it is not a problem, I'll add missing documentation in some of the next PRs (probably after the next adapter is finished).

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again, thanks for much for adding LoHa to PEFT. You did a great job here and the code is super clean, thanks for that.

Unfortunately, GH shows unrelated changes in the diff. I assume that you didn't make any changes on top of those changes and will not review them.

I have some comments, mostly they are not blockers but I think there a few small mistakes too, please take a look.

Apart from that, I have a larger design question:

From my understanding, LoHa and LoKr are almost identical. Ideally, it would be possible to extend this code in the future to switch to LoKr by using a single parameter, like config = LoHaConfig(..., method="kronecker")? But from what I can tell, this wouldn't be easy to add and we would essentially need to copy almost all of the code added here and create a completely separate LoKr tuner. I wonder if it's possible to make some changes to facilitate the addition of LoKr (and possibly other similar methods) in the future. Do you have an idea if this could be achieved?

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work, this looks good from my point of view. I'll ask one of the others for an additional review, as this is a pretty big PR.

@kovalexal
Copy link
Contributor Author

kovalexal commented Sep 28, 2023

@BenjaminBossan thank you very much for your valuable comments and for your time!

I wonder if it's possible to make some changes to facilitate the addition of LoKr.

From one side, having them as separate classes may be better for end users - there will be less confusion when somebody decides to try out these adapters for downstream tasks.

On the other side, it will require a lot of code duplication, but the code for each adapter will be much simpler of course. Also, in that case, we may need to rename LoHa adapter to something like LyCORIS to prevent future misunderstandings.

Do you have an idea if this could be achieved?

I've also been thinking a lot about it. I really enjoy using PEFT, but there is one missing general feature - ability to load and work with several adapter types at the same time (at least for inference time).

From my experience it is a pretty common thing in Stable Diffusion - you load and mix several adapter types with some scale - LoRAs, LoHas, and LoKrs to get the end result and it is pretty useful to mix different styles / characters / concepts to create something unique. I am not sure, if it is a common thing for LLMs, but it would be great to be able to do it at least for SD/SDXL. Of course, it's required in terms of being able to easily switch between different mixtures (like PEFT allows currently to easily switch between different adapters of the same type).

As far as I know, webui achieves this by summing up additions to the weights of each of the requested adapters. Maybe we can incorporate something like DeltaModel with a simple config (we need just target_modules for it) (there exist some adapters for SD that have diff keys, which probably are full diff that need to be added to base model). Also, we can reuse an existing add_weighted_adapter from LoRA with SVD to transform DeltaModel to LoraModel. But to do it we definitely need to be able to load and switch between different types of adapters at the beginning.

What do you think about it? Just in case you also find it useful - I would be grateful to help your team in implementing this.

@BenjaminBossan
Copy link
Member

From one side, having them as separate classes may be better for end users - there will be less confusion when somebody decides to try out these adapters for downstream tasks.

On the other side, it will require a lot of code duplication, but the code for each adapter will be much simpler of course. Also, in that case, we may need to rename LoHa adapter to something like LyCORIS to prevent future misunderstandings.

The big amount of code duplication is a burden on maintaining the library. A few times already, an issue was fixed in one place but forgotten somewhere else. I think we should strive to reduce duplication in the future. Perhaps, it would be possible to expose to the user separate LoHaConfig and LoKrConfig classes to avoid confusion, while still using 99% of the same code underneath.

ability to load and work with several adapter types at the same time (at least for inference time).

Yes, for sure this would be a nice feature and there is no technical limitation that would prevent it. At the moment, this only really works with PEFT when merging the adapters, which can be undesired.

On top of the technical question of how to combine different tuner types, we would also like each setting to be representable as a single config file, which could prove difficult. Personally, I would be fine with not having this possibility, but it would make sharing of models more difficult (basically a custom script would be required to set up the model, which has security implications).

Some combination of tuners applied to the same module could prove problematic if they make incompatible changes, but I think it would be up to the user to avoid such combinations.

Anyway, I think this is a topic for a different day and does not directly alter this PR.

Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for your huge work @kovalexal - I made a first pass, before we merge the PR can you try to merge / rebase your branch with main and see if this fixes the diff issue on the PR ? 🙏

@kovalexal
Copy link
Contributor Author

Hmm, seems strange that check_code_quality fails...
I ran make style & make quality several times, but it does not change anything.

@younesbelkada
Copy link
Contributor

@kovalexal
can you try:

pip install -U ".[quality]"

@BenjaminBossan
Copy link
Member

Strange, maybe the ruff version differs?

I think the issue is the order of these imports in test_custom_models.py:

from peft import AdaLoraConfig, IA3Config, LoraConfig, LoHaConfig, PeftModel, get_peft_model

LoHa should come before LoRA.

@kovalexal
Copy link
Contributor Author

@BenjaminBossan thank you for clarification, I fixed the wrong import order.

Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looking great on my side, thanks for your huge work!
Could you link the diffusion model example on the README , for example at the end of this section? https://github.com/huggingface/peft#parameter-efficient-tuning-of-diffusion-models

@kovalexal
Copy link
Contributor Author

@younesbelkada done! 😊

@younesbelkada
Copy link
Contributor

Thanks! 🎉

@kovalexal
Copy link
Contributor Author

@younesbelkada maybe we should also mention civitai.com adapters conversion script? Probably it could be done after I finish the next adapter and work on SDXL conversion script also.

@younesbelkada
Copy link
Contributor

@kovalexal yes good idea! Yes no problem we can do it after this PR gets merged 🙏

@kovalexal
Copy link
Contributor Author

@BenjaminBossan sorry to disturb you, I am no sure, are there any open questions left?

Could you please merge this PR by chance as I don't have access to perform a merge?

@BenjaminBossan
Copy link
Member

We wanted to give @pacman100 a chance to take a final look before merging.

@BenjaminBossan BenjaminBossan merged commit 7a5f17f into huggingface:main Oct 2, 2023
@BenjaminBossan
Copy link
Member

Thanks so much @kovalexal, great addition to PEFT.

cyyever pushed a commit to cyyever/peft that referenced this pull request Sep 4, 2025
…ace#956)

* failing test
Co-authored-by: Shoaib Burq <saburq@gmail.com>

* merge initial peft model
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants

Comments