Llama: allow custom 4d masks#29618
Conversation
|
|
||
| hid_0 = self.model.model.embed_tokens(input_0) | ||
| outs_0 = self.model.model.layers[0].self_attn.forward(hid_0)[0] | ||
| outs_0 = self.model.model.layers[0].self_attn.forward(hid_0, position_ids=position_ids_0)[0] |
There was a problem hiding this comment.
position_ids is now a "mandatory" input to the attention layer forward
| hid_1 = self.model.model.embed_tokens(input_1) | ||
| outs_1 = self.model.model.layers[0].self_attn.forward( | ||
| hid_1, attention_mask=mask_1.bool(), position_ids=position_ids_1 | ||
| hid_1, attention_mask=causal_mask_1, position_ids=position_ids_1 |
There was a problem hiding this comment.
the attention layer forward now expects numerical 4D causal masks (as opposed to 2D boolean masks)
| outs_1_last_tokens = outs_1[0, -3:, :] # last three tokens | ||
| assert torch.allclose(outs_0_last_tokens, outs_1_last_tokens) | ||
|
|
||
| def test_inner_model(self): |
There was a problem hiding this comment.
This test was a copy of the test below 🤔
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
amyeroberts
left a comment
There was a problem hiding this comment.
Thanks for reenabling this!
Only question before merge is how come this is only needed for the gemma and llama models?
@amyeroberts They are the only models that have received the static cache treatment. The static cache transition did not foresee this case in the original diff :) We are finalizing support on the |
What does this PR do?
Fixes #29525
Reintroduces the ability to pass custom 4D attention masks, which was removed in the static cache transition. The following tests are now passing
cc @ArthurZucker after you come back from holidays, have a look at this PR :)