Delay float32 upcast in ForCausalLMLoss after filtering ignore_index#40065
Delay float32 upcast in ForCausalLMLoss after filtering ignore_index#40065starcatmeow wants to merge 1 commit intohuggingface:mainfrom
Conversation
…uggingface#38452) This avoids upcasting logits corresponding to ignore_index positions, reducing unnecessary memory usage during loss computation. Particularly useful when fine-tuning causal LMs with prompt tokens set to ignore_index (e.g., -100).
2235288 to
2b9e9cf
Compare
|
Hi @starcatmeow, I just took a look and I'm not sure we can accept this! Although I thought it was a good optimization in the issue at #38452, this is because I misread - I thought the code was already selecting masked labels, it was just doing it after casting to However, adding a mask select step when we didn't use one before introduces some problems - in particular, it makes the sizes of the output tensors data-dependent, which can force recompilations with less efficient dynamic shapes. I'm not sure it's worth it for the memory saving here! cc core maintainers @ArthurZucker @Cyrilvallez for their opinion |
|
Interesting idea, I'm not sure how often we expect to see Mathematically, it's fully equivalent |
|
Thanks @starcatmeow for picking this change up. @Cyrilvallez my main motivation behind proposing the change was to account for fine-tuning cases where one doesn't want to fine tune on prompt tokens. In addition, won't this also help with getting rid of padding tokens before the upcast which is a common scenario? |
Ported from huggingface#40065 (head 2b9e9cf).
What does this PR do?
This PR implements the optimization discussed in #38452, originally proposed by @harshit2997.
Thanks for the original suggestion and discussion.
ForCausalLMLossto after filtering outignore_indexlabels.Fixes #38452
Before submitting
Pull Request section?
to it if that's the case. Memory saving by upcasting logits for only non-ignored positions #38452
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.