Skip to content
41 changes: 32 additions & 9 deletions examples/consistency_distillation/train_lcm_distill_lora_sdxl.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
from datasets import load_dataset
from huggingface_hub import create_repo, upload_folder
from packaging import version
from peft import LoraConfig, get_peft_model_state_dict
from peft import LoraConfig, get_peft_model_state_dict, set_peft_model_state_dict
from torchvision import transforms
from torchvision.transforms.functional import crop
from tqdm.auto import tqdm
Expand All @@ -52,7 +52,12 @@
)
from diffusers.optimization import get_scheduler
from diffusers.training_utils import cast_training_params, resolve_interpolation_mode
from diffusers.utils import check_min_version, convert_state_dict_to_diffusers, is_wandb_available
from diffusers.utils import (
check_min_version,
convert_state_dict_to_diffusers,
convert_unet_state_dict_to_peft,
is_wandb_available,
)
from diffusers.utils.import_utils import is_xformers_available


Expand Down Expand Up @@ -858,11 +863,6 @@ def main(args):
)
unet.add_adapter(lora_config)

# Make sure the trainable params are in float32.
if args.mixed_precision == "fp16":
# only upcast trainable parameters (LoRA) into fp32
cast_training_params(unet, dtype=torch.float32)

# Also move the alpha and sigma noise schedules to accelerator.device.
alpha_schedule = alpha_schedule.to(accelerator.device)
sigma_schedule = sigma_schedule.to(accelerator.device)
Expand All @@ -887,13 +887,31 @@ def save_model_hook(models, weights, output_dir):
def load_model_hook(models, input_dir):
# load the LoRA into the model
unet_ = accelerator.unwrap_model(unet)
lora_state_dict, network_alphas = StableDiffusionXLPipeline.lora_state_dict(input_dir)
StableDiffusionXLPipeline.load_lora_into_unet(lora_state_dict, network_alphas=network_alphas, unet=unet_)
lora_state_dict, _ = StableDiffusionXLPipeline.lora_state_dict(input_dir)
unet_state_dict = {
f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")
}
unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict)
incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default")
if incompatible_keys is not None:
# check only for unexpected keys
unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
if unexpected_keys:
logger.warning(
f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
f" {unexpected_keys}. "
)

for _ in range(len(models)):
# pop models so that they are not loaded again
models.pop()

# Make sure the trainable params are in float32. This is again needed since the base models
# are in `weight_dtype`. More details:
# https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804
if args.mixed_precision == "fp16":
cast_training_params(unet_, dtype=torch.float32)

accelerator.register_save_state_pre_hook(save_model_hook)
accelerator.register_load_state_pre_hook(load_model_hook)

Expand Down Expand Up @@ -1092,6 +1110,11 @@ def compute_time_ids(original_size, crops_coords_top_left):
args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
)

# Make sure the trainable params are in float32.
if args.mixed_precision == "fp16":
# only upcast trainable parameters (LoRA) into fp32
cast_training_params(unet, dtype=torch.float32)

lr_scheduler = get_scheduler(
args.lr_scheduler,
optimizer=optimizer,
Expand Down