From 53790716894d320ac5caf01e3d50eee51653849d Mon Sep 17 00:00:00 2001 From: zxl <43881818+oahzxl@users.noreply.github.com> Date: Thu, 23 Nov 2023 18:36:51 +0800 Subject: [PATCH 1/7] update fused attn --- colossalai/kernel/npu/__init__.py | 3 ++ colossalai/kernel/npu/mha/__init__.py | 3 ++ colossalai/kernel/npu/mha/fused_attn.py | 65 +++++++++++++++++++++++++ colossalai/kernel/npu/mha/mha.py | 51 +++++++++++++++++++ 4 files changed, 122 insertions(+) create mode 100644 colossalai/kernel/npu/__init__.py create mode 100644 colossalai/kernel/npu/mha/__init__.py create mode 100644 colossalai/kernel/npu/mha/fused_attn.py create mode 100644 colossalai/kernel/npu/mha/mha.py diff --git a/colossalai/kernel/npu/__init__.py b/colossalai/kernel/npu/__init__.py new file mode 100644 index 000000000000..6a02c705559a --- /dev/null +++ b/colossalai/kernel/npu/__init__.py @@ -0,0 +1,3 @@ +from .mha import NPUColoAttention + +__all__ = ["NPUColoAttention"] diff --git a/colossalai/kernel/npu/mha/__init__.py b/colossalai/kernel/npu/mha/__init__.py new file mode 100644 index 000000000000..6a02c705559a --- /dev/null +++ b/colossalai/kernel/npu/mha/__init__.py @@ -0,0 +1,3 @@ +from .mha import NPUColoAttention + +__all__ = ["NPUColoAttention"] diff --git a/colossalai/kernel/npu/mha/fused_attn.py b/colossalai/kernel/npu/mha/fused_attn.py new file mode 100644 index 000000000000..471578e30a7a --- /dev/null +++ b/colossalai/kernel/npu/mha/fused_attn.py @@ -0,0 +1,65 @@ +import torch +import warnings + + +HAS_NPU_FUSED_ATTN = False +try: + from torch_npu import npu_fusion_attention + + HAS_NPU_FUSED_ATTN = True +except ImportError: + warnings.warn("please install torch_npu with npu_fusion_attention") + + +if HAS_NPU_FUSED_ATTN: + + def npu_fused_attention( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + attention_mask: torch.Tensor, + scale: float = 1.0, + dropout_p: float = 0.0, + ): + """ + Implement the scaled dot product attention with softmax. + + Arguments: + q: (batch, q_seqlen, nheads, headdim) + k: (batch, kv_seqlen, nheads, headdim) + v: (batch, kv_seqlen, nheads, headdim) + batch_size: int. + seq_len: int. + dropout_p: float. Dropout probability. + scale: float. The scaling of QK^T before applying softmax. + Default to 1. + Return: + attn_out: (batch, q_seqlen, nheads, headdim). + """ + batch, q_len, num_heads = q.shape[:3] + kv_len = k.shape[1] + matmul_result = torch.empty( + (batch, num_heads, q_len, kv_len), dtype=q.dtype, device=q.device + ) + output = npu_fusion_attention( + query=q, + key=k, + value=v, + head_num=num_heads, + input_layout="BSH", + pse=matmul_result, + padding_mask=None, + atten_mask=attention_mask, + scale=scale, + pre_tockens=kv_len, + next_tockens=0, + keep_prob=1 - dropout_p, + )[0] + return output + + + if __name__ == "__main__": + b, s, h, d = 4, 32, 16, 64 + q, k, v = [torch.rand(b, s, h * d).npu() for _ in range(3)] + context_layer = npu_fused_attention(q, k, v, None) + print(context_layer) diff --git a/colossalai/kernel/npu/mha/mha.py b/colossalai/kernel/npu/mha/mha.py new file mode 100644 index 000000000000..f72aae2bd210 --- /dev/null +++ b/colossalai/kernel/npu/mha/mha.py @@ -0,0 +1,51 @@ +import math +from typing import Optional + +import torch +from einops import rearrange + +from .fused_attn import HAS_NPU_FUSED_ATTN + +if HAS_NPU_FUSED_ATTN: + from .fused_attn import npu_fused_attention + + +class NPUColoAttention(torch.nn.Module): + def __init__(self, embed_dim: int, num_heads: int, dropout: float = 0.0, scale=None): + super().__init__() + assert ( + embed_dim % num_heads == 0 + ), f"the embed dim ({embed_dim}) is not divisible by the number of attention heads ({num_heads})." + if scale is not None: + self.scale = scale + else: + self.scale = 1 / math.sqrt(embed_dim // num_heads) + self.dropout = dropout + + if not HAS_NPU_FUSED_ATTN: + raise Exception("npu attention kernel can not support!") + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + attn_mask_type: int = None, + bias: Optional[torch.Tensor] = None, + ): + if HAS_NPU_FUSED_ATTN and query.dtype in [torch.float16, torch.bfloat16] and bias == None: + attn = npu_fused_attention + else: + raise Exception("npu attention kernel can not support!") + + out = attn( + query, + key, + value, + attention_mask=attn_mask, + dropout_p=self.dropout, + scale=self.scale, + ) + out = rearrange(out, "b s h d -> b s (h d)") + return out From 995c9d855687063fab016a7eda046037f445e502 Mon Sep 17 00:00:00 2001 From: zxl <43881818+oahzxl@users.noreply.github.com> Date: Fri, 24 Nov 2023 17:07:49 +0800 Subject: [PATCH 2/7] update spda --- .../kernel/cuda_native/mha/flash_attn_2.py | 1 - colossalai/kernel/cuda_native/mha/mha.py | 1 + colossalai/kernel/npu/mha/fused_attn.py | 65 ------------------- colossalai/kernel/npu/mha/mha.py | 54 ++++++++++----- colossalai/kernel/npu/mha/spda_attn.py | 42 ++++++++++++ colossalai/shardformer/layer/utils.py | 18 +++++ colossalai/shardformer/modeling/llama.py | 5 +- 7 files changed, 100 insertions(+), 86 deletions(-) delete mode 100644 colossalai/kernel/npu/mha/fused_attn.py create mode 100644 colossalai/kernel/npu/mha/spda_attn.py diff --git a/colossalai/kernel/cuda_native/mha/flash_attn_2.py b/colossalai/kernel/cuda_native/mha/flash_attn_2.py index 9ee83915b1b4..de2ccaa4947f 100644 --- a/colossalai/kernel/cuda_native/mha/flash_attn_2.py +++ b/colossalai/kernel/cuda_native/mha/flash_attn_2.py @@ -29,7 +29,6 @@ def is_ampere_or_better_gpu(): HAS_FLASH_ATTN = False if HAS_FLASH_ATTN: - pass from .utils import SeqLenInfo diff --git a/colossalai/kernel/cuda_native/mha/mha.py b/colossalai/kernel/cuda_native/mha/mha.py index 1c778439d33f..b56d37cf026e 100644 --- a/colossalai/kernel/cuda_native/mha/mha.py +++ b/colossalai/kernel/cuda_native/mha/mha.py @@ -44,6 +44,7 @@ def forward( key: torch.Tensor, value: torch.Tensor, attn_mask: Optional[torch.Tensor] = None, + origin_attn_mask: Optional[torch.Tensor] = None, attn_mask_type: Optional[AttnMaskType] = None, bias: Optional[torch.Tensor] = None, ): diff --git a/colossalai/kernel/npu/mha/fused_attn.py b/colossalai/kernel/npu/mha/fused_attn.py deleted file mode 100644 index 471578e30a7a..000000000000 --- a/colossalai/kernel/npu/mha/fused_attn.py +++ /dev/null @@ -1,65 +0,0 @@ -import torch -import warnings - - -HAS_NPU_FUSED_ATTN = False -try: - from torch_npu import npu_fusion_attention - - HAS_NPU_FUSED_ATTN = True -except ImportError: - warnings.warn("please install torch_npu with npu_fusion_attention") - - -if HAS_NPU_FUSED_ATTN: - - def npu_fused_attention( - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - attention_mask: torch.Tensor, - scale: float = 1.0, - dropout_p: float = 0.0, - ): - """ - Implement the scaled dot product attention with softmax. - - Arguments: - q: (batch, q_seqlen, nheads, headdim) - k: (batch, kv_seqlen, nheads, headdim) - v: (batch, kv_seqlen, nheads, headdim) - batch_size: int. - seq_len: int. - dropout_p: float. Dropout probability. - scale: float. The scaling of QK^T before applying softmax. - Default to 1. - Return: - attn_out: (batch, q_seqlen, nheads, headdim). - """ - batch, q_len, num_heads = q.shape[:3] - kv_len = k.shape[1] - matmul_result = torch.empty( - (batch, num_heads, q_len, kv_len), dtype=q.dtype, device=q.device - ) - output = npu_fusion_attention( - query=q, - key=k, - value=v, - head_num=num_heads, - input_layout="BSH", - pse=matmul_result, - padding_mask=None, - atten_mask=attention_mask, - scale=scale, - pre_tockens=kv_len, - next_tockens=0, - keep_prob=1 - dropout_p, - )[0] - return output - - - if __name__ == "__main__": - b, s, h, d = 4, 32, 16, 64 - q, k, v = [torch.rand(b, s, h * d).npu() for _ in range(3)] - context_layer = npu_fused_attention(q, k, v, None) - print(context_layer) diff --git a/colossalai/kernel/npu/mha/mha.py b/colossalai/kernel/npu/mha/mha.py index f72aae2bd210..75dc5100abeb 100644 --- a/colossalai/kernel/npu/mha/mha.py +++ b/colossalai/kernel/npu/mha/mha.py @@ -1,17 +1,13 @@ import math from typing import Optional - +from .spda_attn import npu_sdpa_attention import torch -from einops import rearrange - -from .fused_attn import HAS_NPU_FUSED_ATTN - -if HAS_NPU_FUSED_ATTN: - from .fused_attn import npu_fused_attention class NPUColoAttention(torch.nn.Module): - def __init__(self, embed_dim: int, num_heads: int, dropout: float = 0.0, scale=None): + def __init__( + self, embed_dim: int, num_heads: int, dropout: float = 0.0, scale: float = None + ): super().__init__() assert ( embed_dim % num_heads == 0 @@ -22,30 +18,52 @@ def __init__(self, embed_dim: int, num_heads: int, dropout: float = 0.0, scale=N self.scale = 1 / math.sqrt(embed_dim // num_heads) self.dropout = dropout - if not HAS_NPU_FUSED_ATTN: - raise Exception("npu attention kernel can not support!") - def forward( self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attn_mask: Optional[torch.Tensor] = None, + origin_attn_mask: Optional[torch.Tensor] = None, attn_mask_type: int = None, bias: Optional[torch.Tensor] = None, ): - if HAS_NPU_FUSED_ATTN and query.dtype in [torch.float16, torch.bfloat16] and bias == None: - attn = npu_fused_attention - else: - raise Exception("npu attention kernel can not support!") + """ + Implement the scaled dot product attention with softmax. + + Arguments: + q: (batch, q_seqlen, nheads, headdim) + k: (batch, kv_seqlen, nheads, headdim) + v: (batch, kv_seqlen, nheads, headdim) + batch_size: int. + seq_len: int. + dropout_p: float. Dropout probability. + scale: float. The scaling of QK^T before applying softmax. + Default to 1. + Return: + attn_out: (batch, q_seqlen, nheads, headdim). + """ + assert ( + len(query.shape) == 4 and len(key.shape) == 4 and len(value.shape) == 4 + ), f"query, key, value should be 4D tensors, but got {query.shape}, {key.shape}, {value.shape}" + assert ( + query.device.type == "npu" + and key.device.type == "npu" + and value.device.type == "npu" + ), f"query, key, value should be on npu device, but got {query.device}, {key.device}, {value.device}" + assert bias is None, "bias is not supported in npu colo attention" + + causal = attn_mask_type is not None and attn_mask_type.value > 1 + attn_fn = npu_sdpa_attention - out = attn( + out = attn_fn( query, key, value, - attention_mask=attn_mask, + attn_mask=attn_mask, + origin_attn_mask=origin_attn_mask, dropout_p=self.dropout, scale=self.scale, + is_causal=causal, ) - out = rearrange(out, "b s h d -> b s (h d)") return out diff --git a/colossalai/kernel/npu/mha/spda_attn.py b/colossalai/kernel/npu/mha/spda_attn.py new file mode 100644 index 000000000000..32ebb3f1567c --- /dev/null +++ b/colossalai/kernel/npu/mha/spda_attn.py @@ -0,0 +1,42 @@ +import torch +import torch_npu # noqa +from einops import rearrange + + +def npu_sdpa_attention( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + attn_mask: torch.Tensor = None, + origin_attn_mask: torch.Tensor = None, + scale: float = 1.0, + dropout_p: float = 0.0, + is_causal: bool = True, +): + """ + The scaled dot product attention. + + Arguments: + q: (batch, q_seqlen, nheads, headdim) + k: (batch, kv_seqlen, nheads, headdim) + v: (batch, kv_seqlen, nheads, headdim) + batch_size: int. + seq_len: int. + dropout_p: float. Dropout probability. + scale: float. The scaling of QK^T before applying softmax. + Default to 1. + Return: + attn_out: (batch, q_seqlen, nheads, headdim). + """ + q, k, v = [rearrange(x, "b s h d -> b h s d").contiguous() for x in (q, k, v)] + output = torch.nn.functional.scaled_dot_product_attention( + q, + k, + v, + attn_mask=origin_attn_mask, + dropout_p=dropout_p, + is_causal=origin_attn_mask is None, + scale=scale, + ) + output = rearrange(output, "b h s d -> b s (h d)") + return output diff --git a/colossalai/shardformer/layer/utils.py b/colossalai/shardformer/layer/utils.py index 4b6343adcd3b..55683b227be9 100644 --- a/colossalai/shardformer/layer/utils.py +++ b/colossalai/shardformer/layer/utils.py @@ -280,3 +280,21 @@ def create_randomizer_with_offset( Randomizer.increment_index() return Randomizer(seed=base_seed) + + +def get_attention_kernel(): + """ + Get the attention kernel based on the device type. + """ + from colossalai.kernel.cuda_native import AttnMaskType + + if torch.cuda.is_available(): + from colossalai.kernel.cuda_native import ColoAttention as AttentionKernel + else: + try: + torch.npu.is_available() + from colossalai.kernel.npu import NPUColoAttention as AttentionKernel + except: + raise Exception("No available device for attention kernel!") + + return AttnMaskType, AttentionKernel diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index 616c9220f4ab..c3de197c4354 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -12,6 +12,7 @@ from transformers.utils import logging from colossalai.pipeline.stage_manager import PipelineStageManager +from colossalai.shardformer.layer.utils import get_attention_kernel try: from transformers.models.llama.modeling_llama import _prepare_4d_causal_attention_mask @@ -404,7 +405,7 @@ def llama_for_sequence_classification_forward( def get_llama_flash_attention_forward(): from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb - from colossalai.kernel.cuda_native import AttnMaskType, ColoAttention + AttnMaskType, ColoAttention = get_attention_kernel() llama_version = 2 try: @@ -468,7 +469,7 @@ def forward( attention = ColoAttention(embed_dim=self.hidden_size, num_heads=self.num_heads) attn_output = attention( - query_states, key_states, value_states, attn_mask=flash_attention_mask, attn_mask_type=attn_mask_type + query_states, key_states, value_states, attn_mask=flash_attention_mask, attn_mask_type=attn_mask_type, origin_attn_mask=attention_mask, ) attn_output = self.o_proj(attn_output) From 4f5a0801a381e0daab7d95205277d34a9d72cf4a Mon Sep 17 00:00:00 2001 From: zxl <43881818+oahzxl@users.noreply.github.com> Date: Fri, 24 Nov 2023 17:33:39 +0800 Subject: [PATCH 3/7] tri attn --- colossalai/kernel/npu/mha/triangle_attn.py | 133 +++++++++++++++++++++ 1 file changed, 133 insertions(+) create mode 100644 colossalai/kernel/npu/mha/triangle_attn.py diff --git a/colossalai/kernel/npu/mha/triangle_attn.py b/colossalai/kernel/npu/mha/triangle_attn.py new file mode 100644 index 000000000000..f7fa6fcea49e --- /dev/null +++ b/colossalai/kernel/npu/mha/triangle_attn.py @@ -0,0 +1,133 @@ +# coding=utf-8 +# Copyright (c) 2023, HUAWEI CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import torch +import torch.nn as nn + +HAS_NPU_TRIANGLE_ATTENTION = False +try: + from torch_npu import npu_scaled_masked_softmax + from torch_npu import npu_confusion_transpose + HAS_NPU_TRIANGLE_ATTENTION = True +except ImportError: + logging.warning("Import torch_npu Error.") + + +class TriangleAttention(nn.Module): + """ + The triangle attention reduces the attention calculation of the mask + part by dividing the q, k, and v matrices into blocks + + Arguments: + block_size: The size of the inverted triangle block, the default is 512, + the smaller the block_size, the more calculations will be reduced, + but the number of small operators will be increased + masked_softmax_func: mask function to be applied. + dropout_func: dropout function to be applied. + """ + + def __init__(self, block_size=512, masked_softmax_func=None, dropout_func=None): + super(TriangleAttention, self).__init__() + self.block_size = block_size + self.mask_tmp_initialed = False + self.mask_tmp_groups = [] + if masked_softmax_func is not None: + self.scaled_masked_softmax = masked_softmax_func + else: + self.scaled_masked_softmax = npu_scaled_masked_softmax + if dropout_func: + self.dropout = True + self.attn_dropout = dropout_func + else: + self.dropout = False + + def compute_attn(self, q_layer, k_layer, v_layer, mask_tmp): + # [b, hn, q_size, hd] * [b, hn, hd, kv_size] -> [b, hn, q_size, kv_size] + cur_sim = torch.matmul(q_layer, k_layer) + + attention_probs = self.scaled_masked_softmax(cur_sim, mask_tmp) + + # attention dropout + if self.dropout: + attention_probs = self.attn_dropout(attention_probs) + + # [b, hn, q_size, kv_size] * [b, hn, kv_size, hd] -> [b, hn, q_size, hd] + context_layer_tmp = torch.matmul(attention_probs, v_layer) + return context_layer_tmp + + def forward(self, query_layer, key_layer, value_layer, attention_mask): + # input shape: [b, hn, sq, hd] + bsz, head_num, sequence_len, head_dim = key_layer.shape + sparse_groups = sequence_len // self.block_size + # Determine whether blocks size can be divided by sequence_length + flag = sequence_len == self.block_size * sparse_groups + key_layer = key_layer.transpose(2, 3).contiguous() + if flag: + q_tmp_layers = torch.chunk(query_layer, sparse_groups, 2) + k_tmp_layers = torch.chunk(key_layer, sparse_groups, 3) + v_tmp_layers = torch.chunk(value_layer, sparse_groups, 2) + else: + seq_tmp = self.block_size * sparse_groups + q_last = query_layer[:, :, seq_tmp:, :].contiguous() + mask_last = attention_mask[:, :, seq_tmp:, :].contiguous() + q_tmp_layers = torch.chunk(query_layer[:, :, :seq_tmp, :], sparse_groups, 2) + k_tmp_layers = torch.chunk(key_layer[:, :, :, :seq_tmp], sparse_groups, 3) + v_tmp_layers = torch.chunk(value_layer[:, :, :seq_tmp, :], sparse_groups, 2) + context_list_tmp, k_tmp, v_tmp = [], (), () + for i in range(sparse_groups): + # compute slice shape of q k v for each loop + q_begin, q_end = i * self.block_size, (i + 1) * self.block_size + kv_begin, kv_end = 0, (i + 1) * self.block_size + q_tmp = q_tmp_layers[i] + # slice k and v + if i == 0: + k_tmp = k_tmp_layers[i].contiguous() + v_tmp = v_tmp_layers[i].contiguous() + else: + k_tmp = torch.cat((k_tmp, k_tmp_layers[i]), -1).contiguous() + v_tmp = torch.cat((v_tmp, v_tmp_layers[i]), -2).contiguous() + + if not self.mask_tmp_initialed: + mask_tmp = attention_mask[:, :, q_begin:q_end, kv_begin:kv_end] + self.mask_tmp_groups.append(mask_tmp.contiguous()) + else: + mask_tmp = self.mask_tmp_groups[i] + + context_layer_tmp = self.compute_attn(q_tmp, k_tmp, v_tmp, mask_tmp) + context_list_tmp.append(context_layer_tmp) + + if not flag: + # circumstances that cannot be divisible + context_layer_tmp = self.compute_attn(q_last, key_layer, value_layer, mask_last) + context_list_tmp.append(context_layer_tmp) + context_layer = torch.cat(context_list_tmp, 2) + self.mask_tmp_initialed = True + new_context_layer_shape = (bsz, sequence_len, head_num * head_dim) + context_layer = npu_confusion_transpose(context_layer, [0, 2, 1, 3], [*new_context_layer_shape], True) + # ========================= + # Context layer. [b, sq, hp] + # ========================= + return context_layer + + +if __name__ == "__main__": + attn = TriangleAttention() + q, k, v = [torch.randn((2, 12, 1024, 64), requires_grad=True).npu().half() for _ in range(3)] + mask = torch.ones(2, 12, 1024, 1024).npu().bool() + out = attn(q, k, v, mask) + loss = out.sum() + loss.backward() + From 95d1cc4e45c947168bc8f37d2364166f3c863b7d Mon Sep 17 00:00:00 2001 From: zxl <43881818+oahzxl@users.noreply.github.com> Date: Fri, 24 Nov 2023 18:07:29 +0800 Subject: [PATCH 4/7] update triangle --- colossalai/kernel/npu/mha/mha.py | 14 +- colossalai/kernel/npu/mha/spda_attn.py | 2 +- colossalai/kernel/npu/mha/triangle_attn.py | 142 ++++++++++----------- 3 files changed, 84 insertions(+), 74 deletions(-) diff --git a/colossalai/kernel/npu/mha/mha.py b/colossalai/kernel/npu/mha/mha.py index 75dc5100abeb..6db6c9aaaedb 100644 --- a/colossalai/kernel/npu/mha/mha.py +++ b/colossalai/kernel/npu/mha/mha.py @@ -1,5 +1,6 @@ import math from typing import Optional +from .triangle_attn import HAS_NPU_TRIANGLE_ATTENTION from .spda_attn import npu_sdpa_attention import torch @@ -9,6 +10,12 @@ def __init__( self, embed_dim: int, num_heads: int, dropout: float = 0.0, scale: float = None ): super().__init__() + + try: + import torch_npu + except ImportError: + raise Exception("torch_npu is not installed.") + assert ( embed_dim % num_heads == 0 ), f"the embed dim ({embed_dim}) is not divisible by the number of attention heads ({num_heads})." @@ -54,7 +61,12 @@ def forward( assert bias is None, "bias is not supported in npu colo attention" causal = attn_mask_type is not None and attn_mask_type.value > 1 - attn_fn = npu_sdpa_attention + + if HAS_NPU_TRIANGLE_ATTENTION: + from .triangle_attn import npu_triangle_attention + attn_fn = npu_triangle_attention + else: + attn_fn = npu_sdpa_attention out = attn_fn( query, diff --git a/colossalai/kernel/npu/mha/spda_attn.py b/colossalai/kernel/npu/mha/spda_attn.py index 32ebb3f1567c..b9297e9aa277 100644 --- a/colossalai/kernel/npu/mha/spda_attn.py +++ b/colossalai/kernel/npu/mha/spda_attn.py @@ -1,5 +1,5 @@ import torch -import torch_npu # noqa +import torch_npu # noqa from einops import rearrange diff --git a/colossalai/kernel/npu/mha/triangle_attn.py b/colossalai/kernel/npu/mha/triangle_attn.py index f7fa6fcea49e..434d885610c5 100644 --- a/colossalai/kernel/npu/mha/triangle_attn.py +++ b/colossalai/kernel/npu/mha/triangle_attn.py @@ -16,81 +16,80 @@ import logging import torch import torch.nn as nn +from einops import rearrange HAS_NPU_TRIANGLE_ATTENTION = False try: from torch_npu import npu_scaled_masked_softmax from torch_npu import npu_confusion_transpose + HAS_NPU_TRIANGLE_ATTENTION = True except ImportError: logging.warning("Import torch_npu Error.") -class TriangleAttention(nn.Module): - """ - The triangle attention reduces the attention calculation of the mask - part by dividing the q, k, and v matrices into blocks - - Arguments: - block_size: The size of the inverted triangle block, the default is 512, - the smaller the block_size, the more calculations will be reduced, - but the number of small operators will be increased - masked_softmax_func: mask function to be applied. - dropout_func: dropout function to be applied. - """ - - def __init__(self, block_size=512, masked_softmax_func=None, dropout_func=None): - super(TriangleAttention, self).__init__() - self.block_size = block_size - self.mask_tmp_initialed = False - self.mask_tmp_groups = [] - if masked_softmax_func is not None: - self.scaled_masked_softmax = masked_softmax_func - else: - self.scaled_masked_softmax = npu_scaled_masked_softmax - if dropout_func: - self.dropout = True - self.attn_dropout = dropout_func - else: - self.dropout = False - - def compute_attn(self, q_layer, k_layer, v_layer, mask_tmp): - # [b, hn, q_size, hd] * [b, hn, hd, kv_size] -> [b, hn, q_size, kv_size] - cur_sim = torch.matmul(q_layer, k_layer) - - attention_probs = self.scaled_masked_softmax(cur_sim, mask_tmp) - - # attention dropout - if self.dropout: - attention_probs = self.attn_dropout(attention_probs) - - # [b, hn, q_size, kv_size] * [b, hn, kv_size, hd] -> [b, hn, q_size, hd] - context_layer_tmp = torch.matmul(attention_probs, v_layer) - return context_layer_tmp - - def forward(self, query_layer, key_layer, value_layer, attention_mask): +if HAS_NPU_TRIANGLE_ATTENTION: + + def npu_triangle_attention( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + attn_mask: torch.Tensor = None, + origin_attn_mask: torch.Tensor = None, + scale: float = 1.0, + dropout_p: float = 0.0, + is_causal: bool = True, + block_size=512, + ): + """ + The triangle attention reduces the attention calculation of the mask + part by dividing the q, k, and v matrices into blocks + + Arguments: + block_size: The size of the inverted triangle block, the default is 512, + the smaller the block_size, the more calculations will be reduced, + but the number of small operators will be increased + masked_softmax_func: mask function to be applied. + dropout_func: dropout function to be applied. + """ + + def compute_attn(q_layer, k_layer, v_layer, mask_tmp): + # [b, hn, q_size, hd] * [b, hn, hd, kv_size] -> [b, hn, q_size, kv_size] + cur_sim = torch.matmul(q_layer, k_layer) + attention_probs = npu_scaled_masked_softmax(cur_sim, mask_tmp) + # attention dropout + if dropout_p > 0: + attention_probs = torch.nn.functional.dropout( + attention_probs, p=dropout_p, training=attention_probs.require_grad + ) + # [b, hn, q_size, kv_size] * [b, hn, kv_size, hd] -> [b, hn, q_size, hd] + context_layer_tmp = torch.matmul(attention_probs, v_layer) + return context_layer_tmp + + q, k, v = [rearrange(x, "b s h d -> b h s d") for x in (q, k, v)] + origin_attn_mask = origin_attn_mask.to(torch.bool) # input shape: [b, hn, sq, hd] - bsz, head_num, sequence_len, head_dim = key_layer.shape - sparse_groups = sequence_len // self.block_size + bsz, head_num, sequence_len, head_dim = k.shape + sparse_groups = sequence_len // block_size # Determine whether blocks size can be divided by sequence_length - flag = sequence_len == self.block_size * sparse_groups - key_layer = key_layer.transpose(2, 3).contiguous() + flag = sequence_len == block_size * sparse_groups + k = k.transpose(2, 3).contiguous() if flag: - q_tmp_layers = torch.chunk(query_layer, sparse_groups, 2) - k_tmp_layers = torch.chunk(key_layer, sparse_groups, 3) - v_tmp_layers = torch.chunk(value_layer, sparse_groups, 2) + q_tmp_layers = torch.chunk(q, sparse_groups, 2) + k_tmp_layers = torch.chunk(k, sparse_groups, 3) + v_tmp_layers = torch.chunk(v, sparse_groups, 2) else: - seq_tmp = self.block_size * sparse_groups - q_last = query_layer[:, :, seq_tmp:, :].contiguous() - mask_last = attention_mask[:, :, seq_tmp:, :].contiguous() - q_tmp_layers = torch.chunk(query_layer[:, :, :seq_tmp, :], sparse_groups, 2) - k_tmp_layers = torch.chunk(key_layer[:, :, :, :seq_tmp], sparse_groups, 3) - v_tmp_layers = torch.chunk(value_layer[:, :, :seq_tmp, :], sparse_groups, 2) + seq_tmp = block_size * sparse_groups + q_last = q[:, :, seq_tmp:, :].contiguous() + mask_last = origin_attn_mask[:, :, seq_tmp:, :].contiguous() + q_tmp_layers = torch.chunk(q[:, :, :seq_tmp, :], sparse_groups, 2) + k_tmp_layers = torch.chunk(k[:, :, :, :seq_tmp], sparse_groups, 3) + v_tmp_layers = torch.chunk(v[:, :, :seq_tmp, :], sparse_groups, 2) context_list_tmp, k_tmp, v_tmp = [], (), () for i in range(sparse_groups): # compute slice shape of q k v for each loop - q_begin, q_end = i * self.block_size, (i + 1) * self.block_size - kv_begin, kv_end = 0, (i + 1) * self.block_size + q_begin, q_end = i * block_size, (i + 1) * block_size + kv_begin, kv_end = 0, (i + 1) * block_size q_tmp = q_tmp_layers[i] # slice k and v if i == 0: @@ -100,23 +99,21 @@ def forward(self, query_layer, key_layer, value_layer, attention_mask): k_tmp = torch.cat((k_tmp, k_tmp_layers[i]), -1).contiguous() v_tmp = torch.cat((v_tmp, v_tmp_layers[i]), -2).contiguous() - if not self.mask_tmp_initialed: - mask_tmp = attention_mask[:, :, q_begin:q_end, kv_begin:kv_end] - self.mask_tmp_groups.append(mask_tmp.contiguous()) - else: - mask_tmp = self.mask_tmp_groups[i] - - context_layer_tmp = self.compute_attn(q_tmp, k_tmp, v_tmp, mask_tmp) + mask_tmp = origin_attn_mask[ + :, :, q_begin:q_end, kv_begin:kv_end + ].contiguous() + context_layer_tmp = compute_attn(q_tmp, k_tmp, v_tmp, mask_tmp) context_list_tmp.append(context_layer_tmp) if not flag: # circumstances that cannot be divisible - context_layer_tmp = self.compute_attn(q_last, key_layer, value_layer, mask_last) + context_layer_tmp = compute_attn(q_last, k, v, mask_last) context_list_tmp.append(context_layer_tmp) context_layer = torch.cat(context_list_tmp, 2) - self.mask_tmp_initialed = True new_context_layer_shape = (bsz, sequence_len, head_num * head_dim) - context_layer = npu_confusion_transpose(context_layer, [0, 2, 1, 3], [*new_context_layer_shape], True) + context_layer = npu_confusion_transpose( + context_layer, [0, 2, 1, 3], [*new_context_layer_shape], True + ) # ========================= # Context layer. [b, sq, hp] # ========================= @@ -124,10 +121,11 @@ def forward(self, query_layer, key_layer, value_layer, attention_mask): if __name__ == "__main__": - attn = TriangleAttention() - q, k, v = [torch.randn((2, 12, 1024, 64), requires_grad=True).npu().half() for _ in range(3)] + q, k, v = [ + torch.randn((2, 12, 1024, 64), requires_grad=True).npu().half() + for _ in range(3) + ] mask = torch.ones(2, 12, 1024, 1024).npu().bool() - out = attn(q, k, v, mask) + out = npu_triangle_attention(q, k, v, origin_attn_mask=mask) loss = out.sum() loss.backward() - From 3b07b59c3f909ea13d428f486cae8b92961d90f9 Mon Sep 17 00:00:00 2001 From: zxl <43881818+oahzxl@users.noreply.github.com> Date: Fri, 24 Nov 2023 18:08:04 +0800 Subject: [PATCH 5/7] import --- colossalai/kernel/npu/mha/triangle_attn.py | 1 - 1 file changed, 1 deletion(-) diff --git a/colossalai/kernel/npu/mha/triangle_attn.py b/colossalai/kernel/npu/mha/triangle_attn.py index 434d885610c5..fc91114b1d2e 100644 --- a/colossalai/kernel/npu/mha/triangle_attn.py +++ b/colossalai/kernel/npu/mha/triangle_attn.py @@ -15,7 +15,6 @@ import logging import torch -import torch.nn as nn from einops import rearrange HAS_NPU_TRIANGLE_ATTENTION = False From 83f5d80b84c5e0204c37fd5a62f5c8de9f6003d3 Mon Sep 17 00:00:00 2001 From: Xuanlei Zhao Date: Thu, 30 Nov 2023 13:02:35 +0800 Subject: [PATCH 6/7] fix --- colossalai/kernel/npu/mha/mha.py | 17 +++++------ .../npu/mha/{spda_attn.py => sdpa_attn.py} | 3 +- colossalai/kernel/npu/mha/triangle_attn.py | 29 +++++-------------- 3 files changed, 17 insertions(+), 32 deletions(-) rename colossalai/kernel/npu/mha/{spda_attn.py => sdpa_attn.py} (97%) diff --git a/colossalai/kernel/npu/mha/mha.py b/colossalai/kernel/npu/mha/mha.py index 6db6c9aaaedb..01bd1c811ddb 100644 --- a/colossalai/kernel/npu/mha/mha.py +++ b/colossalai/kernel/npu/mha/mha.py @@ -1,18 +1,18 @@ import math from typing import Optional -from .triangle_attn import HAS_NPU_TRIANGLE_ATTENTION -from .spda_attn import npu_sdpa_attention + import torch +from .sdpa_attn import npu_sdpa_attention +from .triangle_attn import HAS_NPU_TRIANGLE_ATTENTION + class NPUColoAttention(torch.nn.Module): - def __init__( - self, embed_dim: int, num_heads: int, dropout: float = 0.0, scale: float = None - ): + def __init__(self, embed_dim: int, num_heads: int, dropout: float = 0.0, scale: float = None): super().__init__() try: - import torch_npu + pass except ImportError: raise Exception("torch_npu is not installed.") @@ -54,9 +54,7 @@ def forward( len(query.shape) == 4 and len(key.shape) == 4 and len(value.shape) == 4 ), f"query, key, value should be 4D tensors, but got {query.shape}, {key.shape}, {value.shape}" assert ( - query.device.type == "npu" - and key.device.type == "npu" - and value.device.type == "npu" + query.device.type == "npu" and key.device.type == "npu" and value.device.type == "npu" ), f"query, key, value should be on npu device, but got {query.device}, {key.device}, {value.device}" assert bias is None, "bias is not supported in npu colo attention" @@ -64,6 +62,7 @@ def forward( if HAS_NPU_TRIANGLE_ATTENTION: from .triangle_attn import npu_triangle_attention + attn_fn = npu_triangle_attention else: attn_fn = npu_sdpa_attention diff --git a/colossalai/kernel/npu/mha/spda_attn.py b/colossalai/kernel/npu/mha/sdpa_attn.py similarity index 97% rename from colossalai/kernel/npu/mha/spda_attn.py rename to colossalai/kernel/npu/mha/sdpa_attn.py index b9297e9aa277..349e4b4da4d8 100644 --- a/colossalai/kernel/npu/mha/spda_attn.py +++ b/colossalai/kernel/npu/mha/sdpa_attn.py @@ -1,7 +1,8 @@ import torch -import torch_npu # noqa from einops import rearrange +import torch_npu # noqa + def npu_sdpa_attention( q: torch.Tensor, diff --git a/colossalai/kernel/npu/mha/triangle_attn.py b/colossalai/kernel/npu/mha/triangle_attn.py index fc91114b1d2e..619076d5f888 100644 --- a/colossalai/kernel/npu/mha/triangle_attn.py +++ b/colossalai/kernel/npu/mha/triangle_attn.py @@ -14,13 +14,13 @@ # limitations under the License. import logging + import torch from einops import rearrange HAS_NPU_TRIANGLE_ATTENTION = False try: - from torch_npu import npu_scaled_masked_softmax - from torch_npu import npu_confusion_transpose + from torch_npu import npu_confusion_transpose, npu_scaled_masked_softmax HAS_NPU_TRIANGLE_ATTENTION = True except ImportError: @@ -71,9 +71,9 @@ def compute_attn(q_layer, k_layer, v_layer, mask_tmp): bsz, head_num, sequence_len, head_dim = k.shape sparse_groups = sequence_len // block_size # Determine whether blocks size can be divided by sequence_length - flag = sequence_len == block_size * sparse_groups + divisible_flag = sequence_len == block_size * sparse_groups k = k.transpose(2, 3).contiguous() - if flag: + if divisible_flag: q_tmp_layers = torch.chunk(q, sparse_groups, 2) k_tmp_layers = torch.chunk(k, sparse_groups, 3) v_tmp_layers = torch.chunk(v, sparse_groups, 2) @@ -98,33 +98,18 @@ def compute_attn(q_layer, k_layer, v_layer, mask_tmp): k_tmp = torch.cat((k_tmp, k_tmp_layers[i]), -1).contiguous() v_tmp = torch.cat((v_tmp, v_tmp_layers[i]), -2).contiguous() - mask_tmp = origin_attn_mask[ - :, :, q_begin:q_end, kv_begin:kv_end - ].contiguous() + mask_tmp = origin_attn_mask[:, :, q_begin:q_end, kv_begin:kv_end].contiguous() context_layer_tmp = compute_attn(q_tmp, k_tmp, v_tmp, mask_tmp) context_list_tmp.append(context_layer_tmp) - if not flag: + if not divisible_flag: # circumstances that cannot be divisible context_layer_tmp = compute_attn(q_last, k, v, mask_last) context_list_tmp.append(context_layer_tmp) context_layer = torch.cat(context_list_tmp, 2) new_context_layer_shape = (bsz, sequence_len, head_num * head_dim) - context_layer = npu_confusion_transpose( - context_layer, [0, 2, 1, 3], [*new_context_layer_shape], True - ) + context_layer = npu_confusion_transpose(context_layer, [0, 2, 1, 3], [*new_context_layer_shape], True) # ========================= # Context layer. [b, sq, hp] # ========================= return context_layer - - -if __name__ == "__main__": - q, k, v = [ - torch.randn((2, 12, 1024, 64), requires_grad=True).npu().half() - for _ in range(3) - ] - mask = torch.ones(2, 12, 1024, 1024).npu().bool() - out = npu_triangle_attention(q, k, v, origin_attn_mask=mask) - loss = out.sum() - loss.backward() From 34b83f1b42e46710d4736637625b5cdf38ab3a36 Mon Sep 17 00:00:00 2001 From: Xuanlei Zhao Date: Thu, 30 Nov 2023 13:27:13 +0800 Subject: [PATCH 7/7] fix --- colossalai/kernel/npu/mha/mha.py | 2 +- colossalai/kernel/npu/mha/sdpa_attn.py | 2 -- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/colossalai/kernel/npu/mha/mha.py b/colossalai/kernel/npu/mha/mha.py index 01bd1c811ddb..ac982384e518 100644 --- a/colossalai/kernel/npu/mha/mha.py +++ b/colossalai/kernel/npu/mha/mha.py @@ -12,7 +12,7 @@ def __init__(self, embed_dim: int, num_heads: int, dropout: float = 0.0, scale: super().__init__() try: - pass + import torch_npu # noqa except ImportError: raise Exception("torch_npu is not installed.") diff --git a/colossalai/kernel/npu/mha/sdpa_attn.py b/colossalai/kernel/npu/mha/sdpa_attn.py index 349e4b4da4d8..2af1dbae2e67 100644 --- a/colossalai/kernel/npu/mha/sdpa_attn.py +++ b/colossalai/kernel/npu/mha/sdpa_attn.py @@ -1,8 +1,6 @@ import torch from einops import rearrange -import torch_npu # noqa - def npu_sdpa_attention( q: torch.Tensor,