diff --git a/README.md b/README.md
index 21670e1e59fb..44e4f97f1f4e 100644
--- a/README.md
+++ b/README.md
@@ -25,6 +25,7 @@
## Latest News
+* [2023/07] [HPC-AI Tech Raises 22 Million USD in Series A Funding](https://www.hpc-ai.tech/blog/hpc-ai-tech-raises-22-million-usd-in-series-a-funding-to-fuel-team-expansion-and-business-growth)
* [2023/07] [65B Model Pretraining Accelerated by 38%, Best Practices for Building LLaMA-Like Base Models Open-Source](https://www.hpc-ai.tech/blog/large-model-pretraining)
* [2023/03] [ColossalChat: An Open-Source Solution for Cloning ChatGPT With a Complete RLHF Pipeline](https://medium.com/@yangyou_berkeley/colossalchat-an-open-source-solution-for-cloning-chatgpt-with-a-complete-rlhf-pipeline-5edf08fb538b)
* [2023/03] [Intel and Colossal-AI Partner to Deliver Cost-Efficient Open-Source Solution for Protein Folding Structure Prediction](https://www.hpc-ai.tech/blog/intel-habana)
@@ -33,7 +34,6 @@
* [2023/01] [Hardware Savings Up to 46 Times for AIGC and Automatic Parallelism](https://medium.com/pytorch/latest-colossal-ai-boasts-novel-automatic-parallelism-and-offers-savings-up-to-46x-for-stable-1453b48f3f02)
* [2022/11] [Diffusion Pretraining and Hardware Fine-Tuning Can Be Almost 7X Cheaper](https://www.hpc-ai.tech/blog/diffusion-pretraining-and-hardware-fine-tuning-can-be-almost-7x-cheaper)
* [2022/10] [Use a Laptop to Analyze 90% of Proteins, With a Single-GPU Inference Sequence Exceeding 10,000](https://www.hpc-ai.tech/blog/use-a-laptop-to-analyze-90-of-proteins-with-a-single-gpu-inference-sequence-exceeding)
-* [2022/09] [HPC-AI Tech Completes $6 Million Seed and Angel Round Fundraising](https://www.hpc-ai.tech/blog/hpc-ai-tech-completes-6-million-seed-and-angel-round-fundraising-led-by-bluerun-ventures-in-the)
## Table of Contents
@@ -463,6 +463,7 @@ To cite this project, you can use the following BibTeX citation.
}
```
-Colossal-AI has been accepted as official tutorial by top conferences [SC](https://sc22.supercomputing.org/), [AAAI](https://aaai.org/Conferences/AAAI-23/), [PPoPP](https://ppopp23.sigplan.org/), [CVPR](https://cvpr2023.thecvf.com/), [ISC](https://www.isc-hpc.com/), etc.
+Colossal-AI has been accepted as official tutorial by top conferences [NeurIPS](https://nips.cc/), [SC](https://sc22.supercomputing.org/), [AAAI](https://aaai.org/Conferences/AAAI-23/),
+[PPoPP](https://ppopp23.sigplan.org/), [CVPR](https://cvpr2023.thecvf.com/), [ISC](https://www.isc-hpc.com/), [NVIDIA GTC](https://www.nvidia.com/en-us/on-demand/session/gtcspring23-S51482/) ,etc.
(back to top)
diff --git a/colossalai/kernel/cuda_native/__init__.py b/colossalai/kernel/cuda_native/__init__.py
index 1d5a6ce495bd..4910717b5723 100644
--- a/colossalai/kernel/cuda_native/__init__.py
+++ b/colossalai/kernel/cuda_native/__init__.py
@@ -1,5 +1,8 @@
from .layer_norm import MixedFusedLayerNorm as LayerNorm
+from .mha.mha import ColoAttention
from .multihead_attention import MultiHeadAttention
from .scaled_softmax import FusedScaleMaskSoftmax, ScaledUpperTriangMaskedSoftmax
-__all__ = ['LayerNorm', 'MultiHeadAttention', 'FusedScaleMaskSoftmax', 'ScaledUpperTriangMaskedSoftmax']
+__all__ = [
+ 'LayerNorm', 'MultiHeadAttention', 'FusedScaleMaskSoftmax', 'ScaledUpperTriangMaskedSoftmax', 'ColoAttention'
+]
diff --git a/colossalai/kernel/cuda_native/flash_attention.py b/colossalai/kernel/cuda_native/flash_attention.py
deleted file mode 100644
index 3db7374509a0..000000000000
--- a/colossalai/kernel/cuda_native/flash_attention.py
+++ /dev/null
@@ -1,635 +0,0 @@
-"""
-A general attention module using the flash attention kernels from xformers:
-https://github.com/facebookresearch/xformers/tree/main/xformers/ops/fmha
-"""
-
-import math
-import os
-import subprocess
-
-import torch
-
-try:
- from xformers.ops.fmha import memory_efficient_attention
- HAS_MEM_EFF_ATTN = True
-except ImportError:
- HAS_MEM_EFF_ATTN = False
- print('please install xformers from https://github.com/facebookresearch/xformers')
-
-if HAS_MEM_EFF_ATTN:
-
- from typing import Optional
-
- from einops import rearrange
- from xformers.ops.fmha import MemoryEfficientAttentionCutlassOp
- from xformers.ops.fmha.attn_bias import BlockDiagonalMask, LowerTriangularMask, LowerTriangularMaskWithTensorBias
-
- from .scaled_softmax import AttnMaskType
-
- allow_alibi = True
- for op in MemoryEfficientAttentionCutlassOp:
- allow_alibi = allow_alibi & (LowerTriangularMaskWithTensorBias in op.SUPPORTED_ATTN_BIAS_TYPES)
-
- class Unpad(torch.autograd.Function):
- """
- Adapted from
- https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/bert_padding.py
- """
-
- @staticmethod
- def forward(ctx, tensor: torch.Tensor, indices: torch.Tensor):
- ctx.save_for_backward(indices)
- # [b, s, ...]
- assert tensor.ndim >= 3
- ctx.bsz = tensor.shape[0]
- out = rearrange(tensor, 'b s ... -> (b s) ...')
- ctx.shape = out.shape
- # [1, ntokens, ...]
- return out[indices].unsqueeze(0)
-
- @staticmethod
- def backward(ctx, grad_output):
- indices, = ctx.saved_tensors
- # [b*s, ...]
- grad = torch.zeros(ctx.shape, dtype=grad_output.dtype, device=grad_output.device)
- grad[indices] = grad_output.squeeze(0)
- grad = rearrange(grad, '(b s) ... -> b s ...', b=ctx.bsz)
- # [b, s, ...]
- return grad, None
-
- class Repad(torch.autograd.Function):
- """
- Adapted from
- https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/bert_padding.py
- """
-
- @staticmethod
- def forward(ctx, tensor: torch.Tensor, indices: torch.Tensor, batch_size: int, seq_len: int):
- ctx.save_for_backward(indices)
- # [ntokens, ...]
- tensor = tensor.squeeze(0)
- out = torch.zeros((batch_size * seq_len, *tensor.shape[1:]), dtype=tensor.dtype, device=tensor.device)
- # [b*s, ...]
- out[indices] = tensor
- # [b, s, ...]
- out = rearrange(out, '(b s) ... -> b s ...', b=batch_size)
- return out
-
- @staticmethod
- def backward(ctx, grad_output):
- indices, = ctx.saved_tensors
- # [b*s, ...]
- grad_output = rearrange(grad_output, 'b s ... -> (b s) ...')
- grad = grad_output[indices]
- # [1, ntokens, ...]
- return grad.unsqueeze(0), None, None, None
-
- class ColoAttention(torch.nn.Module):
-
- def __init__(self, embed_dim: int, num_heads: int, dropout: float = 0.0):
- super().__init__()
- assert embed_dim % num_heads == 0, \
- f"the embed dim ({embed_dim}) is not divisible by the number of attention heads ({num_heads})."
- self.scale = 1 / math.sqrt(embed_dim // num_heads)
- self.dropout = dropout
-
- @staticmethod
- def get_seq_info_from_mask(attn_mask: torch.Tensor):
- indices = torch.nonzero(attn_mask.flatten(), as_tuple=False).flatten()
- seqlens = attn_mask.sum(dim=-1, dtype=torch.int32).flatten().tolist()
- return indices, seqlens
-
- @staticmethod
- def unpad(tensor: torch.Tensor, indices: torch.Tensor) -> torch.Tensor:
- return Unpad.apply(tensor, indices)
-
- @staticmethod
- def repad(tensor: torch.Tensor, indices: torch.Tensor, batch_size: int, seq_len: int) -> torch.Tensor:
- return Repad.apply(tensor, indices, batch_size, seq_len)
-
- def forward(self,
- query: torch.Tensor,
- key: torch.Tensor,
- value: torch.Tensor,
- attn_mask: Optional[torch.Tensor] = None,
- attn_mask_type: Optional[AttnMaskType] = None,
- bias: Optional[torch.Tensor] = None):
- batch_size, tgt_len, src_len = query.shape[0], query.shape[1], key.shape[1]
- attn_bias = None
- if attn_mask_type == AttnMaskType.padding: # bert style
- assert attn_mask is not None, \
- f"attention mask {attn_mask} is not valid for attention mask type {attn_mask_type}."
- assert attn_mask.dim() == 2, \
- "attention mask is supposed to have shape (batch_size, seq_len), " + \
- f"but got {attn_mask.dim()} dimensions."
- if tgt_len == src_len:
- q_indices, q_seqlen = self.get_seq_info_from_mask(attn_mask)
- kv_seqlen = None
- if batch_size > 1:
- query, key, value = self.unpad(torch.stack([query, key, value], dim=2), q_indices).unbind(dim=2)
- else:
- q_indices = torch.arange(batch_size * tgt_len, dtype=torch.int32, device=query.device)
- q_seqlen = torch.LongTensor([tgt_len] * batch_size, device=query.device)
- kv_indices, kv_seqlen = self.get_seq_info_from_mask(attn_mask)
- if batch_size > 1:
- query = rearrange(query, "b s ... -> c (b s) ...", c=1)
- key, value = self.unpad(torch.stack([query, key, value], dim=2), kv_indices).unbind(dim=2)
- attn_bias = BlockDiagonalMask.from_seqlens(q_seqlen, kv_seqlen)
- elif attn_mask_type == AttnMaskType.causal: # gpt style
- attn_bias = LowerTriangularMask()
-
- if bias is not None: # alibi / relative position embedding
- assert allow_alibi, "flash attention with bias is not supported in this system."
- assert attn_mask_type == AttnMaskType.causal, \
- "attention with bias is only supported for causal attention so far."
- attn_bias = attn_bias.add_bias(bias)
-
- out = memory_efficient_attention(query, key, value, attn_bias=attn_bias, p=self.dropout, scale=self.scale)
-
- if attn_mask_type == AttnMaskType.padding and batch_size > 1:
- out = self.repad(out, q_indices, batch_size, tgt_len)
-
- out = rearrange(out, 'b s h d -> b s (h d)')
- return out
-
-
-##########################################################################
-# the flash attention functions below that are copied
-# from the OpenAI/triton repository will be deprecated
-# You can find the repository in Triton https://github.com/openai/triton
-# You can find the source file in https://github.com/openai/triton/blob/main/python/tutorials/06-fused-attention.py
-# Reference:
-# 1. Dao et al., https://arxiv.org/pdf/2205.14135v2.pdf
-# 2. Rabe and Staats https://arxiv.org/pdf/2112.05682v2.pdf
-
-
-def triton_cuda_check():
- cuda_home = os.getenv("CUDA_HOME", default="/usr/local/cuda")
- cuda_version = subprocess.check_output([os.path.join(cuda_home, "bin/nvcc"), "--version"]).decode().strip()
- cuda_version = cuda_version.split('release ')[1]
- cuda_version = cuda_version.split(',')[0]
- cuda_version = cuda_version.split('.')
- if len(cuda_version) == 2 and \
- (int(cuda_version[0]) == 11 and int(cuda_version[1]) >= 4) or \
- int(cuda_version[0]) > 11:
- return True
- return False
-
-
-try:
- import triton
- import triton.language as tl
- if triton_cuda_check():
- HAS_TRITON = True
- else:
- print("triton requires cuda >= 11.4")
- HAS_TRITON = False
-except ImportError:
- print('please install triton from https://github.com/openai/triton')
- HAS_TRITON = False
-try:
- from flash_attn.flash_attention import FlashAttention
- from flash_attn.flash_attn_interface import (
- flash_attn_unpadded_func,
- flash_attn_unpadded_kvpacked_func,
- flash_attn_unpadded_qkvpacked_func,
- )
- HAS_FLASH_ATTN = True
-except ImportError:
- HAS_FLASH_ATTN = False
- print('please install flash_attn from https://github.com/HazyResearch/flash-attention')
-
-if HAS_TRITON:
- # the following functions are adapted from the OpenAI Triton tutorial
- # https://github.com/openai/triton/blob/main/python/tutorials/06-fused-attention.py
- @triton.jit
- def _fwd_kernel(
- Q,
- K,
- V,
- sm_scale,
- TMP,
- L,
- M, # NOTE: TMP is a scratchpad buffer to workaround a compiler bug
- Out,
- stride_qz,
- stride_qh,
- stride_qm,
- stride_qk,
- stride_kz,
- stride_kh,
- stride_kn,
- stride_kk,
- stride_vz,
- stride_vh,
- stride_vk,
- stride_vn,
- stride_oz,
- stride_oh,
- stride_om,
- stride_on,
- Z,
- H,
- N_CTX,
- BLOCK_M: tl.constexpr,
- BLOCK_DMODEL: tl.constexpr,
- BLOCK_N: tl.constexpr,
- ):
- start_m = tl.program_id(0)
- off_hz = tl.program_id(1)
- # initialize offsets
- offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
- offs_n = tl.arange(0, BLOCK_N)
- offs_d = tl.arange(0, BLOCK_DMODEL)
- off_q = off_hz * stride_qh + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qk
- off_k = off_hz * stride_qh + offs_n[:, None] * stride_kn + offs_d[None, :] * stride_kk
- off_v = off_hz * stride_qh + offs_n[:, None] * stride_qm + offs_d[None, :] * stride_qk
- # Initialize pointers to Q, K, V
- q_ptrs = Q + off_q
- k_ptrs = K + off_k
- v_ptrs = V + off_v
- # initialize pointer to m and l
- t_ptrs = TMP + off_hz * N_CTX + offs_m
- m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
- l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
- acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
- # load q: it will stay in SRAM throughout
- q = tl.load(q_ptrs)
- # loop over k, v and update accumulator
- for start_n in range(0, (start_m + 1) * BLOCK_M, BLOCK_N):
- start_n = tl.multiple_of(start_n, BLOCK_N)
- # -- compute qk ----
- k = tl.load(k_ptrs + start_n * stride_kn)
- qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
- qk += tl.dot(q, k, trans_b=True)
- qk *= sm_scale
- qk += tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), 0, float("-inf"))
- # -- compute m_ij, p, l_ij
- m_ij = tl.max(qk, 1)
- p = tl.exp(qk - m_ij[:, None])
- l_ij = tl.sum(p, 1)
- # -- update m_i and l_i
- m_i_new = tl.maximum(m_i, m_ij)
- alpha = tl.exp(m_i - m_i_new)
- beta = tl.exp(m_ij - m_i_new)
- l_i_new = alpha * l_i + beta * l_ij
- # -- update output accumulator --
- # scale p
- p_scale = beta / l_i_new
- p = p * p_scale[:, None]
- # scale acc
- acc_scale = l_i / l_i_new * alpha
- tl.store(t_ptrs, acc_scale)
- acc_scale = tl.load(t_ptrs) # BUG: have to store and immediately load
- acc = acc * acc_scale[:, None]
- # update acc
- v = tl.load(v_ptrs + start_n * stride_vk)
- p = p.to(tl.float16)
- acc += tl.dot(p, v)
- # update m_i and l_i
- l_i = l_i_new
- m_i = m_i_new
- # rematerialize offsets to save registers
- start_m = tl.program_id(0)
- offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
- # write back l and m
- l_ptrs = L + off_hz * N_CTX + offs_m
- m_ptrs = M + off_hz * N_CTX + offs_m
- tl.store(l_ptrs, l_i)
- tl.store(m_ptrs, m_i)
- # initialize pointers to output
- offs_n = tl.arange(0, BLOCK_DMODEL)
- off_o = off_hz * stride_oh + offs_m[:, None] * stride_om + offs_n[None, :] * stride_on
- out_ptrs = Out + off_o
- tl.store(out_ptrs, acc)
-
- @triton.jit
- def _bwd_preprocess(
- Out,
- DO,
- L,
- NewDO,
- Delta,
- BLOCK_M: tl.constexpr,
- D_HEAD: tl.constexpr,
- ):
- off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)
- off_n = tl.arange(0, D_HEAD)
- # load
- o = tl.load(Out + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32)
- do = tl.load(DO + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32)
- denom = tl.load(L + off_m).to(tl.float32)
- # compute
- do = do / denom[:, None]
- delta = tl.sum(o * do, axis=1)
- # write-back
- tl.store(NewDO + off_m[:, None] * D_HEAD + off_n[None, :], do)
- tl.store(Delta + off_m, delta)
-
- @triton.jit
- def _bwd_kernel(
- Q,
- K,
- V,
- sm_scale,
- Out,
- DO,
- DQ,
- DK,
- DV,
- L,
- M,
- D,
- stride_qz,
- stride_qh,
- stride_qm,
- stride_qk,
- stride_kz,
- stride_kh,
- stride_kn,
- stride_kk,
- stride_vz,
- stride_vh,
- stride_vk,
- stride_vn,
- Z,
- H,
- N_CTX,
- num_block,
- BLOCK_M: tl.constexpr,
- BLOCK_DMODEL: tl.constexpr,
- BLOCK_N: tl.constexpr,
- ):
- off_hz = tl.program_id(0)
- off_z = off_hz // H
- off_h = off_hz % H
- # offset pointers for batch/head
- Q += off_z * stride_qz + off_h * stride_qh
- K += off_z * stride_qz + off_h * stride_qh
- V += off_z * stride_qz + off_h * stride_qh
- DO += off_z * stride_qz + off_h * stride_qh
- DQ += off_z * stride_qz + off_h * stride_qh
- DK += off_z * stride_qz + off_h * stride_qh
- DV += off_z * stride_qz + off_h * stride_qh
- for start_n in range(0, num_block):
- lo = start_n * BLOCK_M
- # initialize row/col offsets
- offs_qm = lo + tl.arange(0, BLOCK_M)
- offs_n = start_n * BLOCK_M + tl.arange(0, BLOCK_M)
- offs_m = tl.arange(0, BLOCK_N)
- offs_k = tl.arange(0, BLOCK_DMODEL)
- # initialize pointers to value-like data
- q_ptrs = Q + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk)
- k_ptrs = K + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk)
- v_ptrs = V + (offs_n[:, None] * stride_qm + offs_k[None, :] * stride_qk)
- do_ptrs = DO + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk)
- dq_ptrs = DQ + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk)
- # pointer to row-wise quantities in value-like data
- D_ptrs = D + off_hz * N_CTX
- m_ptrs = M + off_hz * N_CTX
- # initialize dv amd dk
- dv = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
- dk = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
- # k and v stay in SRAM throughout
- k = tl.load(k_ptrs)
- v = tl.load(v_ptrs)
- # loop over rows
- for start_m in range(lo, num_block * BLOCK_M, BLOCK_M):
- offs_m_curr = start_m + offs_m
- # load q, k, v, do on-chip
- q = tl.load(q_ptrs)
- # recompute p = softmax(qk, dim=-1).T
- # NOTE: `do` is pre-divided by `l`; no normalization here
- qk = tl.dot(q, k, trans_b=True)
- qk = tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), qk, float("-inf"))
- m = tl.load(m_ptrs + offs_m_curr)
- p = tl.exp(qk * sm_scale - m[:, None])
- # compute dv
- do = tl.load(do_ptrs)
- dv += tl.dot(p.to(tl.float16), do, trans_a=True)
- # compute dp = dot(v, do)
- Di = tl.load(D_ptrs + offs_m_curr)
- dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - Di[:, None]
- dp += tl.dot(do, v, trans_b=True)
- # compute ds = p * (dp - delta[:, None])
- ds = p * dp * sm_scale
- # compute dk = dot(ds.T, q)
- dk += tl.dot(ds.to(tl.float16), q, trans_a=True)
- # # compute dq
- dq = tl.load(dq_ptrs, eviction_policy="evict_last")
- dq += tl.dot(ds.to(tl.float16), k)
- tl.store(dq_ptrs, dq, eviction_policy="evict_last")
- # # increment pointers
- dq_ptrs += BLOCK_M * stride_qm
- q_ptrs += BLOCK_M * stride_qm
- do_ptrs += BLOCK_M * stride_qm
- # write-back
- dv_ptrs = DV + (offs_n[:, None] * stride_qm + offs_k[None, :] * stride_qk)
- dk_ptrs = DK + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk)
- tl.store(dv_ptrs, dv)
- tl.store(dk_ptrs, dk)
-
- class _TritonFlashAttention(torch.autograd.Function):
-
- @staticmethod
- def forward(ctx, q, k, v, sm_scale):
- BLOCK = 128
- # shape constraints
- Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
- assert Lq == Lk and Lk == Lv
- assert Lk in {16, 32, 64, 128}
- o = torch.empty_like(q)
- grid = (triton.cdiv(q.shape[2], BLOCK), q.shape[0] * q.shape[1])
- tmp = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
- L = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
- m = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
- num_warps = 4 if Lk <= 64 else 8
-
- _fwd_kernel[grid](
- q,
- k,
- v,
- sm_scale,
- tmp,
- L,
- m,
- o,
- q.stride(0),
- q.stride(1),
- q.stride(2),
- q.stride(3),
- k.stride(0),
- k.stride(1),
- k.stride(2),
- k.stride(3),
- v.stride(0),
- v.stride(1),
- v.stride(2),
- v.stride(3),
- o.stride(0),
- o.stride(1),
- o.stride(2),
- o.stride(3),
- q.shape[0],
- q.shape[1],
- q.shape[2],
- BLOCK_M=BLOCK,
- BLOCK_N=BLOCK,
- BLOCK_DMODEL=Lk,
- num_warps=num_warps,
- num_stages=1,
- )
- ctx.save_for_backward(q, k, v, o, L, m)
- ctx.BLOCK = BLOCK
- ctx.grid = grid
- ctx.sm_scale = sm_scale
- ctx.BLOCK_DMODEL = Lk
- return o
-
- @staticmethod
- def backward(ctx, do):
- q, k, v, o, l, m = ctx.saved_tensors
- do = do.contiguous()
- dq = torch.zeros_like(q, dtype=torch.float32)
- dk = torch.empty_like(k)
- dv = torch.empty_like(v)
- do_scaled = torch.empty_like(do)
- delta = torch.empty_like(l)
- _bwd_preprocess[(ctx.grid[0] * ctx.grid[1],)](
- o,
- do,
- l,
- do_scaled,
- delta,
- BLOCK_M=ctx.BLOCK,
- D_HEAD=ctx.BLOCK_DMODEL,
- )
-
- # NOTE: kernel currently buggy for other values of `num_warps`
- num_warps = 8
- _bwd_kernel[(ctx.grid[1],)](
- q,
- k,
- v,
- ctx.sm_scale,
- o,
- do_scaled,
- dq,
- dk,
- dv,
- l,
- m,
- delta,
- q.stride(0),
- q.stride(1),
- q.stride(2),
- q.stride(3),
- k.stride(0),
- k.stride(1),
- k.stride(2),
- k.stride(3),
- v.stride(0),
- v.stride(1),
- v.stride(2),
- v.stride(3),
- q.shape[0],
- q.shape[1],
- q.shape[2],
- ctx.grid[0],
- BLOCK_M=ctx.BLOCK,
- BLOCK_N=ctx.BLOCK,
- BLOCK_DMODEL=ctx.BLOCK_DMODEL,
- num_warps=num_warps,
- num_stages=1,
- )
- return dq, dk, dv, None
-
- def triton_flash_attention(q, k, v, sm_scale):
- """
- Arguments:
- q: (batch, nheads, seq, headdim)
- k: (batch, nheads, seq, headdim)
- v: (batch, nheads, seq, headdim)
- sm_scale: float. The scaling of QK^T before applying softmax.
- Return:
- out: (batch, nheads, seq, headdim)
- """
- if HAS_TRITON:
- return _TritonFlashAttention.apply(q, k, v, sm_scale)
- else:
- raise RuntimeError("Triton kernel requires CUDA 11.4+!")
-
-
-if HAS_FLASH_ATTN:
-
- def flash_attention_qkv(qkv, sm_scale, batch_size, seq_len, dropout_p=0., causal=False):
- """
- Arguments:
- qkv: (batch * seqlen, 3, nheads, headdim)
- batch_size: int.
- seq_len: int.
- sm_scale: float. The scaling of QK^T before applying softmax.
- Default to 1 / sqrt(headdim).
- dropout_p: float.
- causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
- Return:
- out: (total, nheads, headdim).
- """
- max_s = seq_len
- cu_seqlens = torch.arange(0, (batch_size + 1) * seq_len, step=seq_len, dtype=torch.int32, device=qkv.device)
- out = flash_attn_unpadded_qkvpacked_func(qkv,
- cu_seqlens,
- max_s,
- dropout_p,
- softmax_scale=sm_scale,
- causal=causal)
- return out
-
- def flash_attention_q_kv(q, kv, sm_scale, batch_size, q_seqlen, kv_seqlen, dropout_p=0., causal=False):
- """
- Arguments:
- q: (batch * q_seqlen, nheads, headdim)
- kv: (batch * kv_seqlen, 2, nheads, headdim)
- batch_size: int.
- seq_len: int.
- sm_scale: float. The scaling of QK^T before applying softmax.
- Default to 1 / sqrt(headdim).
- dropout_p: float.
- causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
- Return:
- out: (total, nheads, headdim).
- """
- cu_seqlens_q = torch.arange(0, (batch_size + 1) * q_seqlen, step=q_seqlen, dtype=torch.int32, device=q.device)
- cu_seqlens_k = torch.arange(0, (batch_size + 1) * kv_seqlen,
- step=kv_seqlen,
- dtype=torch.int32,
- device=kv.device)
- out = flash_attn_unpadded_kvpacked_func(q, kv, cu_seqlens_q, cu_seqlens_k, q_seqlen, kv_seqlen, dropout_p,
- sm_scale, causal)
- return out
-
- def flash_attention_q_k_v(q, k, v, sm_scale, batch_size, q_seqlen, kv_seqlen, dropout_p=0., causal=False):
- """
- Arguments:
- q: (batch * q_seqlen, nheads, headdim)
- k: (batch * kv_seqlen, nheads, headdim)
- v: (batch * kv_seqlen, nheads, headdim)
- batch_size: int.
- seq_len: int.
- dropout_p: float. Dropout probability.
- sm_scale: float. The scaling of QK^T before applying softmax.
- Default to 1 / sqrt(headdim).
- causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
- Return:
- out: (total, nheads, headdim).
- """
- cu_seqlens_q = torch.arange(0, (batch_size + 1) * q_seqlen, step=q_seqlen, dtype=torch.int32, device=q.device)
- cu_seqlens_kv = torch.arange(0, (batch_size + 1) * kv_seqlen,
- step=kv_seqlen,
- dtype=torch.int32,
- device=k.device)
- return flash_attn_unpadded_func(q, k, v, cu_seqlens_q, cu_seqlens_kv, q_seqlen, kv_seqlen, dropout_p, sm_scale,
- causal)
-
-
-##########################################################################
diff --git a/colossalai/kernel/cuda_native/mha/__init__.py b/colossalai/kernel/cuda_native/mha/__init__.py
new file mode 100644
index 000000000000..21fddd512957
--- /dev/null
+++ b/colossalai/kernel/cuda_native/mha/__init__.py
@@ -0,0 +1,3 @@
+from .mha import ColoAttention
+
+__all__ = ['ColoAttention']
diff --git a/colossalai/kernel/cuda_native/mha/flash_attn_2.py b/colossalai/kernel/cuda_native/mha/flash_attn_2.py
new file mode 100644
index 000000000000..6a8d74f70c1d
--- /dev/null
+++ b/colossalai/kernel/cuda_native/mha/flash_attn_2.py
@@ -0,0 +1,68 @@
+import warnings
+from typing import Optional
+
+import torch
+
+
+def is_ampere_or_better_gpu():
+ if torch.cuda.is_available():
+ device = torch.device("cuda")
+ properties = torch.cuda.get_device_properties(device)
+ if properties.major >= 8: # Ampere GPUs or newer
+ return True
+ return False
+
+
+# "Check Ampere GPUs or newer"
+HAS_FLASH_ATTN = False
+if is_ampere_or_better_gpu():
+ HAS_FLASH_ATTN = True
+else:
+ warnings.warn('FlashAttention only supports Ampere GPUs or newer.')
+ HAS_FLASH_ATTN = False
+try:
+ from flash_attn.flash_attn_interface import flash_attn_func, flash_attn_varlen_func
+ HAS_FLASH_ATTN = True
+except ImportError:
+ warnings.warn('please install flash_attn from https://github.com/HazyResearch/flash-attention')
+ HAS_FLASH_ATTN = False
+
+if HAS_FLASH_ATTN:
+ from einops import rearrange
+
+ from .utils import SeqLenInfo
+
+ def flash_attention(q: torch.Tensor,
+ k: torch.Tensor,
+ v: torch.Tensor,
+ seq_len_info_q: SeqLenInfo,
+ seq_len_info_kv: SeqLenInfo,
+ bias: Optional[torch.Tensor] = None,
+ dropout_p: float = 0.,
+ scale: float = None,
+ causal: bool = False,
+ padded: bool = False):
+ """
+ Arguments:
+ q: (batch, q_seqlen, nheads, headdim)
+ k: (batch, kv_seqlen, nheads, headdim)
+ v: (batch, kv_seqlen, nheads, headdim)
+ batch_size: int.
+ seq_len: int.
+ dropout_p: float. Dropout probability.
+ sm_scale: float. The scaling of QK^T before applying softmax.
+ Default to 1 / sqrt(headdim).
+ causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
+ Return:
+ attn_out: (batch, q_seqlen, nheads, headdim).
+ """
+ if padded:
+ if seq_len_info_kv == None:
+ seq_len_info_kv = seq_len_info_q
+
+ attn_out = flash_attn_varlen_func(q, k, v, seq_len_info_q.cu_seqlens, seq_len_info_kv.cu_seqlens,
+ seq_len_info_q.max_seqlen, seq_len_info_kv.max_seqlen, dropout_p, scale,
+ causal)
+ else:
+ attn_out = flash_attn_func(q, k, v, dropout_p=dropout_p, softmax_scale=scale, causal=causal)
+ return attn_out
diff --git a/colossalai/kernel/cuda_native/mha/mem_eff_attn.py b/colossalai/kernel/cuda_native/mha/mem_eff_attn.py
new file mode 100644
index 000000000000..e83beb8b2429
--- /dev/null
+++ b/colossalai/kernel/cuda_native/mha/mem_eff_attn.py
@@ -0,0 +1,70 @@
+import warnings
+
+HAS_MEM_EFF_ATTN = False
+try:
+ from xformers.ops.fmha import memory_efficient_attention
+ HAS_MEM_EFF_ATTN = True
+except ImportError:
+ warnings.warn('please install xformers from https://github.com/facebookresearch/xformers')
+ HAS_MEM_EFF_ATTN = False
+
+if HAS_MEM_EFF_ATTN:
+ """
+ A general attention module using the flash attention kernels from xformers:
+ https://github.com/facebookresearch/xformers/tree/main/xformers/ops/fmha
+ """
+ from typing import Optional
+
+ import torch
+ from xformers.ops.fmha import MemoryEfficientAttentionCutlassOp
+ from xformers.ops.fmha.attn_bias import (
+ BlockDiagonalCausalMask,
+ BlockDiagonalMask,
+ LowerTriangularMask,
+ LowerTriangularMaskWithTensorBias,
+ )
+
+ from .utils import SeqLenInfo
+
+ allow_alibi = True
+ for op in MemoryEfficientAttentionCutlassOp:
+ allow_alibi = allow_alibi & (LowerTriangularMaskWithTensorBias in op.SUPPORTED_ATTN_BIAS_TYPES)
+
+ def mem_eff_attention(q: torch.Tensor,
+ k: torch.Tensor,
+ v: torch.Tensor,
+ seq_len_info_q: SeqLenInfo,
+ seq_len_info_kv: SeqLenInfo,
+ bias: Optional[torch.Tensor] = None,
+ dropout_p: float = 0.,
+ scale: float = None,
+ causal: bool = False,
+ padded: bool = False):
+
+ attn_bias = None
+ if padded: # bert style
+ if not causal:
+ attn_bias = BlockDiagonalMask.from_seqlens(seq_len_info_q.seqlens, seq_len_info_kv.seqlens)
+ else:
+ attn_bias = BlockDiagonalCausalMask.from_seqlens(seq_len_info_q.seqlens, seq_len_info_kv.seqlens)
+ elif causal: # gpt style
+ attn_bias = LowerTriangularMask()
+
+ if bias is not None: # alibi / relative position embedding
+ assert allow_alibi, "flash attention with bias is not supported in this system."
+ assert causal, \
+ "attention with bias is only supported for causal attention so far."
+ attn_bias = attn_bias.add_bias(bias)
+
+ if padded:
+ q = q.unsqueeze(0)
+ k = k.unsqueeze(0)
+ v = v.unsqueeze(0)
+
+ out = memory_efficient_attention(q, k, v, attn_bias=attn_bias, p=dropout_p, scale=scale)
+
+ # shape: (b*s, n, d)
+ if padded:
+ out = out.squeeze(0)
+
+ return out
diff --git a/colossalai/kernel/cuda_native/mha/mha.py b/colossalai/kernel/cuda_native/mha/mha.py
new file mode 100644
index 000000000000..8f449a138c51
--- /dev/null
+++ b/colossalai/kernel/cuda_native/mha/mha.py
@@ -0,0 +1,107 @@
+import math
+from typing import Optional
+
+import torch
+import torch.nn.functional as F
+from einops import rearrange
+
+from ..scaled_softmax import AttnMaskType
+from .flash_attn_2 import HAS_FLASH_ATTN
+from .mem_eff_attn import HAS_MEM_EFF_ATTN
+from .utils import Repad, SeqLenInfo, Unpad
+
+if HAS_FLASH_ATTN:
+ from .flash_attn_2 import flash_attention
+if HAS_MEM_EFF_ATTN:
+ from .mem_eff_attn import mem_eff_attention
+
+
+class ColoAttention(torch.nn.Module):
+
+ def __init__(self, embed_dim: int, num_heads: int, dropout: float = 0.0, scale=None):
+ super().__init__()
+ assert embed_dim % num_heads == 0, \
+ f"the embed dim ({embed_dim}) is not divisible by the number of attention heads ({num_heads})."
+ if scale is not None:
+ self.scale = scale
+ else:
+ self.scale = 1 / math.sqrt(embed_dim // num_heads)
+ self.dropout = dropout
+
+ if not HAS_MEM_EFF_ATTN and not HAS_FLASH_ATTN:
+ raise Exception("flash attention can not support!")
+
+ @staticmethod
+ def unpad(tensor: torch.Tensor, indices: torch.Tensor) -> torch.Tensor:
+ return Unpad.apply(tensor, indices)
+
+ @staticmethod
+ def repad(tensor: torch.Tensor, indices: torch.Tensor, batch_size: int, seq_len: int) -> torch.Tensor:
+ return Repad.apply(tensor, indices, batch_size, seq_len)
+
+ def forward(self,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attn_mask: Optional[torch.Tensor] = None,
+ attn_mask_type: Optional[AttnMaskType] = None,
+ bias: Optional[torch.Tensor] = None):
+
+ attn = None
+ if HAS_FLASH_ATTN and query.dtype in [torch.float16, torch.bfloat16] and bias == None:
+ attn = flash_attention
+ else:
+ attn = mem_eff_attention
+
+ padded = attn_mask_type is not None and attn_mask_type.value % 2 == 1
+ causal = attn_mask_type is not None and attn_mask_type.value > 1
+
+ batch_size, tgt_len, src_len = query.shape[0], query.shape[1], key.shape[1]
+ # unpad
+ seq_len_info_q = None
+ seq_len_info_kv = None
+ if padded:
+ # bert style, unpad process
+ assert attn_mask is not None, \
+ f"attention mask {attn_mask} is not valid for attention mask type {attn_mask_type}."
+ assert attn_mask.dim() == 2, \
+ "attention mask is supposed to have shape (batch_size, seq_len), " + \
+ f"but got {attn_mask.dim()} dimensions."
+
+ # bert style
+ if tgt_len == src_len:
+ seq_len_info_q = SeqLenInfo.materialize(attn_mask=attn_mask, device=query.device)
+ if batch_size > 1:
+ query, key, value = self.unpad(torch.stack([query, key, value], dim=2),
+ seq_len_info_q.indices).unbind(dim=1)
+ else:
+ query, key, value = torch.stack([query, key, value], dim=2).squeeze(0).unbind(dim=1)
+ seq_len_info_kv = seq_len_info_q
+ else:
+ seq_len_info_q = SeqLenInfo.materialize(size=(batch_size, tgt_len), device=query.device)
+ seq_len_info_kv = SeqLenInfo.materialize(attn_mask=attn_mask, device=query.device)
+ if batch_size > 1:
+ query = rearrange(query, "b s ... -> c (b s) ...", c=1)
+ key, value = self.unpad(torch.stack([query, key, value], dim=2),
+ seq_len_info_kv.indices).unbind(dim=1)
+ else:
+ query, key, value = torch.stack([query, key, value], dim=2).squeeze(0).unbind(dim=1)
+
+ out = attn(query,
+ key,
+ value,
+ seq_len_info_q,
+ seq_len_info_kv,
+ dropout_p=self.dropout,
+ scale=self.scale,
+ causal=causal,
+ padded=padded)
+
+ # repad
+ if padded:
+ if batch_size > 1:
+ out = self.repad(out, seq_len_info_q.indices, batch_size, tgt_len)
+ out = rearrange(out, '(b s) h d -> b s h d', b=batch_size)
+
+ out = rearrange(out, 'b s h d -> b s (h d)')
+ return out
diff --git a/colossalai/kernel/cuda_native/mha/utils.py b/colossalai/kernel/cuda_native/mha/utils.py
new file mode 100644
index 000000000000..e3e431fa7e99
--- /dev/null
+++ b/colossalai/kernel/cuda_native/mha/utils.py
@@ -0,0 +1,82 @@
+from dataclasses import dataclass
+from typing import Iterable, Tuple
+
+import torch
+import torch.nn.functional as F
+from einops import rearrange
+
+from colossalai.utils.cuda import get_current_device
+
+
+class Unpad(torch.autograd.Function):
+ """
+ Adapted from
+ https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/bert_padding.py
+ """
+
+ @staticmethod
+ def forward(ctx, tensor: torch.Tensor, indices: torch.Tensor):
+ ctx.save_for_backward(indices)
+ # [b, s, ...]
+ assert tensor.ndim >= 3
+ ctx.bsz = tensor.shape[0]
+ out = rearrange(tensor, 'b s ... -> (b s) ...')
+ ctx.shape = out.shape
+ # [ntokens, ...]
+ return out[indices]
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ indices, = ctx.saved_tensors
+ # [ntokens, ...]
+ grad = torch.zeros(ctx.shape, dtype=grad_output.dtype, device=grad_output.device)
+ grad[indices] = grad_output
+ grad = rearrange(grad, '(b s) ... -> b s ...', b=ctx.bsz)
+ # [b, s, ...]
+ return grad, None
+
+
+class Repad(torch.autograd.Function):
+ """
+ Adapted from
+ https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/bert_padding.py
+ """
+
+ @staticmethod
+ def forward(ctx, tensor: torch.Tensor, indices: torch.Tensor, batch_size: int, seq_len: int):
+ ctx.save_for_backward(indices)
+ # [ntokens, ...]
+ tensor = tensor
+ out = torch.zeros((batch_size * seq_len, *tensor.shape[1:]), dtype=tensor.dtype, device=tensor.device)
+ # [b*s, ...]
+ out[indices] = tensor
+ return out
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ indices, = ctx.saved_tensors
+ # [b*s, ...]
+ grad = grad_output[indices]
+ # [ntokens, ...]
+ return grad, None, None, None
+
+
+@dataclass
+class SeqLenInfo:
+ seqlens: Iterable[int] = None
+ indices: torch.Tensor = None
+ max_seqlen: int = None
+ cu_seqlens: torch.Tensor = None
+
+ @staticmethod
+ def materialize(attn_mask: torch.Tensor = None, size: Tuple[int] = None, device=get_current_device()):
+ if attn_mask is not None:
+ indices = torch.nonzero(attn_mask.flatten(), as_tuple=False).flatten().to(device)
+ seqlens = attn_mask.sum(dim=-1, dtype=torch.int32).flatten()
+ else:
+ batch_size, tgt_len = size[0], size[1]
+ indices = torch.arange(batch_size * tgt_len, dtype=torch.long, device=device)
+ seqlens = torch.LongTensor([tgt_len] * batch_size, device=device)
+ max_seqlen = max(seqlens)
+ cu_seqlens = F.pad(torch.cumsum(seqlens, dim=0, dtype=torch.int32), (1, 0)).to(device)
+ return SeqLenInfo(seqlens.tolist(), indices, max_seqlen, cu_seqlens)
diff --git a/colossalai/kernel/cuda_native/scaled_softmax.py b/colossalai/kernel/cuda_native/scaled_softmax.py
index 24e458bb3ea5..41cd4b20faa1 100644
--- a/colossalai/kernel/cuda_native/scaled_softmax.py
+++ b/colossalai/kernel/cuda_native/scaled_softmax.py
@@ -19,6 +19,7 @@
class AttnMaskType(enum.Enum):
padding = 1
causal = 2
+ paddedcausal = 3
class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function):
@@ -139,7 +140,7 @@ def is_kernel_available(self, mask, b, np, sq, sk):
if 0 <= sk <= 2048:
batch_per_block = self.get_batch_per_block(sq, sk, b, np)
- if self.attn_mask_type == AttnMaskType.causal:
+ if self.attn_mask_type.value > 1:
if attn_batches % batch_per_block == 0:
return True
else:
@@ -151,7 +152,7 @@ def forward_fused_softmax(self, input, mask):
b, np, sq, sk = input.size()
scale = self.scale if self.scale is not None else 1.0
- if self.attn_mask_type == AttnMaskType.causal:
+ if self.attn_mask_type.value > 1:
assert sq == sk, "causal mask is only for self attention"
# input is 3D tensor (attn_batches, sq, sk)
diff --git a/colossalai/nn/optimizer/README.md b/colossalai/nn/optimizer/README.md
index 09395d08b93e..d839753d6c44 100644
--- a/colossalai/nn/optimizer/README.md
+++ b/colossalai/nn/optimizer/README.md
@@ -3,7 +3,8 @@
## Introduction
Welcome to the large-scale deep learning optimization techniques of [Colossal-AI](https://github.com/hpcaitech/ColossalAI),
-which has been accepted as official tutorials by top conference [SC](https://sc22.supercomputing.org/), [AAAI](https://aaai.org/Conferences/AAAI-23/), [PPoPP](https://ppopp23.sigplan.org/), [CVPR](https://cvpr2023.thecvf.com/), [ISC](https://www.isc-hpc.com/), etc.
+which has been accepted as official tutorials by top conference [NeurIPS](https://nips.cc/), [SC](https://sc22.supercomputing.org/), [AAAI](https://aaai.org/Conferences/AAAI-23/),
+[PPoPP](https://ppopp23.sigplan.org/), [CVPR](https://cvpr2023.thecvf.com/), [ISC](https://www.isc-hpc.com/), [NVIDIA GTC](https://www.nvidia.com/en-us/on-demand/session/gtcspring23-S51482/) ,etc.
[Colossal-AI](https://github.com/hpcaitech/ColossalAI), a unified deep learning system for the big model era, integrates
diff --git a/docs/README-zh-Hans.md b/docs/README-zh-Hans.md
index e229c65d890c..945ca4080413 100644
--- a/docs/README-zh-Hans.md
+++ b/docs/README-zh-Hans.md
@@ -24,6 +24,7 @@
## 新闻
+* [2023/07] [HPC-AI Tech Raises 22 Million USD in Series A Funding](https://www.hpc-ai.tech/blog/hpc-ai-tech-raises-22-million-usd-in-series-a-funding-to-fuel-team-expansion-and-business-growth)
* [2023/07] [65B Model Pretraining Accelerated by 38%, Best Practices for Building LLaMA-Like Base Models Open-Source](https://www.hpc-ai.tech/blog/large-model-pretraining)
* [2023/03] [ColossalChat: An Open-Source Solution for Cloning ChatGPT With a Complete RLHF Pipeline](https://medium.com/@yangyou_berkeley/colossalchat-an-open-source-solution-for-cloning-chatgpt-with-a-complete-rlhf-pipeline-5edf08fb538b)
* [2023/03] [Intel and Colossal-AI Partner to Deliver Cost-Efficient Open-Source Solution for Protein Folding Structure Prediction](https://www.hpc-ai.tech/blog/intel-habana)
@@ -32,8 +33,6 @@
* [2023/01] [Hardware Savings Up to 46 Times for AIGC and Automatic Parallelism](https://medium.com/pytorch/latest-colossal-ai-boasts-novel-automatic-parallelism-and-offers-savings-up-to-46x-for-stable-1453b48f3f02)
* [2022/11] [Diffusion Pretraining and Hardware Fine-Tuning Can Be Almost 7X Cheaper](https://www.hpc-ai.tech/blog/diffusion-pretraining-and-hardware-fine-tuning-can-be-almost-7x-cheaper)
* [2022/10] [Use a Laptop to Analyze 90% of Proteins, With a Single-GPU Inference Sequence Exceeding 10,000](https://www.hpc-ai.tech/blog/use-a-laptop-to-analyze-90-of-proteins-with-a-single-gpu-inference-sequence-exceeding)
-* [2022/09] [HPC-AI Tech Completes $6 Million Seed and Angel Round Fundraising](https://www.hpc-ai.tech/blog/hpc-ai-tech-completes-6-million-seed-and-angel-round-fundraising-led-by-bluerun-ventures-in-the)
-
## 目录
@@ -444,6 +443,7 @@ Colossal-AI项目受一些相关的项目启发而成立,一些项目是我们
}
```
-Colossal-AI 已被 [SC](https://sc22.supercomputing.org/), [AAAI](https://aaai.org/Conferences/AAAI-23/), [PPoPP](https://ppopp23.sigplan.org/), [CVPR](https://cvpr2023.thecvf.com/), [ISC](https://www.isc-hpc.com/)等顶级会议录取为官方教程。
+Colossal-AI 已被[NeurIPS](https://nips.cc/), [SC](https://sc22.supercomputing.org/), [AAAI](https://aaai.org/Conferences/AAAI-23/),
+[PPoPP](https://ppopp23.sigplan.org/), [CVPR](https://cvpr2023.thecvf.com/), [ISC](https://www.isc-hpc.com/), [NVIDIA GTC](https://www.nvidia.com/en-us/on-demand/session/gtcspring23-S51482/) ,等顶级会议录取为官方教程。
(返回顶端)
diff --git a/docs/source/en/features/gradient_accumulation_with_booster.md b/docs/source/en/features/gradient_accumulation_with_booster.md
index 201e3bc2b643..7bc4eb47bcd7 100644
--- a/docs/source/en/features/gradient_accumulation_with_booster.md
+++ b/docs/source/en/features/gradient_accumulation_with_booster.md
@@ -103,10 +103,12 @@ for idx, (img, label) in enumerate(train_dataloader):
with sync_context:
output = model(img)
train_loss = criterion(output, label)
+ train_loss = train_loss / GRADIENT_ACCUMULATION
booster.backward(train_loss, optimizer)
else:
output = model(img)
train_loss = criterion(output, label)
+ train_loss = train_loss / GRADIENT_ACCUMULATION
booster.backward(train_loss, optimizer)
optimizer.step()
optimizer.zero_grad()
diff --git a/docs/source/zh-Hans/features/gradient_accumulation_with_booster.md b/docs/source/zh-Hans/features/gradient_accumulation_with_booster.md
index a8422060f0ea..d121b161b9ff 100644
--- a/docs/source/zh-Hans/features/gradient_accumulation_with_booster.md
+++ b/docs/source/zh-Hans/features/gradient_accumulation_with_booster.md
@@ -106,10 +106,12 @@ for idx, (img, label) in enumerate(train_dataloader):
with sync_context:
output = model(img)
train_loss = criterion(output, label)
+ train_loss = train_loss / GRADIENT_ACCUMULATION
booster.backward(train_loss, optimizer)
else:
output = model(img)
train_loss = criterion(output, label)
+ train_loss = train_loss / GRADIENT_ACCUMULATION
booster.backward(train_loss, optimizer)
optimizer.step()
optimizer.zero_grad()
diff --git a/examples/tutorial/README.md b/examples/tutorial/README.md
index 0664d41fd359..7b5668612818 100644
--- a/examples/tutorial/README.md
+++ b/examples/tutorial/README.md
@@ -4,7 +4,8 @@
## Introduction
-Welcome to the [Colossal-AI](https://github.com/hpcaitech/ColossalAI) tutorial, which has been accepted as official tutorials by top conference [SC](https://sc22.supercomputing.org/), [AAAI](https://aaai.org/Conferences/AAAI-23/), [PPoPP](https://ppopp23.sigplan.org/), [CVPR](https://cvpr2023.thecvf.com/), [ISC](https://www.isc-hpc.com/), etc.
+Welcome to the [Colossal-AI](https://github.com/hpcaitech/ColossalAI) tutorial, which has been accepted as official tutorials by top conference [NeurIPS](https://nips.cc/), [SC](https://sc22.supercomputing.org/), [AAAI](https://aaai.org/Conferences/AAAI-23/),
+[PPoPP](https://ppopp23.sigplan.org/), [CVPR](https://cvpr2023.thecvf.com/), [ISC](https://www.isc-hpc.com/), [NVIDIA GTC](https://www.nvidia.com/en-us/on-demand/session/gtcspring23-S51482/) ,etc.
[Colossal-AI](https://github.com/hpcaitech/ColossalAI), a unified deep learning system for the big model era, integrates
diff --git a/pytest.ini b/pytest.ini
index e99fe3f086c6..e8a60c85336b 100644
--- a/pytest.ini
+++ b/pytest.ini
@@ -4,4 +4,4 @@ markers =
gpu: tests which requires a single GPU
dist: tests which are run in a multi-GPU or multi-machine environment
experiment: tests for experimental features
-addopts = --ignore=tests/test_analyzer --ignore=tests/test_auto_parallel --ignore=tests/test_autochunk
+addopts = --ignore=tests/test_analyzer --ignore=tests/test_auto_parallel --ignore=tests/test_autochunk --ignore=tests/test_moe
diff --git a/requirements/requirements.txt b/requirements/requirements.txt
index b34dc2e223ae..f6be6a624c70 100644
--- a/requirements/requirements.txt
+++ b/requirements/requirements.txt
@@ -10,3 +10,4 @@ contexttimer
ninja
torch>=1.11
safetensors
+einops
diff --git a/tests/test_utils/test_flash_attention.py b/tests/test_utils/test_flash_attention.py
index 7a28b0157384..fbcc452650cf 100644
--- a/tests/test_utils/test_flash_attention.py
+++ b/tests/test_utils/test_flash_attention.py
@@ -4,11 +4,15 @@
import torch
from einops import rearrange
-from colossalai.kernel.cuda_native.flash_attention import HAS_MEM_EFF_ATTN
+from colossalai.kernel.cuda_native.mha.flash_attn_2 import HAS_FLASH_ATTN
+from colossalai.kernel.cuda_native.mha.mem_eff_attn import HAS_MEM_EFF_ATTN
from colossalai.testing import clear_cache_before_run, parameterize
-if HAS_MEM_EFF_ATTN:
- from colossalai.kernel.cuda_native.flash_attention import AttnMaskType, ColoAttention
+if HAS_MEM_EFF_ATTN or HAS_FLASH_ATTN:
+ from colossalai.kernel.cuda_native import ColoAttention
+ from colossalai.kernel.cuda_native.scaled_softmax import AttnMaskType
+
+DTYPE = [torch.float16, torch.bfloat16, torch.float32]
def baseline_attention(Z, N_CTX, H, q, k, v, sm_scale):
@@ -22,10 +26,13 @@ def baseline_attention(Z, N_CTX, H, q, k, v, sm_scale):
return ref_out
-@pytest.mark.skipif(HAS_MEM_EFF_ATTN == False, reason="xformers is not available")
+@pytest.mark.skipif(not HAS_MEM_EFF_ATTN and not HAS_FLASH_ATTN, reason="xformers is not available")
@clear_cache_before_run()
-@parameterize('B, S, H, D_HEAD', [(6, 8, 4, 16)])
-def test_attention_gpt(B, S, H, D_HEAD, dtype=torch.float16):
+@parameterize('proj_shape', [(1, 8, 4, 16)])
+@parameterize('dtype', DTYPE)
+def test_attention_gpt(proj_shape, dtype):
+ # TODO check output value
+ (B, S, H, D_HEAD) = proj_shape
D = H * D_HEAD
c_attn = torch.nn.Linear(D, 3 * D, dtype=dtype, device="cuda")
@@ -35,7 +42,11 @@ def test_attention_gpt(B, S, H, D_HEAD, dtype=torch.float16):
qkv = c_attn(x)
q, k, v = rearrange(qkv, 'b s (n h d) -> n b s h d', n=3, h=H)
- y = attn(q, k, v, attn_mask_type=AttnMaskType.causal)
+
+ mask = [torch.ones(S - i, dtype=dtype, device="cuda") for i in range(B)]
+ mask = torch.nn.utils.rnn.pad_sequence(mask, batch_first=True)
+
+ y = attn(q, k, v, attn_mask=mask, attn_mask_type=AttnMaskType.paddedcausal)
assert list(y.shape) == [B, S, D]
@@ -43,10 +54,12 @@ def test_attention_gpt(B, S, H, D_HEAD, dtype=torch.float16):
y.backward(dy)
-@pytest.mark.skipif(HAS_MEM_EFF_ATTN == False, reason="xformers is not available")
+@pytest.mark.skipif(not HAS_MEM_EFF_ATTN and not HAS_FLASH_ATTN, reason="xformers is not available")
@clear_cache_before_run()
-@parameterize('B, S, H, D_HEAD', [(6, 8, 4, 16)])
-def test_attention_bert(B, S, H, D_HEAD, dtype=torch.float16):
+@parameterize('proj_shape', [(6, 8, 4, 16)])
+@parameterize('dtype', DTYPE)
+def test_attention_bert(proj_shape, dtype):
+ (B, S, H, D_HEAD) = proj_shape
D = H * D_HEAD
c_attn = torch.nn.Linear(D, 3 * D, dtype=dtype, device="cuda")
@@ -67,10 +80,12 @@ def test_attention_bert(B, S, H, D_HEAD, dtype=torch.float16):
y.backward(dy)
-@pytest.mark.skipif(HAS_MEM_EFF_ATTN == False, reason="xformers is not available")
+@pytest.mark.skipif(not HAS_MEM_EFF_ATTN and not HAS_FLASH_ATTN, reason="xformers is not available")
@clear_cache_before_run()
-@parameterize('B, S, H, D_HEAD', [(6, 8, 4, 16)])
-def test_attention_no_mask(B, S, H, D_HEAD, dtype=torch.float16):
+@parameterize('proj_shape', [(6, 8, 4, 16)])
+@parameterize('dtype', DTYPE)
+def test_attention_no_mask(proj_shape, dtype):
+ (B, S, H, D_HEAD) = proj_shape
D = H * D_HEAD
c_attn = torch.nn.Linear(D, 3 * D, dtype=dtype, device="cuda")
@@ -87,10 +102,12 @@ def test_attention_no_mask(B, S, H, D_HEAD, dtype=torch.float16):
y.backward(dy)
-@pytest.mark.skipif(HAS_MEM_EFF_ATTN == False, reason="xformers is not available")
+@pytest.mark.skipif(not HAS_MEM_EFF_ATTN and not HAS_FLASH_ATTN, reason="xformers is not available")
@clear_cache_before_run()
-@parameterize('B, S, T, H, D_HEAD', [(6, 24, 8, 4, 16)])
-def test_cross_attention(B, S, T, H, D_HEAD, dtype=torch.float16):
+@parameterize('proj_shape', [(6, 24, 8, 4, 16)])
+@parameterize('dtype', DTYPE)
+def test_cross_attention(proj_shape, dtype):
+ (B, S, T, H, D_HEAD) = proj_shape
D = H * D_HEAD
q_attn = torch.nn.Linear(D, D, dtype=dtype, device="cuda")