diff --git a/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py b/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py index 44a58fa2a815..34de0d048e8b 100644 --- a/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py +++ b/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py @@ -36,7 +36,7 @@ from datasets import load_dataset from huggingface_hub import create_repo, upload_folder from packaging import version -from peft import LoraConfig, get_peft_model_state_dict +from peft import LoraConfig, get_peft_model_state_dict, set_peft_model_state_dict from torchvision import transforms from torchvision.transforms.functional import crop from tqdm.auto import tqdm @@ -52,7 +52,12 @@ ) from diffusers.optimization import get_scheduler from diffusers.training_utils import cast_training_params, resolve_interpolation_mode -from diffusers.utils import check_min_version, convert_state_dict_to_diffusers, is_wandb_available +from diffusers.utils import ( + check_min_version, + convert_state_dict_to_diffusers, + convert_unet_state_dict_to_peft, + is_wandb_available, +) from diffusers.utils.import_utils import is_xformers_available @@ -858,11 +863,6 @@ def main(args): ) unet.add_adapter(lora_config) - # Make sure the trainable params are in float32. - if args.mixed_precision == "fp16": - # only upcast trainable parameters (LoRA) into fp32 - cast_training_params(unet, dtype=torch.float32) - # Also move the alpha and sigma noise schedules to accelerator.device. alpha_schedule = alpha_schedule.to(accelerator.device) sigma_schedule = sigma_schedule.to(accelerator.device) @@ -887,13 +887,31 @@ def save_model_hook(models, weights, output_dir): def load_model_hook(models, input_dir): # load the LoRA into the model unet_ = accelerator.unwrap_model(unet) - lora_state_dict, network_alphas = StableDiffusionXLPipeline.lora_state_dict(input_dir) - StableDiffusionXLPipeline.load_lora_into_unet(lora_state_dict, network_alphas=network_alphas, unet=unet_) + lora_state_dict, _ = StableDiffusionXLPipeline.lora_state_dict(input_dir) + unet_state_dict = { + f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.") + } + unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict) + incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default") + if incompatible_keys is not None: + # check only for unexpected keys + unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None) + if unexpected_keys: + logger.warning( + f"Loading adapter weights from state_dict led to unexpected keys not found in the model: " + f" {unexpected_keys}. " + ) for _ in range(len(models)): # pop models so that they are not loaded again models.pop() + # Make sure the trainable params are in float32. This is again needed since the base models + # are in `weight_dtype`. More details: + # https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804 + if args.mixed_precision == "fp16": + cast_training_params(unet_, dtype=torch.float32) + accelerator.register_save_state_pre_hook(save_model_hook) accelerator.register_load_state_pre_hook(load_model_hook) @@ -1092,6 +1110,11 @@ def compute_time_ids(original_size, crops_coords_top_left): args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes ) + # Make sure the trainable params are in float32. + if args.mixed_precision == "fp16": + # only upcast trainable parameters (LoRA) into fp32 + cast_training_params(unet, dtype=torch.float32) + lr_scheduler = get_scheduler( args.lr_scheduler, optimizer=optimizer,