From d9d1d0fc451df4787b6f9201761b3a5a55070cc1 Mon Sep 17 00:00:00 2001 From: gambletan Date: Sat, 14 Mar 2026 11:43:39 +0800 Subject: [PATCH] fix: add missing ReLU in GLM-MOE-DSA indexer scoring The DSA indexer was missing a ReLU activation on the per-head dot-product scores before the weighted sum across heads. The reference DeepSeek V3.2 implementation applies ReLU inside the fp8_index kernel via `T.max(logits, 0)` before multiplying by head weights. Without this, negative attention scores incorrectly contribute to the index scoring, which can affect top-k token selection for sparse attention. Fixes huggingface/transformers#44360 Co-Authored-By: Claude Opus 4.6 (1M context) --- src/transformers/models/glm_moe_dsa/modeling_glm_moe_dsa.py | 3 +++ src/transformers/models/glm_moe_dsa/modular_glm_moe_dsa.py | 3 +++ 2 files changed, 6 insertions(+) diff --git a/src/transformers/models/glm_moe_dsa/modeling_glm_moe_dsa.py b/src/transformers/models/glm_moe_dsa/modeling_glm_moe_dsa.py index e1a2f4dd5b7f..c5f57f3376da 100644 --- a/src/transformers/models/glm_moe_dsa/modeling_glm_moe_dsa.py +++ b/src/transformers/models/glm_moe_dsa/modeling_glm_moe_dsa.py @@ -218,6 +218,9 @@ def forward( # q·k^T per head: [B, S, H, D] @ [B, T, D]^T → [B, S, H, T] scores = torch.einsum("bshd,btd->bsht", q.float(), k_cached.float()) * self.softmax_scale + # ReLU matches the reference fp8_index kernel: T.max(logits, 0) before weighting + scores = torch.nn.functional.relu(scores) + # Weight per head and sum across heads → [B, S, T] index_scores = torch.einsum("bsht,bsh->bst", scores, weights) diff --git a/src/transformers/models/glm_moe_dsa/modular_glm_moe_dsa.py b/src/transformers/models/glm_moe_dsa/modular_glm_moe_dsa.py index 96ff01d0cc9e..52c56cb9cb18 100644 --- a/src/transformers/models/glm_moe_dsa/modular_glm_moe_dsa.py +++ b/src/transformers/models/glm_moe_dsa/modular_glm_moe_dsa.py @@ -346,6 +346,9 @@ def forward( # q·k^T per head: [B, S, H, D] @ [B, T, D]^T → [B, S, H, T] scores = torch.einsum("bshd,btd->bsht", q.float(), k_cached.float()) * self.softmax_scale + # ReLU matches the reference fp8_index kernel: T.max(logits, 0) before weighting + scores = torch.nn.functional.relu(scores) + # Weight per head and sum across heads → [B, S, T] index_scores = torch.einsum("bsht,bsh->bst", scores, weights)