diff --git a/examples/textual_inversion/textual_inversion.py b/examples/textual_inversion/textual_inversion.py index 74fcf71cb22c..467e710222de 100644 --- a/examples/textual_inversion/textual_inversion.py +++ b/examples/textual_inversion/textual_inversion.py @@ -1,5 +1,4 @@ import argparse -import itertools import math import os import random @@ -147,6 +146,11 @@ def parse_args(): default=1, help="Number of updates steps to accumulate before performing a backward/update pass.", ) + parser.add_argument( + "--gradient_checkpointing", + action="store_true", + help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", + ) parser.add_argument( "--learning_rate", type=float, @@ -383,11 +387,6 @@ def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: return f"{organization}/{model_id}" -def freeze_params(params): - for param in params: - param.requires_grad = False - - def main(): args = parse_args() logging_dir = os.path.join(args.output_dir, args.logging_dir) @@ -460,6 +459,10 @@ def main(): revision=args.revision, ) + if args.gradient_checkpointing: + text_encoder.gradient_checkpointing_enable() + unet.enable_gradient_checkpointing() + if args.enable_xformers_memory_efficient_attention: if is_xformers_available(): unet.enable_xformers_memory_efficient_attention() @@ -474,15 +477,12 @@ def main(): token_embeds[placeholder_token_id] = token_embeds[initializer_token_id] # Freeze vae and unet - freeze_params(vae.parameters()) - freeze_params(unet.parameters()) + vae.requires_grad_(False) + unet.requires_grad_(False) # Freeze all parameters except for the token embeddings in text encoder - params_to_freeze = itertools.chain( - text_encoder.text_model.encoder.parameters(), - text_encoder.text_model.final_layer_norm.parameters(), - text_encoder.text_model.embeddings.position_embedding.parameters(), - ) - freeze_params(params_to_freeze) + text_encoder.text_model.encoder.requires_grad_(False) + text_encoder.text_model.final_layer_norm.requires_grad_(False) + text_encoder.text_model.embeddings.position_embedding.requires_grad_(False) if args.scale_lr: args.learning_rate = ( @@ -541,9 +541,10 @@ def main(): unet.to(accelerator.device, dtype=weight_dtype) vae.to(accelerator.device, dtype=weight_dtype) - # Keep vae and unet in eval model as we don't train these - vae.eval() - unet.eval() + # Keep unet in train mode if we are using gradient checkpointing to save memory. + # The dropout is 0 so it doesn't matter if we are in eval or train mode. + if args.gradient_checkpointing: + unet.train() # We need to recalculate our total training steps as the size of the training dataloader may have changed. num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) @@ -609,12 +610,11 @@ def main(): latents = latents * 0.18215 # Sample noise that we'll add to the latents - noise = torch.randn(latents.shape).to(latents.device).to(dtype=weight_dtype) + noise = torch.randn_like(latents) bsz = latents.shape[0] # Sample a random timestep for each image - timesteps = torch.randint( - 0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device - ).long() + timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device) + timesteps = timesteps.long() # Add noise to the latents according to the noise magnitude at each timestep # (this is the forward diffusion process) @@ -669,8 +669,7 @@ def main(): if global_step >= args.max_train_steps: break - accelerator.wait_for_everyone() - + accelerator.wait_for_everyone() # Create the pipeline using using the trained modules and save it. if accelerator.is_main_process: if args.push_to_hub and args.only_save_embeds: