From b5322915dbe6f50284b2eb9e63ffb261c399048c Mon Sep 17 00:00:00 2001 From: asrimanth Date: Sun, 4 Feb 2024 05:07:20 +0000 Subject: [PATCH 1/5] Fix: training resume from fp16 for lcm distill lora sdxl --- .../train_lcm_distill_lora_sdxl.py | 27 +++++++++++++++++-- 1 file changed, 25 insertions(+), 2 deletions(-) diff --git a/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py b/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py index 44a58fa2a815..f5d5477fb4d1 100644 --- a/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py +++ b/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py @@ -36,7 +36,7 @@ from datasets import load_dataset from huggingface_hub import create_repo, upload_folder from packaging import version -from peft import LoraConfig, get_peft_model_state_dict +from peft import LoraConfig, set_peft_model_state_dict, get_peft_model_state_dict from torchvision import transforms from torchvision.transforms.functional import crop from tqdm.auto import tqdm @@ -52,7 +52,12 @@ ) from diffusers.optimization import get_scheduler from diffusers.training_utils import cast_training_params, resolve_interpolation_mode -from diffusers.utils import check_min_version, convert_state_dict_to_diffusers, is_wandb_available +from diffusers.utils import ( + check_min_version, + convert_state_dict_to_diffusers, + convert_unet_state_dict_to_peft, + is_wandb_available, +) from diffusers.utils.import_utils import is_xformers_available @@ -889,11 +894,29 @@ def load_model_hook(models, input_dir): unet_ = accelerator.unwrap_model(unet) lora_state_dict, network_alphas = StableDiffusionXLPipeline.lora_state_dict(input_dir) StableDiffusionXLPipeline.load_lora_into_unet(lora_state_dict, network_alphas=network_alphas, unet=unet_) + unet_state_dict = {f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")} + unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict) + incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default") + if incompatible_keys is not None: + # check only for unexpected keys + unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None) + if unexpected_keys: + logger.warning( + f"Loading adapter weights from state_dict led to unexpected keys not found in the model: " + f" {unexpected_keys}. " + ) for _ in range(len(models)): # pop models so that they are not loaded again models.pop() + # Make sure the trainable params are in float32. This is again needed since the base models + # are in `weight_dtype`. More details: + # https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804 + if args.mixed_precision == "fp16": + models_to_load_in_fp32 = [unet_] + cast_training_params(models_to_load_in_fp32, dtype=torch.float32) + accelerator.register_save_state_pre_hook(save_model_hook) accelerator.register_load_state_pre_hook(load_model_hook) From cbea2b13cf758eb2bb0aae9a8582683a6164cfe0 Mon Sep 17 00:00:00 2001 From: asrimanth Date: Sun, 4 Feb 2024 17:12:59 +0000 Subject: [PATCH 2/5] Fix coding quality - run linter --- .../consistency_distillation/train_lcm_distill_lora_sdxl.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py b/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py index f5d5477fb4d1..036ca1da8130 100644 --- a/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py +++ b/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py @@ -36,7 +36,7 @@ from datasets import load_dataset from huggingface_hub import create_repo, upload_folder from packaging import version -from peft import LoraConfig, set_peft_model_state_dict, get_peft_model_state_dict +from peft import LoraConfig, get_peft_model_state_dict, set_peft_model_state_dict from torchvision import transforms from torchvision.transforms.functional import crop from tqdm.auto import tqdm @@ -894,7 +894,9 @@ def load_model_hook(models, input_dir): unet_ = accelerator.unwrap_model(unet) lora_state_dict, network_alphas = StableDiffusionXLPipeline.lora_state_dict(input_dir) StableDiffusionXLPipeline.load_lora_into_unet(lora_state_dict, network_alphas=network_alphas, unet=unet_) - unet_state_dict = {f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")} + unet_state_dict = { + f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.") + } unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict) incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default") if incompatible_keys is not None: From d5ed3352d38cf10d57949f74868e533fd10ff50f Mon Sep 17 00:00:00 2001 From: asrimanth Date: Wed, 7 Feb 2024 00:00:43 +0000 Subject: [PATCH 3/5] Fix 1 - shift mixed precision cast before optimizer --- .../train_lcm_distill_lora_sdxl.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py b/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py index 036ca1da8130..75848a7ded98 100644 --- a/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py +++ b/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py @@ -863,11 +863,6 @@ def main(args): ) unet.add_adapter(lora_config) - # Make sure the trainable params are in float32. - if args.mixed_precision == "fp16": - # only upcast trainable parameters (LoRA) into fp32 - cast_training_params(unet, dtype=torch.float32) - # Also move the alpha and sigma noise schedules to accelerator.device. alpha_schedule = alpha_schedule.to(accelerator.device) sigma_schedule = sigma_schedule.to(accelerator.device) @@ -916,8 +911,8 @@ def load_model_hook(models, input_dir): # are in `weight_dtype`. More details: # https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804 if args.mixed_precision == "fp16": - models_to_load_in_fp32 = [unet_] - cast_training_params(models_to_load_in_fp32, dtype=torch.float32) + models = [unet_] + cast_training_params(models, dtype=torch.float32) accelerator.register_save_state_pre_hook(save_model_hook) accelerator.register_load_state_pre_hook(load_model_hook) @@ -1117,6 +1112,11 @@ def compute_time_ids(original_size, crops_coords_top_left): args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes ) + # Make sure the trainable params are in float32. + if args.mixed_precision == "fp16": + # only upcast trainable parameters (LoRA) into fp32 + cast_training_params(unet, dtype=torch.float32) + lr_scheduler = get_scheduler( args.lr_scheduler, optimizer=optimizer, From e6a1f827dc2782b19d1fe99ceabc8893d53de097 Mon Sep 17 00:00:00 2001 From: asrimanth Date: Thu, 8 Feb 2024 01:42:14 +0000 Subject: [PATCH 4/5] Fix 2 - State dict errors by removing load_lora_into_unet --- .../train_lcm_distill_lora_sdxl.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py b/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py index 75848a7ded98..82bee0b0a5e2 100644 --- a/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py +++ b/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py @@ -310,7 +310,7 @@ def parse_args(): parser.add_argument( "--cache_dir", type=str, - default=None, + default="/workspace/cache", help="The directory where the downloaded models and datasets will be stored.", ) parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") @@ -887,8 +887,7 @@ def save_model_hook(models, weights, output_dir): def load_model_hook(models, input_dir): # load the LoRA into the model unet_ = accelerator.unwrap_model(unet) - lora_state_dict, network_alphas = StableDiffusionXLPipeline.lora_state_dict(input_dir) - StableDiffusionXLPipeline.load_lora_into_unet(lora_state_dict, network_alphas=network_alphas, unet=unet_) + lora_state_dict, _ = StableDiffusionXLPipeline.lora_state_dict(input_dir) unet_state_dict = { f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.") } @@ -911,8 +910,7 @@ def load_model_hook(models, input_dir): # are in `weight_dtype`. More details: # https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804 if args.mixed_precision == "fp16": - models = [unet_] - cast_training_params(models, dtype=torch.float32) + cast_training_params(unet_, dtype=torch.float32) accelerator.register_save_state_pre_hook(save_model_hook) accelerator.register_load_state_pre_hook(load_model_hook) From 99e529024a6909fa56b5c84254af02f6e735ebdc Mon Sep 17 00:00:00 2001 From: Srimanth Agastyaraju <30816357+asrimanth@users.noreply.github.com> Date: Wed, 7 Feb 2024 23:46:39 -0500 Subject: [PATCH 5/5] Update train_lcm_distill_lora_sdxl.py - Revert default cache dir to None --- .../consistency_distillation/train_lcm_distill_lora_sdxl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py b/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py index 82bee0b0a5e2..34de0d048e8b 100644 --- a/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py +++ b/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py @@ -310,7 +310,7 @@ def parse_args(): parser.add_argument( "--cache_dir", type=str, - default="/workspace/cache", + default=None, help="The directory where the downloaded models and datasets will be stored.", ) parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")