Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion examples/models/llama2/custom_ops/op_sdpa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -664,7 +664,6 @@ void update_cache(
const Tensor& cache,
int64_t start_pos,
int64_t seq_length) {
ET_CHECK_MSG(seq_length == 1, "seq_length must be 1");
ET_CHECK_MSG(
projected_value.size(0) == 1,
"projected_value must have batch size of 1");
Expand Down
63 changes: 63 additions & 0 deletions examples/models/llama2/custom_ops/test_sdpa_with_kv_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,69 @@ def setUp(self):
self.is_causal = True


class SDPAWithDynamicShape(unittest.TestCase):

def setUp(self):
torch.manual_seed(42)
self.k_cache = torch.zeros((1, 10, 8, 4))
self.v_cache = torch.zeros((1, 10, 8, 4))
self.mask = torch.full(
(10, 10),
float("-inf"),
)
self.mask = torch.triu(self.mask, diagonal=1)
self.use_mask_with_custom_op = False
self.is_causal = False

def test_sdpa_with_cache_dynamic_shape_0(self):
q = torch.rand((1, 4, 8, 4))
k = torch.rand((1, 4, 8, 4))
v = torch.rand((1, 4, 8, 4))
seq_len = q.size(1)
start_pos = 0
attn_mask = self.mask[start_pos : start_pos + seq_len, :]
attn_mask = attn_mask[:, : start_pos + seq_len]

ref_output = _sdpa_with_kv_cache_ref(
q, k, v, self.k_cache, self.v_cache, attn_mask, start_pos, seq_len
)

op_output = torch.ops.llama.sdpa_with_kv_cache(
q, k, v, self.k_cache, self.v_cache, start_pos, seq_len, None, 0, True
)
self.assertTrue(torch.allclose(ref_output, op_output))

def test_sdpa_with_cache_dynamic_shape_2(self):
q = torch.rand((1, 3, 8, 4))
k = torch.rand((1, 3, 8, 4))
v = torch.rand((1, 3, 8, 4))
seq_len = q.size(1)
start_pos = 2
attn_mask = self.mask[start_pos : start_pos + seq_len, :]
attn_mask = attn_mask[:, : start_pos + seq_len]

ref_output = _sdpa_with_kv_cache_ref(
q, k, v, self.k_cache, self.v_cache, attn_mask, start_pos, seq_len
)

op_output = torch.ops.llama.sdpa_with_kv_cache(
q, k, v, self.k_cache, self.v_cache, start_pos, seq_len, None, 0, True
)
self.assertTrue(torch.allclose(ref_output, op_output))

@unittest.skip("This test will expect failure but runtime is not bubbling it up.")
def test_sdpa_with_cache_dynamic_shape_4(self):
q = torch.rand((1, 11, 8, 4))
k = torch.rand((1, 11, 8, 4))
v = torch.rand((1, 11, 8, 4))
seq_len = q.size(1)
start_pos = 4

torch.ops.llama.sdpa_with_kv_cache(
q, k, v, self.k_cache, self.v_cache, start_pos, seq_len, None, 0, True
)


class SDPATestWithMQA(unittest.TestCase):

def setup_caches(self):
Expand Down