From b1bc9efb9793d1c890c3126e6a94e24e48229ae9 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 5 Jun 2023 09:45:21 +0000 Subject: [PATCH] Correct multi gpu --- examples/dreambooth/train_dreambooth.py | 2 +- examples/dreambooth/train_dreambooth_lora.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/dreambooth/train_dreambooth.py b/examples/dreambooth/train_dreambooth.py index ad03829fd1bc..97b7f334bc9f 100644 --- a/examples/dreambooth/train_dreambooth.py +++ b/examples/dreambooth/train_dreambooth.py @@ -1211,7 +1211,7 @@ def compute_text_embeddings(prompt): text_encoder_use_attention_mask=args.text_encoder_use_attention_mask, ) - if unet.config.in_channels == channels * 2: + if accelerator.unwrap_model(unet).config.in_channels == channels * 2: noisy_model_input = torch.cat([noisy_model_input, noisy_model_input], dim=1) if args.class_labels_conditioning == "timesteps": diff --git a/examples/dreambooth/train_dreambooth_lora.py b/examples/dreambooth/train_dreambooth_lora.py index 49aef1cc4a99..ca25152fcb1c 100644 --- a/examples/dreambooth/train_dreambooth_lora.py +++ b/examples/dreambooth/train_dreambooth_lora.py @@ -1156,7 +1156,7 @@ def compute_text_embeddings(prompt): text_encoder_use_attention_mask=args.text_encoder_use_attention_mask, ) - if unet.config.in_channels == channels * 2: + if accelerator.unwrap_model(unet).config.in_channels == channels * 2: noisy_model_input = torch.cat([noisy_model_input, noisy_model_input], dim=1) if args.class_labels_conditioning == "timesteps":