SlidingWindowCache: reduce differences to other Cache classes#30970
SlidingWindowCache: reduce differences to other Cache classes#30970ArthurZucker merged 4 commits intohuggingface:mainfrom
Conversation
There was a problem hiding this comment.
sliding_window is the config attribute name, not sliding_window_size
There was a problem hiding this comment.
(unrelated to the sliding window cache) this was incorrect, we need a new cache object with a different batch size
There was a problem hiding this comment.
that's a nice catch!
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
ArthurZucker
left a comment
There was a problem hiding this comment.
Overall good for me, but wondering about where these graph breaks are from?
| # assume this will be called only in the first generation step | ||
| # `cache_postion` will be used in other cases | ||
| return 0 | ||
| # `_.zero()` followed by `+=` is equivalent `=`, but compile-friendly (without graph breaks due to assignment) |
There was a problem hiding this comment.
where are these graph break from? (this did not work before?) because it's equivalent but less fast no?
There was a problem hiding this comment.
there is an extra zero here involved to make cudagraphs happy, I believe we should not change the address of the tensor during compiling and direct assign violates that, in StaticCache there is no problem because k_out[:,:,cache_position] = key_states does not change the address of k_out, and if we want a 4d instead of 5d cache, the direct assign will just substitute the original tensor in layers list, causing address change
There was a problem hiding this comment.
Ahhh yeah which is why you had did not have this
There was a problem hiding this comment.
If there is no tradeof to using this (make bench + test on a100 as well) fine, otherwise not fine but a comment to say why
|
@ArthurZucker @zhenglongjiepheonix the implementation from this PR is also faster 🙌 Setup:
codefrom transformers import AutoTokenizer, MistralForCausalLM
import torch
import time
prompts = ["My favourite condiment is " * 100]
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1", use_fast=False)
tokenizer.pad_token = tokenizer.eos_token
model = MistralForCausalLM.from_pretrained(
"mistralai/Mistral-7B-v0.1", device_map="auto", torch_dtype=torch.float16
)
inputs = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device)
print(inputs.input_ids.shape)
model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True)
for i in range(5):
start = time.time()
generated_ids = model.generate(
**inputs, max_new_tokens=128, do_sample=False, cache_implementation="sliding_window"
)
assert generated_ids.shape[1] == 128 + inputs.input_ids.shape[1]
print(f"Time: {time.time() - start:.2f}s")👉 static cache: 76.2 tok/s Could it be because there are fewer slicing OPs? (before, we had to slice the 5D cache into a 4D tensor at every layer) |
Yes, Slicing can be time-consuming, I have tested on my side and in your setting your implementation indeed saves about 1ms per token, I think it's good if we don't have to slice everytime by using |
ArthurZucker
left a comment
There was a problem hiding this comment.
LGTM! Let's merge @gante is not here
What does this PR do?
Follow up to #30642: this PR aims at reducing the differences between
SlidingWindowCacheandStaticCache, such that long-term maintenance becomes easier. Fewer attributes/functions = less cognitive overload and fewer bugs 🤗More specifically:
👉 no need for attributes regarding the sliding window (it is a form of maximum cache size, for which there was an attribute)
👉 list of 4D tensors holding the cache, as opposed to 5D tensors (to keep the same data format as in other caches)
👉 inherits from
StaticCache, as most of the__init__and other boilerplate functions are identicalSlow Mistral tests were ran locally, all green ✅
cc @zhenglongjiepheonix I meant to request these changes in the PR linked above, but I was slow to review 😛