From 938fa2805591998e9f5a9fa22296638fe6182f36 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 31 Jan 2023 15:04:35 +0530 Subject: [PATCH 01/10] add: logging to text2image. --- examples/text_to_image/train_text_to_image.py | 86 ++++++++++++++++++- 1 file changed, 85 insertions(+), 1 deletion(-) diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index 7f6ddeaee135..986a2d151605 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -36,7 +36,7 @@ from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel from diffusers.optimization import get_scheduler from diffusers.training_utils import EMAModel -from diffusers.utils import check_min_version +from diffusers.utils import check_min_version, is_wandb_available from diffusers.utils.import_utils import is_xformers_available from huggingface_hub import HfFolder, Repository, create_repo, whoami from torchvision import transforms @@ -101,6 +101,24 @@ def parse_args(): default="text", help="The column of the dataset containing a caption or a list of captions.", ) + parser.add_argument( + "--validation_prompt", type=str, default=None, help="A prompt that is sampled during training for inference." + ) + parser.add_argument( + "--num_validation_images", + type=int, + default=4, + help="Number of images that should be generated during validation with `validation_prompt`.", + ) + parser.add_argument( + "--validation_epochs", + type=int, + default=1, + help=( + "Run fine-tuning validation every X epochs. The validation process consists of running the prompt" + " `args.validation_prompt` multiple times: `args.num_validation_images`." + ), + ) parser.add_argument( "--max_train_samples", type=int, @@ -328,6 +346,11 @@ def main(): logging_dir=logging_dir, ) + if args.report_to == "wandb": + if not is_wandb_available(): + raise ImportError("Make sure to install wandb if you want to use it for logging during training.") + import wandb + # Make one log on every process with the configuration for debugging. logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", @@ -697,6 +720,45 @@ def collate_fn(examples): if global_step >= args.max_train_steps: break + if args.validation_prompt is not None and epoch % args.validation_epochs == 0: + logger.info( + f"Running validation... \n Generating {args.num_validation_images} images with prompt:" + f" {args.validation_prompt}." + ) + # create pipeline + pipeline = StableDiffusionPipeline.from_pretrained( + args.pretrained_model_name_or_path, + unet=accelerator.unwrap_model(unet), + revision=args.revision, + torch_dtype=weight_dtype, + ) + pipeline = pipeline.to(accelerator.device) + pipeline.set_progress_bar_config(disable=True) + + # run inference + generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) + images = [] + for _ in range(args.num_validation_images): + images.append(pipeline(args.validation_prompt, num_inference_steps=30, generator=generator).images[0]) + + if accelerator.is_main_process: + for tracker in accelerator.trackers: + if tracker.name == "tensorboard": + np_images = np.stack([np.asarray(img) for img in images]) + tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC") + if tracker.name == "wandb": + tracker.log( + { + "validation": [ + wandb.Image(image, caption=f"{i}: {args.validation_prompt}") + for i, image in enumerate(images) + ] + } + ) + + del pipeline + torch.cuda.empty_cache() + # Create the pipeline using the trained modules and save it. accelerator.wait_for_everyone() if accelerator.is_main_process: @@ -716,6 +778,28 @@ def collate_fn(examples): if args.push_to_hub: repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True) + if args.validation_prompt is not None: + # run inference + generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) + images = [] + for _ in range(args.num_validation_images): + images.append(pipeline(args.validation_prompt, num_inference_steps=30, generator=generator).images[0]) + + if accelerator.is_main_process: + for tracker in accelerator.trackers: + if tracker.name == "tensorboard": + np_images = np.stack([np.asarray(img) for img in images]) + tracker.writer.add_images("test", np_images, epoch, dataformats="NHWC") + if tracker.name == "wandb": + tracker.log( + { + "test": [ + wandb.Image(image, caption=f"{i}: {args.validation_prompt}") + for i, image in enumerate(images) + ] + } + ) + accelerator.end_training() From 69af6cecd74131fe963524df0319dfbded869f05 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 31 Jan 2023 15:09:40 +0530 Subject: [PATCH 02/10] add: autocast block. --- examples/text_to_image/train_text_to_image.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index 986a2d151605..a35b4868889d 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -739,7 +739,8 @@ def collate_fn(examples): generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) images = [] for _ in range(args.num_validation_images): - images.append(pipeline(args.validation_prompt, num_inference_steps=30, generator=generator).images[0]) + with torch.autocast(device_type="cuda", dtype=weight_dtype): + images.append(pipeline(args.validation_prompt, num_inference_steps=30, generator=generator).images[0]) if accelerator.is_main_process: for tracker in accelerator.trackers: From 67c6c568e32c79da856a054e572300cbd447fffe Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 31 Jan 2023 15:23:36 +0530 Subject: [PATCH 03/10] apply make style/. --- examples/text_to_image/train_text_to_image.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index a35b4868889d..6bf89583c792 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -740,7 +740,9 @@ def collate_fn(examples): images = [] for _ in range(args.num_validation_images): with torch.autocast(device_type="cuda", dtype=weight_dtype): - images.append(pipeline(args.validation_prompt, num_inference_steps=30, generator=generator).images[0]) + images.append( + pipeline(args.validation_prompt, num_inference_steps=30, generator=generator).images[0] + ) if accelerator.is_main_process: for tracker in accelerator.trackers: From ef2abd8e912452a24d7d2b00a9c85de41cf86f64 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 31 Jan 2023 15:43:35 +0530 Subject: [PATCH 04/10] disable unwrapping. --- examples/text_to_image/train_text_to_image.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index 6bf89583c792..4d9f402c32e4 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -728,9 +728,9 @@ def collate_fn(examples): # create pipeline pipeline = StableDiffusionPipeline.from_pretrained( args.pretrained_model_name_or_path, - unet=accelerator.unwrap_model(unet), + unet=unet, revision=args.revision, - torch_dtype=weight_dtype, + # torch_dtype=weight_dtype, ) pipeline = pipeline.to(accelerator.device) pipeline.set_progress_bar_config(disable=True) From 6582c802b6394f0bf0f5a77e77c2d81ffb0bf19a Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 31 Jan 2023 15:54:12 +0530 Subject: [PATCH 05/10] autocast context manager before final inference. --- examples/text_to_image/train_text_to_image.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index 4d9f402c32e4..f358260a45b8 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -738,8 +738,8 @@ def collate_fn(examples): # run inference generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) images = [] - for _ in range(args.num_validation_images): - with torch.autocast(device_type="cuda", dtype=weight_dtype): + with torch.autocast(device_type="cuda", dtype=weight_dtype): + for _ in range(args.num_validation_images): images.append( pipeline(args.validation_prompt, num_inference_steps=30, generator=generator).images[0] ) @@ -785,8 +785,11 @@ def collate_fn(examples): # run inference generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) images = [] - for _ in range(args.num_validation_images): - images.append(pipeline(args.validation_prompt, num_inference_steps=30, generator=generator).images[0]) + with torch.autocast(device_type="cuda", dtype=weight_dtype): + for _ in range(args.num_validation_images): + images.append( + pipeline(args.validation_prompt, num_inference_steps=30, generator=generator).images[0] + ) if accelerator.is_main_process: for tracker in accelerator.trackers: From 26bc0d79a5a38d5732658a282c8492aba1245d24 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 31 Jan 2023 16:06:30 +0530 Subject: [PATCH 06/10] remove autocasts. --- examples/text_to_image/train_text_to_image.py | 21 ++++++++----------- 1 file changed, 9 insertions(+), 12 deletions(-) diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index f358260a45b8..33ba82121bd2 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -728,9 +728,8 @@ def collate_fn(examples): # create pipeline pipeline = StableDiffusionPipeline.from_pretrained( args.pretrained_model_name_or_path, - unet=unet, + unet=accelerator.unwrap_model(unet), revision=args.revision, - # torch_dtype=weight_dtype, ) pipeline = pipeline.to(accelerator.device) pipeline.set_progress_bar_config(disable=True) @@ -738,11 +737,10 @@ def collate_fn(examples): # run inference generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) images = [] - with torch.autocast(device_type="cuda", dtype=weight_dtype): - for _ in range(args.num_validation_images): - images.append( - pipeline(args.validation_prompt, num_inference_steps=30, generator=generator).images[0] - ) + for _ in range(args.num_validation_images): + images.append( + pipeline(args.validation_prompt, num_inference_steps=30, generator=generator).images[0] + ) if accelerator.is_main_process: for tracker in accelerator.trackers: @@ -785,11 +783,10 @@ def collate_fn(examples): # run inference generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) images = [] - with torch.autocast(device_type="cuda", dtype=weight_dtype): - for _ in range(args.num_validation_images): - images.append( - pipeline(args.validation_prompt, num_inference_steps=30, generator=generator).images[0] - ) + for _ in range(args.num_validation_images): + images.append( + pipeline(args.validation_prompt, num_inference_steps=30, generator=generator).images[0] + ) if accelerator.is_main_process: for tracker in accelerator.trackers: From 5a33617da9f5171e9b9ca83fa102d8c7c3ed07b2 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 31 Jan 2023 16:10:20 +0530 Subject: [PATCH 07/10] make style. --- examples/text_to_image/train_text_to_image.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index 33ba82121bd2..eb1cd083e07e 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -738,9 +738,7 @@ def collate_fn(examples): generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) images = [] for _ in range(args.num_validation_images): - images.append( - pipeline(args.validation_prompt, num_inference_steps=30, generator=generator).images[0] - ) + images.append(pipeline(args.validation_prompt, num_inference_steps=30, generator=generator).images[0]) if accelerator.is_main_process: for tracker in accelerator.trackers: @@ -784,9 +782,7 @@ def collate_fn(examples): generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) images = [] for _ in range(args.num_validation_images): - images.append( - pipeline(args.validation_prompt, num_inference_steps=30, generator=generator).images[0] - ) + images.append(pipeline(args.validation_prompt, num_inference_steps=30, generator=generator).images[0]) if accelerator.is_main_process: for tracker in accelerator.trackers: From 37a35c11b41a2c78332997eeabe441cfc8a6a36b Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 31 Jan 2023 17:12:24 +0530 Subject: [PATCH 08/10] add: safety checker. --- examples/text_to_image/train_text_to_image.py | 77 ++++++++++--------- 1 file changed, 42 insertions(+), 35 deletions(-) diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index eb1cd083e07e..eb4f3b1c5f99 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -35,6 +35,7 @@ from datasets import load_dataset from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel from diffusers.optimization import get_scheduler +from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker from diffusers.training_utils import EMAModel from diffusers.utils import check_min_version, is_wandb_available from diffusers.utils.import_utils import is_xformers_available @@ -720,27 +721,34 @@ def collate_fn(examples): if global_step >= args.max_train_steps: break - if args.validation_prompt is not None and epoch % args.validation_epochs == 0: - logger.info( - f"Running validation... \n Generating {args.num_validation_images} images with prompt:" - f" {args.validation_prompt}." - ) - # create pipeline - pipeline = StableDiffusionPipeline.from_pretrained( - args.pretrained_model_name_or_path, - unet=accelerator.unwrap_model(unet), - revision=args.revision, - ) - pipeline = pipeline.to(accelerator.device) - pipeline.set_progress_bar_config(disable=True) + if accelerator.is_main_process: + if args.validation_prompt is not None and epoch % args.validation_epochs == 0: + logger.info( + f"Running validation... \n Generating {args.num_validation_images} images with prompt:" + f" {args.validation_prompt}." + ) + # create pipeline + pt_safety_checker = StableDiffusionSafetyChecker.from_pretrained( + args.pretrained_model_name_or_path, subfolder="safety_checker", revision=args.non_ema_version + ) + pt_safety_checker.to(accelerator.device, dtype=weight_dtype) + pipeline = StableDiffusionPipeline.from_pretrained( + args.pretrained_model_name_or_path, + unet=accelerator.unwrap_model(unet), + safety_checker=pt_safety_checker, + revision=args.revision, + ) + pipeline = pipeline.to(accelerator.device) + pipeline.set_progress_bar_config(disable=True) - # run inference - generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) - images = [] - for _ in range(args.num_validation_images): - images.append(pipeline(args.validation_prompt, num_inference_steps=30, generator=generator).images[0]) + # run inference + generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) + images = [] + for _ in range(args.num_validation_images): + images.append( + pipeline(args.validation_prompt, num_inference_steps=30, generator=generator).images[0] + ) - if accelerator.is_main_process: for tracker in accelerator.trackers: if tracker.name == "tensorboard": np_images = np.stack([np.asarray(img) for img in images]) @@ -755,8 +763,8 @@ def collate_fn(examples): } ) - del pipeline - torch.cuda.empty_cache() + del pipeline + torch.cuda.empty_cache() # Create the pipeline using the trained modules and save it. accelerator.wait_for_everyone() @@ -784,20 +792,19 @@ def collate_fn(examples): for _ in range(args.num_validation_images): images.append(pipeline(args.validation_prompt, num_inference_steps=30, generator=generator).images[0]) - if accelerator.is_main_process: - for tracker in accelerator.trackers: - if tracker.name == "tensorboard": - np_images = np.stack([np.asarray(img) for img in images]) - tracker.writer.add_images("test", np_images, epoch, dataformats="NHWC") - if tracker.name == "wandb": - tracker.log( - { - "test": [ - wandb.Image(image, caption=f"{i}: {args.validation_prompt}") - for i, image in enumerate(images) - ] - } - ) + for tracker in accelerator.trackers: + if tracker.name == "tensorboard": + np_images = np.stack([np.asarray(img) for img in images]) + tracker.writer.add_images("test", np_images, epoch, dataformats="NHWC") + if tracker.name == "wandb": + tracker.log( + { + "test": [ + wandb.Image(image, caption=f"{i}: {args.validation_prompt}") + for i, image in enumerate(images) + ] + } + ) accelerator.end_training() From f2a143fc362040778008730a661a294ebc20ac1c Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 31 Jan 2023 17:31:25 +0530 Subject: [PATCH 09/10] fix: cli arg. --- examples/text_to_image/train_text_to_image.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index eb4f3b1c5f99..5c44fde214f9 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -729,7 +729,7 @@ def collate_fn(examples): ) # create pipeline pt_safety_checker = StableDiffusionSafetyChecker.from_pretrained( - args.pretrained_model_name_or_path, subfolder="safety_checker", revision=args.non_ema_version + args.pretrained_model_name_or_path, subfolder="safety_checker", revision=args.non_ema_revision ) pt_safety_checker.to(accelerator.device, dtype=weight_dtype) pipeline = StableDiffusionPipeline.from_pretrained( From a0e844d045f9c37c7276109640d8367b033484db Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 31 Jan 2023 17:43:35 +0530 Subject: [PATCH 10/10] disable casting for safety checker. --- examples/text_to_image/train_text_to_image.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index 5c44fde214f9..d7bad9d8adc7 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -728,14 +728,14 @@ def collate_fn(examples): f" {args.validation_prompt}." ) # create pipeline - pt_safety_checker = StableDiffusionSafetyChecker.from_pretrained( + safety_checker = StableDiffusionSafetyChecker.from_pretrained( args.pretrained_model_name_or_path, subfolder="safety_checker", revision=args.non_ema_revision ) - pt_safety_checker.to(accelerator.device, dtype=weight_dtype) + # safety_checker.to(accelerator.device, dtype=weight_dtype) pipeline = StableDiffusionPipeline.from_pretrained( args.pretrained_model_name_or_path, unet=accelerator.unwrap_model(unet), - safety_checker=pt_safety_checker, + safety_checker=safety_checker, revision=args.revision, ) pipeline = pipeline.to(accelerator.device) @@ -762,7 +762,7 @@ def collate_fn(examples): ] } ) - + del safety_checker del pipeline torch.cuda.empty_cache()