diff --git a/examples/text_to_image/train_text_to_image_lora.py b/examples/text_to_image/train_text_to_image_lora.py index c8efbddd0b44..afbf80404872 100644 --- a/examples/text_to_image/train_text_to_image_lora.py +++ b/examples/text_to_image/train_text_to_image_lora.py @@ -844,10 +844,11 @@ def collate_fn(examples): if args.seed is not None: generator = generator.manual_seed(args.seed) images = [] - for _ in range(args.num_validation_images): - images.append( - pipeline(args.validation_prompt, num_inference_steps=30, generator=generator).images[0] - ) + with torch.cuda.amp.autocast(): + for _ in range(args.num_validation_images): + images.append( + pipeline(args.validation_prompt, num_inference_steps=30, generator=generator).images[0] + ) for tracker in accelerator.trackers: if tracker.name == "tensorboard": @@ -913,8 +914,11 @@ def collate_fn(examples): if args.seed is not None: generator = generator.manual_seed(args.seed) images = [] - for _ in range(args.num_validation_images): - images.append(pipeline(args.validation_prompt, num_inference_steps=30, generator=generator).images[0]) + with torch.cuda.amp.autocast(): + for _ in range(args.num_validation_images): + images.append( + pipeline(args.validation_prompt, num_inference_steps=30, generator=generator).images[0] + ) for tracker in accelerator.trackers: if len(images) != 0: