diff --git a/tests/test_infer_ops/triton/test_rotary_embedding.py b/tests/test_infer_ops/triton/test_rotary_embedding.py index 4413dba642b8..9fafd480a956 100644 --- a/tests/test_infer_ops/triton/test_rotary_embedding.py +++ b/tests/test_infer_ops/triton/test_rotary_embedding.py @@ -32,6 +32,8 @@ def torch_rotary_emb(x, cos, sin): return torch.cat((o0, o1), dim=-1) +@pytest.mark.skipif(not TRITON_CUDA_SUPPORT or not HAS_TRITON, + reason="triton requires cuda version to be higher than 11.4") def test_rotary_emb(): SEQ_LEN = 1 HEAD_NUM = 32 @@ -50,3 +52,7 @@ def test_rotary_emb(): # print("max delta:", torch.max(torch.abs(y_torch - y_triton))) # compare assert torch.allclose(y_torch, y_triton, atol=1e-2, rtol=1e-2) + + +if __name__ == "__main__": + test_rotary_emb()