diff --git a/examples/dreambooth/train_dreambooth.py b/examples/dreambooth/train_dreambooth.py index ad03829fd1bc..97b7f334bc9f 100644 --- a/examples/dreambooth/train_dreambooth.py +++ b/examples/dreambooth/train_dreambooth.py @@ -1211,7 +1211,7 @@ def compute_text_embeddings(prompt): text_encoder_use_attention_mask=args.text_encoder_use_attention_mask, ) - if unet.config.in_channels == channels * 2: + if accelerator.unwrap_model(unet).config.in_channels == channels * 2: noisy_model_input = torch.cat([noisy_model_input, noisy_model_input], dim=1) if args.class_labels_conditioning == "timesteps": diff --git a/examples/dreambooth/train_dreambooth_lora.py b/examples/dreambooth/train_dreambooth_lora.py index 49aef1cc4a99..ca25152fcb1c 100644 --- a/examples/dreambooth/train_dreambooth_lora.py +++ b/examples/dreambooth/train_dreambooth_lora.py @@ -1156,7 +1156,7 @@ def compute_text_embeddings(prompt): text_encoder_use_attention_mask=args.text_encoder_use_attention_mask, ) - if unet.config.in_channels == channels * 2: + if accelerator.unwrap_model(unet).config.in_channels == channels * 2: noisy_model_input = torch.cat([noisy_model_input, noisy_model_input], dim=1) if args.class_labels_conditioning == "timesteps":