Add a static cache that offloads to the CPU or other device#32161
Add a static cache that offloads to the CPU or other device#32161ArthurZucker merged 2 commits intohuggingface:mainfrom
Conversation
87541c5 to
8b862ab
Compare
386f231 to
3d413cb
Compare
There was a problem hiding this comment.
Already looks great! IMO would be nice to add a snippet of how to use it in the doc of the class + once #32150 is merged also add this as compatible with torch.compile (while the non static version won't be)
Really like the findings you posted on the issue, will be useful for everyone I think!
|
Also would be interesting to test this / showcase potential for huge beamsearch ! |
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
| # For backwards compatibility. | |
| self._seen_tokens = 0 |
this one is unused throughout the code
There was a problem hiding this comment.
See comment below.
There was a problem hiding this comment.
let's use the fn from the static cache (since we want to remove self._seen_tokens)
There was a problem hiding this comment.
That would be a performance degradation compared to the current integer update of self._seen_tokens. Is there a plan to remove get_seq_length ? In that case, I would remove it then. Otherwise, this will be a lot slower since it will have to do (synced) CPU operation on the offloaded cache. Let me know what you think.
There was a problem hiding this comment.
Yeah get_seq_length is deprecated in favor of the cache_positions which should not need CPU operations as you only use them on device
There was a problem hiding this comment.
Okay, but shouldn't we just remove the get_seq_length methods including all neccesary variables (_seen_tokens) once it gets removed from the API? Since there is still quite a few usages throughout the codebase.
What I meant was that get_seq_length will be much less performant in the meantime if I switch it over to the StaticCache implementation. Which if you want that in the meantime, I'm fine with as well.
There was a problem hiding this comment.
Ah sorry, no need to use the one from static, and actually yeah, we should probably just prevent user from using it -> no offloading for them?
Let's go with keeping seentoken for now, add a comment saying #TODO @gante remove this
There was a problem hiding this comment.
we should probably just prevent user from using it -> no offloading for them?
The method get_seq_length works fine but it's used internally still. Hence my hesitation to remove it / revert it to the slower StaticCache version.
There was a problem hiding this comment.
Sounds good!
Me and Arthur will handle the deprecation of both on all cache types + internal usage afterwards 🤗
There was a problem hiding this comment.
@ArthurZucker Please review and merge when happy.
3d413cb to
b470b63
Compare
|
@ArthurZucker @gante Thanks for reviewing! Made some fixes and added unit tests in a new commit. Also added the performance testing into the PR description. |
fd41a90 to
890db71
Compare
gante
left a comment
There was a problem hiding this comment.
LGTM 🤗
Thank you for the cool feature and for iterating with us!
There was a problem hiding this comment.
| Like [`~OffloadedCache`] exists for offloading a "DynamicCache", there is also an offloaded static cache. Just | |
| ### Offloaded Static Cache | |
| Like [`~OffloadedCache`] exists for offloading a "DynamicCache", there is also an offloaded static cache. Just |
I think it deserves a subsection of its own 🤗
There was a problem hiding this comment.
Will do, also, I will add it to the overview table.
There was a problem hiding this comment.
Done! Although I do wonder what Initialization Recommended means in the overview table?
There was a problem hiding this comment.
Whether it should be init outside generate and passed to generate or not cc @zucchini-nlp on this!
There was a problem hiding this comment.
Some cache classes are recommended to initialize outside generation, e.g. StaticCache with compilation had some issues when we initialized cache while compiling.
Also, some cache types are not handled automatically by our API, e.g. SinkCache so the user has no option as to initialize and pass past_key_values
890db71 to
dc8d226
Compare
ArthurZucker
left a comment
There was a problem hiding this comment.
Well sorry for the late review, very very nice, let's make sure the slow tests pass and good to go! Can you try to run them locally? 🤗
| self.dtype = dtype if dtype is not None else torch.float32 | ||
|
|
||
| # Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads | ||
| head_dim = config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads |
There was a problem hiding this comment.
| head_dim = config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads | |
| head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) |
|
@gerbenvv I can also merge like this and commit later if you are busy! |
Yeah trying to run the tests now I'll try running just the file |
|
Hmm, the tests have literally crashed the whole machine. Is it supposed to use all the GPUs on the machine? I am struggling a bit to run this. |
|
Then I tried to fix those errors by passing the token & installing ran succesful and those were the ones that I have changed. |
|
Okay, making progress ;-) New output of So no tests regarding the static, dynamic or offloading caches are failing. I think this should be good enough to get this merged, right? |
|
Yep let's go! 🔥 |
…ace#32161) * Add a static cache that offloads to the CPU or other device * Fix PR comments, add unit-tests
What does this PR do?
This PR adds a static cache that offloads to another device.
Fixes #32179
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
@ArthurZucker @gante @n17s
Performance tests
Performance tested it with:
torch.compile(model)and I am getting a throughput of about 535 tokens/s (OOM static).
Also with:
which gets 10.6 tokens/s (12.8 tokens/s static).
And with:
which does 98.8 tokens/s (106.8 tokens/s static)
And with:
which does 939.5 tokens/s (995.6 tokens/s static)