diff --git a/examples/models/llama2/custom_ops/op_sdpa_with_kv_cache_test.cpp b/examples/models/llama2/custom_ops/op_sdpa_with_kv_cache_test.cpp index fa2d164fe3d..8496f15d781 100644 --- a/examples/models/llama2/custom_ops/op_sdpa_with_kv_cache_test.cpp +++ b/examples/models/llama2/custom_ops/op_sdpa_with_kv_cache_test.cpp @@ -73,7 +73,6 @@ attn_mask=attn_mask) /* Missing tests: -1. Add back tests with attention masks 2. Add different seq lengths on Q 3. Test for different batch sizes 4. Apply causal attention diff --git a/examples/models/llama2/custom_ops/test_sdpa_with_kv_cache.py b/examples/models/llama2/custom_ops/test_sdpa_with_kv_cache.py index bf857204e1c..9005021dec5 100644 --- a/examples/models/llama2/custom_ops/test_sdpa_with_kv_cache.py +++ b/examples/models/llama2/custom_ops/test_sdpa_with_kv_cache.py @@ -12,10 +12,7 @@ from .sdpa_with_kv_cache import custom_ops_lib # noqa -def _sdpa_with_kv_cache_ref(q, k, v, k_cache, v_cache, mask, start_pos): - seq_len = q.size(1) - attn_mask = mask[start_pos : start_pos + seq_len, :] - attn_mask = attn_mask[:, : start_pos + seq_len] +def _sdpa_with_kv_cache_ref(q, k, v, k_cache, v_cache, attn_mask, start_pos, seq_len): q = q.transpose(1, 2) k_cache[:, start_pos : start_pos + seq_len, :, :] = k v_cache[:, start_pos : start_pos + seq_len, :, :] = v @@ -52,59 +49,150 @@ def setUp(self): float("-inf"), ) self.mask = torch.triu(self.mask, diagonal=1) + self.use_mask_with_custom_op = False def test_sdpa_with_cache_no_mqa_1(self): q = torch.rand((1, 1, 8, 4)) k = torch.rand((1, 1, 8, 4)) v = torch.rand((1, 1, 8, 4)) + start_pos = 0 + seq_len = q.size(1) + attn_mask = self.mask[start_pos : start_pos + seq_len, :] + attn_mask = attn_mask[:, : start_pos + seq_len] ref_output = _sdpa_with_kv_cache_ref( - q, k, v, self.k_cache, self.v_cache, self.mask, 0 - ) - op_output = torch.ops.llama.sdpa_with_kv_cache( - q, k, v, self.k_cache, self.v_cache, 0, 1, None, 0, False + q, k, v, self.k_cache, self.v_cache, attn_mask, start_pos, seq_len ) + if self.use_mask_with_custom_op: + attn_mask = attn_mask.contiguous() + op_output = torch.ops.llama.sdpa_with_kv_cache( + q, + k, + v, + self.k_cache, + self.v_cache, + start_pos, + seq_len, + attn_mask, + 0, + False, + ) + else: + op_output = torch.ops.llama.sdpa_with_kv_cache( + q, k, v, self.k_cache, self.v_cache, start_pos, seq_len, None, 0, False + ) self.assertTrue(torch.allclose(ref_output, op_output)) def test_sdpa_with_cache_no_mqa_2(self): q = torch.rand((1, 1, 8, 4)) k = torch.rand((1, 1, 8, 4)) v = torch.rand((1, 1, 8, 4)) + start_pos = 1 + seq_len = q.size(1) + attn_mask = self.mask[start_pos : start_pos + seq_len, :] + attn_mask = attn_mask[:, : start_pos + seq_len] ref_output = _sdpa_with_kv_cache_ref( - q, k, v, self.k_cache, self.v_cache, self.mask, 1 - ) - op_output = torch.ops.llama.sdpa_with_kv_cache( - q, k, v, self.k_cache, self.v_cache, 1, 1, None, 0, False + q, k, v, self.k_cache, self.v_cache, attn_mask, start_pos, seq_len ) + if self.use_mask_with_custom_op: + attn_mask = attn_mask.contiguous() + op_output = torch.ops.llama.sdpa_with_kv_cache( + q, + k, + v, + self.k_cache, + self.v_cache, + start_pos, + seq_len, + attn_mask, + 0, + False, + ) + else: + op_output = torch.ops.llama.sdpa_with_kv_cache( + q, k, v, self.k_cache, self.v_cache, start_pos, seq_len, None, 0, False + ) + self.assertTrue(torch.allclose(ref_output, op_output)) def test_sdpa_with_cache_no_mqa_3(self): q = torch.rand((1, 1, 8, 4)) k = torch.rand((1, 1, 8, 4)) v = torch.rand((1, 1, 8, 4)) + start_pos = 2 + seq_len = q.size(1) + attn_mask = self.mask[start_pos : start_pos + seq_len, :] + attn_mask = attn_mask[:, : start_pos + seq_len] ref_output = _sdpa_with_kv_cache_ref( - q, k, v, self.k_cache, self.v_cache, self.mask, 2 - ) - op_output = torch.ops.llama.sdpa_with_kv_cache( - q, k, v, self.k_cache, self.v_cache, 2, 1, None, 0, False + q, k, v, self.k_cache, self.v_cache, attn_mask, start_pos, seq_len ) + if self.use_mask_with_custom_op: + attn_mask = attn_mask.contiguous() + op_output = torch.ops.llama.sdpa_with_kv_cache( + q, + k, + v, + self.k_cache, + self.v_cache, + start_pos, + seq_len, + attn_mask, + 0, + False, + ) + else: + op_output = torch.ops.llama.sdpa_with_kv_cache( + q, k, v, self.k_cache, self.v_cache, start_pos, seq_len, None, 0, False + ) self.assertTrue(torch.allclose(ref_output, op_output)) def test_sdpa_with_cache_no_mqa_4(self): q = torch.rand((1, 1, 8, 4)) k = torch.rand((1, 1, 8, 4)) v = torch.rand((1, 1, 8, 4)) + start_pos = 3 + seq_len = q.size(1) + attn_mask = self.mask[start_pos : start_pos + seq_len, :] + attn_mask = attn_mask[:, : start_pos + seq_len] ref_output = _sdpa_with_kv_cache_ref( - q, k, v, self.k_cache, self.v_cache, self.mask, 3 - ) - op_output = torch.ops.llama.sdpa_with_kv_cache( - q, k, v, self.k_cache, self.v_cache, 3, 1, None, 0, False + q, k, v, self.k_cache, self.v_cache, attn_mask, start_pos, seq_len ) + if self.use_mask_with_custom_op: + attn_mask = attn_mask.contiguous() + op_output = torch.ops.llama.sdpa_with_kv_cache( + q, + k, + v, + self.k_cache, + self.v_cache, + start_pos, + seq_len, + attn_mask, + 0, + False, + ) + else: + op_output = torch.ops.llama.sdpa_with_kv_cache( + q, k, v, self.k_cache, self.v_cache, start_pos, seq_len, None, 0, False + ) self.assertTrue(torch.allclose(ref_output, op_output)) +class SDPAWithAttentionMaskTest(SDPATest): + + def setUp(self): + torch.manual_seed(42) + self.k_cache = torch.zeros((1, 10, 8, 4)) + self.v_cache = torch.zeros((1, 10, 8, 4)) + self.mask = torch.full( + (10, 10), + 100.642, + ) + self.use_mask_with_custom_op = True + + class SDPATestWithMQA(unittest.TestCase): def setup_caches(self): @@ -126,8 +214,12 @@ def test_sdpa_with_cache_mqa_1(self): q = torch.rand((1, 1, self.n_heads_q, 4)) k = torch.rand((1, 1, self.n_heads_kv, 4)) v = torch.rand((1, 1, self.n_heads_kv, 4)) + start_pos = 0 + seq_len = q.size(1) + attn_mask = self.mask[start_pos : start_pos + seq_len, :] + attn_mask = attn_mask[:, : start_pos + seq_len] ref_output = _sdpa_with_kv_cache_ref( - q, k, v, self.k_cache, self.v_cache, self.mask, 0 + q, k, v, self.k_cache, self.v_cache, attn_mask, start_pos, seq_len ) op_output = torch.ops.llama.sdpa_with_kv_cache( q, k, v, self.k_cache, self.v_cache, 0, 1, None, 0, False @@ -138,8 +230,12 @@ def test_sdpa_with_cache_mqa_2(self): q = torch.rand((1, 1, self.n_heads_q, 4)) k = torch.rand((1, 1, self.n_heads_kv, 4)) v = torch.rand((1, 1, self.n_heads_kv, 4)) + start_pos = 1 + seq_len = q.size(1) + attn_mask = self.mask[start_pos : start_pos + seq_len, :] + attn_mask = attn_mask[:, : start_pos + seq_len] ref_output = _sdpa_with_kv_cache_ref( - q, k, v, self.k_cache, self.v_cache, self.mask, 1 + q, k, v, self.k_cache, self.v_cache, attn_mask, start_pos, seq_len ) op_output = torch.ops.llama.sdpa_with_kv_cache( q, k, v, self.k_cache, self.v_cache, 1, 1, None, 0, False @@ -153,8 +249,12 @@ def test_sdpa_with_cache_mqa_3(self): q = torch.rand((1, 1, self.n_heads_q, 4)) k = torch.rand((1, 1, self.n_heads_kv, 4)) v = torch.rand((1, 1, self.n_heads_kv, 4)) + start_pos = 1 + seq_len = q.size(1) + attn_mask = self.mask[start_pos : start_pos + seq_len, :] + attn_mask = attn_mask[:, : start_pos + seq_len] ref_output = _sdpa_with_kv_cache_ref( - q, k, v, self.k_cache, self.v_cache, self.mask, 1 + q, k, v, self.k_cache, self.v_cache, attn_mask, start_pos, seq_len ) op_output = torch.ops.llama.sdpa_with_kv_cache( q, k, v, self.k_cache, self.v_cache, 1, 1, None, 0, False