From edea73ee95b4089fc06e9d699640df08cd88b29b Mon Sep 17 00:00:00 2001 From: Kimish Patel Date: Tue, 25 Jun 2024 15:31:21 -0700 Subject: [PATCH 1/3] Refactor sdpa_with_kv_cache tests Refactorso as to - Test attention mask - Test is causal - enable dynamic shape tests Differential Revision: [D58874166](https://our.internmc.facebook.com/intern/diff/D58874166/) [ghstack-poisoned] --- examples/models/llama2/custom_ops/op_sdpa.cpp | 14 ++ .../custom_ops/test_sdpa_with_kv_cache.py | 121 ++++++------------ 2 files changed, 53 insertions(+), 82 deletions(-) diff --git a/examples/models/llama2/custom_ops/op_sdpa.cpp b/examples/models/llama2/custom_ops/op_sdpa.cpp index 01edddb069c..c4f5fa3cf10 100644 --- a/examples/models/llama2/custom_ops/op_sdpa.cpp +++ b/examples/models/llama2/custom_ops/op_sdpa.cpp @@ -401,6 +401,20 @@ void cpu_flash_attention( qk_data, kvBlockSize); // Apply causal mask, fill unused with -inf + // Apply causal mask, fill unused, i.e. future values, with -inf + // Say you have q @ k.T size = [16, 32] + // With qblock size = 4, say you are processing + // q seq len dim = 8:11. + // Say kvSplitSize = 4 + // Then for causal mask, the entries that needs to be + // ignored are + // [8, 9:31], [9, 10:31], [10, 10:31], [11, 11:31] + // Following condition says that num_keys = 8 + 4 =12 + // (num_keys - n) <= kvSplitSize + // num_keys <= n + kvSplitSize + // If n + kvSplitSize is larger than 12, then some + // entries need masked out. In our example n = 4 + // will qualify for that if (is_causal && num_keys - n <= kvSplitSize) { for (int32_t row = 0; row < qBlockSize; ++row) { int64_t last_col = m + row - n; diff --git a/examples/models/llama2/custom_ops/test_sdpa_with_kv_cache.py b/examples/models/llama2/custom_ops/test_sdpa_with_kv_cache.py index abf3abc0284..bf857204e1c 100644 --- a/examples/models/llama2/custom_ops/test_sdpa_with_kv_cache.py +++ b/examples/models/llama2/custom_ops/test_sdpa_with_kv_cache.py @@ -12,54 +12,52 @@ from .sdpa_with_kv_cache import custom_ops_lib # noqa +def _sdpa_with_kv_cache_ref(q, k, v, k_cache, v_cache, mask, start_pos): + seq_len = q.size(1) + attn_mask = mask[start_pos : start_pos + seq_len, :] + attn_mask = attn_mask[:, : start_pos + seq_len] + q = q.transpose(1, 2) + k_cache[:, start_pos : start_pos + seq_len, :, :] = k + v_cache[:, start_pos : start_pos + seq_len, :, :] = v + sliced_k_cache = k_cache[:, : start_pos + seq_len, :, :] + sliced_v_cache = v_cache[:, : start_pos + seq_len, :, :] + sliced_k_cache = sliced_k_cache.transpose(1, 2) + sliced_v_cache = sliced_v_cache.transpose(1, 2) + + num_heads_q = q.size(1) + num_heads_kv = sliced_k_cache.size(1) + if num_heads_q != num_heads_kv: + assert ( + num_heads_q % num_heads_kv == 0 + ), f"{num_heads_q} not divisible by {num_heads_kv}" + n_reps = num_heads_q // num_heads_kv + if n_reps > 1: + sliced_k_cache = sliced_k_cache.repeat_interleave(n_reps, dim=1) + sliced_v_cache = sliced_v_cache.repeat_interleave(n_reps, dim=1) + out = F.scaled_dot_product_attention( + q, sliced_k_cache, sliced_v_cache, attn_mask=attn_mask + ) + out = out.transpose(1, 2) + return out + + class SDPATest(unittest.TestCase): def setUp(self): torch.manual_seed(42) - self.k_cache = torch.zeros((1, 5, 8, 4)) - self.v_cache = torch.zeros((1, 5, 8, 4)) + self.k_cache = torch.zeros((1, 10, 8, 4)) + self.v_cache = torch.zeros((1, 10, 8, 4)) self.mask = torch.full( - (5, 5), + (10, 10), float("-inf"), ) self.mask = torch.triu(self.mask, diagonal=1) - def _sdpa_with_kv_cache_ref(self, q, k, v, k_cache, v_cache, mask, start_pos): - print(f"at start_pos:{start_pos}") - print(q) - print(k) - print(v) - attn_mask = mask[start_pos].view((1, -1)) - attn_mask = attn_mask[:, : start_pos + 1] - q = q.transpose(1, 2) - k_cache[:, start_pos] = k - v_cache[:, start_pos] = v - sliced_k_cache = k_cache[:, : start_pos + 1, :, :] - sliced_v_cache = v_cache[:, : start_pos + 1, :, :] - sliced_k_cache = sliced_k_cache.transpose(1, 2) - sliced_v_cache = sliced_v_cache.transpose(1, 2) - # print(sliced_k_cache.size()) - # print(torch.matmul(q, sliced_k_cache.transpose(2, 3))) - # print("q @ k") - # qk = torch.matmul(q, sliced_k_cache.transpose(2, 3)) - # qk_softmax = torch.softmax(qk, dim=-1) - # qkv = torch.matmul(qk_softmax, sliced_v_cache) - # print(qk) - # print(qk_softmax) - # print(qkv) - out = F.scaled_dot_product_attention( - q, sliced_k_cache, sliced_v_cache, attn_mask=attn_mask - ) - out = out.transpose(1, 2) - print(out) - print(f"-------- start pos {start_pos} done -----") - return out - def test_sdpa_with_cache_no_mqa_1(self): q = torch.rand((1, 1, 8, 4)) k = torch.rand((1, 1, 8, 4)) v = torch.rand((1, 1, 8, 4)) - ref_output = self._sdpa_with_kv_cache_ref( + ref_output = _sdpa_with_kv_cache_ref( q, k, v, self.k_cache, self.v_cache, self.mask, 0 ) op_output = torch.ops.llama.sdpa_with_kv_cache( @@ -72,7 +70,7 @@ def test_sdpa_with_cache_no_mqa_2(self): k = torch.rand((1, 1, 8, 4)) v = torch.rand((1, 1, 8, 4)) - ref_output = self._sdpa_with_kv_cache_ref( + ref_output = _sdpa_with_kv_cache_ref( q, k, v, self.k_cache, self.v_cache, self.mask, 1 ) op_output = torch.ops.llama.sdpa_with_kv_cache( @@ -85,7 +83,7 @@ def test_sdpa_with_cache_no_mqa_3(self): k = torch.rand((1, 1, 8, 4)) v = torch.rand((1, 1, 8, 4)) - ref_output = self._sdpa_with_kv_cache_ref( + ref_output = _sdpa_with_kv_cache_ref( q, k, v, self.k_cache, self.v_cache, self.mask, 2 ) op_output = torch.ops.llama.sdpa_with_kv_cache( @@ -98,7 +96,7 @@ def test_sdpa_with_cache_no_mqa_4(self): k = torch.rand((1, 1, 8, 4)) v = torch.rand((1, 1, 8, 4)) - ref_output = self._sdpa_with_kv_cache_ref( + ref_output = _sdpa_with_kv_cache_ref( q, k, v, self.k_cache, self.v_cache, self.mask, 3 ) op_output = torch.ops.llama.sdpa_with_kv_cache( @@ -124,52 +122,11 @@ def setUp(self): ) self.mask = torch.triu(self.mask, diagonal=1) - def _sdpa_with_kv_cache_ref(self, q, k, v, k_cache, v_cache, mask, start_pos): - print(f"at start_pos:{start_pos}") - print(q) - print(k) - print(v) - attn_mask = mask[start_pos].view((1, -1)) - attn_mask = attn_mask[:, : start_pos + 1] - q = q.transpose(1, 2) - k_cache[:, start_pos] = k - v_cache[:, start_pos] = v - sliced_k_cache = k_cache[:, : start_pos + 1, :, :] - sliced_v_cache = v_cache[:, : start_pos + 1, :, :] - sliced_k_cache = sliced_k_cache.transpose(1, 2) - sliced_v_cache = sliced_v_cache.transpose(1, 2) - # print(sliced_k_cache.size()) - # print(torch.matmul(q, sliced_k_cache.transpose(2, 3))) - # print("q @ k") - # qk = torch.matmul(q, sliced_k_cache.transpose(2, 3)) - # qk_softmax = torch.softmax(qk, dim=-1) - # qkv = torch.matmul(qk_softmax, sliced_v_cache) - # print(qk) - # print(qk_softmax) - # print(qkv) - num_heads_q = q.size(1) - num_heads_kv = sliced_k_cache.size(1) - if num_heads_q != num_heads_kv: - assert ( - num_heads_q % num_heads_kv == 0 - ), f"{num_heads_q} not divisible by {num_heads_kv}" - n_reps = num_heads_q // num_heads_kv - if n_reps > 1: - sliced_k_cache = sliced_k_cache.repeat_interleave(n_reps, dim=1) - sliced_v_cache = sliced_v_cache.repeat_interleave(n_reps, dim=1) - out = F.scaled_dot_product_attention( - q, sliced_k_cache, sliced_v_cache, attn_mask=attn_mask - ) - out = out.transpose(1, 2) - print(out) - print(f"-------- start pos {start_pos} done -----") - return out - def test_sdpa_with_cache_mqa_1(self): q = torch.rand((1, 1, self.n_heads_q, 4)) k = torch.rand((1, 1, self.n_heads_kv, 4)) v = torch.rand((1, 1, self.n_heads_kv, 4)) - ref_output = self._sdpa_with_kv_cache_ref( + ref_output = _sdpa_with_kv_cache_ref( q, k, v, self.k_cache, self.v_cache, self.mask, 0 ) op_output = torch.ops.llama.sdpa_with_kv_cache( @@ -181,7 +138,7 @@ def test_sdpa_with_cache_mqa_2(self): q = torch.rand((1, 1, self.n_heads_q, 4)) k = torch.rand((1, 1, self.n_heads_kv, 4)) v = torch.rand((1, 1, self.n_heads_kv, 4)) - ref_output = self._sdpa_with_kv_cache_ref( + ref_output = _sdpa_with_kv_cache_ref( q, k, v, self.k_cache, self.v_cache, self.mask, 1 ) op_output = torch.ops.llama.sdpa_with_kv_cache( @@ -196,7 +153,7 @@ def test_sdpa_with_cache_mqa_3(self): q = torch.rand((1, 1, self.n_heads_q, 4)) k = torch.rand((1, 1, self.n_heads_kv, 4)) v = torch.rand((1, 1, self.n_heads_kv, 4)) - ref_output = self._sdpa_with_kv_cache_ref( + ref_output = _sdpa_with_kv_cache_ref( q, k, v, self.k_cache, self.v_cache, self.mask, 1 ) op_output = torch.ops.llama.sdpa_with_kv_cache( From be02bb6a5d8603b25f6f25b902a56cefa5fcb219 Mon Sep 17 00:00:00 2001 From: Kimish Patel Date: Wed, 26 Jun 2024 14:15:38 -0700 Subject: [PATCH 2/3] Update on "Refactor sdpa_with_kv_cache tests" Refactorso as to - Test attention mask - Test is causal - enable dynamic shape tests Differential Revision: [D58874166](https://our.internmc.facebook.com/intern/diff/D58874166/) [ghstack-poisoned] --- examples/models/llama2/custom_ops/op_sdpa.cpp | 1 - examples/models/llama2/runner/runner.cpp | 19 +++++++++++++++++++ examples/models/llama2/runner/runner.h | 1 + 3 files changed, 20 insertions(+), 1 deletion(-) diff --git a/examples/models/llama2/custom_ops/op_sdpa.cpp b/examples/models/llama2/custom_ops/op_sdpa.cpp index c4f5fa3cf10..b2cd7fb0d0a 100644 --- a/examples/models/llama2/custom_ops/op_sdpa.cpp +++ b/examples/models/llama2/custom_ops/op_sdpa.cpp @@ -400,7 +400,6 @@ void cpu_flash_attention( static_cast(0), qk_data, kvBlockSize); - // Apply causal mask, fill unused with -inf // Apply causal mask, fill unused, i.e. future values, with -inf // Say you have q @ k.T size = [16, 32] // With qblock size = 4, say you are processing diff --git a/examples/models/llama2/runner/runner.cpp b/examples/models/llama2/runner/runner.cpp index 559c71fd81a..094a8b283ca 100644 --- a/examples/models/llama2/runner/runner.cpp +++ b/examples/models/llama2/runner/runner.cpp @@ -138,6 +138,24 @@ int32_t Runner::logitsToToken( return sampler_->sample(logits_last); } +void Runner::warmup() { + std::vector token_data = {1}; // allocate space for the tokens + ManagedTensor tokens_managed( + token_data.data(), {1, 1}, ScalarType::Long); + std::vector start_pos_data = {0}; // allocate space for the tokens + ManagedTensor start_pos_managed( + start_pos_data.data(), {1}, ScalarType::Long); + std::vector inputs; + auto tokens_tensor = tokens_managed.get_aliasing_tensor(); + auto start_pos = start_pos_managed.get_aliasing_tensor(); + + // inputs:[tokens, start_pos] + inputs.push_back(tokens_tensor); + inputs.push_back(start_pos); + + Result> outputs_res = module_->forward(inputs); +} + // Given an input token. Set up the inputs for the model and execute a single // step. Returning the logits tensor. Result Runner::run_model_step( @@ -222,6 +240,7 @@ Error Runner::generate( // First token time only measures the time it takes to encode the prompt and // return a response token. + warmup(); stats_.inference_start_ms = util::time_in_ms(); shouldStop_ = false; diff --git a/examples/models/llama2/runner/runner.h b/examples/models/llama2/runner/runner.h index 4e200d5e6ca..624998b6258 100644 --- a/examples/models/llama2/runner/runner.h +++ b/examples/models/llama2/runner/runner.h @@ -73,6 +73,7 @@ class Runner { template int32_t logitsToToken(const exec_aten::Tensor& logits_tensor, int64_t pos, T _); + void warmup(); Result run_model_step( int64_t input_token, ManagedTensor& tokens, From 76ac850336d08d3348c05e8160c924d7f451d11b Mon Sep 17 00:00:00 2001 From: Kimish Patel Date: Wed, 26 Jun 2024 14:46:41 -0700 Subject: [PATCH 3/3] Update on "Refactor sdpa_with_kv_cache tests" Refactorso as to - Test attention mask - Test is causal - enable dynamic shape tests Differential Revision: [D58874166](https://our.internmc.facebook.com/intern/diff/D58874166/) [ghstack-poisoned] --- examples/models/llama2/runner/runner.cpp | 19 ------------------- examples/models/llama2/runner/runner.h | 1 - 2 files changed, 20 deletions(-) diff --git a/examples/models/llama2/runner/runner.cpp b/examples/models/llama2/runner/runner.cpp index 094a8b283ca..559c71fd81a 100644 --- a/examples/models/llama2/runner/runner.cpp +++ b/examples/models/llama2/runner/runner.cpp @@ -138,24 +138,6 @@ int32_t Runner::logitsToToken( return sampler_->sample(logits_last); } -void Runner::warmup() { - std::vector token_data = {1}; // allocate space for the tokens - ManagedTensor tokens_managed( - token_data.data(), {1, 1}, ScalarType::Long); - std::vector start_pos_data = {0}; // allocate space for the tokens - ManagedTensor start_pos_managed( - start_pos_data.data(), {1}, ScalarType::Long); - std::vector inputs; - auto tokens_tensor = tokens_managed.get_aliasing_tensor(); - auto start_pos = start_pos_managed.get_aliasing_tensor(); - - // inputs:[tokens, start_pos] - inputs.push_back(tokens_tensor); - inputs.push_back(start_pos); - - Result> outputs_res = module_->forward(inputs); -} - // Given an input token. Set up the inputs for the model and execute a single // step. Returning the logits tensor. Result Runner::run_model_step( @@ -240,7 +222,6 @@ Error Runner::generate( // First token time only measures the time it takes to encode the prompt and // return a response token. - warmup(); stats_.inference_start_ms = util::time_in_ms(); shouldStop_ = false; diff --git a/examples/models/llama2/runner/runner.h b/examples/models/llama2/runner/runner.h index 624998b6258..4e200d5e6ca 100644 --- a/examples/models/llama2/runner/runner.h +++ b/examples/models/llama2/runner/runner.h @@ -73,7 +73,6 @@ class Runner { template int32_t logitsToToken(const exec_aten::Tensor& logits_tensor, int64_t pos, T _); - void warmup(); Result run_model_step( int64_t input_token, ManagedTensor& tokens,