-
Notifications
You must be signed in to change notification settings - Fork 6.7k
add: logging to text2image. #2173
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
938fa28
69af6ce
67c6c56
ef2abd8
6582c80
26bc0d7
5a33617
37a35c1
f2a143f
a0e844d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -35,8 +35,9 @@ | |
| from datasets import load_dataset | ||
| from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel | ||
| from diffusers.optimization import get_scheduler | ||
| from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker | ||
| from diffusers.training_utils import EMAModel | ||
| from diffusers.utils import check_min_version | ||
| from diffusers.utils import check_min_version, is_wandb_available | ||
| from diffusers.utils.import_utils import is_xformers_available | ||
| from huggingface_hub import HfFolder, Repository, create_repo, whoami | ||
| from torchvision import transforms | ||
|
|
@@ -101,6 +102,24 @@ def parse_args(): | |
| default="text", | ||
| help="The column of the dataset containing a caption or a list of captions.", | ||
| ) | ||
| parser.add_argument( | ||
| "--validation_prompt", type=str, default=None, help="A prompt that is sampled during training for inference." | ||
| ) | ||
| parser.add_argument( | ||
| "--num_validation_images", | ||
| type=int, | ||
| default=4, | ||
| help="Number of images that should be generated during validation with `validation_prompt`.", | ||
| ) | ||
| parser.add_argument( | ||
| "--validation_epochs", | ||
| type=int, | ||
| default=1, | ||
| help=( | ||
| "Run fine-tuning validation every X epochs. The validation process consists of running the prompt" | ||
| " `args.validation_prompt` multiple times: `args.num_validation_images`." | ||
| ), | ||
| ) | ||
| parser.add_argument( | ||
| "--max_train_samples", | ||
| type=int, | ||
|
|
@@ -328,6 +347,11 @@ def main(): | |
| logging_dir=logging_dir, | ||
| ) | ||
|
|
||
| if args.report_to == "wandb": | ||
| if not is_wandb_available(): | ||
| raise ImportError("Make sure to install wandb if you want to use it for logging during training.") | ||
| import wandb | ||
|
|
||
| # Make one log on every process with the configuration for debugging. | ||
| logging.basicConfig( | ||
| format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", | ||
|
|
@@ -697,6 +721,51 @@ def collate_fn(examples): | |
| if global_step >= args.max_train_steps: | ||
| break | ||
|
|
||
| if accelerator.is_main_process: | ||
| if args.validation_prompt is not None and epoch % args.validation_epochs == 0: | ||
| logger.info( | ||
| f"Running validation... \n Generating {args.num_validation_images} images with prompt:" | ||
| f" {args.validation_prompt}." | ||
| ) | ||
| # create pipeline | ||
| safety_checker = StableDiffusionSafetyChecker.from_pretrained( | ||
| args.pretrained_model_name_or_path, subfolder="safety_checker", revision=args.non_ema_revision | ||
| ) | ||
|
Comment on lines
+731
to
+733
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also, I'm not sure what difference it makes when we load |
||
| # safety_checker.to(accelerator.device, dtype=weight_dtype) | ||
| pipeline = StableDiffusionPipeline.from_pretrained( | ||
| args.pretrained_model_name_or_path, | ||
| unet=accelerator.unwrap_model(unet), | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. When doing
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. to do that, we'll need to
For that we'll need to add the Happy to take care of this if you want :) |
||
| safety_checker=safety_checker, | ||
| revision=args.revision, | ||
|
Comment on lines
+738
to
+739
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We could also directly pass |
||
| ) | ||
| pipeline = pipeline.to(accelerator.device) | ||
| pipeline.set_progress_bar_config(disable=True) | ||
|
|
||
| # run inference | ||
| generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) | ||
| images = [] | ||
| for _ in range(args.num_validation_images): | ||
| images.append( | ||
| pipeline(args.validation_prompt, num_inference_steps=30, generator=generator).images[0] | ||
| ) | ||
|
|
||
| for tracker in accelerator.trackers: | ||
| if tracker.name == "tensorboard": | ||
| np_images = np.stack([np.asarray(img) for img in images]) | ||
| tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC") | ||
| if tracker.name == "wandb": | ||
| tracker.log( | ||
| { | ||
| "validation": [ | ||
| wandb.Image(image, caption=f"{i}: {args.validation_prompt}") | ||
| for i, image in enumerate(images) | ||
| ] | ||
| } | ||
| ) | ||
| del safety_checker | ||
| del pipeline | ||
| torch.cuda.empty_cache() | ||
|
|
||
| # Create the pipeline using the trained modules and save it. | ||
| accelerator.wait_for_everyone() | ||
| if accelerator.is_main_process: | ||
|
|
@@ -716,6 +785,27 @@ def collate_fn(examples): | |
| if args.push_to_hub: | ||
| repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True) | ||
|
|
||
| if args.validation_prompt is not None: | ||
| # run inference | ||
| generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) | ||
| images = [] | ||
| for _ in range(args.num_validation_images): | ||
| images.append(pipeline(args.validation_prompt, num_inference_steps=30, generator=generator).images[0]) | ||
|
|
||
| for tracker in accelerator.trackers: | ||
| if tracker.name == "tensorboard": | ||
| np_images = np.stack([np.asarray(img) for img in images]) | ||
| tracker.writer.add_images("test", np_images, epoch, dataformats="NHWC") | ||
| if tracker.name == "wandb": | ||
| tracker.log( | ||
| { | ||
| "test": [ | ||
| wandb.Image(image, caption=f"{i}: {args.validation_prompt}") | ||
| for i, image in enumerate(images) | ||
| ] | ||
| } | ||
| ) | ||
|
|
||
| accelerator.end_training() | ||
|
|
||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.