diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py index 3db9ff65e441..568279d9be3e 100644 --- a/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py +++ b/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py @@ -68,6 +68,7 @@ is_wandb_available, ) from diffusers.utils.import_utils import is_xformers_available +from diffusers.utils.torch_utils import is_compiled_module # Will error if the minimal version of diffusers is not installed. Remove at your own risks. @@ -1293,6 +1294,11 @@ def main(args): else: param.requires_grad = False + def unwrap_model(model): + model = accelerator.unwrap_model(model) + model = model._orig_mod if is_compiled_module(model) else model + return model + # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format def save_model_hook(models, weights, output_dir): if accelerator.is_main_process: @@ -1303,14 +1309,14 @@ def save_model_hook(models, weights, output_dir): text_encoder_two_lora_layers_to_save = None for model in models: - if isinstance(model, type(accelerator.unwrap_model(unet))): + if isinstance(model, type(unwrap_model(unet))): unet_lora_layers_to_save = convert_state_dict_to_diffusers(get_peft_model_state_dict(model)) - elif isinstance(model, type(accelerator.unwrap_model(text_encoder_one))): + elif isinstance(model, type(unwrap_model(text_encoder_one))): if args.train_text_encoder: text_encoder_one_lora_layers_to_save = convert_state_dict_to_diffusers( get_peft_model_state_dict(model) ) - elif isinstance(model, type(accelerator.unwrap_model(text_encoder_two))): + elif isinstance(model, type(unwrap_model(text_encoder_two))): if args.train_text_encoder: text_encoder_two_lora_layers_to_save = convert_state_dict_to_diffusers( get_peft_model_state_dict(model) @@ -1338,11 +1344,11 @@ def load_model_hook(models, input_dir): while len(models) > 0: model = models.pop() - if isinstance(model, type(accelerator.unwrap_model(unet))): + if isinstance(model, type(unwrap_model(unet))): unet_ = model - elif isinstance(model, type(accelerator.unwrap_model(text_encoder_one))): + elif isinstance(model, type(unwrap_model(text_encoder_one))): text_encoder_one_ = model - elif isinstance(model, type(accelerator.unwrap_model(text_encoder_two))): + elif isinstance(model, type(unwrap_model(text_encoder_two))): text_encoder_two_ = model else: raise ValueError(f"unexpected save model: {model.__class__}")