-
Notifications
You must be signed in to change notification settings - Fork 136
Add KeyRerotationPress #31
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
25 commits
Select commit
Hold shift + click to select a range
4390868
add test for rerotation
maxjeblick 42c30d7
move to inv rope implementation
maxjeblick 83d8813
more tests
maxjeblick a34a941
format
maxjeblick 48c5d01
format
maxjeblick 1609240
resue rope
maxjeblick 467e52c
remove compression ratios n tests
maxjeblick 696da76
fix test on gpu
maxjeblick 4275595
better readme
maxjeblick 5ed6208
fix style
maxjeblick ef03b4a
Merge branch 'main' into max/rerotate_keys_2
maxjeblick 1079a03
fix merge conflicts
maxjeblick 78519f7
Merge branch 'main' into max/rerotate_keys_2
maxjeblick 4c9eb72
update to 0.1.0
maxjeblick 1ad53cb
add fp16 test
maxjeblick 037b5c0
refactor tests
maxjeblick dc96e97
fix broken test
maxjeblick 9e51013
update readme
maxjeblick d99f287
address pr feedback
maxjeblick aa355d5
address pr feedback
maxjeblick 575c032
address pr feedback
maxjeblick d64ad7c
Update README
SimJeg f9b5347
address pr feedback
maxjeblick ab514ce
Add PR template
SimJeg a20df83
update PR template
maxjeblick File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,10 @@ | ||
| ## PR description | ||
|
|
||
| Description of your PR. Fixes # (issue) (if applicable) | ||
|
|
||
| ## New press checklist (if applicable) | ||
|
|
||
| - [ ] I added `mypress_press.py` in the `presses` directory | ||
| - [ ] I added `MyPress` in `__init__.py` | ||
| - [ ] I updated the `README.md` with a 1 liner about my new press in the Available presses section | ||
| - [ ] I added my press in the `default_presses` list in `tests/default_presses.py` |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,79 @@ | ||
| # SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
|
|
||
|
|
||
| import inspect | ||
| from dataclasses import dataclass | ||
|
|
||
| import torch | ||
| from torch import nn | ||
| from transformers.models.llama.modeling_llama import rotate_half | ||
|
|
||
| from kvpress.presses.base_press import BasePress | ||
| from kvpress.presses.scorer_press import ScorerPress | ||
|
|
||
|
|
||
| @dataclass | ||
| class KeyRerotationPress(BasePress): | ||
| """ | ||
| Rerotate keys to have a uniform RoPE representation of keys after pruning. | ||
| This method is used in several key-value cache compression methods, such as | ||
| - SinkCache implementation in Hugging Face's transformers library | ||
| - FINCH: Prompt-guided Key-Value Cache Compression for Large Language Models | ||
| Parameters | ||
| ---------- | ||
| press : ScorerPress | ||
| The press object to apply per-layer compression to. | ||
| """ | ||
|
|
||
| press: ScorerPress | ||
|
|
||
| def compress( | ||
| self, | ||
| module: nn.Module, | ||
| hidden_states: torch.Tensor, | ||
| keys: torch.Tensor, | ||
| values: torch.Tensor, | ||
| attentions: torch.Tensor, | ||
| kwargs: dict, | ||
| ) -> tuple[torch.Tensor, torch.Tensor]: | ||
| if self.press.compression_ratio == 0: | ||
| return keys, values | ||
|
|
||
| # Compute scores from base press | ||
| scores = self.press.score(module, hidden_states, keys, values, attentions, kwargs) | ||
|
|
||
| # Get indices of KV pairs with the lowest scores | ||
| q_len = hidden_states.shape[1] | ||
| n_kept = int(q_len * (1 - self.press.compression_ratio)) | ||
| indices = scores.topk(n_kept, dim=-1).indices | ||
| indices = indices.unsqueeze(-1).expand(-1, -1, -1, module.head_dim) | ||
|
|
||
| cos, sin = get_rope_embeddings(module, keys) | ||
| # Rerotate as follows | ||
| # 1. keys = RoPE(W_k * hidden_states) | ||
| # 2. keys_unrotated = RoPE^-1(keys) | ||
| # 3. keys_pruned = prune(keys_unrotated) | ||
| # 4. keys = RoPE(keys_pruned) | ||
|
SimJeg marked this conversation as resolved.
|
||
|
|
||
| # 2. Inverse of rotation matrix is equivalent to setting sin -> -sin in the equation below | ||
| keys = (keys * cos.unsqueeze(1)) + (rotate_half(keys) * (-sin.unsqueeze(1))) | ||
|
SimJeg marked this conversation as resolved.
|
||
| # 3. Prune keys | ||
| keys = keys.gather(2, indices).contiguous() | ||
| # 4. Apply RoPE | ||
| cos, sin = get_rope_embeddings(module, keys) | ||
|
SimJeg marked this conversation as resolved.
|
||
| keys = (keys * cos.unsqueeze(1)) + (rotate_half(keys) * sin.unsqueeze(1)) | ||
|
|
||
| values = values.gather(2, indices).contiguous() | ||
| return keys, values | ||
|
|
||
|
|
||
| def get_rope_embeddings(module, x): | ||
| length = x.shape[2] | ||
| # rotary_emb function only needs .device and .dtype, so we can plug in any tensor regardless of shape | ||
| if "position_ids" in inspect.signature(module.rotary_emb.forward).parameters: | ||
| position_ids = torch.arange(length).unsqueeze(0).to(x.device) | ||
| cos, sin = module.rotary_emb(x, position_ids) | ||
| else: | ||
| cos, sin = module.rotary_emb(x, length) | ||
| return cos, sin | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,42 @@ | ||
| # SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
|
|
||
|
|
||
| from kvpress import ( | ||
| ExpectedAttentionPress, | ||
| KnormPress, | ||
| RandomPress, | ||
| SimLayerKVPress, | ||
| SnapKVPress, | ||
| StreamingLLMPress, | ||
| ThinKPress, | ||
| TOVAPress, | ||
| ) | ||
|
|
||
| # contains all presses to be tested | ||
| # kwargs should be ordered easy to hard compression | ||
| default_presses = [ | ||
| {"cls": KnormPress, "kwargs": [{"compression_ratio": 0.2}, {"compression_ratio": 0.8}]}, | ||
| {"cls": ExpectedAttentionPress, "kwargs": [{"compression_ratio": 0.2}, {"compression_ratio": 0.8}]}, | ||
| {"cls": RandomPress, "kwargs": [{"compression_ratio": 0.2}, {"compression_ratio": 0.8}]}, | ||
| {"cls": StreamingLLMPress, "kwargs": [{"compression_ratio": 0.2}, {"compression_ratio": 0.8}]}, | ||
| { | ||
| "cls": SnapKVPress, | ||
| "kwargs": [{"compression_ratio": 0.2, "window_size": 2}, {"compression_ratio": 0.8, "window_size": 2}], | ||
| }, | ||
| {"cls": TOVAPress, "kwargs": [{"compression_ratio": 0.2}, {"compression_ratio": 0.8}]}, | ||
| { | ||
| "cls": ThinKPress, | ||
| "kwargs": [ | ||
| {"key_channel_compression_ratio": 0.2, "window_size": 2}, | ||
| {"key_channel_compression_ratio": 0.8, "window_size": 2}, | ||
| ], | ||
| }, | ||
| { | ||
| "cls": SimLayerKVPress, | ||
| "kwargs": [ | ||
| {"lazy_threshold": 0.8, "n_initial": 1, "n_recent": 1, "n_last": 1}, | ||
| {"lazy_threshold": 0.2, "n_initial": 1, "n_recent": 1, "n_last": 1}, | ||
| ], | ||
| }, | ||
| ] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.