diff --git a/examples/models/llama2/custom_ops/op_sdpa.cpp b/examples/models/llama2/custom_ops/op_sdpa.cpp index b2cd7fb0d0a..e3b3eba5416 100644 --- a/examples/models/llama2/custom_ops/op_sdpa.cpp +++ b/examples/models/llama2/custom_ops/op_sdpa.cpp @@ -177,6 +177,42 @@ inline void fill_stub(scalar_t* data, scalar_t val, int64_t size) { } } +/* +Note on start_pos as a parameter: +What is start_pos? +- start_pos is the position of the first element of the current query. That is, +in LLMs during generate phase, when we generate one token a time, the query +will correspond to monotonically increasing start_pos. e.g. the first token +is at start_pos = 0, the second token is at start_pos = 1, and so on. +If we do prefill with prompt which has 4 tokens, then during the decode phase, +start_pos = 4. + +Why is start_pos neded? +- Attention should not need to know start_pos. However, to apply causal mask, +we can use is_causal parameter (aten API for SDPA is thinking of getting rid +of it). However, the current handling of is_causal assumes that start_pos = 0. +Meaning when we have a query during decode at start_pos = 4, it will be a +single vector of [1, head_dim] for a given head. Key param, derived from kv +cache, will be of size [start_pos + 1, head_dim]. That is all the past tokens +contained in kv cache. If we apply causal mask naively, then the query is +assumed to be at start_pos = 0, and thus all the future tokens (indices 1...4) +in q @ k.T = [1, start_pos], will be masked out for attention calculation. +However, that is not right. Since query is at pos 4, that is 4th token, it +should attend to all previous tokens in the cache. That is 0...start_pos. Thus +we need to pass start_pos. + +Can we use attn_mask? +- Yes. Attention mask can be used for the same, however, at the moment attention +mask for our llama model is a boolean mask which requires conversion to -inf for +masked out section. This requires change that may have perf implication, however +we havent really validated this. It is possible that there is no perf +implication. If the mask was float mask, thing will work out-of-the-box. In our +llama definition each layer is storying mask and if we move to float mask, that +can increase memory footprint, which is right now optimized away since +sdpa_with_kv_cache does not use attn_mask. + +TODO: Just handle conversion of bool mask to float +*/ template void cpu_flash_attention( Tensor& output, @@ -187,7 +223,8 @@ void cpu_flash_attention( bool is_causal, const optional& attn_mask, const optional& scale, - bool is_with_kv_cache = false) { + bool is_with_kv_cache = false, + const int64_t start_pos = 0) { (void)dropout_p; // Query (Batch x Num_heads x Q_seq_len x Dim_per_head) // Key (Batch x Num_heads x KV_seq_len x Dim_per_head) @@ -416,7 +453,7 @@ void cpu_flash_attention( // will qualify for that if (is_causal && num_keys - n <= kvSplitSize) { for (int32_t row = 0; row < qBlockSize; ++row) { - int64_t last_col = m + row - n; + int64_t last_col = m + (row + start_pos) - n; accum_t* row_ptr = qk_data + row * kvBlockSize; fill_stub( row_ptr + last_col + 1, @@ -760,6 +797,13 @@ Tensor& sdpa_with_kv_cache_out( InvalidArgument, output); + ET_KERNEL_CHECK_MSG( + ctx, + !attn_mask.has_value() || !is_causal, + InvalidArgument, + output, + "attn_mask and is_causal cannot be set at the same time"); + ET_CHECK_MSG(q_projected.dim() == 4, "query must be a 4D tensor"); update_cache(k_projected, key_cache, start_pos, seq_len); @@ -844,7 +888,8 @@ Tensor& sdpa_with_kv_cache_out( is_causal, attn_mask, scale, - true); + true, + start_pos); } else if (q_seq_len >= 192) { cpu_flash_attention( output, @@ -855,7 +900,8 @@ Tensor& sdpa_with_kv_cache_out( is_causal, attn_mask, scale, - true); + true, + start_pos); } else { cpu_flash_attention( output, @@ -866,7 +912,8 @@ Tensor& sdpa_with_kv_cache_out( is_causal, attn_mask, scale, - true); + true, + start_pos); } }); return output; diff --git a/examples/models/llama2/custom_ops/op_sdpa_with_kv_cache_test.cpp b/examples/models/llama2/custom_ops/op_sdpa_with_kv_cache_test.cpp index 8496f15d781..5ea7856dba9 100644 --- a/examples/models/llama2/custom_ops/op_sdpa_with_kv_cache_test.cpp +++ b/examples/models/llama2/custom_ops/op_sdpa_with_kv_cache_test.cpp @@ -73,13 +73,11 @@ attn_mask=attn_mask) /* Missing tests: -2. Add different seq lengths on Q -3. Test for different batch sizes -4. Apply causal attention -5. Mix 2 with attention_mask -6. No bool attention_mask -7. apply scaling -8. Different dtypes, fp16, bf16, double (or expect throw) +1. Test for different batch sizes +2. Mix 2 with attention_mask +3. No bool attention_mask +4. apply scaling +5. Different dtypes, fp16, bf16, double (or expect throw) */ TEST(OpScaledDotProductAttentionTest, BasicTest) { torch::executor::testing::TensorFactory tfFloat; diff --git a/examples/models/llama2/custom_ops/test_sdpa_with_kv_cache.py b/examples/models/llama2/custom_ops/test_sdpa_with_kv_cache.py index 9005021dec5..5bd05e3db05 100644 --- a/examples/models/llama2/custom_ops/test_sdpa_with_kv_cache.py +++ b/examples/models/llama2/custom_ops/test_sdpa_with_kv_cache.py @@ -50,6 +50,7 @@ def setUp(self): ) self.mask = torch.triu(self.mask, diagonal=1) self.use_mask_with_custom_op = False + self.is_causal = False def test_sdpa_with_cache_no_mqa_1(self): q = torch.rand((1, 1, 8, 4)) @@ -78,7 +79,16 @@ def test_sdpa_with_cache_no_mqa_1(self): ) else: op_output = torch.ops.llama.sdpa_with_kv_cache( - q, k, v, self.k_cache, self.v_cache, start_pos, seq_len, None, 0, False + q, + k, + v, + self.k_cache, + self.v_cache, + start_pos, + seq_len, + None, + 0, + self.is_causal, ) self.assertTrue(torch.allclose(ref_output, op_output)) @@ -110,7 +120,16 @@ def test_sdpa_with_cache_no_mqa_2(self): ) else: op_output = torch.ops.llama.sdpa_with_kv_cache( - q, k, v, self.k_cache, self.v_cache, start_pos, seq_len, None, 0, False + q, + k, + v, + self.k_cache, + self.v_cache, + start_pos, + seq_len, + None, + 0, + self.is_causal, ) self.assertTrue(torch.allclose(ref_output, op_output)) @@ -143,7 +162,16 @@ def test_sdpa_with_cache_no_mqa_3(self): ) else: op_output = torch.ops.llama.sdpa_with_kv_cache( - q, k, v, self.k_cache, self.v_cache, start_pos, seq_len, None, 0, False + q, + k, + v, + self.k_cache, + self.v_cache, + start_pos, + seq_len, + None, + 0, + self.is_causal, ) self.assertTrue(torch.allclose(ref_output, op_output)) @@ -175,7 +203,16 @@ def test_sdpa_with_cache_no_mqa_4(self): ) else: op_output = torch.ops.llama.sdpa_with_kv_cache( - q, k, v, self.k_cache, self.v_cache, start_pos, seq_len, None, 0, False + q, + k, + v, + self.k_cache, + self.v_cache, + start_pos, + seq_len, + None, + 0, + self.is_causal, ) self.assertTrue(torch.allclose(ref_output, op_output)) @@ -183,9 +220,7 @@ def test_sdpa_with_cache_no_mqa_4(self): class SDPAWithAttentionMaskTest(SDPATest): def setUp(self): - torch.manual_seed(42) - self.k_cache = torch.zeros((1, 10, 8, 4)) - self.v_cache = torch.zeros((1, 10, 8, 4)) + SDPATest.setUp(self) self.mask = torch.full( (10, 10), 100.642, @@ -193,6 +228,13 @@ def setUp(self): self.use_mask_with_custom_op = True +class SDPAWithCausalTest(SDPATest): + + def setUp(self): + SDPATest.setUp(self) + self.is_causal = True + + class SDPATestWithMQA(unittest.TestCase): def setup_caches(self):