Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions examples/unconditional_image_generation/train_unconditional.py
Original file line number Diff line number Diff line change
Expand Up @@ -625,8 +625,11 @@ def transform_images(examples):
if accelerator.is_main_process:
if epoch % args.save_images_epochs == 0 or epoch == args.num_epochs - 1:
unet = accelerator.unwrap_model(model)

if args.use_ema:
ema_model.store(unet.parameters())
ema_model.copy_to(unet.parameters())

pipeline = DDPMPipeline(
unet=unet,
scheduler=noise_scheduler,
Expand All @@ -641,6 +644,9 @@ def transform_images(examples):
output_type="numpy",
).images

if args.use_ema:
ema_model.restore(unet.parameters())

# denormalize the images and save to tensorboard
images_processed = (images * 255).round().astype("uint8")

Expand All @@ -659,7 +665,22 @@ def transform_images(examples):

if epoch % args.save_model_epochs == 0 or epoch == args.num_epochs - 1:
# save the model
unet = accelerator.unwrap_model(model)

if args.use_ema:
ema_model.store(unet.parameters())
ema_model.copy_to(unet.parameters())

pipeline = DDPMPipeline(
unet=unet,
scheduler=noise_scheduler,
)

pipeline.save_pretrained(args.output_dir)

if args.use_ema:
ema_model.restore(unet.parameters())

Comment on lines +668 to +683
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have to make a separate pipeline here because it's not guaranteed that the pipeline set on line 633 is ever set. It's very likely that most people calling this script are never setting the save_model_epocs or save_images_epochs and they're always the same default values

if args.push_to_hub:
repo.push_to_hub(commit_message=f"Epoch {epoch}", blocking=False)

Expand Down