Fix causal mask in llama for long seq_length#29263
Fix causal mask in llama for long seq_length#29263YLGH wants to merge 1 commit intohuggingface:mainfrom
Conversation
|
Will this PR fix #29252 ? (Yi has same architecture as Llama2) |
Yeah I think so, (it uses LlamaForCausalLM) |
|
I'm not sure why unit tests are failing, they seem unrelated to my change... |
ArthurZucker
left a comment
There was a problem hiding this comment.
thanks! I don't think a while loop is the best way to achieve this. However we do need a proper fix!
| while seq_length > causal_mask.shape[-1]: | ||
| causal_mask = torch.full((2 * causal_mask.shape[-1], 2 * causal_mask.shape[-1]), fill_value=1) |
There was a problem hiding this comment.
| while seq_length > causal_mask.shape[-1]: | |
| causal_mask = torch.full((2 * causal_mask.shape[-1], 2 * causal_mask.shape[-1]), fill_value=1) | |
| if seq_length > causal_mask.shape[-1]: | |
| new_max_positions = round(seq_length / causal_mask.shape[-1]) * causal_mask.shape[-1] | |
| causal_mask = torch.full((new_max_positions, new_max_positions), fill_value=1) |
There was a problem hiding this comment.
this should offer a tradeoff between loading too big of a mask and increasing it based on the length of the input
|
It would also allow one to run |
|
fixed by #29753 ! |
|
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |
What does this PR do?
Fix a bug in Llama causal mask creation for when input length > causal_mask.shape (which comes from max_position_embeddings).
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.
@ArthurZucker @younesbelkada