Fix Mamba2ForCausalLM weight tying#43207
Conversation
Add _tied_weights_keys mapping to enable proper weight tying when tie_word_embeddings=True. This is the standard pattern used by MambaForCausalLM, GPT2, LLaMA, and other models. Fixes huggingface#43206
2a2a93b to
d596364
Compare
vasqu
left a comment
There was a problem hiding this comment.
Can you add a fast test as regression test
|
Thanks! Enabled |
|
Can we instead create a small test for this? It seems that the original model did not have tied weights so this is more of an addition IMO and we check this as an explicit test (would be nice to link / mention your issue as well) |
Replace ModelTester default with explicit test per reviewer feedback.
|
[For maintainers] Suggested jobs to run (before merge) run-slow: mamba2 |
|
@vasqu Added an explicit regression test that checks both tie_word_embeddings=True and tie_word_embeddings=False 🙏 |
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
vasqu
left a comment
There was a problem hiding this comment.
Thank you, just checking with run-slow once just to be paranoid safe then merging :D
| check_equivalence(model, tuple_inputs, dict_inputs, {"output_hidden_states": True}) | ||
|
|
||
| def test_tied_weight_embeddings(self): | ||
| """Regression test for https://github.com/huggingface/transformers/issues/43206.""" |
There was a problem hiding this comment.
Awesome, thanks for linking 🙏
|
run-slow: mamba2 |
|
This comment contains models: ["models/mamba2"] |
CI Results✅ No failing test specific to this PR 🎉 ! |
* Fix Mamba2ForCausalLM weight tying Add _tied_weights_keys mapping to enable proper weight tying when tie_word_embeddings=True. This is the standard pattern used by MambaForCausalLM, GPT2, LLaMA, and other models. Fixes huggingface#43206 * Enable weight tying in Mamba2ModelTester for regression testing * Add explicit regression test for Mamba2 weight tying Replace ModelTester default with explicit test per reviewer feedback.
What does this PR do?
Fixes #43206
Adds the
_tied_weights_keysmapping toMamba2ForCausalLMto enable proper weight tying whentie_word_embeddings=True.The Bug
When
tie_word_embeddings=True, the embedding weights should be shared with thelm_head. However,Mamba2ForCausalLMhad:This caused:
resize_token_embeddings()not resizinglm_headproperlyThe Fix
This is the standard pattern used by
MambaForCausalLM(v1), GPT2, LLaMA, and other models.Verification