From 3ef531e775a41d5558b3e5d86da1ff4b8e7646d6 Mon Sep 17 00:00:00 2001 From: anton-l Date: Thu, 13 Oct 2022 15:09:29 +0200 Subject: [PATCH] Fix dreambooth loss type with prior preservation --- examples/dreambooth/train_dreambooth.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/dreambooth/train_dreambooth.py b/examples/dreambooth/train_dreambooth.py index fe4741d5e2db..ca39aeff236b 100644 --- a/examples/dreambooth/train_dreambooth.py +++ b/examples/dreambooth/train_dreambooth.py @@ -544,7 +544,7 @@ def collate_fn(examples): noise, noise_prior = torch.chunk(noise, 2, dim=0) # Compute instance loss - loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() + loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="none").mean([1, 2, 3]).mean() # Compute prior loss prior_loss = F.mse_loss(noise_pred_prior.float(), noise_prior.float(), reduction="mean")