Skip to content

Delay float32 upcast in ForCausalLMLoss after filtering ignore_index#40065

Open
starcatmeow wants to merge 1 commit intohuggingface:mainfrom
starcatmeow:for-casual-lm-loss-optim
Open

Delay float32 upcast in ForCausalLMLoss after filtering ignore_index#40065
starcatmeow wants to merge 1 commit intohuggingface:mainfrom
starcatmeow:for-casual-lm-loss-optim

Conversation

@starcatmeow
Copy link
Copy Markdown
Contributor

What does this PR do?

This PR implements the optimization discussed in #38452, originally proposed by @harshit2997.
Thanks for the original suggestion and discussion.

  • Move the float32 upcast in ForCausalLMLoss to after filtering out ignore_index labels.
  • Ensures only relevant logits are upcasted, reducing VRAM usage without affecting correctness.

Fixes #38452

Before submitting

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

…uggingface#38452)

This avoids upcasting logits corresponding to ignore_index positions,
reducing unnecessary memory usage during loss computation.
Particularly useful when fine-tuning causal LMs with prompt tokens
set to ignore_index (e.g., -100).
@Rocketknight1 Rocketknight1 force-pushed the for-casual-lm-loss-optim branch from 2235288 to 2b9e9cf Compare August 13, 2025 13:47
@Rocketknight1
Copy link
Copy Markdown
Member

Hi @starcatmeow, I just took a look and I'm not sure we can accept this! Although I thought it was a good optimization in the issue at #38452, this is because I misread - I thought the code was already selecting masked labels, it was just doing it after casting to float32 instead of before. In that case, doing the select first would save memory with no issues.

However, adding a mask select step when we didn't use one before introduces some problems - in particular, it makes the sizes of the output tensors data-dependent, which can force recompilations with less efficient dynamic shapes. I'm not sure it's worth it for the memory saving here! cc core maintainers @ArthurZucker @Cyrilvallez for their opinion

@Cyrilvallez
Copy link
Copy Markdown
Member

Cyrilvallez commented Aug 14, 2025

Interesting idea, I'm not sure how often we expect to see ignore_index in the labels, cc @ArthurZucker do you have an idea? It could actually save a LOT of memory as vocabularies tend to be large now (200k+)

Mathematically, it's fully equivalent

@harshit2997
Copy link
Copy Markdown

Thanks @starcatmeow for picking this change up. @Cyrilvallez my main motivation behind proposing the change was to account for fine-tuning cases where one doesn't want to fine tune on prompt tokens. In addition, won't this also help with getting rid of padding tokens before the upcast which is a common scenario?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Memory saving by upcasting logits for only non-ignored positions

4 participants