diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py index ad37363b7d30..2f1002012473 100644 --- a/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py +++ b/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py @@ -15,7 +15,6 @@ import argparse import gc -import hashlib import itertools import logging import math @@ -36,7 +35,10 @@ from accelerate.logging import get_logger from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed from huggingface_hub import create_repo, upload_folder +from huggingface_hub.utils import insecure_hashlib from packaging import version +from peft import LoraConfig +from peft.utils import get_peft_model_state_dict from PIL import Image from PIL.ImageOps import exif_transpose from safetensors.torch import save_file @@ -54,9 +56,8 @@ UNet2DConditionModel, ) from diffusers.loaders import LoraLoaderMixin -from diffusers.models.lora import LoRALinearLayer from diffusers.optimization import get_scheduler -from diffusers.training_utils import compute_snr, unet_lora_state_dict +from diffusers.training_utils import compute_snr from diffusers.utils import check_min_version, is_wandb_available from diffusers.utils.import_utils import is_xformers_available @@ -67,39 +68,6 @@ logger = get_logger(__name__) -# TODO: This function should be removed once training scripts are rewritten in PEFT -def text_encoder_lora_state_dict(text_encoder): - state_dict = {} - - def text_encoder_attn_modules(text_encoder): - from transformers import CLIPTextModel, CLIPTextModelWithProjection - - attn_modules = [] - - if isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection)): - for i, layer in enumerate(text_encoder.text_model.encoder.layers): - name = f"text_model.encoder.layers.{i}.self_attn" - mod = layer.self_attn - attn_modules.append((name, mod)) - - return attn_modules - - for name, module in text_encoder_attn_modules(text_encoder): - for k, v in module.q_proj.lora_linear_layer.state_dict().items(): - state_dict[f"{name}.q_proj.lora_linear_layer.{k}"] = v - - for k, v in module.k_proj.lora_linear_layer.state_dict().items(): - state_dict[f"{name}.k_proj.lora_linear_layer.{k}"] = v - - for k, v in module.v_proj.lora_linear_layer.state_dict().items(): - state_dict[f"{name}.v_proj.lora_linear_layer.{k}"] = v - - for k, v in module.out_proj.lora_linear_layer.state_dict().items(): - state_dict[f"{name}.out_proj.lora_linear_layer.{k}"] = v - - return state_dict - - def save_model_card( repo_id: str, images=None, @@ -1123,7 +1091,7 @@ def main(args): images = pipeline(example["prompt"]).images for i, image in enumerate(images): - hash_image = hashlib.sha1(image.tobytes()).hexdigest() + hash_image = insecure_hashlib.sha1(image.tobytes()).hexdigest() image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg" image.save(image_filename) @@ -1262,76 +1230,36 @@ def main(args): text_encoder_two.gradient_checkpointing_enable() # now we will add new LoRA weights to the attention layers - # Set correct lora layers - unet_lora_parameters = [] - for attn_processor_name, attn_processor in unet.attn_processors.items(): - # Parse the attention module. - attn_module = unet - for n in attn_processor_name.split(".")[:-1]: - attn_module = getattr(attn_module, n) - - # Set the `lora_layer` attribute of the attention-related matrices. - attn_module.to_q.set_lora_layer( - LoRALinearLayer( - in_features=attn_module.to_q.in_features, out_features=attn_module.to_q.out_features, rank=args.rank - ) - ) - attn_module.to_k.set_lora_layer( - LoRALinearLayer( - in_features=attn_module.to_k.in_features, out_features=attn_module.to_k.out_features, rank=args.rank - ) - ) - attn_module.to_v.set_lora_layer( - LoRALinearLayer( - in_features=attn_module.to_v.in_features, out_features=attn_module.to_v.out_features, rank=args.rank - ) - ) - attn_module.to_out[0].set_lora_layer( - LoRALinearLayer( - in_features=attn_module.to_out[0].in_features, - out_features=attn_module.to_out[0].out_features, - rank=args.rank, - ) - ) - - # Accumulate the LoRA params to optimize. - unet_lora_parameters.extend(attn_module.to_q.lora_layer.parameters()) - unet_lora_parameters.extend(attn_module.to_k.lora_layer.parameters()) - unet_lora_parameters.extend(attn_module.to_v.lora_layer.parameters()) - unet_lora_parameters.extend(attn_module.to_out[0].lora_layer.parameters()) + unet_lora_config = LoraConfig( + r=args.rank, + lora_alpha=args.rank, + init_lora_weights="gaussian", + target_modules=["to_k", "to_q", "to_v", "to_out.0"], + ) + unet.add_adapter(unet_lora_config) # The text encoder comes from 🤗 transformers, so we cannot directly modify it. # So, instead, we monkey-patch the forward calls of its attention-blocks. if args.train_text_encoder: - # ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16 - text_lora_parameters_one = LoraLoaderMixin._modify_text_encoder( - text_encoder_one, dtype=torch.float32, rank=args.rank - ) - text_lora_parameters_two = LoraLoaderMixin._modify_text_encoder( - text_encoder_two, dtype=torch.float32, rank=args.rank + text_lora_config = LoraConfig( + r=args.rank, + lora_alpha=args.rank, + init_lora_weights="gaussian", + target_modules=["q_proj", "k_proj", "v_proj", "out_proj"], ) + text_encoder_one.add_adapter(text_lora_config) + text_encoder_two.add_adapter(text_lora_config) - # if we use textual inversion, we freeze all parameters except for the token embeddings - # in text encoder - elif args.train_text_encoder_ti: - text_lora_parameters_one = [] - for name, param in text_encoder_one.named_parameters(): - if "token_embedding" in name: - # ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16 - param = param.to(dtype=torch.float32) - param.requires_grad = True - text_lora_parameters_one.append(param) - else: - param.requires_grad = False - text_lora_parameters_two = [] - for name, param in text_encoder_two.named_parameters(): - if "token_embedding" in name: - # ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16 - param = param.to(dtype=torch.float32) - param.requires_grad = True - text_lora_parameters_two.append(param) - else: - param.requires_grad = False + # Make sure the trainable params are in float32. + if args.mixed_precision == "fp16": + models = [unet] + if args.train_text_encoder: + models.extend([text_encoder_one, text_encoder_two]) + for model in models: + for param in model.parameters(): + # only upcast trainable parameters (LoRA) into fp32 + if param.requires_grad: + param.data = param.to(torch.float32) # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format def save_model_hook(models, weights, output_dir): @@ -1344,11 +1272,11 @@ def save_model_hook(models, weights, output_dir): for model in models: if isinstance(model, type(accelerator.unwrap_model(unet))): - unet_lora_layers_to_save = unet_lora_state_dict(model) + unet_lora_layers_to_save = get_peft_model_state_dict(model) elif isinstance(model, type(accelerator.unwrap_model(text_encoder_one))): - text_encoder_one_lora_layers_to_save = text_encoder_lora_state_dict(model) + text_encoder_one_lora_layers_to_save = get_peft_model_state_dict(model) elif isinstance(model, type(accelerator.unwrap_model(text_encoder_two))): - text_encoder_two_lora_layers_to_save = text_encoder_lora_state_dict(model) + text_encoder_two_lora_layers_to_save = get_peft_model_state_dict(model) else: raise ValueError(f"unexpected save model: {model.__class__}") @@ -1405,25 +1333,47 @@ def load_model_hook(models, input_dir): args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes ) + unet_lora_parameters = list(filter(lambda p: p.requires_grad, unet.parameters())) + # If neither --train_text_encoder nor --train_text_encoder_ti, text_encoders remain frozen during training freeze_text_encoder = not (args.train_text_encoder or args.train_text_encoder_ti) + if args.train_text_encoder: + text_lora_parameters_one = list(filter(lambda p: p.requires_grad, text_encoder_one.parameters())) + text_lora_parameters_two = list(filter(lambda p: p.requires_grad, text_encoder_two.parameters())) + # if we use textual inversion, we freeze all parameters except for the token embeddings + # in text encoder + elif args.train_text_encoder_ti: + text_lora_parameters_one = [] + for name, param in text_encoder_one.named_parameters(): + if "token_embedding" in name: + # ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16 + param = param.to(dtype=torch.float32) + param.requires_grad = True + text_lora_parameters_one.append(param) + else: + param.requires_grad = False + text_lora_parameters_two = [] + for name, param in text_encoder_two.named_parameters(): + if "token_embedding" in name: + # ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16 + param = param.to(dtype=torch.float32) + param.requires_grad = True + text_lora_parameters_two.append(param) + else: + param.requires_grad = False # Optimization parameters unet_lora_parameters_with_lr = {"params": unet_lora_parameters, "lr": args.learning_rate} if not freeze_text_encoder: # different learning rate for text encoder and unet text_lora_parameters_one_with_lr = { "params": text_lora_parameters_one, - "weight_decay": args.adam_weight_decay_text_encoder - if args.adam_weight_decay_text_encoder - else args.adam_weight_decay, + "weight_decay": args.adam_weight_decay_text_encoder, "lr": args.text_encoder_lr if args.text_encoder_lr else args.learning_rate, } text_lora_parameters_two_with_lr = { "params": text_lora_parameters_two, - "weight_decay": args.adam_weight_decay_text_encoder - if args.adam_weight_decay_text_encoder - else args.adam_weight_decay, + "weight_decay": args.adam_weight_decay_text_encoder, "lr": args.text_encoder_lr if args.text_encoder_lr else args.learning_rate, } params_to_optimize = [ @@ -1839,9 +1789,17 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): # Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556. # Since we predict the noise instead of x_0, the original formulation is slightly changed. # This is discussed in Section 4.2 of the same paper. - snr = compute_snr(noise_scheduler, timesteps) + + if args.with_prior_preservation: + # if we're using prior preservation, we calc snr for instance loss only - + # and hence only need timesteps corresponding to instance images + snr_timesteps, _ = torch.chunk(timesteps, 2, dim=0) + else: + snr_timesteps = timesteps + + snr = compute_snr(noise_scheduler, snr_timesteps) base_weight = ( - torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr + torch.stack([snr, args.snr_gamma * torch.ones_like(snr_timesteps)], dim=1).min(dim=1)[0] / snr ) if noise_scheduler.config.prediction_type == "v_prediction": @@ -1995,13 +1953,13 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): if accelerator.is_main_process: unet = accelerator.unwrap_model(unet) unet = unet.to(torch.float32) - unet_lora_layers = unet_lora_state_dict(unet) + unet_lora_layers = get_peft_model_state_dict(unet) if args.train_text_encoder: text_encoder_one = accelerator.unwrap_model(text_encoder_one) - text_encoder_lora_layers = text_encoder_lora_state_dict(text_encoder_one.to(torch.float32)) + text_encoder_lora_layers = get_peft_model_state_dict(text_encoder_one.to(torch.float32)) text_encoder_two = accelerator.unwrap_model(text_encoder_two) - text_encoder_2_lora_layers = text_encoder_lora_state_dict(text_encoder_two.to(torch.float32)) + text_encoder_2_lora_layers = get_peft_model_state_dict(text_encoder_two.to(torch.float32)) else: text_encoder_lora_layers = None text_encoder_2_lora_layers = None @@ -2047,7 +2005,9 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): # load attention processors pipeline.load_lora_weights(args.output_dir) - # run inference + # run inference + images = [] + if args.validation_prompt and args.num_validation_images > 0: pipeline = pipeline.to(accelerator.device) generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None images = [