-
Notifications
You must be signed in to change notification settings - Fork 6.7k
[WIP] Sample images when checkpointing. #2157
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
6aa6402
09b8e21
c291cce
4a4821c
27c4fc3
ef77d09
e4b5b4f
d0560cd
ddf5bcc
e8dd09d
9f1da1d
b359a35
bc04c92
50d19e7
a719b45
5b536b0
a3300ba
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||
|---|---|---|---|---|
|
|
@@ -43,7 +43,7 @@ | |||
| from diffusers.optimization import get_scheduler | ||||
| from diffusers.training_utils import EMAModel | ||||
| from diffusers.utils import check_min_version, deprecate | ||||
| from diffusers.utils.import_utils import is_xformers_available | ||||
| from diffusers.utils.import_utils import is_wandb_available, is_xformers_available | ||||
|
|
||||
|
|
||||
| # Will error if the minimal version of diffusers is not installed. Remove at your own risks. | ||||
|
|
@@ -297,6 +297,24 @@ def parse_args(): | |||
| parser.add_argument( | ||||
| "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers." | ||||
| ) | ||||
| parser.add_argument( | ||||
| "--validation_prompt", | ||||
| type=str, | ||||
| default=None, | ||||
| help="A prompt that is used during validation to verify that the model is learning.", | ||||
| ) | ||||
| parser.add_argument( | ||||
| "--num_validation_images", | ||||
| type=int, | ||||
| default=4, | ||||
| help="Number of images that should be generated during validation with `validation_prompt`.", | ||||
| ) | ||||
| parser.add_argument( | ||||
| "--validation_steps", | ||||
| type=int, | ||||
| default=500, | ||||
| help="Sample a validation image every X updates.", | ||||
| ) | ||||
|
|
||||
| args = parser.parse_args() | ||||
| env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) | ||||
|
|
@@ -353,7 +371,14 @@ def main(): | |||
| project_config=accelerator_project_config, | ||||
| ) | ||||
|
|
||||
| # Make one log on every process with the configuration for debugging. | ||||
| if args.report_to == "wandb": | ||||
| if not is_wandb_available(): | ||||
| raise ImportError( | ||||
| "Make sure to install wandb if you want to use it for logging during training. You can do so by doing" | ||||
| " `pip install wandb`" | ||||
| ) | ||||
| import wandb # Make one log on every process with the configuration for debugging. | ||||
|
|
||||
| logging.basicConfig( | ||||
| format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", | ||||
| datefmt="%m/%d/%Y %H:%M:%S", | ||||
|
|
@@ -756,6 +781,53 @@ def collate_fn(examples): | |||
| accelerator.save_state(save_path) | ||||
| logger.info(f"Saved state to {save_path}") | ||||
|
|
||||
| if global_step % args.validation_steps == 0: | ||||
| if accelerator.is_main_process: | ||||
| if args.validation_prompt: | ||||
| if args.use_ema: | ||||
| # Store the UNet parameters temporarily and load the EMA parameters to perform inference. | ||||
| ema_unet.store(unet.parameters()) | ||||
| ema_unet.copy_to(unet.parameters()) | ||||
| pipeline = StableDiffusionPipeline.from_pretrained( | ||||
| args.pretrained_model_name_or_path, | ||||
| text_encoder=text_encoder, | ||||
| vae=vae, | ||||
| unet=unet, | ||||
| revision=args.revision, | ||||
| ) | ||||
| pipeline = pipeline.to(accelerator.device) | ||||
| pipeline.set_progress_bar_config(disable=True) | ||||
|
|
||||
| # run inference | ||||
| prompt = [args.validation_prompt] | ||||
| images = [] | ||||
| with torch.autocast( | ||||
| str(accelerator.device), enabled=accelerator.mixed_precision == "fp16" | ||||
| ): | ||||
| for _ in range(args.num_validation_images): | ||||
| images.append(pipeline(prompt).images[0]) | ||||
|
|
||||
| for i, image in enumerate(images): | ||||
| image.save(os.path.join(args.output_dir, f"validation-{global_step}-{i}.jpg")) | ||||
| for tracker in accelerator.trackers: | ||||
| if tracker.name == "tensorboard": | ||||
| np_images = np.stack([np.asarray(img) for img in images]) | ||||
| tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC") | ||||
| if tracker.name == "wandb": | ||||
| tracker.log( | ||||
| { | ||||
| "validation": [ | ||||
| wandb.Image(image, caption=f"{i}: {args.validation_prompt}") | ||||
| for i, image in enumerate(images) | ||||
| ] | ||||
| } | ||||
| ) | ||||
| if args.use_ema: | ||||
| # Switch back to the original UNet parameters. | ||||
| ema_unet.restore(unet.parameters()) | ||||
| del pipeline | ||||
| torch.cuda.empty_cache() | ||||
|
|
||||
| logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} | ||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we also want to add another final inference logging as done in the
|
||||
| progress_bar.set_postfix(**logs) | ||||
|
|
||||
|
|
@@ -778,6 +850,33 @@ def collate_fn(examples): | |||
| ) | ||||
| pipeline.save_pretrained(args.output_dir) | ||||
|
|
||||
| if args.validation_prompt: | ||||
| pipeline = pipeline.to(accelerator.device) | ||||
| pipeline.set_progress_bar_config(disable=True) | ||||
|
|
||||
| # run inference | ||||
| prompt = [args.validation_prompt] | ||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this required? We can simply pass the validation prompt, no? |
||||
| images = [] | ||||
| with torch.autocast(str(accelerator.device), enabled=accelerator.mixed_precision == "fp16"): | ||||
| for _ in range(args.num_validation_images): | ||||
| images.append(pipeline(prompt).images[0]) | ||||
|
|
||||
| for i, image in enumerate(images): | ||||
| image.save(os.path.join(args.output_dir, f"test-{i}.jpg")) | ||||
| for tracker in accelerator.trackers: | ||||
| if tracker.name == "tensorboard": | ||||
| np_images = np.stack([np.asarray(img) for img in images]) | ||||
| tracker.writer.add_images("test", np_images, epoch, dataformats="NHWC") | ||||
| if tracker.name == "wandb": | ||||
| tracker.log( | ||||
| { | ||||
| "test": [ | ||||
| wandb.Image(image, caption=f"{i}: {args.validation_prompt}") | ||||
| for i, image in enumerate(images) | ||||
| ] | ||||
| } | ||||
| ) | ||||
|
|
||||
| if args.push_to_hub: | ||||
| repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True) | ||||
|
|
||||
|
|
||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should add
validation_epochsand generate according to that instead of generating after each loop.you could refer to this script to see how to do that
diffusers/examples/dreambooth/train_dreambooth_lora.py
Line 926 in 9213d81
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
And this should be wrapped under the main process condition (
if accelerator.is_main_process:) to handle situations for multi-GPU training.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm currently using the args.checkpointing_steps. Do we have a preference for # of epochs vs. # of global steps? I slightly favor # of global steps, since that's how we're controlling checkpointing.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, for validation images, we prefer epochs since conceptually it's a bit simpler to think of when the inference is going to take place.