Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion examples/models/llama2/custom_ops/op_sdpa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -400,7 +400,20 @@ void cpu_flash_attention(
static_cast<accum_t>(0),
qk_data,
kvBlockSize);
// Apply causal mask, fill unused with -inf
// Apply causal mask, fill unused, i.e. future values, with -inf
// Say you have q @ k.T size = [16, 32]
// With qblock size = 4, say you are processing
// q seq len dim = 8:11.
// Say kvSplitSize = 4
// Then for causal mask, the entries that needs to be
// ignored are
// [8, 9:31], [9, 10:31], [10, 10:31], [11, 11:31]
// Following condition says that num_keys = 8 + 4 =12
// (num_keys - n) <= kvSplitSize
// num_keys <= n + kvSplitSize
// If n + kvSplitSize is larger than 12, then some
// entries need masked out. In our example n = 4
// will qualify for that
if (is_causal && num_keys - n <= kvSplitSize) {
for (int32_t row = 0; row < qBlockSize; ++row) {
int64_t last_col = m + row - n;
Expand Down
121 changes: 39 additions & 82 deletions examples/models/llama2/custom_ops/test_sdpa_with_kv_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,54 +12,52 @@
from .sdpa_with_kv_cache import custom_ops_lib # noqa


def _sdpa_with_kv_cache_ref(q, k, v, k_cache, v_cache, mask, start_pos):
seq_len = q.size(1)
attn_mask = mask[start_pos : start_pos + seq_len, :]
attn_mask = attn_mask[:, : start_pos + seq_len]
q = q.transpose(1, 2)
k_cache[:, start_pos : start_pos + seq_len, :, :] = k
v_cache[:, start_pos : start_pos + seq_len, :, :] = v
sliced_k_cache = k_cache[:, : start_pos + seq_len, :, :]
sliced_v_cache = v_cache[:, : start_pos + seq_len, :, :]
sliced_k_cache = sliced_k_cache.transpose(1, 2)
sliced_v_cache = sliced_v_cache.transpose(1, 2)

num_heads_q = q.size(1)
num_heads_kv = sliced_k_cache.size(1)
if num_heads_q != num_heads_kv:
assert (
num_heads_q % num_heads_kv == 0
), f"{num_heads_q} not divisible by {num_heads_kv}"
n_reps = num_heads_q // num_heads_kv
if n_reps > 1:
sliced_k_cache = sliced_k_cache.repeat_interleave(n_reps, dim=1)
sliced_v_cache = sliced_v_cache.repeat_interleave(n_reps, dim=1)
out = F.scaled_dot_product_attention(
q, sliced_k_cache, sliced_v_cache, attn_mask=attn_mask
)
out = out.transpose(1, 2)
return out


class SDPATest(unittest.TestCase):

def setUp(self):
torch.manual_seed(42)
self.k_cache = torch.zeros((1, 5, 8, 4))
self.v_cache = torch.zeros((1, 5, 8, 4))
self.k_cache = torch.zeros((1, 10, 8, 4))
self.v_cache = torch.zeros((1, 10, 8, 4))
self.mask = torch.full(
(5, 5),
(10, 10),
float("-inf"),
)
self.mask = torch.triu(self.mask, diagonal=1)

def _sdpa_with_kv_cache_ref(self, q, k, v, k_cache, v_cache, mask, start_pos):
print(f"at start_pos:{start_pos}")
print(q)
print(k)
print(v)
attn_mask = mask[start_pos].view((1, -1))
attn_mask = attn_mask[:, : start_pos + 1]
q = q.transpose(1, 2)
k_cache[:, start_pos] = k
v_cache[:, start_pos] = v
sliced_k_cache = k_cache[:, : start_pos + 1, :, :]
sliced_v_cache = v_cache[:, : start_pos + 1, :, :]
sliced_k_cache = sliced_k_cache.transpose(1, 2)
sliced_v_cache = sliced_v_cache.transpose(1, 2)
# print(sliced_k_cache.size())
# print(torch.matmul(q, sliced_k_cache.transpose(2, 3)))
# print("q @ k")
# qk = torch.matmul(q, sliced_k_cache.transpose(2, 3))
# qk_softmax = torch.softmax(qk, dim=-1)
# qkv = torch.matmul(qk_softmax, sliced_v_cache)
# print(qk)
# print(qk_softmax)
# print(qkv)
out = F.scaled_dot_product_attention(
q, sliced_k_cache, sliced_v_cache, attn_mask=attn_mask
)
out = out.transpose(1, 2)
print(out)
print(f"-------- start pos {start_pos} done -----")
return out

def test_sdpa_with_cache_no_mqa_1(self):
q = torch.rand((1, 1, 8, 4))
k = torch.rand((1, 1, 8, 4))
v = torch.rand((1, 1, 8, 4))
ref_output = self._sdpa_with_kv_cache_ref(
ref_output = _sdpa_with_kv_cache_ref(
q, k, v, self.k_cache, self.v_cache, self.mask, 0
)
op_output = torch.ops.llama.sdpa_with_kv_cache(
Expand All @@ -72,7 +70,7 @@ def test_sdpa_with_cache_no_mqa_2(self):
k = torch.rand((1, 1, 8, 4))
v = torch.rand((1, 1, 8, 4))

ref_output = self._sdpa_with_kv_cache_ref(
ref_output = _sdpa_with_kv_cache_ref(
q, k, v, self.k_cache, self.v_cache, self.mask, 1
)
op_output = torch.ops.llama.sdpa_with_kv_cache(
Expand All @@ -85,7 +83,7 @@ def test_sdpa_with_cache_no_mqa_3(self):
k = torch.rand((1, 1, 8, 4))
v = torch.rand((1, 1, 8, 4))

ref_output = self._sdpa_with_kv_cache_ref(
ref_output = _sdpa_with_kv_cache_ref(
q, k, v, self.k_cache, self.v_cache, self.mask, 2
)
op_output = torch.ops.llama.sdpa_with_kv_cache(
Expand All @@ -98,7 +96,7 @@ def test_sdpa_with_cache_no_mqa_4(self):
k = torch.rand((1, 1, 8, 4))
v = torch.rand((1, 1, 8, 4))

ref_output = self._sdpa_with_kv_cache_ref(
ref_output = _sdpa_with_kv_cache_ref(
q, k, v, self.k_cache, self.v_cache, self.mask, 3
)
op_output = torch.ops.llama.sdpa_with_kv_cache(
Expand All @@ -124,52 +122,11 @@ def setUp(self):
)
self.mask = torch.triu(self.mask, diagonal=1)

def _sdpa_with_kv_cache_ref(self, q, k, v, k_cache, v_cache, mask, start_pos):
print(f"at start_pos:{start_pos}")
print(q)
print(k)
print(v)
attn_mask = mask[start_pos].view((1, -1))
attn_mask = attn_mask[:, : start_pos + 1]
q = q.transpose(1, 2)
k_cache[:, start_pos] = k
v_cache[:, start_pos] = v
sliced_k_cache = k_cache[:, : start_pos + 1, :, :]
sliced_v_cache = v_cache[:, : start_pos + 1, :, :]
sliced_k_cache = sliced_k_cache.transpose(1, 2)
sliced_v_cache = sliced_v_cache.transpose(1, 2)
# print(sliced_k_cache.size())
# print(torch.matmul(q, sliced_k_cache.transpose(2, 3)))
# print("q @ k")
# qk = torch.matmul(q, sliced_k_cache.transpose(2, 3))
# qk_softmax = torch.softmax(qk, dim=-1)
# qkv = torch.matmul(qk_softmax, sliced_v_cache)
# print(qk)
# print(qk_softmax)
# print(qkv)
num_heads_q = q.size(1)
num_heads_kv = sliced_k_cache.size(1)
if num_heads_q != num_heads_kv:
assert (
num_heads_q % num_heads_kv == 0
), f"{num_heads_q} not divisible by {num_heads_kv}"
n_reps = num_heads_q // num_heads_kv
if n_reps > 1:
sliced_k_cache = sliced_k_cache.repeat_interleave(n_reps, dim=1)
sliced_v_cache = sliced_v_cache.repeat_interleave(n_reps, dim=1)
out = F.scaled_dot_product_attention(
q, sliced_k_cache, sliced_v_cache, attn_mask=attn_mask
)
out = out.transpose(1, 2)
print(out)
print(f"-------- start pos {start_pos} done -----")
return out

def test_sdpa_with_cache_mqa_1(self):
q = torch.rand((1, 1, self.n_heads_q, 4))
k = torch.rand((1, 1, self.n_heads_kv, 4))
v = torch.rand((1, 1, self.n_heads_kv, 4))
ref_output = self._sdpa_with_kv_cache_ref(
ref_output = _sdpa_with_kv_cache_ref(
q, k, v, self.k_cache, self.v_cache, self.mask, 0
)
op_output = torch.ops.llama.sdpa_with_kv_cache(
Expand All @@ -181,7 +138,7 @@ def test_sdpa_with_cache_mqa_2(self):
q = torch.rand((1, 1, self.n_heads_q, 4))
k = torch.rand((1, 1, self.n_heads_kv, 4))
v = torch.rand((1, 1, self.n_heads_kv, 4))
ref_output = self._sdpa_with_kv_cache_ref(
ref_output = _sdpa_with_kv_cache_ref(
q, k, v, self.k_cache, self.v_cache, self.mask, 1
)
op_output = torch.ops.llama.sdpa_with_kv_cache(
Expand All @@ -196,7 +153,7 @@ def test_sdpa_with_cache_mqa_3(self):
q = torch.rand((1, 1, self.n_heads_q, 4))
k = torch.rand((1, 1, self.n_heads_kv, 4))
v = torch.rand((1, 1, self.n_heads_kv, 4))
ref_output = self._sdpa_with_kv_cache_ref(
ref_output = _sdpa_with_kv_cache_ref(
q, k, v, self.k_cache, self.v_cache, self.mask, 1
)
op_output = torch.ops.llama.sdpa_with_kv_cache(
Expand Down