diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index f50774ef8065..b54cc88e3926 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -3280,6 +3280,54 @@ def save_pretrained( if self._auto_class is not None: custom_object_save(self, save_directory, config=self.config) + # Detect if tie_word_embeddings=True but the actual weights have diverged (e.g. after PEFT merge_and_unload) + # This mirrors the load-side check at tie_weights() which uses torch.equal + # Only auto-fix if the output embeddings are actually declared as tied to the input embeddings + # in _tied_weights_keys (some models like Pop2Piano have lm_head but don't tie it) + try: + if getattr(model_to_save.config, "tie_word_embeddings", False): + output_embeddings = model_to_save.get_output_embeddings() + input_embeddings = model_to_save.get_input_embeddings() + if output_embeddings is not None and input_embeddings is not None: + out_w = getattr(output_embeddings, "weight", None) + in_w = getattr(input_embeddings, "weight", None) + if out_w is not None and in_w is not None: + # Verify that the model actually declares these weights as tied + # via _tied_weights_keys before attempting auto-fix + tied_keys = getattr(model_to_save, "_tied_weights_keys", None) or {} + out_names = {n for n, p in model_to_save.named_parameters() if p is out_w} + in_names = {n for n, p in model_to_save.named_parameters() if p is in_w} + embeddings_declared_tied = any( + (k in out_names and v in in_names) or (k in in_names and v in out_names) + for k, v in tied_keys.items() + ) + if not embeddings_declared_tied: + pass # lm_head exists but is not tied to input embeddings — skip + elif out_w is not in_w: + # They are separate Python objects — check if values diverged + weights_differ = False + if out_w.shape != in_w.shape: + weights_differ = True + elif out_w.device != in_w.device: + # Cross-device: skip comparison to avoid false positives + # in model-parallel / offloading scenarios + pass + elif out_w.device.type != "meta" and not torch.equal(out_w, in_w): + weights_differ = True + + if weights_differ: + model_to_save.config.tie_word_embeddings = False + logger.warning( + "Detected that the model config has `tie_word_embeddings=True` but the input " + "and output embeddings have different values (e.g. after PEFT merging or " + "vocabulary resizing). Setting `tie_word_embeddings=False` in the saved config " + "to prevent weight corruption on reload." + ) + except NotImplementedError: + pass + except Exception as e: + logger.debug(f"Could not check tied embeddings consistency during save: {e}") + # Save the config if is_main_process: if not _hf_peft_config_loaded: diff --git a/tests/utils/test_modeling_utils.py b/tests/utils/test_modeling_utils.py index 7366845c4d78..8e42ec3f4890 100644 --- a/tests/utils/test_modeling_utils.py +++ b/tests/utils/test_modeling_utils.py @@ -1602,6 +1602,41 @@ def test_tied_weights_are_always_tied_from_config(self): model = LlamaForCausalLM._from_config(copy.deepcopy(config)) self.assertTrue(model.lm_head.weight is not model.model.embed_tokens.weight) + def test_save_pretrained_auto_fixes_diverged_tied_embeddings(self): + """Test that save_pretrained sets tie_word_embeddings=False and emits a warning when weights have diverged.""" + config = LlamaConfig(num_hidden_layers=2, hidden_size=32, intermediate_size=16, tie_word_embeddings=True) + model = LlamaForCausalLM(config) + # Sanity: weights should be tied + self.assertIs(model.lm_head.weight, model.model.embed_tokens.weight) + + # Simulate PEFT merge_and_unload: manually untie and assign different values + with torch.no_grad(): + model.lm_head.weight = nn.Parameter(model.lm_head.weight.clone()) + model.lm_head.weight.fill_(0.42) + model.model.embed_tokens.weight.fill_(0.24) + # Sanity: weights are now separate objects with different values + self.assertIsNot(model.lm_head.weight, model.model.embed_tokens.weight) + self.assertFalse(torch.equal(model.lm_head.weight, model.model.embed_tokens.weight)) + + logger = logging.get_logger("transformers.modeling_utils") + with tempfile.TemporaryDirectory() as tmp_dir: + with CaptureLogger(logger) as cl: + model.save_pretrained(tmp_dir) + + # 1. The warning should have been emitted + self.assertIn("Setting `tie_word_embeddings=False`", cl.out) + + # 2. The saved config should have tie_word_embeddings=False + with open(os.path.join(tmp_dir, "config.json")) as f: + saved_config = json.load(f) + self.assertFalse(saved_config["tie_word_embeddings"]) + + # 3. Reloading the model should preserve separate weights + reloaded = LlamaForCausalLM.from_pretrained(tmp_dir) + self.assertIsNot(reloaded.lm_head.weight, reloaded.model.embed_tokens.weight) + self.assertTrue(torch.allclose(reloaded.lm_head.weight, torch.tensor(0.42), atol=1e-6)) + self.assertTrue(torch.allclose(reloaded.model.embed_tokens.weight, torch.tensor(0.24), atol=1e-6)) + def test_unexpected_keys_warnings(self): model = ModelWithHead(PreTrainedConfig(tie_word_embeddings=True)) logger = logging.get_logger("transformers.modeling_utils")