-
Notifications
You must be signed in to change notification settings - Fork 6.7k
Description
We have had many requests (rightfully so) for allowing to fine-tune the text encoder with LoRA (such as #2683). This is quite useful for improving the quality of the generated samples. This issue thread aims to discuss a solution candidate that combines our LoRA attention processors and peft.
Proposed solution candidate
We stick to using AttnProcsLayers and LoRAAttnProcessor for handling the UNet:
from diffusers.loaders import AttnProcsLayers
from diffusers.models.attention_processor import LoRAAttnProcessor
from diffusers import UNet2DConditionModel
model_id = "runwayml/stable-diffusion-v1-5"
output_dir = "demo-peft"
unet = UNet2DConditionModel.from_pretrained(
model_id, subfolder="unet", revision=None
)
lora_attn_procs = {}
for name in unet.attn_processors.keys():
cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
if name.startswith("mid_block"):
hidden_size = unet.config.block_out_channels[-1]
elif name.startswith("up_blocks"):
block_id = int(name[len("up_blocks.")])
hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
elif name.startswith("down_blocks"):
block_id = int(name[len("down_blocks.")])
hidden_size = unet.config.block_out_channels[block_id]
lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim)
unet.set_attn_processor(lora_attn_procs)
unet.save_attn_procs(output_dir)For the text encoder, we do:
from peft import LoraConfig, get_peft_model
from transformers import CLIPTextModel
text_encoder = CLIPTextModel.from_pretrained(
model_id, subfolder="text_encoder", revision=None
)
config = LoraConfig(
r=4,
lora_alpha=32,
target_modules=["q_proj", "v_proj"],
lora_dropout=0.0,
bias="none",
)
lora_text_encoder = get_peft_model(text_encoder, config)
lora_text_encoder.save_pretrained(output_dir)Now, if we do:
ls -lh {output_dir}It should give us:
total 3.8M
-rw-r--r-- 1 root root 363 Mar 17 05:28 adapter_config.json
-rw-r--r-- 1 root root 593K Mar 17 05:28 adapter_model.bin
-rw-r--r-- 1 root root 3.2M Mar 17 05:28 pytorch_lora_weights.bin
We can push all these files nicely to the Hub thanks to huggingface_hub. The final repository on the Hub would look like this: https://huggingface.co/sayakpaul/fine-tuned-dreambooth/tree/main.
We can then do the loading part like so:
from peft import PeftConfig, PeftModel
from diffusers import DiffusionPipeline
# `model_id` can be determined programmatically from `repo_name`.
pipe = DiffusionPipeline.from_pretrained(model_id)
config = PeftConfig.from_pretrained(repo_name)
text_encoder = pipe.text_encoder
pipe.text_encoder = PeftModel.from_pretrained(text_encoder, repo_name)
pipe.unet.load_attn_procs(repo_name)
image = pipe("A picture of a dog in a bucket", num_inference_steps=10).images[0]This is the Colab Notebook, where this above PoC can be found full-fledged.
One could argue that we can also use peft for the UNet part (it's possible). IIUC, in that case, we might have to maintain two separate repositories -- one for the UNet and another one for the text encoder. However, @pacman100 please correct me if I am wrong.
In any case, if the above (dual repo creation) issue is sorted out from peft, we can potentially deprecate our LoRA-related utilities.
If the above design looks good to you, I will drop a PR to modify the LoRA DreamBooth example.