diff --git a/examples/controlnet/train_controlnet_sdxl.py b/examples/controlnet/train_controlnet_sdxl.py index 065e1c3f1ef7..b394565058b2 100644 --- a/examples/controlnet/train_controlnet_sdxl.py +++ b/examples/controlnet/train_controlnet_sdxl.py @@ -1001,7 +1001,12 @@ def compute_embeddings(batch, proportion_empty_prompts, text_encoders, tokenizer proportion_empty_prompts=args.proportion_empty_prompts, ) with accelerator.main_process_first(): - train_dataset = train_dataset.map(compute_embeddings_fn, batched=True) + from datasets.fingerprint import Hasher + + # fingerprint used by the cache for the other processes to load the result + # details: https://github.com/huggingface/diffusers/pull/4038#discussion_r1266078401 + new_fingerprint = Hasher.hash(args) + train_dataset = train_dataset.map(compute_embeddings_fn, batched=True, new_fingerprint=new_fingerprint) del text_encoders, tokenizers gc.collect() @@ -1113,8 +1118,6 @@ def compute_embeddings(batch, proportion_empty_prompts, text_encoders, tokenizer # Convert images to latent space if args.pretrained_vae_model_name_or_path is not None: pixel_values = batch["pixel_values"].to(dtype=weight_dtype) - if vae.dtype != weight_dtype: - vae.to(dtype=weight_dtype) else: pixel_values = batch["pixel_values"] latents = vae.encode(pixel_values).latent_dist.sample()