[BC 4.37 -> 4.38] for Llama family, memory and speed #29753
[BC 4.37 -> 4.38] for Llama family, memory and speed #29753ArthurZucker merged 22 commits intomainfrom
BC 4.37 -> 4.38] for Llama family, memory and speed #29753Conversation
…ausal-mask-dispatch
…nsformers into fix-causal-mask-dispatch
…ausal-mask-dispatch
BC 4.37 -> 4.38]BC 4.37 -> 4.38] for Llama family, memory and speed
gante
left a comment
There was a problem hiding this comment.
In general looks good to me, although I'm not 100% sure on the causal_mask *= torch.arange(target_length, device=device) > cache_position[0] line + assisted generation -- going to have a look
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
|
torch script can be fixed by @fxmarty in a follow up PR + patch IMO! |
|
@ArthurZucker that would be great if torchscript/fx tests pass &
does not break |
younesbelkada
left a comment
There was a problem hiding this comment.
Thanks for the offline explanation ! this shouldn't affect FA2 as we always return attention_mask without processing it for FA2 modules in _update_causal_mask, as you explained offline causal_mask *= torch.arange(target_length, device=device) > cache_position[0] is used to mask out the cached hidden states
|
@ArthurZucker the line In the previous version, in a situation with 17 cached tokens and 4 new assistant tokens, the causal mask would become i.e. not upper triangular. After the suggested change, it becomes |
|
Any idea why our tests don't complain? |
I think we don't have hard correctness checks for assisted generation 🙈 only API checks |
* attempt to fix * the actual fix that works with compilation! * this? * temporary update * nit? * dispatcg to memory efficient? * update both models that have static cache support * fix copies fix compile * make sure fix * fix cohere and gemma * fix beams? * nit * slipped through the cracks * nit * nits * update * fix-copies * skip failing tests * nits
|
congratulations with this PR fixing many important things! @gante and myself introduced this test recently in #29731 - please make it a part of the test suite for all things related to attention_masks and here is what happens (numbers based on I hesitate to offer a PR because it may break other things that you try to do with this part of code. But, please, make the test work. BTW, it may be OK to change the test, for instance passing the whole bigger mask including cached items by editing this line to Special note on StaticCache: I like this feature and I want to use custom 4D masks with it. So far this is not tested. I'd be glad to contribute such test once this issue is fixed. It will look like |
|
I can try to fix it, and yes I thought it would be tested automatically. It should be part of the |
|
^ this PR should fix it 🤗 |
What does this PR do?
Fixes the BC issues between the two versions in term of memory consumption.
This fix is made a lot easier by all the tests, so thanks a lot @gante!
fixes #29412, fixes #29484 , fixes #29644, fixes #29651