Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 52 additions & 5 deletions examples/models/llama2/custom_ops/op_sdpa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,42 @@ inline void fill_stub(scalar_t* data, scalar_t val, int64_t size) {
}
}

/*
Note on start_pos as a parameter:
What is start_pos?
- start_pos is the position of the first element of the current query. That is,
in LLMs during generate phase, when we generate one token a time, the query
will correspond to monotonically increasing start_pos. e.g. the first token
is at start_pos = 0, the second token is at start_pos = 1, and so on.
If we do prefill with prompt which has 4 tokens, then during the decode phase,
start_pos = 4.

Why is start_pos neded?
- Attention should not need to know start_pos. However, to apply causal mask,
we can use is_causal parameter (aten API for SDPA is thinking of getting rid
of it). However, the current handling of is_causal assumes that start_pos = 0.
Meaning when we have a query during decode at start_pos = 4, it will be a
single vector of [1, head_dim] for a given head. Key param, derived from kv
cache, will be of size [start_pos + 1, head_dim]. That is all the past tokens
contained in kv cache. If we apply causal mask naively, then the query is
assumed to be at start_pos = 0, and thus all the future tokens (indices 1...4)
in q @ k.T = [1, start_pos], will be masked out for attention calculation.
However, that is not right. Since query is at pos 4, that is 4th token, it
should attend to all previous tokens in the cache. That is 0...start_pos. Thus
we need to pass start_pos.

Can we use attn_mask?
- Yes. Attention mask can be used for the same, however, at the moment attention
mask for our llama model is a boolean mask which requires conversion to -inf for
masked out section. This requires change that may have perf implication, however
we havent really validated this. It is possible that there is no perf
implication. If the mask was float mask, thing will work out-of-the-box. In our
llama definition each layer is storying mask and if we move to float mask, that
can increase memory footprint, which is right now optimized away since
sdpa_with_kv_cache does not use attn_mask.

TODO: Just handle conversion of bool mask to float
*/
template <typename scalar_t, int64_t q_split_size, int64_t kv_split_size>
void cpu_flash_attention(
Tensor& output,
Expand All @@ -187,7 +223,8 @@ void cpu_flash_attention(
bool is_causal,
const optional<Tensor>& attn_mask,
const optional<double>& scale,
bool is_with_kv_cache = false) {
bool is_with_kv_cache = false,
const int64_t start_pos = 0) {
(void)dropout_p;
// Query (Batch x Num_heads x Q_seq_len x Dim_per_head)
// Key (Batch x Num_heads x KV_seq_len x Dim_per_head)
Expand Down Expand Up @@ -416,7 +453,7 @@ void cpu_flash_attention(
// will qualify for that
if (is_causal && num_keys - n <= kvSplitSize) {
for (int32_t row = 0; row < qBlockSize; ++row) {
int64_t last_col = m + row - n;
int64_t last_col = m + (row + start_pos) - n;
accum_t* row_ptr = qk_data + row * kvBlockSize;
fill_stub(
row_ptr + last_col + 1,
Expand Down Expand Up @@ -760,6 +797,13 @@ Tensor& sdpa_with_kv_cache_out(
InvalidArgument,
output);

ET_KERNEL_CHECK_MSG(
ctx,
!attn_mask.has_value() || !is_causal,
InvalidArgument,
output,
"attn_mask and is_causal cannot be set at the same time");

ET_CHECK_MSG(q_projected.dim() == 4, "query must be a 4D tensor");

update_cache(k_projected, key_cache, start_pos, seq_len);
Expand Down Expand Up @@ -844,7 +888,8 @@ Tensor& sdpa_with_kv_cache_out(
is_causal,
attn_mask,
scale,
true);
true,
start_pos);
} else if (q_seq_len >= 192) {
cpu_flash_attention<CTYPE, 64, 512>(
output,
Expand All @@ -855,7 +900,8 @@ Tensor& sdpa_with_kv_cache_out(
is_causal,
attn_mask,
scale,
true);
true,
start_pos);
} else {
cpu_flash_attention<CTYPE, 32, 512>(
output,
Expand All @@ -866,7 +912,8 @@ Tensor& sdpa_with_kv_cache_out(
is_causal,
attn_mask,
scale,
true);
true,
start_pos);
}
});
return output;
Expand Down
12 changes: 5 additions & 7 deletions examples/models/llama2/custom_ops/op_sdpa_with_kv_cache_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,13 +73,11 @@ attn_mask=attn_mask)

/*
Missing tests:
2. Add different seq lengths on Q
3. Test for different batch sizes
4. Apply causal attention
5. Mix 2 with attention_mask
6. No bool attention_mask
7. apply scaling
8. Different dtypes, fp16, bf16, double (or expect throw)
1. Test for different batch sizes
2. Mix 2 with attention_mask
3. No bool attention_mask
4. apply scaling
5. Different dtypes, fp16, bf16, double (or expect throw)
*/
TEST(OpScaledDotProductAttentionTest, BasicTest) {
torch::executor::testing::TensorFactory<exec_aten::ScalarType::Float> tfFloat;
Expand Down
56 changes: 49 additions & 7 deletions examples/models/llama2/custom_ops/test_sdpa_with_kv_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def setUp(self):
)
self.mask = torch.triu(self.mask, diagonal=1)
self.use_mask_with_custom_op = False
self.is_causal = False

def test_sdpa_with_cache_no_mqa_1(self):
q = torch.rand((1, 1, 8, 4))
Expand Down Expand Up @@ -78,7 +79,16 @@ def test_sdpa_with_cache_no_mqa_1(self):
)
else:
op_output = torch.ops.llama.sdpa_with_kv_cache(
q, k, v, self.k_cache, self.v_cache, start_pos, seq_len, None, 0, False
q,
k,
v,
self.k_cache,
self.v_cache,
start_pos,
seq_len,
None,
0,
self.is_causal,
)
self.assertTrue(torch.allclose(ref_output, op_output))

Expand Down Expand Up @@ -110,7 +120,16 @@ def test_sdpa_with_cache_no_mqa_2(self):
)
else:
op_output = torch.ops.llama.sdpa_with_kv_cache(
q, k, v, self.k_cache, self.v_cache, start_pos, seq_len, None, 0, False
q,
k,
v,
self.k_cache,
self.v_cache,
start_pos,
seq_len,
None,
0,
self.is_causal,
)

self.assertTrue(torch.allclose(ref_output, op_output))
Expand Down Expand Up @@ -143,7 +162,16 @@ def test_sdpa_with_cache_no_mqa_3(self):
)
else:
op_output = torch.ops.llama.sdpa_with_kv_cache(
q, k, v, self.k_cache, self.v_cache, start_pos, seq_len, None, 0, False
q,
k,
v,
self.k_cache,
self.v_cache,
start_pos,
seq_len,
None,
0,
self.is_causal,
)
self.assertTrue(torch.allclose(ref_output, op_output))

Expand Down Expand Up @@ -175,24 +203,38 @@ def test_sdpa_with_cache_no_mqa_4(self):
)
else:
op_output = torch.ops.llama.sdpa_with_kv_cache(
q, k, v, self.k_cache, self.v_cache, start_pos, seq_len, None, 0, False
q,
k,
v,
self.k_cache,
self.v_cache,
start_pos,
seq_len,
None,
0,
self.is_causal,
)
self.assertTrue(torch.allclose(ref_output, op_output))


class SDPAWithAttentionMaskTest(SDPATest):

def setUp(self):
torch.manual_seed(42)
self.k_cache = torch.zeros((1, 10, 8, 4))
self.v_cache = torch.zeros((1, 10, 8, 4))
SDPATest.setUp(self)
self.mask = torch.full(
(10, 10),
100.642,
)
self.use_mask_with_custom_op = True


class SDPAWithCausalTest(SDPATest):

def setUp(self):
SDPATest.setUp(self)
self.is_causal = True


class SDPATestWithMQA(unittest.TestCase):

def setup_caches(self):
Expand Down