From 29e2a8031f73a650a5f9e67db38af5fe23215c4f Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Mon, 1 Jan 2024 12:31:22 -0800 Subject: [PATCH 1/7] Make WDS pipeline interpolation type configurable. --- .../train_lcm_distill_lora_sd_wds.py | 33 ++++++++++++++++++- .../train_lcm_distill_lora_sdxl_wds.py | 33 ++++++++++++++++++- .../train_lcm_distill_sd_wds.py | 33 ++++++++++++++++++- .../train_lcm_distill_sdxl_wds.py | 33 ++++++++++++++++++- 4 files changed, 128 insertions(+), 4 deletions(-) diff --git a/examples/consistency_distillation/train_lcm_distill_lora_sd_wds.py b/examples/consistency_distillation/train_lcm_distill_lora_sd_wds.py index c85e2c462b04..7621bfc22dbc 100644 --- a/examples/consistency_distillation/train_lcm_distill_lora_sd_wds.py +++ b/examples/consistency_distillation/train_lcm_distill_lora_sd_wds.py @@ -136,6 +136,24 @@ def tarfile_to_samples_nothrow(src, handler=wds.warn_and_continue): return samples +def resolve_interpolation_mode(interpolation_type): + if interpolation_type == "bilinear": + interpolation_mode = TF.InterpolationMode.BILINEAR + elif interpolation_type == "bicubic": + interpolation_mode = TF.InterpolationMode.BICUBIC + elif interpolation_type == "nearest": + interpolation_mode = TF.InterpolationMode.NEAREST + elif interpolation_type == "lanczos": + interpolation_mode = TF.InterpolationMode.LANCZOS + else: + raise ValueError( + f"The given interpolation mode {interpolation_type} is not supported. Currently supported interpolation" + f" modes are `bilinear`, `bicubic`, `lanczos`, and `nearest`." + ) + + return interpolation_mode + + class WebdatasetFilter: def __init__(self, min_size=1024, max_pwatermark=0.5): self.min_size = min_size @@ -165,6 +183,7 @@ def __init__( global_batch_size: int, num_workers: int, resolution: int = 512, + interpolation_type: str = "bilinear", shuffle_buffer_size: int = 1000, pin_memory: bool = False, persistent_workers: bool = False, @@ -174,10 +193,12 @@ def __init__( # flatten list using itertools train_shards_path_or_url = list(itertools.chain.from_iterable(train_shards_path_or_url)) + interpolation_mode = resolve_interpolation_mode(interpolation_type) + def transform(example): # resize image image = example["image"] - image = TF.resize(image, resolution, interpolation=transforms.InterpolationMode.BILINEAR) + image = TF.resize(image, resolution, interpolation=interpolation_mode) # get crop coordinates and crop image c_top, c_left, _, _ = transforms.RandomCrop.get_params(image, output_size=(resolution, resolution)) @@ -572,6 +593,15 @@ def parse_args(): " resolution" ), ) + parser.add_argument( + "--interpolation_type", + type=str, + default="bilinear", + help=( + "The interpolation function used when resizing images to the desired resolution. Choose between `bilinear`," + " `bicubic`, `lanczos`, and `nearest`." + ), + ) parser.add_argument( "--center_crop", default=False, @@ -1051,6 +1081,7 @@ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tok global_batch_size=args.train_batch_size * accelerator.num_processes, num_workers=args.dataloader_num_workers, resolution=args.resolution, + interpolation_type=args.interpolation_type, shuffle_buffer_size=1000, pin_memory=True, persistent_workers=True, diff --git a/examples/consistency_distillation/train_lcm_distill_lora_sdxl_wds.py b/examples/consistency_distillation/train_lcm_distill_lora_sdxl_wds.py index 75671c18c5e0..09a340c440bb 100644 --- a/examples/consistency_distillation/train_lcm_distill_lora_sdxl_wds.py +++ b/examples/consistency_distillation/train_lcm_distill_lora_sdxl_wds.py @@ -142,6 +142,24 @@ def tarfile_to_samples_nothrow(src, handler=wds.warn_and_continue): return samples +def resolve_interpolation_mode(interpolation_type): + if interpolation_type == "bilinear": + interpolation_mode = TF.InterpolationMode.BILINEAR + elif interpolation_type == "bicubic": + interpolation_mode = TF.InterpolationMode.BICUBIC + elif interpolation_type == "nearest": + interpolation_mode = TF.InterpolationMode.NEAREST + elif interpolation_type == "lanczos": + interpolation_mode = TF.InterpolationMode.LANCZOS + else: + raise ValueError( + f"The given interpolation mode {interpolation_type} is not supported. Currently supported interpolation" + f" modes are `bilinear`, `bicubic`, `lanczos`, and `nearest`." + ) + + return interpolation_mode + + class WebdatasetFilter: def __init__(self, min_size=1024, max_pwatermark=0.5): self.min_size = min_size @@ -171,6 +189,7 @@ def __init__( global_batch_size: int, num_workers: int, resolution: int = 1024, + interpolation_type: str = "bilinear", shuffle_buffer_size: int = 1000, pin_memory: bool = False, persistent_workers: bool = False, @@ -187,10 +206,12 @@ def get_orig_size(json): else: return (int(json.get(WDS_JSON_WIDTH, 0.0)), int(json.get(WDS_JSON_HEIGHT, 0.0))) + interpolation_mode = resolve_interpolation_mode(interpolation_type) + def transform(example): # resize image image = example["image"] - image = TF.resize(image, resolution, interpolation=transforms.InterpolationMode.BILINEAR) + image = TF.resize(image, resolution, interpolation=interpolation_mode) # get crop coordinates and crop image c_top, c_left, _, _ = transforms.RandomCrop.get_params(image, output_size=(resolution, resolution)) @@ -546,6 +567,15 @@ def parse_args(): " resolution" ), ) + parser.add_argument( + "--interpolation_type", + type=str, + default="bilinear", + help=( + "The interpolation function used when resizing images to the desired resolution. Choose between `bilinear`," + " `bicubic`, `lanczos`, and `nearest`." + ), + ) parser.add_argument( "--use_fix_crop_and_size", action="store_true", @@ -1090,6 +1120,7 @@ def compute_embeddings( global_batch_size=args.train_batch_size * accelerator.num_processes, num_workers=args.dataloader_num_workers, resolution=args.resolution, + interpolation_type=args.interpolation_type, shuffle_buffer_size=1000, pin_memory=True, persistent_workers=True, diff --git a/examples/consistency_distillation/train_lcm_distill_sd_wds.py b/examples/consistency_distillation/train_lcm_distill_sd_wds.py index b2085b7044ba..9663588feb97 100644 --- a/examples/consistency_distillation/train_lcm_distill_sd_wds.py +++ b/examples/consistency_distillation/train_lcm_distill_sd_wds.py @@ -118,6 +118,24 @@ def tarfile_to_samples_nothrow(src, handler=wds.warn_and_continue): return samples +def resolve_interpolation_mode(interpolation_type): + if interpolation_type == "bilinear": + interpolation_mode = TF.InterpolationMode.BILINEAR + elif interpolation_type == "bicubic": + interpolation_mode = TF.InterpolationMode.BICUBIC + elif interpolation_type == "nearest": + interpolation_mode = TF.InterpolationMode.NEAREST + elif interpolation_type == "lanczos": + interpolation_mode = TF.InterpolationMode.LANCZOS + else: + raise ValueError( + f"The given interpolation mode {interpolation_type} is not supported. Currently supported interpolation" + f" modes are `bilinear`, `bicubic`, `lanczos`, and `nearest`." + ) + + return interpolation_mode + + class WebdatasetFilter: def __init__(self, min_size=1024, max_pwatermark=0.5): self.min_size = min_size @@ -147,6 +165,7 @@ def __init__( global_batch_size: int, num_workers: int, resolution: int = 512, + interpolation_type: str = "bilinear", shuffle_buffer_size: int = 1000, pin_memory: bool = False, persistent_workers: bool = False, @@ -156,10 +175,12 @@ def __init__( # flatten list using itertools train_shards_path_or_url = list(itertools.chain.from_iterable(train_shards_path_or_url)) + interpolation_mode = resolve_interpolation_mode(interpolation_type) + def transform(example): # resize image image = example["image"] - image = TF.resize(image, resolution, interpolation=transforms.InterpolationMode.BILINEAR) + image = TF.resize(image, resolution, interpolation=interpolation_mode) # get crop coordinates and crop image c_top, c_left, _, _ = transforms.RandomCrop.get_params(image, output_size=(resolution, resolution)) @@ -549,6 +570,15 @@ def parse_args(): " resolution" ), ) + parser.add_argument( + "--interpolation_type", + type=str, + default="bilinear", + help=( + "The interpolation function used when resizing images to the desired resolution. Choose between `bilinear`," + " `bicubic`, `lanczos`, and `nearest`." + ), + ) parser.add_argument( "--center_crop", default=False, @@ -1034,6 +1064,7 @@ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tok global_batch_size=args.train_batch_size * accelerator.num_processes, num_workers=args.dataloader_num_workers, resolution=args.resolution, + interpolation_type=args.interpolation_type, shuffle_buffer_size=1000, pin_memory=True, persistent_workers=True, diff --git a/examples/consistency_distillation/train_lcm_distill_sdxl_wds.py b/examples/consistency_distillation/train_lcm_distill_sdxl_wds.py index ee86def673fa..51480216eb42 100644 --- a/examples/consistency_distillation/train_lcm_distill_sdxl_wds.py +++ b/examples/consistency_distillation/train_lcm_distill_sdxl_wds.py @@ -124,6 +124,24 @@ def tarfile_to_samples_nothrow(src, handler=wds.warn_and_continue): return samples +def resolve_interpolation_mode(interpolation_type): + if interpolation_type == "bilinear": + interpolation_mode = TF.InterpolationMode.BILINEAR + elif interpolation_type == "bicubic": + interpolation_mode = TF.InterpolationMode.BICUBIC + elif interpolation_type == "nearest": + interpolation_mode = TF.InterpolationMode.NEAREST + elif interpolation_type == "lanczos": + interpolation_mode = TF.InterpolationMode.LANCZOS + else: + raise ValueError( + f"The given interpolation mode {interpolation_type} is not supported. Currently supported interpolation" + f" modes are `bilinear`, `bicubic`, `lanczos`, and `nearest`." + ) + + return interpolation_mode + + class WebdatasetFilter: def __init__(self, min_size=1024, max_pwatermark=0.5): self.min_size = min_size @@ -153,6 +171,7 @@ def __init__( global_batch_size: int, num_workers: int, resolution: int = 1024, + interpolation_type: str = "bilinear", shuffle_buffer_size: int = 1000, pin_memory: bool = False, persistent_workers: bool = False, @@ -169,10 +188,12 @@ def get_orig_size(json): else: return (int(json.get(WDS_JSON_WIDTH, 0.0)), int(json.get(WDS_JSON_HEIGHT, 0.0))) + interpolation_mode = resolve_interpolation_mode(interpolation_type) + def transform(example): # resize image image = example["image"] - image = TF.resize(image, resolution, interpolation=transforms.InterpolationMode.BILINEAR) + image = TF.resize(image, resolution, interpolation=interpolation_mode) # get crop coordinates and crop image c_top, c_left, _, _ = transforms.RandomCrop.get_params(image, output_size=(resolution, resolution)) @@ -568,6 +589,15 @@ def parse_args(): " resolution" ), ) + parser.add_argument( + "--interpolation_type", + type=str, + default="bilinear", + help=( + "The interpolation function used when resizing images to the desired resolution. Choose between `bilinear`," + " `bicubic`, `lanczos`, and `nearest`." + ), + ) parser.add_argument( "--use_fix_crop_and_size", action="store_true", @@ -1118,6 +1148,7 @@ def compute_embeddings( global_batch_size=args.train_batch_size * accelerator.num_processes, num_workers=args.dataloader_num_workers, resolution=args.resolution, + interpolation_type=args.interpolation_type, shuffle_buffer_size=1000, pin_memory=True, persistent_workers=True, From 5abbbea9268b888ff31d8d8fef895ec9abdd4a82 Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Mon, 1 Jan 2024 12:47:41 -0800 Subject: [PATCH 2/7] Make the VAE encoding batch size configurable. --- .../train_lcm_distill_lora_sd_wds.py | 16 +++++++++++++--- .../train_lcm_distill_lora_sdxl.py | 16 +++++++++++++--- .../train_lcm_distill_lora_sdxl_wds.py | 16 +++++++++++++--- .../train_lcm_distill_sd_wds.py | 16 +++++++++++++--- .../train_lcm_distill_sdxl_wds.py | 16 +++++++++++++--- 5 files changed, 65 insertions(+), 15 deletions(-) diff --git a/examples/consistency_distillation/train_lcm_distill_lora_sd_wds.py b/examples/consistency_distillation/train_lcm_distill_lora_sd_wds.py index 7621bfc22dbc..1e7149df2ed5 100644 --- a/examples/consistency_distillation/train_lcm_distill_lora_sd_wds.py +++ b/examples/consistency_distillation/train_lcm_distill_lora_sd_wds.py @@ -740,6 +740,16 @@ def parse_args(): default=64, help="The rank of the LoRA projection matrix.", ) + parser.add_argument( + "--vae_encode_batch_size", + type=int, + default=32, + required=False, + help=( + "The batch size used when encoding (and decoding) images to latents (and vice versa) using the VAE." + " Encoding or decoding the whole batch at once may run into OOM issues." + ), + ) # ----Mixed Precision---- parser.add_argument( "--mixed_precision", @@ -1193,10 +1203,10 @@ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tok if vae.dtype != weight_dtype: vae.to(dtype=weight_dtype) - # encode pixel values with batch size of at most 32 + # encode pixel values with batch size of at most args.vae_encode_batch_size latents = [] - for i in range(0, pixel_values.shape[0], 32): - latents.append(vae.encode(pixel_values[i : i + 32]).latent_dist.sample()) + for i in range(0, pixel_values.shape[0], args.vae_encode_batch_size): + latents.append(vae.encode(pixel_values[i : i + args.vae_encode_batch_size]).latent_dist.sample()) latents = torch.cat(latents, dim=0) latents = latents * vae.config.scaling_factor diff --git a/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py b/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py index 95e7b2dbaa27..ace9cdfb7d94 100644 --- a/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py +++ b/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py @@ -534,6 +534,16 @@ def parse_args(): default=64, help="The rank of the LoRA projection matrix.", ) + parser.add_argument( + "--vae_encode_batch_size", + type=int, + default=8, + required=False, + help=( + "The batch size used when encoding (and decoding) images to latents (and vice versa) using the VAE." + " Encoding or decoding the whole batch at once may run into OOM issues." + ), + ) # ----Mixed Precision---- parser.add_argument( "--mixed_precision", @@ -1121,11 +1131,11 @@ def compute_time_ids(original_size, crops_coords_top_left): encoded_text = compute_embeddings_fn(text, orig_size, crop_coords) - # encode pixel values with batch size of at most 8 + # encode pixel values with batch size of at most args.vae_encode_batch_size pixel_values = pixel_values.to(dtype=vae.dtype) latents = [] - for i in range(0, pixel_values.shape[0], args.encode_batch_size): - latents.append(vae.encode(pixel_values[i : i + args.encode_batch_size]).latent_dist.sample()) + for i in range(0, pixel_values.shape[0], args.vae_encode_batch_size): + latents.append(vae.encode(pixel_values[i : i + args.vae_encode_batch_size]).latent_dist.sample()) latents = torch.cat(latents, dim=0) latents = latents * vae.config.scaling_factor diff --git a/examples/consistency_distillation/train_lcm_distill_lora_sdxl_wds.py b/examples/consistency_distillation/train_lcm_distill_lora_sdxl_wds.py index 09a340c440bb..d5e1be62aa35 100644 --- a/examples/consistency_distillation/train_lcm_distill_lora_sdxl_wds.py +++ b/examples/consistency_distillation/train_lcm_distill_lora_sdxl_wds.py @@ -720,6 +720,16 @@ def parse_args(): default=64, help="The rank of the LoRA projection matrix.", ) + parser.add_argument( + "--vae_encode_batch_size", + type=int, + default=8, + required=False, + help=( + "The batch size used when encoding (and decoding) images to latents (and vice versa) using the VAE." + " Encoding or decoding the whole batch at once may run into OOM issues." + ), + ) # ----Mixed Precision---- parser.add_argument( "--mixed_precision", @@ -1245,10 +1255,10 @@ def compute_embeddings( else: pixel_values = image - # encode pixel values with batch size of at most 8 + # encode pixel values with batch size of at most args.vae_encode_batch_size latents = [] - for i in range(0, pixel_values.shape[0], 8): - latents.append(vae.encode(pixel_values[i : i + 8]).latent_dist.sample()) + for i in range(0, pixel_values.shape[0], args.vae_encode_batch_size): + latents.append(vae.encode(pixel_values[i : i + args.vae_encode_batch_size]).latent_dist.sample()) latents = torch.cat(latents, dim=0) latents = latents * vae.config.scaling_factor diff --git a/examples/consistency_distillation/train_lcm_distill_sd_wds.py b/examples/consistency_distillation/train_lcm_distill_sd_wds.py index 9663588feb97..e190fea194dd 100644 --- a/examples/consistency_distillation/train_lcm_distill_sd_wds.py +++ b/examples/consistency_distillation/train_lcm_distill_sd_wds.py @@ -720,6 +720,16 @@ def parse_args(): " does not have `time_cond_proj_dim` set." ), ) + parser.add_argument( + "--vae_encode_batch_size", + type=int, + default=32, + required=False, + help=( + "The batch size used when encoding (and decoding) images to latents (and vice versa) using the VAE." + " Encoding or decoding the whole batch at once may run into OOM issues." + ), + ) # ----Exponential Moving Average (EMA)---- parser.add_argument( "--ema_decay", @@ -1176,10 +1186,10 @@ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tok if vae.dtype != weight_dtype: vae.to(dtype=weight_dtype) - # encode pixel values with batch size of at most 32 + # encode pixel values with batch size of at most args.vae_encode_batch_size latents = [] - for i in range(0, pixel_values.shape[0], 32): - latents.append(vae.encode(pixel_values[i : i + 32]).latent_dist.sample()) + for i in range(0, pixel_values.shape[0], args.vae_encode_batch_size): + latents.append(vae.encode(pixel_values[i : i + args.vae_encode_batch_size]).latent_dist.sample()) latents = torch.cat(latents, dim=0) latents = latents * vae.config.scaling_factor diff --git a/examples/consistency_distillation/train_lcm_distill_sdxl_wds.py b/examples/consistency_distillation/train_lcm_distill_sdxl_wds.py index 51480216eb42..b669025807b4 100644 --- a/examples/consistency_distillation/train_lcm_distill_sdxl_wds.py +++ b/examples/consistency_distillation/train_lcm_distill_sdxl_wds.py @@ -745,6 +745,16 @@ def parse_args(): " does not have `time_cond_proj_dim` set." ), ) + parser.add_argument( + "--vae_encode_batch_size", + type=int, + default=8, + required=False, + help=( + "The batch size used when encoding (and decoding) images to latents (and vice versa) using the VAE." + " Encoding or decoding the whole batch at once may run into OOM issues." + ), + ) # ----Exponential Moving Average (EMA)---- parser.add_argument( "--ema_decay", @@ -1273,10 +1283,10 @@ def compute_embeddings( else: pixel_values = image - # encode pixel values with batch size of at most 8 + # encode pixel values with batch size of at most args.vae_encode_batch_size latents = [] - for i in range(0, pixel_values.shape[0], 8): - latents.append(vae.encode(pixel_values[i : i + 8]).latent_dist.sample()) + for i in range(0, pixel_values.shape[0], args.vae_encode_batch_size): + latents.append(vae.encode(pixel_values[i : i + args.vae_encode_batch_size]).latent_dist.sample()) latents = torch.cat(latents, dim=0) latents = latents * vae.config.scaling_factor From f23743c4070b97ab34e34c58498c70fe30001c0e Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Mon, 1 Jan 2024 13:11:26 -0800 Subject: [PATCH 3/7] Make lora_alpha and lora_dropout configurable for LCM LoRA scripts. --- .../train_lcm_distill_lora_sd_wds.py | 17 +++++++++++++++++ .../train_lcm_distill_lora_sdxl.py | 18 +++++++++++++++++- .../train_lcm_distill_lora_sdxl_wds.py | 17 +++++++++++++++++ 3 files changed, 51 insertions(+), 1 deletion(-) diff --git a/examples/consistency_distillation/train_lcm_distill_lora_sd_wds.py b/examples/consistency_distillation/train_lcm_distill_lora_sd_wds.py index 1e7149df2ed5..464caa5a6de0 100644 --- a/examples/consistency_distillation/train_lcm_distill_lora_sd_wds.py +++ b/examples/consistency_distillation/train_lcm_distill_lora_sd_wds.py @@ -740,6 +740,21 @@ def parse_args(): default=64, help="The rank of the LoRA projection matrix.", ) + parser.add_argument( + "--lora_alpha", + type=int, + default=64, + help=( + "The value of the LoRA alpha parameter, which controls the scaling factor in front of the LoRA weight" + " update delta_W. No scaling will be performed if this value is equal to `lora_rank`." + ), + ) + parser.add_argument( + "--lora_dropout", + type=float, + default=0.0, + help="The dropout probability for the dropout layer added before applying the LoRA to each layer input.", + ) parser.add_argument( "--vae_encode_batch_size", type=int, @@ -957,6 +972,8 @@ def main(args): # 8. Add LoRA to the student U-Net, only the LoRA projection matrix will be updated by the optimizer. lora_config = LoraConfig( r=args.lora_rank, + lora_alpha=args.lora_alpha, + lora_dropout=args.lora_dropout, target_modules=[ "to_q", "to_k", diff --git a/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py b/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py index ace9cdfb7d94..80474895b436 100644 --- a/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py +++ b/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py @@ -534,6 +534,21 @@ def parse_args(): default=64, help="The rank of the LoRA projection matrix.", ) + parser.add_argument( + "--lora_alpha", + type=int, + default=64, + help=( + "The value of the LoRA alpha parameter, which controls the scaling factor in front of the LoRA weight" + " update delta_W. No scaling will be performed if this value is equal to `lora_rank`." + ), + ) + parser.add_argument( + "--lora_dropout", + type=float, + default=0.0, + help="The dropout probability for the dropout layer added before applying the LoRA to each layer input.", + ) parser.add_argument( "--vae_encode_batch_size", type=int, @@ -788,7 +803,8 @@ def main(args): # 9. Add LoRA to the student U-Net, only the LoRA projection matrix will be updated by the optimizer. lora_config = LoraConfig( r=args.lora_rank, - lora_alpha=args.lora_rank, + lora_alpha=args.lora_alpha, + lora_dropout=args.lora_dropout, target_modules=[ "to_q", "to_k", diff --git a/examples/consistency_distillation/train_lcm_distill_lora_sdxl_wds.py b/examples/consistency_distillation/train_lcm_distill_lora_sdxl_wds.py index d5e1be62aa35..259a39068d2a 100644 --- a/examples/consistency_distillation/train_lcm_distill_lora_sdxl_wds.py +++ b/examples/consistency_distillation/train_lcm_distill_lora_sdxl_wds.py @@ -720,6 +720,21 @@ def parse_args(): default=64, help="The rank of the LoRA projection matrix.", ) + parser.add_argument( + "--lora_alpha", + type=int, + default=64, + help=( + "The value of the LoRA alpha parameter, which controls the scaling factor in front of the LoRA weight" + " update delta_W. No scaling will be performed if this value is equal to `lora_rank`." + ), + ) + parser.add_argument( + "--lora_dropout", + type=float, + default=0.0, + help="The dropout probability for the dropout layer added before applying the LoRA to each layer input.", + ) parser.add_argument( "--vae_encode_batch_size", type=int, @@ -971,6 +986,8 @@ def main(args): # 8. Add LoRA to the student U-Net, only the LoRA projection matrix will be updated by the optimizer. lora_config = LoraConfig( r=args.lora_rank, + lora_alpha=args.lora_alpha, + lora_dropout=args.lora_dropout, target_modules=[ "to_q", "to_k", From f06df327b6d3e1a172b316e3766d35485b0dc792 Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Mon, 1 Jan 2024 13:28:01 -0800 Subject: [PATCH 4/7] Generalize scalings_for_boundary_conditions function and make the timestep scaling configurable. --- .../train_lcm_distill_lora_sd_wds.py | 23 +++++++++++++++---- .../train_lcm_distill_lora_sdxl.py | 23 +++++++++++++++---- .../train_lcm_distill_lora_sdxl_wds.py | 23 +++++++++++++++---- .../train_lcm_distill_sd_wds.py | 23 +++++++++++++++---- .../train_lcm_distill_sdxl_wds.py | 23 +++++++++++++++---- 5 files changed, 95 insertions(+), 20 deletions(-) diff --git a/examples/consistency_distillation/train_lcm_distill_lora_sd_wds.py b/examples/consistency_distillation/train_lcm_distill_lora_sd_wds.py index 464caa5a6de0..ea240fe698c4 100644 --- a/examples/consistency_distillation/train_lcm_distill_lora_sd_wds.py +++ b/examples/consistency_distillation/train_lcm_distill_lora_sd_wds.py @@ -374,8 +374,9 @@ def append_dims(x, target_dims): # From LCMScheduler.get_scalings_for_boundary_condition_discrete def scalings_for_boundary_conditions(timestep, sigma_data=0.5, timestep_scaling=10.0): - c_skip = sigma_data**2 / ((timestep / 0.1) ** 2 + sigma_data**2) - c_out = (timestep / 0.1) / ((timestep / 0.1) ** 2 + sigma_data**2) ** 0.5 + scaled_timestep = timestep_scaling * timestep + c_skip = sigma_data**2 / (scaled_timestep**2 + sigma_data**2) + c_out = scaled_timestep / (scaled_timestep**2 + sigma_data**2) ** 0.5 return c_skip, c_out @@ -765,6 +766,16 @@ def parse_args(): " Encoding or decoding the whole batch at once may run into OOM issues." ), ) + parser.add_argument( + "--timestep_scaling_factor", + type=float, + default=10.0, + help=( + "The multiplicative timestep scaling factor used when calculating the boundary scalings for LCM. The" + " higher the scaling is, the lower the approximation error, but the default value of 10.0 should typically" + " suffice." + ), + ) # ----Mixed Precision---- parser.add_argument( "--mixed_precision", @@ -1239,9 +1250,13 @@ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tok timesteps = torch.where(timesteps < 0, torch.zeros_like(timesteps), timesteps) # 3. Get boundary scalings for start_timesteps and (end) timesteps. - c_skip_start, c_out_start = scalings_for_boundary_conditions(start_timesteps) + c_skip_start, c_out_start = scalings_for_boundary_conditions( + start_timesteps, timestep_scaling=args.timestep_scaling_factor + ) c_skip_start, c_out_start = [append_dims(x, latents.ndim) for x in [c_skip_start, c_out_start]] - c_skip, c_out = scalings_for_boundary_conditions(timesteps) + c_skip, c_out = scalings_for_boundary_conditions( + timesteps, timestep_scaling=args.timestep_scaling_factor + ) c_skip, c_out = [append_dims(x, latents.ndim) for x in [c_skip, c_out]] # 4. Sample noise from the prior and add it to the latents according to the noise magnitude at each diff --git a/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py b/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py index 80474895b436..ca527c052f35 100644 --- a/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py +++ b/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py @@ -193,8 +193,9 @@ def append_dims(x, target_dims): # From LCMScheduler.get_scalings_for_boundary_condition_discrete def scalings_for_boundary_conditions(timestep, sigma_data=0.5, timestep_scaling=10.0): - c_skip = sigma_data**2 / ((timestep / 0.1) ** 2 + sigma_data**2) - c_out = (timestep / 0.1) / ((timestep / 0.1) ** 2 + sigma_data**2) ** 0.5 + scaled_timestep = timestep_scaling * timestep + c_skip = sigma_data**2 / (scaled_timestep**2 + sigma_data**2) + c_out = scaled_timestep / (scaled_timestep**2 + sigma_data**2) ** 0.5 return c_skip, c_out @@ -559,6 +560,16 @@ def parse_args(): " Encoding or decoding the whole batch at once may run into OOM issues." ), ) + parser.add_argument( + "--timestep_scaling_factor", + type=float, + default=10.0, + help=( + "The multiplicative timestep scaling factor used when calculating the boundary scalings for LCM. The" + " higher the scaling is, the lower the approximation error, but the default value of 10.0 should typically" + " suffice." + ), + ) # ----Mixed Precision---- parser.add_argument( "--mixed_precision", @@ -1168,9 +1179,13 @@ def compute_time_ids(original_size, crops_coords_top_left): timesteps = torch.where(timesteps < 0, torch.zeros_like(timesteps), timesteps) # 3. Get boundary scalings for start_timesteps and (end) timesteps. - c_skip_start, c_out_start = scalings_for_boundary_conditions(start_timesteps) + c_skip_start, c_out_start = scalings_for_boundary_conditions( + start_timesteps, timestep_scaling=args.timestep_scaling_factor + ) c_skip_start, c_out_start = [append_dims(x, latents.ndim) for x in [c_skip_start, c_out_start]] - c_skip, c_out = scalings_for_boundary_conditions(timesteps) + c_skip, c_out = scalings_for_boundary_conditions( + timesteps, timestep_scaling=args.timestep_scaling_factor + ) c_skip, c_out = [append_dims(x, latents.ndim) for x in [c_skip, c_out]] # 4. Sample noise from the prior and add it to the latents according to the noise magnitude at each diff --git a/examples/consistency_distillation/train_lcm_distill_lora_sdxl_wds.py b/examples/consistency_distillation/train_lcm_distill_lora_sdxl_wds.py index 259a39068d2a..1ef944bebaf4 100644 --- a/examples/consistency_distillation/train_lcm_distill_lora_sdxl_wds.py +++ b/examples/consistency_distillation/train_lcm_distill_lora_sdxl_wds.py @@ -361,8 +361,9 @@ def append_dims(x, target_dims): # From LCMScheduler.get_scalings_for_boundary_condition_discrete def scalings_for_boundary_conditions(timestep, sigma_data=0.5, timestep_scaling=10.0): - c_skip = sigma_data**2 / ((timestep / 0.1) ** 2 + sigma_data**2) - c_out = (timestep / 0.1) / ((timestep / 0.1) ** 2 + sigma_data**2) ** 0.5 + scaled_timestep = timestep_scaling * timestep + c_skip = sigma_data**2 / (scaled_timestep**2 + sigma_data**2) + c_out = scaled_timestep / (scaled_timestep**2 + sigma_data**2) ** 0.5 return c_skip, c_out @@ -745,6 +746,16 @@ def parse_args(): " Encoding or decoding the whole batch at once may run into OOM issues." ), ) + parser.add_argument( + "--timestep_scaling_factor", + type=float, + default=10.0, + help=( + "The multiplicative timestep scaling factor used when calculating the boundary scalings for LCM. The" + " higher the scaling is, the lower the approximation error, but the default value of 10.0 should typically" + " suffice." + ), + ) # ----Mixed Precision---- parser.add_argument( "--mixed_precision", @@ -1292,9 +1303,13 @@ def compute_embeddings( timesteps = torch.where(timesteps < 0, torch.zeros_like(timesteps), timesteps) # 3. Get boundary scalings for start_timesteps and (end) timesteps. - c_skip_start, c_out_start = scalings_for_boundary_conditions(start_timesteps) + c_skip_start, c_out_start = scalings_for_boundary_conditions( + start_timesteps, timestep_scaling=args.timestep_scaling_factor + ) c_skip_start, c_out_start = [append_dims(x, latents.ndim) for x in [c_skip_start, c_out_start]] - c_skip, c_out = scalings_for_boundary_conditions(timesteps) + c_skip, c_out = scalings_for_boundary_conditions( + timesteps, timestep_scaling=args.timestep_scaling_factor + ) c_skip, c_out = [append_dims(x, latents.ndim) for x in [c_skip, c_out]] # 4. Sample noise from the prior and add it to the latents according to the noise magnitude at each diff --git a/examples/consistency_distillation/train_lcm_distill_sd_wds.py b/examples/consistency_distillation/train_lcm_distill_sd_wds.py index e190fea194dd..2e0f2cfb75db 100644 --- a/examples/consistency_distillation/train_lcm_distill_sd_wds.py +++ b/examples/consistency_distillation/train_lcm_distill_sd_wds.py @@ -351,8 +351,9 @@ def append_dims(x, target_dims): # From LCMScheduler.get_scalings_for_boundary_condition_discrete def scalings_for_boundary_conditions(timestep, sigma_data=0.5, timestep_scaling=10.0): - c_skip = sigma_data**2 / ((timestep / 0.1) ** 2 + sigma_data**2) - c_out = (timestep / 0.1) / ((timestep / 0.1) ** 2 + sigma_data**2) ** 0.5 + scaled_timestep = timestep_scaling * timestep + c_skip = sigma_data**2 / (scaled_timestep**2 + sigma_data**2) + c_out = scaled_timestep / (scaled_timestep**2 + sigma_data**2) ** 0.5 return c_skip, c_out @@ -730,6 +731,16 @@ def parse_args(): " Encoding or decoding the whole batch at once may run into OOM issues." ), ) + parser.add_argument( + "--timestep_scaling_factor", + type=float, + default=10.0, + help=( + "The multiplicative timestep scaling factor used when calculating the boundary scalings for LCM. The" + " higher the scaling is, the lower the approximation error, but the default value of 10.0 should typically" + " suffice." + ), + ) # ----Exponential Moving Average (EMA)---- parser.add_argument( "--ema_decay", @@ -1205,9 +1216,13 @@ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tok timesteps = torch.where(timesteps < 0, torch.zeros_like(timesteps), timesteps) # 3. Get boundary scalings for start_timesteps and (end) timesteps. - c_skip_start, c_out_start = scalings_for_boundary_conditions(start_timesteps) + c_skip_start, c_out_start = scalings_for_boundary_conditions( + start_timesteps, timestep_scaling=args.timestep_scaling_factor + ) c_skip_start, c_out_start = [append_dims(x, latents.ndim) for x in [c_skip_start, c_out_start]] - c_skip, c_out = scalings_for_boundary_conditions(timesteps) + c_skip, c_out = scalings_for_boundary_conditions( + timesteps, timestep_scaling=args.timestep_scaling_factor + ) c_skip, c_out = [append_dims(x, latents.ndim) for x in [c_skip, c_out]] # 4. Sample noise from the prior and add it to the latents according to the noise magnitude at each diff --git a/examples/consistency_distillation/train_lcm_distill_sdxl_wds.py b/examples/consistency_distillation/train_lcm_distill_sdxl_wds.py index b669025807b4..01ee838e1679 100644 --- a/examples/consistency_distillation/train_lcm_distill_sdxl_wds.py +++ b/examples/consistency_distillation/train_lcm_distill_sdxl_wds.py @@ -339,8 +339,9 @@ def append_dims(x, target_dims): # From LCMScheduler.get_scalings_for_boundary_condition_discrete def scalings_for_boundary_conditions(timestep, sigma_data=0.5, timestep_scaling=10.0): - c_skip = sigma_data**2 / ((timestep / 0.1) ** 2 + sigma_data**2) - c_out = (timestep / 0.1) / ((timestep / 0.1) ** 2 + sigma_data**2) ** 0.5 + scaled_timestep = timestep_scaling * timestep + c_skip = sigma_data**2 / (scaled_timestep**2 + sigma_data**2) + c_out = scaled_timestep / (scaled_timestep**2 + sigma_data**2) ** 0.5 return c_skip, c_out @@ -755,6 +756,16 @@ def parse_args(): " Encoding or decoding the whole batch at once may run into OOM issues." ), ) + parser.add_argument( + "--timestep_scaling_factor", + type=float, + default=10.0, + help=( + "The multiplicative timestep scaling factor used when calculating the boundary scalings for LCM. The" + " higher the scaling is, the lower the approximation error, but the default value of 10.0 should typically" + " suffice." + ), + ) # ----Exponential Moving Average (EMA)---- parser.add_argument( "--ema_decay", @@ -1303,9 +1314,13 @@ def compute_embeddings( timesteps = torch.where(timesteps < 0, torch.zeros_like(timesteps), timesteps) # 3. Get boundary scalings for start_timesteps and (end) timesteps. - c_skip_start, c_out_start = scalings_for_boundary_conditions(start_timesteps) + c_skip_start, c_out_start = scalings_for_boundary_conditions( + start_timesteps, timestep_scaling=args.timestep_scaling_factor + ) c_skip_start, c_out_start = [append_dims(x, latents.ndim) for x in [c_skip_start, c_out_start]] - c_skip, c_out = scalings_for_boundary_conditions(timesteps) + c_skip, c_out = scalings_for_boundary_conditions( + timesteps, timestep_scaling=args.timestep_scaling_factor + ) c_skip, c_out = [append_dims(x, latents.ndim) for x in [c_skip, c_out]] # 4. Sample noise from the prior and add it to the latents according to the noise magnitude at each From 6e5b2be2fa880fb3b5ebbf3030953ee8703f20fa Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Tue, 2 Jan 2024 12:55:01 -0800 Subject: [PATCH 5/7] Make LoRA target modules configurable for LCM-LoRA scripts. --- .../train_lcm_distill_lora_sd_wds.py | 25 ++++++++++++++----- .../train_lcm_distill_lora_sdxl.py | 25 ++++++++++++++----- .../train_lcm_distill_lora_sdxl_wds.py | 25 ++++++++++++++----- 3 files changed, 57 insertions(+), 18 deletions(-) diff --git a/examples/consistency_distillation/train_lcm_distill_lora_sd_wds.py b/examples/consistency_distillation/train_lcm_distill_lora_sd_wds.py index ea240fe698c4..6be45c0e2278 100644 --- a/examples/consistency_distillation/train_lcm_distill_lora_sd_wds.py +++ b/examples/consistency_distillation/train_lcm_distill_lora_sd_wds.py @@ -756,6 +756,15 @@ def parse_args(): default=0.0, help="The dropout probability for the dropout layer added before applying the LoRA to each layer input.", ) + parser.add_argument( + "--lora_target_modules", + type=str, + default=None, + help=( + "A comma-separated string of target module keys to add LoRA to. If not set, a default list of modules will" + " be used." + ), + ) parser.add_argument( "--vae_encode_batch_size", type=int, @@ -981,11 +990,10 @@ def main(args): ) # 8. Add LoRA to the student U-Net, only the LoRA projection matrix will be updated by the optimizer. - lora_config = LoraConfig( - r=args.lora_rank, - lora_alpha=args.lora_alpha, - lora_dropout=args.lora_dropout, - target_modules=[ + if args.lora_target_modules is not None: + lora_target_modules = [module_key.strip() for module_key in args.lora_target_modules.split(",")] + else: + lora_target_modules = [ "to_q", "to_k", "to_v", @@ -1000,7 +1008,12 @@ def main(args): "downsamplers.0.conv", "upsamplers.0.conv", "time_emb_proj", - ], + ] + lora_config = LoraConfig( + r=args.lora_rank, + target_modules=lora_target_modules, + lora_alpha=args.lora_alpha, + lora_dropout=args.lora_dropout, ) unet = get_peft_model(unet, lora_config) diff --git a/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py b/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py index ca527c052f35..39bd657e2142 100644 --- a/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py +++ b/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py @@ -550,6 +550,15 @@ def parse_args(): default=0.0, help="The dropout probability for the dropout layer added before applying the LoRA to each layer input.", ) + parser.add_argument( + "--lora_target_modules", + type=str, + default=None, + help=( + "A comma-separated string of target module keys to add LoRA to. If not set, a default list of modules will" + " be used." + ), + ) parser.add_argument( "--vae_encode_batch_size", type=int, @@ -812,11 +821,10 @@ def main(args): text_encoder_two.to(accelerator.device, dtype=weight_dtype) # 9. Add LoRA to the student U-Net, only the LoRA projection matrix will be updated by the optimizer. - lora_config = LoraConfig( - r=args.lora_rank, - lora_alpha=args.lora_alpha, - lora_dropout=args.lora_dropout, - target_modules=[ + if args.lora_target_modules is not None: + lora_target_modules = [module_key.strip() for module_key in args.lora_target_modules.split(",")] + else: + lora_target_modules = [ "to_q", "to_k", "to_v", @@ -831,7 +839,12 @@ def main(args): "downsamplers.0.conv", "upsamplers.0.conv", "time_emb_proj", - ], + ] + lora_config = LoraConfig( + r=args.lora_rank, + target_modules=lora_target_modules, + lora_alpha=args.lora_alpha, + lora_dropout=args.lora_dropout, ) unet.add_adapter(lora_config) diff --git a/examples/consistency_distillation/train_lcm_distill_lora_sdxl_wds.py b/examples/consistency_distillation/train_lcm_distill_lora_sdxl_wds.py index 1ef944bebaf4..19b4c14ba2c3 100644 --- a/examples/consistency_distillation/train_lcm_distill_lora_sdxl_wds.py +++ b/examples/consistency_distillation/train_lcm_distill_lora_sdxl_wds.py @@ -736,6 +736,15 @@ def parse_args(): default=0.0, help="The dropout probability for the dropout layer added before applying the LoRA to each layer input.", ) + parser.add_argument( + "--lora_target_modules", + type=str, + default=None, + help=( + "A comma-separated string of target module keys to add LoRA to. If not set, a default list of modules will" + " be used." + ), + ) parser.add_argument( "--vae_encode_batch_size", type=int, @@ -995,11 +1004,10 @@ def main(args): ) # 8. Add LoRA to the student U-Net, only the LoRA projection matrix will be updated by the optimizer. - lora_config = LoraConfig( - r=args.lora_rank, - lora_alpha=args.lora_alpha, - lora_dropout=args.lora_dropout, - target_modules=[ + if args.lora_target_modules is not None: + lora_target_modules = [module_key.strip() for module_key in args.lora_target_modules.split(",")] + else: + lora_target_modules = [ "to_q", "to_k", "to_v", @@ -1014,7 +1022,12 @@ def main(args): "downsamplers.0.conv", "upsamplers.0.conv", "time_emb_proj", - ], + ] + lora_config = LoraConfig( + r=args.lora_rank, + target_modules=lora_target_modules, + lora_alpha=args.lora_alpha, + lora_dropout=args.lora_dropout, ) unet = get_peft_model(unet, lora_config) From f1c131b734f5df12557f498a3e3c814ada69c131 Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Tue, 2 Jan 2024 13:48:38 -0800 Subject: [PATCH 6/7] Move resolve_interpolation_mode to src/diffusers/training_utils.py and make interpolation type configurable in non-WDS script. --- .../train_lcm_distill_lora_sd_wds.py | 21 +--------- .../train_lcm_distill_lora_sdxl.py | 13 +++++- .../train_lcm_distill_lora_sdxl_wds.py | 21 +--------- .../train_lcm_distill_sd_wds.py | 21 +--------- .../train_lcm_distill_sdxl_wds.py | 21 +--------- src/diffusers/training_utils.py | 40 +++++++++++++++++++ 6 files changed, 60 insertions(+), 77 deletions(-) diff --git a/examples/consistency_distillation/train_lcm_distill_lora_sd_wds.py b/examples/consistency_distillation/train_lcm_distill_lora_sd_wds.py index 6be45c0e2278..cb574372fd96 100644 --- a/examples/consistency_distillation/train_lcm_distill_lora_sd_wds.py +++ b/examples/consistency_distillation/train_lcm_distill_lora_sd_wds.py @@ -61,6 +61,7 @@ UNet2DConditionModel, ) from diffusers.optimization import get_scheduler +from diffusers.training_utils import resolve_interpolation_mode from diffusers.utils import check_min_version, is_wandb_available from diffusers.utils.import_utils import is_xformers_available @@ -136,24 +137,6 @@ def tarfile_to_samples_nothrow(src, handler=wds.warn_and_continue): return samples -def resolve_interpolation_mode(interpolation_type): - if interpolation_type == "bilinear": - interpolation_mode = TF.InterpolationMode.BILINEAR - elif interpolation_type == "bicubic": - interpolation_mode = TF.InterpolationMode.BICUBIC - elif interpolation_type == "nearest": - interpolation_mode = TF.InterpolationMode.NEAREST - elif interpolation_type == "lanczos": - interpolation_mode = TF.InterpolationMode.LANCZOS - else: - raise ValueError( - f"The given interpolation mode {interpolation_type} is not supported. Currently supported interpolation" - f" modes are `bilinear`, `bicubic`, `lanczos`, and `nearest`." - ) - - return interpolation_mode - - class WebdatasetFilter: def __init__(self, min_size=1024, max_pwatermark=0.5): self.min_size = min_size @@ -600,7 +583,7 @@ def parse_args(): default="bilinear", help=( "The interpolation function used when resizing images to the desired resolution. Choose between `bilinear`," - " `bicubic`, `lanczos`, and `nearest`." + " `bicubic`, `box`, `nearest`, `nearest_exact`, `hamming`, and `lanczos`." ), ) parser.add_argument( diff --git a/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py b/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py index 39bd657e2142..43491e38f89a 100644 --- a/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py +++ b/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py @@ -51,6 +51,7 @@ UNet2DConditionModel, ) from diffusers.optimization import get_scheduler +from diffusers.training_utils import resolve_interpolation_mode from diffusers.utils import check_min_version, convert_state_dict_to_diffusers, is_wandb_available from diffusers.utils.import_utils import is_xformers_available @@ -397,6 +398,15 @@ def parse_args(): " resolution" ), ) + parser.add_argument( + "--interpolation_type", + type=str, + default="bilinear", + help=( + "The interpolation function used when resizing images to the desired resolution. Choose between `bilinear`," + " `bicubic`, `box`, `nearest`, `nearest_exact`, `hamming`, and `lanczos`." + ), + ) parser.add_argument( "--center_crop", default=False, @@ -979,7 +989,8 @@ def load_model_hook(models, input_dir): ) # Preprocessing the datasets. - train_resize = transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR) + interpolation_mode = resolve_interpolation_mode(args.interpolation_type) + train_resize = transforms.Resize(args.resolution, interpolation=interpolation_mode) train_crop = transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution) train_flip = transforms.RandomHorizontalFlip(p=1.0) train_transforms = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]) diff --git a/examples/consistency_distillation/train_lcm_distill_lora_sdxl_wds.py b/examples/consistency_distillation/train_lcm_distill_lora_sdxl_wds.py index 19b4c14ba2c3..0d962ee7dd49 100644 --- a/examples/consistency_distillation/train_lcm_distill_lora_sdxl_wds.py +++ b/examples/consistency_distillation/train_lcm_distill_lora_sdxl_wds.py @@ -62,6 +62,7 @@ UNet2DConditionModel, ) from diffusers.optimization import get_scheduler +from diffusers.training_utils import resolve_interpolation_mode from diffusers.utils import check_min_version, is_wandb_available from diffusers.utils.import_utils import is_xformers_available @@ -142,24 +143,6 @@ def tarfile_to_samples_nothrow(src, handler=wds.warn_and_continue): return samples -def resolve_interpolation_mode(interpolation_type): - if interpolation_type == "bilinear": - interpolation_mode = TF.InterpolationMode.BILINEAR - elif interpolation_type == "bicubic": - interpolation_mode = TF.InterpolationMode.BICUBIC - elif interpolation_type == "nearest": - interpolation_mode = TF.InterpolationMode.NEAREST - elif interpolation_type == "lanczos": - interpolation_mode = TF.InterpolationMode.LANCZOS - else: - raise ValueError( - f"The given interpolation mode {interpolation_type} is not supported. Currently supported interpolation" - f" modes are `bilinear`, `bicubic`, `lanczos`, and `nearest`." - ) - - return interpolation_mode - - class WebdatasetFilter: def __init__(self, min_size=1024, max_pwatermark=0.5): self.min_size = min_size @@ -574,7 +557,7 @@ def parse_args(): default="bilinear", help=( "The interpolation function used when resizing images to the desired resolution. Choose between `bilinear`," - " `bicubic`, `lanczos`, and `nearest`." + " `bicubic`, `box`, `nearest`, `nearest_exact`, `hamming`, and `lanczos`." ), ) parser.add_argument( diff --git a/examples/consistency_distillation/train_lcm_distill_sd_wds.py b/examples/consistency_distillation/train_lcm_distill_sd_wds.py index 2e0f2cfb75db..b6bccfbb82b3 100644 --- a/examples/consistency_distillation/train_lcm_distill_sd_wds.py +++ b/examples/consistency_distillation/train_lcm_distill_sd_wds.py @@ -60,6 +60,7 @@ UNet2DConditionModel, ) from diffusers.optimization import get_scheduler +from diffusers.training_utils import resolve_interpolation_mode from diffusers.utils import check_min_version, is_wandb_available from diffusers.utils.import_utils import is_xformers_available @@ -118,24 +119,6 @@ def tarfile_to_samples_nothrow(src, handler=wds.warn_and_continue): return samples -def resolve_interpolation_mode(interpolation_type): - if interpolation_type == "bilinear": - interpolation_mode = TF.InterpolationMode.BILINEAR - elif interpolation_type == "bicubic": - interpolation_mode = TF.InterpolationMode.BICUBIC - elif interpolation_type == "nearest": - interpolation_mode = TF.InterpolationMode.NEAREST - elif interpolation_type == "lanczos": - interpolation_mode = TF.InterpolationMode.LANCZOS - else: - raise ValueError( - f"The given interpolation mode {interpolation_type} is not supported. Currently supported interpolation" - f" modes are `bilinear`, `bicubic`, `lanczos`, and `nearest`." - ) - - return interpolation_mode - - class WebdatasetFilter: def __init__(self, min_size=1024, max_pwatermark=0.5): self.min_size = min_size @@ -577,7 +560,7 @@ def parse_args(): default="bilinear", help=( "The interpolation function used when resizing images to the desired resolution. Choose between `bilinear`," - " `bicubic`, `lanczos`, and `nearest`." + " `bicubic`, `box`, `nearest`, `nearest_exact`, `hamming`, and `lanczos`." ), ) parser.add_argument( diff --git a/examples/consistency_distillation/train_lcm_distill_sdxl_wds.py b/examples/consistency_distillation/train_lcm_distill_sdxl_wds.py index 01ee838e1679..dab38c96a254 100644 --- a/examples/consistency_distillation/train_lcm_distill_sdxl_wds.py +++ b/examples/consistency_distillation/train_lcm_distill_sdxl_wds.py @@ -61,6 +61,7 @@ UNet2DConditionModel, ) from diffusers.optimization import get_scheduler +from diffusers.training_utils import resolve_interpolation_mode from diffusers.utils import check_min_version, is_wandb_available from diffusers.utils.import_utils import is_xformers_available @@ -124,24 +125,6 @@ def tarfile_to_samples_nothrow(src, handler=wds.warn_and_continue): return samples -def resolve_interpolation_mode(interpolation_type): - if interpolation_type == "bilinear": - interpolation_mode = TF.InterpolationMode.BILINEAR - elif interpolation_type == "bicubic": - interpolation_mode = TF.InterpolationMode.BICUBIC - elif interpolation_type == "nearest": - interpolation_mode = TF.InterpolationMode.NEAREST - elif interpolation_type == "lanczos": - interpolation_mode = TF.InterpolationMode.LANCZOS - else: - raise ValueError( - f"The given interpolation mode {interpolation_type} is not supported. Currently supported interpolation" - f" modes are `bilinear`, `bicubic`, `lanczos`, and `nearest`." - ) - - return interpolation_mode - - class WebdatasetFilter: def __init__(self, min_size=1024, max_pwatermark=0.5): self.min_size = min_size @@ -596,7 +579,7 @@ def parse_args(): default="bilinear", help=( "The interpolation function used when resizing images to the desired resolution. Choose between `bilinear`," - " `bicubic`, `lanczos`, and `nearest`." + " `bicubic`, `box`, `nearest`, `nearest_exact`, `hamming`, and `lanczos`." ), ) parser.add_argument( diff --git a/src/diffusers/training_utils.py b/src/diffusers/training_utils.py index 992ae7d1b194..9fb6fad3a267 100644 --- a/src/diffusers/training_utils.py +++ b/src/diffusers/training_utils.py @@ -5,6 +5,7 @@ import numpy as np import torch +from torchvision import transforms from .models import UNet2DConditionModel from .utils import deprecate, is_transformers_available @@ -53,6 +54,45 @@ def compute_snr(noise_scheduler, timesteps): return snr +def resolve_interpolation_mode(interpolation_type: str): + """ + Maps a string describing an interpolation function to the corresponding torchvision `InterpolationMode` enum. The + full list of supported enums is documented at + https://pytorch.org/vision/0.9/transforms.html#torchvision.transforms.functional.InterpolationMode. + + Args: + interpolation_type (`str`): + A string describing an interpolation method. Currently, `bilinear`, `bicubic`, `box`, `nearest`, + `nearest_exact`, `hamming`, and `lanczos` are supported, corresponding to the supported interpolation modes + in torchvision. + + Returns: + `torchvision.transforms.InterpolationMode`: an `InterpolationMode` enum used by torchvision's `resize` + transform. + """ + if interpolation_type == "bilinear": + interpolation_mode = transforms.InterpolationMode.BILINEAR + elif interpolation_type == "bicubic": + interpolation_mode = transforms.InterpolationMode.BICUBIC + elif interpolation_type == "box": + interpolation_mode = transforms.InterpolationMode.BOX + elif interpolation_type == "nearest": + interpolation_mode = transforms.InterpolationMode.NEAREST + elif interpolation_type == "nearest_exact": + interpolation_mode = transforms.InterpolationMode.NEAREST_EXACT + elif interpolation_type == "hamming": + interpolation_mode = transforms.InterpolationMode.HAMMING + elif interpolation_type == "lanczos": + interpolation_mode = transforms.InterpolationMode.LANCZOS + else: + raise ValueError( + f"The given interpolation mode {interpolation_type} is not supported. Currently supported interpolation" + f" modes are `bilinear`, `bicubic`, `box`, `nearest`, `nearest_exact`, `hamming`, and `lanczos`." + ) + + return interpolation_mode + + def unet_lora_state_dict(unet: UNet2DConditionModel) -> Dict[str, torch.Tensor]: r""" Returns: From f3b7595e8adeca4c79fc04af2e88a67a536eddab Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Thu, 4 Jan 2024 10:32:49 -0800 Subject: [PATCH 7/7] apply suggestions from review --- .../consistency_distillation/train_lcm_distill_lora_sd_wds.py | 2 +- .../consistency_distillation/train_lcm_distill_lora_sdxl.py | 2 +- .../consistency_distillation/train_lcm_distill_lora_sdxl_wds.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/consistency_distillation/train_lcm_distill_lora_sd_wds.py b/examples/consistency_distillation/train_lcm_distill_lora_sd_wds.py index cb574372fd96..0feb1108027a 100644 --- a/examples/consistency_distillation/train_lcm_distill_lora_sd_wds.py +++ b/examples/consistency_distillation/train_lcm_distill_lora_sd_wds.py @@ -745,7 +745,7 @@ def parse_args(): default=None, help=( "A comma-separated string of target module keys to add LoRA to. If not set, a default list of modules will" - " be used." + " be used. By default, LoRA will be applied to all conv and linear layers." ), ) parser.add_argument( diff --git a/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py b/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py index 43491e38f89a..fd8ce5f8cb51 100644 --- a/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py +++ b/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py @@ -566,7 +566,7 @@ def parse_args(): default=None, help=( "A comma-separated string of target module keys to add LoRA to. If not set, a default list of modules will" - " be used." + " be used. By default, LoRA will be applied to all conv and linear layers." ), ) parser.add_argument( diff --git a/examples/consistency_distillation/train_lcm_distill_lora_sdxl_wds.py b/examples/consistency_distillation/train_lcm_distill_lora_sdxl_wds.py index 0d962ee7dd49..16d32c4280a6 100644 --- a/examples/consistency_distillation/train_lcm_distill_lora_sdxl_wds.py +++ b/examples/consistency_distillation/train_lcm_distill_lora_sdxl_wds.py @@ -725,7 +725,7 @@ def parse_args(): default=None, help=( "A comma-separated string of target module keys to add LoRA to. If not set, a default list of modules will" - " be used." + " be used. By default, LoRA will be applied to all conv and linear layers." ), ) parser.add_argument(