From 5232e7f40855bce4baa018b04c471848bdfd0a3e Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Fri, 23 Aug 2024 14:07:44 +0300 Subject: [PATCH 1/5] fix shape --- examples/dreambooth/train_dreambooth_flux.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_flux.py b/examples/dreambooth/train_dreambooth_flux.py index ece12e289e0c..dcd9e7dae065 100644 --- a/examples/dreambooth/train_dreambooth_flux.py +++ b/examples/dreambooth/train_dreambooth_flux.py @@ -1586,8 +1586,8 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): model_pred = FluxPipeline._unpack_latents( model_pred, - height=int(model_input.shape[2]), - width=int(model_input.shape[3]), + height=int(model_input.shape[2] * vae_scale_factor / 2), + width=int(model_input.shape[3] * vae_scale_factor / 2), vae_scale_factor=vae_scale_factor, ) From d4491f037cf262e97f5576575562819429d3800f Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Fri, 23 Aug 2024 14:16:51 +0300 Subject: [PATCH 2/5] fix prompt encoding --- examples/dreambooth/train_dreambooth_flux.py | 99 +++++++++++--------- 1 file changed, 56 insertions(+), 43 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_flux.py b/examples/dreambooth/train_dreambooth_flux.py index dcd9e7dae065..2d18b09a1a79 100644 --- a/examples/dreambooth/train_dreambooth_flux.py +++ b/examples/dreambooth/train_dreambooth_flux.py @@ -842,7 +842,7 @@ def __getitem__(self, index): return example -def tokenize_prompt(tokenizer, prompt, max_sequence_length=512): +def tokenize_prompt(tokenizer, prompt, max_sequence_length): text_inputs = tokenizer( prompt, padding="max_length", @@ -863,20 +863,26 @@ def _encode_prompt_with_t5( prompt=None, num_images_per_prompt=1, device=None, + text_input_ids=None, ): prompt = [prompt] if isinstance(prompt, str) else prompt batch_size = len(prompt) - text_inputs = tokenizer( - prompt, - padding="max_length", - max_length=max_sequence_length, - truncation=True, - return_length=False, - return_overflowing_tokens=False, - return_tensors="pt", - ) - text_input_ids = text_inputs.input_ids + if tokenizer is not None: + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_length=False, + return_overflowing_tokens=False, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + else: + if text_input_ids is None: + raise ValueError("text_input_ids must be provided when the tokenizer is not specified") + prompt_embeds = text_encoder(text_input_ids.to(device))[0] dtype = text_encoder.dtype @@ -896,22 +902,28 @@ def _encode_prompt_with_clip( tokenizer, prompt: str, device=None, + text_input_ids=None, num_images_per_prompt: int = 1, ): prompt = [prompt] if isinstance(prompt, str) else prompt batch_size = len(prompt) - text_inputs = tokenizer( - prompt, - padding="max_length", - max_length=77, - truncation=True, - return_overflowing_tokens=False, - return_length=False, - return_tensors="pt", - ) + if tokenizer is not None: + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=77, + truncation=True, + return_overflowing_tokens=False, + return_length=False, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + else: + if text_input_ids is None: + raise ValueError("text_input_ids must be provided when the tokenizer is not specified") - text_input_ids = text_inputs.input_ids prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=False) # Use pooled output of CLIPTextModel @@ -932,6 +944,7 @@ def encode_prompt( max_sequence_length, device=None, num_images_per_prompt: int = 1, + text_input_ids_list=None, ): prompt = [prompt] if isinstance(prompt, str) else prompt batch_size = len(prompt) @@ -943,6 +956,7 @@ def encode_prompt( prompt=prompt, device=device if device is not None else text_encoders[0].device, num_images_per_prompt=num_images_per_prompt, + text_input_ids=text_input_ids_list[0] if text_input_ids_list else None, ) prompt_embeds = _encode_prompt_with_t5( @@ -952,6 +966,7 @@ def encode_prompt( prompt=prompt, num_images_per_prompt=num_images_per_prompt, device=device if device is not None else text_encoders[1].device, + text_input_ids=text_input_ids_list[1] if text_input_ids_list else None, ) text_ids = torch.zeros(batch_size, prompt_embeds.shape[1], 3).to(device=device, dtype=dtype) @@ -1499,7 +1514,25 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): ) else: tokens_one = tokenize_prompt(tokenizer_one, prompts, max_sequence_length=77) - tokens_two = tokenize_prompt(tokenizer_two, prompts, max_sequence_length=512) + tokens_two = tokenize_prompt( + tokenizer_two, prompts, max_sequence_length=args.max_sequence_length + ) + prompt_embeds, pooled_prompt_embeds, text_ids = encode_prompt( + text_encoders=[text_encoder_one, text_encoder_two], + tokenizers=[None, None], + text_input_ids_list=[tokens_one, tokens_two], + max_sequence_length=args.max_sequence_length, + prompt=prompts, + ) + else: + if args.train_text_encoder: + prompt_embeds, pooled_prompt_embeds, text_ids = encode_prompt( + text_encoders=[text_encoder_one, text_encoder_two], + tokenizers=[None, None], + text_input_ids_list=[tokens_one, tokens_two], + max_sequence_length=args.max_sequence_length, + prompt=args.instance_prompt, + ) # Convert images to latent space model_input = vae.encode(pixel_values).latent_dist.sample() @@ -1553,8 +1586,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): guidance = None # Predict the noise residual - if not args.train_text_encoder: - model_pred = transformer( + model_pred = transformer( hidden_states=packed_noisy_model_input, # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing) timestep=timesteps / 1000, @@ -1565,25 +1597,6 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): img_ids=latent_image_ids, return_dict=False, )[0] - else: - prompt_embeds, pooled_prompt_embeds, text_ids = encode_prompt( - text_encoders=[text_encoder_one, text_encoder_two], - tokenizers=None, - prompt=None, - text_input_ids_list=[tokens_one, tokens_two], - ) - model_pred = transformer( - hidden_states=packed_noisy_model_input, - # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing) - timestep=timesteps / 1000, - guidance=guidance, - pooled_projections=pooled_prompt_embeds, - encoder_hidden_states=prompt_embeds, - txt_ids=text_ids, - img_ids=latent_image_ids, - return_dict=False, - )[0] - model_pred = FluxPipeline._unpack_latents( model_pred, height=int(model_input.shape[2] * vae_scale_factor / 2), From f9b33ed3f5ca80a01a8e6a4d8a8769eddcc89d9b Mon Sep 17 00:00:00 2001 From: Linoy Date: Fri, 23 Aug 2024 11:20:35 +0000 Subject: [PATCH 3/5] style --- examples/dreambooth/train_dreambooth_flux.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_flux.py b/examples/dreambooth/train_dreambooth_flux.py index 2d18b09a1a79..003f155368ba 100644 --- a/examples/dreambooth/train_dreambooth_flux.py +++ b/examples/dreambooth/train_dreambooth_flux.py @@ -1587,16 +1587,16 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): # Predict the noise residual model_pred = transformer( - hidden_states=packed_noisy_model_input, - # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing) - timestep=timesteps / 1000, - guidance=guidance, - pooled_projections=pooled_prompt_embeds, - encoder_hidden_states=prompt_embeds, - txt_ids=text_ids, - img_ids=latent_image_ids, - return_dict=False, - )[0] + hidden_states=packed_noisy_model_input, + # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing) + timestep=timesteps / 1000, + guidance=guidance, + pooled_projections=pooled_prompt_embeds, + encoder_hidden_states=prompt_embeds, + txt_ids=text_ids, + img_ids=latent_image_ids, + return_dict=False, + )[0] model_pred = FluxPipeline._unpack_latents( model_pred, height=int(model_input.shape[2] * vae_scale_factor / 2), From badea9da00a3bccad30bf00afdde33d3287a15c1 Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Mon, 26 Aug 2024 14:18:20 +0300 Subject: [PATCH 4/5] fix device --- examples/dreambooth/train_dreambooth_flux.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_flux.py b/examples/dreambooth/train_dreambooth_flux.py index 003f155368ba..63112065a742 100644 --- a/examples/dreambooth/train_dreambooth_flux.py +++ b/examples/dreambooth/train_dreambooth_flux.py @@ -949,12 +949,12 @@ def encode_prompt( prompt = [prompt] if isinstance(prompt, str) else prompt batch_size = len(prompt) dtype = text_encoders[0].dtype - + device = device if device is not None else text_encoders[1].device pooled_prompt_embeds = _encode_prompt_with_clip( text_encoder=text_encoders[0], tokenizer=tokenizers[0], prompt=prompt, - device=device if device is not None else text_encoders[0].device, + device=device, num_images_per_prompt=num_images_per_prompt, text_input_ids=text_input_ids_list[0] if text_input_ids_list else None, ) @@ -965,7 +965,7 @@ def encode_prompt( max_sequence_length=max_sequence_length, prompt=prompt, num_images_per_prompt=num_images_per_prompt, - device=device if device is not None else text_encoders[1].device, + device=device, text_input_ids=text_input_ids_list[1] if text_input_ids_list else None, ) From 2aa29defd743db76b8d6adc1447dab29aadccd64 Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Mon, 26 Aug 2024 14:37:06 +0300 Subject: [PATCH 5/5] add comment --- examples/dreambooth/train_dreambooth_flux.py | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/dreambooth/train_dreambooth_flux.py b/examples/dreambooth/train_dreambooth_flux.py index 63112065a742..da571cc46c57 100644 --- a/examples/dreambooth/train_dreambooth_flux.py +++ b/examples/dreambooth/train_dreambooth_flux.py @@ -1597,6 +1597,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): img_ids=latent_image_ids, return_dict=False, )[0] + # upscaling height & width as discussed in https://github.com/huggingface/diffusers/pull/9257#discussion_r1731108042 model_pred = FluxPipeline._unpack_latents( model_pred, height=int(model_input.shape[2] * vae_scale_factor / 2),