From 18d3fc51e37567ce8176c7e3b73342b9b5993eca Mon Sep 17 00:00:00 2001 From: Philippe Date: Mon, 27 Nov 2023 14:48:20 -0500 Subject: [PATCH 1/2] fix: changed prepare input --- examples/text_to_image/train_text_to_image_lora.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/text_to_image/train_text_to_image_lora.py b/examples/text_to_image/train_text_to_image_lora.py index 7d731c994bdd..824efe9d010c 100644 --- a/examples/text_to_image/train_text_to_image_lora.py +++ b/examples/text_to_image/train_text_to_image_lora.py @@ -700,8 +700,8 @@ def collate_fn(examples): ) # Prepare everything with our `accelerator`. - unet_lora_parameters, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - unet_lora_parameters, optimizer, train_dataloader, lr_scheduler + unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + unet, optimizer, train_dataloader, lr_scheduler ) # We need to recalculate our total training steps as the size of the training dataloader may have changed. From e81f14f70db008671bc66f300c06a14412da45f6 Mon Sep 17 00:00:00 2001 From: Philippe Date: Wed, 29 Nov 2023 18:08:51 -0500 Subject: [PATCH 2/2] fix: autocast added --- examples/text_to_image/train_text_to_image_lora.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/examples/text_to_image/train_text_to_image_lora.py b/examples/text_to_image/train_text_to_image_lora.py index 824efe9d010c..989e70593102 100644 --- a/examples/text_to_image/train_text_to_image_lora.py +++ b/examples/text_to_image/train_text_to_image_lora.py @@ -899,11 +899,12 @@ def collate_fn(examples): generator = torch.Generator(device=accelerator.device) if args.seed is not None: generator = generator.manual_seed(args.seed) - images = [] - for _ in range(args.num_validation_images): - images.append( - pipeline(args.validation_prompt, num_inference_steps=30, generator=generator).images[0] - ) + with torch.cuda.amp.autocast(): + images = [] + for _ in range(args.num_validation_images): + images.append( + pipeline(args.validation_prompt, num_inference_steps=30, generator=generator).images[0] + ) for tracker in accelerator.trackers: if tracker.name == "tensorboard":