From 25f5e066309762cef0562f9bafa3a2d5a1aa273b Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Thu, 20 Feb 2025 16:18:48 +0200 Subject: [PATCH 1/2] fix t5 training bug --- .../train_dreambooth_lora_flux_advanced.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py index 235113d6a348..3392f187b6b6 100644 --- a/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py +++ b/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py @@ -881,7 +881,7 @@ def save_embeddings(self, file_path: str): for idx, text_encoder in enumerate(self.text_encoders): train_ids = self.train_ids if idx == 0 else self.train_ids_t5 embeds = ( - text_encoder.text_model.embeddings.token_embedding if idx == 0 else text_encoder.encoder.embed_tokens + text_encoder.text_model.embeddings.token_embedding if idx == 0 else text_encoder.shared ) assert embeds.weight.data.shape[0] == len(self.tokenizers[idx]), "Tokenizers should be the same." new_token_embeddings = embeds.weight.data[train_ids] @@ -905,7 +905,7 @@ def device(self): def retract_embeddings(self): for idx, text_encoder in enumerate(self.text_encoders): embeds = ( - text_encoder.text_model.embeddings.token_embedding if idx == 0 else text_encoder.encoder.embed_tokens + text_encoder.text_model.embeddings.token_embedding if idx == 0 else text_encoder.shared ) index_no_updates = self.embeddings_settings[f"index_no_updates_{idx}"] embeds.weight.data[index_no_updates] = ( @@ -1749,7 +1749,7 @@ def load_model_hook(models, input_dir): if args.enable_t5_ti: # whether to do pivotal tuning/textual inversion for T5 as well text_lora_parameters_two = [] for name, param in text_encoder_two.named_parameters(): - if "token_embedding" in name: + if "shared" in name: # ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16 param.data = param.to(dtype=torch.float32) param.requires_grad = True From 8c1751ef654f19b331ce25dc34dbea65b41b51d2 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Wed, 26 Feb 2025 13:26:14 +0000 Subject: [PATCH 2/2] Apply style fixes --- .../train_dreambooth_lora_flux_advanced.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py index 3392f187b6b6..7a546c33bebb 100644 --- a/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py +++ b/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py @@ -880,9 +880,7 @@ def save_embeddings(self, file_path: str): idx_to_text_encoder_name = {0: "clip_l", 1: "t5"} for idx, text_encoder in enumerate(self.text_encoders): train_ids = self.train_ids if idx == 0 else self.train_ids_t5 - embeds = ( - text_encoder.text_model.embeddings.token_embedding if idx == 0 else text_encoder.shared - ) + embeds = text_encoder.text_model.embeddings.token_embedding if idx == 0 else text_encoder.shared assert embeds.weight.data.shape[0] == len(self.tokenizers[idx]), "Tokenizers should be the same." new_token_embeddings = embeds.weight.data[train_ids] @@ -904,9 +902,7 @@ def device(self): @torch.no_grad() def retract_embeddings(self): for idx, text_encoder in enumerate(self.text_encoders): - embeds = ( - text_encoder.text_model.embeddings.token_embedding if idx == 0 else text_encoder.shared - ) + embeds = text_encoder.text_model.embeddings.token_embedding if idx == 0 else text_encoder.shared index_no_updates = self.embeddings_settings[f"index_no_updates_{idx}"] embeds.weight.data[index_no_updates] = ( self.embeddings_settings[f"original_embeddings_{idx}"][index_no_updates]