Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 48 additions & 0 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3280,6 +3280,54 @@ def save_pretrained(
if self._auto_class is not None:
custom_object_save(self, save_directory, config=self.config)

# Detect if tie_word_embeddings=True but the actual weights have diverged (e.g. after PEFT merge_and_unload)
# This mirrors the load-side check at tie_weights() which uses torch.equal
# Only auto-fix if the output embeddings are actually declared as tied to the input embeddings
# in _tied_weights_keys (some models like Pop2Piano have lm_head but don't tie it)
try:
if getattr(model_to_save.config, "tie_word_embeddings", False):
output_embeddings = model_to_save.get_output_embeddings()
input_embeddings = model_to_save.get_input_embeddings()
if output_embeddings is not None and input_embeddings is not None:
out_w = getattr(output_embeddings, "weight", None)
in_w = getattr(input_embeddings, "weight", None)
if out_w is not None and in_w is not None:
# Verify that the model actually declares these weights as tied
# via _tied_weights_keys before attempting auto-fix
tied_keys = getattr(model_to_save, "_tied_weights_keys", None) or {}
out_names = {n for n, p in model_to_save.named_parameters() if p is out_w}
in_names = {n for n, p in model_to_save.named_parameters() if p is in_w}
embeddings_declared_tied = any(
(k in out_names and v in in_names) or (k in in_names and v in out_names)
for k, v in tied_keys.items()
)
if not embeddings_declared_tied:
pass # lm_head exists but is not tied to input embeddings — skip
elif out_w is not in_w:
# They are separate Python objects — check if values diverged
weights_differ = False
if out_w.shape != in_w.shape:
weights_differ = True
elif out_w.device != in_w.device:
# Cross-device: skip comparison to avoid false positives
# in model-parallel / offloading scenarios
pass
elif out_w.device.type != "meta" and not torch.equal(out_w, in_w):
weights_differ = True

if weights_differ:
model_to_save.config.tie_word_embeddings = False
logger.warning(
"Detected that the model config has `tie_word_embeddings=True` but the input "
"and output embeddings have different values (e.g. after PEFT merging or "
"vocabulary resizing). Setting `tie_word_embeddings=False` in the saved config "
"to prevent weight corruption on reload."
)
except NotImplementedError:
pass
except Exception as e:
logger.debug(f"Could not check tied embeddings consistency during save: {e}")

# Save the config
if is_main_process:
if not _hf_peft_config_loaded:
Expand Down
35 changes: 35 additions & 0 deletions tests/utils/test_modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1602,6 +1602,41 @@ def test_tied_weights_are_always_tied_from_config(self):
model = LlamaForCausalLM._from_config(copy.deepcopy(config))
self.assertTrue(model.lm_head.weight is not model.model.embed_tokens.weight)

def test_save_pretrained_auto_fixes_diverged_tied_embeddings(self):
"""Test that save_pretrained sets tie_word_embeddings=False and emits a warning when weights have diverged."""
config = LlamaConfig(num_hidden_layers=2, hidden_size=32, intermediate_size=16, tie_word_embeddings=True)
model = LlamaForCausalLM(config)
# Sanity: weights should be tied
self.assertIs(model.lm_head.weight, model.model.embed_tokens.weight)

# Simulate PEFT merge_and_unload: manually untie and assign different values
with torch.no_grad():
model.lm_head.weight = nn.Parameter(model.lm_head.weight.clone())
model.lm_head.weight.fill_(0.42)
model.model.embed_tokens.weight.fill_(0.24)
# Sanity: weights are now separate objects with different values
self.assertIsNot(model.lm_head.weight, model.model.embed_tokens.weight)
self.assertFalse(torch.equal(model.lm_head.weight, model.model.embed_tokens.weight))

logger = logging.get_logger("transformers.modeling_utils")
with tempfile.TemporaryDirectory() as tmp_dir:
with CaptureLogger(logger) as cl:
model.save_pretrained(tmp_dir)

# 1. The warning should have been emitted
self.assertIn("Setting `tie_word_embeddings=False`", cl.out)

# 2. The saved config should have tie_word_embeddings=False
with open(os.path.join(tmp_dir, "config.json")) as f:
saved_config = json.load(f)
self.assertFalse(saved_config["tie_word_embeddings"])

# 3. Reloading the model should preserve separate weights
reloaded = LlamaForCausalLM.from_pretrained(tmp_dir)
self.assertIsNot(reloaded.lm_head.weight, reloaded.model.embed_tokens.weight)
self.assertTrue(torch.allclose(reloaded.lm_head.weight, torch.tensor(0.42), atol=1e-6))
self.assertTrue(torch.allclose(reloaded.model.embed_tokens.weight, torch.tensor(0.24), atol=1e-6))

def test_unexpected_keys_warnings(self):
model = ModelWithHead(PreTrainedConfig(tie_word_embeddings=True))
logger = logging.get_logger("transformers.modeling_utils")
Expand Down
Loading