From d4e99393b32e91136bd561f636c3675aa258823a Mon Sep 17 00:00:00 2001 From: Cursx <33718736+Cursx@users.noreply.github.com> Date: Tue, 31 Mar 2026 10:13:26 +0800 Subject: [PATCH] Fix save_pretrained() to set tie_word_embeddings=False when weights are independently modified outside of Transformers (e.g., via PEFT) --- src/transformers/modeling_utils.py | 37 ++++++++++++++++++++++++++++++ 1 file changed, 37 insertions(+) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index f50774ef8065..0abc2b50d645 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -3280,6 +3280,43 @@ def save_pretrained( if self._auto_class is not None: custom_object_save(self, save_directory, config=self.config) + # Detect if embeddings have been untied at runtime (e.g. after PEFT merge or vocab resize + # with independent training). If the config still says tie_word_embeddings=True but the + # actual tensor storages differ, update the config to prevent weight corruption on reload. + if getattr(model_to_save.config, "tie_word_embeddings", False): + try: + input_embeddings = model_to_save.get_input_embeddings() + output_embeddings = model_to_save.get_output_embeddings() + if ( + input_embeddings is not None + and output_embeddings is not None + and hasattr(input_embeddings, "weight") + and hasattr(output_embeddings, "weight") + ): + in_weight = input_embeddings.weight + out_weight = output_embeddings.weight + # If they don't share identical memory, their values might still be identical (e.g. cloned). + # If their values differ entirely (like after PEFT merge), they are functionally untied. + if in_weight.data_ptr() != out_weight.data_ptr(): + if in_weight.device != out_weight.device: + is_tied = torch.equal(in_weight.to(out_weight.device), out_weight) + else: + is_tied = torch.equal(in_weight, out_weight) + + if not is_tied: + logger.warning( + "The model config specifies `tie_word_embeddings=True` but the input and output embeddings" + " do not share the same weights (they may have been untied after PEFT adapter merging or" + " vocabulary resizing). Setting `tie_word_embeddings=False` in the saved config to prevent" + " weight corruption on reload." + ) + model_to_save.config.tie_word_embeddings = False + except NotImplementedError: + pass + except Exception as e: + # Catch any device/meta tensor related errors gracefully + logger.debug("Could not check tied embeddings during save: %s", e) + # Save the config if is_main_process: if not _hf_peft_config_loaded: