diff --git a/test/transformers/test_rope.py b/test/transformers/test_rope.py index a7623a236..4df7da938 100644 --- a/test/transformers/test_rope.py +++ b/test/transformers/test_rope.py @@ -83,7 +83,7 @@ def test_correctness( cos, sin = rotary_emb(k1, pos_ids) # validate forward pass - hf_q, hf_k = apply_rotary_pos_emb(q1, k1, cos, sin, pos_ids) + hf_q, hf_k = apply_rotary_pos_emb(q1, k1, cos, sin) tt_q, tt_k = liger_rotary_pos_emb(q2, k2, cos, sin) assert torch.allclose(hf_q, tt_q, atol=atol, rtol=rtol) assert torch.allclose(hf_k, tt_k, atol=atol, rtol=rtol)