From 97fa1cac4e0e613780b5d9a5bf79ba688a78794a Mon Sep 17 00:00:00 2001 From: Sajad Norouzi Date: Mon, 29 Jan 2024 12:18:41 +0000 Subject: [PATCH 1/2] Fix mixed precision fine-tuning for text-to-image-lora-sdxl example. --- .../train_text_to_image_lora_sdxl.py | 62 ++++++++++++------- 1 file changed, 41 insertions(+), 21 deletions(-) diff --git a/examples/text_to_image/train_text_to_image_lora_sdxl.py b/examples/text_to_image/train_text_to_image_lora_sdxl.py index 24e182c895d4..1a42bcc74230 100644 --- a/examples/text_to_image/train_text_to_image_lora_sdxl.py +++ b/examples/text_to_image/train_text_to_image_lora_sdxl.py @@ -35,7 +35,7 @@ from datasets import load_dataset from huggingface_hub import create_repo, upload_folder from packaging import version -from peft import LoraConfig +from peft import LoraConfig, set_peft_model_state_dict from peft.utils import get_peft_model_state_dict from torchvision import transforms from torchvision.transforms.functional import crop @@ -51,8 +51,13 @@ ) from diffusers.loaders import LoraLoaderMixin from diffusers.optimization import get_scheduler -from diffusers.training_utils import cast_training_params, compute_snr -from diffusers.utils import check_min_version, convert_state_dict_to_diffusers, is_wandb_available +from diffusers.training_utils import _set_state_dict_into_text_encoder, cast_training_params, compute_snr +from diffusers.utils import ( + check_min_version, + convert_state_dict_to_diffusers, + convert_unet_state_dict_to_peft, + is_wandb_available, +) from diffusers.utils.import_utils import is_xformers_available from diffusers.utils.torch_utils import is_compiled_module @@ -629,14 +634,6 @@ def main(args): text_encoder_one.add_adapter(text_lora_config) text_encoder_two.add_adapter(text_lora_config) - # Make sure the trainable params are in float32. - if args.mixed_precision == "fp16": - models = [unet] - if args.train_text_encoder: - models.extend([text_encoder_one, text_encoder_two]) - # only upcast trainable parameters (LoRA) into fp32 - cast_training_params(models, dtype=torch.float32) - def unwrap_model(model): model = accelerator.unwrap_model(model) model = model._orig_mod if is_compiled_module(model) else model @@ -693,18 +690,34 @@ def load_model_hook(models, input_dir): else: raise ValueError(f"unexpected save model: {model.__class__}") - lora_state_dict, network_alphas = LoraLoaderMixin.lora_state_dict(input_dir) - LoraLoaderMixin.load_lora_into_unet(lora_state_dict, network_alphas=network_alphas, unet=unet_) + lora_state_dict, _ = LoraLoaderMixin.lora_state_dict(input_dir) + unet_state_dict = {f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")} + unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict) + incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default") + if incompatible_keys is not None: + # check only for unexpected keys + unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None) + if unexpected_keys: + logger.warning( + f"Loading adapter weights from state_dict led to unexpected keys not found in the model: " + f" {unexpected_keys}. " + ) - text_encoder_state_dict = {k: v for k, v in lora_state_dict.items() if "text_encoder." in k} - LoraLoaderMixin.load_lora_into_text_encoder( - text_encoder_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_one_ - ) + if args.train_text_encoder: + _set_state_dict_into_text_encoder(lora_state_dict, prefix="text_encoder.", text_encoder=text_encoder_one_) - text_encoder_2_state_dict = {k: v for k, v in lora_state_dict.items() if "text_encoder_2." in k} - LoraLoaderMixin.load_lora_into_text_encoder( - text_encoder_2_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_two_ - ) + _set_state_dict_into_text_encoder( + lora_state_dict, prefix="text_encoder_2.", text_encoder=text_encoder_one_ + ) + + # Make sure the trainable params are in float32. This is again needed since the base models + # are in `weight_dtype`. More details: + # https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804 + if args.mixed_precision == "fp16": + models = [unet_] + if args.train_text_encoder: + models.extend([text_encoder_one_, text_encoder_two_]) + cast_training_params(models, dtype=torch.float32) accelerator.register_save_state_pre_hook(save_model_hook) accelerator.register_load_state_pre_hook(load_model_hook) @@ -725,6 +738,13 @@ def load_model_hook(models, input_dir): args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes ) + # Make sure the trainable params are in float32. + if args.mixed_precision == "fp16": + models = [unet] + if args.train_text_encoder: + models.extend([text_encoder_one, text_encoder_two]) + cast_training_params(models, dtype=torch.float32) + # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs if args.use_8bit_adam: try: From 83fbca1fec14f1048bacf17e9c03feb02d31c104 Mon Sep 17 00:00:00 2001 From: Sajad Norouzi Date: Mon, 29 Jan 2024 17:08:39 +0000 Subject: [PATCH 2/2] fix text_encoder_two bug. --- examples/text_to_image/train_text_to_image_lora_sdxl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/text_to_image/train_text_to_image_lora_sdxl.py b/examples/text_to_image/train_text_to_image_lora_sdxl.py index 1a42bcc74230..78a412f1818f 100644 --- a/examples/text_to_image/train_text_to_image_lora_sdxl.py +++ b/examples/text_to_image/train_text_to_image_lora_sdxl.py @@ -707,7 +707,7 @@ def load_model_hook(models, input_dir): _set_state_dict_into_text_encoder(lora_state_dict, prefix="text_encoder.", text_encoder=text_encoder_one_) _set_state_dict_into_text_encoder( - lora_state_dict, prefix="text_encoder_2.", text_encoder=text_encoder_one_ + lora_state_dict, prefix="text_encoder_2.", text_encoder=text_encoder_two_ ) # Make sure the trainable params are in float32. This is again needed since the base models