Generate: handle cache_position update in generate#29467
Generate: handle cache_position update in generate#29467gante merged 6 commits intohuggingface:mainfrom
cache_position update in generate#29467Conversation
f5c91b9 to
572ca8e
Compare
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
ArthurZucker
left a comment
There was a problem hiding this comment.
Alright, I think Llama is already testing this. Moving fast here
There was a problem hiding this comment.
alright, we are deprecating this anyways
There was a problem hiding this comment.
my single worry here is potential stride, adding a .contiguous() might be needed
There was a problem hiding this comment.
I've double-checked, it's always (1,) 🤗 (which makes sense, since it's a 1D tensor)
Its shape will indeed be different, at least between prefill and subsequent generation
There was a problem hiding this comment.
We should also set the dtype of the cache positions to int32 wdyt?
There was a problem hiding this comment.
Our integers inputs (input_ids, attention_mask, ...) are all int64, I think we should keep a consistent type :p
There was a problem hiding this comment.
we have correct long typing here!
There was a problem hiding this comment.
(see int64 comment above)
58660e2 to
10360b3
Compare
|
(rebasing and reruning tests, just in case 🙃 ) |
To resolve error `TypeError: LlavaLlamaForCausalLM.forward() got an unexpected keyword argument 'cache_position'` introduced by huggingface/transformers#29467
What does this PR do?
Updates
cache_positioningenerate, and makes it the primary source for the input position in the models that support them,llamaandgemma(as opposed to relying onpast_key_values.seen_tokens).The PR also adds the following related changes:
StaticCachenow supportsget_seq_length(). This was drawn from Static Cache: no mandatorycache_positionsinput #29221, and is needed for.prepare_inputs_for_generation()retrocompatibility;seen_tokensattribute enters a deprecation cycle, as it is redundant withcache_positions(and doesn't work with compilation).This PR is drawn from the diff in #29374, i.e. it is a requirement for
generatecompilation withfullgraph=True🙌👉 Llama, Gemma, and Cache slow tests ran, no new failures
👉 FWD compilation benchmarks ran, no throughput change