Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions examples/controlnet/train_controlnet_sdxl.py
Original file line number Diff line number Diff line change
Expand Up @@ -1001,7 +1001,12 @@ def compute_embeddings(batch, proportion_empty_prompts, text_encoders, tokenizer
proportion_empty_prompts=args.proportion_empty_prompts,
)
with accelerator.main_process_first():
train_dataset = train_dataset.map(compute_embeddings_fn, batched=True)
from datasets.fingerprint import Hasher

# fingerprint used by the cache for the other processes to load the result
# details: https://github.com/huggingface/diffusers/pull/4038#discussion_r1266078401
new_fingerprint = Hasher.hash(args)
Copy link
Contributor

@williamberman williamberman Jul 21, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are the args going to actually be good enough to create a hash? Multiple runs of the script might have the same args. Is that ok? I'm not sure I haven't thought through well enough.

Ideally we can get the PID of the parent process if the parent process is accelerate and hash that. If the parent process is not accelerate, we don't have to pass any additional fingerprint

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Resorting to @lhoestq again.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Multiple runs of the script might have the same args. Is that ok? I'm not sure I haven't thought through well enough.

If that is the case, we would want to avoid the execution of the map fn and instead load from the cache no? Or will there be undesired consequences of that?

In any case, coming to your suggestion on

Ideally we can get the PID of the parent process if the parent process is accelerate and hash that. If the parent process is not accelerate, we don't have to pass any additional fingerprint

Are you thinking of something like:

with accelerator.main_process_first():
    from datasets.fingerprint import Hasher

    # fingerprint used by the cache for the other processes to load the result
    # details: https://github.com/huggingface/diffusers/pull/4038#discussion_r1266078401
    if accelerator.is_main_process:
    	pid = os.getpid()
    new_fingerprint = Hasher.hash(pid)
    train_dataset = train_dataset.map(compute_embeddings_fn, batched=True, new_fingerprint=new_fingerprint)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

up to you if it's ok to reload from the cache when calling map, I'm not as familiar with the script :) if it is ok, a comment in the code would be nice on under what circumstances and why

I'm not familiar on the accelerate forking model and if one of the scripts themselves ends up being the parent process or if there's a separate accelerate script that forks into the children. If you need the parent pid (again depending on who actually gets forked), you would call os.getppid

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@lhoestq WDYT? IMO, it should be okay to

If that is the case, we would want to avoid the execution of the map fn and instead load from the cache no? Or will there be undesired consequences of that?

Copy link
Member

@lhoestq lhoestq Jul 21, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If that is the case, we would want to avoid the execution of the map fn and instead load from the cache no? Or will there be undesired consequences of that?

Since the args are the same it will reload from cache in subsequent run instead of reprocessing the data :)
The resulting dataset doesn't depend on whether accelerate is used no ?

train_dataset = train_dataset.map(compute_embeddings_fn, batched=True, new_fingerprint=new_fingerprint)

del text_encoders, tokenizers
gc.collect()
Expand Down Expand Up @@ -1113,8 +1118,6 @@ def compute_embeddings(batch, proportion_empty_prompts, text_encoders, tokenizer
# Convert images to latent space
if args.pretrained_vae_model_name_or_path is not None:
pixel_values = batch["pixel_values"].to(dtype=weight_dtype)
if vae.dtype != weight_dtype:
vae.to(dtype=weight_dtype)
else:
pixel_values = batch["pixel_values"]
latents = vae.encode(pixel_values).latent_dist.sample()
Expand Down