Skip to content
Merged

L #131

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion requirements/requirements-test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ torchrec==0.2.0
contexttimer
einops
triton==2.0.0.dev20221202
git+https://github.com/HazyResearch/flash-attention.git@c422fee3776eb3ea24e011ef641fd5fbeb212623#egg=flash_attn
requests==2.27.1 # downgrade to avoid huggingface error https://github.com/huggingface/transformers/issues/17611
SentencePiece
ninja
flash_attn>=2.0
1 change: 1 addition & 0 deletions requirements/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,5 @@ contexttimer
ninja
torch>=1.11
safetensors
flash_attn>=2.0
einops
142 changes: 91 additions & 51 deletions tests/test_utils/test_flash_attention.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import random
import math

import pytest
import torch
Expand All @@ -13,118 +13,158 @@
from colossalai.kernel.cuda_native.scaled_softmax import AttnMaskType

DTYPE = [torch.float16, torch.bfloat16, torch.float32]
FLASH_DTYPE = [torch.float16, torch.bfloat16]


def baseline_attention(Z, N_CTX, H, q, k, v, sm_scale):
M = torch.tril(torch.ones((N_CTX, N_CTX), device="cuda"))
p = torch.matmul(q, k.transpose(2, 3)) * sm_scale
for z in range(Z):
for h in range(H):
p[:, :, M == 0] = float("-inf")
p = torch.softmax(p.float(), dim=-1).half()
ref_out = torch.matmul(p, v)
return ref_out
def attention_ref(q, k, v, attn_mask=None, causal=False):
"""
attention output of the control group
"""
dtype_og = q.dtype
seqlen_q, seqlen_k = q.shape[1], k.shape[1]
d = q.shape[-1]
scale = 1.0 / math.sqrt(d)
scores = torch.einsum('bthd,bshd->bhts', q * scale, k)

if attn_mask is not None:
scores.masked_fill_(rearrange(~attn_mask, 'b s -> b 1 1 s'), float('-inf'))
if causal:
causal_mask = torch.triu(torch.ones(seqlen_q, seqlen_k, dtype=torch.bool, device=q.device), 1)
scores.masked_fill_(causal_mask, float('-inf'))
attention = torch.softmax(scores, dim=-1)

output = torch.einsum('bhts,bshd->bthd', attention, v)
output = rearrange(output, "b s h d -> b s (h d)")

# Modify the data at the positions of the mask to 0
if attn_mask is not None:
output.masked_fill_(rearrange(~attn_mask, 'b s -> b s 1'), 0.0)

return output.to(dtype=dtype_og)


@pytest.mark.skipif(not HAS_MEM_EFF_ATTN and not HAS_FLASH_ATTN, reason="xformers is not available")
@clear_cache_before_run()
@parameterize('proj_shape', [(1, 8, 4, 16)])
@parameterize('proj_shape', [(6, 8, 4, 16)])
@parameterize('dtype', DTYPE)
def test_attention_gpt(proj_shape, dtype):
# TODO check output value
@parameterize('dropout', [0.0])
def test_attention_gpt(proj_shape, dtype, dropout):
(B, S, H, D_HEAD) = proj_shape
D = H * D_HEAD

c_attn = torch.nn.Linear(D, 3 * D, dtype=dtype, device="cuda")
attn = ColoAttention(D, H, dropout=0.1)
q = torch.randn((B, S, H, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
k = torch.randn((B, S, H, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
v = torch.randn((B, S, H, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)

x = torch.randn((B, S, D), dtype=dtype, device="cuda")

qkv = c_attn(x)
q, k, v = rearrange(qkv, 'b s (n h d) -> n b s h d', n=3, h=H)

mask = [torch.ones(S - i, dtype=dtype, device="cuda") for i in range(B)]
mask = [torch.ones(S - i, dtype=torch.bool, device="cuda") for i in range(B)]
mask = torch.nn.utils.rnn.pad_sequence(mask, batch_first=True)

attn = ColoAttention(D, H, dropout=dropout)
y = attn(q, k, v, attn_mask=mask, attn_mask_type=AttnMaskType.paddedcausal)

assert list(y.shape) == [B, S, D]

out_ref = attention_ref(q, k, v, mask, causal=True)

# check gradients
dy = torch.rand_like(y)
y.backward(dy)
grad_q, grad_k, grad_v = torch.autograd.grad(y, (q, k, v), dy)
grad_ref_q, grad_ref_k, grad_ref_v = torch.autograd.grad(out_ref, (q, k, v), dy)

torch.allclose(y, out_ref, atol=1e-7), f"{(y - out_ref).abs().max()}"
torch.allclose(grad_q, grad_ref_q, atol=1e-7), f"{(grad_q - grad_ref_q).abs().max()}"
torch.allclose(grad_k, grad_ref_k, atol=1e-7), f"{(grad_k - grad_ref_k).abs().max()}"
torch.allclose(grad_v, grad_ref_v, atol=1e-7), f"{(grad_v - grad_ref_v).abs().max()}"


@pytest.mark.skipif(not HAS_MEM_EFF_ATTN and not HAS_FLASH_ATTN, reason="xformers is not available")
@clear_cache_before_run()
@parameterize('proj_shape', [(6, 8, 4, 16)])
@parameterize('dtype', DTYPE)
def test_attention_bert(proj_shape, dtype):
@parameterize('dropout', [0.0])
def test_attention_bert(proj_shape, dtype, dropout):
(B, S, H, D_HEAD) = proj_shape
D = H * D_HEAD

c_attn = torch.nn.Linear(D, 3 * D, dtype=dtype, device="cuda")
attn = ColoAttention(D, H, dropout=0.1)
q = torch.randn((B, S, H, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
k = torch.randn((B, S, H, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
v = torch.randn((B, S, H, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)

x = torch.randn((B, S, D), dtype=dtype, device="cuda")
# attention mask of shape [B, S] with zero padding to max length S
mask = [torch.ones(S - i, dtype=dtype, device="cuda") for i in range(B)]
mask = torch.nn.utils.rnn.pad_sequence(mask, batch_first=True)
mask = torch.randint(0, 2, (B, S), dtype=torch.bool, device="cuda")

qkv = c_attn(x)
q, k, v = rearrange(qkv, 'b s (n h d) -> b s n h d', n=3, h=H).unbind(dim=2)
attn = ColoAttention(D, H, dropout=dropout)
y = attn(q, k, v, attn_mask=mask, attn_mask_type=AttnMaskType.padding)

assert list(y.shape) == [B, S, D]

out_ref = attention_ref(q, k, v, mask, causal=False)

dy = torch.rand_like(y)
y.backward(dy)
grad_q, grad_k, grad_v = torch.autograd.grad(y, (q, k, v), dy)
grad_ref_q, grad_ref_k, grad_ref_v = torch.autograd.grad(out_ref, (q, k, v), dy)

torch.allclose(y, out_ref, atol=1e-7), f"{(y - out_ref).abs().max()}"
torch.allclose(grad_q, grad_ref_q, atol=1e-7), f"{(grad_q - grad_ref_q).abs().max()}"
torch.allclose(grad_k, grad_ref_k, atol=1e-7), f"{(grad_k - grad_ref_k).abs().max()}"
torch.allclose(grad_v, grad_ref_v, atol=1e-7), f"{(grad_v - grad_ref_v).abs().max()}"


@pytest.mark.skipif(not HAS_MEM_EFF_ATTN and not HAS_FLASH_ATTN, reason="xformers is not available")
@clear_cache_before_run()
@parameterize('proj_shape', [(6, 8, 4, 16)])
@parameterize('dtype', DTYPE)
def test_attention_no_mask(proj_shape, dtype):
@parameterize('dropout', [0.0])
def test_attention_no_mask(proj_shape, dtype, dropout):
(B, S, H, D_HEAD) = proj_shape
D = H * D_HEAD

c_attn = torch.nn.Linear(D, 3 * D, dtype=dtype, device="cuda")
attn = ColoAttention(D, H, dropout=0.1)
q = torch.randn((B, S, H, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
k = torch.randn((B, S, H, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
v = torch.randn((B, S, H, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)

x = torch.randn((B, S, D), dtype=dtype, device="cuda")
qkv = c_attn(x)
q, k, v = rearrange(qkv, 'b s (n h d) -> b s n h d', n=3, h=H).unbind(dim=2)
attn = ColoAttention(D, H, dropout=dropout)
y = attn(q, k, v)

assert list(y.shape) == [B, S, D]

out_ref = attention_ref(q, k, v, None, causal=False)

dy = torch.rand_like(y)
y.backward(dy)
grad_q, grad_k, grad_v = torch.autograd.grad(y, (q, k, v), dy)
grad_ref_q, grad_ref_k, grad_ref_v = torch.autograd.grad(out_ref, (q, k, v), dy)

torch.allclose(y, out_ref, atol=1e-7), f"{(y - out_ref).abs().max()}"
torch.allclose(grad_q, grad_ref_q, atol=1e-7), f"{(grad_q - grad_ref_q).abs().max()}"
torch.allclose(grad_k, grad_ref_k, atol=1e-7), f"{(grad_k - grad_ref_k).abs().max()}"
torch.allclose(grad_v, grad_ref_v, atol=1e-7), f"{(grad_v - grad_ref_v).abs().max()}"


@pytest.mark.skipif(not HAS_MEM_EFF_ATTN and not HAS_FLASH_ATTN, reason="xformers is not available")
@clear_cache_before_run()
@parameterize('proj_shape', [(6, 24, 8, 4, 16)])
@parameterize('dtype', DTYPE)
def test_cross_attention(proj_shape, dtype):
@parameterize('dropout', [0.0])
def test_cross_attention(proj_shape, dtype, dropout):
(B, S, T, H, D_HEAD) = proj_shape
D = H * D_HEAD

q_attn = torch.nn.Linear(D, D, dtype=dtype, device="cuda")
kv_attn = torch.nn.Linear(D, 2 * D, dtype=dtype, device="cuda")
q = torch.randn((B, T, H, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
k = torch.randn((B, S, H, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
v = torch.randn((B, S, H, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)

attn = ColoAttention(D, H, dropout=0.1)

src = torch.randn((B, S, D), dtype=dtype, device="cuda")
tgt = torch.randn((B, T, D), dtype=dtype, device="cuda")

q = q_attn(tgt)
kv = kv_attn(src)
q = rearrange(q, 'b s (h d) -> b s h d', h=H)
k, v = rearrange(kv, 'b s (n h d) -> b s n h d', n=2, h=H).unbind(dim=2)
attn = ColoAttention(D, H, dropout=dropout)
y = attn(q, k, v, attn_mask_type=AttnMaskType.causal)

assert list(y.shape) == [B, T, D]

out_ref = attention_ref(q, k, v, None, causal=True)

dy = torch.rand_like(y)
y.backward(dy)
grad_q, grad_k, grad_v = torch.autograd.grad(y, (q, k, v), dy)
grad_ref_q, grad_ref_k, grad_ref_v = torch.autograd.grad(out_ref, (q, k, v), dy)

torch.allclose(y, out_ref, atol=1e-18), f"{(y - out_ref).abs().max()}"
torch.allclose(grad_q, grad_ref_q, atol=1e-7), f"{(grad_q - grad_ref_q).abs().max()}"
torch.allclose(grad_k, grad_ref_k, atol=1e-7), f"{(grad_k - grad_ref_k).abs().max()}"
torch.allclose(grad_v, grad_ref_v, atol=1e-7), f"{(grad_v - grad_ref_v).abs().max()}"