-
Notifications
You must be signed in to change notification settings - Fork 4.5k
[npu] support triangle attention for llama #5130
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
8 commits
Select commit
Hold shift + click to select a range
5379071
update fused attn
oahzxl 5f4dfea
Merge branch 'main' of https://github.com/hpcaitech/ColossalAI into f…
oahzxl 995c9d8
update spda
oahzxl 4f5a080
tri attn
oahzxl 95d1cc4
update triangle
oahzxl 3b07b59
import
oahzxl 83f5d80
fix
oahzxl 34b83f1
fix
oahzxl File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,3 @@ | ||
| from .mha import NPUColoAttention | ||
|
|
||
| __all__ = ["NPUColoAttention"] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,3 @@ | ||
| from .mha import NPUColoAttention | ||
|
|
||
| __all__ = ["NPUColoAttention"] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,80 @@ | ||
| import math | ||
| from typing import Optional | ||
|
|
||
| import torch | ||
|
|
||
| from .sdpa_attn import npu_sdpa_attention | ||
| from .triangle_attn import HAS_NPU_TRIANGLE_ATTENTION | ||
|
|
||
|
|
||
| class NPUColoAttention(torch.nn.Module): | ||
| def __init__(self, embed_dim: int, num_heads: int, dropout: float = 0.0, scale: float = None): | ||
| super().__init__() | ||
|
|
||
| try: | ||
| import torch_npu # noqa | ||
| except ImportError: | ||
| raise Exception("torch_npu is not installed.") | ||
|
|
||
| assert ( | ||
| embed_dim % num_heads == 0 | ||
| ), f"the embed dim ({embed_dim}) is not divisible by the number of attention heads ({num_heads})." | ||
| if scale is not None: | ||
| self.scale = scale | ||
| else: | ||
| self.scale = 1 / math.sqrt(embed_dim // num_heads) | ||
| self.dropout = dropout | ||
|
|
||
| def forward( | ||
| self, | ||
| query: torch.Tensor, | ||
| key: torch.Tensor, | ||
| value: torch.Tensor, | ||
| attn_mask: Optional[torch.Tensor] = None, | ||
| origin_attn_mask: Optional[torch.Tensor] = None, | ||
| attn_mask_type: int = None, | ||
| bias: Optional[torch.Tensor] = None, | ||
| ): | ||
| """ | ||
| Implement the scaled dot product attention with softmax. | ||
|
|
||
| Arguments: | ||
| q: (batch, q_seqlen, nheads, headdim) | ||
| k: (batch, kv_seqlen, nheads, headdim) | ||
| v: (batch, kv_seqlen, nheads, headdim) | ||
| batch_size: int. | ||
| seq_len: int. | ||
| dropout_p: float. Dropout probability. | ||
| scale: float. The scaling of QK^T before applying softmax. | ||
| Default to 1. | ||
| Return: | ||
| attn_out: (batch, q_seqlen, nheads, headdim). | ||
| """ | ||
| assert ( | ||
| len(query.shape) == 4 and len(key.shape) == 4 and len(value.shape) == 4 | ||
| ), f"query, key, value should be 4D tensors, but got {query.shape}, {key.shape}, {value.shape}" | ||
| assert ( | ||
| query.device.type == "npu" and key.device.type == "npu" and value.device.type == "npu" | ||
| ), f"query, key, value should be on npu device, but got {query.device}, {key.device}, {value.device}" | ||
| assert bias is None, "bias is not supported in npu colo attention" | ||
|
|
||
| causal = attn_mask_type is not None and attn_mask_type.value > 1 | ||
|
|
||
| if HAS_NPU_TRIANGLE_ATTENTION: | ||
| from .triangle_attn import npu_triangle_attention | ||
|
|
||
| attn_fn = npu_triangle_attention | ||
| else: | ||
| attn_fn = npu_sdpa_attention | ||
|
|
||
| out = attn_fn( | ||
| query, | ||
| key, | ||
| value, | ||
| attn_mask=attn_mask, | ||
| origin_attn_mask=origin_attn_mask, | ||
| dropout_p=self.dropout, | ||
| scale=self.scale, | ||
| is_causal=causal, | ||
| ) | ||
| return out | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,41 @@ | ||
| import torch | ||
| from einops import rearrange | ||
|
|
||
|
|
||
| def npu_sdpa_attention( | ||
| q: torch.Tensor, | ||
| k: torch.Tensor, | ||
| v: torch.Tensor, | ||
| attn_mask: torch.Tensor = None, | ||
| origin_attn_mask: torch.Tensor = None, | ||
| scale: float = 1.0, | ||
| dropout_p: float = 0.0, | ||
| is_causal: bool = True, | ||
| ): | ||
| """ | ||
| The scaled dot product attention. | ||
|
|
||
| Arguments: | ||
| q: (batch, q_seqlen, nheads, headdim) | ||
| k: (batch, kv_seqlen, nheads, headdim) | ||
| v: (batch, kv_seqlen, nheads, headdim) | ||
| batch_size: int. | ||
| seq_len: int. | ||
| dropout_p: float. Dropout probability. | ||
| scale: float. The scaling of QK^T before applying softmax. | ||
| Default to 1. | ||
| Return: | ||
| attn_out: (batch, q_seqlen, nheads, headdim). | ||
| """ | ||
| q, k, v = [rearrange(x, "b s h d -> b h s d").contiguous() for x in (q, k, v)] | ||
| output = torch.nn.functional.scaled_dot_product_attention( | ||
| q, | ||
| k, | ||
| v, | ||
| attn_mask=origin_attn_mask, | ||
| dropout_p=dropout_p, | ||
| is_causal=origin_attn_mask is None, | ||
| scale=scale, | ||
| ) | ||
| output = rearrange(output, "b h s d -> b s (h d)") | ||
| return output |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,115 @@ | ||
| # coding=utf-8 | ||
| # Copyright (c) 2023, HUAWEI CORPORATION. All rights reserved. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| import logging | ||
|
|
||
| import torch | ||
| from einops import rearrange | ||
|
|
||
| HAS_NPU_TRIANGLE_ATTENTION = False | ||
| try: | ||
| from torch_npu import npu_confusion_transpose, npu_scaled_masked_softmax | ||
|
|
||
| HAS_NPU_TRIANGLE_ATTENTION = True | ||
| except ImportError: | ||
| logging.warning("Import torch_npu Error.") | ||
|
|
||
|
|
||
| if HAS_NPU_TRIANGLE_ATTENTION: | ||
|
|
||
| def npu_triangle_attention( | ||
| q: torch.Tensor, | ||
| k: torch.Tensor, | ||
| v: torch.Tensor, | ||
| attn_mask: torch.Tensor = None, | ||
| origin_attn_mask: torch.Tensor = None, | ||
| scale: float = 1.0, | ||
| dropout_p: float = 0.0, | ||
| is_causal: bool = True, | ||
| block_size=512, | ||
| ): | ||
| """ | ||
| The triangle attention reduces the attention calculation of the mask | ||
| part by dividing the q, k, and v matrices into blocks | ||
|
|
||
| Arguments: | ||
| block_size: The size of the inverted triangle block, the default is 512, | ||
| the smaller the block_size, the more calculations will be reduced, | ||
| but the number of small operators will be increased | ||
| masked_softmax_func: mask function to be applied. | ||
| dropout_func: dropout function to be applied. | ||
| """ | ||
|
|
||
| def compute_attn(q_layer, k_layer, v_layer, mask_tmp): | ||
| # [b, hn, q_size, hd] * [b, hn, hd, kv_size] -> [b, hn, q_size, kv_size] | ||
| cur_sim = torch.matmul(q_layer, k_layer) | ||
| attention_probs = npu_scaled_masked_softmax(cur_sim, mask_tmp) | ||
| # attention dropout | ||
| if dropout_p > 0: | ||
| attention_probs = torch.nn.functional.dropout( | ||
| attention_probs, p=dropout_p, training=attention_probs.require_grad | ||
| ) | ||
| # [b, hn, q_size, kv_size] * [b, hn, kv_size, hd] -> [b, hn, q_size, hd] | ||
| context_layer_tmp = torch.matmul(attention_probs, v_layer) | ||
| return context_layer_tmp | ||
|
|
||
| q, k, v = [rearrange(x, "b s h d -> b h s d") for x in (q, k, v)] | ||
| origin_attn_mask = origin_attn_mask.to(torch.bool) | ||
| # input shape: [b, hn, sq, hd] | ||
| bsz, head_num, sequence_len, head_dim = k.shape | ||
| sparse_groups = sequence_len // block_size | ||
| # Determine whether blocks size can be divided by sequence_length | ||
| divisible_flag = sequence_len == block_size * sparse_groups | ||
| k = k.transpose(2, 3).contiguous() | ||
| if divisible_flag: | ||
| q_tmp_layers = torch.chunk(q, sparse_groups, 2) | ||
| k_tmp_layers = torch.chunk(k, sparse_groups, 3) | ||
| v_tmp_layers = torch.chunk(v, sparse_groups, 2) | ||
| else: | ||
| seq_tmp = block_size * sparse_groups | ||
| q_last = q[:, :, seq_tmp:, :].contiguous() | ||
| mask_last = origin_attn_mask[:, :, seq_tmp:, :].contiguous() | ||
| q_tmp_layers = torch.chunk(q[:, :, :seq_tmp, :], sparse_groups, 2) | ||
| k_tmp_layers = torch.chunk(k[:, :, :, :seq_tmp], sparse_groups, 3) | ||
| v_tmp_layers = torch.chunk(v[:, :, :seq_tmp, :], sparse_groups, 2) | ||
| context_list_tmp, k_tmp, v_tmp = [], (), () | ||
| for i in range(sparse_groups): | ||
| # compute slice shape of q k v for each loop | ||
| q_begin, q_end = i * block_size, (i + 1) * block_size | ||
| kv_begin, kv_end = 0, (i + 1) * block_size | ||
| q_tmp = q_tmp_layers[i] | ||
| # slice k and v | ||
| if i == 0: | ||
| k_tmp = k_tmp_layers[i].contiguous() | ||
| v_tmp = v_tmp_layers[i].contiguous() | ||
| else: | ||
| k_tmp = torch.cat((k_tmp, k_tmp_layers[i]), -1).contiguous() | ||
| v_tmp = torch.cat((v_tmp, v_tmp_layers[i]), -2).contiguous() | ||
|
|
||
| mask_tmp = origin_attn_mask[:, :, q_begin:q_end, kv_begin:kv_end].contiguous() | ||
| context_layer_tmp = compute_attn(q_tmp, k_tmp, v_tmp, mask_tmp) | ||
| context_list_tmp.append(context_layer_tmp) | ||
|
|
||
| if not divisible_flag: | ||
| # circumstances that cannot be divisible | ||
| context_layer_tmp = compute_attn(q_last, k, v, mask_last) | ||
| context_list_tmp.append(context_layer_tmp) | ||
| context_layer = torch.cat(context_list_tmp, 2) | ||
| new_context_layer_shape = (bsz, sequence_len, head_num * head_dim) | ||
| context_layer = npu_confusion_transpose(context_layer, [0, 2, 1, 3], [*new_context_layer_shape], True) | ||
| # ========================= | ||
| # Context layer. [b, sq, hp] | ||
| # ========================= | ||
| return context_layer |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.