From a1248d6c5127cf658d54924c0f6fee8b2e708d2f Mon Sep 17 00:00:00 2001 From: Haofan Wang Date: Tue, 21 Mar 2023 19:39:09 +0800 Subject: [PATCH 1/5] Update train_text_to_image_lora.py --- examples/research_projects/lora/train_text_to_image_lora.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/research_projects/lora/train_text_to_image_lora.py b/examples/research_projects/lora/train_text_to_image_lora.py index a53af7bcffd2..f5a3997c0188 100644 --- a/examples/research_projects/lora/train_text_to_image_lora.py +++ b/examples/research_projects/lora/train_text_to_image_lora.py @@ -582,7 +582,7 @@ def main(): else: optimizer_cls = torch.optim.AdamW - if args.peft: + if args.use_peft: # Optimizer creation params_to_optimize = ( itertools.chain(unet.parameters(), text_encoder.parameters()) @@ -724,7 +724,7 @@ def collate_fn(examples): ) # Prepare everything with our `accelerator`. - if args.peft: + if args.use_peft: if args.train_text_encoder: unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( unet, text_encoder, optimizer, train_dataloader, lr_scheduler @@ -842,7 +842,7 @@ def collate_fn(examples): # Backpropagate accelerator.backward(loss) if accelerator.sync_gradients: - if args.peft: + if args.use_peft: params_to_clip = ( itertools.chain(unet.parameters(), text_encoder.parameters()) if args.train_text_encoder From c3d26545d9c86cf455e2e7512ff135b03c007a67 Mon Sep 17 00:00:00 2001 From: Haofan Wang Date: Tue, 21 Mar 2023 19:58:19 +0800 Subject: [PATCH 2/5] Update train_text_to_image_lora.py --- .../research_projects/lora/train_text_to_image_lora.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/examples/research_projects/lora/train_text_to_image_lora.py b/examples/research_projects/lora/train_text_to_image_lora.py index f5a3997c0188..a44867285755 100644 --- a/examples/research_projects/lora/train_text_to_image_lora.py +++ b/examples/research_projects/lora/train_text_to_image_lora.py @@ -932,8 +932,8 @@ def collate_fn(examples): state_dict.update(text_encoder_state_dict) lora_config["text_encoder_peft_config"] = text_encoder.get_peft_config_as_dict(inference=True) - accelerator.save(state_dict, os.path.join(args.output_dir, f"{args.instance_prompt}_lora.pt")) - with open(os.path.join(args.output_dir, f"{args.instance_prompt}_lora_config.json"), "w") as f: + accelerator.save(state_dict, os.path.join(args.output_dir, f"{global_step}_lora.pt")) + with open(os.path.join(args.output_dir, f"{global_step}_lora_config.json"), "w") as f: json.dump(lora_config, f) else: unet = unet.to(torch.float32) @@ -957,12 +957,12 @@ def collate_fn(examples): if args.use_peft: - def load_and_set_lora_ckpt(pipe, ckpt_dir, instance_prompt, device, dtype): + def load_and_set_lora_ckpt(pipe, ckpt_dir, global_step, device, dtype): with open(f"{ckpt_dir}{instance_prompt}_lora_config.json", "r") as f: lora_config = json.load(f) print(lora_config) - checkpoint = f"{ckpt_dir}{instance_prompt}_lora.pt" + checkpoint = f"{ckpt_dir}{global_step}_lora.pt" lora_checkpoint_sd = torch.load(checkpoint) unet_lora_ds = {k: v for k, v in lora_checkpoint_sd.items() if "text_encoder_" not in k} text_encoder_lora_ds = { @@ -986,7 +986,7 @@ def load_and_set_lora_ckpt(pipe, ckpt_dir, instance_prompt, device, dtype): return pipe pipeline = load_and_set_lora_ckpt( - pipeline, args.output_dir, args.instance_prompt, accelerator.device, weight_dtype + pipeline, args.output_dir, global_step, accelerator.device, weight_dtype ) else: From 0cd2b58589d89d3648a3abae85e97cc25d1dc2ca Mon Sep 17 00:00:00 2001 From: Haofan Wang Date: Tue, 21 Mar 2023 19:59:11 +0800 Subject: [PATCH 3/5] Update train_text_to_image_lora.py --- examples/research_projects/lora/train_text_to_image_lora.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/research_projects/lora/train_text_to_image_lora.py b/examples/research_projects/lora/train_text_to_image_lora.py index a44867285755..d8c3dc8249f8 100644 --- a/examples/research_projects/lora/train_text_to_image_lora.py +++ b/examples/research_projects/lora/train_text_to_image_lora.py @@ -958,7 +958,7 @@ def collate_fn(examples): if args.use_peft: def load_and_set_lora_ckpt(pipe, ckpt_dir, global_step, device, dtype): - with open(f"{ckpt_dir}{instance_prompt}_lora_config.json", "r") as f: + with open(f"{ckpt_dir}{global_step}_lora_config.json", "r") as f: lora_config = json.load(f) print(lora_config) From ab68e1d5e26aa2ea922f11ddddad19d938249456 Mon Sep 17 00:00:00 2001 From: Haofan Wang Date: Tue, 21 Mar 2023 21:35:10 +0800 Subject: [PATCH 4/5] Update train_text_to_image_lora.py --- .../lora/train_text_to_image_lora.py | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/examples/research_projects/lora/train_text_to_image_lora.py b/examples/research_projects/lora/train_text_to_image_lora.py index d8c3dc8249f8..28bf09df4cb8 100644 --- a/examples/research_projects/lora/train_text_to_image_lora.py +++ b/examples/research_projects/lora/train_text_to_image_lora.py @@ -922,15 +922,17 @@ def collate_fn(examples): if accelerator.is_main_process: if args.use_peft: lora_config = {} - state_dict = get_peft_model_state_dict(unet, state_dict=accelerator.get_state_dict(unet)) - lora_config["peft_config"] = unet.get_peft_config_as_dict(inference=True) + unwarpped_unet = accelerator.unwrap_model(unet) + state_dict = get_peft_model_state_dict(unwarpped_unet, state_dict=accelerator.get_state_dict(unet)) + lora_config["peft_config"] = unwarpped_unet.get_peft_config_as_dict(inference=True) if args.train_text_encoder: + unwarpped_text_encoder = accelerator.unwrap_model(text_encoder) text_encoder_state_dict = get_peft_model_state_dict( - text_encoder, state_dict=accelerator.get_state_dict(text_encoder) + unwarpped_text_encoder, state_dict=accelerator.get_state_dict(text_encoder) ) text_encoder_state_dict = {f"text_encoder_{k}": v for k, v in text_encoder_state_dict.items()} state_dict.update(text_encoder_state_dict) - lora_config["text_encoder_peft_config"] = text_encoder.get_peft_config_as_dict(inference=True) + lora_config["text_encoder_peft_config"] = unwarpped_text_encoder.get_peft_config_as_dict(inference=True) accelerator.save(state_dict, os.path.join(args.output_dir, f"{global_step}_lora.pt")) with open(os.path.join(args.output_dir, f"{global_step}_lora_config.json"), "w") as f: @@ -958,11 +960,11 @@ def collate_fn(examples): if args.use_peft: def load_and_set_lora_ckpt(pipe, ckpt_dir, global_step, device, dtype): - with open(f"{ckpt_dir}{global_step}_lora_config.json", "r") as f: + with open(os.path.join(args.output_dir, f"{global_step}_lora_config.json"), "r") as f: lora_config = json.load(f) print(lora_config) - checkpoint = f"{ckpt_dir}{global_step}_lora.pt" + checkpoint = os.path.join(args.output_dir, f"{global_step}_lora.pt") lora_checkpoint_sd = torch.load(checkpoint) unet_lora_ds = {k: v for k, v in lora_checkpoint_sd.items() if "text_encoder_" not in k} text_encoder_lora_ds = { @@ -995,7 +997,10 @@ def load_and_set_lora_ckpt(pipe, ckpt_dir, global_step, device, dtype): pipeline.unet.load_attn_procs(args.output_dir) # run inference - generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) + if args.seed is not None: + generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) + else: + generator = None images = [] for _ in range(args.num_validation_images): images.append(pipeline(args.validation_prompt, num_inference_steps=30, generator=generator).images[0]) From 1dac8ae787e6d91c323e648e19d76c2a32630a09 Mon Sep 17 00:00:00 2001 From: haofanwang Date: Wed, 22 Mar 2023 02:26:33 +0800 Subject: [PATCH 5/5] format --- .../research_projects/lora/train_text_to_image_lora.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/examples/research_projects/lora/train_text_to_image_lora.py b/examples/research_projects/lora/train_text_to_image_lora.py index 28bf09df4cb8..0ff15ed293e4 100644 --- a/examples/research_projects/lora/train_text_to_image_lora.py +++ b/examples/research_projects/lora/train_text_to_image_lora.py @@ -932,7 +932,9 @@ def collate_fn(examples): ) text_encoder_state_dict = {f"text_encoder_{k}": v for k, v in text_encoder_state_dict.items()} state_dict.update(text_encoder_state_dict) - lora_config["text_encoder_peft_config"] = unwarpped_text_encoder.get_peft_config_as_dict(inference=True) + lora_config["text_encoder_peft_config"] = unwarpped_text_encoder.get_peft_config_as_dict( + inference=True + ) accelerator.save(state_dict, os.path.join(args.output_dir, f"{global_step}_lora.pt")) with open(os.path.join(args.output_dir, f"{global_step}_lora_config.json"), "w") as f: @@ -987,9 +989,7 @@ def load_and_set_lora_ckpt(pipe, ckpt_dir, global_step, device, dtype): pipe.to(device) return pipe - pipeline = load_and_set_lora_ckpt( - pipeline, args.output_dir, global_step, accelerator.device, weight_dtype - ) + pipeline = load_and_set_lora_ckpt(pipeline, args.output_dir, global_step, accelerator.device, weight_dtype) else: pipeline = pipeline.to(accelerator.device)