From 3a8c9c92874e97cf532617ad2888d2a6b9748559 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Thu, 29 Dec 2022 17:50:08 +0100 Subject: [PATCH] update loss computation --- examples/dreambooth/train_dreambooth.py | 2 +- examples/textual_inversion/textual_inversion.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/examples/dreambooth/train_dreambooth.py b/examples/dreambooth/train_dreambooth.py index ddc3a608767b..30f5e0ccae0b 100644 --- a/examples/dreambooth/train_dreambooth.py +++ b/examples/dreambooth/train_dreambooth.py @@ -732,7 +732,7 @@ def main(args): target, target_prior = torch.chunk(target, 2, dim=0) # Compute instance loss - loss = F.mse_loss(model_pred.float(), target.float(), reduction="none").mean([1, 2, 3]).mean() + loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") # Compute prior loss prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean") diff --git a/examples/textual_inversion/textual_inversion.py b/examples/textual_inversion/textual_inversion.py index 467e710222de..2a765e47a20b 100644 --- a/examples/textual_inversion/textual_inversion.py +++ b/examples/textual_inversion/textual_inversion.py @@ -634,7 +634,8 @@ def main(): else: raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") - loss = F.mse_loss(model_pred.float(), target.float(), reduction="none").mean([1, 2, 3]).mean() + loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") + accelerator.backward(loss) optimizer.step()