diff --git a/README.md b/README.md
index bb774be8..6c2d41a4 100644
--- a/README.md
+++ b/README.md
@@ -94,7 +94,7 @@ A more accessible, comprehensive, and efficient toolkit for large model compress
Sparse Attention
- - Under Development
+ - Stem
diff --git a/README_cn.md b/README_cn.md
index 6d16e363..ae1fc891 100644
--- a/README_cn.md
+++ b/README_cn.md
@@ -95,7 +95,7 @@
稀疏注意力
diff --git a/angelslim/compressor/sparsity/__init__.py b/angelslim/compressor/sparsity/__init__.py
index eca77b00..eaa4c7a2 100644
--- a/angelslim/compressor/sparsity/__init__.py
+++ b/angelslim/compressor/sparsity/__init__.py
@@ -11,3 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+
+from .stem import StemInference # noqa: F401
+
+__all__ = ["StemInference"]
diff --git a/angelslim/compressor/sparsity/stem/__init__.py b/angelslim/compressor/sparsity/stem/__init__.py
new file mode 100644
index 00000000..5ab6b1ed
--- /dev/null
+++ b/angelslim/compressor/sparsity/stem/__init__.py
@@ -0,0 +1,25 @@
+# Copyright 2025 Tencent Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+"""Stem — Sparse Token Estimation Module for long-context LLM inference.
+
+Public API:
+ StemInference: Callable that patches a HuggingFace model to use Stem
+ sparse attention during the prefill stage.
+"""
+
+from .stem import StemInference
+
+__all__ = ["StemInference"]
diff --git a/angelslim/compressor/sparsity/stem/backends/__init__.py b/angelslim/compressor/sparsity/stem/backends/__init__.py
new file mode 100644
index 00000000..b152ab84
--- /dev/null
+++ b/angelslim/compressor/sparsity/stem/backends/__init__.py
@@ -0,0 +1,20 @@
+# Copyright 2025 Tencent Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+"""Stem backend implementations (torch / HPC)."""
+
+from .dispatcher import stem_forward
+
+__all__ = ["stem_forward"]
diff --git a/angelslim/compressor/sparsity/stem/backends/dispatcher.py b/angelslim/compressor/sparsity/stem/backends/dispatcher.py
new file mode 100644
index 00000000..93582489
--- /dev/null
+++ b/angelslim/compressor/sparsity/stem/backends/dispatcher.py
@@ -0,0 +1,59 @@
+# Copyright 2025 Tencent Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+"""Backend dispatcher: routes Stem prefill to the correct implementation."""
+
+from __future__ import annotations
+
+import torch
+
+from .torch_impl import stem_forward_torch
+
+
+def stem_forward(
+ query_states: torch.Tensor,
+ key_states: torch.Tensor,
+ value_states: torch.Tensor,
+ prefill_kwargs: dict,
+) -> torch.Tensor:
+ """Dispatch a Stem prefill call to the appropriate backend.
+
+ Args:
+ query_states: Query tensor of shape ``(B, H_q, L_q, D)``.
+ key_states: Key tensor of shape ``(B, H_kv, L_kv, D)``.
+ value_states: Value tensor of shape ``(B, H_kv, L_kv, D)``.
+ prefill_kwargs: Must contain ``"attn_forward_config"`` (with a
+ ``"backend"`` key) and ``"layer_idx"``.
+
+ Returns:
+ Attention output tensor of shape ``(B, H_q, L_q, D)``.
+
+ Raises:
+ ValueError: If the requested backend is not ``"torch"`` or ``"hpc"``.
+ """
+ config = prefill_kwargs["attn_forward_config"]
+ backend = config.get("backend", "torch")
+
+ if backend == "torch":
+ return stem_forward_torch(query_states, key_states, value_states, prefill_kwargs)
+
+ if backend == "hpc":
+ # Lazy import to avoid hard dependency on the ``hpc`` C++ extension
+ # when only the pure-torch path is needed.
+ from .hpc_impl import stem_forward_hpc
+
+ return stem_forward_hpc(query_states, key_states, value_states, prefill_kwargs)
+
+ raise ValueError(f"Unknown stem backend: {backend!r}")
diff --git a/angelslim/compressor/sparsity/stem/backends/hpc_impl.py b/angelslim/compressor/sparsity/stem/backends/hpc_impl.py
new file mode 100644
index 00000000..04dba8c4
--- /dev/null
+++ b/angelslim/compressor/sparsity/stem/backends/hpc_impl.py
@@ -0,0 +1,541 @@
+# Copyright 2025 Tencent Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+"""HPC (C++ extension) backend for Stem sparse prefill.
+
+Provides three execution paths:
+
+* **bf16 dense** — calls ``hpc.attention_prefill_bf16`` directly.
+* **fp8 varlen** — Triton scoring → ``hpc.stem_tpd`` mask →
+ ``hpc.attention_blocksparse_prefill_fp8``.
+* **fp8 paged** — ``hpc.stem_paged_kv`` mask →
+ ``hpc.attention_with_kvcache_blocksparse_prefill_fp8``.
+"""
+
+from __future__ import annotations
+
+import hpc
+import torch
+
+from .torch_impl import _compute_triton_block_logits, stem_forward_torch
+
+# ---------------------------------------------------------------------------
+# Constants
+# ---------------------------------------------------------------------------
+
+FP8_DTYPE = torch.float8_e4m3fn
+HPC_FP8_BLOCK_SIZE = 128
+HPC_INITIAL_BLOCKS = 4
+HPC_WINDOW_SIZE = 4
+
+
+# ---------------------------------------------------------------------------
+# Tensor packing / quantisation helpers
+# ---------------------------------------------------------------------------
+
+
+def _pack_bhld_to_varlen(x: torch.Tensor) -> torch.Tensor:
+ """Reshape ``(B, H, L, D)`` → ``(B*L, H, D)`` for variable-length HPC kernels."""
+ B, H, L, D = x.shape
+ return x.transpose(1, 2).reshape(B * L, H, D).contiguous()
+
+
+def _uniform_seq_metadata(
+ batch_size: int,
+ seq_len: int,
+ device: torch.device,
+) -> tuple[torch.Tensor, torch.Tensor]:
+ """Return ``(seqlens, cu_seqlens)`` for a uniform-length batch.
+
+ Parameters
+ ----------
+ batch_size : int
+ seq_len : int
+ device : torch.device
+
+ Returns
+ -------
+ seqlens : torch.Tensor — ``(batch_size,)``, all equal to *seq_len*.
+ cu_seqlens : torch.Tensor — ``(batch_size + 1,)``, cumulative offsets.
+ """
+ seqlens = torch.full((batch_size,), seq_len, dtype=torch.int32, device=device)
+ cu_seqlens = torch.arange(
+ 0,
+ (batch_size + 1) * seq_len,
+ step=seq_len,
+ dtype=torch.int32,
+ device=device,
+ )
+ return seqlens, cu_seqlens
+
+
+def _quantize_per_tensor_fp8(
+ x: torch.Tensor,
+) -> tuple[torch.Tensor, torch.Tensor]:
+ """Symmetric per-tensor FP8 (E4M3) quantisation.
+
+ Returns
+ -------
+ x_fp8 : torch.Tensor — quantised tensor (same shape, dtype ``float8_e4m3fn``).
+ scale : torch.Tensor — scalar scale factor, shape ``(1,)``.
+ """
+ fp8_max = torch.finfo(FP8_DTYPE).max
+ scale = x.abs().amax().float().clamp(min=1e-12) / fp8_max
+ x_fp8 = torch.clamp(x.float() / scale, min=-fp8_max, max=fp8_max).to(FP8_DTYPE)
+ return x_fp8.contiguous(), scale.view(1)
+
+
+def _quantize_query_for_paged_fp8(
+ query_states: torch.Tensor,
+) -> tuple[torch.Tensor, torch.Tensor]:
+ """Quantise the query tensor for the paged FP8 attention kernel.
+
+ Uses HPC's per-token-group quantisation (``group_size = head_dim``).
+
+ Args:
+ query_states: Query tensor of shape ``(B, H, L, D)``.
+
+ Returns:
+ A tuple of ``(q_fp8, q_scale)`` where *q_fp8* has shape
+ ``(B*L, H, D)`` in ``float8_e4m3fn`` and *q_scale* has shape
+ ``(B, H, L_padded)`` with per-token scales.
+
+ Raises:
+ ValueError: If ``head_dim != 128``.
+ """
+ B, H, L, D = query_states.shape
+ if D != HPC_FP8_BLOCK_SIZE:
+ raise ValueError(f"HPC fp8 query quant only supports head_dim=128, got {D}.")
+
+ q_rows = query_states.transpose(1, 2).reshape(B * L, H * D).contiguous()
+ q_fp8_rows, q_scale_rows = hpc.quant.per_token_group_fp8_quant(q_rows, group_size=D)
+
+ q_fp8 = q_fp8_rows.view(B * L, H, D).contiguous()
+ q_scale = q_scale_rows.view(B, L, H).permute(0, 2, 1).contiguous()
+
+ # Pad the sequence dimension to a multiple of the block size.
+ padded_L = ((L + HPC_FP8_BLOCK_SIZE - 1) // HPC_FP8_BLOCK_SIZE) * HPC_FP8_BLOCK_SIZE
+ if padded_L != L:
+ q_scale = torch.nn.functional.pad(q_scale, (0, padded_L - L))
+ return q_fp8, q_scale
+
+
+def _pack_paged_cache(
+ x_varlen: torch.Tensor,
+ batch_size: int,
+ seq_len: int,
+ block_size: int,
+) -> tuple[torch.Tensor, torch.Tensor]:
+ """Pack a varlen KV tensor into paged-cache layout.
+
+ Parameters
+ ----------
+ x_varlen : torch.Tensor — ``(B*L, H, D)``
+ batch_size, seq_len, block_size : int
+
+ Returns
+ -------
+ cache : torch.Tensor — ``(total_blocks, block_size, H, D)``
+ kv_indices : torch.Tensor — ``(B, blocks_per_seq)`` block-table indices.
+ """
+ _, H, D = x_varlen.shape
+ padded_L = ((seq_len + block_size - 1) // block_size) * block_size
+ blocks_per_seq = padded_L // block_size
+
+ x_seq = x_varlen.view(batch_size, seq_len, H, D)
+ if padded_L > seq_len:
+ pad = torch.zeros(
+ (batch_size, padded_L - seq_len, H, D),
+ dtype=x_varlen.dtype,
+ device=x_varlen.device,
+ )
+ x_seq = torch.cat([x_seq, pad], dim=1)
+
+ cache = (
+ x_seq.view(batch_size, blocks_per_seq, block_size, H, D)
+ .reshape(batch_size * blocks_per_seq, block_size, H, D)
+ .contiguous()
+ )
+ kv_indices = torch.arange(
+ batch_size * blocks_per_seq, dtype=torch.int32, device=x_varlen.device
+ ).view(batch_size, blocks_per_seq)
+ return cache, kv_indices
+
+
+# ---------------------------------------------------------------------------
+# Head-repeat helper (GQA → full Q heads)
+# ---------------------------------------------------------------------------
+
+
+def _repeat_to_q_heads(x: torch.Tensor, num_q_heads: int) -> torch.Tensor:
+ """Repeat KV heads to match the query head count (GQA support).
+
+ Parameters
+ ----------
+ x : torch.Tensor — ``(B, H_kv, L, D)``
+ num_q_heads : int — target number of heads (``H_q``).
+
+ Returns
+ -------
+ torch.Tensor — ``(B, H_q, L, D)``
+ """
+ B, H_kv, L, D = x.shape
+ if num_q_heads == H_kv:
+ return x
+ if num_q_heads % H_kv != 0:
+ raise ValueError(f"Cannot repeat kv heads from {H_kv} to {num_q_heads}.")
+ n_rep = num_q_heads // H_kv
+ x = x[:, :, None, :, :].expand(B, H_kv, n_rep, L, D)
+ return x.reshape(B, num_q_heads, L, D)
+
+
+# ---------------------------------------------------------------------------
+# Fallback to the pure-torch backend
+# ---------------------------------------------------------------------------
+
+
+def _fallback_to_torch(
+ query_states: torch.Tensor,
+ key_states: torch.Tensor,
+ value_states: torch.Tensor,
+ prefill_kwargs: dict,
+) -> torch.Tensor:
+ """Fall back to :func:`stem_forward_torch` after repeating KV heads."""
+ H_q = query_states.shape[1]
+ if key_states.shape[1] != H_q:
+ key_states = _repeat_to_q_heads(key_states, H_q)
+ if value_states.shape[1] != H_q:
+ value_states = _repeat_to_q_heads(value_states, H_q)
+ return stem_forward_torch(query_states, key_states, value_states, prefill_kwargs)
+
+
+# ---------------------------------------------------------------------------
+# Runtime parameter extraction
+# ---------------------------------------------------------------------------
+
+
+def _stem_runtime_params(config: dict, layer_idx: int) -> dict:
+ """Extract Stem runtime hyper-parameters from the user config dict."""
+ alpha_cfg = config.get("stem_alpha", 1.0)
+ alpha = alpha_cfg[layer_idx] if isinstance(alpha_cfg, (list, tuple)) else alpha_cfg
+ return {
+ "block_size": int(config.get("block_size", HPC_FP8_BLOCK_SIZE)),
+ "stem_stride": int(config.get("hpc_stem_stride", config.get("stride", 16))),
+ "chunk_size": int(config.get("chunk_size", 2048)),
+ "norm": float(config.get("norm", 1.0)),
+ "initial_blocks": int(config.get("initial_blocks", HPC_INITIAL_BLOCKS)),
+ "window_size": int(config.get("window_size", HPC_WINDOW_SIZE)),
+ "lambda_mag": float(config.get("lambda_mag", 0.3)),
+ "alpha": float(alpha),
+ "k_block_num_rate": float(config.get("k_block_num_rate", 0.1)),
+ "k_block_num_bias": int(config.get("k_block_num_bias", 30)),
+ }
+
+
+# ===================================================================== #
+# Execution paths #
+# ===================================================================== #
+
+# --- Path 1: bf16 dense ------------------------------------------------
+
+
+def _run_hpc_bf16_dense(
+ query_states: torch.Tensor,
+ key_states: torch.Tensor,
+ value_states: torch.Tensor,
+) -> torch.Tensor:
+ """BF16 dense prefill via ``hpc.attention_prefill_bf16``."""
+ B, H, Lq, _ = query_states.shape
+ D_v = value_states.shape[-1]
+
+ q_varlen = _pack_bhld_to_varlen(query_states)
+ k_varlen = _pack_bhld_to_varlen(key_states)
+ v_varlen = _pack_bhld_to_varlen(value_states)
+ seqlens_q, cu_seqlens_q = _uniform_seq_metadata(B, Lq, query_states.device)
+
+ output = hpc.attention_prefill_bf16(q_varlen, k_varlen, v_varlen, seqlens_q, cu_seqlens_q, Lq)
+ return output.view(B, Lq, H, D_v).transpose(1, 2).contiguous()
+
+
+# --- Path 2: fp8 varlen ------------------------------------------------
+
+
+def _build_varlen_stem_mask(
+ query_states: torch.Tensor,
+ key_states: torch.Tensor,
+ value_states: torch.Tensor,
+ params: dict,
+) -> torch.Tensor:
+ """Build the block-sparse mask using Triton scoring + ``hpc.stem_tpd``."""
+ H_q = query_states.shape[1]
+ if key_states.shape[1] != H_q:
+ key_states = _repeat_to_q_heads(key_states, H_q)
+ if value_states.shape[1] != H_q:
+ value_states = _repeat_to_q_heads(value_states, H_q)
+
+ block_logits = _compute_triton_block_logits(
+ query_states,
+ key_states,
+ value_states,
+ block_size=params["block_size"],
+ stride=params["stem_stride"],
+ chunk_size=params["chunk_size"],
+ norm=params["norm"],
+ causal=True,
+ )
+
+ B, _, Lq, _ = query_states.shape
+ Lkv = key_states.shape[2]
+ q_seq_lens = torch.full((B,), Lq, dtype=torch.int32, device=query_states.device)
+ kv_seq_lens = torch.full((B,), Lkv, dtype=torch.int32, device=query_states.device)
+
+ return hpc.stem_tpd(
+ block_logits,
+ q_seq_lens,
+ kv_seq_lens,
+ params["block_size"],
+ params["alpha"],
+ params["initial_blocks"],
+ params["window_size"],
+ params["k_block_num_rate"],
+ params["k_block_num_bias"],
+ )
+
+
+def _run_hpc_varlen_stem(
+ query_states: torch.Tensor,
+ key_states: torch.Tensor,
+ value_states: torch.Tensor,
+ params: dict,
+) -> torch.Tensor:
+ """FP8 varlen sparse prefill via ``hpc.attention_blocksparse_prefill_fp8``."""
+ B, H_q, Lq, _ = query_states.shape
+ _, _, Lkv, D_v = value_states.shape
+
+ q_varlen = _pack_bhld_to_varlen(query_states)
+ k_varlen = _pack_bhld_to_varlen(key_states)
+ v_varlen = _pack_bhld_to_varlen(value_states)
+
+ q_fp8, q_scale = _quantize_per_tensor_fp8(q_varlen)
+ k_fp8, k_scale = _quantize_per_tensor_fp8(k_varlen)
+ v_fp8, v_scale = _quantize_per_tensor_fp8(v_varlen)
+
+ _, cu_q = _uniform_seq_metadata(B, Lq, query_states.device)
+ _, cu_kv = _uniform_seq_metadata(B, Lkv, query_states.device)
+
+ mask = _build_varlen_stem_mask(query_states, key_states, value_states, params)
+
+ output = hpc.attention_blocksparse_prefill_fp8(
+ q_fp8,
+ k_fp8,
+ v_fp8,
+ cu_q,
+ cu_kv,
+ Lq,
+ Lkv,
+ q_scale,
+ k_scale,
+ v_scale,
+ block_mask=mask,
+ )
+ return output.view(B, Lq, H_q, D_v).transpose(1, 2).contiguous()
+
+
+# --- Path 3: fp8 paged -------------------------------------------------
+
+
+def _run_hpc_paged_stem(
+ query_states: torch.Tensor,
+ key_states: torch.Tensor,
+ value_states: torch.Tensor,
+ params: dict,
+) -> torch.Tensor:
+ """Paged FP8 prefill path for chunked prefill with existing KV history.
+
+ Key semantics
+ -------------
+ * ``hpc.stem_paged_kv`` expects ``kv_seq_lens`` = **total visible KV length**
+ (history + current Q).
+ * ``hpc.attention_with_kvcache_blocksparse_prefill_fp8`` expects
+ ``seqlens_kvcache`` = **history length only** (tokens *before* the
+ current Q chunk).
+ """
+ B, H_q, Lq, _ = query_states.shape
+ _, _, Lkv, D_v = value_states.shape
+ if Lkv < Lq:
+ raise ValueError(
+ f"Paged HPC prefill requires kv_len >= q_len, " f"got kv_len={Lkv}, q_len={Lq}."
+ )
+
+ history_len = Lkv - Lq
+
+ # --- FP8 quantisation -------------------------------------------------
+ q_fp8, q_scale = _quantize_query_for_paged_fp8(query_states)
+ k_varlen = _pack_bhld_to_varlen(key_states)
+ v_varlen = _pack_bhld_to_varlen(value_states)
+ k_fp8, k_scale = _quantize_per_tensor_fp8(k_varlen)
+ v_fp8, v_scale = _quantize_per_tensor_fp8(v_varlen)
+ kcache, kv_indices = _pack_paged_cache(k_fp8, B, Lkv, params["block_size"])
+ vcache, _ = _pack_paged_cache(v_fp8, B, Lkv, params["block_size"])
+
+ visible_kv_lens = torch.full((B,), Lkv, dtype=torch.int32, device=query_states.device)
+ history_kv_lens = torch.full((B,), history_len, dtype=torch.int32, device=query_states.device)
+ _, cu_q = _uniform_seq_metadata(B, Lq, query_states.device)
+
+ # --- Step 1: generate sparse mask -------------------------------------
+ mask = hpc.stem_paged_kv(
+ q_fp8,
+ kcache,
+ vcache,
+ q_scale,
+ k_scale,
+ v_scale,
+ kv_indices,
+ cu_q,
+ visible_kv_lens,
+ lambda_mag=params["lambda_mag"],
+ alpha=params["alpha"],
+ stem_block_size=params["block_size"],
+ stem_stride=params["stem_stride"],
+ causal=True,
+ initial_blocks=params["initial_blocks"],
+ window_size=params["window_size"],
+ k_block_num_rate=params["k_block_num_rate"],
+ k_block_num_bias=params["k_block_num_bias"],
+ )
+
+ # --- Step 2: block-sparse attention -----------------------------------
+ output = hpc.attention_with_kvcache_blocksparse_prefill_fp8(
+ q_fp8,
+ kcache,
+ vcache,
+ q_scale,
+ k_scale,
+ v_scale,
+ cu_q,
+ kv_indices,
+ history_kv_lens, # history length (not total visible!)
+ Lq, # max_seqlens_q
+ mask,
+ )
+ return output.view(B, Lq, H_q, D_v).transpose(1, 2).contiguous()
+
+
+# ===================================================================== #
+# Top-level HPC dispatcher #
+# ===================================================================== #
+
+
+def stem_forward_hpc(
+ query_states: torch.Tensor,
+ key_states: torch.Tensor,
+ value_states: torch.Tensor,
+ prefill_kwargs: dict,
+) -> torch.Tensor:
+ """HPC backend entry point — selects among bf16 / fp8-varlen / fp8-paged.
+
+ Falls back to :func:`stem_forward_torch` when the hardware or
+ configuration is not compatible with the requested HPC path.
+ """
+ config = prefill_kwargs["attn_forward_config"]
+ layer_idx = prefill_kwargs["layer_idx"]
+ strict_hpc = config.get("hpc_strict", False)
+ hpc_dtype = config.get("hpc_dtype", "bf16")
+ fp8_path = config.get("hpc_fp8_path", "varlen")
+
+ # --- Guard: CUDA required ---------------------------------------------
+ if query_states.device.type != "cuda":
+ if strict_hpc:
+ raise RuntimeError("HPC stem backend requires CUDA tensors.")
+ if layer_idx == 0:
+ print("[Stem][HPC] CUDA tensors are required; falling back to torch backend.")
+ return _fallback_to_torch(query_states, key_states, value_states, prefill_kwargs)
+
+ params = _stem_runtime_params(config, layer_idx)
+
+ # --- BF16 dense path --------------------------------------------------
+ if hpc_dtype == "bf16":
+ try:
+ if layer_idx == 0:
+ print("[Stem][HPC] using bf16 dense prefill path.")
+ return _run_hpc_bf16_dense(query_states, key_states, value_states)
+ except Exception as exc:
+ if strict_hpc:
+ raise RuntimeError(f"HPC bf16 backend failed: {exc}") from exc
+ if layer_idx == 0:
+ print(
+ f"[Stem][HPC] bf16 dense path failed ({exc}); falling back to torch backend."
+ )
+ return _fallback_to_torch(query_states, key_states, value_states, prefill_kwargs)
+
+ # --- FP8 sparse path: validate dimensions -----------------------------
+ dim_qk = query_states.shape[-1]
+ dim_v = value_states.shape[-1]
+ block_size = params["block_size"]
+
+ if block_size != HPC_FP8_BLOCK_SIZE:
+ if strict_hpc:
+ raise RuntimeError(
+ f"HPC fp8 sparse prefill only supports block_size=128, got {block_size}."
+ )
+ if layer_idx == 0:
+ print(
+ f"[Stem][HPC] fp8 sparse prefill only supports block_size=128, "
+ f"got {block_size}; falling back to torch backend."
+ )
+ return _fallback_to_torch(query_states, key_states, value_states, prefill_kwargs)
+
+ if dim_qk != 128 or dim_v != 128:
+ if strict_hpc:
+ raise RuntimeError(f"Unsupported HPC fp8 head dims: dim_qk={dim_qk}, dim_v={dim_v}.")
+ if layer_idx == 0:
+ print(
+ f"[Stem][HPC] unsupported fp8 head dims dim_qk={dim_qk}, "
+ f"dim_v={dim_v}; falling back to torch backend."
+ )
+ return _fallback_to_torch(query_states, key_states, value_states, prefill_kwargs)
+
+ # --- Execute FP8 path -------------------------------------------------
+ try:
+ if fp8_path == "paged":
+ if key_states.shape[2] == query_states.shape[2]:
+ # First prefill chunk (no KV history). The paged attention
+ # kernel needs seqlens_kvcache > 0; use varlen instead.
+ if layer_idx == 0:
+ print(
+ "[Stem][HPC] first prefill chunk (q_len == kv_len, no history); "
+ "using varlen fp8 path."
+ )
+ return _run_hpc_varlen_stem(query_states, key_states, value_states, params)
+ if layer_idx == 0:
+ print("[Stem][HPC] using paged fp8 prefill path with stem_paged_kv mask.")
+ return _run_hpc_paged_stem(query_states, key_states, value_states, params)
+
+ if fp8_path == "varlen":
+ if layer_idx == 0:
+ print("[Stem][HPC] using varlen fp8 prefill path with hpc tpd mask.")
+ return _run_hpc_varlen_stem(query_states, key_states, value_states, params)
+
+ raise ValueError(f"Unsupported hpc_fp8_path={fp8_path!r}; expected 'paged' or 'varlen'.")
+
+ except Exception as exc:
+ if strict_hpc:
+ raise RuntimeError(f"HPC stem backend failed: {exc}") from exc
+ if layer_idx == 0:
+ print(
+ f"[Stem][HPC] {fp8_path} fp8 path failed ({exc}); "
+ "falling back to torch backend."
+ )
+ return _fallback_to_torch(query_states, key_states, value_states, prefill_kwargs)
diff --git a/angelslim/compressor/sparsity/stem/backends/torch_impl.py b/angelslim/compressor/sparsity/stem/backends/torch_impl.py
new file mode 100644
index 00000000..056fe46d
--- /dev/null
+++ b/angelslim/compressor/sparsity/stem/backends/torch_impl.py
@@ -0,0 +1,468 @@
+# Copyright 2025 Tencent Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+"""Pure-PyTorch implementation of the Stem sparse prefill.
+
+This module provides two main entry points:
+
+* :func:`_compute_triton_block_logits` — compute block-level importance
+ scores using a Triton-accelerated strided group GEMM.
+* :func:`stem_forward_torch` — full Stem prefill: score → schedule →
+ top-k mask → block-sparse (or pseudo-sparse) attention.
+"""
+
+from __future__ import annotations
+
+import math
+
+import torch
+import torch.nn.functional as F
+
+try:
+ from block_sparse_attn import block_sparse_attn_func
+
+ HAS_BLOCK_SPARSE_KERNEL = True
+except ImportError:
+ print(
+ "⚠️ [Stem] 'block_sparse_attn' not found. " "Falling back to pseudo-sparse implementation."
+ )
+ HAS_BLOCK_SPARSE_KERNEL = False
+
+
+# ---------------------------------------------------------------------------
+# Per-layer sparsity schedule
+# ---------------------------------------------------------------------------
+
+# Default per-layer keep-ratio: first 2 layers keep 100%, remaining layers 20%.
+_DEFAULT_LAYER_KEEP_RATIOS: list[float] = [1.0, 1.0] + [0.2] * 36
+
+# Short-sequence thresholds — below these lengths the sparsity schedule is
+# relaxed to avoid losing too much context.
+_SHORT_SEQ_THRESHOLD_FULL = 56 # keep-all threshold
+_SHORT_SEQ_THRESHOLD_LINEAR = 160 # linear blend threshold
+_SHORT_SEQ_LINEAR_RATE = 0.2
+_SHORT_SEQ_LINEAR_BIAS = 30
+_LONG_SEQ_RATE = 0.1
+_LONG_SEQ_BIAS = 30
+
+
+def generate_exact_k_schedule(
+ num_blocks: int,
+ alpha: float,
+ layer_idx: int,
+ device: torch.device,
+ num_heads: int | None = None,
+) -> torch.Tensor:
+ """Generate the per-block top-k budget schedule.
+
+ For each query-block position, the schedule tells how many key-blocks
+ should be kept (before the alpha decay).
+
+ Parameters
+ ----------
+ num_blocks : int
+ Number of query blocks (``Qb``).
+ alpha : float
+ Decay factor applied beyond the initial keep region.
+ layer_idx : int
+ Transformer layer index (controls the base keep ratio).
+ device : torch.device
+ Target device for the returned tensor.
+ num_heads : int | None
+ If given, return a ``(num_heads, num_blocks)`` tensor with
+ per-head schedules; otherwise return ``(num_blocks,)``.
+
+ Returns
+ -------
+ torch.Tensor
+ Budget schedule — ``(num_blocks,)`` or ``(num_heads, num_blocks)``.
+ """
+ keep_ratio = _DEFAULT_LAYER_KEEP_RATIOS[layer_idx]
+ per_head_ratios = [keep_ratio] * (num_heads or 1)
+
+ def _build_single_schedule(k_val: int, n: int) -> torch.Tensor:
+ # Relax budget for short sequences.
+ if k_val != n:
+ if n < _SHORT_SEQ_THRESHOLD_FULL:
+ k_val = n
+ elif n < _SHORT_SEQ_THRESHOLD_LINEAR:
+ k_val = int(n * _SHORT_SEQ_LINEAR_RATE + _SHORT_SEQ_LINEAR_BIAS)
+ else:
+ k_val = int(n * _LONG_SEQ_RATE) + _LONG_SEQ_BIAS
+
+ schedule = torch.full((n,), k_val, dtype=torch.long, device=device)
+ if n > k_val:
+ decay_len = n - k_val
+ k_end = k_val * alpha
+ ideal_vals = torch.linspace(
+ float(k_val),
+ float(k_end),
+ steps=decay_len,
+ dtype=torch.float64,
+ device=device,
+ )
+ schedule[k_val:] = torch.clamp(torch.floor(ideal_vals).long(), min=1, max=k_val)
+ return schedule
+
+ schedules = [
+ _build_single_schedule(int(ratio * num_blocks), num_blocks) for ratio in per_head_ratios
+ ]
+ stacked = torch.stack(schedules, dim=0)
+ return stacked.squeeze(0) if stacked.shape[0] == 1 else stacked
+
+
+# ---------------------------------------------------------------------------
+# Block-level importance scoring
+# ---------------------------------------------------------------------------
+
+
+def _block_downsample(
+ x: torch.Tensor,
+ seq_len: int,
+ num_blocks: int,
+ block_size: int,
+) -> torch.Tensor:
+ """Downsample *x* along the sequence dimension via per-block max-pooling.
+
+ Args:
+ x: Input tensor of shape ``(..., seq_len, D)``.
+ seq_len: Original sequence length.
+ num_blocks: Target block count.
+ block_size: Block size.
+
+ Returns:
+ Tensor of shape ``(..., num_blocks, D)`` — the maximum value in
+ each block.
+ """
+ padded_len = num_blocks * block_size
+ if seq_len % block_size != 0:
+ pad = torch.zeros(
+ x.shape[:-2] + (padded_len - seq_len, x.shape[-1]),
+ dtype=x.dtype,
+ device=x.device,
+ )
+ x = torch.cat([x, pad], dim=-2)
+ return x.view(x.shape[:-2] + (num_blocks, block_size, x.shape[-1])).max(dim=-2).values
+
+
+def _compute_triton_block_logits(
+ query_states: torch.Tensor,
+ key_states: torch.Tensor,
+ value_states: torch.Tensor,
+ block_size: int,
+ stride: int,
+ chunk_size: int,
+ norm: float,
+ causal: bool,
+) -> torch.Tensor:
+ """Compute block-level importance logits using the Triton strided GEMM.
+
+ The scoring process:
+
+ 1. Pad Q / K / V to a multiple of ``chunk_size``.
+ 2. Compute strided Q·K^T via the Triton kernel (chunk by chunk).
+ 3. Add a value-norm bonus term (log-normalised, ReLU-gated).
+ 4. Reduce to ``(Qb, Kb)`` block logits by averaging over sub-blocks.
+ 5. Apply a causal block mask.
+
+ Parameters
+ ----------
+ query_states : torch.Tensor — ``(B, H, L_q, D)``
+ key_states : torch.Tensor — ``(B, H, L_kv, D)``
+ value_states : torch.Tensor — ``(B, H, L_kv, D)``
+ block_size : int — size of each block (typically 128)
+ stride : int — striding factor for the GEMM
+ chunk_size : int — chunk width for iterating over Q
+ norm : float — additional scaling denominator
+ causal : bool — whether to apply causal masking
+
+ Returns
+ -------
+ torch.Tensor
+ Block logits of shape ``(B, H, Qb, Kb)``.
+ """
+ from ..ops.stem_kernel import flat_group_gemm_fuse_reshape
+
+ B, H, k_len, head_dim = key_states.shape
+ _, _, q_len, _ = query_states.shape
+ assert H == query_states.shape[1], "Q and K must have the same number of heads."
+ dtype = query_states.dtype
+ device = query_states.device
+
+ # --- Step 1: pad to a multiple of chunk_size --------------------------
+ target_seq_len = max(k_len, q_len)
+ target_seq_len = ((target_seq_len + chunk_size - 1) // chunk_size) * chunk_size
+ k_pad = target_seq_len - k_len
+ q_pad = target_seq_len - q_len
+
+ if k_pad > 0:
+ pad_key_states = F.pad(key_states, (0, 0, 0, k_pad), value=0).to("cuda")
+ value_norm = torch.norm(value_states, p=2, dim=-1, keepdim=True).to(device)
+ pad_value_states = F.pad(value_norm, (0, 0, 0, k_pad), value=0).to("cuda")
+ else:
+ pad_key_states = key_states
+ value_norm = torch.norm(value_states, p=2, dim=-1, keepdim=True).to(device)
+ pad_value_states = value_norm
+
+ if q_pad > 0:
+ pad_query_states = F.pad(query_states, (0, 0, 0, q_pad), value=0).to("cuda")
+ else:
+ pad_query_states = query_states
+
+ # --- Derived dimensions -----------------------------------------------
+ reshaped_chunk_size = chunk_size // stride
+ reshaped_block_size = block_size // stride
+
+ pad_q_len = pad_query_states.shape[2]
+ pad_k_len = pad_key_states.shape[2]
+ assert pad_q_len == pad_k_len == target_seq_len
+
+ q_down_len = pad_q_len // stride
+ k_down_len = pad_k_len // stride
+ pad_Qb = pad_q_len // block_size
+ pad_Kb = pad_k_len // block_size
+ chunk_base = (pad_Kb - pad_Qb) * reshaped_block_size
+
+ # --- Step 2: value-norm bonus term ------------------------------------
+ v_down = _block_downsample(pad_value_states, pad_k_len, k_down_len, stride).squeeze(-1)
+ v_log_norm = torch.log(v_down + 1e-6)
+
+ LAMBDA_MAG = 0.2 # magnitude of the value-norm bonus
+
+ valid_len_down = k_len // stride
+ if valid_len_down > 0:
+ valid_v = v_log_norm[:, :, :valid_len_down]
+ v_mean = valid_v.mean(dim=-1, keepdim=True)
+ v_std = valid_v.std(dim=-1, keepdim=True)
+ else:
+ v_mean = v_log_norm.mean(dim=-1, keepdim=True)
+ v_std = v_log_norm.std(dim=-1, keepdim=True)
+
+ v_log_norm = (v_log_norm - v_mean) / (v_std + 1e-6)
+ v_log_norm = F.relu(v_log_norm)
+ v_bonus = LAMBDA_MAG * v_log_norm
+
+ # --- Step 3: chunked strided Q·K^T via Triton -------------------------
+ scores = torch.zeros((B, H, q_down_len, k_down_len), dtype=dtype, device=device)
+ scale = query_states.new_tensor(1.0 / (math.sqrt(head_dim) * stride * norm), dtype=dtype)
+ q_chunk_num = pad_q_len // chunk_size
+
+ for chunk_idx in range(q_chunk_num):
+ chunk_q_start = chunk_idx * reshaped_chunk_size
+ chunk_q_end = chunk_q_start + reshaped_chunk_size
+ q_slice_start = chunk_q_start * stride
+ q_slice_end = chunk_q_end * stride
+
+ attn_weights_slice = flat_group_gemm_fuse_reshape(
+ pad_query_states[:, :, q_slice_start:q_slice_end, :].contiguous(),
+ pad_key_states.contiguous(),
+ stride,
+ chunk_base + chunk_q_start,
+ chunk_base + chunk_q_end,
+ is_causal=causal,
+ )
+
+ attn_scores = attn_weights_slice * scale + v_bonus.unsqueeze(-2)
+ scores[:, :, chunk_q_start:chunk_q_end, :] = attn_scores
+ del attn_weights_slice
+
+ # --- Step 4: reduce to block logits -----------------------------------
+ scores = scores.view(B, H, pad_Qb, reshaped_block_size, pad_Kb, reshaped_block_size)
+ block_logits = scores.mean(dim=3).mean(dim=4)
+
+ # Trim to the actual (unpadded) block counts.
+ Qb = (q_len + block_size - 1) // block_size
+ Kb = (k_len + block_size - 1) // block_size
+ block_logits = block_logits[:, :, :Qb, :Kb]
+
+ # --- Step 5: causal block mask ----------------------------------------
+ if causal:
+ qb_idx = torch.arange(Qb, device=device).view(1, 1, -1, 1)
+ kb_idx = torch.arange(Kb, device=device).view(1, 1, 1, -1)
+ block_logits = block_logits.masked_fill(kb_idx > qb_idx, float("-inf"))
+
+ return block_logits.to(dtype)
+
+
+# ---------------------------------------------------------------------------
+# Stem prefill — torch backend
+# ---------------------------------------------------------------------------
+
+
+def stem_forward_torch(
+ query_states: torch.Tensor,
+ key_states: torch.Tensor,
+ value_states: torch.Tensor,
+ prefill_kwargs: dict,
+) -> torch.Tensor:
+ """Stem sparse-prefill implementation using pure PyTorch + optional Triton.
+
+ Workflow:
+ 1. Compute block-level importance scores via :func:`_compute_triton_block_logits`.
+ 2. Build a per-layer top-k schedule (:func:`generate_exact_k_schedule`).
+ 3. Select the top-k blocks, add initial-blocks and sliding-window blocks.
+ 4. Run block-sparse attention (real kernel if available, else pseudo-sparse).
+
+ Parameters
+ ----------
+ query_states : ``(B, H, L_q, D)``
+ key_states : ``(B, H, L_kv, D)``
+ value_states : ``(B, H, L_kv, D)``
+ prefill_kwargs : dict
+ Contains ``"attn_forward_config"`` and ``"layer_idx"``.
+
+ Returns
+ -------
+ torch.Tensor
+ Attention output — ``(B, H, L_q, D)``.
+ """
+ config = prefill_kwargs["attn_forward_config"]
+ layer_idx = prefill_kwargs["layer_idx"]
+
+ block_size: int = config.get("block_size", 128)
+ config_alpha = config.get("stem_alpha", 1.0)
+
+ B, H, Lq, head_dim = query_states.shape
+ _, _, Tk, _ = key_states.shape
+
+ scaling = head_dim**-0.5
+ Qb = (Lq + block_size - 1) // block_size
+ Kb = (Tk + block_size - 1) // block_size
+
+ stride: int = config.get("stride", 8)
+ chunk_size: int = config.get("chunk_size", 2048)
+ norm: float = config.get("norm", 1.0)
+
+ # --- 1. Block-level scoring -------------------------------------------
+ block_logits = _compute_triton_block_logits(
+ query_states,
+ key_states,
+ value_states,
+ block_size=block_size,
+ stride=stride,
+ chunk_size=chunk_size,
+ norm=norm,
+ causal=True,
+ )
+
+ # --- 2. Per-layer sparsity schedule -----------------------------------
+ mask_block = torch.zeros((B, H, Qb, Kb), device=query_states.device, dtype=torch.bool)
+
+ alpha = config_alpha[layer_idx] if isinstance(config_alpha, (list, tuple)) else config_alpha
+ sched = generate_exact_k_schedule(Qb, alpha, layer_idx, query_states.device, num_heads=H)
+ growth = max(1.0, alpha)
+
+ if sched.dim() == 1:
+ head_needed = torch.clamp(torch.ceil(sched[0].to(torch.float32) * growth), max=Kb)
+ needed_k = int(head_needed.item())
+ budget = sched.view(1, 1, -1, 1)
+ else:
+ if sched.shape[0] != H:
+ raise ValueError("Per-head k_start configuration does not match number of heads.")
+ head_start = sched[:, 0].to(torch.float32)
+ head_needed = torch.clamp(torch.ceil(head_start * growth), max=Kb)
+ needed_k = int(head_needed.max().item())
+ budget = sched.view(1, H, -1, 1)
+
+ needed_k = max(0, min(needed_k, Kb))
+
+ # --- 3. Top-k block selection -----------------------------------------
+ if needed_k > 0:
+ topk_vals, topk_idx = torch.topk(block_logits, needed_k, dim=-1)
+ rank = (
+ torch.arange(needed_k, device=query_states.device)
+ .view(1, 1, 1, -1)
+ .expand(B, H, Qb, needed_k)
+ )
+ keep = (rank < budget) & torch.isfinite(topk_vals)
+ if keep.any():
+ mask_block.scatter_(3, topk_idx, keep)
+
+ # Causal block mask.
+ q_range = torch.arange(Qb, device=query_states.device).view(1, 1, Qb, 1)
+ k_range = torch.arange(Kb, device=query_states.device).view(1, 1, 1, Kb)
+ causal_block_mask = k_range <= q_range
+
+ # Always keep the first few blocks (sink tokens).
+ initial_blocks: int = int(config.get("initial_blocks", 4))
+ if initial_blocks > 0:
+ mask_block |= (k_range < min(initial_blocks, Kb)) & causal_block_mask
+
+ # Sliding-window blocks.
+ window_size: int = int(config.get("window_size", 4))
+ if window_size > 0:
+ recent_start = torch.clamp(q_range - (window_size - 1), min=0)
+ mask_block |= (k_range >= recent_start) & causal_block_mask
+
+ mask_block &= causal_block_mask
+
+ # --- 4. Block-sparse attention ----------------------------------------
+ if HAS_BLOCK_SPARSE_KERNEL:
+ q_kernel = query_states.transpose(1, 2).reshape(B * Lq, H, head_dim)
+ k_kernel = key_states.transpose(1, 2).reshape(B * Tk, H, head_dim)
+ v_kernel = value_states.transpose(1, 2).reshape(B * Tk, H, head_dim)
+
+ q_cu = torch.arange(
+ 0,
+ (B + 1) * Lq,
+ step=Lq,
+ dtype=torch.int32,
+ device=query_states.device,
+ )
+ k_cu = torch.arange(
+ 0,
+ (B + 1) * Tk,
+ step=Tk,
+ dtype=torch.int32,
+ device=query_states.device,
+ )
+ head_mask_type = torch.ones(H, dtype=torch.int32, device=query_states.device)
+
+ torch.cuda.synchronize()
+ attn_output = block_sparse_attn_func(
+ q_kernel,
+ k_kernel,
+ v_kernel,
+ q_cu,
+ k_cu,
+ head_mask_type,
+ None,
+ mask_block.contiguous(),
+ Lq,
+ Tk,
+ p_dropout=0.0,
+ deterministic=True,
+ is_causal=True,
+ )
+ torch.cuda.synchronize()
+ return attn_output.view(B, Lq, H, head_dim).transpose(1, 2)
+
+ # Fallback: pseudo-sparse attention (expand block mask to full mask).
+ mask_full = (
+ mask_block.unsqueeze(-1)
+ .unsqueeze(-3)
+ .expand(-1, -1, -1, block_size, -1, block_size)
+ .reshape(B, H, Qb * block_size, Kb * block_size)
+ )
+ mask_full = mask_full[..., :Lq, :Tk]
+
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * scaling
+ if Lq > 1:
+ causal_bool = torch.ones((Lq, Tk), device=query_states.device, dtype=torch.bool).triu(1)
+ attn_weights = attn_weights.masked_fill(causal_bool[None, None, :, :], float("-inf"))
+
+ attn_weights = attn_weights.masked_fill(~mask_full, float("-inf"))
+ probs = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
+ return torch.matmul(probs, value_states)
diff --git a/angelslim/compressor/sparsity/stem/modules/__init__.py b/angelslim/compressor/sparsity/stem/modules/__init__.py
new file mode 100644
index 00000000..5ff7c06b
--- /dev/null
+++ b/angelslim/compressor/sparsity/stem/modules/__init__.py
@@ -0,0 +1,20 @@
+# Copyright 2025 Tencent Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+"""Stem attention forward module."""
+
+from .forward import attn_forward
+
+__all__ = ["attn_forward"]
diff --git a/angelslim/compressor/sparsity/stem/modules/forward.py b/angelslim/compressor/sparsity/stem/modules/forward.py
new file mode 100644
index 00000000..061ce8e3
--- /dev/null
+++ b/angelslim/compressor/sparsity/stem/modules/forward.py
@@ -0,0 +1,225 @@
+# Copyright 2025 Tencent Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+"""Stem-patched attention forward pass.
+
+This module provides the replacement ``forward`` method that is bound to each
+attention layer by :func:`stem.patch.stem_patch`. During **prefill**
+(``q_len > 1``) it delegates to the Stem sparse backend; during **decode**
+(``q_len == 1``) it falls back to the model's original attention implementation
+(eager, FlashAttention-2, SDPA, etc.).
+
+The code mirrors the structure of
+``transformers.models.qwen3.modeling_qwen3.Qwen3Attention.forward``
+(Transformers >= 5.2) and should be kept in sync with upstream changes.
+"""
+
+from __future__ import annotations
+
+from collections.abc import Callable
+
+import torch
+from torch import nn
+from transformers.cache_utils import Cache
+from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
+from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
+from transformers.processing_utils import Unpack
+
+from ..backends import stem_forward
+
+# ---------------------------------------------------------------------------
+# Helper functions (identical to upstream Qwen3)
+# ---------------------------------------------------------------------------
+
+
+def rotate_half(x: torch.Tensor) -> torch.Tensor:
+ """Rotate the last dimension by splitting and concatenating halves."""
+ x1 = x[..., : x.shape[-1] // 2]
+ x2 = x[..., x.shape[-1] // 2 :]
+ return torch.cat((-x2, x1), dim=-1)
+
+
+def apply_rotary_pos_emb(
+ q: torch.Tensor,
+ k: torch.Tensor,
+ cos: torch.Tensor,
+ sin: torch.Tensor,
+ unsqueeze_dim: int = 1,
+) -> tuple[torch.Tensor, torch.Tensor]:
+ """Apply Rotary Position Embedding (RoPE) to query and key tensors."""
+ cos = cos.unsqueeze(unsqueeze_dim)
+ sin = sin.unsqueeze(unsqueeze_dim)
+ q_embed = (q * cos) + (rotate_half(q) * sin)
+ k_embed = (k * cos) + (rotate_half(k) * sin)
+ return q_embed, k_embed
+
+
+def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
+ """Repeat KV heads to match the number of query heads (GQA support).
+
+ ``(B, num_kv_heads, L, D)`` -> ``(B, num_attention_heads, L, D)``
+ """
+ if n_rep == 1:
+ return hidden_states
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
+ hidden_states = hidden_states[:, :, None, :, :].expand(
+ batch, num_key_value_heads, n_rep, slen, head_dim
+ )
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
+
+
+# ---------------------------------------------------------------------------
+# Fallback eager attention (used in decode phase, mirrors upstream)
+# ---------------------------------------------------------------------------
+
+
+def eager_attention_forward(
+ module: nn.Module,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attention_mask: torch.Tensor | None,
+ scaling: float,
+ dropout: float = 0.0,
+ **kwargs,
+) -> tuple[torch.Tensor, torch.Tensor]:
+ """Eager (non-sparse) scaled dot-product attention.
+
+ Used as the **decode** fallback when ``q_len == 1`` and no specialised
+ attention implementation (e.g. FlashAttention-2) is configured.
+ Matches the upstream ``eager_attention_forward`` in Transformers >= 5.2.
+ """
+ key_states = repeat_kv(key, module.num_key_value_groups)
+ value_states = repeat_kv(value, module.num_key_value_groups)
+
+ attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
+ if attention_mask is not None:
+ attn_weights = attn_weights + attention_mask
+
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
+ attn_output = torch.matmul(attn_weights, value_states)
+ attn_output = attn_output.transpose(1, 2).contiguous()
+
+ return attn_output, attn_weights
+
+
+# ---------------------------------------------------------------------------
+# Validation
+# ---------------------------------------------------------------------------
+
+
+def _assert_no_padding_mask_for_stem(attention_mask: torch.Tensor, k_len: int) -> None:
+ """Verify that the attention mask has no padding (required by Stem prefill).
+
+ Raises
+ ------
+ ValueError
+ If the mask is not 4-D or if the last query row contains ``-inf``
+ entries (indicating padding tokens).
+ """
+ if attention_mask.ndim != 4:
+ raise ValueError(f"attention_mask must be 4-D, got shape={tuple(attention_mask.shape)}")
+ last_row = attention_mask[:, :, -1, :k_len]
+ if not torch.isfinite(last_row).all():
+ raise ValueError("Stem prefill requires no padding mask (last query row has -inf).")
+
+
+# ---------------------------------------------------------------------------
+# Patched attention forward
+# ---------------------------------------------------------------------------
+
+
+def attn_forward(
+ self,
+ hidden_states: torch.Tensor,
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
+ attention_mask: torch.Tensor | None,
+ past_key_values: Cache | None = None,
+ **kwargs: Unpack[FlashAttentionKwargs],
+) -> tuple[torch.Tensor, torch.Tensor | None]:
+ """Stem-patched attention forward — drop-in replacement for
+ ``Qwen3Attention.forward`` (Transformers >= 5.2).
+
+ * **Prefill** (``q_len > 1``): delegates to :func:`stem_forward` which
+ computes block-sparse attention according to the configured backend.
+ * **Decode** (``q_len == 1``): uses the model's original attention
+ implementation (eager / FlashAttention-2 / SDPA / flex).
+ """
+ input_shape = hidden_states.shape[:-1]
+ hidden_shape = (*input_shape, -1, self.head_dim)
+
+ # --- QKV projection & RoPE (identical to upstream) --------------------
+ query_states = self.q_norm(self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
+ key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
+ value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
+
+ cos, sin = position_embeddings
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+
+ # --- KV cache update (Transformers >= 5.2 style) ----------------------
+ if past_key_values is not None:
+ key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx)
+
+ q_len = query_states.shape[2]
+ k_len = key_states.shape[2]
+
+ # --- Prefill (Stem sparse attention) ----------------------------------
+ if q_len > 1:
+ if attention_mask is not None:
+ _assert_no_padding_mask_for_stem(attention_mask, k_len)
+
+ prefill_kwargs = {
+ "layer_idx": self.layer_idx,
+ "attn_forward_config": self.attn_forward_config,
+ }
+ backend = self.attn_forward_config.get("backend", "torch")
+
+ # HPC kernels (both bf16 and fp8) handle GQA internally;
+ # only the pure-torch path needs explicit KV head repeat.
+ if backend == "hpc":
+ stem_key_states = key_states
+ stem_value_states = value_states
+ else:
+ stem_key_states = repeat_kv(key_states, self.num_key_value_groups)
+ stem_value_states = repeat_kv(value_states, self.num_key_value_groups)
+
+ attn_output = stem_forward(
+ query_states, stem_key_states, stem_value_states, prefill_kwargs
+ )
+ attn_output = attn_output.transpose(1, 2).contiguous()
+ attn_weights = None
+
+ # --- Decode (standard attention, mirrors upstream) ---------------------
+ else:
+ attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
+ self.config._attn_implementation, eager_attention_forward
+ )
+
+ attn_output, attn_weights = attention_interface(
+ self,
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ dropout=0.0 if not self.training else self.attention_dropout,
+ scaling=self.scaling,
+ sliding_window=self.sliding_window,
+ **kwargs,
+ )
+
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
+ attn_output = self.o_proj(attn_output)
+ return attn_output, attn_weights
diff --git a/angelslim/compressor/sparsity/stem/ops/__init__.py b/angelslim/compressor/sparsity/stem/ops/__init__.py
new file mode 100644
index 00000000..d7628048
--- /dev/null
+++ b/angelslim/compressor/sparsity/stem/ops/__init__.py
@@ -0,0 +1,16 @@
+# Copyright 2025 Tencent Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+"""Stem ops — Triton kernels for block-logit scoring."""
diff --git a/angelslim/compressor/sparsity/stem/ops/stem_kernel.py b/angelslim/compressor/sparsity/stem/ops/stem_kernel.py
new file mode 100644
index 00000000..c5325e7b
--- /dev/null
+++ b/angelslim/compressor/sparsity/stem/ops/stem_kernel.py
@@ -0,0 +1,191 @@
+# Copyright 2025 Tencent Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+"""Triton kernel: strided group GEMM with fused reshape for block-logit scoring.
+
+This kernel computes a strided dot product between query and key blocks,
+producing a downsampled attention-score matrix used to estimate per-block
+importance in the Stem scoring pipeline.
+"""
+
+from __future__ import annotations
+
+import torch
+import triton
+import triton.language as tl
+
+
+@triton.jit
+def flat_group_gemm_fuse_reshape_kernel(
+ Q,
+ K,
+ Out,
+ stride_qz,
+ stride_qh,
+ stride_qn,
+ stride_kz,
+ stride_kh,
+ stride_kn,
+ stride_oz,
+ stride_oh,
+ stride_on,
+ chunk_start,
+ chunk_end,
+ H: tl.constexpr,
+ STRIDE: tl.constexpr,
+ HEAD_DIM: tl.constexpr,
+ BLOCK_M: tl.constexpr,
+ BLOCK_N: tl.constexpr,
+ is_causal: tl.constexpr,
+):
+ """Triton kernel: one tile of the strided Q·K^T group GEMM.
+
+ Args:
+ Q: Query tensor pointer, shape ``(B, H, L_q, D)``.
+ K: Key tensor pointer, shape ``(B, H, L_kv, D)``.
+ Out: Output tensor pointer, shape ``(B, H, L_q // stride, L_kv // stride)``.
+ stride_qz: Stride of Q along the batch dimension.
+ stride_qh: Stride of Q along the head dimension.
+ stride_qn: Stride of Q along the sequence dimension.
+ stride_kz: Stride of K along the batch dimension.
+ stride_kh: Stride of K along the head dimension.
+ stride_kn: Stride of K along the sequence dimension.
+ stride_oz: Stride of Out along the batch dimension.
+ stride_oh: Stride of Out along the head dimension.
+ stride_on: Stride of Out along the row (downsampled query) dimension.
+ chunk_start: Logical chunk start boundary (downsampled coords) for causal masking.
+ chunk_end: Logical chunk end boundary (downsampled coords).
+ H: Number of attention heads (compile-time constant).
+ STRIDE: Striding factor for downsampling (compile-time constant).
+ HEAD_DIM: Per-head hidden dimension (compile-time constant).
+ BLOCK_M: Tile size along the query (M) axis (compile-time constant).
+ BLOCK_N: Tile size along the key (N) axis (compile-time constant).
+ is_causal: Whether to apply causal masking (compile-time constant).
+ """
+ block_m = tl.program_id(0).to(tl.int64)
+ block_n = tl.program_id(1).to(tl.int64)
+ batch_id = tl.program_id(2).to(tl.int64) // H
+ head_id = tl.program_id(2).to(tl.int64) % H
+
+ # Early exit for causal tiles that are entirely above the diagonal.
+ if is_causal and chunk_start + (block_m + 1) * BLOCK_M <= block_n * BLOCK_N:
+ return
+
+ Q_ptrs = (
+ Q + batch_id * stride_qz + head_id * stride_qh + block_m * BLOCK_M * STRIDE * stride_qn
+ )
+ K_ptrs = (
+ K + batch_id * stride_kz + head_id * stride_kh + block_n * BLOCK_N * STRIDE * stride_kn
+ )
+
+ Q_ptrs = (
+ Q_ptrs
+ + tl.arange(0, BLOCK_M)[:, None] * (stride_qn * STRIDE)
+ + tl.arange(0, HEAD_DIM)[None, :]
+ + stride_qn * (STRIDE - 1)
+ )
+ K_ptrs = (
+ K_ptrs
+ + tl.arange(0, BLOCK_N)[None, :] * (stride_kn * STRIDE)
+ + tl.arange(0, HEAD_DIM)[:, None]
+ )
+
+ # Accumulate strided dot products.
+ acc = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
+ for _iter in range(STRIDE):
+ q = tl.load(Q_ptrs - _iter * stride_qn)
+ k = tl.load(K_ptrs + _iter * stride_kn)
+ acc += tl.dot(q, k)
+
+ O_ptrs = (
+ Out
+ + batch_id * stride_oz
+ + head_id * stride_oh
+ + block_m * BLOCK_M * stride_on
+ + block_n * BLOCK_N
+ )
+ O_ptrs = O_ptrs + tl.arange(0, BLOCK_M)[:, None] * stride_on + tl.arange(0, BLOCK_N)[None, :]
+ tl.store(O_ptrs, acc.to(Out.type.element_ty))
+
+
+def flat_group_gemm_fuse_reshape(
+ query_states: torch.Tensor,
+ key_states: torch.Tensor,
+ stride: int,
+ chunk_start: int,
+ chunk_end: int,
+ is_causal: bool = True,
+) -> torch.Tensor:
+ """Launch the strided group-GEMM Triton kernel.
+
+ Args:
+ query_states: Query tensor of shape ``(B, H, L_q, D)``.
+ key_states: Key tensor of shape ``(B, H, L_kv, D)``.
+ stride: Striding factor.
+ chunk_start: Logical chunk start boundary (in downsampled coordinates)
+ used for the causal early-exit check inside the kernel.
+ chunk_end: Logical chunk end boundary (in downsampled coordinates).
+ is_causal: Whether to apply causal masking.
+
+ Returns:
+ Downsampled score matrix of shape ``(B, H, L_q // stride, L_kv // stride)``.
+ """
+ B, H, Lq, D = query_states.shape
+ Lkv = key_states.shape[2]
+
+ assert key_states.shape[0] == B
+ assert key_states.shape[1] == H
+ assert key_states.shape[3] == D
+
+ BLOCK_M = 128
+ BLOCK_N = 128
+ assert (
+ Lq % (stride * BLOCK_M) == 0
+ ), f"q_len ({Lq}) must be divisible by stride*BLOCK_M ({stride * BLOCK_M})"
+ assert (
+ Lkv % (stride * BLOCK_N) == 0
+ ), f"kv_len ({Lkv}) must be divisible by stride*BLOCK_N ({stride * BLOCK_N})"
+
+ output = torch.empty(
+ (B, H, Lq // stride, Lkv // stride),
+ dtype=query_states.dtype,
+ device=query_states.device,
+ )
+
+ grid = (Lq // stride // BLOCK_M, Lkv // stride // BLOCK_N, B * H)
+ flat_group_gemm_fuse_reshape_kernel[grid](
+ query_states,
+ key_states,
+ output,
+ query_states.stride(0),
+ query_states.stride(1),
+ query_states.stride(2),
+ key_states.stride(0),
+ key_states.stride(1),
+ key_states.stride(2),
+ output.stride(0),
+ output.stride(1),
+ output.stride(2),
+ chunk_start,
+ chunk_end,
+ H,
+ stride,
+ D,
+ BLOCK_M,
+ BLOCK_N,
+ is_causal,
+ )
+
+ return output
diff --git a/angelslim/compressor/sparsity/stem/patch.py b/angelslim/compressor/sparsity/stem/patch.py
new file mode 100644
index 00000000..1ae1fe5b
--- /dev/null
+++ b/angelslim/compressor/sparsity/stem/patch.py
@@ -0,0 +1,62 @@
+# Copyright 2025 Tencent Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+"""Model-patching logic: replace the standard attention forward with Stem's."""
+
+from __future__ import annotations
+
+from typing import TYPE_CHECKING
+
+from .modules.forward import attn_forward
+
+if TYPE_CHECKING:
+ from transformers import PreTrainedModel
+
+ from .stem_configuration import StemConfig
+
+
+def stem_patch(model: "PreTrainedModel", config: "StemConfig") -> "PreTrainedModel":
+ """Replace each attention layer's ``forward`` with :func:`attn_forward`.
+
+ Currently only **Qwen3** models are supported.
+
+ Args:
+ model: A HuggingFace causal-LM model (e.g. ``Qwen3ForCausalLM``).
+ config: Stem runtime configuration.
+
+ Returns:
+ The same *model* object, mutated in-place with Stem attention.
+
+ Raises:
+ ValueError: If the model's ``model_type`` does not contain ``"qwen3"``.
+ """
+ model_type = model.config.model_type.lower()
+ if "qwen3" not in model_type:
+ raise ValueError(f"Only Qwen3 is supported, got model_type={model_type!r}")
+
+ AttentionClass = model.model.layers[0].self_attn.__class__
+
+ # Ensure every layer carries its own index (used by schedule functions).
+ for i, layer in enumerate(model.model.layers):
+ layer.self_attn.layer_idx = i
+
+ def _apply_stem_forward(module: object) -> None:
+ """Bind the Stem ``attn_forward`` and config to each attention module."""
+ if isinstance(module, AttentionClass):
+ module.attn_forward_config = config.attn_kwargs
+ module.forward = attn_forward.__get__(module, AttentionClass)
+
+ model.apply(_apply_stem_forward)
+ return model
diff --git a/angelslim/compressor/sparsity/stem/stem.py b/angelslim/compressor/sparsity/stem/stem.py
new file mode 100644
index 00000000..2add3b04
--- /dev/null
+++ b/angelslim/compressor/sparsity/stem/stem.py
@@ -0,0 +1,49 @@
+# Copyright 2025 Tencent Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+"""High-level entry point for applying the Stem patch to a HuggingFace model."""
+
+from __future__ import annotations
+
+from typing import TYPE_CHECKING
+
+from .patch import stem_patch
+from .stem_configuration import StemConfig
+
+if TYPE_CHECKING:
+ from transformers import PreTrainedModel
+
+
+class StemInference:
+ """Callable object that patches a model to use Stem sparse attention.
+
+ Usage::
+
+ stem = StemInference(attn_kwargs={"backend": "hpc", "hpc_dtype": "fp8"})
+ model = stem(model)
+
+ Args:
+ attn_kwargs: Forwarded to ``StemConfig``. See its docstring for valid keys.
+ """
+
+ def __init__(self, attn_kwargs: dict | None = None) -> None:
+ self.config = StemConfig(attn_kwargs=attn_kwargs)
+
+ def __call__(self, model: "PreTrainedModel") -> "PreTrainedModel":
+ """Apply the Stem attention patch and return the modified model."""
+ return stem_patch(model, self.config)
+
+ def __repr__(self) -> str:
+ return f"StemInference(config={self.config!r})"
diff --git a/angelslim/compressor/sparsity/stem/stem_configuration.py b/angelslim/compressor/sparsity/stem/stem_configuration.py
new file mode 100644
index 00000000..2bf80013
--- /dev/null
+++ b/angelslim/compressor/sparsity/stem/stem_configuration.py
@@ -0,0 +1,61 @@
+# Copyright 2025 Tencent Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+"""Configuration class for the Stem sparse attention module."""
+
+from __future__ import annotations
+
+# Supported option values.
+SUPPORTED_BACKENDS = {"torch", "hpc"}
+SUPPORTED_HPC_DTYPES = {"bf16", "fp8"}
+
+
+class StemConfig:
+ """Immutable configuration container for Stem attention.
+
+ Args:
+ attn_kwargs: Dictionary of keyword arguments forwarded to the attention
+ backend. Recognised keys include:
+
+ - ``backend``: ``"torch"`` (default) or ``"hpc"``.
+ - ``hpc_dtype``: ``"bf16"`` (default) or ``"fp8"`` (only used when
+ ``backend="hpc"``).
+ - ``stem_alpha``, ``block_size``, ``stride``, ``chunk_size``, etc.
+
+ Raises:
+ ValueError: If *backend* or *hpc_dtype* is not in the supported set.
+ """
+
+ def __init__(self, attn_kwargs: dict | None = None) -> None:
+ self.attn_kwargs: dict = dict(attn_kwargs or {})
+ self.attn_kwargs.setdefault("backend", "torch")
+ self.attn_kwargs.setdefault("hpc_dtype", "bf16")
+
+ backend = self.attn_kwargs["backend"]
+ hpc_dtype = self.attn_kwargs["hpc_dtype"]
+
+ if backend not in SUPPORTED_BACKENDS:
+ raise ValueError(
+ f"Unsupported stem backend: {backend!r}. "
+ f"Choose from {sorted(SUPPORTED_BACKENDS)}."
+ )
+ if hpc_dtype not in SUPPORTED_HPC_DTYPES:
+ raise ValueError(
+ f"Unsupported hpc_dtype: {hpc_dtype!r}. "
+ f"Choose from {sorted(SUPPORTED_HPC_DTYPES)}."
+ )
+
+ def __repr__(self) -> str:
+ return f"StemConfig(attn_kwargs={self.attn_kwargs!r})"
diff --git a/docs/source/features/sparse_attention/index.md b/docs/source/features/sparse_attention/index.md
new file mode 100644
index 00000000..1352d628
--- /dev/null
+++ b/docs/source/features/sparse_attention/index.md
@@ -0,0 +1,10 @@
+# 稀疏注意力
+
+稀疏注意力(Sparse Attention)是 AngelSlim 针对长上下文大模型推理开发的 Prefill 加速模块。其核心目标是在推理过程中动态跳过不重要的注意力块,从而显著降低 Prefill 阶段的计算量与延迟。
+
+:::{toctree}
+:caption: Contents
+:maxdepth: 1
+
+stem
+:::
diff --git a/docs/source/features/sparse_attention/stem.md b/docs/source/features/sparse_attention/stem.md
new file mode 100644
index 00000000..edecc726
--- /dev/null
+++ b/docs/source/features/sparse_attention/stem.md
@@ -0,0 +1,157 @@
+# Stem: Rethinking Causal Information Flow in Sparse Attention
+
+**Stem** 是 AngelSlim 的稀疏注意力算法,用于加速长上下文 LLM 的 **Prefill** 阶段。它通过在 block 粒度上估计注意力重要性,动态选择 top-k 关键块执行 block-sparse attention,在保持生成质量的同时大幅降低 Prefill 延迟。
+
+## 1. 算法动机
+
+长上下文推理(如 32K–128K tokens)中,Prefill 阶段的全量 attention 计算是主要瓶颈:
+
+- 计算量随序列长度 **二次增长**,显存和延迟双重压力
+- 实际上大部分 attention block 对最终输出贡献极小,存在大量冗余
+
+Stem 的核心思路是:**先用低成本的 block-level scoring 估计每个 attention block 的重要性,再只对重要的 block 执行精确 attention**。
+
+## 2. 技术原理
+
+Stem 的 Prefill 过程分为三步:
+
+### 2.1 Block-Level Scoring
+
+使用 **Triton 加速的 strided group GEMM** 计算下采样的 Q·K^T 分数矩阵,并结合 value-norm bonus 项,得到每个 query-block 对每个 key-block 的重要性估计:
+
+$$\text{score}(Q_i, K_j) = \frac{Q_i \cdot K_j^T}{\sqrt{d} \cdot s \cdot n} + \lambda \cdot \text{ReLU}(\bar{v}_j)$$
+
+其中 $s$ 为 stride 因子,$n$ 为归一化系数,$\bar{v}_j$ 为 value-norm 的标准化对数值。
+
+### 2.2 Top-k Schedule
+
+每层根据预设的 keep-ratio 和 alpha 衰减因子,生成 per-block 的 top-k budget:
+
+- **前 N 层(warmup)**:alpha=1.0,保留更多 block 以保证底层特征提取的完整性
+- **后续层**:alpha=0.7,更激进的稀疏化以加速计算
+- 额外保证 **initial blocks**(sink tokens)和 **sliding window** blocks 始终被保留
+
+### 2.3 Block-Sparse Attention
+
+根据 top-k mask 执行稀疏 attention:
+
+- 如果安装了 `block-sparse-attn` 库,使用真正的 block-sparse kernel
+- 否则自动 fallback 到 pseudo-sparse 实现(展开 mask 后做 dense attention)
+- **HPC 后端**支持 bf16 dense prefill 和 fp8 block-sparse prefill(varlen / paged 两种路径)
+
+**Decode 阶段不受影响**,仍使用模型原始的 attention 实现(FlashAttention-2 / eager / SDPA)。
+
+## 3. 支持范围
+
+| 维度 | 支持情况 |
+|------|---------|
+| **后端** | `torch`(纯 PyTorch + Triton)、`hpc`(HPC C++ 扩展) |
+| **HPC 精度** | bf16(dense prefill)、fp8(block-sparse prefill,varlen / paged) |
+| **序列长度** | 无上限,建议 4K+ tokens 以体现加速效果 |
+
+## 4. 快速开始
+
+确保已安装 AngelSlim(`pip install -e .` 或 `uv sync`),然后在项目根目录运行:
+
+### Dense 对照(无 Stem patch)
+
+```bash
+python tools/run_stem.py \
+ --mode dense \
+ --model-path /path/to/Qwen3-8B \
+ --prompt-file prompt.txt \
+ --max-new-tokens 160
+```
+
+### Stem + HPC bf16
+
+```bash
+python tools/run_stem.py \
+ --mode stem \
+ --stem-backend hpc \
+ --hpc-dtype bf16 \
+ --model-path /path/to/Qwen3-8B \
+ --prompt-file prompt.txt \
+ --max-new-tokens 160
+```
+
+### Stem + HPC fp8
+
+```bash
+python tools/run_stem.py \
+ --mode stem \
+ --stem-backend hpc \
+ --hpc-dtype fp8 \
+ --model-path /path/to/Qwen3-8B \
+ --prompt-file prompt.txt \
+ --max-new-tokens 160
+```
+
+### 使用自定义 prompt
+
+```bash
+python tools/run_stem.py \
+ --mode stem \
+ --stem-backend hpc \
+ --hpc-dtype bf16 \
+ --model-path /path/to/Qwen3-8B \
+ --prompt-file my_long_document.txt \
+ --max-new-tokens 256
+```
+
+也可以通过封装脚本启动:
+
+```bash
+bash scripts/sparsity/run_stem.sh /path/to/Qwen3-8B prompt.txt stem
+```
+
+## 5. 参数说明
+
+| 参数 | 默认值 | 说明 |
+|------|--------|------|
+| `backend` | `"torch"` | 后端选择:`"torch"` 或 `"hpc"` |
+| `hpc_dtype` | `"bf16"` | HPC 后端精度:`"bf16"` 或 `"fp8"` |
+| `hpc_fp8_path` | `"paged"` | FP8 执行路径:`"varlen"` 或 `"paged"` |
+| `stem_alpha` | `1.0` | per-layer alpha 衰减因子,可传 list 实现分层控制 |
+| `block_size` | `128` | attention block 大小 |
+| `stride` | `8` | scoring 阶段的下采样步长 |
+| `chunk_size` | `2048` | scoring 阶段的分块宽度 |
+| `norm` | `1.0` | scoring 阶段的额外归一化系数 |
+| `initial_blocks` | `4` | 始终保留的头部 block 数量(sink tokens) |
+| `window_size` | `4` | sliding window 保留的尾部 block 数量 |
+
+## 6. 代码结构
+
+```
+angelslim/compressor/sparsity/
+├── __init__.py # re-export StemInference
+└── stem/
+ ├── __init__.py # 包入口
+ ├── stem.py # StemInference 类(主入口)
+ ├── patch.py # 模型 patch 逻辑
+ ├── stem_configuration.py # StemConfig 配置
+ ├── backends/
+ │ ├── dispatcher.py # torch / hpc 路由
+ │ ├── torch_impl.py # PyTorch + Triton 实现
+ │ └── hpc_impl.py # HPC C++ 扩展实现
+ ├── modules/
+ │ └── forward.py # patched attention forward
+ └── ops/
+ └── stem_kernel.py # Triton kernel
+
+tools/run_stem.py # 推理入口
+scripts/sparsity/run_stem.sh # 启动脚本
+```
+
+## 7. Python API
+
+```python
+from angelslim.compressor.sparsity import StemInference
+
+stem = StemInference(attn_kwargs={
+ "backend": "hpc",
+ "hpc_dtype": "fp8",
+ "stem_alpha": [1.0] * 5 + [0.7] * 31, # 36 层 Qwen3-8B
+})
+model = stem(model) # 返回 patched 后的同一个 model 对象
+```
diff --git a/docs/source/index.md b/docs/source/index.md
index be52d176..50fd915a 100644
--- a/docs/source/index.md
+++ b/docs/source/index.md
@@ -63,9 +63,10 @@ AngelSlim是腾讯自研的,致力于打造更易用、更全面和更高效
- Tequila
- - Eagle3
- SpecExit
- - - **稀疏注意力**
+ - - **稀疏注意力**
- - Minference(建设中)
+ - Stem
+ - Minference(建设中)
* - **图/视频生文(VLM)**
- - Hunyuan-VL
- HunyuanOCR
@@ -129,8 +130,8 @@ getting_started/quickstrat
features/quantization/index
features/speculative_decoding/index
+features/sparse_attention/index
features/diffusion/index
-features/token_compressor/index.md
:::
% Additional capabilities
diff --git a/scripts/sparsity/run_stem.sh b/scripts/sparsity/run_stem.sh
new file mode 100755
index 00000000..904aa907
--- /dev/null
+++ b/scripts/sparsity/run_stem.sh
@@ -0,0 +1,23 @@
+#!/usr/bin/env bash
+# Stem sparse-attention launch script.
+#
+# Usage:
+# bash scripts/sparsity/run_stem.sh /path/to/Qwen3-8B prompt.txt stem
+# bash scripts/sparsity/run_stem.sh /path/to/Qwen3-8B prompt.txt dense
+set -euo pipefail
+
+SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
+ROOT_DIR="$(cd "$SCRIPT_DIR/../.." && pwd)"
+
+MODEL_PATH="${1:-Qwen/Qwen3-8B}"
+PROMPT_FILE="${2:?Usage: $0 [stem|dense] [extra args...]}"
+MODE="${3:-stem}"
+shift 3 2>/dev/null || true
+
+cd "$ROOT_DIR"
+exec python -u tools/run_stem.py \
+ --model-path "$MODEL_PATH" \
+ --model-name "Qwen/Qwen3-8B" \
+ --prompt-file "$PROMPT_FILE" \
+ --mode "$MODE" \
+ "$@"
diff --git a/tools/run_stem.py b/tools/run_stem.py
new file mode 100644
index 00000000..aecb4fc7
--- /dev/null
+++ b/tools/run_stem.py
@@ -0,0 +1,232 @@
+# Copyright 2025 Tencent Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+Stem sparse attention inference script.
+
+Usage:
+ python tools/run_stem.py --mode stem --model-path /path/to/Qwen3-8B --prompt-file prompt.txt
+ python tools/run_stem.py --mode dense --model-path /path/to/Qwen3-8B --prompt-file prompt.txt
+ python tools/run_stem.py --mode stem --stem-backend hpc \
+ --hpc-dtype fp8 --model-path /path/to/Qwen3-8B --prompt-file prompt.txt
+"""
+
+import argparse
+import sys
+import time
+import traceback
+
+import torch
+from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
+
+from angelslim.compressor.sparsity.stem import StemInference
+
+DEFAULT_MODEL_PATH = "Qwen/Qwen3-8B"
+DEFAULT_MODEL_NAME = "Qwen/Qwen3-8B"
+DEFAULT_MAX_MODEL_LEN = 131072
+
+
+def build_stem_alpha_schedule(
+ num_layers: int,
+ warmup_layers: int = 5,
+ warmup_alpha: float = 1.0,
+ steady_alpha: float = 0.7,
+) -> list[float]:
+ if num_layers <= 0:
+ raise ValueError(f"num_layers must be positive, got {num_layers}")
+ if warmup_layers < 0 or warmup_layers > num_layers:
+ raise ValueError(f"warmup_layers must be in [0, {num_layers}], got {warmup_layers}")
+ return [warmup_alpha] * warmup_layers + [steady_alpha] * (num_layers - warmup_layers)
+
+
+def build_prompt(tokenizer, raw_prompt: str) -> str:
+ messages = [{"role": "user", "content": raw_prompt}]
+ try:
+ return tokenizer.apply_chat_template(
+ messages,
+ tokenize=False,
+ add_generation_prompt=True,
+ enable_thinking=False,
+ )
+ except TypeError:
+ return tokenizer.apply_chat_template(
+ messages,
+ tokenize=False,
+ add_generation_prompt=True,
+ )
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(
+ description="Stem sparse attention inference: Dense vs Stem on Qwen3.",
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
+ )
+ parser.add_argument(
+ "--mode",
+ type=str,
+ default="stem",
+ choices=["dense", "stem"],
+ help="Attention mode: 'dense' (no patch) or 'stem' (Stem sparse).",
+ )
+ parser.add_argument(
+ "--stem-backend",
+ type=str,
+ default="torch",
+ choices=["torch", "hpc"],
+ help="Stem backend: 'torch' (PyTorch + Triton) or 'hpc' (HPC C++ extension).",
+ )
+ parser.add_argument(
+ "--hpc-dtype",
+ type=str,
+ default="bf16",
+ choices=["bf16", "fp8"],
+ help="HPC backend precision.",
+ )
+ parser.add_argument(
+ "--hpc-fp8-path",
+ type=str,
+ default="paged",
+ choices=["paged", "varlen"],
+ help="HPC FP8 execution path.",
+ )
+ parser.add_argument(
+ "--model-path",
+ type=str,
+ default=DEFAULT_MODEL_PATH,
+ help="Path to the model directory or HuggingFace model ID.",
+ )
+ parser.add_argument(
+ "--model-name",
+ type=str,
+ default=DEFAULT_MODEL_NAME,
+ help="Model name (used to determine rope_scaling config).",
+ )
+ parser.add_argument(
+ "--max-model-len",
+ type=int,
+ default=DEFAULT_MAX_MODEL_LEN,
+ help="Maximum position embeddings length.",
+ )
+ parser.add_argument(
+ "--prompt-file",
+ type=str,
+ required=True,
+ help="Path to a text file containing the input prompt.",
+ )
+ parser.add_argument(
+ "--max-new-tokens", type=int, default=256, help="Maximum number of new tokens to generate."
+ )
+ return parser.parse_args()
+
+
+def main():
+ args = parse_args()
+
+ print(f"[Env] Python executable: {sys.executable}")
+ print(f"Loading tokenizer from {args.model_path} ...")
+ tokenizer = AutoTokenizer.from_pretrained(args.model_path, trust_remote_code=True)
+ if tokenizer.pad_token is None:
+ tokenizer.pad_token = tokenizer.eos_token
+
+ print("Loading model config ...")
+ config = AutoConfig.from_pretrained(args.model_path, trust_remote_code=True)
+ if "Qwen" in args.model_name or "qwen" in args.model_name.lower():
+ rope_theta = (
+ config.rope_scaling.get("rope_theta", 1000000) if config.rope_scaling else 1000000
+ )
+ config.rope_scaling = {
+ "rope_type": "yarn",
+ "rope_theta": rope_theta,
+ "factor": 4.0,
+ "original_max_position_embeddings": 32768,
+ }
+ config.max_position_embeddings = args.max_model_len
+
+ print("Loading model weights ...")
+ model = AutoModelForCausalLM.from_pretrained(
+ args.model_path,
+ config=config,
+ torch_dtype=torch.bfloat16,
+ device_map="cuda",
+ trust_remote_code=True,
+ attn_implementation="flash_attention_2",
+ )
+
+ if args.mode == "stem":
+ num_layers = int(model.config.num_hidden_layers)
+ stem_alpha = build_stem_alpha_schedule(
+ num_layers=num_layers,
+ warmup_layers=5,
+ warmup_alpha=1.0,
+ steady_alpha=0.7,
+ )
+ minf = StemInference(
+ attn_kwargs={
+ "backend": args.stem_backend,
+ "hpc_dtype": args.hpc_dtype,
+ "hpc_fp8_path": args.hpc_fp8_path,
+ "stem_alpha": stem_alpha,
+ },
+ )
+ model = minf(model)
+ msg = (
+ f"[Stem] Patch applied. backend={args.stem_backend}, "
+ f"hpc_dtype={args.hpc_dtype}, num_layers={num_layers}"
+ )
+ print(msg)
+ else:
+ print("[Dense] No Stem patch applied. Using standard flash_attention_2.")
+
+ with open(args.prompt_file, "r", encoding="utf-8") as f:
+ raw_prompt = f.read()
+ print(f"Loaded prompt from: {args.prompt_file}")
+
+ prompt = build_prompt(tokenizer, raw_prompt)
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
+ input_len = int(inputs.input_ids.shape[1])
+ print(f"[Input Stats] token_length={input_len}")
+
+ print("Generating ...")
+ start = time.time()
+ try:
+ with torch.no_grad():
+ out = model.generate(
+ **inputs,
+ max_new_tokens=args.max_new_tokens,
+ do_sample=False,
+ pad_token_id=tokenizer.eos_token_id,
+ )
+ torch.cuda.synchronize()
+ end = time.time()
+
+ gen_ids = out[0][input_len:]
+ gen_text = tokenizer.decode(
+ gen_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
+ )
+
+ print("=" * 80)
+ print(f"Mode: {args.mode}")
+ print(f"Time Taken: {end - start:.3f} s")
+ print(f"Generated Tokens: {len(gen_ids)}")
+ print("Model Output:")
+ print(gen_text.strip())
+ print("=" * 80)
+ except Exception as e:
+ print(f"[Error]: {e}")
+ traceback.print_exc()
+ raise
+
+
+if __name__ == "__main__":
+ main()