-
Notifications
You must be signed in to change notification settings - Fork 4.5k
[shardformer] llama support flash attention #4185
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
FrankLeeeee
merged 10 commits into
hpcaitech:feature/flash-attention-shardformer
from
flybird11111:flash-llama
Jul 6, 2023
Merged
Changes from all commits
Commits
Show all changes
10 commits
Select commit
Hold shift + click to select a range
cb9d34b
[shardformer] opt support flash attention
flybird11111 ac179b5
[shardformer] opt support flash attention
flybird11111 00c4a82
[shardformer] opt support flash attention
flybird11111 27526d9
[shardformer] opt support flash attention
flybird11111 70535bd
[shardformer] move to modeling
flybird11111 794dd86
[shardformer] move to modeling
flybird11111 60f4a36
Merge branch 'hpcaitech:feature/flash-attention-shardformer' into fea…
flybird11111 50dd1db
[shardformer] llama support flash attention
flybird11111 526d2e1
[shardformer] llama support flash attention
flybird11111 e455479
[shardformer] Move the import statement for xformer outside the forwa…
flybird11111 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,82 @@ | ||
| from typing import Optional, Tuple | ||
|
|
||
| import torch | ||
|
|
||
| __all__ = ['get_llama_forward'] | ||
|
|
||
|
|
||
| def rotate_half(x): | ||
| """Rotates half the hidden dims of the input.""" | ||
| x1 = x[..., :x.shape[-1] // 2] | ||
| x2 = x[..., x.shape[-1] // 2:] | ||
| return torch.cat((-x2, x1), dim=-1) | ||
|
|
||
|
|
||
| def apply_rotary_pos_emb(q, k, cos, sin, position_ids): | ||
| # The first two dimensions of cos and sin are always 1, so we can `squeeze` them. | ||
| cos = cos.squeeze(1).squeeze(0) # [seq_len, dim] | ||
| sin = sin.squeeze(1).squeeze(0) # [seq_len, dim] | ||
| cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] | ||
| sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] | ||
| q_embed = (q * cos) + (rotate_half(q) * sin) | ||
| k_embed = (k * cos) + (rotate_half(k) * sin) | ||
| return q_embed, k_embed | ||
|
|
||
|
|
||
| def get_llama_forward(): | ||
|
|
||
| try: | ||
| from xformers.ops import memory_efficient_attention as me_attention | ||
| except: | ||
| raise ImportError("Error: xformers module is not installed. Please install it to use flash attention.") | ||
|
|
||
| def llama_flash_attention_forward( | ||
| self, | ||
| hidden_states: torch.Tensor, | ||
| attention_mask: Optional[torch.Tensor] = None, | ||
| position_ids: Optional[torch.LongTensor] = None, | ||
| past_key_value: Optional[Tuple[torch.Tensor]] = None, | ||
| output_attentions: bool = False, | ||
| use_cache: bool = False, | ||
| ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: | ||
| bsz, q_len, _ = hidden_states.size() | ||
|
|
||
| query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) | ||
| key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) | ||
| value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) | ||
|
|
||
| kv_seq_len = key_states.shape[-2] | ||
| if past_key_value is not None: | ||
| kv_seq_len += past_key_value[0].shape[-2] | ||
|
|
||
| cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) | ||
| query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) | ||
|
|
||
| if past_key_value is not None: | ||
| # reuse k, v, self_attention | ||
| key_states = torch.cat([past_key_value[0], key_states], dim=2) | ||
| value_states = torch.cat([past_key_value[1], value_states], dim=2) | ||
|
|
||
| past_key_value = (key_states, value_states) if use_cache else None | ||
|
|
||
| me_input_shape = (bsz, q_len, self.num_heads, self.head_dim) | ||
| query_states = query_states.transpose(1, 2).contiguous().view(*me_input_shape) | ||
| key_states = key_states.transpose(1, 2).contiguous().view(*me_input_shape) | ||
| value_states = value_states.transpose(1, 2).contiguous().view(*me_input_shape) | ||
|
|
||
| if attention_mask != None: | ||
| if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): | ||
| raise ValueError( | ||
| f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}") | ||
| attention_mask = attention_mask.expand(bsz, self.num_heads, q_len, kv_seq_len).contiguous() | ||
|
|
||
| attn_output = me_attention(query_states, key_states, value_states, attn_bias=attention_mask) | ||
| if attn_output.size() != (bsz, q_len, self.num_heads, self.head_dim): | ||
| raise ValueError(f"`attn_output` should be of size {(bsz, q_len, self.num_heads, self.head_dim)}, but is" | ||
| f" {attn_output.size()}") | ||
| attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) | ||
| attn_output = self.o_proj(attn_output) | ||
|
|
||
| return attn_output, None, past_key_value | ||
|
|
||
| return llama_flash_attention_forward | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.