Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
b8fa73e
Merge pull request #106 from jamesthesnake/best
jamesthesnake Jul 26, 2023
150e469
Merge pull request #110 from jamesthesnake/ra
jamesthesnake Jul 27, 2023
f0c04d2
Merge pull request #111 from jamesthesnake/l
jamesthesnake Jul 27, 2023
12dc1e7
Merge pull request #113 from jamesthesnake/ra
jamesthesnake Jul 29, 2023
0bdcd22
Merge pull request #114 from jamesthesnake/best
jamesthesnake Jul 29, 2023
d1a33cd
Merge pull request #115 from jamesthesnake/l
jamesthesnake Jul 29, 2023
ce95ae2
Merge pull request #116 from jamesthesnake/main
jamesthesnake Jul 29, 2023
d37fb67
Merge pull request #118 from jamesthesnake/ra
jamesthesnake Aug 3, 2023
1038c2b
Merge pull request #119 from jamesthesnake/ra
jamesthesnake Aug 4, 2023
c299445
Merge pull request #120 from jamesthesnake/co
jamesthesnake Aug 4, 2023
d177105
Merge pull request #121 from jamesthesnake/best
jamesthesnake Aug 4, 2023
78d5362
Merge pull request #122 from hpcaitech/main
jamesthesnake Aug 5, 2023
e981a57
Merge pull request #125 from jamesthesnake/ra
jamesthesnake Aug 5, 2023
3ddd630
Merge pull request #126 from jamesthesnake/l
jamesthesnake Aug 5, 2023
a9e470b
Merge pull request #127 from jamesthesnake/co
jamesthesnake Aug 5, 2023
458ae33
[kernel] updated unittests for coloattention (#4389)
flybird11111 Aug 9, 2023
355eec5
Merge pull request #130 from hpcaitech/main
jamesthesnake Aug 10, 2023
8cce6f5
Merge pull request #131 from jamesthesnake/l
jamesthesnake Aug 10, 2023
b3945e1
Merge pull request #132 from jamesthesnake/co
jamesthesnake Aug 10, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion requirements/requirements-test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ torchrec==0.2.0
contexttimer
einops
triton==2.0.0.dev20221202
git+https://github.com/HazyResearch/flash-attention.git@c422fee3776eb3ea24e011ef641fd5fbeb212623#egg=flash_attn
requests==2.27.1 # downgrade to avoid huggingface error https://github.com/huggingface/transformers/issues/17611
SentencePiece
ninja
flash_attn>=2.0
1 change: 1 addition & 0 deletions requirements/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,5 @@ contexttimer
ninja
torch>=1.11
safetensors
flash_attn>=2.0
einops
142 changes: 91 additions & 51 deletions tests/test_utils/test_flash_attention.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import random
import math

import pytest
import torch
Expand All @@ -13,118 +13,158 @@
from colossalai.kernel.cuda_native.scaled_softmax import AttnMaskType

DTYPE = [torch.float16, torch.bfloat16, torch.float32]
FLASH_DTYPE = [torch.float16, torch.bfloat16]


def baseline_attention(Z, N_CTX, H, q, k, v, sm_scale):
M = torch.tril(torch.ones((N_CTX, N_CTX), device="cuda"))
p = torch.matmul(q, k.transpose(2, 3)) * sm_scale
for z in range(Z):
for h in range(H):
p[:, :, M == 0] = float("-inf")
p = torch.softmax(p.float(), dim=-1).half()
ref_out = torch.matmul(p, v)
return ref_out
def attention_ref(q, k, v, attn_mask=None, causal=False):
"""
attention output of the control group
"""
dtype_og = q.dtype
seqlen_q, seqlen_k = q.shape[1], k.shape[1]
d = q.shape[-1]
scale = 1.0 / math.sqrt(d)
scores = torch.einsum('bthd,bshd->bhts', q * scale, k)

if attn_mask is not None:
scores.masked_fill_(rearrange(~attn_mask, 'b s -> b 1 1 s'), float('-inf'))
if causal:
causal_mask = torch.triu(torch.ones(seqlen_q, seqlen_k, dtype=torch.bool, device=q.device), 1)
scores.masked_fill_(causal_mask, float('-inf'))
attention = torch.softmax(scores, dim=-1)

output = torch.einsum('bhts,bshd->bthd', attention, v)
output = rearrange(output, "b s h d -> b s (h d)")

# Modify the data at the positions of the mask to 0
if attn_mask is not None:
output.masked_fill_(rearrange(~attn_mask, 'b s -> b s 1'), 0.0)

return output.to(dtype=dtype_og)


@pytest.mark.skipif(not HAS_MEM_EFF_ATTN and not HAS_FLASH_ATTN, reason="xformers is not available")
@clear_cache_before_run()
@parameterize('proj_shape', [(1, 8, 4, 16)])
@parameterize('proj_shape', [(6, 8, 4, 16)])
@parameterize('dtype', DTYPE)
def test_attention_gpt(proj_shape, dtype):
# TODO check output value
@parameterize('dropout', [0.0])
def test_attention_gpt(proj_shape, dtype, dropout):
(B, S, H, D_HEAD) = proj_shape
D = H * D_HEAD

c_attn = torch.nn.Linear(D, 3 * D, dtype=dtype, device="cuda")
attn = ColoAttention(D, H, dropout=0.1)
q = torch.randn((B, S, H, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
k = torch.randn((B, S, H, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
v = torch.randn((B, S, H, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)

x = torch.randn((B, S, D), dtype=dtype, device="cuda")

qkv = c_attn(x)
q, k, v = rearrange(qkv, 'b s (n h d) -> n b s h d', n=3, h=H)

mask = [torch.ones(S - i, dtype=dtype, device="cuda") for i in range(B)]
mask = [torch.ones(S - i, dtype=torch.bool, device="cuda") for i in range(B)]
mask = torch.nn.utils.rnn.pad_sequence(mask, batch_first=True)

attn = ColoAttention(D, H, dropout=dropout)
y = attn(q, k, v, attn_mask=mask, attn_mask_type=AttnMaskType.paddedcausal)

assert list(y.shape) == [B, S, D]

out_ref = attention_ref(q, k, v, mask, causal=True)

# check gradients
dy = torch.rand_like(y)
y.backward(dy)
grad_q, grad_k, grad_v = torch.autograd.grad(y, (q, k, v), dy)
grad_ref_q, grad_ref_k, grad_ref_v = torch.autograd.grad(out_ref, (q, k, v), dy)

torch.allclose(y, out_ref, atol=1e-7), f"{(y - out_ref).abs().max()}"
torch.allclose(grad_q, grad_ref_q, atol=1e-7), f"{(grad_q - grad_ref_q).abs().max()}"
torch.allclose(grad_k, grad_ref_k, atol=1e-7), f"{(grad_k - grad_ref_k).abs().max()}"
torch.allclose(grad_v, grad_ref_v, atol=1e-7), f"{(grad_v - grad_ref_v).abs().max()}"


@pytest.mark.skipif(not HAS_MEM_EFF_ATTN and not HAS_FLASH_ATTN, reason="xformers is not available")
@clear_cache_before_run()
@parameterize('proj_shape', [(6, 8, 4, 16)])
@parameterize('dtype', DTYPE)
def test_attention_bert(proj_shape, dtype):
@parameterize('dropout', [0.0])
def test_attention_bert(proj_shape, dtype, dropout):
(B, S, H, D_HEAD) = proj_shape
D = H * D_HEAD

c_attn = torch.nn.Linear(D, 3 * D, dtype=dtype, device="cuda")
attn = ColoAttention(D, H, dropout=0.1)
q = torch.randn((B, S, H, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
k = torch.randn((B, S, H, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
v = torch.randn((B, S, H, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)

x = torch.randn((B, S, D), dtype=dtype, device="cuda")
# attention mask of shape [B, S] with zero padding to max length S
mask = [torch.ones(S - i, dtype=dtype, device="cuda") for i in range(B)]
mask = torch.nn.utils.rnn.pad_sequence(mask, batch_first=True)
mask = torch.randint(0, 2, (B, S), dtype=torch.bool, device="cuda")

qkv = c_attn(x)
q, k, v = rearrange(qkv, 'b s (n h d) -> b s n h d', n=3, h=H).unbind(dim=2)
attn = ColoAttention(D, H, dropout=dropout)
y = attn(q, k, v, attn_mask=mask, attn_mask_type=AttnMaskType.padding)

assert list(y.shape) == [B, S, D]

out_ref = attention_ref(q, k, v, mask, causal=False)

dy = torch.rand_like(y)
y.backward(dy)
grad_q, grad_k, grad_v = torch.autograd.grad(y, (q, k, v), dy)
grad_ref_q, grad_ref_k, grad_ref_v = torch.autograd.grad(out_ref, (q, k, v), dy)

torch.allclose(y, out_ref, atol=1e-7), f"{(y - out_ref).abs().max()}"
torch.allclose(grad_q, grad_ref_q, atol=1e-7), f"{(grad_q - grad_ref_q).abs().max()}"
torch.allclose(grad_k, grad_ref_k, atol=1e-7), f"{(grad_k - grad_ref_k).abs().max()}"
torch.allclose(grad_v, grad_ref_v, atol=1e-7), f"{(grad_v - grad_ref_v).abs().max()}"


@pytest.mark.skipif(not HAS_MEM_EFF_ATTN and not HAS_FLASH_ATTN, reason="xformers is not available")
@clear_cache_before_run()
@parameterize('proj_shape', [(6, 8, 4, 16)])
@parameterize('dtype', DTYPE)
def test_attention_no_mask(proj_shape, dtype):
@parameterize('dropout', [0.0])
def test_attention_no_mask(proj_shape, dtype, dropout):
(B, S, H, D_HEAD) = proj_shape
D = H * D_HEAD

c_attn = torch.nn.Linear(D, 3 * D, dtype=dtype, device="cuda")
attn = ColoAttention(D, H, dropout=0.1)
q = torch.randn((B, S, H, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
k = torch.randn((B, S, H, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
v = torch.randn((B, S, H, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)

x = torch.randn((B, S, D), dtype=dtype, device="cuda")
qkv = c_attn(x)
q, k, v = rearrange(qkv, 'b s (n h d) -> b s n h d', n=3, h=H).unbind(dim=2)
attn = ColoAttention(D, H, dropout=dropout)
y = attn(q, k, v)

assert list(y.shape) == [B, S, D]

out_ref = attention_ref(q, k, v, None, causal=False)

dy = torch.rand_like(y)
y.backward(dy)
grad_q, grad_k, grad_v = torch.autograd.grad(y, (q, k, v), dy)
grad_ref_q, grad_ref_k, grad_ref_v = torch.autograd.grad(out_ref, (q, k, v), dy)

torch.allclose(y, out_ref, atol=1e-7), f"{(y - out_ref).abs().max()}"
torch.allclose(grad_q, grad_ref_q, atol=1e-7), f"{(grad_q - grad_ref_q).abs().max()}"
torch.allclose(grad_k, grad_ref_k, atol=1e-7), f"{(grad_k - grad_ref_k).abs().max()}"
torch.allclose(grad_v, grad_ref_v, atol=1e-7), f"{(grad_v - grad_ref_v).abs().max()}"


@pytest.mark.skipif(not HAS_MEM_EFF_ATTN and not HAS_FLASH_ATTN, reason="xformers is not available")
@clear_cache_before_run()
@parameterize('proj_shape', [(6, 24, 8, 4, 16)])
@parameterize('dtype', DTYPE)
def test_cross_attention(proj_shape, dtype):
@parameterize('dropout', [0.0])
def test_cross_attention(proj_shape, dtype, dropout):
(B, S, T, H, D_HEAD) = proj_shape
D = H * D_HEAD

q_attn = torch.nn.Linear(D, D, dtype=dtype, device="cuda")
kv_attn = torch.nn.Linear(D, 2 * D, dtype=dtype, device="cuda")
q = torch.randn((B, T, H, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
k = torch.randn((B, S, H, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
v = torch.randn((B, S, H, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)

attn = ColoAttention(D, H, dropout=0.1)

src = torch.randn((B, S, D), dtype=dtype, device="cuda")
tgt = torch.randn((B, T, D), dtype=dtype, device="cuda")

q = q_attn(tgt)
kv = kv_attn(src)
q = rearrange(q, 'b s (h d) -> b s h d', h=H)
k, v = rearrange(kv, 'b s (n h d) -> b s n h d', n=2, h=H).unbind(dim=2)
attn = ColoAttention(D, H, dropout=dropout)
y = attn(q, k, v, attn_mask_type=AttnMaskType.causal)

assert list(y.shape) == [B, T, D]

out_ref = attention_ref(q, k, v, None, causal=True)

dy = torch.rand_like(y)
y.backward(dy)
grad_q, grad_k, grad_v = torch.autograd.grad(y, (q, k, v), dy)
grad_ref_q, grad_ref_k, grad_ref_v = torch.autograd.grad(out_ref, (q, k, v), dy)

torch.allclose(y, out_ref, atol=1e-18), f"{(y - out_ref).abs().max()}"
torch.allclose(grad_q, grad_ref_q, atol=1e-7), f"{(grad_q - grad_ref_q).abs().max()}"
torch.allclose(grad_k, grad_ref_k, atol=1e-7), f"{(grad_k - grad_ref_k).abs().max()}"
torch.allclose(grad_v, grad_ref_v, atol=1e-7), f"{(grad_v - grad_ref_v).abs().max()}"