diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index 06a847e6ca61..e32656abe569 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -43,7 +43,7 @@ from diffusers.optimization import get_scheduler from diffusers.training_utils import EMAModel from diffusers.utils import check_min_version, deprecate -from diffusers.utils.import_utils import is_xformers_available +from diffusers.utils.import_utils import is_wandb_available, is_xformers_available # Will error if the minimal version of diffusers is not installed. Remove at your own risks. @@ -297,6 +297,24 @@ def parse_args(): parser.add_argument( "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers." ) + parser.add_argument( + "--validation_prompt", + type=str, + default=None, + help="A prompt that is used during validation to verify that the model is learning.", + ) + parser.add_argument( + "--num_validation_images", + type=int, + default=4, + help="Number of images that should be generated during validation with `validation_prompt`.", + ) + parser.add_argument( + "--validation_steps", + type=int, + default=500, + help="Sample a validation image every X updates.", + ) args = parser.parse_args() env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) @@ -353,7 +371,14 @@ def main(): project_config=accelerator_project_config, ) - # Make one log on every process with the configuration for debugging. + if args.report_to == "wandb": + if not is_wandb_available(): + raise ImportError( + "Make sure to install wandb if you want to use it for logging during training. You can do so by doing" + " `pip install wandb`" + ) + import wandb # Make one log on every process with the configuration for debugging. + logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", @@ -756,6 +781,53 @@ def collate_fn(examples): accelerator.save_state(save_path) logger.info(f"Saved state to {save_path}") + if global_step % args.validation_steps == 0: + if accelerator.is_main_process: + if args.validation_prompt: + if args.use_ema: + # Store the UNet parameters temporarily and load the EMA parameters to perform inference. + ema_unet.store(unet.parameters()) + ema_unet.copy_to(unet.parameters()) + pipeline = StableDiffusionPipeline.from_pretrained( + args.pretrained_model_name_or_path, + text_encoder=text_encoder, + vae=vae, + unet=unet, + revision=args.revision, + ) + pipeline = pipeline.to(accelerator.device) + pipeline.set_progress_bar_config(disable=True) + + # run inference + prompt = [args.validation_prompt] + images = [] + with torch.autocast( + str(accelerator.device), enabled=accelerator.mixed_precision == "fp16" + ): + for _ in range(args.num_validation_images): + images.append(pipeline(prompt).images[0]) + + for i, image in enumerate(images): + image.save(os.path.join(args.output_dir, f"validation-{global_step}-{i}.jpg")) + for tracker in accelerator.trackers: + if tracker.name == "tensorboard": + np_images = np.stack([np.asarray(img) for img in images]) + tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC") + if tracker.name == "wandb": + tracker.log( + { + "validation": [ + wandb.Image(image, caption=f"{i}: {args.validation_prompt}") + for i, image in enumerate(images) + ] + } + ) + if args.use_ema: + # Switch back to the original UNet parameters. + ema_unet.restore(unet.parameters()) + del pipeline + torch.cuda.empty_cache() + logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} progress_bar.set_postfix(**logs) @@ -778,6 +850,33 @@ def collate_fn(examples): ) pipeline.save_pretrained(args.output_dir) + if args.validation_prompt: + pipeline = pipeline.to(accelerator.device) + pipeline.set_progress_bar_config(disable=True) + + # run inference + prompt = [args.validation_prompt] + images = [] + with torch.autocast(str(accelerator.device), enabled=accelerator.mixed_precision == "fp16"): + for _ in range(args.num_validation_images): + images.append(pipeline(prompt).images[0]) + + for i, image in enumerate(images): + image.save(os.path.join(args.output_dir, f"test-{i}.jpg")) + for tracker in accelerator.trackers: + if tracker.name == "tensorboard": + np_images = np.stack([np.asarray(img) for img in images]) + tracker.writer.add_images("test", np_images, epoch, dataformats="NHWC") + if tracker.name == "wandb": + tracker.log( + { + "test": [ + wandb.Image(image, caption=f"{i}: {args.validation_prompt}") + for i, image in enumerate(images) + ] + } + ) + if args.push_to_hub: repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True)