From f213720a40f5bd7d35a6e2fb58af70eac69674cc Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Thu, 24 Aug 2023 08:49:05 +0530 Subject: [PATCH 1/2] fix sdxl dreambooth lora checkpointing. --- .../dreambooth/train_dreambooth_lora_sdxl.py | 8 ++- examples/test_examples.py | 71 +++++++++++++++++++ 2 files changed, 77 insertions(+), 2 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora_sdxl.py b/examples/dreambooth/train_dreambooth_lora_sdxl.py index 81250db78412..0c6a4c0f14d5 100644 --- a/examples/dreambooth/train_dreambooth_lora_sdxl.py +++ b/examples/dreambooth/train_dreambooth_lora_sdxl.py @@ -843,11 +843,15 @@ def load_model_hook(models, input_dir): lora_state_dict, network_alphas = LoraLoaderMixin.lora_state_dict(input_dir) LoraLoaderMixin.load_lora_into_unet(lora_state_dict, network_alphas=network_alphas, unet=unet_) + + text_encoder_state_dict = {k: v for k, v in lora_state_dict.items() if "text_encoder." in k} LoraLoaderMixin.load_lora_into_text_encoder( - lora_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_one_ + text_encoder_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_one_ ) + + text_encoder_2_state_dict = {k: v for k, v in lora_state_dict.items() if "text_encoder_2." in k} LoraLoaderMixin.load_lora_into_text_encoder( - lora_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_two_ + text_encoder_2_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_two_ ) accelerator.register_save_state_pre_hook(save_model_hook) diff --git a/examples/test_examples.py b/examples/test_examples.py index c06c9417d594..7016cdd0e0c9 100644 --- a/examples/test_examples.py +++ b/examples/test_examples.py @@ -420,6 +420,77 @@ def test_dreambooth_lora_sdxl_with_text_encoder(self): k.startswith("unet") or k.startswith("text_encoder") or k.startswith("text_encoder_2") for k in keys ) self.assertTrue(starts_with_unet) + + def test_dreambooth_lora_sdxl_checkpointing_checkpoints_total_limit(self): + pipeline_path = "hf-internal-testing/tiny-stable-diffusion-xl-pipe" + + with tempfile.TemporaryDirectory() as tmpdir: + test_args = f""" + examples/dreambooth/train_dreambooth_lora_sdxl.py + --pretrained_model_name_or_path {pipeline_path} + --instance_data_dir docs/source/en/imgs + --instance_prompt photo + --resolution 64 + --train_batch_size 1 + --gradient_accumulation_steps 1 + --max_train_steps 7 + --checkpointing_steps=2 + --checkpoints_total_limit=2 + --learning_rate 5.0e-04 + --scale_lr + --lr_scheduler constant + --lr_warmup_steps 0 + --output_dir {tmpdir} + """.split() + + run_command(self._launch_args + test_args) + + pipe = DiffusionPipeline.from_pretrained(pipeline_path) + pipe.load_lora_weights(tmpdir) + pipe("a prompt", num_inference_steps=2) + + # check checkpoint directories exist + self.assertEqual( + {x for x in os.listdir(tmpdir) if "checkpoint" in x}, + # checkpoint-2 should have been deleted + {"checkpoint-4", "checkpoint-6"}, + ) + + def test_dreambooth_lora_sdxl_text_encoder_checkpointing_checkpoints_total_limit(self): + pipeline_path = "hf-internal-testing/tiny-stable-diffusion-xl-pipe" + + with tempfile.TemporaryDirectory() as tmpdir: + test_args = f""" + examples/dreambooth/train_dreambooth_lora_sdxl.py + --pretrained_model_name_or_path {pipeline_path} + --instance_data_dir docs/source/en/imgs + --instance_prompt photo + --resolution 64 + --train_batch_size 1 + --gradient_accumulation_steps 1 + --max_train_steps 7 + --checkpointing_steps=2 + --checkpoints_total_limit=2 + --train_text_encoder + --learning_rate 5.0e-04 + --scale_lr + --lr_scheduler constant + --lr_warmup_steps 0 + --output_dir {tmpdir} + """.split() + + run_command(self._launch_args + test_args) + + pipe = DiffusionPipeline.from_pretrained(pipeline_path) + pipe.load_lora_weights(tmpdir) + pipe("a prompt", num_inference_steps=2) + + # check checkpoint directories exist + self.assertEqual( + {x for x in os.listdir(tmpdir) if "checkpoint" in x}, + # checkpoint-2 should have been deleted + {"checkpoint-4", "checkpoint-6"}, + ) def test_custom_diffusion(self): with tempfile.TemporaryDirectory() as tmpdir: From 9a407daffe013679302c7d12c8c2da0fe7568c53 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Thu, 24 Aug 2023 16:20:04 +0530 Subject: [PATCH 2/2] style --- examples/dreambooth/train_dreambooth_lora_sdxl.py | 2 +- examples/test_examples.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora_sdxl.py b/examples/dreambooth/train_dreambooth_lora_sdxl.py index 0c6a4c0f14d5..247d111c06e2 100644 --- a/examples/dreambooth/train_dreambooth_lora_sdxl.py +++ b/examples/dreambooth/train_dreambooth_lora_sdxl.py @@ -843,7 +843,7 @@ def load_model_hook(models, input_dir): lora_state_dict, network_alphas = LoraLoaderMixin.lora_state_dict(input_dir) LoraLoaderMixin.load_lora_into_unet(lora_state_dict, network_alphas=network_alphas, unet=unet_) - + text_encoder_state_dict = {k: v for k, v in lora_state_dict.items() if "text_encoder." in k} LoraLoaderMixin.load_lora_into_text_encoder( text_encoder_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_one_ diff --git a/examples/test_examples.py b/examples/test_examples.py index 7016cdd0e0c9..2b4f49dd6bdd 100644 --- a/examples/test_examples.py +++ b/examples/test_examples.py @@ -420,9 +420,9 @@ def test_dreambooth_lora_sdxl_with_text_encoder(self): k.startswith("unet") or k.startswith("text_encoder") or k.startswith("text_encoder_2") for k in keys ) self.assertTrue(starts_with_unet) - + def test_dreambooth_lora_sdxl_checkpointing_checkpoints_total_limit(self): - pipeline_path = "hf-internal-testing/tiny-stable-diffusion-xl-pipe" + pipeline_path = "hf-internal-testing/tiny-stable-diffusion-xl-pipe" with tempfile.TemporaryDirectory() as tmpdir: test_args = f""" @@ -457,7 +457,7 @@ def test_dreambooth_lora_sdxl_checkpointing_checkpoints_total_limit(self): ) def test_dreambooth_lora_sdxl_text_encoder_checkpointing_checkpoints_total_limit(self): - pipeline_path = "hf-internal-testing/tiny-stable-diffusion-xl-pipe" + pipeline_path = "hf-internal-testing/tiny-stable-diffusion-xl-pipe" with tempfile.TemporaryDirectory() as tmpdir: test_args = f"""