Support casting to fp32 when word embeddings are tied to lm_head#4446
Conversation
|
Profiling with Profiling when LM_HEAD and WORD_EMBEDDINGS share the same object so additional memory for FP32 This is equal to 1.401GB - 1.110GB = 296 MB (approx) 2 is the factor of additional bytes FP32/FP16 Forward Pass 🤔 |
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
|
|
||
| def test_training_with_cast_lm_head_to_fp32(self): | ||
| @pytest.mark.parametrize( | ||
| "model_name", ["trl-internal-testing/tiny-Qwen3ForCausalLM", "trl-internal-testing/tiny-Gemma2ForCausalLM"] |
There was a problem hiding this comment.
Qwen3 has tied word embedding and Gemma 2 no, correct? If so, I'd just add a small comment so that we remember why we test these two cases
There was a problem hiding this comment.
It's the other way around Qwen3 has untied and Gemma 2 has tied.
| return (inputs[0].to(torch.float32),) + inputs[1:] | ||
|
|
||
| original_dtype_local = target_model.lm_head.weight.dtype | ||
| target_model.lm_head = target_model.lm_head.float() |
There was a problem hiding this comment.
for the record, float() is inlace, so in theory, you could just have
target_model.lm_head.float()it happens that .float() returns self, so target_model.lm_head = target_model.lm_head.float() works as well. I personally prefer the current way, assignment makes it more explicit.
qgallouedec
left a comment
There was a problem hiding this comment.
Nice! just a small comment on the test
commit 7a9592b Author: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> Date: Tue Nov 4 14:32:04 2025 -0700 🐍 Drop Python 3.9 (huggingface#4183) commit 7f15a7f Author: Harras Mansoor <98635627+Harras3@users.noreply.github.com> Date: Wed Nov 5 02:06:31 2025 +0500 Removed outdated warning about batch contamination (huggingface#4423) commit 8b0a3ce Author: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com> Date: Tue Nov 4 21:37:39 2025 +0100 Update tokenizer apply_chat_template with return_dict=True default (huggingface#4448) commit d9f9e2b Author: Pramodith Ballapuram <16939722+pramodith@users.noreply.github.com> Date: Tue Nov 4 19:56:58 2025 +0000 Support casting to fp32 when word embeddings are tied to lm_head (huggingface#4446) commit 4e138ab Author: Sergio Paniego Blanco <sergiopaniegoblanco@gmail.com> Date: Tue Nov 4 15:15:23 2025 +0100 Upload notebook with T4 selected (huggingface#4449)
commit 4677cf2 Author: Harras Mansoor <98635627+Harras3@users.noreply.github.com> Date: Wed Nov 5 04:06:13 2025 +0500 Removed Sentiment Tuning Examples (#4424) commit 7a9592b Author: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> Date: Tue Nov 4 14:32:04 2025 -0700 🐍 Drop Python 3.9 (#4183) commit 7f15a7f Author: Harras Mansoor <98635627+Harras3@users.noreply.github.com> Date: Wed Nov 5 02:06:31 2025 +0500 Removed outdated warning about batch contamination (#4423) commit 8b0a3ce Author: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com> Date: Tue Nov 4 21:37:39 2025 +0100 Update tokenizer apply_chat_template with return_dict=True default (#4448) commit d9f9e2b Author: Pramodith Ballapuram <16939722+pramodith@users.noreply.github.com> Date: Tue Nov 4 19:56:58 2025 +0000 Support casting to fp32 when word embeddings are tied to lm_head (#4446) commit 4e138ab Author: Sergio Paniego Blanco <sergiopaniegoblanco@gmail.com> Date: Tue Nov 4 15:15:23 2025 +0100 Upload notebook with T4 selected (#4449) commit 43253b2 Author: Pramodith Ballapuram <16939722+pramodith@users.noreply.github.com> Date: Mon Nov 3 21:07:31 2025 +0000 Add On-Policy Distillation from thinking labs to paper index. (#4410) Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> commit 6f41b18 Author: Behrooz Azarkhalili <80390531+behroozazarkhalili@users.noreply.github.com> Date: Mon Nov 3 10:57:51 2025 -0800 fix: Remove chat template setting from non-SFT trainer scripts (#4437) Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com> Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
What does this PR do?
Per slack discussion https://huggingface.slack.com/archives/C089Q56GPMM/p1761949724954879.
TinyGemma has
tie_word_embeddings=TrueQwen3 has
tie_word_embeddings=FalseBefore submitting
Pull Request section?
to it if that's the case.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.