diff --git a/examples/research_projects/lora/train_text_to_image_lora.py b/examples/research_projects/lora/train_text_to_image_lora.py index a53af7bcffd2..0ff15ed293e4 100644 --- a/examples/research_projects/lora/train_text_to_image_lora.py +++ b/examples/research_projects/lora/train_text_to_image_lora.py @@ -582,7 +582,7 @@ def main(): else: optimizer_cls = torch.optim.AdamW - if args.peft: + if args.use_peft: # Optimizer creation params_to_optimize = ( itertools.chain(unet.parameters(), text_encoder.parameters()) @@ -724,7 +724,7 @@ def collate_fn(examples): ) # Prepare everything with our `accelerator`. - if args.peft: + if args.use_peft: if args.train_text_encoder: unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( unet, text_encoder, optimizer, train_dataloader, lr_scheduler @@ -842,7 +842,7 @@ def collate_fn(examples): # Backpropagate accelerator.backward(loss) if accelerator.sync_gradients: - if args.peft: + if args.use_peft: params_to_clip = ( itertools.chain(unet.parameters(), text_encoder.parameters()) if args.train_text_encoder @@ -922,18 +922,22 @@ def collate_fn(examples): if accelerator.is_main_process: if args.use_peft: lora_config = {} - state_dict = get_peft_model_state_dict(unet, state_dict=accelerator.get_state_dict(unet)) - lora_config["peft_config"] = unet.get_peft_config_as_dict(inference=True) + unwarpped_unet = accelerator.unwrap_model(unet) + state_dict = get_peft_model_state_dict(unwarpped_unet, state_dict=accelerator.get_state_dict(unet)) + lora_config["peft_config"] = unwarpped_unet.get_peft_config_as_dict(inference=True) if args.train_text_encoder: + unwarpped_text_encoder = accelerator.unwrap_model(text_encoder) text_encoder_state_dict = get_peft_model_state_dict( - text_encoder, state_dict=accelerator.get_state_dict(text_encoder) + unwarpped_text_encoder, state_dict=accelerator.get_state_dict(text_encoder) ) text_encoder_state_dict = {f"text_encoder_{k}": v for k, v in text_encoder_state_dict.items()} state_dict.update(text_encoder_state_dict) - lora_config["text_encoder_peft_config"] = text_encoder.get_peft_config_as_dict(inference=True) + lora_config["text_encoder_peft_config"] = unwarpped_text_encoder.get_peft_config_as_dict( + inference=True + ) - accelerator.save(state_dict, os.path.join(args.output_dir, f"{args.instance_prompt}_lora.pt")) - with open(os.path.join(args.output_dir, f"{args.instance_prompt}_lora_config.json"), "w") as f: + accelerator.save(state_dict, os.path.join(args.output_dir, f"{global_step}_lora.pt")) + with open(os.path.join(args.output_dir, f"{global_step}_lora_config.json"), "w") as f: json.dump(lora_config, f) else: unet = unet.to(torch.float32) @@ -957,12 +961,12 @@ def collate_fn(examples): if args.use_peft: - def load_and_set_lora_ckpt(pipe, ckpt_dir, instance_prompt, device, dtype): - with open(f"{ckpt_dir}{instance_prompt}_lora_config.json", "r") as f: + def load_and_set_lora_ckpt(pipe, ckpt_dir, global_step, device, dtype): + with open(os.path.join(args.output_dir, f"{global_step}_lora_config.json"), "r") as f: lora_config = json.load(f) print(lora_config) - checkpoint = f"{ckpt_dir}{instance_prompt}_lora.pt" + checkpoint = os.path.join(args.output_dir, f"{global_step}_lora.pt") lora_checkpoint_sd = torch.load(checkpoint) unet_lora_ds = {k: v for k, v in lora_checkpoint_sd.items() if "text_encoder_" not in k} text_encoder_lora_ds = { @@ -985,9 +989,7 @@ def load_and_set_lora_ckpt(pipe, ckpt_dir, instance_prompt, device, dtype): pipe.to(device) return pipe - pipeline = load_and_set_lora_ckpt( - pipeline, args.output_dir, args.instance_prompt, accelerator.device, weight_dtype - ) + pipeline = load_and_set_lora_ckpt(pipeline, args.output_dir, global_step, accelerator.device, weight_dtype) else: pipeline = pipeline.to(accelerator.device) @@ -995,7 +997,10 @@ def load_and_set_lora_ckpt(pipe, ckpt_dir, instance_prompt, device, dtype): pipeline.unet.load_attn_procs(args.output_dir) # run inference - generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) + if args.seed is not None: + generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) + else: + generator = None images = [] for _ in range(args.num_validation_images): images.append(pipeline(args.validation_prompt, num_inference_steps=30, generator=generator).images[0])