⛰️ Reduce peak vram consumption with efficient selective log_softmax#2799
Merged
qgallouedec merged 12 commits intoFeb 7, 2025
Merged
Conversation
…log-softmax approach
Contributor
Author
|
See benchmarks here: #2773 (comment) (thanks @qgallouedec ) Notably, the most efficient approach in these benchmarks is not stable with bfloat16, and so we fall back to the approach that loops over log_softmax for bfloat16 and float16. |
Contributor
Author
tyler-romero
commented
Feb 7, 2025
|
|
||
| import requests | ||
| import torch | ||
| import wandb |
Contributor
Author
There was a problem hiding this comment.
changed by running precommit
tyler-romero
commented
Feb 7, 2025
qgallouedec
reviewed
Feb 7, 2025
Member
|
That's a super cool improvement! Thanks! |
qgallouedec
reviewed
Feb 7, 2025
qgallouedec
reviewed
Feb 7, 2025
| logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits]) | ||
| per_token_logps = selected_logits - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x) | ||
| else: | ||
| # logsumexp approach is unstable with bfloat16, fall back to slightly less efficent approach |
Contributor
Author
|
Ready for re-review! |
qgallouedec
approved these changes
Feb 7, 2025
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
Member
|
Thanks again! |
yxliu-TAMU
pushed a commit
to mincheolseong/ECEN743-GRPO-Project-Proposal
that referenced
this pull request
Apr 20, 2025
…uggingface#2799) * Reduce mem consumption across many trainers with efficient selective log-softmax approach * rename * typo fix * precommit * Update tests/test_core.py * relocate * precommit * style * smaller values for test, and run on cpu * nit doc improvements * style * fix test --------- Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
What does this PR do?
Many TRL Trainers use the same log_softmax -> gather operation to compute a selected set of logprobs. This approach is inefficient b/c it allocates a
bs*seqlen*vocab_sizetensor to hold the logprobs. For modest bs/seqlen/vocab_size this tensor can require >2GB vram. There are a variety of more memory efficient (and faster) approaches.This PR creates a utility function to hold a more efficient implementation of this operation and uses that utility function broadly across TRL.
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.