From 6d4db9998dff3dfc918e40472091a35a015396b4 Mon Sep 17 00:00:00 2001 From: Kimish Patel Date: Tue, 25 Jun 2024 15:31:31 -0700 Subject: [PATCH] Enable dynamic shape tests for sdpa with kv cache Note that in practice during prefill we will always see start_pos=0 regardless prompt size. However, the test added here also simulate batch (along seq dim) inference done not just at the beginning. This is useful for speculative decoding where you want the larger model to run batched inference for efficiency. Differential Revision: [D58874163](https://our.internmc.facebook.com/intern/diff/D58874163/) [ghstack-poisoned] --- examples/models/llama2/custom_ops/op_sdpa.cpp | 1 - .../custom_ops/test_sdpa_with_kv_cache.py | 63 +++++++++++++++++++ 2 files changed, 63 insertions(+), 1 deletion(-) diff --git a/examples/models/llama2/custom_ops/op_sdpa.cpp b/examples/models/llama2/custom_ops/op_sdpa.cpp index 1ca6ed08ab2..e553803f8a7 100644 --- a/examples/models/llama2/custom_ops/op_sdpa.cpp +++ b/examples/models/llama2/custom_ops/op_sdpa.cpp @@ -665,7 +665,6 @@ void update_cache( const Tensor& cache, int64_t start_pos, int64_t seq_length) { - ET_CHECK_MSG(seq_length == 1, "seq_length must be 1"); ET_CHECK_MSG( projected_value.size(0) == 1, "projected_value must have batch size of 1"); diff --git a/examples/models/llama2/custom_ops/test_sdpa_with_kv_cache.py b/examples/models/llama2/custom_ops/test_sdpa_with_kv_cache.py index 5bd05e3db05..1b8f425b673 100644 --- a/examples/models/llama2/custom_ops/test_sdpa_with_kv_cache.py +++ b/examples/models/llama2/custom_ops/test_sdpa_with_kv_cache.py @@ -235,6 +235,69 @@ def setUp(self): self.is_causal = True +class SDPAWithDynamicShape(unittest.TestCase): + + def setUp(self): + torch.manual_seed(42) + self.k_cache = torch.zeros((1, 10, 8, 4)) + self.v_cache = torch.zeros((1, 10, 8, 4)) + self.mask = torch.full( + (10, 10), + float("-inf"), + ) + self.mask = torch.triu(self.mask, diagonal=1) + self.use_mask_with_custom_op = False + self.is_causal = False + + def test_sdpa_with_cache_dynamic_shape_0(self): + q = torch.rand((1, 4, 8, 4)) + k = torch.rand((1, 4, 8, 4)) + v = torch.rand((1, 4, 8, 4)) + seq_len = q.size(1) + start_pos = 0 + attn_mask = self.mask[start_pos : start_pos + seq_len, :] + attn_mask = attn_mask[:, : start_pos + seq_len] + + ref_output = _sdpa_with_kv_cache_ref( + q, k, v, self.k_cache, self.v_cache, attn_mask, start_pos, seq_len + ) + + op_output = torch.ops.llama.sdpa_with_kv_cache( + q, k, v, self.k_cache, self.v_cache, start_pos, seq_len, None, 0, True + ) + self.assertTrue(torch.allclose(ref_output, op_output)) + + def test_sdpa_with_cache_dynamic_shape_2(self): + q = torch.rand((1, 3, 8, 4)) + k = torch.rand((1, 3, 8, 4)) + v = torch.rand((1, 3, 8, 4)) + seq_len = q.size(1) + start_pos = 2 + attn_mask = self.mask[start_pos : start_pos + seq_len, :] + attn_mask = attn_mask[:, : start_pos + seq_len] + + ref_output = _sdpa_with_kv_cache_ref( + q, k, v, self.k_cache, self.v_cache, attn_mask, start_pos, seq_len + ) + + op_output = torch.ops.llama.sdpa_with_kv_cache( + q, k, v, self.k_cache, self.v_cache, start_pos, seq_len, None, 0, True + ) + self.assertTrue(torch.allclose(ref_output, op_output)) + + @unittest.skip("This test will expect failure but runtime is not bubbling it up.") + def test_sdpa_with_cache_dynamic_shape_4(self): + q = torch.rand((1, 11, 8, 4)) + k = torch.rand((1, 11, 8, 4)) + v = torch.rand((1, 11, 8, 4)) + seq_len = q.size(1) + start_pos = 4 + + torch.ops.llama.sdpa_with_kv_cache( + q, k, v, self.k_cache, self.v_cache, start_pos, seq_len, None, 0, True + ) + + class SDPATestWithMQA(unittest.TestCase): def setup_caches(self):