Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def forward(
layer_past=layer_past,
use_cache=True,
)

if use_cache:
present = (key_layer, value_layer)
else:
Expand Down
2 changes: 0 additions & 2 deletions applications/Chat/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@ transformers>=4.20.1
tqdm
datasets
loralib
colossalai>=0.2.4
torch<2.0.0, >=1.12.1
langchain
tokenizers
fastapi
Expand Down
14 changes: 8 additions & 6 deletions colossalai/kernel/triton/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import torch
from torch import nn
from torch.nn import functional as F

try:
import triton
Expand Down Expand Up @@ -156,7 +157,7 @@ def compute_attention_for_bloom(q: torch.Tensor,
Return:
output (Torch.Tensor): The output shape is (batch, seq_len, num_heads, head_size)
"""

assert len(q.shape) == len(k.shape), "the dimensions must be matched"
assert len(q.shape) == len(v.shape), "the dimensions must be matched"
assert len(q.shape) == 4, "the length of input q must be 4, which is (batch, seq_len, num_heads, head_dim)"
Expand All @@ -181,14 +182,14 @@ def compute_attention_for_bloom(q: torch.Tensor,
triton.cdiv(N, meta["BLOCK_SIZE_N"]),
)

qkv_gemm_4d_kernel_alibi[grid](
q, k, alibi, score_output,
qkv_gemm_4d_kernel[grid](
q, k,
score_output,
M, N, K,
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
score_output.stride(0), score_output.stride(1), score_output.stride(2), score_output.stride(3),
scale=scale,
beta=beta,
# currently manually setting, later on we can use auto-tune config to match best setting
BLOCK_SIZE_M=64,
BLOCK_SIZE_N=32,
Expand All @@ -197,6 +198,8 @@ def compute_attention_for_bloom(q: torch.Tensor,
num_stages=4,
)

score_output += beta * alibi

# cast attention scores to fp32, compute scaled softmax and cast back to initial dtype - [batch_size, num_heads, q_length, kv_length]
input_dtype = score_output.dtype
# `float16` has a minimum value of -65504.0, whereas `bfloat16` and `float32` have a minimum value of `-3.4e+38`
Expand All @@ -209,7 +212,6 @@ def compute_attention_for_bloom(q: torch.Tensor,
softmax_leading_size = batches * H * M

if softmax_leading_size <= 350000:

score_output = score_output.to(input_dtype)
softmax_output = torch.empty(
score_output.shape, device=score_output.device, dtype=score_output.dtype)
Expand Down Expand Up @@ -241,7 +243,7 @@ def compute_attention_for_bloom(q: torch.Tensor,
softmax_output = F.softmax(score_output, dim=-1, dtype=torch.float32).to(input_dtype)

if drop_out > 0 and drop_out < 1:
softmax_output = F.dropout(softmax_output, drop_out, False, True).to(input_dtype)
softmax_output = F.dropout(softmax_output, drop_out, False, False).to(input_dtype)

if head_mask is not None:
softmax_output = softmax_output * head_mask
Expand Down
2 changes: 1 addition & 1 deletion colossalai/kernel/triton/qkv_matmul_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,8 +210,8 @@ def qkv_gemm_4d_kernel_alibi(
alibi_ptrs = (alibi_ptr + batch * stride_cb + head * stride_ch + stride_cm * offs_accumu_m[:, None] +
stride_cn * offs_accumu_n[None, :])
alibi_vals = tl.load(alibi_ptrs, mask=accumulator_mask, other=0.)
accumulator += (alibi_vals * beta.to(c_ptr.dtype.element_ty))

accumulator += (alibi_vals * beta).to(c_ptr.dtype.element_ty)
accumulator = accumulator.to(c_ptr.dtype.element_ty)

tl.store(c_ptrs, accumulator, mask=accumulator_mask)
2 changes: 1 addition & 1 deletion examples/tutorial/fastfold/FastFold
Submodule FastFold updated 1 files
+18 −9 benchmark/perf.py