From b6de725f10560b514bc3335fc29aecd563c9224b Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Sun, 10 Dec 2023 10:24:19 +0530 Subject: [PATCH 1/4] fix: unscale fp16 gradient problem --- .../text_to_image/train_text_to_image_lora.py | 67 ++++++++++--------- 1 file changed, 37 insertions(+), 30 deletions(-) diff --git a/examples/text_to_image/train_text_to_image_lora.py b/examples/text_to_image/train_text_to_image_lora.py index b63500f906a8..2f9ad759e8b7 100644 --- a/examples/text_to_image/train_text_to_image_lora.py +++ b/examples/text_to_image/train_text_to_image_lora.py @@ -493,7 +493,11 @@ def main(): vae.to(accelerator.device, dtype=weight_dtype) text_encoder.to(accelerator.device, dtype=weight_dtype) + # Add adapter and make sure the trainable params are in float32. unet.add_adapter(unet_lora_config) + for param in unet.parameters(): + if param.requires_grad: + param.data = param.to(torch.float32) if args.enable_xformers_memory_efficient_attention: if is_xformers_available(): @@ -919,39 +923,42 @@ def collate_fn(examples): ignore_patterns=["step_*", "epoch_*"], ) - # Final inference - # Load previous pipeline - pipeline = DiffusionPipeline.from_pretrained( - args.pretrained_model_name_or_path, revision=args.revision, variant=args.variant, torch_dtype=weight_dtype - ) - pipeline = pipeline.to(accelerator.device) + # Final inference + # Load previous pipeline + if args.validation_prompt is not None: + pipeline = DiffusionPipeline.from_pretrained( + args.pretrained_model_name_or_path, + revision=args.revision, + variant=args.variant, + torch_dtype=weight_dtype, + ) + pipeline = pipeline.to(accelerator.device) - # load attention processors - pipeline.unet.load_attn_procs(args.output_dir) + # load attention processors + pipeline.load_lora_weights(args.output_dir) - # run inference - generator = torch.Generator(device=accelerator.device) - if args.seed is not None: - generator = generator.manual_seed(args.seed) - images = [] - for _ in range(args.num_validation_images): - images.append(pipeline(args.validation_prompt, num_inference_steps=30, generator=generator).images[0]) + # run inference + generator = torch.Generator(device=accelerator.device) + if args.seed is not None: + generator = generator.manual_seed(args.seed) + images = [] + for _ in range(args.num_validation_images): + images.append(pipeline(args.validation_prompt, num_inference_steps=30, generator=generator).images[0]) - if accelerator.is_main_process: - for tracker in accelerator.trackers: - if len(images) != 0: - if tracker.name == "tensorboard": - np_images = np.stack([np.asarray(img) for img in images]) - tracker.writer.add_images("test", np_images, epoch, dataformats="NHWC") - if tracker.name == "wandb": - tracker.log( - { - "test": [ - wandb.Image(image, caption=f"{i}: {args.validation_prompt}") - for i, image in enumerate(images) - ] - } - ) + for tracker in accelerator.trackers: + if len(images) != 0: + if tracker.name == "tensorboard": + np_images = np.stack([np.asarray(img) for img in images]) + tracker.writer.add_images("test", np_images, epoch, dataformats="NHWC") + if tracker.name == "wandb": + tracker.log( + { + "test": [ + wandb.Image(image, caption=f"{i}: {args.validation_prompt}") + for i, image in enumerate(images) + ] + } + ) accelerator.end_training() From 32bd473e629b1ef227d77fa604c683b24c78a05a Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Mon, 11 Dec 2023 08:47:35 +0530 Subject: [PATCH 2/4] fix for dreambooth lora sdxl --- examples/dreambooth/train_dreambooth_lora_sdxl.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/examples/dreambooth/train_dreambooth_lora_sdxl.py b/examples/dreambooth/train_dreambooth_lora_sdxl.py index c8a9a6ad4812..991f9e5cf6d5 100644 --- a/examples/dreambooth/train_dreambooth_lora_sdxl.py +++ b/examples/dreambooth/train_dreambooth_lora_sdxl.py @@ -1024,6 +1024,12 @@ def main(args): text_encoder_one.add_adapter(text_lora_config) text_encoder_two.add_adapter(text_lora_config) + # Make sure the trainable params are in float32. + for model in [unet, text_encoder_one, text_encoder_two]: + for param in model.parameters(): + if param.requires_grad: + param.data = param.to(torch.float32) + # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format def save_model_hook(models, weights, output_dir): if accelerator.is_main_process: From 8ac462ba8552c9cdd48b1111da616c8bb26c6583 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Mon, 18 Dec 2023 15:37:46 +0530 Subject: [PATCH 3/4] make the type-casting conditional. --- examples/dreambooth/train_dreambooth_lora_sdxl.py | 12 ++++++++---- examples/text_to_image/train_text_to_image_lora.py | 7 ++++--- 2 files changed, 12 insertions(+), 7 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora_sdxl.py b/examples/dreambooth/train_dreambooth_lora_sdxl.py index 991f9e5cf6d5..a155fcabe5cf 100644 --- a/examples/dreambooth/train_dreambooth_lora_sdxl.py +++ b/examples/dreambooth/train_dreambooth_lora_sdxl.py @@ -1025,10 +1025,14 @@ def main(args): text_encoder_two.add_adapter(text_lora_config) # Make sure the trainable params are in float32. - for model in [unet, text_encoder_one, text_encoder_two]: - for param in model.parameters(): - if param.requires_grad: - param.data = param.to(torch.float32) + if args.mixed_precision == "fp16": + models = [unet] + if args.train_text_encoder: + models.extend([text_encoder_one, text_encoder_two]) + for model in models: + for param in model.parameters(): + if param.requires_grad: + param.data = param.to(torch.float32) # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format def save_model_hook(models, weights, output_dir): diff --git a/examples/text_to_image/train_text_to_image_lora.py b/examples/text_to_image/train_text_to_image_lora.py index 2f9ad759e8b7..304116dded52 100644 --- a/examples/text_to_image/train_text_to_image_lora.py +++ b/examples/text_to_image/train_text_to_image_lora.py @@ -495,9 +495,10 @@ def main(): # Add adapter and make sure the trainable params are in float32. unet.add_adapter(unet_lora_config) - for param in unet.parameters(): - if param.requires_grad: - param.data = param.to(torch.float32) + if args.mixed_precision == "fp16": + for param in unet.parameters(): + if param.requires_grad: + param.data = param.to(torch.float32) if args.enable_xformers_memory_efficient_attention: if is_xformers_available(): From 18e6bf748b693ae86db4b3cc8efdadb0e310032d Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 19 Dec 2023 08:17:59 +0530 Subject: [PATCH 4/4] Apply suggestions from code review Co-authored-by: Patrick von Platen --- examples/dreambooth/train_dreambooth_lora_sdxl.py | 1 + examples/text_to_image/train_text_to_image_lora.py | 1 + 2 files changed, 2 insertions(+) diff --git a/examples/dreambooth/train_dreambooth_lora_sdxl.py b/examples/dreambooth/train_dreambooth_lora_sdxl.py index a155fcabe5cf..e306847dc36b 100644 --- a/examples/dreambooth/train_dreambooth_lora_sdxl.py +++ b/examples/dreambooth/train_dreambooth_lora_sdxl.py @@ -1031,6 +1031,7 @@ def main(args): models.extend([text_encoder_one, text_encoder_two]) for model in models: for param in model.parameters(): + # only upcast trainable parameters (LoRA) into fp32 if param.requires_grad: param.data = param.to(torch.float32) diff --git a/examples/text_to_image/train_text_to_image_lora.py b/examples/text_to_image/train_text_to_image_lora.py index 304116dded52..7cfb231c82a2 100644 --- a/examples/text_to_image/train_text_to_image_lora.py +++ b/examples/text_to_image/train_text_to_image_lora.py @@ -497,6 +497,7 @@ def main(): unet.add_adapter(unet_lora_config) if args.mixed_precision == "fp16": for param in unet.parameters(): + # only upcast trainable parameters (LoRA) into fp32 if param.requires_grad: param.data = param.to(torch.float32)