Conversation
|
@yuhuixu1993 here is a proposal for the implementation of
Please confirm this press is what you want. I explicitely mentioned you in the docstring as a reviewer.
|
|
@SimJeg, It looks awesome!!! Thanks for your hard work!! Sorry for the late response as we have time difference(I live in Singapore). While I notice that in this implementation channel pruning is applied in all tokens in key cache, we prefer to keep the most recent tokens unchanged e.g. 32. I did not test the performance of current evaluations if we prune all the tokens. keys = torch.cat([keys[:, :, :q_len - self.window_size, :].masked_fill(mask_k, 0), keys[:, :, q_len - self.window_size:, :]], dim=-2) |
|
I will update. Are you using 32 for the window size too ? (I've been using 64 as in the default of SnapKV). |
|
Thanks. Yes, I use 32 in the paper, but 64 is definitely OK. |
maxjeblick
left a comment
There was a problem hiding this comment.
Thanks a lot for this PR.
From the technical side, @yuhuixu1993 already gave some great feedback.
I left a comment regarding RoPE computation where I'm not sure about the dimensions.
Apart from that, it would be great to add an additional test that tests the inner_press functionality, as well.
|
@yuhuixu1993 I tried with and without 0-ing the channels for the last 32 tokens and did not see any difference in the prompt I tried. May I keep the current version ? I'm asking because if other similar presses come, would be nice to have a uniform API and not very custom changes. Also to implement what you ask we need to slightly update the compression ratio to take into account the 32*n_pruned_channels elements that are not removed. |
|
@SimJeg I think it is OK with current version, as the performance is OK, many thanks for the experiments!! |
maxjeblick
left a comment
There was a problem hiding this comment.
Thanks a lot LGTM!
I inspected if module.rotary_emb(query_states, q_len) is needed, and it seems that
- mixtral
- gpt neox
- open_llama
- idefics (cross attention)
are using it. As transformers converges to using module.rotary_emb(query_states, position_ids) , I don't think we need to support these models.
Add ThinKPress (NVIDIA#20)
Implementation of ThinKV following #18