From 89f8d3fda4b727159b29314af40010d88519b249 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lucca=20Zen=C3=B3bio?= Date: Mon, 17 Apr 2023 21:00:37 -0300 Subject: [PATCH] cast to weight dtype --- .../dreambooth_inpaint/train_dreambooth_inpaint_lora.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/research_projects/dreambooth_inpaint/train_dreambooth_inpaint_lora.py b/examples/research_projects/dreambooth_inpaint/train_dreambooth_inpaint_lora.py index 07df6f201175..821c66b7237a 100644 --- a/examples/research_projects/dreambooth_inpaint/train_dreambooth_inpaint_lora.py +++ b/examples/research_projects/dreambooth_inpaint/train_dreambooth_inpaint_lora.py @@ -735,7 +735,7 @@ def collate_fn(examples): torch.nn.functional.interpolate(mask, size=(args.resolution // 8, args.resolution // 8)) for mask in masks ] - ) + ).to(dtype=weight_dtype) mask = mask.reshape(-1, 1, args.resolution // 8, args.resolution // 8) # Sample noise that we'll add to the latents