From 2f97b6702ecb3363fac343f11836c2cdce2df79e Mon Sep 17 00:00:00 2001 From: Mohamad Zeina Date: Tue, 5 Dec 2023 13:01:31 +0000 Subject: [PATCH 1/3] Assign device to unet. Resolves #5897 --- examples/text_to_image/train_text_to_image_lora.py | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/text_to_image/train_text_to_image_lora.py b/examples/text_to_image/train_text_to_image_lora.py index 7d731c994bdd..275f4aac0294 100644 --- a/examples/text_to_image/train_text_to_image_lora.py +++ b/examples/text_to_image/train_text_to_image_lora.py @@ -536,6 +536,7 @@ def main(): unet_lora_parameters.extend(attn_module.to_v.lora_layer.parameters()) unet_lora_parameters.extend(attn_module.to_out[0].lora_layer.parameters()) + unet.to(accelerator.device, dtype=weight_dtype) if args.enable_xformers_memory_efficient_attention: if is_xformers_available(): import xformers From ede94ebdfe48d1799cffa51661624c37fdb55fdd Mon Sep 17 00:00:00 2001 From: Mohamad Zeina Date: Tue, 5 Dec 2023 13:34:36 +0000 Subject: [PATCH 2/3] Remove redundant move unet to device --- examples/text_to_image/train_text_to_image_lora.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/text_to_image/train_text_to_image_lora.py b/examples/text_to_image/train_text_to_image_lora.py index 275f4aac0294..369823a56a03 100644 --- a/examples/text_to_image/train_text_to_image_lora.py +++ b/examples/text_to_image/train_text_to_image_lora.py @@ -480,7 +480,6 @@ def main(): weight_dtype = torch.bfloat16 # Move unet, vae and text_encoder to device and cast to weight_dtype - unet.to(accelerator.device, dtype=weight_dtype) vae.to(accelerator.device, dtype=weight_dtype) text_encoder.to(accelerator.device, dtype=weight_dtype) @@ -536,6 +535,7 @@ def main(): unet_lora_parameters.extend(attn_module.to_v.lora_layer.parameters()) unet_lora_parameters.extend(attn_module.to_out[0].lora_layer.parameters()) + # Move unet and lora to device and cast to weight_dtype unet.to(accelerator.device, dtype=weight_dtype) if args.enable_xformers_memory_efficient_attention: if is_xformers_available(): From 0f5e2815274a8f2690bfbec8c901f9f4594f5979 Mon Sep 17 00:00:00 2001 From: Mohamad Zeina Date: Tue, 5 Dec 2023 13:40:19 +0000 Subject: [PATCH 3/3] Fix comment --- examples/text_to_image/train_text_to_image_lora.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/text_to_image/train_text_to_image_lora.py b/examples/text_to_image/train_text_to_image_lora.py index 369823a56a03..c71e7c29b023 100644 --- a/examples/text_to_image/train_text_to_image_lora.py +++ b/examples/text_to_image/train_text_to_image_lora.py @@ -479,7 +479,7 @@ def main(): elif accelerator.mixed_precision == "bf16": weight_dtype = torch.bfloat16 - # Move unet, vae and text_encoder to device and cast to weight_dtype + # Move vae and text_encoder to device and cast to weight_dtype vae.to(accelerator.device, dtype=weight_dtype) text_encoder.to(accelerator.device, dtype=weight_dtype)