From be9d8d828abe1bc8a9012cbeb0e4f363741e0519 Mon Sep 17 00:00:00 2001 From: Harry Hu Date: Mon, 30 Jun 2025 16:06:21 -0400 Subject: [PATCH] [RELAX] Fix rotary embedding buffer size calculation * Change head_dim//2 to rotary_dim//2 in LongRope scaling * Fixes buffer size when rotary_dim differs from head_dim --- python/tvm/relax/frontend/nn/llm/position_embedding.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/relax/frontend/nn/llm/position_embedding.py b/python/tvm/relax/frontend/nn/llm/position_embedding.py index fc82148be1a9..1a1659b29e18 100644 --- a/python/tvm/relax/frontend/nn/llm/position_embedding.py +++ b/python/tvm/relax/frontend/nn/llm/position_embedding.py @@ -493,7 +493,7 @@ def fused_rope_longrope_scaling( # pylint: disable=too-many-locals var_q: T.handle, var_k: T.handle, var_v: T.handle, - ext_factors: T.Buffer((head_dim // 2,), "float32"), # type: ignore + ext_factors: T.Buffer((rotary_dim // 2,), "float32"), # type: ignore ): T.func_attr( {