Fix logic error in prepare_inputs_for_generation cache slicing condition#41764
Conversation
|
I had same question, so I think it's better to ping @gante who has more understanding of the situation. IIRC it had smth to do with Reformer model |
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
|
Hi @albertvillanova @zucchini-nlp 👋 Have a look at this slack thread, I think it contains the full discussion related to this bug (assuming it is the same issue, I think it is) and how it interacts with TRL: https://huggingface.slack.com/archives/C01N44FJDHT/p1760725474636279 TL;DR I think there is a fix we can implement in TRL right now (see thread) to work around it, the transformers-side fix is trickier |
zucchini-nlp
left a comment
There was a problem hiding this comment.
Actually, I just looked at the special RecurrentGemma model which has no past_key_values, and it should not have an edge case if we check values of use_cache. For other special models like Mamba, they have their own input preparation method
So I don't think we needed to check past_key_values is None, unless the edge case mentioned by Joao isn't tested. So let's merge it
|
I will just rebase main to check no new test failures :) |
…ition (huggingface#41764) Fix logic error in cache slicing condition Co-authored-by: Raushan Turganbay <raushan@huggingface.co>
Fix logic error in
prepare_inputs_for_generationcache slicing condition:Background
I think the PR
generatedelegates default cache initialization to the model #41505introduced a logic error where the condition for calling
_cache_dependant_input_preparationusesis Noneinstead ofis not None, causing crashes whenprepare_inputs_for_generationis called withpast_key_values=Noneanduse_cache=False.Bug
PR #41505 introduced this condition:
The condition
past_key_values is None or use_cachemeans:This triggers the function even when:
This combination is invalid for cache-dependent preparation and causes a crash when accessing
cache_position[-1](line 456).Note that during normal generation, it works fine because
use_cache=True, making the buggypast_key_values is Nonepart irrelevant.Fix
This PR changes the condition to:
The condition
past_key_values is not None or use_cachemeans:This is semantically correct and matches the intent described in the PR #41505 comment: #41505 (comment)
The
use_cachepart handles stateful models, whilepast_key_values is not Nonehandles normal cached models.Testing
This PR fixes the downstream failing test in TRL:
See the associated issue:
Related
This PR addresses a logic error introduced by:
generatedelegates default cache initialization to the model #41505This PR will fix CI fails with dev dependencies: TypeError: 'NoneType' object is not subscriptable trl#4272
CC:
generatedelegates default cache initialization to the model #41505generatedelegates default cache initialization to the model #41505; see 🚨 [v5]generatedelegates default cache initialization to the model #41505 (comment)