Skip to content

Add ThinKPress#20

Merged
SimJeg merged 4 commits intomainfrom
simon/think-press
Dec 3, 2024
Merged

Add ThinKPress#20
SimJeg merged 4 commits intomainfrom
simon/think-press

Conversation

@SimJeg
Copy link
Copy Markdown
Collaborator

@SimJeg SimJeg commented Nov 28, 2024

Implementation of ThinKV following #18

@SimJeg SimJeg mentioned this pull request Nov 28, 2024
@SimJeg
Copy link
Copy Markdown
Collaborator Author

SimJeg commented Nov 28, 2024

@yuhuixu1993 here is a proposal for the implementation of ThinKPress based on our previous discussions in #18:

  • As in your proposal, I'm zeroing the pruned dimensions
  • I added support to optionally combine it with any press, and in my first experiments it worked great ! As you required, the inner press is applied before ThinK
  • I also ensured support for quantization

Please confirm this press is what you want. I explicitely mentioned you in the docstring as a reviewer.

ThinKPress is the first press in this repo that compress the channel dimension, hence the code is bit more complex. If other similar presses are proposed, we will refactor the code to make the implementation easier (e.g. we can imagine a SequenceBasePress and a DimensionBasePress with options to compose them in one way or the other).

@SimJeg SimJeg requested a review from maxjeblick November 28, 2024 17:42
@yuhuixu1993
Copy link
Copy Markdown
Contributor

yuhuixu1993 commented Nov 29, 2024

@SimJeg, It looks awesome!!! Thanks for your hard work!! Sorry for the late response as we have time difference(I live in Singapore). While I notice that in this implementation channel pruning is applied in all tokens in key cache, we prefer to keep the most recent tokens unchanged e.g. 32. I did not test the performance of current evaluations if we prune all the tokens.
In my previous PR:

keys = torch.cat([keys[:, :, :q_len - self.window_size, :].masked_fill(mask_k, 0), keys[:, :, q_len - self.window_size:, :]], dim=-2)

@SimJeg
Copy link
Copy Markdown
Collaborator Author

SimJeg commented Dec 2, 2024

I will update. Are you using 32 for the window size too ? (I've been using 64 as in the default of SnapKV).

@yuhuixu1993
Copy link
Copy Markdown
Contributor

Thanks. Yes, I use 32 in the paper, but 64 is definitely OK.

Comment thread kvpress/presses/think_press.py Outdated
Comment thread kvpress/presses/think_press.py
Copy link
Copy Markdown
Collaborator

@maxjeblick maxjeblick left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for this PR.
From the technical side, @yuhuixu1993 already gave some great feedback.

I left a comment regarding RoPE computation where I'm not sure about the dimensions.

Apart from that, it would be great to add an additional test that tests the inner_press functionality, as well.

@SimJeg
Copy link
Copy Markdown
Collaborator Author

SimJeg commented Dec 3, 2024

@yuhuixu1993 I tried with and without 0-ing the channels for the last 32 tokens and did not see any difference in the prompt I tried. May I keep the current version ? I'm asking because if other similar presses come, would be nice to have a uniform API and not very custom changes. Also to implement what you ask we need to slightly update the compression ratio to take into account the 32*n_pruned_channels elements that are not removed.

@yuhuixu1993
Copy link
Copy Markdown
Contributor

@SimJeg I think it is OK with current version, as the performance is OK, many thanks for the experiments!!

Copy link
Copy Markdown
Collaborator

@maxjeblick maxjeblick left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot LGTM!

I inspected if module.rotary_emb(query_states, q_len) is needed, and it seems that

  • mixtral
  • gpt neox
  • open_llama
  • idefics (cross attention)

are using it. As transformers converges to using module.rotary_emb(query_states, position_ids) , I don't think we need to support these models.

@SimJeg SimJeg merged commit ac2445e into main Dec 3, 2024
@SimJeg SimJeg deleted the simon/think-press branch December 3, 2024 15:29
FFY0 added a commit to FFY0/AdaKV-in-NVIDIA-kvpress that referenced this pull request Dec 6, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants