diff --git a/examples/text_to_image/train_text_to_image_lora.py b/examples/text_to_image/train_text_to_image_lora.py index 7d731c994bdd..989e70593102 100644 --- a/examples/text_to_image/train_text_to_image_lora.py +++ b/examples/text_to_image/train_text_to_image_lora.py @@ -700,8 +700,8 @@ def collate_fn(examples): ) # Prepare everything with our `accelerator`. - unet_lora_parameters, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - unet_lora_parameters, optimizer, train_dataloader, lr_scheduler + unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + unet, optimizer, train_dataloader, lr_scheduler ) # We need to recalculate our total training steps as the size of the training dataloader may have changed. @@ -899,11 +899,12 @@ def collate_fn(examples): generator = torch.Generator(device=accelerator.device) if args.seed is not None: generator = generator.manual_seed(args.seed) - images = [] - for _ in range(args.num_validation_images): - images.append( - pipeline(args.validation_prompt, num_inference_steps=30, generator=generator).images[0] - ) + with torch.cuda.amp.autocast(): + images = [] + for _ in range(args.num_validation_images): + images.append( + pipeline(args.validation_prompt, num_inference_steps=30, generator=generator).images[0] + ) for tracker in accelerator.trackers: if tracker.name == "tensorboard":