diff --git a/examples/models/llama2/custom_ops/op_sdpa.cpp b/examples/models/llama2/custom_ops/op_sdpa.cpp index e3b3eba5416..5bda2506460 100644 --- a/examples/models/llama2/custom_ops/op_sdpa.cpp +++ b/examples/models/llama2/custom_ops/op_sdpa.cpp @@ -664,7 +664,6 @@ void update_cache( const Tensor& cache, int64_t start_pos, int64_t seq_length) { - ET_CHECK_MSG(seq_length == 1, "seq_length must be 1"); ET_CHECK_MSG( projected_value.size(0) == 1, "projected_value must have batch size of 1"); diff --git a/examples/models/llama2/custom_ops/test_sdpa_with_kv_cache.py b/examples/models/llama2/custom_ops/test_sdpa_with_kv_cache.py index 5bd05e3db05..1b8f425b673 100644 --- a/examples/models/llama2/custom_ops/test_sdpa_with_kv_cache.py +++ b/examples/models/llama2/custom_ops/test_sdpa_with_kv_cache.py @@ -235,6 +235,69 @@ def setUp(self): self.is_causal = True +class SDPAWithDynamicShape(unittest.TestCase): + + def setUp(self): + torch.manual_seed(42) + self.k_cache = torch.zeros((1, 10, 8, 4)) + self.v_cache = torch.zeros((1, 10, 8, 4)) + self.mask = torch.full( + (10, 10), + float("-inf"), + ) + self.mask = torch.triu(self.mask, diagonal=1) + self.use_mask_with_custom_op = False + self.is_causal = False + + def test_sdpa_with_cache_dynamic_shape_0(self): + q = torch.rand((1, 4, 8, 4)) + k = torch.rand((1, 4, 8, 4)) + v = torch.rand((1, 4, 8, 4)) + seq_len = q.size(1) + start_pos = 0 + attn_mask = self.mask[start_pos : start_pos + seq_len, :] + attn_mask = attn_mask[:, : start_pos + seq_len] + + ref_output = _sdpa_with_kv_cache_ref( + q, k, v, self.k_cache, self.v_cache, attn_mask, start_pos, seq_len + ) + + op_output = torch.ops.llama.sdpa_with_kv_cache( + q, k, v, self.k_cache, self.v_cache, start_pos, seq_len, None, 0, True + ) + self.assertTrue(torch.allclose(ref_output, op_output)) + + def test_sdpa_with_cache_dynamic_shape_2(self): + q = torch.rand((1, 3, 8, 4)) + k = torch.rand((1, 3, 8, 4)) + v = torch.rand((1, 3, 8, 4)) + seq_len = q.size(1) + start_pos = 2 + attn_mask = self.mask[start_pos : start_pos + seq_len, :] + attn_mask = attn_mask[:, : start_pos + seq_len] + + ref_output = _sdpa_with_kv_cache_ref( + q, k, v, self.k_cache, self.v_cache, attn_mask, start_pos, seq_len + ) + + op_output = torch.ops.llama.sdpa_with_kv_cache( + q, k, v, self.k_cache, self.v_cache, start_pos, seq_len, None, 0, True + ) + self.assertTrue(torch.allclose(ref_output, op_output)) + + @unittest.skip("This test will expect failure but runtime is not bubbling it up.") + def test_sdpa_with_cache_dynamic_shape_4(self): + q = torch.rand((1, 11, 8, 4)) + k = torch.rand((1, 11, 8, 4)) + v = torch.rand((1, 11, 8, 4)) + seq_len = q.size(1) + start_pos = 4 + + torch.ops.llama.sdpa_with_kv_cache( + q, k, v, self.k_cache, self.v_cache, start_pos, seq_len, None, 0, True + ) + + class SDPATestWithMQA(unittest.TestCase): def setup_caches(self):