diff --git a/examples/research_projects/onnxruntime/unconditional_image_generation/train_unconditional.py b/examples/research_projects/onnxruntime/unconditional_image_generation/train_unconditional.py index 586ed8c8d397..cd120460c10f 100644 --- a/examples/research_projects/onnxruntime/unconditional_image_generation/train_unconditional.py +++ b/examples/research_projects/onnxruntime/unconditional_image_generation/train_unconditional.py @@ -21,7 +21,7 @@ from diffusers import DDPMPipeline, DDPMScheduler, UNet2DModel from diffusers.optimization import get_scheduler from diffusers.training_utils import EMAModel -from diffusers.utils import check_min_version +from diffusers.utils import check_min_version, is_tensorboard_available, is_wandb_available # Will error if the minimal version of diffusers is not installed. Remove at your own risks. @@ -220,6 +220,7 @@ def parse_args(): help="Whether the model should predict the 'epsilon'/noise error or directly the reconstructed image 'x0'.", ) parser.add_argument("--ddpm_num_steps", type=int, default=1000) + parser.add_argument("--ddpm_num_inference_steps", type=int, default=1000) parser.add_argument("--ddpm_beta_schedule", type=str, default="linear") parser.add_argument( "--checkpointing_steps", @@ -271,6 +272,15 @@ def main(args): logging_dir=logging_dir, ) + if args.logger == "tensorboard": + if not is_tensorboard_available(): + raise ImportError("Make sure to install tensorboard if you want to use it for logging during training.") + + elif args.logger == "wandb": + if not is_wandb_available(): + raise ImportError("Make sure to install wandb if you want to use it for logging during training.") + import wandb + # Make one log on every process with the configuration for debugging. logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", @@ -552,7 +562,7 @@ def transform_images(examples): generator=generator, batch_size=args.eval_batch_size, output_type="numpy", - num_inference_steps=args.ddpm_num_steps, + num_inference_steps=args.ddpm_num_inference_steps, ).images # denormalize the images and save to tensorboard @@ -562,6 +572,11 @@ def transform_images(examples): accelerator.get_tracker("tensorboard").add_images( "test_samples", images_processed.transpose(0, 3, 1, 2), epoch ) + elif args.logger == "wandb": + accelerator.get_tracker("wandb").log( + {"test_samples": [wandb.Image(img) for img in images_processed], "epoch": epoch}, + step=global_step, + ) if epoch % args.save_model_epochs == 0 or epoch == args.num_epochs - 1: # save the model diff --git a/examples/unconditional_image_generation/train_unconditional.py b/examples/unconditional_image_generation/train_unconditional.py index 64ba126d0cce..f76594a78c32 100644 --- a/examples/unconditional_image_generation/train_unconditional.py +++ b/examples/unconditional_image_generation/train_unconditional.py @@ -22,7 +22,7 @@ from diffusers import DDPMPipeline, DDPMScheduler, UNet2DModel from diffusers.optimization import get_scheduler from diffusers.training_utils import EMAModel -from diffusers.utils import check_min_version, is_tensorboard_available +from diffusers.utils import check_min_version, is_tensorboard_available, is_wandb_available # Will error if the minimal version of diffusers is not installed. Remove at your own risks. @@ -280,6 +280,15 @@ def main(args): logging_dir=logging_dir, ) + if args.logger == "tensorboard": + if not is_tensorboard_available(): + raise ImportError("Make sure to install tensorboard if you want to use it for logging during training.") + + elif args.logger == "wandb": + if not is_wandb_available(): + raise ImportError("Make sure to install wandb if you want to use it for logging during training.") + import wandb + # `accelerate` 0.16.0 will have better support for customized saving if version.parse(accelerate.__version__) >= version.parse("0.16.0"): # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format @@ -604,10 +613,15 @@ def transform_images(examples): # denormalize the images and save to tensorboard images_processed = (images * 255).round().astype("uint8") - if args.logger == "tensorboard" and is_tensorboard_available(): + if args.logger == "tensorboard": accelerator.get_tracker("tensorboard").add_images( "test_samples", images_processed.transpose(0, 3, 1, 2), epoch ) + elif args.logger == "wandb": + accelerator.get_tracker("wandb").log( + {"test_samples": [wandb.Image(img) for img in images_processed], "epoch": epoch}, + step=global_step, + ) if epoch % args.save_model_epochs == 0 or epoch == args.num_epochs - 1: # save the model