From f4ed5e84382b27744f0979a5fa9b0bf55c018ecf Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Wed, 16 Aug 2023 15:28:26 +0530 Subject: [PATCH 1/4] fix: casting issues. --- examples/text_to_image/train_text_to_image_lora_sdxl.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/examples/text_to_image/train_text_to_image_lora_sdxl.py b/examples/text_to_image/train_text_to_image_lora_sdxl.py index d7c2d07be431..7a72ed22580e 100644 --- a/examples/text_to_image/train_text_to_image_lora_sdxl.py +++ b/examples/text_to_image/train_text_to_image_lora_sdxl.py @@ -1002,9 +1002,12 @@ def collate_fn(examples): continue with accelerator.accumulate(unet): - pixel_values = batch["pixel_values"].to(dtype=weight_dtype) - # Convert images to latent space + if args.pretrained_vae_model_name_or_path is not None: + pixel_values = batch["pixel_values"].to(dtype=weight_dtype) + else: + pixel_values = batch["pixel_values"] + model_input = vae.encode(pixel_values).latent_dist.sample() model_input = model_input * vae.config.scaling_factor if args.pretrained_vae_model_name_or_path is None: From b7d23955d8aca22c3466a010f6461d3d5eecdb55 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Wed, 16 Aug 2023 16:21:12 +0530 Subject: [PATCH 2/4] fix checkpointing. --- examples/text_to_image/train_text_to_image_lora_sdxl.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/examples/text_to_image/train_text_to_image_lora_sdxl.py b/examples/text_to_image/train_text_to_image_lora_sdxl.py index 7a72ed22580e..5497e25d7103 100644 --- a/examples/text_to_image/train_text_to_image_lora_sdxl.py +++ b/examples/text_to_image/train_text_to_image_lora_sdxl.py @@ -724,11 +724,15 @@ def load_model_hook(models, input_dir): lora_state_dict, network_alphas = LoraLoaderMixin.lora_state_dict(input_dir) LoraLoaderMixin.load_lora_into_unet(lora_state_dict, network_alphas=network_alphas, unet=unet_) + + text_encoder_state_dict = {k: v for k, v in lora_state_dict.items() if "text_encoder." in k} LoraLoaderMixin.load_lora_into_text_encoder( - lora_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_one_ + text_encoder_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_one_ ) + + text_encoder_2_state_dict = {k: v for k, v in lora_state_dict.items() if "text_encoder_2." in k} LoraLoaderMixin.load_lora_into_text_encoder( - lora_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_two_ + text_encoder_2_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_two_ ) accelerator.register_save_state_pre_hook(save_model_hook) From dc898919e1c10ee1c8348261860dcd41d56e040b Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Wed, 16 Aug 2023 16:37:46 +0530 Subject: [PATCH 3/4] tests --- examples/test_examples.py | 81 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 81 insertions(+) diff --git a/examples/test_examples.py b/examples/test_examples.py index 4fd2e485cd0f..b13442a2c1b6 100644 --- a/examples/test_examples.py +++ b/examples/test_examples.py @@ -827,6 +827,87 @@ def test_text_to_image_lora_checkpointing_checkpoints_total_limit(self): {"checkpoint-4", "checkpoint-6"}, ) + def test_text_to_image_lora_sdxl_checkpointing_checkpoints_total_limit(self): + prompt = "a prompt" + pipeline_path = "hf-internal-testing/tiny-stable-diffusion-xl-pipe" + + with tempfile.TemporaryDirectory() as tmpdir: + # Run training script with checkpointing + # max_train_steps == 7, checkpointing_steps == 2, checkpoints_total_limit == 2 + # Should create checkpoints at steps 2, 4, 6 + # with checkpoint at step 2 deleted + + initial_run_args = f""" + examples/text_to_image/train_text_to_image_lora_sdxl.py + --pretrained_model_name_or_path {pipeline_path} + --dataset_name hf-internal-testing/dummy_image_text_data + --resolution 64 + --train_batch_size 1 + --gradient_accumulation_steps 1 + --max_train_steps 7 + --learning_rate 5.0e-04 + --scale_lr + --lr_scheduler constant + --lr_warmup_steps 0 + --output_dir {tmpdir} + --checkpointing_steps=2 + --checkpoints_total_limit=2 + """.split() + + run_command(self._launch_args + initial_run_args) + + pipe = DiffusionPipeline.from_pretrained(pipeline_path) + pipe.load_lora_weights(tmpdir) + pipe(prompt, num_inference_steps=2) + + # check checkpoint directories exist + self.assertEqual( + {x for x in os.listdir(tmpdir) if "checkpoint" in x}, + # checkpoint-2 should have been deleted + {"checkpoint-4", "checkpoint-6"}, + ) + + def test_text_to_image_lora_sdxl_text_encoder_checkpointing_checkpoints_total_limit(self): + prompt = "a prompt" + pipeline_path = "hf-internal-testing/tiny-stable-diffusion-xl-pipe" + + with tempfile.TemporaryDirectory() as tmpdir: + # Run training script with checkpointing + # max_train_steps == 7, checkpointing_steps == 2, checkpoints_total_limit == 2 + # Should create checkpoints at steps 2, 4, 6 + # with checkpoint at step 2 deleted + + initial_run_args = f""" + examples/text_to_image/train_text_to_image_lora_sdxl.py + --pretrained_model_name_or_path {pipeline_path} + --dataset_name hf-internal-testing/dummy_image_text_data + --resolution 64 + --train_batch_size 1 + --gradient_accumulation_steps 1 + --max_train_steps 7 + --learning_rate 5.0e-04 + --scale_lr + --lr_scheduler constant + --train_text_encoder + --lr_warmup_steps 0 + --output_dir {tmpdir} + --checkpointing_steps=2 + --checkpoints_total_limit=2 + """.split() + + run_command(self._launch_args + initial_run_args) + + pipe = DiffusionPipeline.from_pretrained(pipeline_path) + pipe.load_lora_weights(tmpdir) + pipe(prompt, num_inference_steps=2) + + # check checkpoint directories exist + self.assertEqual( + {x for x in os.listdir(tmpdir) if "checkpoint" in x}, + # checkpoint-2 should have been deleted + {"checkpoint-4", "checkpoint-6"}, + ) + def test_text_to_image_lora_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self): pretrained_model_name_or_path = "hf-internal-testing/tiny-stable-diffusion-pipe" prompt = "a prompt" From 54f30153f0abee404bd5ed6edebcac21062a0499 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Thu, 17 Aug 2023 12:50:55 +0530 Subject: [PATCH 4/4] fix: bugs --- .../train_text_to_image_lora_sdxl.py | 17 ----------------- tests/models/test_lora_layers.py | 1 + 2 files changed, 1 insertion(+), 17 deletions(-) diff --git a/examples/text_to_image/train_text_to_image_lora_sdxl.py b/examples/text_to_image/train_text_to_image_lora_sdxl.py index 5497e25d7103..fe8bdc594b38 100644 --- a/examples/text_to_image/train_text_to_image_lora_sdxl.py +++ b/examples/text_to_image/train_text_to_image_lora_sdxl.py @@ -396,16 +396,6 @@ def parse_args(input_args=None): " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." ), ) - parser.add_argument( - "--prior_generation_precision", - type=str, - default=None, - choices=["no", "fp32", "fp16", "bf16"], - help=( - "Choose prior generation precision between fp32, fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" - " 1.10.and an Nvidia Ampere GPU. Default to fp16 if a GPU is available else fp32." - ), - ) parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") parser.add_argument( "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers." @@ -1154,13 +1144,6 @@ def compute_time_ids(original_size, crops_coords_top_left): f" {args.validation_prompt}." ) # create pipeline - if not args.train_text_encoder: - text_encoder_one = text_encoder_cls_one.from_pretrained( - args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision - ) - text_encoder_two = text_encoder_cls_two.from_pretrained( - args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision - ) pipeline = StableDiffusionXLPipeline.from_pretrained( args.pretrained_model_name_or_path, vae=vae, diff --git a/tests/models/test_lora_layers.py b/tests/models/test_lora_layers.py index 7f06da81ba38..c2fe98993d00 100644 --- a/tests/models/test_lora_layers.py +++ b/tests/models/test_lora_layers.py @@ -664,6 +664,7 @@ def test_load_lora_locally(self): unet_lora_layers=lora_components["unet_lora_layers"], text_encoder_lora_layers=lora_components["text_encoder_one_lora_layers"], text_encoder_2_lora_layers=lora_components["text_encoder_two_lora_layers"], + safe_serialization=False, ) self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))) sd_pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))