diff --git a/examples/unconditional_image_generation/train_unconditional.py b/examples/unconditional_image_generation/train_unconditional.py index 3b0177402c4e..3ca92717f8fa 100644 --- a/examples/unconditional_image_generation/train_unconditional.py +++ b/examples/unconditional_image_generation/train_unconditional.py @@ -1,5 +1,4 @@ import argparse -import copy import inspect import logging import math @@ -530,7 +529,7 @@ def transforms(examples): # Generate sample images for visual inspection if accelerator.is_main_process: if epoch % args.save_images_epochs == 0 or epoch == args.num_epochs - 1: - unet = copy.deepcopy(accelerator.unwrap_model(model)) + unet = accelerator.unwrap_model(model) if args.use_ema: ema_model.copy_to(unet.parameters()) pipeline = DDPMPipeline(