From 3d6b4dda4480a349ad2aaec80b299948dd47ece1 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Fri, 21 Jul 2023 08:58:55 +0530 Subject: [PATCH 1/3] hash computation. thanks to @lhoestq --- examples/controlnet/train_controlnet_sdxl.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/examples/controlnet/train_controlnet_sdxl.py b/examples/controlnet/train_controlnet_sdxl.py index 065e1c3f1ef7..264e31729b9a 100644 --- a/examples/controlnet/train_controlnet_sdxl.py +++ b/examples/controlnet/train_controlnet_sdxl.py @@ -1001,7 +1001,12 @@ def compute_embeddings(batch, proportion_empty_prompts, text_encoders, tokenizer proportion_empty_prompts=args.proportion_empty_prompts, ) with accelerator.main_process_first(): - train_dataset = train_dataset.map(compute_embeddings_fn, batched=True) + from datasets.fingerprint import Hasher + + # fingerprint used by the cache for the other processes to load the result + # details: https://github.com/huggingface/diffusers/pull/4038#discussion_r1266078401 + new_fingerprint = Hasher.hash(args) + train_dataset = train_dataset.map(compute_embeddings_fn, batched=True, new_fingerprint=new_fingerprint) del text_encoders, tokenizers gc.collect() From 0d39073bff9cdc5f9b23c12846db20e86d40b77c Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Fri, 21 Jul 2023 09:09:02 +0530 Subject: [PATCH 2/3] disable dtype casting. --- examples/controlnet/train_controlnet_sdxl.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/controlnet/train_controlnet_sdxl.py b/examples/controlnet/train_controlnet_sdxl.py index 264e31729b9a..ab19f245f6a3 100644 --- a/examples/controlnet/train_controlnet_sdxl.py +++ b/examples/controlnet/train_controlnet_sdxl.py @@ -1118,8 +1118,8 @@ def compute_embeddings(batch, proportion_empty_prompts, text_encoders, tokenizer # Convert images to latent space if args.pretrained_vae_model_name_or_path is not None: pixel_values = batch["pixel_values"].to(dtype=weight_dtype) - if vae.dtype != weight_dtype: - vae.to(dtype=weight_dtype) + # if vae.dtype != weight_dtype: + # vae.to(dtype=weight_dtype) else: pixel_values = batch["pixel_values"] latents = vae.encode(pixel_values).latent_dist.sample() From 29c2133e31f8b9cd735cf61f4ba0beccbce21023 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Fri, 21 Jul 2023 09:55:29 +0530 Subject: [PATCH 3/3] remove comments. --- examples/controlnet/train_controlnet_sdxl.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/examples/controlnet/train_controlnet_sdxl.py b/examples/controlnet/train_controlnet_sdxl.py index ab19f245f6a3..b394565058b2 100644 --- a/examples/controlnet/train_controlnet_sdxl.py +++ b/examples/controlnet/train_controlnet_sdxl.py @@ -1118,8 +1118,6 @@ def compute_embeddings(batch, proportion_empty_prompts, text_encoders, tokenizer # Convert images to latent space if args.pretrained_vae_model_name_or_path is not None: pixel_values = batch["pixel_values"].to(dtype=weight_dtype) - # if vae.dtype != weight_dtype: - # vae.to(dtype=weight_dtype) else: pixel_values = batch["pixel_values"] latents = vae.encode(pixel_values).latent_dist.sample()