Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 81 additions & 12 deletions examples/consistency_distillation/train_lcm_distill_lora_sd_wds.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
UNet2DConditionModel,
)
from diffusers.optimization import get_scheduler
from diffusers.training_utils import resolve_interpolation_mode
from diffusers.utils import check_min_version, is_wandb_available
from diffusers.utils.import_utils import is_xformers_available

Expand Down Expand Up @@ -165,6 +166,7 @@ def __init__(
global_batch_size: int,
num_workers: int,
resolution: int = 512,
interpolation_type: str = "bilinear",
shuffle_buffer_size: int = 1000,
pin_memory: bool = False,
persistent_workers: bool = False,
Expand All @@ -174,10 +176,12 @@ def __init__(
# flatten list using itertools
train_shards_path_or_url = list(itertools.chain.from_iterable(train_shards_path_or_url))

interpolation_mode = resolve_interpolation_mode(interpolation_type)

def transform(example):
# resize image
image = example["image"]
image = TF.resize(image, resolution, interpolation=transforms.InterpolationMode.BILINEAR)
image = TF.resize(image, resolution, interpolation=interpolation_mode)

# get crop coordinates and crop image
c_top, c_left, _, _ = transforms.RandomCrop.get_params(image, output_size=(resolution, resolution))
Expand Down Expand Up @@ -353,8 +357,9 @@ def append_dims(x, target_dims):

# From LCMScheduler.get_scalings_for_boundary_condition_discrete
def scalings_for_boundary_conditions(timestep, sigma_data=0.5, timestep_scaling=10.0):
c_skip = sigma_data**2 / ((timestep / 0.1) ** 2 + sigma_data**2)
c_out = (timestep / 0.1) / ((timestep / 0.1) ** 2 + sigma_data**2) ** 0.5
scaled_timestep = timestep_scaling * timestep
c_skip = sigma_data**2 / (scaled_timestep**2 + sigma_data**2)
c_out = scaled_timestep / (scaled_timestep**2 + sigma_data**2) ** 0.5
return c_skip, c_out


Expand Down Expand Up @@ -572,6 +577,15 @@ def parse_args():
" resolution"
),
)
parser.add_argument(
"--interpolation_type",
type=str,
default="bilinear",
help=(
"The interpolation function used when resizing images to the desired resolution. Choose between `bilinear`,"
" `bicubic`, `box`, `nearest`, `nearest_exact`, `hamming`, and `lanczos`."
),
)
parser.add_argument(
"--center_crop",
default=False,
Expand Down Expand Up @@ -710,6 +724,50 @@ def parse_args():
default=64,
help="The rank of the LoRA projection matrix.",
)
parser.add_argument(
"--lora_alpha",
type=int,
default=64,
help=(
"The value of the LoRA alpha parameter, which controls the scaling factor in front of the LoRA weight"
" update delta_W. No scaling will be performed if this value is equal to `lora_rank`."
),
)
parser.add_argument(
"--lora_dropout",
type=float,
default=0.0,
help="The dropout probability for the dropout layer added before applying the LoRA to each layer input.",
)
parser.add_argument(
"--lora_target_modules",
type=str,
default=None,
help=(
"A comma-separated string of target module keys to add LoRA to. If not set, a default list of modules will"
" be used. By default, LoRA will be applied to all conv and linear layers."
),
)
parser.add_argument(
"--vae_encode_batch_size",
type=int,
default=32,
required=False,
help=(
"The batch size used when encoding (and decoding) images to latents (and vice versa) using the VAE."
" Encoding or decoding the whole batch at once may run into OOM issues."
),
)
parser.add_argument(
"--timestep_scaling_factor",
type=float,
default=10.0,
help=(
"The multiplicative timestep scaling factor used when calculating the boundary scalings for LCM. The"
" higher the scaling is, the lower the approximation error, but the default value of 10.0 should typically"
" suffice."
),
)
# ----Mixed Precision----
parser.add_argument(
"--mixed_precision",
Expand Down Expand Up @@ -915,9 +973,10 @@ def main(args):
)

# 8. Add LoRA to the student U-Net, only the LoRA projection matrix will be updated by the optimizer.
lora_config = LoraConfig(
r=args.lora_rank,
target_modules=[
if args.lora_target_modules is not None:
lora_target_modules = [module_key.strip() for module_key in args.lora_target_modules.split(",")]
else:
lora_target_modules = [
"to_q",
"to_k",
"to_v",
Expand All @@ -932,7 +991,12 @@ def main(args):
"downsamplers.0.conv",
"upsamplers.0.conv",
"time_emb_proj",
],
]
lora_config = LoraConfig(
r=args.lora_rank,
target_modules=lora_target_modules,
lora_alpha=args.lora_alpha,
lora_dropout=args.lora_dropout,
)
unet = get_peft_model(unet, lora_config)

Expand Down Expand Up @@ -1051,6 +1115,7 @@ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tok
global_batch_size=args.train_batch_size * accelerator.num_processes,
num_workers=args.dataloader_num_workers,
resolution=args.resolution,
interpolation_type=args.interpolation_type,
shuffle_buffer_size=1000,
pin_memory=True,
persistent_workers=True,
Expand Down Expand Up @@ -1162,10 +1227,10 @@ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tok
if vae.dtype != weight_dtype:
vae.to(dtype=weight_dtype)

# encode pixel values with batch size of at most 32
# encode pixel values with batch size of at most args.vae_encode_batch_size
latents = []
for i in range(0, pixel_values.shape[0], 32):
latents.append(vae.encode(pixel_values[i : i + 32]).latent_dist.sample())
for i in range(0, pixel_values.shape[0], args.vae_encode_batch_size):
latents.append(vae.encode(pixel_values[i : i + args.vae_encode_batch_size]).latent_dist.sample())
latents = torch.cat(latents, dim=0)

latents = latents * vae.config.scaling_factor
Expand All @@ -1181,9 +1246,13 @@ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tok
timesteps = torch.where(timesteps < 0, torch.zeros_like(timesteps), timesteps)

# 3. Get boundary scalings for start_timesteps and (end) timesteps.
c_skip_start, c_out_start = scalings_for_boundary_conditions(start_timesteps)
c_skip_start, c_out_start = scalings_for_boundary_conditions(
start_timesteps, timestep_scaling=args.timestep_scaling_factor
)
c_skip_start, c_out_start = [append_dims(x, latents.ndim) for x in [c_skip_start, c_out_start]]
c_skip, c_out = scalings_for_boundary_conditions(timesteps)
c_skip, c_out = scalings_for_boundary_conditions(
timesteps, timestep_scaling=args.timestep_scaling_factor
)
c_skip, c_out = [append_dims(x, latents.ndim) for x in [c_skip, c_out]]

# 4. Sample noise from the prior and add it to the latents according to the noise magnitude at each
Expand Down
91 changes: 78 additions & 13 deletions examples/consistency_distillation/train_lcm_distill_lora_sdxl.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
UNet2DConditionModel,
)
from diffusers.optimization import get_scheduler
from diffusers.training_utils import resolve_interpolation_mode
from diffusers.utils import check_min_version, convert_state_dict_to_diffusers, is_wandb_available
from diffusers.utils.import_utils import is_xformers_available

Expand Down Expand Up @@ -193,8 +194,9 @@ def append_dims(x, target_dims):

# From LCMScheduler.get_scalings_for_boundary_condition_discrete
def scalings_for_boundary_conditions(timestep, sigma_data=0.5, timestep_scaling=10.0):
c_skip = sigma_data**2 / ((timestep / 0.1) ** 2 + sigma_data**2)
c_out = (timestep / 0.1) / ((timestep / 0.1) ** 2 + sigma_data**2) ** 0.5
scaled_timestep = timestep_scaling * timestep
c_skip = sigma_data**2 / (scaled_timestep**2 + sigma_data**2)
c_out = scaled_timestep / (scaled_timestep**2 + sigma_data**2) ** 0.5
return c_skip, c_out


Expand Down Expand Up @@ -396,6 +398,15 @@ def parse_args():
" resolution"
),
)
parser.add_argument(
"--interpolation_type",
type=str,
default="bilinear",
help=(
"The interpolation function used when resizing images to the desired resolution. Choose between `bilinear`,"
" `bicubic`, `box`, `nearest`, `nearest_exact`, `hamming`, and `lanczos`."
),
)
parser.add_argument(
"--center_crop",
default=False,
Expand Down Expand Up @@ -534,6 +545,50 @@ def parse_args():
default=64,
help="The rank of the LoRA projection matrix.",
)
parser.add_argument(
"--lora_alpha",
type=int,
default=64,
help=(
"The value of the LoRA alpha parameter, which controls the scaling factor in front of the LoRA weight"
" update delta_W. No scaling will be performed if this value is equal to `lora_rank`."
),
)
parser.add_argument(
"--lora_dropout",
type=float,
default=0.0,
help="The dropout probability for the dropout layer added before applying the LoRA to each layer input.",
)
parser.add_argument(
"--lora_target_modules",
type=str,
default=None,
help=(
"A comma-separated string of target module keys to add LoRA to. If not set, a default list of modules will"
" be used. By default, LoRA will be applied to all conv and linear layers."
),
)
parser.add_argument(
"--vae_encode_batch_size",
type=int,
default=8,
required=False,
help=(
"The batch size used when encoding (and decoding) images to latents (and vice versa) using the VAE."
" Encoding or decoding the whole batch at once may run into OOM issues."
),
)
parser.add_argument(
"--timestep_scaling_factor",
type=float,
default=10.0,
help=(
"The multiplicative timestep scaling factor used when calculating the boundary scalings for LCM. The"
" higher the scaling is, the lower the approximation error, but the default value of 10.0 should typically"
" suffice."
),
)
# ----Mixed Precision----
parser.add_argument(
"--mixed_precision",
Expand Down Expand Up @@ -776,10 +831,10 @@ def main(args):
text_encoder_two.to(accelerator.device, dtype=weight_dtype)

# 9. Add LoRA to the student U-Net, only the LoRA projection matrix will be updated by the optimizer.
lora_config = LoraConfig(
r=args.lora_rank,
lora_alpha=args.lora_rank,
target_modules=[
if args.lora_target_modules is not None:
lora_target_modules = [module_key.strip() for module_key in args.lora_target_modules.split(",")]
else:
lora_target_modules = [
"to_q",
"to_k",
"to_v",
Expand All @@ -794,7 +849,12 @@ def main(args):
"downsamplers.0.conv",
"upsamplers.0.conv",
"time_emb_proj",
],
]
lora_config = LoraConfig(
r=args.lora_rank,
target_modules=lora_target_modules,
lora_alpha=args.lora_alpha,
lora_dropout=args.lora_dropout,
)
unet.add_adapter(lora_config)

Expand Down Expand Up @@ -929,7 +989,8 @@ def load_model_hook(models, input_dir):
)

# Preprocessing the datasets.
train_resize = transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR)
interpolation_mode = resolve_interpolation_mode(args.interpolation_type)
train_resize = transforms.Resize(args.resolution, interpolation=interpolation_mode)
train_crop = transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution)
train_flip = transforms.RandomHorizontalFlip(p=1.0)
train_transforms = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5])])
Expand Down Expand Up @@ -1121,11 +1182,11 @@ def compute_time_ids(original_size, crops_coords_top_left):

encoded_text = compute_embeddings_fn(text, orig_size, crop_coords)

# encode pixel values with batch size of at most 8
# encode pixel values with batch size of at most args.vae_encode_batch_size
pixel_values = pixel_values.to(dtype=vae.dtype)
latents = []
for i in range(0, pixel_values.shape[0], args.encode_batch_size):
latents.append(vae.encode(pixel_values[i : i + args.encode_batch_size]).latent_dist.sample())
for i in range(0, pixel_values.shape[0], args.vae_encode_batch_size):
latents.append(vae.encode(pixel_values[i : i + args.vae_encode_batch_size]).latent_dist.sample())
latents = torch.cat(latents, dim=0)

latents = latents * vae.config.scaling_factor
Expand All @@ -1142,9 +1203,13 @@ def compute_time_ids(original_size, crops_coords_top_left):
timesteps = torch.where(timesteps < 0, torch.zeros_like(timesteps), timesteps)

# 3. Get boundary scalings for start_timesteps and (end) timesteps.
c_skip_start, c_out_start = scalings_for_boundary_conditions(start_timesteps)
c_skip_start, c_out_start = scalings_for_boundary_conditions(
start_timesteps, timestep_scaling=args.timestep_scaling_factor
)
c_skip_start, c_out_start = [append_dims(x, latents.ndim) for x in [c_skip_start, c_out_start]]
c_skip, c_out = scalings_for_boundary_conditions(timesteps)
c_skip, c_out = scalings_for_boundary_conditions(
timesteps, timestep_scaling=args.timestep_scaling_factor
)
c_skip, c_out = [append_dims(x, latents.ndim) for x in [c_skip, c_out]]

# 4. Sample noise from the prior and add it to the latents according to the noise magnitude at each
Expand Down
Loading