Skip to content
16 changes: 10 additions & 6 deletions examples/text_to_image/train_text_to_image_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -844,10 +844,11 @@ def collate_fn(examples):
if args.seed is not None:
generator = generator.manual_seed(args.seed)
images = []
for _ in range(args.num_validation_images):
images.append(
pipeline(args.validation_prompt, num_inference_steps=30, generator=generator).images[0]
)
with torch.cuda.amp.autocast():
for _ in range(args.num_validation_images):
images.append(
pipeline(args.validation_prompt, num_inference_steps=30, generator=generator).images[0]
)

for tracker in accelerator.trackers:
if tracker.name == "tensorboard":
Expand Down Expand Up @@ -913,8 +914,11 @@ def collate_fn(examples):
if args.seed is not None:
generator = generator.manual_seed(args.seed)
images = []
for _ in range(args.num_validation_images):
images.append(pipeline(args.validation_prompt, num_inference_steps=30, generator=generator).images[0])
with torch.cuda.amp.autocast():
for _ in range(args.num_validation_images):
images.append(
pipeline(args.validation_prompt, num_inference_steps=30, generator=generator).images[0]
)

for tracker in accelerator.trackers:
if len(images) != 0:
Expand Down