diff --git a/examples/dreambooth/train_dreambooth_lora_flux2_klein.py b/examples/dreambooth/train_dreambooth_lora_flux2_klein.py index 278c25900a3a..30f4f4e5d219 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux2_klein.py +++ b/examples/dreambooth/train_dreambooth_lora_flux2_klein.py @@ -1715,7 +1715,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): packed_noisy_model_input = Flux2KleinPipeline._pack_latents(noisy_model_input) # handle guidance - if transformer.config.guidance_embeds: + if unwrap_model(transformer).config.guidance_embeds: guidance = torch.full([1], args.guidance_scale, device=accelerator.device) guidance = guidance.expand(model_input.shape[0]) else: diff --git a/examples/dreambooth/train_dreambooth_lora_flux2_klein_img2img.py b/examples/dreambooth/train_dreambooth_lora_flux2_klein_img2img.py index 28cbaf8f72e7..7edf8c0f194d 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux2_klein_img2img.py +++ b/examples/dreambooth/train_dreambooth_lora_flux2_klein_img2img.py @@ -1682,7 +1682,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): model_input_ids = torch.cat([model_input_ids, cond_model_input_ids], dim=1) # handle guidance - if transformer.config.guidance_embeds: + if unwrap_model(transformer).config.guidance_embeds: guidance = torch.full([1], args.guidance_scale, device=accelerator.device) guidance = guidance.expand(model_input.shape[0]) else: