diff --git a/examples/unconditional_image_generation/train_unconditional.py b/examples/unconditional_image_generation/train_unconditional.py index e9c849b22b6b..3b784eda6a34 100644 --- a/examples/unconditional_image_generation/train_unconditional.py +++ b/examples/unconditional_image_generation/train_unconditional.py @@ -625,8 +625,11 @@ def transform_images(examples): if accelerator.is_main_process: if epoch % args.save_images_epochs == 0 or epoch == args.num_epochs - 1: unet = accelerator.unwrap_model(model) + if args.use_ema: + ema_model.store(unet.parameters()) ema_model.copy_to(unet.parameters()) + pipeline = DDPMPipeline( unet=unet, scheduler=noise_scheduler, @@ -641,6 +644,9 @@ def transform_images(examples): output_type="numpy", ).images + if args.use_ema: + ema_model.restore(unet.parameters()) + # denormalize the images and save to tensorboard images_processed = (images * 255).round().astype("uint8") @@ -659,7 +665,22 @@ def transform_images(examples): if epoch % args.save_model_epochs == 0 or epoch == args.num_epochs - 1: # save the model + unet = accelerator.unwrap_model(model) + + if args.use_ema: + ema_model.store(unet.parameters()) + ema_model.copy_to(unet.parameters()) + + pipeline = DDPMPipeline( + unet=unet, + scheduler=noise_scheduler, + ) + pipeline.save_pretrained(args.output_dir) + + if args.use_ema: + ema_model.restore(unet.parameters()) + if args.push_to_hub: repo.push_to_hub(commit_message=f"Epoch {epoch}", blocking=False)