From dba0b97c8914a1132b9e07160c6334c5eed87245 Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Tue, 12 Dec 2023 02:05:18 -0800 Subject: [PATCH 1/4] Clean up comments in LCM(-LoRA) distillation scripts. --- .../train_lcm_distill_lora_sd_wds.py | 63 ++++++++++------ .../train_lcm_distill_lora_sdxl_wds.py | 52 +++++++++----- .../train_lcm_distill_sd_wds.py | 72 +++++++++++-------- .../train_lcm_distill_sdxl_wds.py | 58 +++++++++------ 4 files changed, 152 insertions(+), 93 deletions(-) diff --git a/examples/consistency_distillation/train_lcm_distill_lora_sd_wds.py b/examples/consistency_distillation/train_lcm_distill_lora_sd_wds.py index c96733f0425e..fd86c6eaeb4e 100644 --- a/examples/consistency_distillation/train_lcm_distill_lora_sd_wds.py +++ b/examples/consistency_distillation/train_lcm_distill_lora_sd_wds.py @@ -156,7 +156,7 @@ def __call__(self, x): return False -class Text2ImageDataset: +class SDText2ImageDataset: def __init__( self, train_shards_path_or_url: Union[str, List[str]], @@ -835,16 +835,17 @@ def main(args): args.pretrained_teacher_model, subfolder="scheduler", revision=args.teacher_revision ) - # The scheduler calculates the alpha and sigma schedule for us + # DDPMScheduler calculates the alpha and sigma noise schedules (based on the alpha bars) for us alpha_schedule = torch.sqrt(noise_scheduler.alphas_cumprod) sigma_schedule = torch.sqrt(1 - noise_scheduler.alphas_cumprod) + # Initialize the DDIM ODE solver for distillation. solver = DDIMSolver( noise_scheduler.alphas_cumprod.numpy(), timesteps=noise_scheduler.config.num_train_timesteps, ddim_timesteps=args.num_ddim_timesteps, ) - # 2. Load tokenizers from SD-XL checkpoint. + # 2. Load tokenizers from SD-1.5 checkpoint. tokenizer = AutoTokenizer.from_pretrained( args.pretrained_teacher_model, subfolder="tokenizer", revision=args.teacher_revision, use_fast=False ) @@ -855,14 +856,14 @@ def main(args): args.pretrained_teacher_model, subfolder="text_encoder", revision=args.teacher_revision ) - # 4. Load VAE from SD-XL checkpoint (or more stable VAE) + # 4. Load VAE from SD-1.5 checkpoint vae = AutoencoderKL.from_pretrained( args.pretrained_teacher_model, subfolder="vae", revision=args.teacher_revision, ) - # 5. Load teacher U-Net from SD-XL checkpoint + # 5. Load teacher U-Net from SD-1.5 checkpoint teacher_unet = UNet2DConditionModel.from_pretrained( args.pretrained_teacher_model, subfolder="unet", revision=args.teacher_revision ) @@ -872,7 +873,7 @@ def main(args): text_encoder.requires_grad_(False) teacher_unet.requires_grad_(False) - # 7. Create online (`unet`) student U-Nets. + # 7. Create online (`unet`) student U-Net. unet = UNet2DConditionModel.from_pretrained( args.pretrained_teacher_model, subfolder="unet", revision=args.teacher_revision ) @@ -935,6 +936,7 @@ def main(args): # Also move the alpha and sigma noise schedules to accelerator.device. alpha_schedule = alpha_schedule.to(accelerator.device) sigma_schedule = sigma_schedule.to(accelerator.device) + # Move the ODE solver to accelerator.device. solver = solver.to(accelerator.device) # 10. Handle saving and loading of checkpoints @@ -1011,13 +1013,14 @@ def load_model_hook(models, input_dir): eps=args.adam_epsilon, ) + # 13. Dataset creation and data processing # Here, we compute not just the text embeddings but also the additional embeddings # needed for the SD XL UNet to operate. def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tokenizer, is_train=True): prompt_embeds = encode_prompt(prompt_batch, text_encoder, tokenizer, proportion_empty_prompts, is_train) return {"prompt_embeds": prompt_embeds} - dataset = Text2ImageDataset( + dataset = SDText2ImageDataset( train_shards_path_or_url=args.train_shards_path_or_url, num_train_examples=args.max_train_samples, per_gpu_batch_size=args.train_batch_size, @@ -1037,6 +1040,7 @@ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tok tokenizer=tokenizer, ) + # 14. LR Scheduler creation # Scheduler and math around the number of training steps. overrode_max_train_steps = False num_update_steps_per_epoch = math.ceil(train_dataloader.num_batches / args.gradient_accumulation_steps) @@ -1051,6 +1055,7 @@ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tok num_training_steps=args.max_train_steps, ) + # 15. Prepare for training # Prepare everything with our `accelerator`. unet, optimizer, lr_scheduler = accelerator.prepare(unet, optimizer, lr_scheduler) @@ -1072,7 +1077,7 @@ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tok ).input_ids.to(accelerator.device) uncond_prompt_embeds = text_encoder(uncond_input_ids)[0] - # Train! + # 16. Train! total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps logger.info("***** Running training *****") @@ -1123,6 +1128,7 @@ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tok for epoch in range(first_epoch, args.num_train_epochs): for step, batch in enumerate(train_dataloader): with accelerator.accumulate(unet): + # 1. Load and process the image and text conditioning image, text = batch image = image.to(accelerator.device, non_blocking=True) @@ -1141,36 +1147,38 @@ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tok latents = latents * vae.config.scaling_factor latents = latents.to(weight_dtype) - # Sample noise that we'll add to the latents + # 2. Sample noise that we'll add to the latents noise = torch.randn_like(latents) bsz = latents.shape[0] - # Sample a random timestep for each image t_n ~ U[0, N - k - 1] without bias. + # 3. Sample a random timestep for each image t_n from the ODE solver timesteps without bias. + # For the DDIM solver, the timestep schedule is [T - 1, T - k - 1, T - 2 * k - 1, ...] topk = noise_scheduler.config.num_train_timesteps // args.num_ddim_timesteps index = torch.randint(0, args.num_ddim_timesteps, (bsz,), device=latents.device).long() start_timesteps = solver.ddim_timesteps[index] timesteps = start_timesteps - topk timesteps = torch.where(timesteps < 0, torch.zeros_like(timesteps), timesteps) - # 20.4.4. Get boundary scalings for start_timesteps and (end) timesteps. + # 4. Get boundary scalings for start_timesteps and (end) timesteps. c_skip_start, c_out_start = scalings_for_boundary_conditions(start_timesteps) c_skip_start, c_out_start = [append_dims(x, latents.ndim) for x in [c_skip_start, c_out_start]] c_skip, c_out = scalings_for_boundary_conditions(timesteps) c_skip, c_out = [append_dims(x, latents.ndim) for x in [c_skip, c_out]] - # 20.4.5. Add noise to the latents according to the noise magnitude at each timestep + # 5. Add noise to the latents according to the noise magnitude at each timestep # (this is the forward diffusion process) [z_{t_{n + k}} in Algorithm 1] noisy_model_input = noise_scheduler.add_noise(latents, noise, start_timesteps) - # 20.4.6. Sample a random guidance scale w from U[w_min, w_max] and embed it + # 6. Sample a random guidance scale w from U[w_min, w_max] + # Note that for LCM-LoRA distillation it is not necessary to use a guidance scale embedding w = (args.w_max - args.w_min) * torch.rand((bsz,)) + args.w_min w = w.reshape(bsz, 1, 1, 1) w = w.to(device=latents.device, dtype=latents.dtype) - # 20.4.8. Prepare prompt embeds and unet_added_conditions + # 7. Prepare prompt embeds and unet_added_conditions prompt_embeds = encoded_text.pop("prompt_embeds") - # 20.4.9. Get online LCM prediction on z_{t_{n + k}}, w, c, t_{n + k} + # 8. Get online LCM prediction on z_{t_{n + k}} (noisy_model_input), w, c, t_{n + k} (start_timesteps) noise_pred = unet( noisy_model_input, start_timesteps, @@ -1190,11 +1198,13 @@ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tok model_pred = c_skip_start * noisy_model_input + c_out_start * pred_x_0 - # 20.4.10. Use the ODE solver to predict the kth step in the augmented PF-ODE trajectory after - # noisy_latents with both the conditioning embedding c and unconditional embedding 0 - # Get teacher model prediction on noisy_latents and conditional embedding + # 9. Compute the conditional and unconditional teacher model predictions to get CFG estimates of the + # predicted noise eps_0 and predicted original sample x_0, then run the ODE solver using these + # estimates to predict the data point in the augmented PF-ODE trajectory corresponding to the next ODE + # solver timestep. with torch.no_grad(): with torch.autocast("cuda"): + # 1. Get teacher model prediction on noisy_model_input z_{t_{n + k}} and conditional embedding c cond_teacher_output = teacher_unet( noisy_model_input.to(weight_dtype), start_timesteps, @@ -1209,7 +1219,7 @@ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tok sigma_schedule, ) - # Get teacher model prediction on noisy_latents and unconditional embedding + # 2. Get teacher model prediction on noisy_model_input z_{t_{n + k}} and unconditional embedding 0 uncond_teacher_output = teacher_unet( noisy_model_input.to(weight_dtype), start_timesteps, @@ -1224,12 +1234,19 @@ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tok sigma_schedule, ) - # 20.4.11. Perform "CFG" to get x_prev estimate (using the LCM paper's CFG formulation) + # 3. Calculate the CFG estimate of x_0 (pred_x0) and eps_0 (pred_noise) + # Note that this uses the LCM paper's CFG formulation rather than the Imagen CFG formulation + # NOTE: this currently assumes that the teacher prediction_type is "epsilon", since we directly + # use the output of teacher_unet. May want to fix at some point (e.g. following DDIMScheduler) pred_x0 = cond_pred_x0 + w * (cond_pred_x0 - uncond_pred_x0) pred_noise = cond_teacher_output + w * (cond_teacher_output - uncond_teacher_output) + # 4. Run one step of the ODE solver to estimate the next point x_prev on the + # augmented PF-ODE trajectory (solving backward in time) + # Note that the DDIM step depends on both the predicted x_0 and source noise eps_0. x_prev = solver.ddim_step(pred_x0, pred_noise, index) - # 20.4.12. Get target LCM prediction on x_prev, w, c, t_n + # 10. Get target LCM prediction on x_prev, w, c, t_n (timesteps) + # Note that we do not use a separate target network for LCM-LoRA distillation. with torch.no_grad(): with torch.autocast("cuda", dtype=weight_dtype): target_noise_pred = unet( @@ -1248,7 +1265,7 @@ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tok ) target = c_skip * x_prev + c_out * pred_x_0 - # 20.4.13. Calculate loss + # 11. Calculate loss if args.loss_type == "l2": loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") elif args.loss_type == "huber": @@ -1256,7 +1273,7 @@ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tok torch.sqrt((model_pred.float() - target.float()) ** 2 + args.huber_c**2) - args.huber_c ) - # 20.4.14. Backpropagate on the online student model (`unet`) + # 12. Backpropagate on the online student model (`unet`) accelerator.backward(loss) if accelerator.sync_gradients: accelerator.clip_grad_norm_(unet.parameters(), args.max_grad_norm) diff --git a/examples/consistency_distillation/train_lcm_distill_lora_sdxl_wds.py b/examples/consistency_distillation/train_lcm_distill_lora_sdxl_wds.py index 2ecd6f43dcde..f28e190e18a5 100644 --- a/examples/consistency_distillation/train_lcm_distill_lora_sdxl_wds.py +++ b/examples/consistency_distillation/train_lcm_distill_lora_sdxl_wds.py @@ -162,7 +162,7 @@ def __call__(self, x): return False -class Text2ImageDataset: +class SDXLText2ImageDataset: def __init__( self, train_shards_path_or_url: Union[str, List[str]], @@ -830,9 +830,10 @@ def main(args): args.pretrained_teacher_model, subfolder="scheduler", revision=args.teacher_revision ) - # The scheduler calculates the alpha and sigma schedule for us + # DDPMScheduler calculates the alpha and sigma noise schedules (based on the alpha bars) for us alpha_schedule = torch.sqrt(noise_scheduler.alphas_cumprod) sigma_schedule = torch.sqrt(1 - noise_scheduler.alphas_cumprod) + # Initialize the DDIM ODE solver for distillation. solver = DDIMSolver( noise_scheduler.alphas_cumprod.numpy(), timesteps=noise_scheduler.config.num_train_timesteps, @@ -886,7 +887,7 @@ def main(args): text_encoder_two.requires_grad_(False) teacher_unet.requires_grad_(False) - # 7. Create online (`unet`) student U-Nets. + # 7. Create online (`unet`) student U-Net. unet = UNet2DConditionModel.from_pretrained( args.pretrained_teacher_model, subfolder="unet", revision=args.teacher_revision ) @@ -950,6 +951,7 @@ def main(args): # Also move the alpha and sigma noise schedules to accelerator.device. alpha_schedule = alpha_schedule.to(accelerator.device) sigma_schedule = sigma_schedule.to(accelerator.device) + # Move the ODE solver to accelerator.device. solver = solver.to(accelerator.device) # 10. Handle saving and loading of checkpoints @@ -1057,7 +1059,7 @@ def compute_embeddings( return {"prompt_embeds": prompt_embeds, **unet_added_cond_kwargs} - dataset = Text2ImageDataset( + dataset = SDXLText2ImageDataset( train_shards_path_or_url=args.train_shards_path_or_url, num_train_examples=args.max_train_samples, per_gpu_batch_size=args.train_batch_size, @@ -1175,6 +1177,7 @@ def compute_embeddings( for epoch in range(first_epoch, args.num_train_epochs): for step, batch in enumerate(train_dataloader): with accelerator.accumulate(unet): + # 1. Load and process the image, text, and micro-conditioning (original image size, crop coordinates) image, text, orig_size, crop_coords = batch image = image.to(accelerator.device, non_blocking=True) @@ -1197,36 +1200,38 @@ def compute_embeddings( if args.pretrained_vae_model_name_or_path is None: latents = latents.to(weight_dtype) - # Sample noise that we'll add to the latents + # 2. Sample noise that we'll add to the latents noise = torch.randn_like(latents) bsz = latents.shape[0] - # Sample a random timestep for each image t_n ~ U[0, N - k - 1] without bias. + # 3. Sample a random timestep for each image t_n from the ODE solver timesteps without bias. + # For the DDIM solver, the timestep schedule is [T - 1, T - k - 1, T - 2 * k - 1, ...] topk = noise_scheduler.config.num_train_timesteps // args.num_ddim_timesteps index = torch.randint(0, args.num_ddim_timesteps, (bsz,), device=latents.device).long() start_timesteps = solver.ddim_timesteps[index] timesteps = start_timesteps - topk timesteps = torch.where(timesteps < 0, torch.zeros_like(timesteps), timesteps) - # 20.4.4. Get boundary scalings for start_timesteps and (end) timesteps. + # 4. Get boundary scalings for start_timesteps and (end) timesteps. c_skip_start, c_out_start = scalings_for_boundary_conditions(start_timesteps) c_skip_start, c_out_start = [append_dims(x, latents.ndim) for x in [c_skip_start, c_out_start]] c_skip, c_out = scalings_for_boundary_conditions(timesteps) c_skip, c_out = [append_dims(x, latents.ndim) for x in [c_skip, c_out]] - # 20.4.5. Add noise to the latents according to the noise magnitude at each timestep + # 5. Add noise to the latents according to the noise magnitude at each timestep # (this is the forward diffusion process) [z_{t_{n + k}} in Algorithm 1] noisy_model_input = noise_scheduler.add_noise(latents, noise, start_timesteps) - # 20.4.6. Sample a random guidance scale w from U[w_min, w_max] and embed it + # 6. Sample a random guidance scale w from U[w_min, w_max] + # Note that for LCM-LoRA distillation it is not necessary to use a guidance scale embedding w = (args.w_max - args.w_min) * torch.rand((bsz,)) + args.w_min w = w.reshape(bsz, 1, 1, 1) w = w.to(device=latents.device, dtype=latents.dtype) - # 20.4.8. Prepare prompt embeds and unet_added_conditions + # 7. Prepare prompt embeds and unet_added_conditions prompt_embeds = encoded_text.pop("prompt_embeds") - # 20.4.9. Get online LCM prediction on z_{t_{n + k}}, w, c, t_{n + k} + # 8. Get online LCM prediction on z_{t_{n + k}} (noisy_model_input), w, c, t_{n + k} (start_timesteps) noise_pred = unet( noisy_model_input, start_timesteps, @@ -1246,11 +1251,13 @@ def compute_embeddings( model_pred = c_skip_start * noisy_model_input + c_out_start * pred_x_0 - # 20.4.10. Use the ODE solver to predict the kth step in the augmented PF-ODE trajectory after - # noisy_latents with both the conditioning embedding c and unconditional embedding 0 - # Get teacher model prediction on noisy_latents and conditional embedding + # 9. Compute the conditional and unconditional teacher model predictions to get CFG estimates of the + # predicted noise eps_0 and predicted original sample x_0, then run the ODE solver using these + # estimates to predict the data point in the augmented PF-ODE trajectory corresponding to the next ODE + # solver timestep. with torch.no_grad(): with torch.autocast("cuda"): + # 1. Get teacher model prediction on noisy_model_input z_{t_{n + k}} and conditional embedding c cond_teacher_output = teacher_unet( noisy_model_input.to(weight_dtype), start_timesteps, @@ -1266,7 +1273,7 @@ def compute_embeddings( sigma_schedule, ) - # Get teacher model prediction on noisy_latents and unconditional embedding + # 2. Get teacher model prediction on noisy_model_input z_{t_{n + k}} and unconditional embedding 0 uncond_added_conditions = copy.deepcopy(encoded_text) uncond_added_conditions["text_embeds"] = uncond_pooled_prompt_embeds uncond_teacher_output = teacher_unet( @@ -1284,12 +1291,19 @@ def compute_embeddings( sigma_schedule, ) - # 20.4.11. Perform "CFG" to get x_prev estimate (using the LCM paper's CFG formulation) + # 3. Calculate the CFG estimate of x_0 (pred_x0) and eps_0 (pred_noise) + # Note that this uses the LCM paper's CFG formulation rather than the Imagen CFG formulation + # NOTE: this currently assumes that the teacher prediction_type is "epsilon", since we directly + # use the output of teacher_unet. May want to fix at some point (e.g. following DDIMScheduler) pred_x0 = cond_pred_x0 + w * (cond_pred_x0 - uncond_pred_x0) pred_noise = cond_teacher_output + w * (cond_teacher_output - uncond_teacher_output) + # 4. Run one step of the ODE solver to estimate the next point x_prev on the + # augmented PF-ODE trajectory (solving backward in time) + # Note that the DDIM step depends on both the predicted x_0 and source noise eps_0. x_prev = solver.ddim_step(pred_x0, pred_noise, index) - # 20.4.12. Get target LCM prediction on x_prev, w, c, t_n + # 10. Get target LCM prediction on x_prev, w, c, t_n (timesteps) + # Note that we do not use a separate target network for LCM-LoRA distillation. with torch.no_grad(): with torch.autocast("cuda", enabled=True, dtype=weight_dtype): target_noise_pred = unet( @@ -1309,7 +1323,7 @@ def compute_embeddings( ) target = c_skip * x_prev + c_out * pred_x_0 - # 20.4.13. Calculate loss + # 11. Calculate loss if args.loss_type == "l2": loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") elif args.loss_type == "huber": @@ -1317,7 +1331,7 @@ def compute_embeddings( torch.sqrt((model_pred.float() - target.float()) ** 2 + args.huber_c**2) - args.huber_c ) - # 20.4.14. Backpropagate on the online student model (`unet`) + # 12. Backpropagate on the online student model (`unet`) accelerator.backward(loss) if accelerator.sync_gradients: accelerator.clip_grad_norm_(unet.parameters(), args.max_grad_norm) diff --git a/examples/consistency_distillation/train_lcm_distill_sd_wds.py b/examples/consistency_distillation/train_lcm_distill_sd_wds.py index 1dfac0464271..b685ce3be967 100644 --- a/examples/consistency_distillation/train_lcm_distill_sd_wds.py +++ b/examples/consistency_distillation/train_lcm_distill_sd_wds.py @@ -138,7 +138,7 @@ def __call__(self, x): return False -class Text2ImageDataset: +class SDText2ImageDataset: def __init__( self, train_shards_path_or_url: Union[str, List[str]], @@ -823,16 +823,17 @@ def main(args): args.pretrained_teacher_model, subfolder="scheduler", revision=args.teacher_revision ) - # The scheduler calculates the alpha and sigma schedule for us + # DDPMScheduler calculates the alpha and sigma noise schedules (based on the alpha bars) for us alpha_schedule = torch.sqrt(noise_scheduler.alphas_cumprod) sigma_schedule = torch.sqrt(1 - noise_scheduler.alphas_cumprod) + # Initialize the DDIM ODE solver for distillation. solver = DDIMSolver( noise_scheduler.alphas_cumprod.numpy(), timesteps=noise_scheduler.config.num_train_timesteps, ddim_timesteps=args.num_ddim_timesteps, ) - # 2. Load tokenizers from SD-XL checkpoint. + # 2. Load tokenizers from SD-1.5 checkpoint. tokenizer = AutoTokenizer.from_pretrained( args.pretrained_teacher_model, subfolder="tokenizer", revision=args.teacher_revision, use_fast=False ) @@ -843,14 +844,14 @@ def main(args): args.pretrained_teacher_model, subfolder="text_encoder", revision=args.teacher_revision ) - # 4. Load VAE from SD-XL checkpoint (or more stable VAE) + # 4. Load VAE from SD-1.5 checkpoint vae = AutoencoderKL.from_pretrained( args.pretrained_teacher_model, subfolder="vae", revision=args.teacher_revision, ) - # 5. Load teacher U-Net from SD-XL checkpoint + # 5. Load teacher U-Net from SD-1.5 checkpoint teacher_unet = UNet2DConditionModel.from_pretrained( args.pretrained_teacher_model, subfolder="unet", revision=args.teacher_revision ) @@ -860,7 +861,7 @@ def main(args): text_encoder.requires_grad_(False) teacher_unet.requires_grad_(False) - # 8. Create online (`unet`) student U-Nets. This will be updated by the optimizer (e.g. via backpropagation.) + # 7. Create online (`unet`) student U-Net. This will be updated by the optimizer (e.g. via backpropagation.) # Add `time_cond_proj_dim` to the student U-Net if `teacher_unet.config.time_cond_proj_dim` is None if teacher_unet.config.time_cond_proj_dim is None: teacher_unet.config["time_cond_proj_dim"] = args.unet_time_cond_proj_dim @@ -869,8 +870,8 @@ def main(args): unet.load_state_dict(teacher_unet.state_dict(), strict=False) unet.train() - # 9. Create target (`ema_unet`) student U-Net parameters. This will be updated via EMA updates (polyak averaging). - # Initialize from unet + # 8. Create target (`target_unet`) student U-Net. This will be updated via EMA updates (polyak averaging). + # Initialize from (online) unet target_unet = UNet2DConditionModel(**teacher_unet.config) target_unet.load_state_dict(unet.state_dict()) target_unet.train() @@ -887,7 +888,7 @@ def main(args): f"Controlnet loaded as datatype {accelerator.unwrap_model(unet).dtype}. {low_precision_error_string}" ) - # 10. Handle mixed precision and device placement + # 9. Handle mixed precision and device placement # For mixed precision training we cast all non-trainable weigths to half-precision # as these weights are only used for inference, keeping weights in full precision is not required. weight_dtype = torch.float32 @@ -914,7 +915,7 @@ def main(args): sigma_schedule = sigma_schedule.to(accelerator.device) solver = solver.to(accelerator.device) - # 11. Handle saving and loading of checkpoints + # 10. Handle saving and loading of checkpoints # `accelerate` 0.16.0 will have better support for customized saving if version.parse(accelerate.__version__) >= version.parse("0.16.0"): # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format @@ -948,7 +949,7 @@ def load_model_hook(models, input_dir): accelerator.register_save_state_pre_hook(save_model_hook) accelerator.register_load_state_pre_hook(load_model_hook) - # 12. Enable optimizations + # 11. Enable optimizations if args.enable_xformers_memory_efficient_attention: if is_xformers_available(): import xformers @@ -994,13 +995,14 @@ def load_model_hook(models, input_dir): eps=args.adam_epsilon, ) + # 13. Dataset creation and data processing # Here, we compute not just the text embeddings but also the additional embeddings # needed for the SD XL UNet to operate. def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tokenizer, is_train=True): prompt_embeds = encode_prompt(prompt_batch, text_encoder, tokenizer, proportion_empty_prompts, is_train) return {"prompt_embeds": prompt_embeds} - dataset = Text2ImageDataset( + dataset = SDText2ImageDataset( train_shards_path_or_url=args.train_shards_path_or_url, num_train_examples=args.max_train_samples, per_gpu_batch_size=args.train_batch_size, @@ -1020,6 +1022,7 @@ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tok tokenizer=tokenizer, ) + # 14. LR Scheduler creation # Scheduler and math around the number of training steps. overrode_max_train_steps = False num_update_steps_per_epoch = math.ceil(train_dataloader.num_batches / args.gradient_accumulation_steps) @@ -1034,6 +1037,7 @@ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tok num_training_steps=args.max_train_steps, ) + # 15. Prepare for training # Prepare everything with our `accelerator`. unet, optimizer, lr_scheduler = accelerator.prepare(unet, optimizer, lr_scheduler) @@ -1055,7 +1059,7 @@ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tok ).input_ids.to(accelerator.device) uncond_prompt_embeds = text_encoder(uncond_input_ids)[0] - # Train! + # 16. Train! total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps logger.info("***** Running training *****") @@ -1106,6 +1110,7 @@ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tok for epoch in range(first_epoch, args.num_train_epochs): for step, batch in enumerate(train_dataloader): with accelerator.accumulate(unet): + # 1. Load and process the image and text conditioning image, text = batch image = image.to(accelerator.device, non_blocking=True) @@ -1124,28 +1129,29 @@ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tok latents = latents * vae.config.scaling_factor latents = latents.to(weight_dtype) - # Sample noise that we'll add to the latents + # 2. Sample noise that we'll add to the latents noise = torch.randn_like(latents) bsz = latents.shape[0] - # Sample a random timestep for each image t_n ~ U[0, N - k - 1] without bias. + # 3. Sample a random timestep for each image t_n from the ODE solver timesteps without bias. + # For the DDIM solver, the timestep schedule is [T - 1, T - k - 1, T - 2 * k - 1, ...] topk = noise_scheduler.config.num_train_timesteps // args.num_ddim_timesteps index = torch.randint(0, args.num_ddim_timesteps, (bsz,), device=latents.device).long() start_timesteps = solver.ddim_timesteps[index] timesteps = start_timesteps - topk timesteps = torch.where(timesteps < 0, torch.zeros_like(timesteps), timesteps) - # 20.4.4. Get boundary scalings for start_timesteps and (end) timesteps. + # 4. Get boundary scalings for start_timesteps and (end) timesteps. c_skip_start, c_out_start = scalings_for_boundary_conditions(start_timesteps) c_skip_start, c_out_start = [append_dims(x, latents.ndim) for x in [c_skip_start, c_out_start]] c_skip, c_out = scalings_for_boundary_conditions(timesteps) c_skip, c_out = [append_dims(x, latents.ndim) for x in [c_skip, c_out]] - # 20.4.5. Add noise to the latents according to the noise magnitude at each timestep + # 5. Add noise to the latents according to the noise magnitude at each timestep # (this is the forward diffusion process) [z_{t_{n + k}} in Algorithm 1] noisy_model_input = noise_scheduler.add_noise(latents, noise, start_timesteps) - # 20.4.6. Sample a random guidance scale w from U[w_min, w_max] and embed it + # 6. Sample a random guidance scale w from U[w_min, w_max] and embed it w = (args.w_max - args.w_min) * torch.rand((bsz,)) + args.w_min w_embedding = guidance_scale_embedding(w, embedding_dim=unet.config.time_cond_proj_dim) w = w.reshape(bsz, 1, 1, 1) @@ -1153,10 +1159,10 @@ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tok w = w.to(device=latents.device, dtype=latents.dtype) w_embedding = w_embedding.to(device=latents.device, dtype=latents.dtype) - # 20.4.8. Prepare prompt embeds and unet_added_conditions + # 7. Prepare prompt embeds and unet_added_conditions prompt_embeds = encoded_text.pop("prompt_embeds") - # 20.4.9. Get online LCM prediction on z_{t_{n + k}}, w, c, t_{n + k} + # 8. Get online LCM prediction on z_{t_{n + k}} (noisy_model_input), w, c, t_{n + k} (start_timesteps) noise_pred = unet( noisy_model_input, start_timesteps, @@ -1176,11 +1182,13 @@ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tok model_pred = c_skip_start * noisy_model_input + c_out_start * pred_x_0 - # 20.4.10. Use the ODE solver to predict the kth step in the augmented PF-ODE trajectory after - # noisy_latents with both the conditioning embedding c and unconditional embedding 0 - # Get teacher model prediction on noisy_latents and conditional embedding + # 9. Compute the conditional and unconditional teacher model predictions to get CFG estimates of the + # predicted noise eps_0 and predicted original sample x_0, then run the ODE solver using these + # estimates to predict the data point in the augmented PF-ODE trajectory corresponding to the next ODE + # solver timestep. with torch.no_grad(): with torch.autocast("cuda"): + # 1. Get teacher model prediction on noisy_model_input z_{t_{n + k}} and conditional embedding c cond_teacher_output = teacher_unet( noisy_model_input.to(weight_dtype), start_timesteps, @@ -1195,7 +1203,7 @@ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tok sigma_schedule, ) - # Get teacher model prediction on noisy_latents and unconditional embedding + # 2. Get teacher model prediction on noisy_model_input z_{t_{n + k}} and unconditional embedding 0 uncond_teacher_output = teacher_unet( noisy_model_input.to(weight_dtype), start_timesteps, @@ -1210,12 +1218,18 @@ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tok sigma_schedule, ) - # 20.4.11. Perform "CFG" to get x_prev estimate (using the LCM paper's CFG formulation) + # 3. Calculate the CFG estimate of x_0 (pred_x0) and eps_0 (pred_noise) + # Note that this uses the LCM paper's CFG formulation rather than the Imagen CFG formulation + # NOTE: this currently assumes that the teacher prediction_type is "epsilon", since we directly + # use the output of teacher_unet. May want to fix at some point (e.g. following DDIMScheduler) pred_x0 = cond_pred_x0 + w * (cond_pred_x0 - uncond_pred_x0) pred_noise = cond_teacher_output + w * (cond_teacher_output - uncond_teacher_output) + # 4. Run one step of the ODE solver to estimate the next point x_prev on the + # augmented PF-ODE trajectory (solving backward in time) + # Note that the DDIM step depends on both the predicted x_0 and source noise eps_0. x_prev = solver.ddim_step(pred_x0, pred_noise, index) - # 20.4.12. Get target LCM prediction on x_prev, w, c, t_n + # 10. Get target LCM prediction on x_prev, w, c, t_n (timesteps) with torch.no_grad(): with torch.autocast("cuda", dtype=weight_dtype): target_noise_pred = target_unet( @@ -1234,7 +1248,7 @@ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tok ) target = c_skip * x_prev + c_out * pred_x_0 - # 20.4.13. Calculate loss + # 11. Calculate loss if args.loss_type == "l2": loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") elif args.loss_type == "huber": @@ -1242,7 +1256,7 @@ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tok torch.sqrt((model_pred.float() - target.float()) ** 2 + args.huber_c**2) - args.huber_c ) - # 20.4.14. Backpropagate on the online student model (`unet`) + # 12. Backpropagate on the online student model (`unet`) accelerator.backward(loss) if accelerator.sync_gradients: accelerator.clip_grad_norm_(unet.parameters(), args.max_grad_norm) @@ -1252,7 +1266,7 @@ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tok # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: - # 20.4.15. Make EMA update to target student model parameters + # 13. Make EMA update to target student model parameters (`target_unet`) update_ema(target_unet.parameters(), unet.parameters(), args.ema_decay) progress_bar.update(1) global_step += 1 diff --git a/examples/consistency_distillation/train_lcm_distill_sdxl_wds.py b/examples/consistency_distillation/train_lcm_distill_sdxl_wds.py index 952bec67d148..b519f3a1bd71 100644 --- a/examples/consistency_distillation/train_lcm_distill_sdxl_wds.py +++ b/examples/consistency_distillation/train_lcm_distill_sdxl_wds.py @@ -144,7 +144,7 @@ def __call__(self, x): return False -class Text2ImageDataset: +class SDXLText2ImageDataset: def __init__( self, train_shards_path_or_url: Union[str, List[str]], @@ -863,9 +863,10 @@ def main(args): args.pretrained_teacher_model, subfolder="scheduler", revision=args.teacher_revision ) - # The scheduler calculates the alpha and sigma schedule for us + # DDPMScheduler calculates the alpha and sigma noise schedules (based on the alpha bars) for us alpha_schedule = torch.sqrt(noise_scheduler.alphas_cumprod) sigma_schedule = torch.sqrt(1 - noise_scheduler.alphas_cumprod) + # Initialize the DDIM ODE solver for distillation. solver = DDIMSolver( noise_scheduler.alphas_cumprod.numpy(), timesteps=noise_scheduler.config.num_train_timesteps, @@ -919,7 +920,7 @@ def main(args): text_encoder_two.requires_grad_(False) teacher_unet.requires_grad_(False) - # 8. Create online (`unet`) student U-Nets. This will be updated by the optimizer (e.g. via backpropagation.) + # 7. Create online (`unet`) student U-Net. This will be updated by the optimizer (e.g. via backpropagation.) # Add `time_cond_proj_dim` to the student U-Net if `teacher_unet.config.time_cond_proj_dim` is None if teacher_unet.config.time_cond_proj_dim is None: teacher_unet.config["time_cond_proj_dim"] = args.unet_time_cond_proj_dim @@ -928,8 +929,8 @@ def main(args): unet.load_state_dict(teacher_unet.state_dict(), strict=False) unet.train() - # 9. Create target (`ema_unet`) student U-Net parameters. This will be updated via EMA updates (polyak averaging). - # Initialize from unet + # 8. Create target (`target_unet`) student U-Net. This will be updated via EMA updates (polyak averaging). + # Initialize from (online) unet target_unet = UNet2DConditionModel(**teacher_unet.config) target_unet.load_state_dict(unet.state_dict()) target_unet.train() @@ -971,6 +972,7 @@ def main(args): # Also move the alpha and sigma noise schedules to accelerator.device. alpha_schedule = alpha_schedule.to(accelerator.device) sigma_schedule = sigma_schedule.to(accelerator.device) + # Move the ODE solver to accelerator.device. solver = solver.to(accelerator.device) # 10. Handle saving and loading of checkpoints @@ -1084,7 +1086,7 @@ def compute_embeddings( return {"prompt_embeds": prompt_embeds, **unet_added_cond_kwargs} - dataset = Text2ImageDataset( + dataset = SDXLText2ImageDataset( train_shards_path_or_url=args.train_shards_path_or_url, num_train_examples=args.max_train_samples, per_gpu_batch_size=args.train_batch_size, @@ -1202,6 +1204,7 @@ def compute_embeddings( for epoch in range(first_epoch, args.num_train_epochs): for step, batch in enumerate(train_dataloader): with accelerator.accumulate(unet): + # 1. Load and process the image, text, and micro-conditioning (original image size, crop coordinates) image, text, orig_size, crop_coords = batch image = image.to(accelerator.device, non_blocking=True) @@ -1224,37 +1227,40 @@ def compute_embeddings( if args.pretrained_vae_model_name_or_path is None: latents = latents.to(weight_dtype) - # Sample noise that we'll add to the latents + # 2. Sample noise that we'll add to the latents noise = torch.randn_like(latents) bsz = latents.shape[0] - # Sample a random timestep for each image t_n ~ U[0, N - k - 1] without bias. + # 3. Sample a random timestep for each image t_n from the ODE solver timesteps without bias. + # For the DDIM solver, the timestep schedule is [T - 1, T - k - 1, T - 2 * k - 1, ...] topk = noise_scheduler.config.num_train_timesteps // args.num_ddim_timesteps index = torch.randint(0, args.num_ddim_timesteps, (bsz,), device=latents.device).long() start_timesteps = solver.ddim_timesteps[index] timesteps = start_timesteps - topk timesteps = torch.where(timesteps < 0, torch.zeros_like(timesteps), timesteps) - # 20.4.4. Get boundary scalings for start_timesteps and (end) timesteps. + # 4. Get boundary scalings for start_timesteps and (end) timesteps. c_skip_start, c_out_start = scalings_for_boundary_conditions(start_timesteps) c_skip_start, c_out_start = [append_dims(x, latents.ndim) for x in [c_skip_start, c_out_start]] c_skip, c_out = scalings_for_boundary_conditions(timesteps) c_skip, c_out = [append_dims(x, latents.ndim) for x in [c_skip, c_out]] - # 20.4.5. Add noise to the latents according to the noise magnitude at each timestep + # 5. Add noise to the latents according to the noise magnitude at each timestep # (this is the forward diffusion process) [z_{t_{n + k}} in Algorithm 1] noisy_model_input = noise_scheduler.add_noise(latents, noise, start_timesteps) - # 20.4.6. Sample a random guidance scale w from U[w_min, w_max] and embed it + # 6. Sample a random guidance scale w from U[w_min, w_max] and embed it w = (args.w_max - args.w_min) * torch.rand((bsz,)) + args.w_min w_embedding = guidance_scale_embedding(w, embedding_dim=unet.config.time_cond_proj_dim) w = w.reshape(bsz, 1, 1, 1) + # Move to U-Net device and dtype w = w.to(device=latents.device, dtype=latents.dtype) + w_embedding = w_embedding.to(device=latents.device, dtype=latents.dtype) - # 20.4.8. Prepare prompt embeds and unet_added_conditions + # 7. Prepare prompt embeds and unet_added_conditions prompt_embeds = encoded_text.pop("prompt_embeds") - # 20.4.9. Get online LCM prediction on z_{t_{n + k}}, w, c, t_{n + k} + # 8. Get online LCM prediction on z_{t_{n + k}} (noisy_model_input), w, c, t_{n + k} (start_timesteps) noise_pred = unet( noisy_model_input, start_timesteps, @@ -1274,11 +1280,13 @@ def compute_embeddings( model_pred = c_skip_start * noisy_model_input + c_out_start * pred_x_0 - # 20.4.10. Use the ODE solver to predict the kth step in the augmented PF-ODE trajectory after - # noisy_latents with both the conditioning embedding c and unconditional embedding 0 - # Get teacher model prediction on noisy_latents and conditional embedding + # 9. Compute the conditional and unconditional teacher model predictions to get CFG estimates of the + # predicted noise eps_0 and predicted original sample x_0, then run the ODE solver using these + # estimates to predict the data point in the augmented PF-ODE trajectory corresponding to the next ODE + # solver timestep. with torch.no_grad(): with torch.autocast("cuda"): + # 1. Get teacher model prediction on noisy_model_input z_{t_{n + k}} and conditional embedding c cond_teacher_output = teacher_unet( noisy_model_input.to(weight_dtype), start_timesteps, @@ -1294,7 +1302,7 @@ def compute_embeddings( sigma_schedule, ) - # Get teacher model prediction on noisy_latents and unconditional embedding + # 2. Get teacher model prediction on noisy_model_input z_{t_{n + k}} and unconditional embedding 0 uncond_added_conditions = copy.deepcopy(encoded_text) uncond_added_conditions["text_embeds"] = uncond_pooled_prompt_embeds uncond_teacher_output = teacher_unet( @@ -1312,12 +1320,18 @@ def compute_embeddings( sigma_schedule, ) - # 20.4.11. Perform "CFG" to get x_prev estimate (using the LCM paper's CFG formulation) + # 3. Calculate the CFG estimate of x_0 (pred_x0) and eps_0 (pred_noise) + # Note that this uses the LCM paper's CFG formulation rather than the Imagen CFG formulation + # NOTE: this currently assumes that the teacher prediction_type is "epsilon", since we directly + # use the output of teacher_unet. May want to fix at some point (e.g. following DDIMScheduler) pred_x0 = cond_pred_x0 + w * (cond_pred_x0 - uncond_pred_x0) pred_noise = cond_teacher_output + w * (cond_teacher_output - uncond_teacher_output) + # 4. Run one step of the ODE solver to estimate the next point x_prev on the + # augmented PF-ODE trajectory (solving backward in time) + # Note that the DDIM step depends on both the predicted x_0 and source noise eps_0. x_prev = solver.ddim_step(pred_x0, pred_noise, index) - # 20.4.12. Get target LCM prediction on x_prev, w, c, t_n + # 10. Get target LCM prediction on x_prev, w, c, t_n (timesteps) with torch.no_grad(): with torch.autocast("cuda", dtype=weight_dtype): target_noise_pred = target_unet( @@ -1337,7 +1351,7 @@ def compute_embeddings( ) target = c_skip * x_prev + c_out * pred_x_0 - # 20.4.13. Calculate loss + # 11. Calculate loss if args.loss_type == "l2": loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") elif args.loss_type == "huber": @@ -1345,7 +1359,7 @@ def compute_embeddings( torch.sqrt((model_pred.float() - target.float()) ** 2 + args.huber_c**2) - args.huber_c ) - # 20.4.14. Backpropagate on the online student model (`unet`) + # 12. Backpropagate on the online student model (`unet`) accelerator.backward(loss) if accelerator.sync_gradients: accelerator.clip_grad_norm_(unet.parameters(), args.max_grad_norm) @@ -1355,7 +1369,7 @@ def compute_embeddings( # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: - # 20.4.15. Make EMA update to target student model parameters + # 13. Make EMA update to target student model parameters (`target_unet`) update_ema(target_unet.parameters(), unet.parameters(), args.ema_decay) progress_bar.update(1) global_step += 1 From a63db273885b22007b4f0996edfee07ef5fab1fa Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Tue, 12 Dec 2023 02:39:56 -0800 Subject: [PATCH 2/4] Calculate predicted source noise noise_pred correctly for all prediction_types. --- .../train_lcm_distill_lora_sd_wds.py | 52 ++++++++++++++++--- .../train_lcm_distill_lora_sdxl_wds.py | 52 ++++++++++++++++--- .../train_lcm_distill_sd_wds.py | 52 ++++++++++++++++--- .../train_lcm_distill_sdxl_wds.py | 52 ++++++++++++++++--- 4 files changed, 180 insertions(+), 28 deletions(-) diff --git a/examples/consistency_distillation/train_lcm_distill_lora_sd_wds.py b/examples/consistency_distillation/train_lcm_distill_lora_sd_wds.py index fd86c6eaeb4e..5e20ff618580 100644 --- a/examples/consistency_distillation/train_lcm_distill_lora_sd_wds.py +++ b/examples/consistency_distillation/train_lcm_distill_lora_sd_wds.py @@ -360,18 +360,42 @@ def scalings_for_boundary_conditions(timestep, sigma_data=0.5, timestep_scaling= # Compare LCMScheduler.step, Step 4 def predicted_origin(model_output, timesteps, sample, prediction_type, alphas, sigmas): + alphas = extract_into_tensor(alphas, timesteps, sample.shape) + sigmas = extract_into_tensor(sigmas, timesteps, sample.shape) if prediction_type == "epsilon": - sigmas = extract_into_tensor(sigmas, timesteps, sample.shape) - alphas = extract_into_tensor(alphas, timesteps, sample.shape) pred_x_0 = (sample - sigmas * model_output) / alphas + elif prediction_type == "sample": + pred_x_0 = model_output elif prediction_type == "v_prediction": - pred_x_0 = alphas[timesteps] * sample - sigmas[timesteps] * model_output + pred_x_0 = alphas * sample - sigmas * model_output else: - raise ValueError(f"Prediction type {prediction_type} currently not supported.") + raise ValueError( + f"Prediction type {prediction_type} is not supported; currently, `epsilon`, `sample`, and `v_prediction`" + f" are supported." + ) return pred_x_0 +# Based on step 4 in DDIMScheduler.step +def predicted_source_noise(model_output, timesteps, sample, prediction_type, alphas, sigmas): + alphas = extract_into_tensor(alphas, timesteps, sample.shape) + sigmas = extract_into_tensor(sigmas, timesteps, sample.shape) + if prediction_type == "epsilon": + pred_epsilon = model_output + elif prediction_type == "sample": + pred_epsilon = (sample - alphas * model_output) / sigmas + elif prediction_type == "v_prediction": + pred_epsilon = alphas * model_output + sigmas * sample + else: + raise ValueError( + f"Prediction type {prediction_type} is not supported; currently, `epsilon`, `sample`, and `v_prediction`" + f" are supported." + ) + + return pred_epsilon + + def extract_into_tensor(a, t, x_shape): b, *_ = t.shape out = a.gather(-1, t) @@ -1218,6 +1242,14 @@ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tok alpha_schedule, sigma_schedule, ) + cond_pred_noise = predicted_source_noise( + cond_teacher_output, + start_timesteps, + noisy_model_input, + noise_scheduler.config.prediction_type, + alpha_schedule, + sigma_schedule, + ) # 2. Get teacher model prediction on noisy_model_input z_{t_{n + k}} and unconditional embedding 0 uncond_teacher_output = teacher_unet( @@ -1233,13 +1265,19 @@ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tok alpha_schedule, sigma_schedule, ) + uncond_pred_noise = predicted_source_noise( + uncond_teacher_output, + start_timesteps, + noisy_model_input, + noise_scheduler.config.prediction_type, + alpha_schedule, + sigma_schedule, + ) # 3. Calculate the CFG estimate of x_0 (pred_x0) and eps_0 (pred_noise) # Note that this uses the LCM paper's CFG formulation rather than the Imagen CFG formulation - # NOTE: this currently assumes that the teacher prediction_type is "epsilon", since we directly - # use the output of teacher_unet. May want to fix at some point (e.g. following DDIMScheduler) pred_x0 = cond_pred_x0 + w * (cond_pred_x0 - uncond_pred_x0) - pred_noise = cond_teacher_output + w * (cond_teacher_output - uncond_teacher_output) + pred_noise = cond_pred_noise + w * (cond_pred_noise - uncond_pred_noise) # 4. Run one step of the ODE solver to estimate the next point x_prev on the # augmented PF-ODE trajectory (solving backward in time) # Note that the DDIM step depends on both the predicted x_0 and source noise eps_0. diff --git a/examples/consistency_distillation/train_lcm_distill_lora_sdxl_wds.py b/examples/consistency_distillation/train_lcm_distill_lora_sdxl_wds.py index f28e190e18a5..2511ff512dbb 100644 --- a/examples/consistency_distillation/train_lcm_distill_lora_sdxl_wds.py +++ b/examples/consistency_distillation/train_lcm_distill_lora_sdxl_wds.py @@ -347,18 +347,42 @@ def scalings_for_boundary_conditions(timestep, sigma_data=0.5, timestep_scaling= # Compare LCMScheduler.step, Step 4 def predicted_origin(model_output, timesteps, sample, prediction_type, alphas, sigmas): + alphas = extract_into_tensor(alphas, timesteps, sample.shape) + sigmas = extract_into_tensor(sigmas, timesteps, sample.shape) if prediction_type == "epsilon": - sigmas = extract_into_tensor(sigmas, timesteps, sample.shape) - alphas = extract_into_tensor(alphas, timesteps, sample.shape) pred_x_0 = (sample - sigmas * model_output) / alphas + elif prediction_type == "sample": + pred_x_0 = model_output elif prediction_type == "v_prediction": - pred_x_0 = alphas[timesteps] * sample - sigmas[timesteps] * model_output + pred_x_0 = alphas * sample - sigmas * model_output else: - raise ValueError(f"Prediction type {prediction_type} currently not supported.") + raise ValueError( + f"Prediction type {prediction_type} is not supported; currently, `epsilon`, `sample`, and `v_prediction`" + f" are supported." + ) return pred_x_0 +# Based on step 4 in DDIMScheduler.step +def predicted_source_noise(model_output, timesteps, sample, prediction_type, alphas, sigmas): + alphas = extract_into_tensor(alphas, timesteps, sample.shape) + sigmas = extract_into_tensor(sigmas, timesteps, sample.shape) + if prediction_type == "epsilon": + pred_epsilon = model_output + elif prediction_type == "sample": + pred_epsilon = (sample - alphas * model_output) / sigmas + elif prediction_type == "v_prediction": + pred_epsilon = alphas * model_output + sigmas * sample + else: + raise ValueError( + f"Prediction type {prediction_type} is not supported; currently, `epsilon`, `sample`, and `v_prediction`" + f" are supported." + ) + + return pred_epsilon + + def extract_into_tensor(a, t, x_shape): b, *_ = t.shape out = a.gather(-1, t) @@ -1272,6 +1296,14 @@ def compute_embeddings( alpha_schedule, sigma_schedule, ) + cond_pred_noise = predicted_source_noise( + cond_teacher_output, + start_timesteps, + noisy_model_input, + noise_scheduler.config.prediction_type, + alpha_schedule, + sigma_schedule, + ) # 2. Get teacher model prediction on noisy_model_input z_{t_{n + k}} and unconditional embedding 0 uncond_added_conditions = copy.deepcopy(encoded_text) @@ -1290,13 +1322,19 @@ def compute_embeddings( alpha_schedule, sigma_schedule, ) + uncond_pred_noise = predicted_source_noise( + uncond_teacher_output, + start_timesteps, + noisy_model_input, + noise_scheduler.config.prediction_type, + alpha_schedule, + sigma_schedule, + ) # 3. Calculate the CFG estimate of x_0 (pred_x0) and eps_0 (pred_noise) # Note that this uses the LCM paper's CFG formulation rather than the Imagen CFG formulation - # NOTE: this currently assumes that the teacher prediction_type is "epsilon", since we directly - # use the output of teacher_unet. May want to fix at some point (e.g. following DDIMScheduler) pred_x0 = cond_pred_x0 + w * (cond_pred_x0 - uncond_pred_x0) - pred_noise = cond_teacher_output + w * (cond_teacher_output - uncond_teacher_output) + pred_noise = cond_pred_noise + w * (cond_pred_noise - uncond_pred_noise) # 4. Run one step of the ODE solver to estimate the next point x_prev on the # augmented PF-ODE trajectory (solving backward in time) # Note that the DDIM step depends on both the predicted x_0 and source noise eps_0. diff --git a/examples/consistency_distillation/train_lcm_distill_sd_wds.py b/examples/consistency_distillation/train_lcm_distill_sd_wds.py index b685ce3be967..31170620ea6d 100644 --- a/examples/consistency_distillation/train_lcm_distill_sd_wds.py +++ b/examples/consistency_distillation/train_lcm_distill_sd_wds.py @@ -337,18 +337,42 @@ def scalings_for_boundary_conditions(timestep, sigma_data=0.5, timestep_scaling= # Compare LCMScheduler.step, Step 4 def predicted_origin(model_output, timesteps, sample, prediction_type, alphas, sigmas): + alphas = extract_into_tensor(alphas, timesteps, sample.shape) + sigmas = extract_into_tensor(sigmas, timesteps, sample.shape) if prediction_type == "epsilon": - sigmas = extract_into_tensor(sigmas, timesteps, sample.shape) - alphas = extract_into_tensor(alphas, timesteps, sample.shape) pred_x_0 = (sample - sigmas * model_output) / alphas + elif prediction_type == "sample": + pred_x_0 = model_output elif prediction_type == "v_prediction": - pred_x_0 = alphas[timesteps] * sample - sigmas[timesteps] * model_output + pred_x_0 = alphas * sample - sigmas * model_output else: - raise ValueError(f"Prediction type {prediction_type} currently not supported.") + raise ValueError( + f"Prediction type {prediction_type} is not supported; currently, `epsilon`, `sample`, and `v_prediction`" + f" are supported." + ) return pred_x_0 +# Based on step 4 in DDIMScheduler.step +def predicted_source_noise(model_output, timesteps, sample, prediction_type, alphas, sigmas): + alphas = extract_into_tensor(alphas, timesteps, sample.shape) + sigmas = extract_into_tensor(sigmas, timesteps, sample.shape) + if prediction_type == "epsilon": + pred_epsilon = model_output + elif prediction_type == "sample": + pred_epsilon = (sample - alphas * model_output) / sigmas + elif prediction_type == "v_prediction": + pred_epsilon = alphas * model_output + sigmas * sample + else: + raise ValueError( + f"Prediction type {prediction_type} is not supported; currently, `epsilon`, `sample`, and `v_prediction`" + f" are supported." + ) + + return pred_epsilon + + def extract_into_tensor(a, t, x_shape): b, *_ = t.shape out = a.gather(-1, t) @@ -1202,6 +1226,14 @@ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tok alpha_schedule, sigma_schedule, ) + cond_pred_noise = predicted_source_noise( + cond_teacher_output, + start_timesteps, + noisy_model_input, + noise_scheduler.config.prediction_type, + alpha_schedule, + sigma_schedule, + ) # 2. Get teacher model prediction on noisy_model_input z_{t_{n + k}} and unconditional embedding 0 uncond_teacher_output = teacher_unet( @@ -1217,13 +1249,19 @@ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tok alpha_schedule, sigma_schedule, ) + uncond_pred_noise = predicted_source_noise( + uncond_teacher_output, + start_timesteps, + noisy_model_input, + noise_scheduler.config.prediction_type, + alpha_schedule, + sigma_schedule, + ) # 3. Calculate the CFG estimate of x_0 (pred_x0) and eps_0 (pred_noise) # Note that this uses the LCM paper's CFG formulation rather than the Imagen CFG formulation - # NOTE: this currently assumes that the teacher prediction_type is "epsilon", since we directly - # use the output of teacher_unet. May want to fix at some point (e.g. following DDIMScheduler) pred_x0 = cond_pred_x0 + w * (cond_pred_x0 - uncond_pred_x0) - pred_noise = cond_teacher_output + w * (cond_teacher_output - uncond_teacher_output) + pred_noise = cond_pred_noise + w * (cond_pred_noise - uncond_pred_noise) # 4. Run one step of the ODE solver to estimate the next point x_prev on the # augmented PF-ODE trajectory (solving backward in time) # Note that the DDIM step depends on both the predicted x_0 and source noise eps_0. diff --git a/examples/consistency_distillation/train_lcm_distill_sdxl_wds.py b/examples/consistency_distillation/train_lcm_distill_sdxl_wds.py index b519f3a1bd71..bf92846414ad 100644 --- a/examples/consistency_distillation/train_lcm_distill_sdxl_wds.py +++ b/examples/consistency_distillation/train_lcm_distill_sdxl_wds.py @@ -325,18 +325,42 @@ def scalings_for_boundary_conditions(timestep, sigma_data=0.5, timestep_scaling= # Compare LCMScheduler.step, Step 4 def predicted_origin(model_output, timesteps, sample, prediction_type, alphas, sigmas): + alphas = extract_into_tensor(alphas, timesteps, sample.shape) + sigmas = extract_into_tensor(sigmas, timesteps, sample.shape) if prediction_type == "epsilon": - sigmas = extract_into_tensor(sigmas, timesteps, sample.shape) - alphas = extract_into_tensor(alphas, timesteps, sample.shape) pred_x_0 = (sample - sigmas * model_output) / alphas + elif prediction_type == "sample": + pred_x_0 = model_output elif prediction_type == "v_prediction": - pred_x_0 = alphas[timesteps] * sample - sigmas[timesteps] * model_output + pred_x_0 = alphas * sample - sigmas * model_output else: - raise ValueError(f"Prediction type {prediction_type} currently not supported.") + raise ValueError( + f"Prediction type {prediction_type} is not supported; currently, `epsilon`, `sample`, and `v_prediction`" + f" are supported." + ) return pred_x_0 +# Based on step 4 in DDIMScheduler.step +def predicted_source_noise(model_output, timesteps, sample, prediction_type, alphas, sigmas): + alphas = extract_into_tensor(alphas, timesteps, sample.shape) + sigmas = extract_into_tensor(sigmas, timesteps, sample.shape) + if prediction_type == "epsilon": + pred_epsilon = model_output + elif prediction_type == "sample": + pred_epsilon = (sample - alphas * model_output) / sigmas + elif prediction_type == "v_prediction": + pred_epsilon = alphas * model_output + sigmas * sample + else: + raise ValueError( + f"Prediction type {prediction_type} is not supported; currently, `epsilon`, `sample`, and `v_prediction`" + f" are supported." + ) + + return pred_epsilon + + def extract_into_tensor(a, t, x_shape): b, *_ = t.shape out = a.gather(-1, t) @@ -1301,6 +1325,14 @@ def compute_embeddings( alpha_schedule, sigma_schedule, ) + cond_pred_noise = predicted_source_noise( + cond_teacher_output, + start_timesteps, + noisy_model_input, + noise_scheduler.config.prediction_type, + alpha_schedule, + sigma_schedule, + ) # 2. Get teacher model prediction on noisy_model_input z_{t_{n + k}} and unconditional embedding 0 uncond_added_conditions = copy.deepcopy(encoded_text) @@ -1319,13 +1351,19 @@ def compute_embeddings( alpha_schedule, sigma_schedule, ) + uncond_pred_noise = predicted_source_noise( + uncond_teacher_output, + start_timesteps, + noisy_model_input, + noise_scheduler.config.prediction_type, + alpha_schedule, + sigma_schedule, + ) # 3. Calculate the CFG estimate of x_0 (pred_x0) and eps_0 (pred_noise) # Note that this uses the LCM paper's CFG formulation rather than the Imagen CFG formulation - # NOTE: this currently assumes that the teacher prediction_type is "epsilon", since we directly - # use the output of teacher_unet. May want to fix at some point (e.g. following DDIMScheduler) pred_x0 = cond_pred_x0 + w * (cond_pred_x0 - uncond_pred_x0) - pred_noise = cond_teacher_output + w * (cond_teacher_output - uncond_teacher_output) + pred_noise = cond_pred_noise + w * (cond_pred_noise - uncond_pred_noise) # 4. Run one step of the ODE solver to estimate the next point x_prev on the # augmented PF-ODE trajectory (solving backward in time) # Note that the DDIM step depends on both the predicted x_0 and source noise eps_0. From f89853574389f6a0bc39c4209dccaf8901f054cf Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Tue, 12 Dec 2023 02:40:39 -0800 Subject: [PATCH 3/4] make style --- .../consistency_distillation/train_lcm_distill_lora_sd_wds.py | 2 +- .../consistency_distillation/train_lcm_distill_lora_sdxl_wds.py | 2 +- examples/consistency_distillation/train_lcm_distill_sd_wds.py | 2 +- examples/consistency_distillation/train_lcm_distill_sdxl_wds.py | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/examples/consistency_distillation/train_lcm_distill_lora_sd_wds.py b/examples/consistency_distillation/train_lcm_distill_lora_sd_wds.py index 5e20ff618580..ccd027e7ccbf 100644 --- a/examples/consistency_distillation/train_lcm_distill_lora_sd_wds.py +++ b/examples/consistency_distillation/train_lcm_distill_lora_sd_wds.py @@ -392,7 +392,7 @@ def predicted_source_noise(model_output, timesteps, sample, prediction_type, alp f"Prediction type {prediction_type} is not supported; currently, `epsilon`, `sample`, and `v_prediction`" f" are supported." ) - + return pred_epsilon diff --git a/examples/consistency_distillation/train_lcm_distill_lora_sdxl_wds.py b/examples/consistency_distillation/train_lcm_distill_lora_sdxl_wds.py index 2511ff512dbb..bcec116decde 100644 --- a/examples/consistency_distillation/train_lcm_distill_lora_sdxl_wds.py +++ b/examples/consistency_distillation/train_lcm_distill_lora_sdxl_wds.py @@ -379,7 +379,7 @@ def predicted_source_noise(model_output, timesteps, sample, prediction_type, alp f"Prediction type {prediction_type} is not supported; currently, `epsilon`, `sample`, and `v_prediction`" f" are supported." ) - + return pred_epsilon diff --git a/examples/consistency_distillation/train_lcm_distill_sd_wds.py b/examples/consistency_distillation/train_lcm_distill_sd_wds.py index 31170620ea6d..7085f51792fb 100644 --- a/examples/consistency_distillation/train_lcm_distill_sd_wds.py +++ b/examples/consistency_distillation/train_lcm_distill_sd_wds.py @@ -369,7 +369,7 @@ def predicted_source_noise(model_output, timesteps, sample, prediction_type, alp f"Prediction type {prediction_type} is not supported; currently, `epsilon`, `sample`, and `v_prediction`" f" are supported." ) - + return pred_epsilon diff --git a/examples/consistency_distillation/train_lcm_distill_sdxl_wds.py b/examples/consistency_distillation/train_lcm_distill_sdxl_wds.py index bf92846414ad..902adb38fd95 100644 --- a/examples/consistency_distillation/train_lcm_distill_sdxl_wds.py +++ b/examples/consistency_distillation/train_lcm_distill_sdxl_wds.py @@ -357,7 +357,7 @@ def predicted_source_noise(model_output, timesteps, sample, prediction_type, alp f"Prediction type {prediction_type} is not supported; currently, `epsilon`, `sample`, and `v_prediction`" f" are supported." ) - + return pred_epsilon From 527dcdbb2dacb7c08b26737dd49ce595808dfd83 Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Fri, 15 Dec 2023 02:07:40 -0800 Subject: [PATCH 4/4] apply suggestions from review --- .../train_lcm_distill_lora_sd_wds.py | 52 +++++++++-------- .../train_lcm_distill_lora_sdxl_wds.py | 44 +++++++-------- .../train_lcm_distill_sd_wds.py | 56 +++++++++---------- .../train_lcm_distill_sdxl_wds.py | 48 ++++++++-------- 4 files changed, 96 insertions(+), 104 deletions(-) diff --git a/examples/consistency_distillation/train_lcm_distill_lora_sd_wds.py b/examples/consistency_distillation/train_lcm_distill_lora_sd_wds.py index ccd027e7ccbf..05689b71fa04 100644 --- a/examples/consistency_distillation/train_lcm_distill_lora_sd_wds.py +++ b/examples/consistency_distillation/train_lcm_distill_lora_sd_wds.py @@ -359,7 +359,7 @@ def scalings_for_boundary_conditions(timestep, sigma_data=0.5, timestep_scaling= # Compare LCMScheduler.step, Step 4 -def predicted_origin(model_output, timesteps, sample, prediction_type, alphas, sigmas): +def get_predicted_original_sample(model_output, timesteps, sample, prediction_type, alphas, sigmas): alphas = extract_into_tensor(alphas, timesteps, sample.shape) sigmas = extract_into_tensor(sigmas, timesteps, sample.shape) if prediction_type == "epsilon": @@ -378,7 +378,7 @@ def predicted_origin(model_output, timesteps, sample, prediction_type, alphas, s # Based on step 4 in DDIMScheduler.step -def predicted_source_noise(model_output, timesteps, sample, prediction_type, alphas, sigmas): +def get_predicted_noise(model_output, timesteps, sample, prediction_type, alphas, sigmas): alphas = extract_into_tensor(alphas, timesteps, sample.shape) sigmas = extract_into_tensor(sigmas, timesteps, sample.shape) if prediction_type == "epsilon": @@ -869,25 +869,25 @@ def main(args): ddim_timesteps=args.num_ddim_timesteps, ) - # 2. Load tokenizers from SD-1.5 checkpoint. + # 2. Load tokenizers from SD 1.X/2.X checkpoint. tokenizer = AutoTokenizer.from_pretrained( args.pretrained_teacher_model, subfolder="tokenizer", revision=args.teacher_revision, use_fast=False ) - # 3. Load text encoders from SD-1.5 checkpoint. + # 3. Load text encoders from SD 1.X/2.X checkpoint. # import correct text encoder classes text_encoder = CLIPTextModel.from_pretrained( args.pretrained_teacher_model, subfolder="text_encoder", revision=args.teacher_revision ) - # 4. Load VAE from SD-1.5 checkpoint + # 4. Load VAE from SD 1.X/2.X checkpoint vae = AutoencoderKL.from_pretrained( args.pretrained_teacher_model, subfolder="vae", revision=args.teacher_revision, ) - # 5. Load teacher U-Net from SD-1.5 checkpoint + # 5. Load teacher U-Net from SD 1.X/2.X checkpoint teacher_unet = UNet2DConditionModel.from_pretrained( args.pretrained_teacher_model, subfolder="unet", revision=args.teacher_revision ) @@ -897,7 +897,7 @@ def main(args): text_encoder.requires_grad_(False) teacher_unet.requires_grad_(False) - # 7. Create online (`unet`) student U-Net. + # 7. Create online student U-Net. unet = UNet2DConditionModel.from_pretrained( args.pretrained_teacher_model, subfolder="unet", revision=args.teacher_revision ) @@ -1170,12 +1170,9 @@ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tok latents = latents * vae.config.scaling_factor latents = latents.to(weight_dtype) - - # 2. Sample noise that we'll add to the latents - noise = torch.randn_like(latents) bsz = latents.shape[0] - # 3. Sample a random timestep for each image t_n from the ODE solver timesteps without bias. + # 2. Sample a random timestep for each image t_n from the ODE solver timesteps without bias. # For the DDIM solver, the timestep schedule is [T - 1, T - k - 1, T - 2 * k - 1, ...] topk = noise_scheduler.config.num_train_timesteps // args.num_ddim_timesteps index = torch.randint(0, args.num_ddim_timesteps, (bsz,), device=latents.device).long() @@ -1183,26 +1180,27 @@ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tok timesteps = start_timesteps - topk timesteps = torch.where(timesteps < 0, torch.zeros_like(timesteps), timesteps) - # 4. Get boundary scalings for start_timesteps and (end) timesteps. + # 3. Get boundary scalings for start_timesteps and (end) timesteps. c_skip_start, c_out_start = scalings_for_boundary_conditions(start_timesteps) c_skip_start, c_out_start = [append_dims(x, latents.ndim) for x in [c_skip_start, c_out_start]] c_skip, c_out = scalings_for_boundary_conditions(timesteps) c_skip, c_out = [append_dims(x, latents.ndim) for x in [c_skip, c_out]] - # 5. Add noise to the latents according to the noise magnitude at each timestep - # (this is the forward diffusion process) [z_{t_{n + k}} in Algorithm 1] + # 4. Sample noise from the prior and add it to the latents according to the noise magnitude at each + # timestep (this is the forward diffusion process) [z_{t_{n + k}} in Algorithm 1] + noise = torch.randn_like(latents) noisy_model_input = noise_scheduler.add_noise(latents, noise, start_timesteps) - # 6. Sample a random guidance scale w from U[w_min, w_max] + # 5. Sample a random guidance scale w from U[w_min, w_max] # Note that for LCM-LoRA distillation it is not necessary to use a guidance scale embedding w = (args.w_max - args.w_min) * torch.rand((bsz,)) + args.w_min w = w.reshape(bsz, 1, 1, 1) w = w.to(device=latents.device, dtype=latents.dtype) - # 7. Prepare prompt embeds and unet_added_conditions + # 6. Prepare prompt embeds and unet_added_conditions prompt_embeds = encoded_text.pop("prompt_embeds") - # 8. Get online LCM prediction on z_{t_{n + k}} (noisy_model_input), w, c, t_{n + k} (start_timesteps) + # 7. Get online LCM prediction on z_{t_{n + k}} (noisy_model_input), w, c, t_{n + k} (start_timesteps) noise_pred = unet( noisy_model_input, start_timesteps, @@ -1211,7 +1209,7 @@ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tok added_cond_kwargs=encoded_text, ).sample - pred_x_0 = predicted_origin( + pred_x_0 = get_predicted_original_sample( noise_pred, start_timesteps, noisy_model_input, @@ -1222,7 +1220,7 @@ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tok model_pred = c_skip_start * noisy_model_input + c_out_start * pred_x_0 - # 9. Compute the conditional and unconditional teacher model predictions to get CFG estimates of the + # 8. Compute the conditional and unconditional teacher model predictions to get CFG estimates of the # predicted noise eps_0 and predicted original sample x_0, then run the ODE solver using these # estimates to predict the data point in the augmented PF-ODE trajectory corresponding to the next ODE # solver timestep. @@ -1234,7 +1232,7 @@ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tok start_timesteps, encoder_hidden_states=prompt_embeds.to(weight_dtype), ).sample - cond_pred_x0 = predicted_origin( + cond_pred_x0 = get_predicted_original_sample( cond_teacher_output, start_timesteps, noisy_model_input, @@ -1242,7 +1240,7 @@ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tok alpha_schedule, sigma_schedule, ) - cond_pred_noise = predicted_source_noise( + cond_pred_noise = get_predicted_noise( cond_teacher_output, start_timesteps, noisy_model_input, @@ -1257,7 +1255,7 @@ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tok start_timesteps, encoder_hidden_states=uncond_prompt_embeds.to(weight_dtype), ).sample - uncond_pred_x0 = predicted_origin( + uncond_pred_x0 = get_predicted_original_sample( uncond_teacher_output, start_timesteps, noisy_model_input, @@ -1265,7 +1263,7 @@ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tok alpha_schedule, sigma_schedule, ) - uncond_pred_noise = predicted_source_noise( + uncond_pred_noise = get_predicted_noise( uncond_teacher_output, start_timesteps, noisy_model_input, @@ -1283,7 +1281,7 @@ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tok # Note that the DDIM step depends on both the predicted x_0 and source noise eps_0. x_prev = solver.ddim_step(pred_x0, pred_noise, index) - # 10. Get target LCM prediction on x_prev, w, c, t_n (timesteps) + # 9. Get target LCM prediction on x_prev, w, c, t_n (timesteps) # Note that we do not use a separate target network for LCM-LoRA distillation. with torch.no_grad(): with torch.autocast("cuda", dtype=weight_dtype): @@ -1293,7 +1291,7 @@ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tok timestep_cond=None, encoder_hidden_states=prompt_embeds.float(), ).sample - pred_x_0 = predicted_origin( + pred_x_0 = get_predicted_original_sample( target_noise_pred, timesteps, x_prev, @@ -1303,7 +1301,7 @@ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tok ) target = c_skip * x_prev + c_out * pred_x_0 - # 11. Calculate loss + # 10. Calculate loss if args.loss_type == "l2": loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") elif args.loss_type == "huber": @@ -1311,7 +1309,7 @@ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tok torch.sqrt((model_pred.float() - target.float()) ** 2 + args.huber_c**2) - args.huber_c ) - # 12. Backpropagate on the online student model (`unet`) + # 11. Backpropagate on the online student model (`unet`) accelerator.backward(loss) if accelerator.sync_gradients: accelerator.clip_grad_norm_(unet.parameters(), args.max_grad_norm) diff --git a/examples/consistency_distillation/train_lcm_distill_lora_sdxl_wds.py b/examples/consistency_distillation/train_lcm_distill_lora_sdxl_wds.py index bcec116decde..014a770fa0ba 100644 --- a/examples/consistency_distillation/train_lcm_distill_lora_sdxl_wds.py +++ b/examples/consistency_distillation/train_lcm_distill_lora_sdxl_wds.py @@ -346,7 +346,7 @@ def scalings_for_boundary_conditions(timestep, sigma_data=0.5, timestep_scaling= # Compare LCMScheduler.step, Step 4 -def predicted_origin(model_output, timesteps, sample, prediction_type, alphas, sigmas): +def get_predicted_original_sample(model_output, timesteps, sample, prediction_type, alphas, sigmas): alphas = extract_into_tensor(alphas, timesteps, sample.shape) sigmas = extract_into_tensor(sigmas, timesteps, sample.shape) if prediction_type == "epsilon": @@ -365,7 +365,7 @@ def predicted_origin(model_output, timesteps, sample, prediction_type, alphas, s # Based on step 4 in DDIMScheduler.step -def predicted_source_noise(model_output, timesteps, sample, prediction_type, alphas, sigmas): +def get_predicted_noise(model_output, timesteps, sample, prediction_type, alphas, sigmas): alphas = extract_into_tensor(alphas, timesteps, sample.shape) sigmas = extract_into_tensor(sigmas, timesteps, sample.shape) if prediction_type == "epsilon": @@ -911,7 +911,7 @@ def main(args): text_encoder_two.requires_grad_(False) teacher_unet.requires_grad_(False) - # 7. Create online (`unet`) student U-Net. + # 7. Create online student U-Net. unet = UNet2DConditionModel.from_pretrained( args.pretrained_teacher_model, subfolder="unet", revision=args.teacher_revision ) @@ -1223,12 +1223,9 @@ def compute_embeddings( latents = latents * vae.config.scaling_factor if args.pretrained_vae_model_name_or_path is None: latents = latents.to(weight_dtype) - - # 2. Sample noise that we'll add to the latents - noise = torch.randn_like(latents) bsz = latents.shape[0] - # 3. Sample a random timestep for each image t_n from the ODE solver timesteps without bias. + # 2. Sample a random timestep for each image t_n from the ODE solver timesteps without bias. # For the DDIM solver, the timestep schedule is [T - 1, T - k - 1, T - 2 * k - 1, ...] topk = noise_scheduler.config.num_train_timesteps // args.num_ddim_timesteps index = torch.randint(0, args.num_ddim_timesteps, (bsz,), device=latents.device).long() @@ -1236,26 +1233,27 @@ def compute_embeddings( timesteps = start_timesteps - topk timesteps = torch.where(timesteps < 0, torch.zeros_like(timesteps), timesteps) - # 4. Get boundary scalings for start_timesteps and (end) timesteps. + # 3. Get boundary scalings for start_timesteps and (end) timesteps. c_skip_start, c_out_start = scalings_for_boundary_conditions(start_timesteps) c_skip_start, c_out_start = [append_dims(x, latents.ndim) for x in [c_skip_start, c_out_start]] c_skip, c_out = scalings_for_boundary_conditions(timesteps) c_skip, c_out = [append_dims(x, latents.ndim) for x in [c_skip, c_out]] - # 5. Add noise to the latents according to the noise magnitude at each timestep - # (this is the forward diffusion process) [z_{t_{n + k}} in Algorithm 1] + # 4. Sample noise from the prior and add it to the latents according to the noise magnitude at each + # timestep (this is the forward diffusion process) [z_{t_{n + k}} in Algorithm 1] + noise = torch.randn_like(latents) noisy_model_input = noise_scheduler.add_noise(latents, noise, start_timesteps) - # 6. Sample a random guidance scale w from U[w_min, w_max] + # 5. Sample a random guidance scale w from U[w_min, w_max] # Note that for LCM-LoRA distillation it is not necessary to use a guidance scale embedding w = (args.w_max - args.w_min) * torch.rand((bsz,)) + args.w_min w = w.reshape(bsz, 1, 1, 1) w = w.to(device=latents.device, dtype=latents.dtype) - # 7. Prepare prompt embeds and unet_added_conditions + # 6. Prepare prompt embeds and unet_added_conditions prompt_embeds = encoded_text.pop("prompt_embeds") - # 8. Get online LCM prediction on z_{t_{n + k}} (noisy_model_input), w, c, t_{n + k} (start_timesteps) + # 7. Get online LCM prediction on z_{t_{n + k}} (noisy_model_input), w, c, t_{n + k} (start_timesteps) noise_pred = unet( noisy_model_input, start_timesteps, @@ -1264,7 +1262,7 @@ def compute_embeddings( added_cond_kwargs=encoded_text, ).sample - pred_x_0 = predicted_origin( + pred_x_0 = get_predicted_original_sample( noise_pred, start_timesteps, noisy_model_input, @@ -1275,7 +1273,7 @@ def compute_embeddings( model_pred = c_skip_start * noisy_model_input + c_out_start * pred_x_0 - # 9. Compute the conditional and unconditional teacher model predictions to get CFG estimates of the + # 8. Compute the conditional and unconditional teacher model predictions to get CFG estimates of the # predicted noise eps_0 and predicted original sample x_0, then run the ODE solver using these # estimates to predict the data point in the augmented PF-ODE trajectory corresponding to the next ODE # solver timestep. @@ -1288,7 +1286,7 @@ def compute_embeddings( encoder_hidden_states=prompt_embeds.to(weight_dtype), added_cond_kwargs={k: v.to(weight_dtype) for k, v in encoded_text.items()}, ).sample - cond_pred_x0 = predicted_origin( + cond_pred_x0 = get_predicted_original_sample( cond_teacher_output, start_timesteps, noisy_model_input, @@ -1296,7 +1294,7 @@ def compute_embeddings( alpha_schedule, sigma_schedule, ) - cond_pred_noise = predicted_source_noise( + cond_pred_noise = get_predicted_noise( cond_teacher_output, start_timesteps, noisy_model_input, @@ -1314,7 +1312,7 @@ def compute_embeddings( encoder_hidden_states=uncond_prompt_embeds.to(weight_dtype), added_cond_kwargs={k: v.to(weight_dtype) for k, v in uncond_added_conditions.items()}, ).sample - uncond_pred_x0 = predicted_origin( + uncond_pred_x0 = get_predicted_original_sample( uncond_teacher_output, start_timesteps, noisy_model_input, @@ -1322,7 +1320,7 @@ def compute_embeddings( alpha_schedule, sigma_schedule, ) - uncond_pred_noise = predicted_source_noise( + uncond_pred_noise = get_predicted_noise( uncond_teacher_output, start_timesteps, noisy_model_input, @@ -1340,7 +1338,7 @@ def compute_embeddings( # Note that the DDIM step depends on both the predicted x_0 and source noise eps_0. x_prev = solver.ddim_step(pred_x0, pred_noise, index) - # 10. Get target LCM prediction on x_prev, w, c, t_n (timesteps) + # 9. Get target LCM prediction on x_prev, w, c, t_n (timesteps) # Note that we do not use a separate target network for LCM-LoRA distillation. with torch.no_grad(): with torch.autocast("cuda", enabled=True, dtype=weight_dtype): @@ -1351,7 +1349,7 @@ def compute_embeddings( encoder_hidden_states=prompt_embeds.float(), added_cond_kwargs=encoded_text, ).sample - pred_x_0 = predicted_origin( + pred_x_0 = get_predicted_original_sample( target_noise_pred, timesteps, x_prev, @@ -1361,7 +1359,7 @@ def compute_embeddings( ) target = c_skip * x_prev + c_out * pred_x_0 - # 11. Calculate loss + # 10. Calculate loss if args.loss_type == "l2": loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") elif args.loss_type == "huber": @@ -1369,7 +1367,7 @@ def compute_embeddings( torch.sqrt((model_pred.float() - target.float()) ** 2 + args.huber_c**2) - args.huber_c ) - # 12. Backpropagate on the online student model (`unet`) + # 11. Backpropagate on the online student model (`unet`) accelerator.backward(loss) if accelerator.sync_gradients: accelerator.clip_grad_norm_(unet.parameters(), args.max_grad_norm) diff --git a/examples/consistency_distillation/train_lcm_distill_sd_wds.py b/examples/consistency_distillation/train_lcm_distill_sd_wds.py index 7085f51792fb..54d05bb5ea26 100644 --- a/examples/consistency_distillation/train_lcm_distill_sd_wds.py +++ b/examples/consistency_distillation/train_lcm_distill_sd_wds.py @@ -336,7 +336,7 @@ def scalings_for_boundary_conditions(timestep, sigma_data=0.5, timestep_scaling= # Compare LCMScheduler.step, Step 4 -def predicted_origin(model_output, timesteps, sample, prediction_type, alphas, sigmas): +def get_predicted_original_sample(model_output, timesteps, sample, prediction_type, alphas, sigmas): alphas = extract_into_tensor(alphas, timesteps, sample.shape) sigmas = extract_into_tensor(sigmas, timesteps, sample.shape) if prediction_type == "epsilon": @@ -355,7 +355,7 @@ def predicted_origin(model_output, timesteps, sample, prediction_type, alphas, s # Based on step 4 in DDIMScheduler.step -def predicted_source_noise(model_output, timesteps, sample, prediction_type, alphas, sigmas): +def get_predicted_noise(model_output, timesteps, sample, prediction_type, alphas, sigmas): alphas = extract_into_tensor(alphas, timesteps, sample.shape) sigmas = extract_into_tensor(sigmas, timesteps, sample.shape) if prediction_type == "epsilon": @@ -857,25 +857,25 @@ def main(args): ddim_timesteps=args.num_ddim_timesteps, ) - # 2. Load tokenizers from SD-1.5 checkpoint. + # 2. Load tokenizers from SD 1.X/2.X checkpoint. tokenizer = AutoTokenizer.from_pretrained( args.pretrained_teacher_model, subfolder="tokenizer", revision=args.teacher_revision, use_fast=False ) - # 3. Load text encoders from SD-1.5 checkpoint. + # 3. Load text encoders from SD 1.X/2.X checkpoint. # import correct text encoder classes text_encoder = CLIPTextModel.from_pretrained( args.pretrained_teacher_model, subfolder="text_encoder", revision=args.teacher_revision ) - # 4. Load VAE from SD-1.5 checkpoint + # 4. Load VAE from SD 1.X/2.X checkpoint vae = AutoencoderKL.from_pretrained( args.pretrained_teacher_model, subfolder="vae", revision=args.teacher_revision, ) - # 5. Load teacher U-Net from SD-1.5 checkpoint + # 5. Load teacher U-Net from SD 1.X/2.X checkpoint teacher_unet = UNet2DConditionModel.from_pretrained( args.pretrained_teacher_model, subfolder="unet", revision=args.teacher_revision ) @@ -885,7 +885,7 @@ def main(args): text_encoder.requires_grad_(False) teacher_unet.requires_grad_(False) - # 7. Create online (`unet`) student U-Net. This will be updated by the optimizer (e.g. via backpropagation.) + # 7. Create online student U-Net. This will be updated by the optimizer (e.g. via backpropagation.) # Add `time_cond_proj_dim` to the student U-Net if `teacher_unet.config.time_cond_proj_dim` is None if teacher_unet.config.time_cond_proj_dim is None: teacher_unet.config["time_cond_proj_dim"] = args.unet_time_cond_proj_dim @@ -894,7 +894,7 @@ def main(args): unet.load_state_dict(teacher_unet.state_dict(), strict=False) unet.train() - # 8. Create target (`target_unet`) student U-Net. This will be updated via EMA updates (polyak averaging). + # 8. Create target student U-Net. This will be updated via EMA updates (polyak averaging). # Initialize from (online) unet target_unet = UNet2DConditionModel(**teacher_unet.config) target_unet.load_state_dict(unet.state_dict()) @@ -1152,12 +1152,9 @@ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tok latents = latents * vae.config.scaling_factor latents = latents.to(weight_dtype) - - # 2. Sample noise that we'll add to the latents - noise = torch.randn_like(latents) bsz = latents.shape[0] - # 3. Sample a random timestep for each image t_n from the ODE solver timesteps without bias. + # 2. Sample a random timestep for each image t_n from the ODE solver timesteps without bias. # For the DDIM solver, the timestep schedule is [T - 1, T - k - 1, T - 2 * k - 1, ...] topk = noise_scheduler.config.num_train_timesteps // args.num_ddim_timesteps index = torch.randint(0, args.num_ddim_timesteps, (bsz,), device=latents.device).long() @@ -1165,17 +1162,18 @@ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tok timesteps = start_timesteps - topk timesteps = torch.where(timesteps < 0, torch.zeros_like(timesteps), timesteps) - # 4. Get boundary scalings for start_timesteps and (end) timesteps. + # 3. Get boundary scalings for start_timesteps and (end) timesteps. c_skip_start, c_out_start = scalings_for_boundary_conditions(start_timesteps) c_skip_start, c_out_start = [append_dims(x, latents.ndim) for x in [c_skip_start, c_out_start]] c_skip, c_out = scalings_for_boundary_conditions(timesteps) c_skip, c_out = [append_dims(x, latents.ndim) for x in [c_skip, c_out]] - # 5. Add noise to the latents according to the noise magnitude at each timestep - # (this is the forward diffusion process) [z_{t_{n + k}} in Algorithm 1] + # 4. Sample noise from the prior and add it to the latents according to the noise magnitude at each + # timestep (this is the forward diffusion process) [z_{t_{n + k}} in Algorithm 1] + noise = torch.randn_like(latents) noisy_model_input = noise_scheduler.add_noise(latents, noise, start_timesteps) - # 6. Sample a random guidance scale w from U[w_min, w_max] and embed it + # 5. Sample a random guidance scale w from U[w_min, w_max] and embed it w = (args.w_max - args.w_min) * torch.rand((bsz,)) + args.w_min w_embedding = guidance_scale_embedding(w, embedding_dim=unet.config.time_cond_proj_dim) w = w.reshape(bsz, 1, 1, 1) @@ -1183,10 +1181,10 @@ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tok w = w.to(device=latents.device, dtype=latents.dtype) w_embedding = w_embedding.to(device=latents.device, dtype=latents.dtype) - # 7. Prepare prompt embeds and unet_added_conditions + # 6. Prepare prompt embeds and unet_added_conditions prompt_embeds = encoded_text.pop("prompt_embeds") - # 8. Get online LCM prediction on z_{t_{n + k}} (noisy_model_input), w, c, t_{n + k} (start_timesteps) + # 7. Get online LCM prediction on z_{t_{n + k}} (noisy_model_input), w, c, t_{n + k} (start_timesteps) noise_pred = unet( noisy_model_input, start_timesteps, @@ -1195,7 +1193,7 @@ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tok added_cond_kwargs=encoded_text, ).sample - pred_x_0 = predicted_origin( + pred_x_0 = get_predicted_original_sample( noise_pred, start_timesteps, noisy_model_input, @@ -1206,7 +1204,7 @@ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tok model_pred = c_skip_start * noisy_model_input + c_out_start * pred_x_0 - # 9. Compute the conditional and unconditional teacher model predictions to get CFG estimates of the + # 8. Compute the conditional and unconditional teacher model predictions to get CFG estimates of the # predicted noise eps_0 and predicted original sample x_0, then run the ODE solver using these # estimates to predict the data point in the augmented PF-ODE trajectory corresponding to the next ODE # solver timestep. @@ -1218,7 +1216,7 @@ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tok start_timesteps, encoder_hidden_states=prompt_embeds.to(weight_dtype), ).sample - cond_pred_x0 = predicted_origin( + cond_pred_x0 = get_predicted_original_sample( cond_teacher_output, start_timesteps, noisy_model_input, @@ -1226,7 +1224,7 @@ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tok alpha_schedule, sigma_schedule, ) - cond_pred_noise = predicted_source_noise( + cond_pred_noise = get_predicted_noise( cond_teacher_output, start_timesteps, noisy_model_input, @@ -1241,7 +1239,7 @@ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tok start_timesteps, encoder_hidden_states=uncond_prompt_embeds.to(weight_dtype), ).sample - uncond_pred_x0 = predicted_origin( + uncond_pred_x0 = get_predicted_original_sample( uncond_teacher_output, start_timesteps, noisy_model_input, @@ -1249,7 +1247,7 @@ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tok alpha_schedule, sigma_schedule, ) - uncond_pred_noise = predicted_source_noise( + uncond_pred_noise = get_predicted_noise( uncond_teacher_output, start_timesteps, noisy_model_input, @@ -1267,7 +1265,7 @@ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tok # Note that the DDIM step depends on both the predicted x_0 and source noise eps_0. x_prev = solver.ddim_step(pred_x0, pred_noise, index) - # 10. Get target LCM prediction on x_prev, w, c, t_n (timesteps) + # 9. Get target LCM prediction on x_prev, w, c, t_n (timesteps) with torch.no_grad(): with torch.autocast("cuda", dtype=weight_dtype): target_noise_pred = target_unet( @@ -1276,7 +1274,7 @@ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tok timestep_cond=w_embedding, encoder_hidden_states=prompt_embeds.float(), ).sample - pred_x_0 = predicted_origin( + pred_x_0 = get_predicted_original_sample( target_noise_pred, timesteps, x_prev, @@ -1286,7 +1284,7 @@ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tok ) target = c_skip * x_prev + c_out * pred_x_0 - # 11. Calculate loss + # 10. Calculate loss if args.loss_type == "l2": loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") elif args.loss_type == "huber": @@ -1294,7 +1292,7 @@ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tok torch.sqrt((model_pred.float() - target.float()) ** 2 + args.huber_c**2) - args.huber_c ) - # 12. Backpropagate on the online student model (`unet`) + # 11. Backpropagate on the online student model (`unet`) accelerator.backward(loss) if accelerator.sync_gradients: accelerator.clip_grad_norm_(unet.parameters(), args.max_grad_norm) @@ -1304,7 +1302,7 @@ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tok # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: - # 13. Make EMA update to target student model parameters (`target_unet`) + # 12. Make EMA update to target student model parameters (`target_unet`) update_ema(target_unet.parameters(), unet.parameters(), args.ema_decay) progress_bar.update(1) global_step += 1 diff --git a/examples/consistency_distillation/train_lcm_distill_sdxl_wds.py b/examples/consistency_distillation/train_lcm_distill_sdxl_wds.py index 902adb38fd95..e58db46c9811 100644 --- a/examples/consistency_distillation/train_lcm_distill_sdxl_wds.py +++ b/examples/consistency_distillation/train_lcm_distill_sdxl_wds.py @@ -324,7 +324,7 @@ def scalings_for_boundary_conditions(timestep, sigma_data=0.5, timestep_scaling= # Compare LCMScheduler.step, Step 4 -def predicted_origin(model_output, timesteps, sample, prediction_type, alphas, sigmas): +def get_predicted_original_sample(model_output, timesteps, sample, prediction_type, alphas, sigmas): alphas = extract_into_tensor(alphas, timesteps, sample.shape) sigmas = extract_into_tensor(sigmas, timesteps, sample.shape) if prediction_type == "epsilon": @@ -343,7 +343,7 @@ def predicted_origin(model_output, timesteps, sample, prediction_type, alphas, s # Based on step 4 in DDIMScheduler.step -def predicted_source_noise(model_output, timesteps, sample, prediction_type, alphas, sigmas): +def get_predicted_noise(model_output, timesteps, sample, prediction_type, alphas, sigmas): alphas = extract_into_tensor(alphas, timesteps, sample.shape) sigmas = extract_into_tensor(sigmas, timesteps, sample.shape) if prediction_type == "epsilon": @@ -944,7 +944,7 @@ def main(args): text_encoder_two.requires_grad_(False) teacher_unet.requires_grad_(False) - # 7. Create online (`unet`) student U-Net. This will be updated by the optimizer (e.g. via backpropagation.) + # 7. Create online student U-Net. This will be updated by the optimizer (e.g. via backpropagation.) # Add `time_cond_proj_dim` to the student U-Net if `teacher_unet.config.time_cond_proj_dim` is None if teacher_unet.config.time_cond_proj_dim is None: teacher_unet.config["time_cond_proj_dim"] = args.unet_time_cond_proj_dim @@ -953,7 +953,7 @@ def main(args): unet.load_state_dict(teacher_unet.state_dict(), strict=False) unet.train() - # 8. Create target (`target_unet`) student U-Net. This will be updated via EMA updates (polyak averaging). + # 8. Create target student U-Net. This will be updated via EMA updates (polyak averaging). # Initialize from (online) unet target_unet = UNet2DConditionModel(**teacher_unet.config) target_unet.load_state_dict(unet.state_dict()) @@ -1250,12 +1250,9 @@ def compute_embeddings( latents = latents * vae.config.scaling_factor if args.pretrained_vae_model_name_or_path is None: latents = latents.to(weight_dtype) - - # 2. Sample noise that we'll add to the latents - noise = torch.randn_like(latents) bsz = latents.shape[0] - # 3. Sample a random timestep for each image t_n from the ODE solver timesteps without bias. + # 2. Sample a random timestep for each image t_n from the ODE solver timesteps without bias. # For the DDIM solver, the timestep schedule is [T - 1, T - k - 1, T - 2 * k - 1, ...] topk = noise_scheduler.config.num_train_timesteps // args.num_ddim_timesteps index = torch.randint(0, args.num_ddim_timesteps, (bsz,), device=latents.device).long() @@ -1263,17 +1260,18 @@ def compute_embeddings( timesteps = start_timesteps - topk timesteps = torch.where(timesteps < 0, torch.zeros_like(timesteps), timesteps) - # 4. Get boundary scalings for start_timesteps and (end) timesteps. + # 3. Get boundary scalings for start_timesteps and (end) timesteps. c_skip_start, c_out_start = scalings_for_boundary_conditions(start_timesteps) c_skip_start, c_out_start = [append_dims(x, latents.ndim) for x in [c_skip_start, c_out_start]] c_skip, c_out = scalings_for_boundary_conditions(timesteps) c_skip, c_out = [append_dims(x, latents.ndim) for x in [c_skip, c_out]] - # 5. Add noise to the latents according to the noise magnitude at each timestep - # (this is the forward diffusion process) [z_{t_{n + k}} in Algorithm 1] + # 4. Sample noise from the prior and add it to the latents according to the noise magnitude at each + # timestep (this is the forward diffusion process) [z_{t_{n + k}} in Algorithm 1] + noise = torch.randn_like(latents) noisy_model_input = noise_scheduler.add_noise(latents, noise, start_timesteps) - # 6. Sample a random guidance scale w from U[w_min, w_max] and embed it + # 5. Sample a random guidance scale w from U[w_min, w_max] and embed it w = (args.w_max - args.w_min) * torch.rand((bsz,)) + args.w_min w_embedding = guidance_scale_embedding(w, embedding_dim=unet.config.time_cond_proj_dim) w = w.reshape(bsz, 1, 1, 1) @@ -1281,10 +1279,10 @@ def compute_embeddings( w = w.to(device=latents.device, dtype=latents.dtype) w_embedding = w_embedding.to(device=latents.device, dtype=latents.dtype) - # 7. Prepare prompt embeds and unet_added_conditions + # 6. Prepare prompt embeds and unet_added_conditions prompt_embeds = encoded_text.pop("prompt_embeds") - # 8. Get online LCM prediction on z_{t_{n + k}} (noisy_model_input), w, c, t_{n + k} (start_timesteps) + # 7. Get online LCM prediction on z_{t_{n + k}} (noisy_model_input), w, c, t_{n + k} (start_timesteps) noise_pred = unet( noisy_model_input, start_timesteps, @@ -1293,7 +1291,7 @@ def compute_embeddings( added_cond_kwargs=encoded_text, ).sample - pred_x_0 = predicted_origin( + pred_x_0 = get_predicted_original_sample( noise_pred, start_timesteps, noisy_model_input, @@ -1304,7 +1302,7 @@ def compute_embeddings( model_pred = c_skip_start * noisy_model_input + c_out_start * pred_x_0 - # 9. Compute the conditional and unconditional teacher model predictions to get CFG estimates of the + # 8. Compute the conditional and unconditional teacher model predictions to get CFG estimates of the # predicted noise eps_0 and predicted original sample x_0, then run the ODE solver using these # estimates to predict the data point in the augmented PF-ODE trajectory corresponding to the next ODE # solver timestep. @@ -1317,7 +1315,7 @@ def compute_embeddings( encoder_hidden_states=prompt_embeds.to(weight_dtype), added_cond_kwargs={k: v.to(weight_dtype) for k, v in encoded_text.items()}, ).sample - cond_pred_x0 = predicted_origin( + cond_pred_x0 = get_predicted_original_sample( cond_teacher_output, start_timesteps, noisy_model_input, @@ -1325,7 +1323,7 @@ def compute_embeddings( alpha_schedule, sigma_schedule, ) - cond_pred_noise = predicted_source_noise( + cond_pred_noise = get_predicted_noise( cond_teacher_output, start_timesteps, noisy_model_input, @@ -1343,7 +1341,7 @@ def compute_embeddings( encoder_hidden_states=uncond_prompt_embeds.to(weight_dtype), added_cond_kwargs={k: v.to(weight_dtype) for k, v in uncond_added_conditions.items()}, ).sample - uncond_pred_x0 = predicted_origin( + uncond_pred_x0 = get_predicted_original_sample( uncond_teacher_output, start_timesteps, noisy_model_input, @@ -1351,7 +1349,7 @@ def compute_embeddings( alpha_schedule, sigma_schedule, ) - uncond_pred_noise = predicted_source_noise( + uncond_pred_noise = get_predicted_noise( uncond_teacher_output, start_timesteps, noisy_model_input, @@ -1369,7 +1367,7 @@ def compute_embeddings( # Note that the DDIM step depends on both the predicted x_0 and source noise eps_0. x_prev = solver.ddim_step(pred_x0, pred_noise, index) - # 10. Get target LCM prediction on x_prev, w, c, t_n (timesteps) + # 9. Get target LCM prediction on x_prev, w, c, t_n (timesteps) with torch.no_grad(): with torch.autocast("cuda", dtype=weight_dtype): target_noise_pred = target_unet( @@ -1379,7 +1377,7 @@ def compute_embeddings( encoder_hidden_states=prompt_embeds.float(), added_cond_kwargs=encoded_text, ).sample - pred_x_0 = predicted_origin( + pred_x_0 = get_predicted_original_sample( target_noise_pred, timesteps, x_prev, @@ -1389,7 +1387,7 @@ def compute_embeddings( ) target = c_skip * x_prev + c_out * pred_x_0 - # 11. Calculate loss + # 10. Calculate loss if args.loss_type == "l2": loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") elif args.loss_type == "huber": @@ -1397,7 +1395,7 @@ def compute_embeddings( torch.sqrt((model_pred.float() - target.float()) ** 2 + args.huber_c**2) - args.huber_c ) - # 12. Backpropagate on the online student model (`unet`) + # 11. Backpropagate on the online student model (`unet`) accelerator.backward(loss) if accelerator.sync_gradients: accelerator.clip_grad_norm_(unet.parameters(), args.max_grad_norm) @@ -1407,7 +1405,7 @@ def compute_embeddings( # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: - # 13. Make EMA update to target student model parameters (`target_unet`) + # 12. Make EMA update to target student model parameters (`target_unet`) update_ema(target_unet.parameters(), unet.parameters(), args.ema_decay) progress_bar.update(1) global_step += 1