From 6aa64029ece04d6a039eea1c34b226fe8f64976b Mon Sep 17 00:00:00 2001 From: Lucas Sloan Date: Sun, 29 Jan 2023 13:50:26 -0800 Subject: [PATCH 01/17] Sample images when checkpointing. --- examples/text_to_image/train_text_to_image.py | 34 +++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index 06a847e6ca61..80b189a882cc 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -297,6 +297,18 @@ def parse_args(): parser.add_argument( "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers." ) + parser.add_argument( + "--validation_prompt", + type=str, + default=None, + help="A prompt that is used during validation to verify that the model is learning.", + ) + parser.add_argument( + "--num_validation_images", + type=int, + default=4, + help="Number of images that should be generated during validation with `validation_prompt`.", + ) args = parser.parse_args() env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) @@ -756,6 +768,28 @@ def collate_fn(examples): accelerator.save_state(save_path) logger.info(f"Saved state to {save_path}") + if args.validation_prompt: + pipeline = StableDiffusionPipeline.from_pretrained( + args.pretrained_model_name_or_path, + text_encoder=text_encoder, + vae=vae, + unet=accelerator.unwrap_model(unet), + revision=args.revision, + torch_dtype=weight_dtype, + ) + pipeline = pipeline.to(accelerator.device) + pipeline.set_progress_bar_config(disable=True) + + # run inference + prompt = [args.validation_prompt] + images = pipeline(prompt, num_images_per_prompt=args.num_validation_images).images + + for i, image in enumerate(images): + image.save(os.path.join(args.output_dir, f"sample-{global_step}-{i}.jpg")) + del pipeline + torch.cuda.empty_cache() + + logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} progress_bar.set_postfix(**logs) From 09b8e216e0a9d77661a3b6071a286a720a759259 Mon Sep 17 00:00:00 2001 From: Lucas Sloan Date: Sun, 29 Jan 2023 14:09:24 -0800 Subject: [PATCH 02/17] Formatting. --- examples/text_to_image/train_text_to_image.py | 1 - 1 file changed, 1 deletion(-) diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index 80b189a882cc..139be95f50e9 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -789,7 +789,6 @@ def collate_fn(examples): del pipeline torch.cuda.empty_cache() - logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} progress_bar.set_postfix(**logs) From c291cceb06e8a09f7738cabb9ffcce9eed9aa31b Mon Sep 17 00:00:00 2001 From: Lucas Sloan Date: Tue, 31 Jan 2023 17:49:00 -0800 Subject: [PATCH 03/17] Use torch.autocast(). --- examples/text_to_image/train_text_to_image.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index 139be95f50e9..a1b63fcfa12c 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -775,14 +775,14 @@ def collate_fn(examples): vae=vae, unet=accelerator.unwrap_model(unet), revision=args.revision, - torch_dtype=weight_dtype, ) pipeline = pipeline.to(accelerator.device) pipeline.set_progress_bar_config(disable=True) # run inference prompt = [args.validation_prompt] - images = pipeline(prompt, num_images_per_prompt=args.num_validation_images).images + with torch.autocast("cuda"): + images = pipeline(prompt, num_images_per_prompt=args.num_validation_images).images for i, image in enumerate(images): image.save(os.path.join(args.output_dir, f"sample-{global_step}-{i}.jpg")) From 4a4821c940ab0bf05758f692359a8268464c409c Mon Sep 17 00:00:00 2001 From: Lucas Sloan Date: Tue, 31 Jan 2023 20:49:36 -0800 Subject: [PATCH 04/17] Don't unwrap the unet. --- examples/text_to_image/train_text_to_image.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index a1b63fcfa12c..533caa47a0c4 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -773,7 +773,7 @@ def collate_fn(examples): args.pretrained_model_name_or_path, text_encoder=text_encoder, vae=vae, - unet=accelerator.unwrap_model(unet), + unet=unet, revision=args.revision, ) pipeline = pipeline.to(accelerator.device) From 27c4fc3434b28a6ea589baf69ff026b395c5ae4d Mon Sep 17 00:00:00 2001 From: Lucas Sloan Date: Tue, 31 Jan 2023 21:18:39 -0800 Subject: [PATCH 05/17] Log to tensorboard and wandb. --- examples/text_to_image/train_text_to_image.py | 24 ++++++++++++++++++- 1 file changed, 23 insertions(+), 1 deletion(-) diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index 533caa47a0c4..655325aa811d 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -32,6 +32,11 @@ from accelerate.logging import get_logger from accelerate.utils import ProjectConfiguration, set_seed from datasets import load_dataset +from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel +from diffusers.optimization import get_scheduler +from diffusers.training_utils import EMAModel +from diffusers.utils import check_min_version, is_wandb_available +from diffusers.utils.import_utils import is_xformers_available from huggingface_hub import HfFolder, Repository, create_repo, whoami from packaging import version from torchvision import transforms @@ -365,7 +370,11 @@ def main(): project_config=accelerator_project_config, ) - # Make one log on every process with the configuration for debugging. + if args.report_to == "wandb": + if not is_wandb_available(): + raise ImportError("Make sure to install wandb if you want to use it for logging during training. You can do so by doing `pip install wandb`") + import wandb # Make one log on every process with the configuration for debugging. + logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", @@ -786,6 +795,19 @@ def collate_fn(examples): for i, image in enumerate(images): image.save(os.path.join(args.output_dir, f"sample-{global_step}-{i}.jpg")) + for tracker in accelerator.trackers: + if tracker.name == "tensorboard": + np_images = np.stack([np.asarray(img) for img in images]) + tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC") + if tracker.name == "wandb": + tracker.log( + { + "validation": [ + wandb.Image(image, caption=f"{i}: {args.validation_prompt}") + for i, image in enumerate(images) + ] + } + ) del pipeline torch.cuda.empty_cache() From ef77d09c4c5c896dfcdf0bd8a3a05a4c9b901be0 Mon Sep 17 00:00:00 2001 From: Lucas Sloan Date: Tue, 31 Jan 2023 21:21:21 -0800 Subject: [PATCH 06/17] Run autoformatter. --- examples/text_to_image/train_text_to_image.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index 655325aa811d..8519c4003240 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -372,8 +372,11 @@ def main(): if args.report_to == "wandb": if not is_wandb_available(): - raise ImportError("Make sure to install wandb if you want to use it for logging during training. You can do so by doing `pip install wandb`") - import wandb # Make one log on every process with the configuration for debugging. + raise ImportError( + "Make sure to install wandb if you want to use it for logging during training. You can do so by doing" + " `pip install wandb`" + ) + import wandb # Make one log on every process with the configuration for debugging. logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", @@ -790,7 +793,7 @@ def collate_fn(examples): # run inference prompt = [args.validation_prompt] - with torch.autocast("cuda"): + with torch.autocast("cuda"): images = pipeline(prompt, num_images_per_prompt=args.num_validation_images).images for i, image in enumerate(images): @@ -807,7 +810,7 @@ def collate_fn(examples): for i, image in enumerate(images) ] } - ) + ) del pipeline torch.cuda.empty_cache() From e4b5b4f731f9b99fe9da46916ec36633453e0282 Mon Sep 17 00:00:00 2001 From: Lucas Sloan Date: Thu, 2 Feb 2023 20:05:27 -0800 Subject: [PATCH 07/17] Nits from code review. Use batch size 1 and iterate over num_validation_images to avoid OOM. Set autocast device from accelerator.device. --- examples/text_to_image/train_text_to_image.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index 8519c4003240..a552d26b7bff 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -793,8 +793,10 @@ def collate_fn(examples): # run inference prompt = [args.validation_prompt] - with torch.autocast("cuda"): - images = pipeline(prompt, num_images_per_prompt=args.num_validation_images).images + images = [] + for _ in range(args.num_validation_images): + with torch.autocast(str(accelerator.device), enabled=accelerator.mixed_precision == "fp16"): + images.append(pipeline(prompt).images[0]) for i, image in enumerate(images): image.save(os.path.join(args.output_dir, f"sample-{global_step}-{i}.jpg")) From d0560cdc1a4572d0e05d4ec95f29c0a5bfae63b5 Mon Sep 17 00:00:00 2001 From: Lucas Sloan Date: Thu, 2 Feb 2023 20:11:15 -0800 Subject: [PATCH 08/17] Add final (test) inference. --- examples/text_to_image/train_text_to_image.py | 30 ++++++++++++++++++- 1 file changed, 29 insertions(+), 1 deletion(-) diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index a552d26b7bff..6e5f5dcf53b9 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -799,7 +799,7 @@ def collate_fn(examples): images.append(pipeline(prompt).images[0]) for i, image in enumerate(images): - image.save(os.path.join(args.output_dir, f"sample-{global_step}-{i}.jpg")) + image.save(os.path.join(args.output_dir, f"validation-{global_step}-{i}.jpg")) for tracker in accelerator.trackers: if tracker.name == "tensorboard": np_images = np.stack([np.asarray(img) for img in images]) @@ -838,6 +838,34 @@ def collate_fn(examples): ) pipeline.save_pretrained(args.output_dir) + if args.validation_prompt: + pipeline = pipeline.to(accelerator.device) + pipeline.set_progress_bar_config(disable=True) + + # run inference + prompt = [args.validation_prompt] + images = [] + for _ in range(args.num_validation_images): + with torch.autocast(str(accelerator.device), enabled=accelerator.mixed_precision == "fp16"): + images.append(pipeline(prompt, num_images_per_prompt=args.num_validation_images).images[0]) + + for i, image in enumerate(images): + image.save(os.path.join(args.output_dir, f"test-{i}.jpg")) + for tracker in accelerator.trackers: + if tracker.name == "tensorboard": + np_images = np.stack([np.asarray(img) for img in images]) + tracker.writer.add_images("test", np_images, epoch, dataformats="NHWC") + if tracker.name == "wandb": + tracker.log( + { + "test": [ + wandb.Image(image, caption=f"{i}: {args.validation_prompt}") + for i, image in enumerate(images) + ] + } + ) + + if args.push_to_hub: repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True) From ddf5bcc9a4aa585fe15c83d8f801c2f6dc164ddb Mon Sep 17 00:00:00 2001 From: Lucas Sloan Date: Thu, 2 Feb 2023 20:14:02 -0800 Subject: [PATCH 09/17] Control the validation steps separately from the checkpointing steps. --- examples/text_to_image/train_text_to_image.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index 6e5f5dcf53b9..3a5edc6401e6 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -314,6 +314,14 @@ def parse_args(): default=4, help="Number of images that should be generated during validation with `validation_prompt`.", ) + parser.add_argument( + "--validation_steps", + type=int, + default=500, + help=( + "Sample a validation image every X updates." + ), + ) args = parser.parse_args() env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) @@ -780,6 +788,8 @@ def collate_fn(examples): accelerator.save_state(save_path) logger.info(f"Saved state to {save_path}") + if global_step % args.validation_steps == 0: + if accelerator.is_main_process: if args.validation_prompt: pipeline = StableDiffusionPipeline.from_pretrained( args.pretrained_model_name_or_path, From e8dd09d14b0aa6a344dd9be4cbe144fab3b5ef26 Mon Sep 17 00:00:00 2001 From: Lucas Sloan Date: Thu, 2 Feb 2023 20:26:43 -0800 Subject: [PATCH 10/17] Autoformatter. --- examples/text_to_image/train_text_to_image.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index 3a5edc6401e6..6a20a9e00db2 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -318,9 +318,7 @@ def parse_args(): "--validation_steps", type=int, default=500, - help=( - "Sample a validation image every X updates." - ), + help="Sample a validation image every X updates.", ) args = parser.parse_args() @@ -805,7 +803,9 @@ def collate_fn(examples): prompt = [args.validation_prompt] images = [] for _ in range(args.num_validation_images): - with torch.autocast(str(accelerator.device), enabled=accelerator.mixed_precision == "fp16"): + with torch.autocast( + str(accelerator.device), enabled=accelerator.mixed_precision == "fp16" + ): images.append(pipeline(prompt).images[0]) for i, image in enumerate(images): @@ -875,7 +875,6 @@ def collate_fn(examples): } ) - if args.push_to_hub: repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True) From 9f1da1d6335dc068fd14a515cf9bfa98688efdaf Mon Sep 17 00:00:00 2001 From: Lucas Sloan Date: Thu, 2 Feb 2023 23:37:52 -0800 Subject: [PATCH 11/17] Hoist the autocast out of the loop. --- examples/text_to_image/train_text_to_image.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index 6a20a9e00db2..0e44eab6a5bf 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -802,10 +802,10 @@ def collate_fn(examples): # run inference prompt = [args.validation_prompt] images = [] - for _ in range(args.num_validation_images): - with torch.autocast( - str(accelerator.device), enabled=accelerator.mixed_precision == "fp16" - ): + with torch.autocast( + str(accelerator.device), enabled=accelerator.mixed_precision == "fp16" + ): + for _ in range(args.num_validation_images): images.append(pipeline(prompt).images[0]) for i, image in enumerate(images): From b359a35f65dced834031318b9f1c62f1ed88c291 Mon Sep 17 00:00:00 2001 From: Lucas Sloan Date: Fri, 3 Feb 2023 00:33:13 -0800 Subject: [PATCH 12/17] Apply fixes to end of training sampling. --- examples/text_to_image/train_text_to_image.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index 0e44eab6a5bf..8f70b128229d 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -855,9 +855,9 @@ def collate_fn(examples): # run inference prompt = [args.validation_prompt] images = [] - for _ in range(args.num_validation_images): - with torch.autocast(str(accelerator.device), enabled=accelerator.mixed_precision == "fp16"): - images.append(pipeline(prompt, num_images_per_prompt=args.num_validation_images).images[0]) + with torch.autocast(str(accelerator.device), enabled=accelerator.mixed_precision == "fp16"): + for _ in range(args.num_validation_images): + images.append(pipeline(prompt).images[0]) for i, image in enumerate(images): image.save(os.path.join(args.output_dir, f"test-{i}.jpg")) From bc04c92071ac866f71c765b1bf596b0f44545043 Mon Sep 17 00:00:00 2001 From: Lucas Sloan Date: Sat, 18 Feb 2023 10:23:11 -0800 Subject: [PATCH 13/17] Use EMA params for validation images. --- examples/text_to_image/train_text_to_image.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index 8f70b128229d..022f48becb70 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -789,6 +789,10 @@ def collate_fn(examples): if global_step % args.validation_steps == 0: if accelerator.is_main_process: if args.validation_prompt: + if args.use_ema: + # Store the UNet parameters temporarily and load the EMA parameters to perform inference. + ema_unet.store(unet.parameters()) + ema_unet.copy_to(unet.parameters()) pipeline = StableDiffusionPipeline.from_pretrained( args.pretrained_model_name_or_path, text_encoder=text_encoder, @@ -823,6 +827,9 @@ def collate_fn(examples): ] } ) + if args.use_ema: + # Switch back to the original UNet parameters. + ema_unet.restore(unet.parameters()) del pipeline torch.cuda.empty_cache() From 50d19e7d70f364823a3fb2bee7833852ecc2b3f2 Mon Sep 17 00:00:00 2001 From: Lucas Sloan Date: Sat, 18 Feb 2023 12:06:22 -0800 Subject: [PATCH 14/17] run formatter --- examples/text_to_image/train_text_to_image.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index 022f48becb70..6e82f0ee90a0 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -791,7 +791,7 @@ def collate_fn(examples): if args.validation_prompt: if args.use_ema: # Store the UNet parameters temporarily and load the EMA parameters to perform inference. - ema_unet.store(unet.parameters()) + ema_unet.store(unet.parameters()) ema_unet.copy_to(unet.parameters()) pipeline = StableDiffusionPipeline.from_pretrained( args.pretrained_model_name_or_path, From a719b45c466d2cd1d2523b025a425d0bd845605c Mon Sep 17 00:00:00 2001 From: Lucas Sloan Date: Sat, 18 Feb 2023 12:12:22 -0800 Subject: [PATCH 15/17] Fix import issue caused by bad merge --- examples/text_to_image/train_text_to_image.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index 6e82f0ee90a0..bfcea8afe4f3 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -32,11 +32,6 @@ from accelerate.logging import get_logger from accelerate.utils import ProjectConfiguration, set_seed from datasets import load_dataset -from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel -from diffusers.optimization import get_scheduler -from diffusers.training_utils import EMAModel -from diffusers.utils import check_min_version, is_wandb_available -from diffusers.utils.import_utils import is_xformers_available from huggingface_hub import HfFolder, Repository, create_repo, whoami from packaging import version from torchvision import transforms From 5b536b0926b16fe1de979a43209ef159096f9b4a Mon Sep 17 00:00:00 2001 From: Lucas Sloan Date: Sat, 18 Feb 2023 12:14:20 -0800 Subject: [PATCH 16/17] More fixes --- examples/text_to_image/train_text_to_image.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index bfcea8afe4f3..d93a94ac73e5 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -43,7 +43,7 @@ from diffusers.optimization import get_scheduler from diffusers.training_utils import EMAModel from diffusers.utils import check_min_version, deprecate -from diffusers.utils.import_utils import is_xformers_available +from diffusers.utils.import_utils import is_xformers_available, is_wandb_available # Will error if the minimal version of diffusers is not installed. Remove at your own risks. From a3300baed04bf0e8beb81d3648c69f758c2e6710 Mon Sep 17 00:00:00 2001 From: Lucas Sloan Date: Sat, 18 Feb 2023 12:19:18 -0800 Subject: [PATCH 17/17] More fixes --- examples/text_to_image/train_text_to_image.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index d93a94ac73e5..e32656abe569 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -43,7 +43,7 @@ from diffusers.optimization import get_scheduler from diffusers.training_utils import EMAModel from diffusers.utils import check_min_version, deprecate -from diffusers.utils.import_utils import is_xformers_available, is_wandb_available +from diffusers.utils.import_utils import is_wandb_available, is_xformers_available # Will error if the minimal version of diffusers is not installed. Remove at your own risks.