diff --git a/src/transformers/models/glm_moe_dsa/modeling_glm_moe_dsa.py b/src/transformers/models/glm_moe_dsa/modeling_glm_moe_dsa.py index e1a2f4dd5b7f..c5f57f3376da 100644 --- a/src/transformers/models/glm_moe_dsa/modeling_glm_moe_dsa.py +++ b/src/transformers/models/glm_moe_dsa/modeling_glm_moe_dsa.py @@ -218,6 +218,9 @@ def forward( # q·k^T per head: [B, S, H, D] @ [B, T, D]^T → [B, S, H, T] scores = torch.einsum("bshd,btd->bsht", q.float(), k_cached.float()) * self.softmax_scale + # ReLU matches the reference fp8_index kernel: T.max(logits, 0) before weighting + scores = torch.nn.functional.relu(scores) + # Weight per head and sum across heads → [B, S, T] index_scores = torch.einsum("bsht,bsh->bst", scores, weights) diff --git a/src/transformers/models/glm_moe_dsa/modular_glm_moe_dsa.py b/src/transformers/models/glm_moe_dsa/modular_glm_moe_dsa.py index 96ff01d0cc9e..52c56cb9cb18 100644 --- a/src/transformers/models/glm_moe_dsa/modular_glm_moe_dsa.py +++ b/src/transformers/models/glm_moe_dsa/modular_glm_moe_dsa.py @@ -346,6 +346,9 @@ def forward( # q·k^T per head: [B, S, H, D] @ [B, T, D]^T → [B, S, H, T] scores = torch.einsum("bshd,btd->bsht", q.float(), k_cached.float()) * self.softmax_scale + # ReLU matches the reference fp8_index kernel: T.max(logits, 0) before weighting + scores = torch.nn.functional.relu(scores) + # Weight per head and sum across heads → [B, S, T] index_scores = torch.einsum("bsht,bsh->bst", scores, weights)