From 7a295feee28ff9a00d4020ee4214ba8f3eb04887 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Wed, 28 Dec 2022 15:13:03 +0100 Subject: [PATCH 1/3] update TI script --- .../textual_inversion/textual_inversion.py | 44 +++++++++---------- 1 file changed, 22 insertions(+), 22 deletions(-) diff --git a/examples/textual_inversion/textual_inversion.py b/examples/textual_inversion/textual_inversion.py index 74fcf71cb22c..5633cfd223f0 100644 --- a/examples/textual_inversion/textual_inversion.py +++ b/examples/textual_inversion/textual_inversion.py @@ -147,6 +147,11 @@ def parse_args(): default=1, help="Number of updates steps to accumulate before performing a backward/update pass.", ) + parser.add_argument( + "--gradient_checkpointing", + action="store_true", + help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", + ) parser.add_argument( "--learning_rate", type=float, @@ -383,11 +388,6 @@ def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: return f"{organization}/{model_id}" -def freeze_params(params): - for param in params: - param.requires_grad = False - - def main(): args = parse_args() logging_dir = os.path.join(args.output_dir, args.logging_dir) @@ -460,6 +460,10 @@ def main(): revision=args.revision, ) + if args.gradient_checkpointing: + text_encoder.gradient_checkpointing_enable() + unet.enable_gradient_checkpointing() + if args.enable_xformers_memory_efficient_attention: if is_xformers_available(): unet.enable_xformers_memory_efficient_attention() @@ -474,15 +478,12 @@ def main(): token_embeds[placeholder_token_id] = token_embeds[initializer_token_id] # Freeze vae and unet - freeze_params(vae.parameters()) - freeze_params(unet.parameters()) + vae.requires_grad_(False) + unet.requires_grad_(False) # Freeze all parameters except for the token embeddings in text encoder - params_to_freeze = itertools.chain( - text_encoder.text_model.encoder.parameters(), - text_encoder.text_model.final_layer_norm.parameters(), - text_encoder.text_model.embeddings.position_embedding.parameters(), - ) - freeze_params(params_to_freeze) + text_encoder.text_model.encoder.requires_grad_(False) + text_encoder.text_model.final_layer_norm.requires_grad_(False) + text_encoder.text_model.embeddings.position_embedding.requires_grad_(False) if args.scale_lr: args.learning_rate = ( @@ -541,9 +542,10 @@ def main(): unet.to(accelerator.device, dtype=weight_dtype) vae.to(accelerator.device, dtype=weight_dtype) - # Keep vae and unet in eval model as we don't train these - vae.eval() - unet.eval() + # Keep unet in train model if we are using gradient checkpointing to save memory. + # The dropout is 0 so it doesn't matter if we are in eval or train mode. + if args.gradient_checkpointing: + unet.train() # We need to recalculate our total training steps as the size of the training dataloader may have changed. num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) @@ -609,12 +611,11 @@ def main(): latents = latents * 0.18215 # Sample noise that we'll add to the latents - noise = torch.randn(latents.shape).to(latents.device).to(dtype=weight_dtype) + noise = torch.randn_like(latents) bsz = latents.shape[0] # Sample a random timestep for each image - timesteps = torch.randint( - 0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device - ).long() + timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device) + timesteps = timesteps.long() # Add noise to the latents according to the noise magnitude at each timestep # (this is the forward diffusion process) @@ -669,8 +670,7 @@ def main(): if global_step >= args.max_train_steps: break - accelerator.wait_for_everyone() - + accelerator.wait_for_everyone() # Create the pipeline using using the trained modules and save it. if accelerator.is_main_process: if args.push_to_hub and args.only_save_embeds: From 8db26b586457cf8f908abf47308a8abf9fd10b0c Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Wed, 28 Dec 2022 15:19:38 +0100 Subject: [PATCH 2/3] make flake happy --- examples/textual_inversion/textual_inversion.py | 1 - 1 file changed, 1 deletion(-) diff --git a/examples/textual_inversion/textual_inversion.py b/examples/textual_inversion/textual_inversion.py index 5633cfd223f0..ac57a3708063 100644 --- a/examples/textual_inversion/textual_inversion.py +++ b/examples/textual_inversion/textual_inversion.py @@ -1,5 +1,4 @@ import argparse -import itertools import math import os import random From 816209ae4f3340ad9a39301557b74f98e559fb6e Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Wed, 28 Dec 2022 16:57:29 +0100 Subject: [PATCH 3/3] fix typ --- examples/textual_inversion/textual_inversion.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/textual_inversion/textual_inversion.py b/examples/textual_inversion/textual_inversion.py index ac57a3708063..467e710222de 100644 --- a/examples/textual_inversion/textual_inversion.py +++ b/examples/textual_inversion/textual_inversion.py @@ -541,7 +541,7 @@ def main(): unet.to(accelerator.device, dtype=weight_dtype) vae.to(accelerator.device, dtype=weight_dtype) - # Keep unet in train model if we are using gradient checkpointing to save memory. + # Keep unet in train mode if we are using gradient checkpointing to save memory. # The dropout is 0 so it doesn't matter if we are in eval or train mode. if args.gradient_checkpointing: unet.train()