diff --git a/examples/consistency_distillation/train_lcm_distill_lora_sd_wds.py b/examples/consistency_distillation/train_lcm_distill_lora_sd_wds.py index c85e2c462b04..0feb1108027a 100644 --- a/examples/consistency_distillation/train_lcm_distill_lora_sd_wds.py +++ b/examples/consistency_distillation/train_lcm_distill_lora_sd_wds.py @@ -61,6 +61,7 @@ UNet2DConditionModel, ) from diffusers.optimization import get_scheduler +from diffusers.training_utils import resolve_interpolation_mode from diffusers.utils import check_min_version, is_wandb_available from diffusers.utils.import_utils import is_xformers_available @@ -165,6 +166,7 @@ def __init__( global_batch_size: int, num_workers: int, resolution: int = 512, + interpolation_type: str = "bilinear", shuffle_buffer_size: int = 1000, pin_memory: bool = False, persistent_workers: bool = False, @@ -174,10 +176,12 @@ def __init__( # flatten list using itertools train_shards_path_or_url = list(itertools.chain.from_iterable(train_shards_path_or_url)) + interpolation_mode = resolve_interpolation_mode(interpolation_type) + def transform(example): # resize image image = example["image"] - image = TF.resize(image, resolution, interpolation=transforms.InterpolationMode.BILINEAR) + image = TF.resize(image, resolution, interpolation=interpolation_mode) # get crop coordinates and crop image c_top, c_left, _, _ = transforms.RandomCrop.get_params(image, output_size=(resolution, resolution)) @@ -353,8 +357,9 @@ def append_dims(x, target_dims): # From LCMScheduler.get_scalings_for_boundary_condition_discrete def scalings_for_boundary_conditions(timestep, sigma_data=0.5, timestep_scaling=10.0): - c_skip = sigma_data**2 / ((timestep / 0.1) ** 2 + sigma_data**2) - c_out = (timestep / 0.1) / ((timestep / 0.1) ** 2 + sigma_data**2) ** 0.5 + scaled_timestep = timestep_scaling * timestep + c_skip = sigma_data**2 / (scaled_timestep**2 + sigma_data**2) + c_out = scaled_timestep / (scaled_timestep**2 + sigma_data**2) ** 0.5 return c_skip, c_out @@ -572,6 +577,15 @@ def parse_args(): " resolution" ), ) + parser.add_argument( + "--interpolation_type", + type=str, + default="bilinear", + help=( + "The interpolation function used when resizing images to the desired resolution. Choose between `bilinear`," + " `bicubic`, `box`, `nearest`, `nearest_exact`, `hamming`, and `lanczos`." + ), + ) parser.add_argument( "--center_crop", default=False, @@ -710,6 +724,50 @@ def parse_args(): default=64, help="The rank of the LoRA projection matrix.", ) + parser.add_argument( + "--lora_alpha", + type=int, + default=64, + help=( + "The value of the LoRA alpha parameter, which controls the scaling factor in front of the LoRA weight" + " update delta_W. No scaling will be performed if this value is equal to `lora_rank`." + ), + ) + parser.add_argument( + "--lora_dropout", + type=float, + default=0.0, + help="The dropout probability for the dropout layer added before applying the LoRA to each layer input.", + ) + parser.add_argument( + "--lora_target_modules", + type=str, + default=None, + help=( + "A comma-separated string of target module keys to add LoRA to. If not set, a default list of modules will" + " be used. By default, LoRA will be applied to all conv and linear layers." + ), + ) + parser.add_argument( + "--vae_encode_batch_size", + type=int, + default=32, + required=False, + help=( + "The batch size used when encoding (and decoding) images to latents (and vice versa) using the VAE." + " Encoding or decoding the whole batch at once may run into OOM issues." + ), + ) + parser.add_argument( + "--timestep_scaling_factor", + type=float, + default=10.0, + help=( + "The multiplicative timestep scaling factor used when calculating the boundary scalings for LCM. The" + " higher the scaling is, the lower the approximation error, but the default value of 10.0 should typically" + " suffice." + ), + ) # ----Mixed Precision---- parser.add_argument( "--mixed_precision", @@ -915,9 +973,10 @@ def main(args): ) # 8. Add LoRA to the student U-Net, only the LoRA projection matrix will be updated by the optimizer. - lora_config = LoraConfig( - r=args.lora_rank, - target_modules=[ + if args.lora_target_modules is not None: + lora_target_modules = [module_key.strip() for module_key in args.lora_target_modules.split(",")] + else: + lora_target_modules = [ "to_q", "to_k", "to_v", @@ -932,7 +991,12 @@ def main(args): "downsamplers.0.conv", "upsamplers.0.conv", "time_emb_proj", - ], + ] + lora_config = LoraConfig( + r=args.lora_rank, + target_modules=lora_target_modules, + lora_alpha=args.lora_alpha, + lora_dropout=args.lora_dropout, ) unet = get_peft_model(unet, lora_config) @@ -1051,6 +1115,7 @@ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tok global_batch_size=args.train_batch_size * accelerator.num_processes, num_workers=args.dataloader_num_workers, resolution=args.resolution, + interpolation_type=args.interpolation_type, shuffle_buffer_size=1000, pin_memory=True, persistent_workers=True, @@ -1162,10 +1227,10 @@ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tok if vae.dtype != weight_dtype: vae.to(dtype=weight_dtype) - # encode pixel values with batch size of at most 32 + # encode pixel values with batch size of at most args.vae_encode_batch_size latents = [] - for i in range(0, pixel_values.shape[0], 32): - latents.append(vae.encode(pixel_values[i : i + 32]).latent_dist.sample()) + for i in range(0, pixel_values.shape[0], args.vae_encode_batch_size): + latents.append(vae.encode(pixel_values[i : i + args.vae_encode_batch_size]).latent_dist.sample()) latents = torch.cat(latents, dim=0) latents = latents * vae.config.scaling_factor @@ -1181,9 +1246,13 @@ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tok timesteps = torch.where(timesteps < 0, torch.zeros_like(timesteps), timesteps) # 3. Get boundary scalings for start_timesteps and (end) timesteps. - c_skip_start, c_out_start = scalings_for_boundary_conditions(start_timesteps) + c_skip_start, c_out_start = scalings_for_boundary_conditions( + start_timesteps, timestep_scaling=args.timestep_scaling_factor + ) c_skip_start, c_out_start = [append_dims(x, latents.ndim) for x in [c_skip_start, c_out_start]] - c_skip, c_out = scalings_for_boundary_conditions(timesteps) + c_skip, c_out = scalings_for_boundary_conditions( + timesteps, timestep_scaling=args.timestep_scaling_factor + ) c_skip, c_out = [append_dims(x, latents.ndim) for x in [c_skip, c_out]] # 4. Sample noise from the prior and add it to the latents according to the noise magnitude at each diff --git a/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py b/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py index 95e7b2dbaa27..fd8ce5f8cb51 100644 --- a/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py +++ b/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py @@ -51,6 +51,7 @@ UNet2DConditionModel, ) from diffusers.optimization import get_scheduler +from diffusers.training_utils import resolve_interpolation_mode from diffusers.utils import check_min_version, convert_state_dict_to_diffusers, is_wandb_available from diffusers.utils.import_utils import is_xformers_available @@ -193,8 +194,9 @@ def append_dims(x, target_dims): # From LCMScheduler.get_scalings_for_boundary_condition_discrete def scalings_for_boundary_conditions(timestep, sigma_data=0.5, timestep_scaling=10.0): - c_skip = sigma_data**2 / ((timestep / 0.1) ** 2 + sigma_data**2) - c_out = (timestep / 0.1) / ((timestep / 0.1) ** 2 + sigma_data**2) ** 0.5 + scaled_timestep = timestep_scaling * timestep + c_skip = sigma_data**2 / (scaled_timestep**2 + sigma_data**2) + c_out = scaled_timestep / (scaled_timestep**2 + sigma_data**2) ** 0.5 return c_skip, c_out @@ -396,6 +398,15 @@ def parse_args(): " resolution" ), ) + parser.add_argument( + "--interpolation_type", + type=str, + default="bilinear", + help=( + "The interpolation function used when resizing images to the desired resolution. Choose between `bilinear`," + " `bicubic`, `box`, `nearest`, `nearest_exact`, `hamming`, and `lanczos`." + ), + ) parser.add_argument( "--center_crop", default=False, @@ -534,6 +545,50 @@ def parse_args(): default=64, help="The rank of the LoRA projection matrix.", ) + parser.add_argument( + "--lora_alpha", + type=int, + default=64, + help=( + "The value of the LoRA alpha parameter, which controls the scaling factor in front of the LoRA weight" + " update delta_W. No scaling will be performed if this value is equal to `lora_rank`." + ), + ) + parser.add_argument( + "--lora_dropout", + type=float, + default=0.0, + help="The dropout probability for the dropout layer added before applying the LoRA to each layer input.", + ) + parser.add_argument( + "--lora_target_modules", + type=str, + default=None, + help=( + "A comma-separated string of target module keys to add LoRA to. If not set, a default list of modules will" + " be used. By default, LoRA will be applied to all conv and linear layers." + ), + ) + parser.add_argument( + "--vae_encode_batch_size", + type=int, + default=8, + required=False, + help=( + "The batch size used when encoding (and decoding) images to latents (and vice versa) using the VAE." + " Encoding or decoding the whole batch at once may run into OOM issues." + ), + ) + parser.add_argument( + "--timestep_scaling_factor", + type=float, + default=10.0, + help=( + "The multiplicative timestep scaling factor used when calculating the boundary scalings for LCM. The" + " higher the scaling is, the lower the approximation error, but the default value of 10.0 should typically" + " suffice." + ), + ) # ----Mixed Precision---- parser.add_argument( "--mixed_precision", @@ -776,10 +831,10 @@ def main(args): text_encoder_two.to(accelerator.device, dtype=weight_dtype) # 9. Add LoRA to the student U-Net, only the LoRA projection matrix will be updated by the optimizer. - lora_config = LoraConfig( - r=args.lora_rank, - lora_alpha=args.lora_rank, - target_modules=[ + if args.lora_target_modules is not None: + lora_target_modules = [module_key.strip() for module_key in args.lora_target_modules.split(",")] + else: + lora_target_modules = [ "to_q", "to_k", "to_v", @@ -794,7 +849,12 @@ def main(args): "downsamplers.0.conv", "upsamplers.0.conv", "time_emb_proj", - ], + ] + lora_config = LoraConfig( + r=args.lora_rank, + target_modules=lora_target_modules, + lora_alpha=args.lora_alpha, + lora_dropout=args.lora_dropout, ) unet.add_adapter(lora_config) @@ -929,7 +989,8 @@ def load_model_hook(models, input_dir): ) # Preprocessing the datasets. - train_resize = transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR) + interpolation_mode = resolve_interpolation_mode(args.interpolation_type) + train_resize = transforms.Resize(args.resolution, interpolation=interpolation_mode) train_crop = transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution) train_flip = transforms.RandomHorizontalFlip(p=1.0) train_transforms = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]) @@ -1121,11 +1182,11 @@ def compute_time_ids(original_size, crops_coords_top_left): encoded_text = compute_embeddings_fn(text, orig_size, crop_coords) - # encode pixel values with batch size of at most 8 + # encode pixel values with batch size of at most args.vae_encode_batch_size pixel_values = pixel_values.to(dtype=vae.dtype) latents = [] - for i in range(0, pixel_values.shape[0], args.encode_batch_size): - latents.append(vae.encode(pixel_values[i : i + args.encode_batch_size]).latent_dist.sample()) + for i in range(0, pixel_values.shape[0], args.vae_encode_batch_size): + latents.append(vae.encode(pixel_values[i : i + args.vae_encode_batch_size]).latent_dist.sample()) latents = torch.cat(latents, dim=0) latents = latents * vae.config.scaling_factor @@ -1142,9 +1203,13 @@ def compute_time_ids(original_size, crops_coords_top_left): timesteps = torch.where(timesteps < 0, torch.zeros_like(timesteps), timesteps) # 3. Get boundary scalings for start_timesteps and (end) timesteps. - c_skip_start, c_out_start = scalings_for_boundary_conditions(start_timesteps) + c_skip_start, c_out_start = scalings_for_boundary_conditions( + start_timesteps, timestep_scaling=args.timestep_scaling_factor + ) c_skip_start, c_out_start = [append_dims(x, latents.ndim) for x in [c_skip_start, c_out_start]] - c_skip, c_out = scalings_for_boundary_conditions(timesteps) + c_skip, c_out = scalings_for_boundary_conditions( + timesteps, timestep_scaling=args.timestep_scaling_factor + ) c_skip, c_out = [append_dims(x, latents.ndim) for x in [c_skip, c_out]] # 4. Sample noise from the prior and add it to the latents according to the noise magnitude at each diff --git a/examples/consistency_distillation/train_lcm_distill_lora_sdxl_wds.py b/examples/consistency_distillation/train_lcm_distill_lora_sdxl_wds.py index 75671c18c5e0..16d32c4280a6 100644 --- a/examples/consistency_distillation/train_lcm_distill_lora_sdxl_wds.py +++ b/examples/consistency_distillation/train_lcm_distill_lora_sdxl_wds.py @@ -62,6 +62,7 @@ UNet2DConditionModel, ) from diffusers.optimization import get_scheduler +from diffusers.training_utils import resolve_interpolation_mode from diffusers.utils import check_min_version, is_wandb_available from diffusers.utils.import_utils import is_xformers_available @@ -171,6 +172,7 @@ def __init__( global_batch_size: int, num_workers: int, resolution: int = 1024, + interpolation_type: str = "bilinear", shuffle_buffer_size: int = 1000, pin_memory: bool = False, persistent_workers: bool = False, @@ -187,10 +189,12 @@ def get_orig_size(json): else: return (int(json.get(WDS_JSON_WIDTH, 0.0)), int(json.get(WDS_JSON_HEIGHT, 0.0))) + interpolation_mode = resolve_interpolation_mode(interpolation_type) + def transform(example): # resize image image = example["image"] - image = TF.resize(image, resolution, interpolation=transforms.InterpolationMode.BILINEAR) + image = TF.resize(image, resolution, interpolation=interpolation_mode) # get crop coordinates and crop image c_top, c_left, _, _ = transforms.RandomCrop.get_params(image, output_size=(resolution, resolution)) @@ -340,8 +344,9 @@ def append_dims(x, target_dims): # From LCMScheduler.get_scalings_for_boundary_condition_discrete def scalings_for_boundary_conditions(timestep, sigma_data=0.5, timestep_scaling=10.0): - c_skip = sigma_data**2 / ((timestep / 0.1) ** 2 + sigma_data**2) - c_out = (timestep / 0.1) / ((timestep / 0.1) ** 2 + sigma_data**2) ** 0.5 + scaled_timestep = timestep_scaling * timestep + c_skip = sigma_data**2 / (scaled_timestep**2 + sigma_data**2) + c_out = scaled_timestep / (scaled_timestep**2 + sigma_data**2) ** 0.5 return c_skip, c_out @@ -546,6 +551,15 @@ def parse_args(): " resolution" ), ) + parser.add_argument( + "--interpolation_type", + type=str, + default="bilinear", + help=( + "The interpolation function used when resizing images to the desired resolution. Choose between `bilinear`," + " `bicubic`, `box`, `nearest`, `nearest_exact`, `hamming`, and `lanczos`." + ), + ) parser.add_argument( "--use_fix_crop_and_size", action="store_true", @@ -690,6 +704,50 @@ def parse_args(): default=64, help="The rank of the LoRA projection matrix.", ) + parser.add_argument( + "--lora_alpha", + type=int, + default=64, + help=( + "The value of the LoRA alpha parameter, which controls the scaling factor in front of the LoRA weight" + " update delta_W. No scaling will be performed if this value is equal to `lora_rank`." + ), + ) + parser.add_argument( + "--lora_dropout", + type=float, + default=0.0, + help="The dropout probability for the dropout layer added before applying the LoRA to each layer input.", + ) + parser.add_argument( + "--lora_target_modules", + type=str, + default=None, + help=( + "A comma-separated string of target module keys to add LoRA to. If not set, a default list of modules will" + " be used. By default, LoRA will be applied to all conv and linear layers." + ), + ) + parser.add_argument( + "--vae_encode_batch_size", + type=int, + default=8, + required=False, + help=( + "The batch size used when encoding (and decoding) images to latents (and vice versa) using the VAE." + " Encoding or decoding the whole batch at once may run into OOM issues." + ), + ) + parser.add_argument( + "--timestep_scaling_factor", + type=float, + default=10.0, + help=( + "The multiplicative timestep scaling factor used when calculating the boundary scalings for LCM. The" + " higher the scaling is, the lower the approximation error, but the default value of 10.0 should typically" + " suffice." + ), + ) # ----Mixed Precision---- parser.add_argument( "--mixed_precision", @@ -929,9 +987,10 @@ def main(args): ) # 8. Add LoRA to the student U-Net, only the LoRA projection matrix will be updated by the optimizer. - lora_config = LoraConfig( - r=args.lora_rank, - target_modules=[ + if args.lora_target_modules is not None: + lora_target_modules = [module_key.strip() for module_key in args.lora_target_modules.split(",")] + else: + lora_target_modules = [ "to_q", "to_k", "to_v", @@ -946,7 +1005,12 @@ def main(args): "downsamplers.0.conv", "upsamplers.0.conv", "time_emb_proj", - ], + ] + lora_config = LoraConfig( + r=args.lora_rank, + target_modules=lora_target_modules, + lora_alpha=args.lora_alpha, + lora_dropout=args.lora_dropout, ) unet = get_peft_model(unet, lora_config) @@ -1090,6 +1154,7 @@ def compute_embeddings( global_batch_size=args.train_batch_size * accelerator.num_processes, num_workers=args.dataloader_num_workers, resolution=args.resolution, + interpolation_type=args.interpolation_type, shuffle_buffer_size=1000, pin_memory=True, persistent_workers=True, @@ -1214,10 +1279,10 @@ def compute_embeddings( else: pixel_values = image - # encode pixel values with batch size of at most 8 + # encode pixel values with batch size of at most args.vae_encode_batch_size latents = [] - for i in range(0, pixel_values.shape[0], 8): - latents.append(vae.encode(pixel_values[i : i + 8]).latent_dist.sample()) + for i in range(0, pixel_values.shape[0], args.vae_encode_batch_size): + latents.append(vae.encode(pixel_values[i : i + args.vae_encode_batch_size]).latent_dist.sample()) latents = torch.cat(latents, dim=0) latents = latents * vae.config.scaling_factor @@ -1234,9 +1299,13 @@ def compute_embeddings( timesteps = torch.where(timesteps < 0, torch.zeros_like(timesteps), timesteps) # 3. Get boundary scalings for start_timesteps and (end) timesteps. - c_skip_start, c_out_start = scalings_for_boundary_conditions(start_timesteps) + c_skip_start, c_out_start = scalings_for_boundary_conditions( + start_timesteps, timestep_scaling=args.timestep_scaling_factor + ) c_skip_start, c_out_start = [append_dims(x, latents.ndim) for x in [c_skip_start, c_out_start]] - c_skip, c_out = scalings_for_boundary_conditions(timesteps) + c_skip, c_out = scalings_for_boundary_conditions( + timesteps, timestep_scaling=args.timestep_scaling_factor + ) c_skip, c_out = [append_dims(x, latents.ndim) for x in [c_skip, c_out]] # 4. Sample noise from the prior and add it to the latents according to the noise magnitude at each diff --git a/examples/consistency_distillation/train_lcm_distill_sd_wds.py b/examples/consistency_distillation/train_lcm_distill_sd_wds.py index b2085b7044ba..b6bccfbb82b3 100644 --- a/examples/consistency_distillation/train_lcm_distill_sd_wds.py +++ b/examples/consistency_distillation/train_lcm_distill_sd_wds.py @@ -60,6 +60,7 @@ UNet2DConditionModel, ) from diffusers.optimization import get_scheduler +from diffusers.training_utils import resolve_interpolation_mode from diffusers.utils import check_min_version, is_wandb_available from diffusers.utils.import_utils import is_xformers_available @@ -147,6 +148,7 @@ def __init__( global_batch_size: int, num_workers: int, resolution: int = 512, + interpolation_type: str = "bilinear", shuffle_buffer_size: int = 1000, pin_memory: bool = False, persistent_workers: bool = False, @@ -156,10 +158,12 @@ def __init__( # flatten list using itertools train_shards_path_or_url = list(itertools.chain.from_iterable(train_shards_path_or_url)) + interpolation_mode = resolve_interpolation_mode(interpolation_type) + def transform(example): # resize image image = example["image"] - image = TF.resize(image, resolution, interpolation=transforms.InterpolationMode.BILINEAR) + image = TF.resize(image, resolution, interpolation=interpolation_mode) # get crop coordinates and crop image c_top, c_left, _, _ = transforms.RandomCrop.get_params(image, output_size=(resolution, resolution)) @@ -330,8 +334,9 @@ def append_dims(x, target_dims): # From LCMScheduler.get_scalings_for_boundary_condition_discrete def scalings_for_boundary_conditions(timestep, sigma_data=0.5, timestep_scaling=10.0): - c_skip = sigma_data**2 / ((timestep / 0.1) ** 2 + sigma_data**2) - c_out = (timestep / 0.1) / ((timestep / 0.1) ** 2 + sigma_data**2) ** 0.5 + scaled_timestep = timestep_scaling * timestep + c_skip = sigma_data**2 / (scaled_timestep**2 + sigma_data**2) + c_out = scaled_timestep / (scaled_timestep**2 + sigma_data**2) ** 0.5 return c_skip, c_out @@ -549,6 +554,15 @@ def parse_args(): " resolution" ), ) + parser.add_argument( + "--interpolation_type", + type=str, + default="bilinear", + help=( + "The interpolation function used when resizing images to the desired resolution. Choose between `bilinear`," + " `bicubic`, `box`, `nearest`, `nearest_exact`, `hamming`, and `lanczos`." + ), + ) parser.add_argument( "--center_crop", default=False, @@ -690,6 +704,26 @@ def parse_args(): " does not have `time_cond_proj_dim` set." ), ) + parser.add_argument( + "--vae_encode_batch_size", + type=int, + default=32, + required=False, + help=( + "The batch size used when encoding (and decoding) images to latents (and vice versa) using the VAE." + " Encoding or decoding the whole batch at once may run into OOM issues." + ), + ) + parser.add_argument( + "--timestep_scaling_factor", + type=float, + default=10.0, + help=( + "The multiplicative timestep scaling factor used when calculating the boundary scalings for LCM. The" + " higher the scaling is, the lower the approximation error, but the default value of 10.0 should typically" + " suffice." + ), + ) # ----Exponential Moving Average (EMA)---- parser.add_argument( "--ema_decay", @@ -1034,6 +1068,7 @@ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tok global_batch_size=args.train_batch_size * accelerator.num_processes, num_workers=args.dataloader_num_workers, resolution=args.resolution, + interpolation_type=args.interpolation_type, shuffle_buffer_size=1000, pin_memory=True, persistent_workers=True, @@ -1145,10 +1180,10 @@ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tok if vae.dtype != weight_dtype: vae.to(dtype=weight_dtype) - # encode pixel values with batch size of at most 32 + # encode pixel values with batch size of at most args.vae_encode_batch_size latents = [] - for i in range(0, pixel_values.shape[0], 32): - latents.append(vae.encode(pixel_values[i : i + 32]).latent_dist.sample()) + for i in range(0, pixel_values.shape[0], args.vae_encode_batch_size): + latents.append(vae.encode(pixel_values[i : i + args.vae_encode_batch_size]).latent_dist.sample()) latents = torch.cat(latents, dim=0) latents = latents * vae.config.scaling_factor @@ -1164,9 +1199,13 @@ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tok timesteps = torch.where(timesteps < 0, torch.zeros_like(timesteps), timesteps) # 3. Get boundary scalings for start_timesteps and (end) timesteps. - c_skip_start, c_out_start = scalings_for_boundary_conditions(start_timesteps) + c_skip_start, c_out_start = scalings_for_boundary_conditions( + start_timesteps, timestep_scaling=args.timestep_scaling_factor + ) c_skip_start, c_out_start = [append_dims(x, latents.ndim) for x in [c_skip_start, c_out_start]] - c_skip, c_out = scalings_for_boundary_conditions(timesteps) + c_skip, c_out = scalings_for_boundary_conditions( + timesteps, timestep_scaling=args.timestep_scaling_factor + ) c_skip, c_out = [append_dims(x, latents.ndim) for x in [c_skip, c_out]] # 4. Sample noise from the prior and add it to the latents according to the noise magnitude at each diff --git a/examples/consistency_distillation/train_lcm_distill_sdxl_wds.py b/examples/consistency_distillation/train_lcm_distill_sdxl_wds.py index ee86def673fa..dab38c96a254 100644 --- a/examples/consistency_distillation/train_lcm_distill_sdxl_wds.py +++ b/examples/consistency_distillation/train_lcm_distill_sdxl_wds.py @@ -61,6 +61,7 @@ UNet2DConditionModel, ) from diffusers.optimization import get_scheduler +from diffusers.training_utils import resolve_interpolation_mode from diffusers.utils import check_min_version, is_wandb_available from diffusers.utils.import_utils import is_xformers_available @@ -153,6 +154,7 @@ def __init__( global_batch_size: int, num_workers: int, resolution: int = 1024, + interpolation_type: str = "bilinear", shuffle_buffer_size: int = 1000, pin_memory: bool = False, persistent_workers: bool = False, @@ -169,10 +171,12 @@ def get_orig_size(json): else: return (int(json.get(WDS_JSON_WIDTH, 0.0)), int(json.get(WDS_JSON_HEIGHT, 0.0))) + interpolation_mode = resolve_interpolation_mode(interpolation_type) + def transform(example): # resize image image = example["image"] - image = TF.resize(image, resolution, interpolation=transforms.InterpolationMode.BILINEAR) + image = TF.resize(image, resolution, interpolation=interpolation_mode) # get crop coordinates and crop image c_top, c_left, _, _ = transforms.RandomCrop.get_params(image, output_size=(resolution, resolution)) @@ -318,8 +322,9 @@ def append_dims(x, target_dims): # From LCMScheduler.get_scalings_for_boundary_condition_discrete def scalings_for_boundary_conditions(timestep, sigma_data=0.5, timestep_scaling=10.0): - c_skip = sigma_data**2 / ((timestep / 0.1) ** 2 + sigma_data**2) - c_out = (timestep / 0.1) / ((timestep / 0.1) ** 2 + sigma_data**2) ** 0.5 + scaled_timestep = timestep_scaling * timestep + c_skip = sigma_data**2 / (scaled_timestep**2 + sigma_data**2) + c_out = scaled_timestep / (scaled_timestep**2 + sigma_data**2) ** 0.5 return c_skip, c_out @@ -568,6 +573,15 @@ def parse_args(): " resolution" ), ) + parser.add_argument( + "--interpolation_type", + type=str, + default="bilinear", + help=( + "The interpolation function used when resizing images to the desired resolution. Choose between `bilinear`," + " `bicubic`, `box`, `nearest`, `nearest_exact`, `hamming`, and `lanczos`." + ), + ) parser.add_argument( "--use_fix_crop_and_size", action="store_true", @@ -715,6 +729,26 @@ def parse_args(): " does not have `time_cond_proj_dim` set." ), ) + parser.add_argument( + "--vae_encode_batch_size", + type=int, + default=8, + required=False, + help=( + "The batch size used when encoding (and decoding) images to latents (and vice versa) using the VAE." + " Encoding or decoding the whole batch at once may run into OOM issues." + ), + ) + parser.add_argument( + "--timestep_scaling_factor", + type=float, + default=10.0, + help=( + "The multiplicative timestep scaling factor used when calculating the boundary scalings for LCM. The" + " higher the scaling is, the lower the approximation error, but the default value of 10.0 should typically" + " suffice." + ), + ) # ----Exponential Moving Average (EMA)---- parser.add_argument( "--ema_decay", @@ -1118,6 +1152,7 @@ def compute_embeddings( global_batch_size=args.train_batch_size * accelerator.num_processes, num_workers=args.dataloader_num_workers, resolution=args.resolution, + interpolation_type=args.interpolation_type, shuffle_buffer_size=1000, pin_memory=True, persistent_workers=True, @@ -1242,10 +1277,10 @@ def compute_embeddings( else: pixel_values = image - # encode pixel values with batch size of at most 8 + # encode pixel values with batch size of at most args.vae_encode_batch_size latents = [] - for i in range(0, pixel_values.shape[0], 8): - latents.append(vae.encode(pixel_values[i : i + 8]).latent_dist.sample()) + for i in range(0, pixel_values.shape[0], args.vae_encode_batch_size): + latents.append(vae.encode(pixel_values[i : i + args.vae_encode_batch_size]).latent_dist.sample()) latents = torch.cat(latents, dim=0) latents = latents * vae.config.scaling_factor @@ -1262,9 +1297,13 @@ def compute_embeddings( timesteps = torch.where(timesteps < 0, torch.zeros_like(timesteps), timesteps) # 3. Get boundary scalings for start_timesteps and (end) timesteps. - c_skip_start, c_out_start = scalings_for_boundary_conditions(start_timesteps) + c_skip_start, c_out_start = scalings_for_boundary_conditions( + start_timesteps, timestep_scaling=args.timestep_scaling_factor + ) c_skip_start, c_out_start = [append_dims(x, latents.ndim) for x in [c_skip_start, c_out_start]] - c_skip, c_out = scalings_for_boundary_conditions(timesteps) + c_skip, c_out = scalings_for_boundary_conditions( + timesteps, timestep_scaling=args.timestep_scaling_factor + ) c_skip, c_out = [append_dims(x, latents.ndim) for x in [c_skip, c_out]] # 4. Sample noise from the prior and add it to the latents according to the noise magnitude at each diff --git a/src/diffusers/training_utils.py b/src/diffusers/training_utils.py index 992ae7d1b194..9fb6fad3a267 100644 --- a/src/diffusers/training_utils.py +++ b/src/diffusers/training_utils.py @@ -5,6 +5,7 @@ import numpy as np import torch +from torchvision import transforms from .models import UNet2DConditionModel from .utils import deprecate, is_transformers_available @@ -53,6 +54,45 @@ def compute_snr(noise_scheduler, timesteps): return snr +def resolve_interpolation_mode(interpolation_type: str): + """ + Maps a string describing an interpolation function to the corresponding torchvision `InterpolationMode` enum. The + full list of supported enums is documented at + https://pytorch.org/vision/0.9/transforms.html#torchvision.transforms.functional.InterpolationMode. + + Args: + interpolation_type (`str`): + A string describing an interpolation method. Currently, `bilinear`, `bicubic`, `box`, `nearest`, + `nearest_exact`, `hamming`, and `lanczos` are supported, corresponding to the supported interpolation modes + in torchvision. + + Returns: + `torchvision.transforms.InterpolationMode`: an `InterpolationMode` enum used by torchvision's `resize` + transform. + """ + if interpolation_type == "bilinear": + interpolation_mode = transforms.InterpolationMode.BILINEAR + elif interpolation_type == "bicubic": + interpolation_mode = transforms.InterpolationMode.BICUBIC + elif interpolation_type == "box": + interpolation_mode = transforms.InterpolationMode.BOX + elif interpolation_type == "nearest": + interpolation_mode = transforms.InterpolationMode.NEAREST + elif interpolation_type == "nearest_exact": + interpolation_mode = transforms.InterpolationMode.NEAREST_EXACT + elif interpolation_type == "hamming": + interpolation_mode = transforms.InterpolationMode.HAMMING + elif interpolation_type == "lanczos": + interpolation_mode = transforms.InterpolationMode.LANCZOS + else: + raise ValueError( + f"The given interpolation mode {interpolation_type} is not supported. Currently supported interpolation" + f" modes are `bilinear`, `bicubic`, `box`, `nearest`, `nearest_exact`, `hamming`, and `lanczos`." + ) + + return interpolation_mode + + def unet_lora_state_dict(unet: UNet2DConditionModel) -> Dict[str, torch.Tensor]: r""" Returns: