🚨 Llama: update rope scaling to match static cache changes#29143
🚨 Llama: update rope scaling to match static cache changes#29143gante merged 2 commits intohuggingface:mainfrom
Conversation
| pass | ||
|
|
||
| @parameterized.expand([("linear",), ("dynamic",)]) | ||
| @unittest.skip("TODO @gante fix this for Llama") |
There was a problem hiding this comment.
This test was fixed as a result of the changes in this PR :)
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
ArthurZucker
left a comment
There was a problem hiding this comment.
🧼 nice cleanup!
Main concern: BC, let's keep the cos_cache and sin_cache for 1 release and then we can directly open a PR on main to remove it!
| self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) | ||
| self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) | ||
| def forward(self, x, position_ids, seq_len=None): |
There was a problem hiding this comment.
I am alright with this but it is breaking for any libs that rely on sin cached and cos cached. Same for the static cache PR!
Let's just add a mention that it will be removed next release and still compute cos and sin!
There was a problem hiding this comment.
This is the cool part -- it calls super's forward, which in turn caches sin/cos (see here). BC is preserved 🙌
There was a problem hiding this comment.
Yes but we need a warning to deprecate !
Follow up is fine
There was a problem hiding this comment.
I'm not sure I follow -- the warning is here. Or were you thinking of some other warning?
There was a problem hiding this comment.
Perfect! Had not seen this when I checked the diff
| emb = torch.cat((freqs, freqs), dim=-1) | ||
| self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) | ||
| self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) | ||
| cos, sin = super().forward(x, position_ids, seq_len) |
ArthurZucker
left a comment
There was a problem hiding this comment.
@younesbelkada also pointed out that the shape of the output of the rope layer is different from before. Thus this is a bit breaking. If so, let's add a big 🔴 on the PR to make sure we know that there are breaking changes!
younesbelkada
left a comment
There was a problem hiding this comment.
Tests all pass on PEFT end ! Thanks for the notice 💪
What does this PR do?
(see title :))
What's breaking? The shape of the returned sin/cos caches are changed (sin/cos for all positions -> sin/cos for the positions in
position_ids). Note that this breaking change was also present in the static cache PR, for the main RoPE class (#27931).Review suggestion: