diff --git a/applications/Chat/coati/models/bloom/triton_attention_forward.py b/applications/Chat/coati/models/bloom/triton_attention_forward.py index 74c7279e00c2..abb54c7cb503 100644 --- a/applications/Chat/coati/models/bloom/triton_attention_forward.py +++ b/applications/Chat/coati/models/bloom/triton_attention_forward.py @@ -79,7 +79,7 @@ def forward( layer_past=layer_past, use_cache=True, ) - + if use_cache: present = (key_layer, value_layer) else: diff --git a/applications/Chat/requirements.txt b/applications/Chat/requirements.txt index af7ff67861eb..444206405900 100644 --- a/applications/Chat/requirements.txt +++ b/applications/Chat/requirements.txt @@ -2,8 +2,6 @@ transformers>=4.20.1 tqdm datasets loralib -colossalai>=0.2.4 -torch<2.0.0, >=1.12.1 langchain tokenizers fastapi diff --git a/colossalai/kernel/triton/ops.py b/colossalai/kernel/triton/ops.py index ea6f5cb52e59..e8d60dc14f63 100644 --- a/colossalai/kernel/triton/ops.py +++ b/colossalai/kernel/triton/ops.py @@ -2,6 +2,7 @@ import torch from torch import nn +from torch.nn import functional as F try: import triton @@ -156,7 +157,7 @@ def compute_attention_for_bloom(q: torch.Tensor, Return: output (Torch.Tensor): The output shape is (batch, seq_len, num_heads, head_size) """ - + assert len(q.shape) == len(k.shape), "the dimensions must be matched" assert len(q.shape) == len(v.shape), "the dimensions must be matched" assert len(q.shape) == 4, "the length of input q must be 4, which is (batch, seq_len, num_heads, head_dim)" @@ -181,14 +182,14 @@ def compute_attention_for_bloom(q: torch.Tensor, triton.cdiv(N, meta["BLOCK_SIZE_N"]), ) - qkv_gemm_4d_kernel_alibi[grid]( - q, k, alibi, score_output, + qkv_gemm_4d_kernel[grid]( + q, k, + score_output, M, N, K, q.stride(0), q.stride(1), q.stride(2), q.stride(3), k.stride(0), k.stride(1), k.stride(2), k.stride(3), score_output.stride(0), score_output.stride(1), score_output.stride(2), score_output.stride(3), scale=scale, - beta=beta, # currently manually setting, later on we can use auto-tune config to match best setting BLOCK_SIZE_M=64, BLOCK_SIZE_N=32, @@ -197,6 +198,8 @@ def compute_attention_for_bloom(q: torch.Tensor, num_stages=4, ) + score_output += beta * alibi + # cast attention scores to fp32, compute scaled softmax and cast back to initial dtype - [batch_size, num_heads, q_length, kv_length] input_dtype = score_output.dtype # `float16` has a minimum value of -65504.0, whereas `bfloat16` and `float32` have a minimum value of `-3.4e+38` @@ -209,7 +212,6 @@ def compute_attention_for_bloom(q: torch.Tensor, softmax_leading_size = batches * H * M if softmax_leading_size <= 350000: - score_output = score_output.to(input_dtype) softmax_output = torch.empty( score_output.shape, device=score_output.device, dtype=score_output.dtype) @@ -241,7 +243,7 @@ def compute_attention_for_bloom(q: torch.Tensor, softmax_output = F.softmax(score_output, dim=-1, dtype=torch.float32).to(input_dtype) if drop_out > 0 and drop_out < 1: - softmax_output = F.dropout(softmax_output, drop_out, False, True).to(input_dtype) + softmax_output = F.dropout(softmax_output, drop_out, False, False).to(input_dtype) if head_mask is not None: softmax_output = softmax_output * head_mask diff --git a/colossalai/kernel/triton/qkv_matmul_kernel.py b/colossalai/kernel/triton/qkv_matmul_kernel.py index 191397d9bc8d..a7a7322f7276 100644 --- a/colossalai/kernel/triton/qkv_matmul_kernel.py +++ b/colossalai/kernel/triton/qkv_matmul_kernel.py @@ -210,8 +210,8 @@ def qkv_gemm_4d_kernel_alibi( alibi_ptrs = (alibi_ptr + batch * stride_cb + head * stride_ch + stride_cm * offs_accumu_m[:, None] + stride_cn * offs_accumu_n[None, :]) alibi_vals = tl.load(alibi_ptrs, mask=accumulator_mask, other=0.) + accumulator += (alibi_vals * beta.to(c_ptr.dtype.element_ty)) - accumulator += (alibi_vals * beta).to(c_ptr.dtype.element_ty) accumulator = accumulator.to(c_ptr.dtype.element_ty) tl.store(c_ptrs, accumulator, mask=accumulator_mask) diff --git a/examples/tutorial/fastfold/FastFold b/examples/tutorial/fastfold/FastFold index 05681304651b..eba496808a91 160000 --- a/examples/tutorial/fastfold/FastFold +++ b/examples/tutorial/fastfold/FastFold @@ -1 +1 @@ -Subproject commit 05681304651b1b29d7d887db169045ea3dd28fce +Subproject commit eba496808a91bbcd9661cf832349a418b197015f