Skip to content

Fix causal mask in llama for long seq_length#29263

Closed
YLGH wants to merge 1 commit intohuggingface:mainfrom
YLGH:fix_llama_mask
Closed

Fix causal mask in llama for long seq_length#29263
YLGH wants to merge 1 commit intohuggingface:mainfrom
YLGH:fix_llama_mask

Conversation

@YLGH
Copy link
Copy Markdown

@YLGH YLGH commented Feb 23, 2024

What does this PR do?

Fix a bug in Llama causal mask creation for when input length > causal_mask.shape (which comes from max_position_embeddings).

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@ArthurZucker @younesbelkada

@Mihaiii
Copy link
Copy Markdown

Mihaiii commented Feb 24, 2024

Will this PR fix #29252 ? (Yi has same architecture as Llama2)

@YLGH
Copy link
Copy Markdown
Author

YLGH commented Feb 24, 2024

Will this PR fix #29252 ? (Yi has same architecture as Llama2)

Yeah I think so, (it uses LlamaForCausalLM)

@YLGH
Copy link
Copy Markdown
Author

YLGH commented Feb 24, 2024

I'm not sure why unit tests are failing, they seem unrelated to my change...

Copy link
Copy Markdown
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks! I don't think a while loop is the best way to achieve this. However we do need a proper fix!

Comment on lines +1065 to +1066
while seq_length > causal_mask.shape[-1]:
causal_mask = torch.full((2 * causal_mask.shape[-1], 2 * causal_mask.shape[-1]), fill_value=1)
Copy link
Copy Markdown
Collaborator

@ArthurZucker ArthurZucker Feb 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
while seq_length > causal_mask.shape[-1]:
causal_mask = torch.full((2 * causal_mask.shape[-1], 2 * causal_mask.shape[-1]), fill_value=1)
if seq_length > causal_mask.shape[-1]:
new_max_positions = round(seq_length / causal_mask.shape[-1]) * causal_mask.shape[-1]
causal_mask = torch.full((new_max_positions, new_max_positions), fill_value=1)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should offer a tradeoff between loading too big of a mask and increasing it based on the length of the input

@BlackSamorez
Copy link
Copy Markdown
Contributor

It would also allow one to run miqu-1-70b on an RTX 3090 in 2 bits. The model itself is ~19Gb but those masks are too large (around a Gb each and copied sometimes) at 32k max context length, preventing one from running the model.

@ArthurZucker
Copy link
Copy Markdown
Collaborator

fixed by #29753 !

@github-actions
Copy link
Copy Markdown
Contributor

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@github-actions github-actions Bot closed this Apr 30, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants