Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion colossalai/kernel/cuda_native/mha/flash_attn_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ def is_ampere_or_better_gpu():
HAS_FLASH_ATTN = False

if HAS_FLASH_ATTN:
pass

from .utils import SeqLenInfo

Expand Down
1 change: 1 addition & 0 deletions colossalai/kernel/cuda_native/mha/mha.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def forward(
key: torch.Tensor,
value: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
origin_attn_mask: Optional[torch.Tensor] = None,
attn_mask_type: Optional[AttnMaskType] = None,
bias: Optional[torch.Tensor] = None,
):
Expand Down
3 changes: 3 additions & 0 deletions colossalai/kernel/npu/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .mha import NPUColoAttention

__all__ = ["NPUColoAttention"]
3 changes: 3 additions & 0 deletions colossalai/kernel/npu/mha/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .mha import NPUColoAttention

__all__ = ["NPUColoAttention"]
80 changes: 80 additions & 0 deletions colossalai/kernel/npu/mha/mha.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
import math
from typing import Optional

import torch

from .sdpa_attn import npu_sdpa_attention
from .triangle_attn import HAS_NPU_TRIANGLE_ATTENTION


class NPUColoAttention(torch.nn.Module):
def __init__(self, embed_dim: int, num_heads: int, dropout: float = 0.0, scale: float = None):
super().__init__()

try:
Comment thread
ver217 marked this conversation as resolved.
import torch_npu # noqa
except ImportError:
raise Exception("torch_npu is not installed.")

assert (
embed_dim % num_heads == 0
), f"the embed dim ({embed_dim}) is not divisible by the number of attention heads ({num_heads})."
if scale is not None:
self.scale = scale
else:
self.scale = 1 / math.sqrt(embed_dim // num_heads)
self.dropout = dropout

def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
origin_attn_mask: Optional[torch.Tensor] = None,
attn_mask_type: int = None,
bias: Optional[torch.Tensor] = None,
):
"""
Implement the scaled dot product attention with softmax.

Arguments:
q: (batch, q_seqlen, nheads, headdim)
k: (batch, kv_seqlen, nheads, headdim)
v: (batch, kv_seqlen, nheads, headdim)
batch_size: int.
seq_len: int.
dropout_p: float. Dropout probability.
scale: float. The scaling of QK^T before applying softmax.
Default to 1.
Return:
attn_out: (batch, q_seqlen, nheads, headdim).
"""
assert (
len(query.shape) == 4 and len(key.shape) == 4 and len(value.shape) == 4
), f"query, key, value should be 4D tensors, but got {query.shape}, {key.shape}, {value.shape}"
assert (
query.device.type == "npu" and key.device.type == "npu" and value.device.type == "npu"
), f"query, key, value should be on npu device, but got {query.device}, {key.device}, {value.device}"
assert bias is None, "bias is not supported in npu colo attention"

causal = attn_mask_type is not None and attn_mask_type.value > 1

if HAS_NPU_TRIANGLE_ATTENTION:
from .triangle_attn import npu_triangle_attention

attn_fn = npu_triangle_attention
else:
attn_fn = npu_sdpa_attention

out = attn_fn(
query,
key,
value,
attn_mask=attn_mask,
origin_attn_mask=origin_attn_mask,
dropout_p=self.dropout,
scale=self.scale,
is_causal=causal,
)
return out
41 changes: 41 additions & 0 deletions colossalai/kernel/npu/mha/sdpa_attn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import torch
from einops import rearrange


def npu_sdpa_attention(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
attn_mask: torch.Tensor = None,
origin_attn_mask: torch.Tensor = None,
scale: float = 1.0,
dropout_p: float = 0.0,
is_causal: bool = True,
):
"""
The scaled dot product attention.

Arguments:
q: (batch, q_seqlen, nheads, headdim)
k: (batch, kv_seqlen, nheads, headdim)
v: (batch, kv_seqlen, nheads, headdim)
batch_size: int.
seq_len: int.
dropout_p: float. Dropout probability.
scale: float. The scaling of QK^T before applying softmax.
Default to 1.
Return:
attn_out: (batch, q_seqlen, nheads, headdim).
"""
q, k, v = [rearrange(x, "b s h d -> b h s d").contiguous() for x in (q, k, v)]
output = torch.nn.functional.scaled_dot_product_attention(
q,
k,
v,
attn_mask=origin_attn_mask,
dropout_p=dropout_p,
is_causal=origin_attn_mask is None,
scale=scale,
)
output = rearrange(output, "b h s d -> b s (h d)")
return output
115 changes: 115 additions & 0 deletions colossalai/kernel/npu/mha/triangle_attn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
# coding=utf-8
# Copyright (c) 2023, HUAWEI CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging

import torch
from einops import rearrange

HAS_NPU_TRIANGLE_ATTENTION = False
try:
from torch_npu import npu_confusion_transpose, npu_scaled_masked_softmax

HAS_NPU_TRIANGLE_ATTENTION = True
except ImportError:
logging.warning("Import torch_npu Error.")


if HAS_NPU_TRIANGLE_ATTENTION:

def npu_triangle_attention(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
attn_mask: torch.Tensor = None,
origin_attn_mask: torch.Tensor = None,
scale: float = 1.0,
dropout_p: float = 0.0,
is_causal: bool = True,
block_size=512,
):
"""
The triangle attention reduces the attention calculation of the mask
part by dividing the q, k, and v matrices into blocks

Arguments:
block_size: The size of the inverted triangle block, the default is 512,
the smaller the block_size, the more calculations will be reduced,
but the number of small operators will be increased
masked_softmax_func: mask function to be applied.
dropout_func: dropout function to be applied.
"""

def compute_attn(q_layer, k_layer, v_layer, mask_tmp):
# [b, hn, q_size, hd] * [b, hn, hd, kv_size] -> [b, hn, q_size, kv_size]
cur_sim = torch.matmul(q_layer, k_layer)
attention_probs = npu_scaled_masked_softmax(cur_sim, mask_tmp)
# attention dropout
if dropout_p > 0:
attention_probs = torch.nn.functional.dropout(
attention_probs, p=dropout_p, training=attention_probs.require_grad
)
# [b, hn, q_size, kv_size] * [b, hn, kv_size, hd] -> [b, hn, q_size, hd]
context_layer_tmp = torch.matmul(attention_probs, v_layer)
return context_layer_tmp

q, k, v = [rearrange(x, "b s h d -> b h s d") for x in (q, k, v)]
origin_attn_mask = origin_attn_mask.to(torch.bool)
# input shape: [b, hn, sq, hd]
bsz, head_num, sequence_len, head_dim = k.shape
sparse_groups = sequence_len // block_size
# Determine whether blocks size can be divided by sequence_length
divisible_flag = sequence_len == block_size * sparse_groups
k = k.transpose(2, 3).contiguous()
if divisible_flag:
q_tmp_layers = torch.chunk(q, sparse_groups, 2)
k_tmp_layers = torch.chunk(k, sparse_groups, 3)
v_tmp_layers = torch.chunk(v, sparse_groups, 2)
else:
seq_tmp = block_size * sparse_groups
q_last = q[:, :, seq_tmp:, :].contiguous()
mask_last = origin_attn_mask[:, :, seq_tmp:, :].contiguous()
q_tmp_layers = torch.chunk(q[:, :, :seq_tmp, :], sparse_groups, 2)
k_tmp_layers = torch.chunk(k[:, :, :, :seq_tmp], sparse_groups, 3)
v_tmp_layers = torch.chunk(v[:, :, :seq_tmp, :], sparse_groups, 2)
context_list_tmp, k_tmp, v_tmp = [], (), ()
for i in range(sparse_groups):
# compute slice shape of q k v for each loop
q_begin, q_end = i * block_size, (i + 1) * block_size
kv_begin, kv_end = 0, (i + 1) * block_size
q_tmp = q_tmp_layers[i]
# slice k and v
if i == 0:
k_tmp = k_tmp_layers[i].contiguous()
v_tmp = v_tmp_layers[i].contiguous()
else:
k_tmp = torch.cat((k_tmp, k_tmp_layers[i]), -1).contiguous()
v_tmp = torch.cat((v_tmp, v_tmp_layers[i]), -2).contiguous()

mask_tmp = origin_attn_mask[:, :, q_begin:q_end, kv_begin:kv_end].contiguous()
context_layer_tmp = compute_attn(q_tmp, k_tmp, v_tmp, mask_tmp)
context_list_tmp.append(context_layer_tmp)

if not divisible_flag:
# circumstances that cannot be divisible
context_layer_tmp = compute_attn(q_last, k, v, mask_last)
context_list_tmp.append(context_layer_tmp)
context_layer = torch.cat(context_list_tmp, 2)
new_context_layer_shape = (bsz, sequence_len, head_num * head_dim)
context_layer = npu_confusion_transpose(context_layer, [0, 2, 1, 3], [*new_context_layer_shape], True)
# =========================
# Context layer. [b, sq, hp]
# =========================
return context_layer
18 changes: 18 additions & 0 deletions colossalai/shardformer/layer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,3 +280,21 @@ def create_randomizer_with_offset(
Randomizer.increment_index()

return Randomizer(seed=base_seed)


def get_attention_kernel():
"""
Get the attention kernel based on the device type.
"""
from colossalai.kernel.cuda_native import AttnMaskType

if torch.cuda.is_available():
from colossalai.kernel.cuda_native import ColoAttention as AttentionKernel
else:
try:
torch.npu.is_available()
from colossalai.kernel.npu import NPUColoAttention as AttentionKernel
except:
raise Exception("No available device for attention kernel!")

return AttnMaskType, AttentionKernel
5 changes: 3 additions & 2 deletions colossalai/shardformer/modeling/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from transformers.utils import logging

from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer.layer.utils import get_attention_kernel

try:
from transformers.models.llama.modeling_llama import _prepare_4d_causal_attention_mask
Expand Down Expand Up @@ -404,7 +405,7 @@ def llama_for_sequence_classification_forward(
def get_llama_flash_attention_forward():
from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb

from colossalai.kernel.cuda_native import AttnMaskType, ColoAttention
AttnMaskType, ColoAttention = get_attention_kernel()

llama_version = 2
try:
Expand Down Expand Up @@ -468,7 +469,7 @@ def forward(

attention = ColoAttention(embed_dim=self.hidden_size, num_heads=self.num_heads)
attn_output = attention(
query_states, key_states, value_states, attn_mask=flash_attention_mask, attn_mask_type=attn_mask_type
query_states, key_states, value_states, attn_mask=flash_attention_mask, attn_mask_type=attn_mask_type, origin_attn_mask=attention_mask,
)

attn_output = self.o_proj(attn_output)
Expand Down