Skip to content

[LoRA] allow fine-tuning of the text encoder with LoRA (using peft) #2719

@sayakpaul

Description

@sayakpaul

We have had many requests (rightfully so) for allowing to fine-tune the text encoder with LoRA (such as #2683). This is quite useful for improving the quality of the generated samples. This issue thread aims to discuss a solution candidate that combines our LoRA attention processors and peft.

Proposed solution candidate

We stick to using AttnProcsLayers and LoRAAttnProcessor for handling the UNet:

from diffusers.loaders import AttnProcsLayers
from diffusers.models.attention_processor import LoRAAttnProcessor
from diffusers import UNet2DConditionModel

model_id = "runwayml/stable-diffusion-v1-5"
output_dir = "demo-peft"

unet = UNet2DConditionModel.from_pretrained(
    model_id, subfolder="unet", revision=None
)

lora_attn_procs = {}
for name in unet.attn_processors.keys():
    cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
    if name.startswith("mid_block"):
        hidden_size = unet.config.block_out_channels[-1]
    elif name.startswith("up_blocks"):
        block_id = int(name[len("up_blocks.")])
        hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
    elif name.startswith("down_blocks"):
        block_id = int(name[len("down_blocks.")])
        hidden_size = unet.config.block_out_channels[block_id]

    lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim)

unet.set_attn_processor(lora_attn_procs)
unet.save_attn_procs(output_dir)

For the text encoder, we do:

from peft import LoraConfig, get_peft_model
from transformers import CLIPTextModel

text_encoder = CLIPTextModel.from_pretrained(
    model_id, subfolder="text_encoder", revision=None
)

config = LoraConfig(
    r=4,
    lora_alpha=32,
    target_modules=["q_proj", "v_proj"],
    lora_dropout=0.0,
    bias="none",
)
lora_text_encoder = get_peft_model(text_encoder, config)
lora_text_encoder.save_pretrained(output_dir)

Now, if we do:

ls -lh {output_dir}

It should give us:

total 3.8M
-rw-r--r-- 1 root root  363 Mar 17 05:28 adapter_config.json
-rw-r--r-- 1 root root 593K Mar 17 05:28 adapter_model.bin
-rw-r--r-- 1 root root 3.2M Mar 17 05:28 pytorch_lora_weights.bin

We can push all these files nicely to the Hub thanks to huggingface_hub. The final repository on the Hub would look like this: https://huggingface.co/sayakpaul/fine-tuned-dreambooth/tree/main.

We can then do the loading part like so:

from peft import PeftConfig, PeftModel
from diffusers import DiffusionPipeline

# `model_id` can be determined programmatically from `repo_name`.
pipe = DiffusionPipeline.from_pretrained(model_id)

config = PeftConfig.from_pretrained(repo_name)
text_encoder = pipe.text_encoder
pipe.text_encoder = PeftModel.from_pretrained(text_encoder, repo_name)
pipe.unet.load_attn_procs(repo_name)
image = pipe("A picture of a dog in a bucket", num_inference_steps=10).images[0]

This is the Colab Notebook, where this above PoC can be found full-fledged.

One could argue that we can also use peft for the UNet part (it's possible). IIUC, in that case, we might have to maintain two separate repositories -- one for the UNet and another one for the text encoder. However, @pacman100 please correct me if I am wrong.

In any case, if the above (dual repo creation) issue is sorted out from peft, we can potentially deprecate our LoRA-related utilities.

If the above design looks good to you, I will drop a PR to modify the LoRA DreamBooth example.

Cc: @patrickvonplaten @williamberman

Metadata

Metadata

Assignees

Labels

staleIssues that haven't received updates

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions