diff --git a/README.md b/README.md index bb774be8..6c2d41a4 100644 --- a/README.md +++ b/README.md @@ -94,7 +94,7 @@ A more accessible, comprehensive, and efficient toolkit for large model compress
  • Sparse Attention
  • diff --git a/README_cn.md b/README_cn.md index 6d16e363..ae1fc891 100644 --- a/README_cn.md +++ b/README_cn.md @@ -95,7 +95,7 @@
  • 稀疏注意力
  • diff --git a/angelslim/compressor/sparsity/__init__.py b/angelslim/compressor/sparsity/__init__.py index eca77b00..eaa4c7a2 100644 --- a/angelslim/compressor/sparsity/__init__.py +++ b/angelslim/compressor/sparsity/__init__.py @@ -11,3 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from .stem import StemInference # noqa: F401 + +__all__ = ["StemInference"] diff --git a/angelslim/compressor/sparsity/stem/__init__.py b/angelslim/compressor/sparsity/stem/__init__.py new file mode 100644 index 00000000..5ab6b1ed --- /dev/null +++ b/angelslim/compressor/sparsity/stem/__init__.py @@ -0,0 +1,25 @@ +# Copyright 2025 Tencent Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +"""Stem — Sparse Token Estimation Module for long-context LLM inference. + +Public API: + StemInference: Callable that patches a HuggingFace model to use Stem + sparse attention during the prefill stage. +""" + +from .stem import StemInference + +__all__ = ["StemInference"] diff --git a/angelslim/compressor/sparsity/stem/backends/__init__.py b/angelslim/compressor/sparsity/stem/backends/__init__.py new file mode 100644 index 00000000..b152ab84 --- /dev/null +++ b/angelslim/compressor/sparsity/stem/backends/__init__.py @@ -0,0 +1,20 @@ +# Copyright 2025 Tencent Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +"""Stem backend implementations (torch / HPC).""" + +from .dispatcher import stem_forward + +__all__ = ["stem_forward"] diff --git a/angelslim/compressor/sparsity/stem/backends/dispatcher.py b/angelslim/compressor/sparsity/stem/backends/dispatcher.py new file mode 100644 index 00000000..93582489 --- /dev/null +++ b/angelslim/compressor/sparsity/stem/backends/dispatcher.py @@ -0,0 +1,59 @@ +# Copyright 2025 Tencent Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +"""Backend dispatcher: routes Stem prefill to the correct implementation.""" + +from __future__ import annotations + +import torch + +from .torch_impl import stem_forward_torch + + +def stem_forward( + query_states: torch.Tensor, + key_states: torch.Tensor, + value_states: torch.Tensor, + prefill_kwargs: dict, +) -> torch.Tensor: + """Dispatch a Stem prefill call to the appropriate backend. + + Args: + query_states: Query tensor of shape ``(B, H_q, L_q, D)``. + key_states: Key tensor of shape ``(B, H_kv, L_kv, D)``. + value_states: Value tensor of shape ``(B, H_kv, L_kv, D)``. + prefill_kwargs: Must contain ``"attn_forward_config"`` (with a + ``"backend"`` key) and ``"layer_idx"``. + + Returns: + Attention output tensor of shape ``(B, H_q, L_q, D)``. + + Raises: + ValueError: If the requested backend is not ``"torch"`` or ``"hpc"``. + """ + config = prefill_kwargs["attn_forward_config"] + backend = config.get("backend", "torch") + + if backend == "torch": + return stem_forward_torch(query_states, key_states, value_states, prefill_kwargs) + + if backend == "hpc": + # Lazy import to avoid hard dependency on the ``hpc`` C++ extension + # when only the pure-torch path is needed. + from .hpc_impl import stem_forward_hpc + + return stem_forward_hpc(query_states, key_states, value_states, prefill_kwargs) + + raise ValueError(f"Unknown stem backend: {backend!r}") diff --git a/angelslim/compressor/sparsity/stem/backends/hpc_impl.py b/angelslim/compressor/sparsity/stem/backends/hpc_impl.py new file mode 100644 index 00000000..04dba8c4 --- /dev/null +++ b/angelslim/compressor/sparsity/stem/backends/hpc_impl.py @@ -0,0 +1,541 @@ +# Copyright 2025 Tencent Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +"""HPC (C++ extension) backend for Stem sparse prefill. + +Provides three execution paths: + +* **bf16 dense** — calls ``hpc.attention_prefill_bf16`` directly. +* **fp8 varlen** — Triton scoring → ``hpc.stem_tpd`` mask → + ``hpc.attention_blocksparse_prefill_fp8``. +* **fp8 paged** — ``hpc.stem_paged_kv`` mask → + ``hpc.attention_with_kvcache_blocksparse_prefill_fp8``. +""" + +from __future__ import annotations + +import hpc +import torch + +from .torch_impl import _compute_triton_block_logits, stem_forward_torch + +# --------------------------------------------------------------------------- +# Constants +# --------------------------------------------------------------------------- + +FP8_DTYPE = torch.float8_e4m3fn +HPC_FP8_BLOCK_SIZE = 128 +HPC_INITIAL_BLOCKS = 4 +HPC_WINDOW_SIZE = 4 + + +# --------------------------------------------------------------------------- +# Tensor packing / quantisation helpers +# --------------------------------------------------------------------------- + + +def _pack_bhld_to_varlen(x: torch.Tensor) -> torch.Tensor: + """Reshape ``(B, H, L, D)`` → ``(B*L, H, D)`` for variable-length HPC kernels.""" + B, H, L, D = x.shape + return x.transpose(1, 2).reshape(B * L, H, D).contiguous() + + +def _uniform_seq_metadata( + batch_size: int, + seq_len: int, + device: torch.device, +) -> tuple[torch.Tensor, torch.Tensor]: + """Return ``(seqlens, cu_seqlens)`` for a uniform-length batch. + + Parameters + ---------- + batch_size : int + seq_len : int + device : torch.device + + Returns + ------- + seqlens : torch.Tensor — ``(batch_size,)``, all equal to *seq_len*. + cu_seqlens : torch.Tensor — ``(batch_size + 1,)``, cumulative offsets. + """ + seqlens = torch.full((batch_size,), seq_len, dtype=torch.int32, device=device) + cu_seqlens = torch.arange( + 0, + (batch_size + 1) * seq_len, + step=seq_len, + dtype=torch.int32, + device=device, + ) + return seqlens, cu_seqlens + + +def _quantize_per_tensor_fp8( + x: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor]: + """Symmetric per-tensor FP8 (E4M3) quantisation. + + Returns + ------- + x_fp8 : torch.Tensor — quantised tensor (same shape, dtype ``float8_e4m3fn``). + scale : torch.Tensor — scalar scale factor, shape ``(1,)``. + """ + fp8_max = torch.finfo(FP8_DTYPE).max + scale = x.abs().amax().float().clamp(min=1e-12) / fp8_max + x_fp8 = torch.clamp(x.float() / scale, min=-fp8_max, max=fp8_max).to(FP8_DTYPE) + return x_fp8.contiguous(), scale.view(1) + + +def _quantize_query_for_paged_fp8( + query_states: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor]: + """Quantise the query tensor for the paged FP8 attention kernel. + + Uses HPC's per-token-group quantisation (``group_size = head_dim``). + + Args: + query_states: Query tensor of shape ``(B, H, L, D)``. + + Returns: + A tuple of ``(q_fp8, q_scale)`` where *q_fp8* has shape + ``(B*L, H, D)`` in ``float8_e4m3fn`` and *q_scale* has shape + ``(B, H, L_padded)`` with per-token scales. + + Raises: + ValueError: If ``head_dim != 128``. + """ + B, H, L, D = query_states.shape + if D != HPC_FP8_BLOCK_SIZE: + raise ValueError(f"HPC fp8 query quant only supports head_dim=128, got {D}.") + + q_rows = query_states.transpose(1, 2).reshape(B * L, H * D).contiguous() + q_fp8_rows, q_scale_rows = hpc.quant.per_token_group_fp8_quant(q_rows, group_size=D) + + q_fp8 = q_fp8_rows.view(B * L, H, D).contiguous() + q_scale = q_scale_rows.view(B, L, H).permute(0, 2, 1).contiguous() + + # Pad the sequence dimension to a multiple of the block size. + padded_L = ((L + HPC_FP8_BLOCK_SIZE - 1) // HPC_FP8_BLOCK_SIZE) * HPC_FP8_BLOCK_SIZE + if padded_L != L: + q_scale = torch.nn.functional.pad(q_scale, (0, padded_L - L)) + return q_fp8, q_scale + + +def _pack_paged_cache( + x_varlen: torch.Tensor, + batch_size: int, + seq_len: int, + block_size: int, +) -> tuple[torch.Tensor, torch.Tensor]: + """Pack a varlen KV tensor into paged-cache layout. + + Parameters + ---------- + x_varlen : torch.Tensor — ``(B*L, H, D)`` + batch_size, seq_len, block_size : int + + Returns + ------- + cache : torch.Tensor — ``(total_blocks, block_size, H, D)`` + kv_indices : torch.Tensor — ``(B, blocks_per_seq)`` block-table indices. + """ + _, H, D = x_varlen.shape + padded_L = ((seq_len + block_size - 1) // block_size) * block_size + blocks_per_seq = padded_L // block_size + + x_seq = x_varlen.view(batch_size, seq_len, H, D) + if padded_L > seq_len: + pad = torch.zeros( + (batch_size, padded_L - seq_len, H, D), + dtype=x_varlen.dtype, + device=x_varlen.device, + ) + x_seq = torch.cat([x_seq, pad], dim=1) + + cache = ( + x_seq.view(batch_size, blocks_per_seq, block_size, H, D) + .reshape(batch_size * blocks_per_seq, block_size, H, D) + .contiguous() + ) + kv_indices = torch.arange( + batch_size * blocks_per_seq, dtype=torch.int32, device=x_varlen.device + ).view(batch_size, blocks_per_seq) + return cache, kv_indices + + +# --------------------------------------------------------------------------- +# Head-repeat helper (GQA → full Q heads) +# --------------------------------------------------------------------------- + + +def _repeat_to_q_heads(x: torch.Tensor, num_q_heads: int) -> torch.Tensor: + """Repeat KV heads to match the query head count (GQA support). + + Parameters + ---------- + x : torch.Tensor — ``(B, H_kv, L, D)`` + num_q_heads : int — target number of heads (``H_q``). + + Returns + ------- + torch.Tensor — ``(B, H_q, L, D)`` + """ + B, H_kv, L, D = x.shape + if num_q_heads == H_kv: + return x + if num_q_heads % H_kv != 0: + raise ValueError(f"Cannot repeat kv heads from {H_kv} to {num_q_heads}.") + n_rep = num_q_heads // H_kv + x = x[:, :, None, :, :].expand(B, H_kv, n_rep, L, D) + return x.reshape(B, num_q_heads, L, D) + + +# --------------------------------------------------------------------------- +# Fallback to the pure-torch backend +# --------------------------------------------------------------------------- + + +def _fallback_to_torch( + query_states: torch.Tensor, + key_states: torch.Tensor, + value_states: torch.Tensor, + prefill_kwargs: dict, +) -> torch.Tensor: + """Fall back to :func:`stem_forward_torch` after repeating KV heads.""" + H_q = query_states.shape[1] + if key_states.shape[1] != H_q: + key_states = _repeat_to_q_heads(key_states, H_q) + if value_states.shape[1] != H_q: + value_states = _repeat_to_q_heads(value_states, H_q) + return stem_forward_torch(query_states, key_states, value_states, prefill_kwargs) + + +# --------------------------------------------------------------------------- +# Runtime parameter extraction +# --------------------------------------------------------------------------- + + +def _stem_runtime_params(config: dict, layer_idx: int) -> dict: + """Extract Stem runtime hyper-parameters from the user config dict.""" + alpha_cfg = config.get("stem_alpha", 1.0) + alpha = alpha_cfg[layer_idx] if isinstance(alpha_cfg, (list, tuple)) else alpha_cfg + return { + "block_size": int(config.get("block_size", HPC_FP8_BLOCK_SIZE)), + "stem_stride": int(config.get("hpc_stem_stride", config.get("stride", 16))), + "chunk_size": int(config.get("chunk_size", 2048)), + "norm": float(config.get("norm", 1.0)), + "initial_blocks": int(config.get("initial_blocks", HPC_INITIAL_BLOCKS)), + "window_size": int(config.get("window_size", HPC_WINDOW_SIZE)), + "lambda_mag": float(config.get("lambda_mag", 0.3)), + "alpha": float(alpha), + "k_block_num_rate": float(config.get("k_block_num_rate", 0.1)), + "k_block_num_bias": int(config.get("k_block_num_bias", 30)), + } + + +# ===================================================================== # +# Execution paths # +# ===================================================================== # + +# --- Path 1: bf16 dense ------------------------------------------------ + + +def _run_hpc_bf16_dense( + query_states: torch.Tensor, + key_states: torch.Tensor, + value_states: torch.Tensor, +) -> torch.Tensor: + """BF16 dense prefill via ``hpc.attention_prefill_bf16``.""" + B, H, Lq, _ = query_states.shape + D_v = value_states.shape[-1] + + q_varlen = _pack_bhld_to_varlen(query_states) + k_varlen = _pack_bhld_to_varlen(key_states) + v_varlen = _pack_bhld_to_varlen(value_states) + seqlens_q, cu_seqlens_q = _uniform_seq_metadata(B, Lq, query_states.device) + + output = hpc.attention_prefill_bf16(q_varlen, k_varlen, v_varlen, seqlens_q, cu_seqlens_q, Lq) + return output.view(B, Lq, H, D_v).transpose(1, 2).contiguous() + + +# --- Path 2: fp8 varlen ------------------------------------------------ + + +def _build_varlen_stem_mask( + query_states: torch.Tensor, + key_states: torch.Tensor, + value_states: torch.Tensor, + params: dict, +) -> torch.Tensor: + """Build the block-sparse mask using Triton scoring + ``hpc.stem_tpd``.""" + H_q = query_states.shape[1] + if key_states.shape[1] != H_q: + key_states = _repeat_to_q_heads(key_states, H_q) + if value_states.shape[1] != H_q: + value_states = _repeat_to_q_heads(value_states, H_q) + + block_logits = _compute_triton_block_logits( + query_states, + key_states, + value_states, + block_size=params["block_size"], + stride=params["stem_stride"], + chunk_size=params["chunk_size"], + norm=params["norm"], + causal=True, + ) + + B, _, Lq, _ = query_states.shape + Lkv = key_states.shape[2] + q_seq_lens = torch.full((B,), Lq, dtype=torch.int32, device=query_states.device) + kv_seq_lens = torch.full((B,), Lkv, dtype=torch.int32, device=query_states.device) + + return hpc.stem_tpd( + block_logits, + q_seq_lens, + kv_seq_lens, + params["block_size"], + params["alpha"], + params["initial_blocks"], + params["window_size"], + params["k_block_num_rate"], + params["k_block_num_bias"], + ) + + +def _run_hpc_varlen_stem( + query_states: torch.Tensor, + key_states: torch.Tensor, + value_states: torch.Tensor, + params: dict, +) -> torch.Tensor: + """FP8 varlen sparse prefill via ``hpc.attention_blocksparse_prefill_fp8``.""" + B, H_q, Lq, _ = query_states.shape + _, _, Lkv, D_v = value_states.shape + + q_varlen = _pack_bhld_to_varlen(query_states) + k_varlen = _pack_bhld_to_varlen(key_states) + v_varlen = _pack_bhld_to_varlen(value_states) + + q_fp8, q_scale = _quantize_per_tensor_fp8(q_varlen) + k_fp8, k_scale = _quantize_per_tensor_fp8(k_varlen) + v_fp8, v_scale = _quantize_per_tensor_fp8(v_varlen) + + _, cu_q = _uniform_seq_metadata(B, Lq, query_states.device) + _, cu_kv = _uniform_seq_metadata(B, Lkv, query_states.device) + + mask = _build_varlen_stem_mask(query_states, key_states, value_states, params) + + output = hpc.attention_blocksparse_prefill_fp8( + q_fp8, + k_fp8, + v_fp8, + cu_q, + cu_kv, + Lq, + Lkv, + q_scale, + k_scale, + v_scale, + block_mask=mask, + ) + return output.view(B, Lq, H_q, D_v).transpose(1, 2).contiguous() + + +# --- Path 3: fp8 paged ------------------------------------------------- + + +def _run_hpc_paged_stem( + query_states: torch.Tensor, + key_states: torch.Tensor, + value_states: torch.Tensor, + params: dict, +) -> torch.Tensor: + """Paged FP8 prefill path for chunked prefill with existing KV history. + + Key semantics + ------------- + * ``hpc.stem_paged_kv`` expects ``kv_seq_lens`` = **total visible KV length** + (history + current Q). + * ``hpc.attention_with_kvcache_blocksparse_prefill_fp8`` expects + ``seqlens_kvcache`` = **history length only** (tokens *before* the + current Q chunk). + """ + B, H_q, Lq, _ = query_states.shape + _, _, Lkv, D_v = value_states.shape + if Lkv < Lq: + raise ValueError( + f"Paged HPC prefill requires kv_len >= q_len, " f"got kv_len={Lkv}, q_len={Lq}." + ) + + history_len = Lkv - Lq + + # --- FP8 quantisation ------------------------------------------------- + q_fp8, q_scale = _quantize_query_for_paged_fp8(query_states) + k_varlen = _pack_bhld_to_varlen(key_states) + v_varlen = _pack_bhld_to_varlen(value_states) + k_fp8, k_scale = _quantize_per_tensor_fp8(k_varlen) + v_fp8, v_scale = _quantize_per_tensor_fp8(v_varlen) + kcache, kv_indices = _pack_paged_cache(k_fp8, B, Lkv, params["block_size"]) + vcache, _ = _pack_paged_cache(v_fp8, B, Lkv, params["block_size"]) + + visible_kv_lens = torch.full((B,), Lkv, dtype=torch.int32, device=query_states.device) + history_kv_lens = torch.full((B,), history_len, dtype=torch.int32, device=query_states.device) + _, cu_q = _uniform_seq_metadata(B, Lq, query_states.device) + + # --- Step 1: generate sparse mask ------------------------------------- + mask = hpc.stem_paged_kv( + q_fp8, + kcache, + vcache, + q_scale, + k_scale, + v_scale, + kv_indices, + cu_q, + visible_kv_lens, + lambda_mag=params["lambda_mag"], + alpha=params["alpha"], + stem_block_size=params["block_size"], + stem_stride=params["stem_stride"], + causal=True, + initial_blocks=params["initial_blocks"], + window_size=params["window_size"], + k_block_num_rate=params["k_block_num_rate"], + k_block_num_bias=params["k_block_num_bias"], + ) + + # --- Step 2: block-sparse attention ----------------------------------- + output = hpc.attention_with_kvcache_blocksparse_prefill_fp8( + q_fp8, + kcache, + vcache, + q_scale, + k_scale, + v_scale, + cu_q, + kv_indices, + history_kv_lens, # history length (not total visible!) + Lq, # max_seqlens_q + mask, + ) + return output.view(B, Lq, H_q, D_v).transpose(1, 2).contiguous() + + +# ===================================================================== # +# Top-level HPC dispatcher # +# ===================================================================== # + + +def stem_forward_hpc( + query_states: torch.Tensor, + key_states: torch.Tensor, + value_states: torch.Tensor, + prefill_kwargs: dict, +) -> torch.Tensor: + """HPC backend entry point — selects among bf16 / fp8-varlen / fp8-paged. + + Falls back to :func:`stem_forward_torch` when the hardware or + configuration is not compatible with the requested HPC path. + """ + config = prefill_kwargs["attn_forward_config"] + layer_idx = prefill_kwargs["layer_idx"] + strict_hpc = config.get("hpc_strict", False) + hpc_dtype = config.get("hpc_dtype", "bf16") + fp8_path = config.get("hpc_fp8_path", "varlen") + + # --- Guard: CUDA required --------------------------------------------- + if query_states.device.type != "cuda": + if strict_hpc: + raise RuntimeError("HPC stem backend requires CUDA tensors.") + if layer_idx == 0: + print("[Stem][HPC] CUDA tensors are required; falling back to torch backend.") + return _fallback_to_torch(query_states, key_states, value_states, prefill_kwargs) + + params = _stem_runtime_params(config, layer_idx) + + # --- BF16 dense path -------------------------------------------------- + if hpc_dtype == "bf16": + try: + if layer_idx == 0: + print("[Stem][HPC] using bf16 dense prefill path.") + return _run_hpc_bf16_dense(query_states, key_states, value_states) + except Exception as exc: + if strict_hpc: + raise RuntimeError(f"HPC bf16 backend failed: {exc}") from exc + if layer_idx == 0: + print( + f"[Stem][HPC] bf16 dense path failed ({exc}); falling back to torch backend." + ) + return _fallback_to_torch(query_states, key_states, value_states, prefill_kwargs) + + # --- FP8 sparse path: validate dimensions ----------------------------- + dim_qk = query_states.shape[-1] + dim_v = value_states.shape[-1] + block_size = params["block_size"] + + if block_size != HPC_FP8_BLOCK_SIZE: + if strict_hpc: + raise RuntimeError( + f"HPC fp8 sparse prefill only supports block_size=128, got {block_size}." + ) + if layer_idx == 0: + print( + f"[Stem][HPC] fp8 sparse prefill only supports block_size=128, " + f"got {block_size}; falling back to torch backend." + ) + return _fallback_to_torch(query_states, key_states, value_states, prefill_kwargs) + + if dim_qk != 128 or dim_v != 128: + if strict_hpc: + raise RuntimeError(f"Unsupported HPC fp8 head dims: dim_qk={dim_qk}, dim_v={dim_v}.") + if layer_idx == 0: + print( + f"[Stem][HPC] unsupported fp8 head dims dim_qk={dim_qk}, " + f"dim_v={dim_v}; falling back to torch backend." + ) + return _fallback_to_torch(query_states, key_states, value_states, prefill_kwargs) + + # --- Execute FP8 path ------------------------------------------------- + try: + if fp8_path == "paged": + if key_states.shape[2] == query_states.shape[2]: + # First prefill chunk (no KV history). The paged attention + # kernel needs seqlens_kvcache > 0; use varlen instead. + if layer_idx == 0: + print( + "[Stem][HPC] first prefill chunk (q_len == kv_len, no history); " + "using varlen fp8 path." + ) + return _run_hpc_varlen_stem(query_states, key_states, value_states, params) + if layer_idx == 0: + print("[Stem][HPC] using paged fp8 prefill path with stem_paged_kv mask.") + return _run_hpc_paged_stem(query_states, key_states, value_states, params) + + if fp8_path == "varlen": + if layer_idx == 0: + print("[Stem][HPC] using varlen fp8 prefill path with hpc tpd mask.") + return _run_hpc_varlen_stem(query_states, key_states, value_states, params) + + raise ValueError(f"Unsupported hpc_fp8_path={fp8_path!r}; expected 'paged' or 'varlen'.") + + except Exception as exc: + if strict_hpc: + raise RuntimeError(f"HPC stem backend failed: {exc}") from exc + if layer_idx == 0: + print( + f"[Stem][HPC] {fp8_path} fp8 path failed ({exc}); " + "falling back to torch backend." + ) + return _fallback_to_torch(query_states, key_states, value_states, prefill_kwargs) diff --git a/angelslim/compressor/sparsity/stem/backends/torch_impl.py b/angelslim/compressor/sparsity/stem/backends/torch_impl.py new file mode 100644 index 00000000..056fe46d --- /dev/null +++ b/angelslim/compressor/sparsity/stem/backends/torch_impl.py @@ -0,0 +1,468 @@ +# Copyright 2025 Tencent Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +"""Pure-PyTorch implementation of the Stem sparse prefill. + +This module provides two main entry points: + +* :func:`_compute_triton_block_logits` — compute block-level importance + scores using a Triton-accelerated strided group GEMM. +* :func:`stem_forward_torch` — full Stem prefill: score → schedule → + top-k mask → block-sparse (or pseudo-sparse) attention. +""" + +from __future__ import annotations + +import math + +import torch +import torch.nn.functional as F + +try: + from block_sparse_attn import block_sparse_attn_func + + HAS_BLOCK_SPARSE_KERNEL = True +except ImportError: + print( + "⚠️ [Stem] 'block_sparse_attn' not found. " "Falling back to pseudo-sparse implementation." + ) + HAS_BLOCK_SPARSE_KERNEL = False + + +# --------------------------------------------------------------------------- +# Per-layer sparsity schedule +# --------------------------------------------------------------------------- + +# Default per-layer keep-ratio: first 2 layers keep 100%, remaining layers 20%. +_DEFAULT_LAYER_KEEP_RATIOS: list[float] = [1.0, 1.0] + [0.2] * 36 + +# Short-sequence thresholds — below these lengths the sparsity schedule is +# relaxed to avoid losing too much context. +_SHORT_SEQ_THRESHOLD_FULL = 56 # keep-all threshold +_SHORT_SEQ_THRESHOLD_LINEAR = 160 # linear blend threshold +_SHORT_SEQ_LINEAR_RATE = 0.2 +_SHORT_SEQ_LINEAR_BIAS = 30 +_LONG_SEQ_RATE = 0.1 +_LONG_SEQ_BIAS = 30 + + +def generate_exact_k_schedule( + num_blocks: int, + alpha: float, + layer_idx: int, + device: torch.device, + num_heads: int | None = None, +) -> torch.Tensor: + """Generate the per-block top-k budget schedule. + + For each query-block position, the schedule tells how many key-blocks + should be kept (before the alpha decay). + + Parameters + ---------- + num_blocks : int + Number of query blocks (``Qb``). + alpha : float + Decay factor applied beyond the initial keep region. + layer_idx : int + Transformer layer index (controls the base keep ratio). + device : torch.device + Target device for the returned tensor. + num_heads : int | None + If given, return a ``(num_heads, num_blocks)`` tensor with + per-head schedules; otherwise return ``(num_blocks,)``. + + Returns + ------- + torch.Tensor + Budget schedule — ``(num_blocks,)`` or ``(num_heads, num_blocks)``. + """ + keep_ratio = _DEFAULT_LAYER_KEEP_RATIOS[layer_idx] + per_head_ratios = [keep_ratio] * (num_heads or 1) + + def _build_single_schedule(k_val: int, n: int) -> torch.Tensor: + # Relax budget for short sequences. + if k_val != n: + if n < _SHORT_SEQ_THRESHOLD_FULL: + k_val = n + elif n < _SHORT_SEQ_THRESHOLD_LINEAR: + k_val = int(n * _SHORT_SEQ_LINEAR_RATE + _SHORT_SEQ_LINEAR_BIAS) + else: + k_val = int(n * _LONG_SEQ_RATE) + _LONG_SEQ_BIAS + + schedule = torch.full((n,), k_val, dtype=torch.long, device=device) + if n > k_val: + decay_len = n - k_val + k_end = k_val * alpha + ideal_vals = torch.linspace( + float(k_val), + float(k_end), + steps=decay_len, + dtype=torch.float64, + device=device, + ) + schedule[k_val:] = torch.clamp(torch.floor(ideal_vals).long(), min=1, max=k_val) + return schedule + + schedules = [ + _build_single_schedule(int(ratio * num_blocks), num_blocks) for ratio in per_head_ratios + ] + stacked = torch.stack(schedules, dim=0) + return stacked.squeeze(0) if stacked.shape[0] == 1 else stacked + + +# --------------------------------------------------------------------------- +# Block-level importance scoring +# --------------------------------------------------------------------------- + + +def _block_downsample( + x: torch.Tensor, + seq_len: int, + num_blocks: int, + block_size: int, +) -> torch.Tensor: + """Downsample *x* along the sequence dimension via per-block max-pooling. + + Args: + x: Input tensor of shape ``(..., seq_len, D)``. + seq_len: Original sequence length. + num_blocks: Target block count. + block_size: Block size. + + Returns: + Tensor of shape ``(..., num_blocks, D)`` — the maximum value in + each block. + """ + padded_len = num_blocks * block_size + if seq_len % block_size != 0: + pad = torch.zeros( + x.shape[:-2] + (padded_len - seq_len, x.shape[-1]), + dtype=x.dtype, + device=x.device, + ) + x = torch.cat([x, pad], dim=-2) + return x.view(x.shape[:-2] + (num_blocks, block_size, x.shape[-1])).max(dim=-2).values + + +def _compute_triton_block_logits( + query_states: torch.Tensor, + key_states: torch.Tensor, + value_states: torch.Tensor, + block_size: int, + stride: int, + chunk_size: int, + norm: float, + causal: bool, +) -> torch.Tensor: + """Compute block-level importance logits using the Triton strided GEMM. + + The scoring process: + + 1. Pad Q / K / V to a multiple of ``chunk_size``. + 2. Compute strided Q·K^T via the Triton kernel (chunk by chunk). + 3. Add a value-norm bonus term (log-normalised, ReLU-gated). + 4. Reduce to ``(Qb, Kb)`` block logits by averaging over sub-blocks. + 5. Apply a causal block mask. + + Parameters + ---------- + query_states : torch.Tensor — ``(B, H, L_q, D)`` + key_states : torch.Tensor — ``(B, H, L_kv, D)`` + value_states : torch.Tensor — ``(B, H, L_kv, D)`` + block_size : int — size of each block (typically 128) + stride : int — striding factor for the GEMM + chunk_size : int — chunk width for iterating over Q + norm : float — additional scaling denominator + causal : bool — whether to apply causal masking + + Returns + ------- + torch.Tensor + Block logits of shape ``(B, H, Qb, Kb)``. + """ + from ..ops.stem_kernel import flat_group_gemm_fuse_reshape + + B, H, k_len, head_dim = key_states.shape + _, _, q_len, _ = query_states.shape + assert H == query_states.shape[1], "Q and K must have the same number of heads." + dtype = query_states.dtype + device = query_states.device + + # --- Step 1: pad to a multiple of chunk_size -------------------------- + target_seq_len = max(k_len, q_len) + target_seq_len = ((target_seq_len + chunk_size - 1) // chunk_size) * chunk_size + k_pad = target_seq_len - k_len + q_pad = target_seq_len - q_len + + if k_pad > 0: + pad_key_states = F.pad(key_states, (0, 0, 0, k_pad), value=0).to("cuda") + value_norm = torch.norm(value_states, p=2, dim=-1, keepdim=True).to(device) + pad_value_states = F.pad(value_norm, (0, 0, 0, k_pad), value=0).to("cuda") + else: + pad_key_states = key_states + value_norm = torch.norm(value_states, p=2, dim=-1, keepdim=True).to(device) + pad_value_states = value_norm + + if q_pad > 0: + pad_query_states = F.pad(query_states, (0, 0, 0, q_pad), value=0).to("cuda") + else: + pad_query_states = query_states + + # --- Derived dimensions ----------------------------------------------- + reshaped_chunk_size = chunk_size // stride + reshaped_block_size = block_size // stride + + pad_q_len = pad_query_states.shape[2] + pad_k_len = pad_key_states.shape[2] + assert pad_q_len == pad_k_len == target_seq_len + + q_down_len = pad_q_len // stride + k_down_len = pad_k_len // stride + pad_Qb = pad_q_len // block_size + pad_Kb = pad_k_len // block_size + chunk_base = (pad_Kb - pad_Qb) * reshaped_block_size + + # --- Step 2: value-norm bonus term ------------------------------------ + v_down = _block_downsample(pad_value_states, pad_k_len, k_down_len, stride).squeeze(-1) + v_log_norm = torch.log(v_down + 1e-6) + + LAMBDA_MAG = 0.2 # magnitude of the value-norm bonus + + valid_len_down = k_len // stride + if valid_len_down > 0: + valid_v = v_log_norm[:, :, :valid_len_down] + v_mean = valid_v.mean(dim=-1, keepdim=True) + v_std = valid_v.std(dim=-1, keepdim=True) + else: + v_mean = v_log_norm.mean(dim=-1, keepdim=True) + v_std = v_log_norm.std(dim=-1, keepdim=True) + + v_log_norm = (v_log_norm - v_mean) / (v_std + 1e-6) + v_log_norm = F.relu(v_log_norm) + v_bonus = LAMBDA_MAG * v_log_norm + + # --- Step 3: chunked strided Q·K^T via Triton ------------------------- + scores = torch.zeros((B, H, q_down_len, k_down_len), dtype=dtype, device=device) + scale = query_states.new_tensor(1.0 / (math.sqrt(head_dim) * stride * norm), dtype=dtype) + q_chunk_num = pad_q_len // chunk_size + + for chunk_idx in range(q_chunk_num): + chunk_q_start = chunk_idx * reshaped_chunk_size + chunk_q_end = chunk_q_start + reshaped_chunk_size + q_slice_start = chunk_q_start * stride + q_slice_end = chunk_q_end * stride + + attn_weights_slice = flat_group_gemm_fuse_reshape( + pad_query_states[:, :, q_slice_start:q_slice_end, :].contiguous(), + pad_key_states.contiguous(), + stride, + chunk_base + chunk_q_start, + chunk_base + chunk_q_end, + is_causal=causal, + ) + + attn_scores = attn_weights_slice * scale + v_bonus.unsqueeze(-2) + scores[:, :, chunk_q_start:chunk_q_end, :] = attn_scores + del attn_weights_slice + + # --- Step 4: reduce to block logits ----------------------------------- + scores = scores.view(B, H, pad_Qb, reshaped_block_size, pad_Kb, reshaped_block_size) + block_logits = scores.mean(dim=3).mean(dim=4) + + # Trim to the actual (unpadded) block counts. + Qb = (q_len + block_size - 1) // block_size + Kb = (k_len + block_size - 1) // block_size + block_logits = block_logits[:, :, :Qb, :Kb] + + # --- Step 5: causal block mask ---------------------------------------- + if causal: + qb_idx = torch.arange(Qb, device=device).view(1, 1, -1, 1) + kb_idx = torch.arange(Kb, device=device).view(1, 1, 1, -1) + block_logits = block_logits.masked_fill(kb_idx > qb_idx, float("-inf")) + + return block_logits.to(dtype) + + +# --------------------------------------------------------------------------- +# Stem prefill — torch backend +# --------------------------------------------------------------------------- + + +def stem_forward_torch( + query_states: torch.Tensor, + key_states: torch.Tensor, + value_states: torch.Tensor, + prefill_kwargs: dict, +) -> torch.Tensor: + """Stem sparse-prefill implementation using pure PyTorch + optional Triton. + + Workflow: + 1. Compute block-level importance scores via :func:`_compute_triton_block_logits`. + 2. Build a per-layer top-k schedule (:func:`generate_exact_k_schedule`). + 3. Select the top-k blocks, add initial-blocks and sliding-window blocks. + 4. Run block-sparse attention (real kernel if available, else pseudo-sparse). + + Parameters + ---------- + query_states : ``(B, H, L_q, D)`` + key_states : ``(B, H, L_kv, D)`` + value_states : ``(B, H, L_kv, D)`` + prefill_kwargs : dict + Contains ``"attn_forward_config"`` and ``"layer_idx"``. + + Returns + ------- + torch.Tensor + Attention output — ``(B, H, L_q, D)``. + """ + config = prefill_kwargs["attn_forward_config"] + layer_idx = prefill_kwargs["layer_idx"] + + block_size: int = config.get("block_size", 128) + config_alpha = config.get("stem_alpha", 1.0) + + B, H, Lq, head_dim = query_states.shape + _, _, Tk, _ = key_states.shape + + scaling = head_dim**-0.5 + Qb = (Lq + block_size - 1) // block_size + Kb = (Tk + block_size - 1) // block_size + + stride: int = config.get("stride", 8) + chunk_size: int = config.get("chunk_size", 2048) + norm: float = config.get("norm", 1.0) + + # --- 1. Block-level scoring ------------------------------------------- + block_logits = _compute_triton_block_logits( + query_states, + key_states, + value_states, + block_size=block_size, + stride=stride, + chunk_size=chunk_size, + norm=norm, + causal=True, + ) + + # --- 2. Per-layer sparsity schedule ----------------------------------- + mask_block = torch.zeros((B, H, Qb, Kb), device=query_states.device, dtype=torch.bool) + + alpha = config_alpha[layer_idx] if isinstance(config_alpha, (list, tuple)) else config_alpha + sched = generate_exact_k_schedule(Qb, alpha, layer_idx, query_states.device, num_heads=H) + growth = max(1.0, alpha) + + if sched.dim() == 1: + head_needed = torch.clamp(torch.ceil(sched[0].to(torch.float32) * growth), max=Kb) + needed_k = int(head_needed.item()) + budget = sched.view(1, 1, -1, 1) + else: + if sched.shape[0] != H: + raise ValueError("Per-head k_start configuration does not match number of heads.") + head_start = sched[:, 0].to(torch.float32) + head_needed = torch.clamp(torch.ceil(head_start * growth), max=Kb) + needed_k = int(head_needed.max().item()) + budget = sched.view(1, H, -1, 1) + + needed_k = max(0, min(needed_k, Kb)) + + # --- 3. Top-k block selection ----------------------------------------- + if needed_k > 0: + topk_vals, topk_idx = torch.topk(block_logits, needed_k, dim=-1) + rank = ( + torch.arange(needed_k, device=query_states.device) + .view(1, 1, 1, -1) + .expand(B, H, Qb, needed_k) + ) + keep = (rank < budget) & torch.isfinite(topk_vals) + if keep.any(): + mask_block.scatter_(3, topk_idx, keep) + + # Causal block mask. + q_range = torch.arange(Qb, device=query_states.device).view(1, 1, Qb, 1) + k_range = torch.arange(Kb, device=query_states.device).view(1, 1, 1, Kb) + causal_block_mask = k_range <= q_range + + # Always keep the first few blocks (sink tokens). + initial_blocks: int = int(config.get("initial_blocks", 4)) + if initial_blocks > 0: + mask_block |= (k_range < min(initial_blocks, Kb)) & causal_block_mask + + # Sliding-window blocks. + window_size: int = int(config.get("window_size", 4)) + if window_size > 0: + recent_start = torch.clamp(q_range - (window_size - 1), min=0) + mask_block |= (k_range >= recent_start) & causal_block_mask + + mask_block &= causal_block_mask + + # --- 4. Block-sparse attention ---------------------------------------- + if HAS_BLOCK_SPARSE_KERNEL: + q_kernel = query_states.transpose(1, 2).reshape(B * Lq, H, head_dim) + k_kernel = key_states.transpose(1, 2).reshape(B * Tk, H, head_dim) + v_kernel = value_states.transpose(1, 2).reshape(B * Tk, H, head_dim) + + q_cu = torch.arange( + 0, + (B + 1) * Lq, + step=Lq, + dtype=torch.int32, + device=query_states.device, + ) + k_cu = torch.arange( + 0, + (B + 1) * Tk, + step=Tk, + dtype=torch.int32, + device=query_states.device, + ) + head_mask_type = torch.ones(H, dtype=torch.int32, device=query_states.device) + + torch.cuda.synchronize() + attn_output = block_sparse_attn_func( + q_kernel, + k_kernel, + v_kernel, + q_cu, + k_cu, + head_mask_type, + None, + mask_block.contiguous(), + Lq, + Tk, + p_dropout=0.0, + deterministic=True, + is_causal=True, + ) + torch.cuda.synchronize() + return attn_output.view(B, Lq, H, head_dim).transpose(1, 2) + + # Fallback: pseudo-sparse attention (expand block mask to full mask). + mask_full = ( + mask_block.unsqueeze(-1) + .unsqueeze(-3) + .expand(-1, -1, -1, block_size, -1, block_size) + .reshape(B, H, Qb * block_size, Kb * block_size) + ) + mask_full = mask_full[..., :Lq, :Tk] + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * scaling + if Lq > 1: + causal_bool = torch.ones((Lq, Tk), device=query_states.device, dtype=torch.bool).triu(1) + attn_weights = attn_weights.masked_fill(causal_bool[None, None, :, :], float("-inf")) + + attn_weights = attn_weights.masked_fill(~mask_full, float("-inf")) + probs = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + return torch.matmul(probs, value_states) diff --git a/angelslim/compressor/sparsity/stem/modules/__init__.py b/angelslim/compressor/sparsity/stem/modules/__init__.py new file mode 100644 index 00000000..5ff7c06b --- /dev/null +++ b/angelslim/compressor/sparsity/stem/modules/__init__.py @@ -0,0 +1,20 @@ +# Copyright 2025 Tencent Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +"""Stem attention forward module.""" + +from .forward import attn_forward + +__all__ = ["attn_forward"] diff --git a/angelslim/compressor/sparsity/stem/modules/forward.py b/angelslim/compressor/sparsity/stem/modules/forward.py new file mode 100644 index 00000000..061ce8e3 --- /dev/null +++ b/angelslim/compressor/sparsity/stem/modules/forward.py @@ -0,0 +1,225 @@ +# Copyright 2025 Tencent Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +"""Stem-patched attention forward pass. + +This module provides the replacement ``forward`` method that is bound to each +attention layer by :func:`stem.patch.stem_patch`. During **prefill** +(``q_len > 1``) it delegates to the Stem sparse backend; during **decode** +(``q_len == 1``) it falls back to the model's original attention implementation +(eager, FlashAttention-2, SDPA, etc.). + +The code mirrors the structure of +``transformers.models.qwen3.modeling_qwen3.Qwen3Attention.forward`` +(Transformers >= 5.2) and should be kept in sync with upstream changes. +""" + +from __future__ import annotations + +from collections.abc import Callable + +import torch +from torch import nn +from transformers.cache_utils import Cache +from transformers.modeling_flash_attention_utils import FlashAttentionKwargs +from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS +from transformers.processing_utils import Unpack + +from ..backends import stem_forward + +# --------------------------------------------------------------------------- +# Helper functions (identical to upstream Qwen3) +# --------------------------------------------------------------------------- + + +def rotate_half(x: torch.Tensor) -> torch.Tensor: + """Rotate the last dimension by splitting and concatenating halves.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb( + q: torch.Tensor, + k: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + unsqueeze_dim: int = 1, +) -> tuple[torch.Tensor, torch.Tensor]: + """Apply Rotary Position Embedding (RoPE) to query and key tensors.""" + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """Repeat KV heads to match the number of query heads (GQA support). + + ``(B, num_kv_heads, L, D)`` -> ``(B, num_attention_heads, L, D)`` + """ + if n_rep == 1: + return hidden_states + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + hidden_states = hidden_states[:, :, None, :, :].expand( + batch, num_key_value_heads, n_rep, slen, head_dim + ) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +# --------------------------------------------------------------------------- +# Fallback eager attention (used in decode phase, mirrors upstream) +# --------------------------------------------------------------------------- + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: torch.Tensor | None, + scaling: float, + dropout: float = 0.0, + **kwargs, +) -> tuple[torch.Tensor, torch.Tensor]: + """Eager (non-sparse) scaled dot-product attention. + + Used as the **decode** fallback when ``q_len == 1`` and no specialised + attention implementation (e.g. FlashAttention-2) is configured. + Matches the upstream ``eager_attention_forward`` in Transformers >= 5.2. + """ + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +# --------------------------------------------------------------------------- +# Validation +# --------------------------------------------------------------------------- + + +def _assert_no_padding_mask_for_stem(attention_mask: torch.Tensor, k_len: int) -> None: + """Verify that the attention mask has no padding (required by Stem prefill). + + Raises + ------ + ValueError + If the mask is not 4-D or if the last query row contains ``-inf`` + entries (indicating padding tokens). + """ + if attention_mask.ndim != 4: + raise ValueError(f"attention_mask must be 4-D, got shape={tuple(attention_mask.shape)}") + last_row = attention_mask[:, :, -1, :k_len] + if not torch.isfinite(last_row).all(): + raise ValueError("Stem prefill requires no padding mask (last query row has -inf).") + + +# --------------------------------------------------------------------------- +# Patched attention forward +# --------------------------------------------------------------------------- + + +def attn_forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: torch.Tensor | None, + past_key_values: Cache | None = None, + **kwargs: Unpack[FlashAttentionKwargs], +) -> tuple[torch.Tensor, torch.Tensor | None]: + """Stem-patched attention forward — drop-in replacement for + ``Qwen3Attention.forward`` (Transformers >= 5.2). + + * **Prefill** (``q_len > 1``): delegates to :func:`stem_forward` which + computes block-sparse attention according to the configured backend. + * **Decode** (``q_len == 1``): uses the model's original attention + implementation (eager / FlashAttention-2 / SDPA / flex). + """ + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + # --- QKV projection & RoPE (identical to upstream) -------------------- + query_states = self.q_norm(self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2) + key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + # --- KV cache update (Transformers >= 5.2 style) ---------------------- + if past_key_values is not None: + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx) + + q_len = query_states.shape[2] + k_len = key_states.shape[2] + + # --- Prefill (Stem sparse attention) ---------------------------------- + if q_len > 1: + if attention_mask is not None: + _assert_no_padding_mask_for_stem(attention_mask, k_len) + + prefill_kwargs = { + "layer_idx": self.layer_idx, + "attn_forward_config": self.attn_forward_config, + } + backend = self.attn_forward_config.get("backend", "torch") + + # HPC kernels (both bf16 and fp8) handle GQA internally; + # only the pure-torch path needs explicit KV head repeat. + if backend == "hpc": + stem_key_states = key_states + stem_value_states = value_states + else: + stem_key_states = repeat_kv(key_states, self.num_key_value_groups) + stem_value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_output = stem_forward( + query_states, stem_key_states, stem_value_states, prefill_kwargs + ) + attn_output = attn_output.transpose(1, 2).contiguous() + attn_weights = None + + # --- Decode (standard attention, mirrors upstream) --------------------- + else: + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + sliding_window=self.sliding_window, + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights diff --git a/angelslim/compressor/sparsity/stem/ops/__init__.py b/angelslim/compressor/sparsity/stem/ops/__init__.py new file mode 100644 index 00000000..d7628048 --- /dev/null +++ b/angelslim/compressor/sparsity/stem/ops/__init__.py @@ -0,0 +1,16 @@ +# Copyright 2025 Tencent Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +"""Stem ops — Triton kernels for block-logit scoring.""" diff --git a/angelslim/compressor/sparsity/stem/ops/stem_kernel.py b/angelslim/compressor/sparsity/stem/ops/stem_kernel.py new file mode 100644 index 00000000..c5325e7b --- /dev/null +++ b/angelslim/compressor/sparsity/stem/ops/stem_kernel.py @@ -0,0 +1,191 @@ +# Copyright 2025 Tencent Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +"""Triton kernel: strided group GEMM with fused reshape for block-logit scoring. + +This kernel computes a strided dot product between query and key blocks, +producing a downsampled attention-score matrix used to estimate per-block +importance in the Stem scoring pipeline. +""" + +from __future__ import annotations + +import torch +import triton +import triton.language as tl + + +@triton.jit +def flat_group_gemm_fuse_reshape_kernel( + Q, + K, + Out, + stride_qz, + stride_qh, + stride_qn, + stride_kz, + stride_kh, + stride_kn, + stride_oz, + stride_oh, + stride_on, + chunk_start, + chunk_end, + H: tl.constexpr, + STRIDE: tl.constexpr, + HEAD_DIM: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + is_causal: tl.constexpr, +): + """Triton kernel: one tile of the strided Q·K^T group GEMM. + + Args: + Q: Query tensor pointer, shape ``(B, H, L_q, D)``. + K: Key tensor pointer, shape ``(B, H, L_kv, D)``. + Out: Output tensor pointer, shape ``(B, H, L_q // stride, L_kv // stride)``. + stride_qz: Stride of Q along the batch dimension. + stride_qh: Stride of Q along the head dimension. + stride_qn: Stride of Q along the sequence dimension. + stride_kz: Stride of K along the batch dimension. + stride_kh: Stride of K along the head dimension. + stride_kn: Stride of K along the sequence dimension. + stride_oz: Stride of Out along the batch dimension. + stride_oh: Stride of Out along the head dimension. + stride_on: Stride of Out along the row (downsampled query) dimension. + chunk_start: Logical chunk start boundary (downsampled coords) for causal masking. + chunk_end: Logical chunk end boundary (downsampled coords). + H: Number of attention heads (compile-time constant). + STRIDE: Striding factor for downsampling (compile-time constant). + HEAD_DIM: Per-head hidden dimension (compile-time constant). + BLOCK_M: Tile size along the query (M) axis (compile-time constant). + BLOCK_N: Tile size along the key (N) axis (compile-time constant). + is_causal: Whether to apply causal masking (compile-time constant). + """ + block_m = tl.program_id(0).to(tl.int64) + block_n = tl.program_id(1).to(tl.int64) + batch_id = tl.program_id(2).to(tl.int64) // H + head_id = tl.program_id(2).to(tl.int64) % H + + # Early exit for causal tiles that are entirely above the diagonal. + if is_causal and chunk_start + (block_m + 1) * BLOCK_M <= block_n * BLOCK_N: + return + + Q_ptrs = ( + Q + batch_id * stride_qz + head_id * stride_qh + block_m * BLOCK_M * STRIDE * stride_qn + ) + K_ptrs = ( + K + batch_id * stride_kz + head_id * stride_kh + block_n * BLOCK_N * STRIDE * stride_kn + ) + + Q_ptrs = ( + Q_ptrs + + tl.arange(0, BLOCK_M)[:, None] * (stride_qn * STRIDE) + + tl.arange(0, HEAD_DIM)[None, :] + + stride_qn * (STRIDE - 1) + ) + K_ptrs = ( + K_ptrs + + tl.arange(0, BLOCK_N)[None, :] * (stride_kn * STRIDE) + + tl.arange(0, HEAD_DIM)[:, None] + ) + + # Accumulate strided dot products. + acc = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + for _iter in range(STRIDE): + q = tl.load(Q_ptrs - _iter * stride_qn) + k = tl.load(K_ptrs + _iter * stride_kn) + acc += tl.dot(q, k) + + O_ptrs = ( + Out + + batch_id * stride_oz + + head_id * stride_oh + + block_m * BLOCK_M * stride_on + + block_n * BLOCK_N + ) + O_ptrs = O_ptrs + tl.arange(0, BLOCK_M)[:, None] * stride_on + tl.arange(0, BLOCK_N)[None, :] + tl.store(O_ptrs, acc.to(Out.type.element_ty)) + + +def flat_group_gemm_fuse_reshape( + query_states: torch.Tensor, + key_states: torch.Tensor, + stride: int, + chunk_start: int, + chunk_end: int, + is_causal: bool = True, +) -> torch.Tensor: + """Launch the strided group-GEMM Triton kernel. + + Args: + query_states: Query tensor of shape ``(B, H, L_q, D)``. + key_states: Key tensor of shape ``(B, H, L_kv, D)``. + stride: Striding factor. + chunk_start: Logical chunk start boundary (in downsampled coordinates) + used for the causal early-exit check inside the kernel. + chunk_end: Logical chunk end boundary (in downsampled coordinates). + is_causal: Whether to apply causal masking. + + Returns: + Downsampled score matrix of shape ``(B, H, L_q // stride, L_kv // stride)``. + """ + B, H, Lq, D = query_states.shape + Lkv = key_states.shape[2] + + assert key_states.shape[0] == B + assert key_states.shape[1] == H + assert key_states.shape[3] == D + + BLOCK_M = 128 + BLOCK_N = 128 + assert ( + Lq % (stride * BLOCK_M) == 0 + ), f"q_len ({Lq}) must be divisible by stride*BLOCK_M ({stride * BLOCK_M})" + assert ( + Lkv % (stride * BLOCK_N) == 0 + ), f"kv_len ({Lkv}) must be divisible by stride*BLOCK_N ({stride * BLOCK_N})" + + output = torch.empty( + (B, H, Lq // stride, Lkv // stride), + dtype=query_states.dtype, + device=query_states.device, + ) + + grid = (Lq // stride // BLOCK_M, Lkv // stride // BLOCK_N, B * H) + flat_group_gemm_fuse_reshape_kernel[grid]( + query_states, + key_states, + output, + query_states.stride(0), + query_states.stride(1), + query_states.stride(2), + key_states.stride(0), + key_states.stride(1), + key_states.stride(2), + output.stride(0), + output.stride(1), + output.stride(2), + chunk_start, + chunk_end, + H, + stride, + D, + BLOCK_M, + BLOCK_N, + is_causal, + ) + + return output diff --git a/angelslim/compressor/sparsity/stem/patch.py b/angelslim/compressor/sparsity/stem/patch.py new file mode 100644 index 00000000..1ae1fe5b --- /dev/null +++ b/angelslim/compressor/sparsity/stem/patch.py @@ -0,0 +1,62 @@ +# Copyright 2025 Tencent Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +"""Model-patching logic: replace the standard attention forward with Stem's.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from .modules.forward import attn_forward + +if TYPE_CHECKING: + from transformers import PreTrainedModel + + from .stem_configuration import StemConfig + + +def stem_patch(model: "PreTrainedModel", config: "StemConfig") -> "PreTrainedModel": + """Replace each attention layer's ``forward`` with :func:`attn_forward`. + + Currently only **Qwen3** models are supported. + + Args: + model: A HuggingFace causal-LM model (e.g. ``Qwen3ForCausalLM``). + config: Stem runtime configuration. + + Returns: + The same *model* object, mutated in-place with Stem attention. + + Raises: + ValueError: If the model's ``model_type`` does not contain ``"qwen3"``. + """ + model_type = model.config.model_type.lower() + if "qwen3" not in model_type: + raise ValueError(f"Only Qwen3 is supported, got model_type={model_type!r}") + + AttentionClass = model.model.layers[0].self_attn.__class__ + + # Ensure every layer carries its own index (used by schedule functions). + for i, layer in enumerate(model.model.layers): + layer.self_attn.layer_idx = i + + def _apply_stem_forward(module: object) -> None: + """Bind the Stem ``attn_forward`` and config to each attention module.""" + if isinstance(module, AttentionClass): + module.attn_forward_config = config.attn_kwargs + module.forward = attn_forward.__get__(module, AttentionClass) + + model.apply(_apply_stem_forward) + return model diff --git a/angelslim/compressor/sparsity/stem/stem.py b/angelslim/compressor/sparsity/stem/stem.py new file mode 100644 index 00000000..2add3b04 --- /dev/null +++ b/angelslim/compressor/sparsity/stem/stem.py @@ -0,0 +1,49 @@ +# Copyright 2025 Tencent Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +"""High-level entry point for applying the Stem patch to a HuggingFace model.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from .patch import stem_patch +from .stem_configuration import StemConfig + +if TYPE_CHECKING: + from transformers import PreTrainedModel + + +class StemInference: + """Callable object that patches a model to use Stem sparse attention. + + Usage:: + + stem = StemInference(attn_kwargs={"backend": "hpc", "hpc_dtype": "fp8"}) + model = stem(model) + + Args: + attn_kwargs: Forwarded to ``StemConfig``. See its docstring for valid keys. + """ + + def __init__(self, attn_kwargs: dict | None = None) -> None: + self.config = StemConfig(attn_kwargs=attn_kwargs) + + def __call__(self, model: "PreTrainedModel") -> "PreTrainedModel": + """Apply the Stem attention patch and return the modified model.""" + return stem_patch(model, self.config) + + def __repr__(self) -> str: + return f"StemInference(config={self.config!r})" diff --git a/angelslim/compressor/sparsity/stem/stem_configuration.py b/angelslim/compressor/sparsity/stem/stem_configuration.py new file mode 100644 index 00000000..2bf80013 --- /dev/null +++ b/angelslim/compressor/sparsity/stem/stem_configuration.py @@ -0,0 +1,61 @@ +# Copyright 2025 Tencent Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +"""Configuration class for the Stem sparse attention module.""" + +from __future__ import annotations + +# Supported option values. +SUPPORTED_BACKENDS = {"torch", "hpc"} +SUPPORTED_HPC_DTYPES = {"bf16", "fp8"} + + +class StemConfig: + """Immutable configuration container for Stem attention. + + Args: + attn_kwargs: Dictionary of keyword arguments forwarded to the attention + backend. Recognised keys include: + + - ``backend``: ``"torch"`` (default) or ``"hpc"``. + - ``hpc_dtype``: ``"bf16"`` (default) or ``"fp8"`` (only used when + ``backend="hpc"``). + - ``stem_alpha``, ``block_size``, ``stride``, ``chunk_size``, etc. + + Raises: + ValueError: If *backend* or *hpc_dtype* is not in the supported set. + """ + + def __init__(self, attn_kwargs: dict | None = None) -> None: + self.attn_kwargs: dict = dict(attn_kwargs or {}) + self.attn_kwargs.setdefault("backend", "torch") + self.attn_kwargs.setdefault("hpc_dtype", "bf16") + + backend = self.attn_kwargs["backend"] + hpc_dtype = self.attn_kwargs["hpc_dtype"] + + if backend not in SUPPORTED_BACKENDS: + raise ValueError( + f"Unsupported stem backend: {backend!r}. " + f"Choose from {sorted(SUPPORTED_BACKENDS)}." + ) + if hpc_dtype not in SUPPORTED_HPC_DTYPES: + raise ValueError( + f"Unsupported hpc_dtype: {hpc_dtype!r}. " + f"Choose from {sorted(SUPPORTED_HPC_DTYPES)}." + ) + + def __repr__(self) -> str: + return f"StemConfig(attn_kwargs={self.attn_kwargs!r})" diff --git a/docs/source/features/sparse_attention/index.md b/docs/source/features/sparse_attention/index.md new file mode 100644 index 00000000..1352d628 --- /dev/null +++ b/docs/source/features/sparse_attention/index.md @@ -0,0 +1,10 @@ +# 稀疏注意力 + +稀疏注意力(Sparse Attention)是 AngelSlim 针对长上下文大模型推理开发的 Prefill 加速模块。其核心目标是在推理过程中动态跳过不重要的注意力块,从而显著降低 Prefill 阶段的计算量与延迟。 + +:::{toctree} +:caption: Contents +:maxdepth: 1 + +stem +::: diff --git a/docs/source/features/sparse_attention/stem.md b/docs/source/features/sparse_attention/stem.md new file mode 100644 index 00000000..edecc726 --- /dev/null +++ b/docs/source/features/sparse_attention/stem.md @@ -0,0 +1,157 @@ +# Stem: Rethinking Causal Information Flow in Sparse Attention + +**Stem** 是 AngelSlim 的稀疏注意力算法,用于加速长上下文 LLM 的 **Prefill** 阶段。它通过在 block 粒度上估计注意力重要性,动态选择 top-k 关键块执行 block-sparse attention,在保持生成质量的同时大幅降低 Prefill 延迟。 + +## 1. 算法动机 + +长上下文推理(如 32K–128K tokens)中,Prefill 阶段的全量 attention 计算是主要瓶颈: + +- 计算量随序列长度 **二次增长**,显存和延迟双重压力 +- 实际上大部分 attention block 对最终输出贡献极小,存在大量冗余 + +Stem 的核心思路是:**先用低成本的 block-level scoring 估计每个 attention block 的重要性,再只对重要的 block 执行精确 attention**。 + +## 2. 技术原理 + +Stem 的 Prefill 过程分为三步: + +### 2.1 Block-Level Scoring + +使用 **Triton 加速的 strided group GEMM** 计算下采样的 Q·K^T 分数矩阵,并结合 value-norm bonus 项,得到每个 query-block 对每个 key-block 的重要性估计: + +$$\text{score}(Q_i, K_j) = \frac{Q_i \cdot K_j^T}{\sqrt{d} \cdot s \cdot n} + \lambda \cdot \text{ReLU}(\bar{v}_j)$$ + +其中 $s$ 为 stride 因子,$n$ 为归一化系数,$\bar{v}_j$ 为 value-norm 的标准化对数值。 + +### 2.2 Top-k Schedule + +每层根据预设的 keep-ratio 和 alpha 衰减因子,生成 per-block 的 top-k budget: + +- **前 N 层(warmup)**:alpha=1.0,保留更多 block 以保证底层特征提取的完整性 +- **后续层**:alpha=0.7,更激进的稀疏化以加速计算 +- 额外保证 **initial blocks**(sink tokens)和 **sliding window** blocks 始终被保留 + +### 2.3 Block-Sparse Attention + +根据 top-k mask 执行稀疏 attention: + +- 如果安装了 `block-sparse-attn` 库,使用真正的 block-sparse kernel +- 否则自动 fallback 到 pseudo-sparse 实现(展开 mask 后做 dense attention) +- **HPC 后端**支持 bf16 dense prefill 和 fp8 block-sparse prefill(varlen / paged 两种路径) + +**Decode 阶段不受影响**,仍使用模型原始的 attention 实现(FlashAttention-2 / eager / SDPA)。 + +## 3. 支持范围 + +| 维度 | 支持情况 | +|------|---------| +| **后端** | `torch`(纯 PyTorch + Triton)、`hpc`(HPC C++ 扩展) | +| **HPC 精度** | bf16(dense prefill)、fp8(block-sparse prefill,varlen / paged) | +| **序列长度** | 无上限,建议 4K+ tokens 以体现加速效果 | + +## 4. 快速开始 + +确保已安装 AngelSlim(`pip install -e .` 或 `uv sync`),然后在项目根目录运行: + +### Dense 对照(无 Stem patch) + +```bash +python tools/run_stem.py \ + --mode dense \ + --model-path /path/to/Qwen3-8B \ + --prompt-file prompt.txt \ + --max-new-tokens 160 +``` + +### Stem + HPC bf16 + +```bash +python tools/run_stem.py \ + --mode stem \ + --stem-backend hpc \ + --hpc-dtype bf16 \ + --model-path /path/to/Qwen3-8B \ + --prompt-file prompt.txt \ + --max-new-tokens 160 +``` + +### Stem + HPC fp8 + +```bash +python tools/run_stem.py \ + --mode stem \ + --stem-backend hpc \ + --hpc-dtype fp8 \ + --model-path /path/to/Qwen3-8B \ + --prompt-file prompt.txt \ + --max-new-tokens 160 +``` + +### 使用自定义 prompt + +```bash +python tools/run_stem.py \ + --mode stem \ + --stem-backend hpc \ + --hpc-dtype bf16 \ + --model-path /path/to/Qwen3-8B \ + --prompt-file my_long_document.txt \ + --max-new-tokens 256 +``` + +也可以通过封装脚本启动: + +```bash +bash scripts/sparsity/run_stem.sh /path/to/Qwen3-8B prompt.txt stem +``` + +## 5. 参数说明 + +| 参数 | 默认值 | 说明 | +|------|--------|------| +| `backend` | `"torch"` | 后端选择:`"torch"` 或 `"hpc"` | +| `hpc_dtype` | `"bf16"` | HPC 后端精度:`"bf16"` 或 `"fp8"` | +| `hpc_fp8_path` | `"paged"` | FP8 执行路径:`"varlen"` 或 `"paged"` | +| `stem_alpha` | `1.0` | per-layer alpha 衰减因子,可传 list 实现分层控制 | +| `block_size` | `128` | attention block 大小 | +| `stride` | `8` | scoring 阶段的下采样步长 | +| `chunk_size` | `2048` | scoring 阶段的分块宽度 | +| `norm` | `1.0` | scoring 阶段的额外归一化系数 | +| `initial_blocks` | `4` | 始终保留的头部 block 数量(sink tokens) | +| `window_size` | `4` | sliding window 保留的尾部 block 数量 | + +## 6. 代码结构 + +``` +angelslim/compressor/sparsity/ +├── __init__.py # re-export StemInference +└── stem/ + ├── __init__.py # 包入口 + ├── stem.py # StemInference 类(主入口) + ├── patch.py # 模型 patch 逻辑 + ├── stem_configuration.py # StemConfig 配置 + ├── backends/ + │ ├── dispatcher.py # torch / hpc 路由 + │ ├── torch_impl.py # PyTorch + Triton 实现 + │ └── hpc_impl.py # HPC C++ 扩展实现 + ├── modules/ + │ └── forward.py # patched attention forward + └── ops/ + └── stem_kernel.py # Triton kernel + +tools/run_stem.py # 推理入口 +scripts/sparsity/run_stem.sh # 启动脚本 +``` + +## 7. Python API + +```python +from angelslim.compressor.sparsity import StemInference + +stem = StemInference(attn_kwargs={ + "backend": "hpc", + "hpc_dtype": "fp8", + "stem_alpha": [1.0] * 5 + [0.7] * 31, # 36 层 Qwen3-8B +}) +model = stem(model) # 返回 patched 后的同一个 model 对象 +``` diff --git a/docs/source/index.md b/docs/source/index.md index be52d176..50fd915a 100644 --- a/docs/source/index.md +++ b/docs/source/index.md @@ -63,9 +63,10 @@ AngelSlim是腾讯自研的,致力于打造更易用、更全面和更高效 - Tequila - - Eagle3 - SpecExit - - - **稀疏注意力** + - - **稀疏注意力** - - Minference(建设中) + - Stem + - Minference(建设中) * - **图/视频生文(VLM)** - - Hunyuan-VL - HunyuanOCR @@ -129,8 +130,8 @@ getting_started/quickstrat features/quantization/index features/speculative_decoding/index +features/sparse_attention/index features/diffusion/index -features/token_compressor/index.md ::: % Additional capabilities diff --git a/scripts/sparsity/run_stem.sh b/scripts/sparsity/run_stem.sh new file mode 100755 index 00000000..904aa907 --- /dev/null +++ b/scripts/sparsity/run_stem.sh @@ -0,0 +1,23 @@ +#!/usr/bin/env bash +# Stem sparse-attention launch script. +# +# Usage: +# bash scripts/sparsity/run_stem.sh /path/to/Qwen3-8B prompt.txt stem +# bash scripts/sparsity/run_stem.sh /path/to/Qwen3-8B prompt.txt dense +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +ROOT_DIR="$(cd "$SCRIPT_DIR/../.." && pwd)" + +MODEL_PATH="${1:-Qwen/Qwen3-8B}" +PROMPT_FILE="${2:?Usage: $0 [stem|dense] [extra args...]}" +MODE="${3:-stem}" +shift 3 2>/dev/null || true + +cd "$ROOT_DIR" +exec python -u tools/run_stem.py \ + --model-path "$MODEL_PATH" \ + --model-name "Qwen/Qwen3-8B" \ + --prompt-file "$PROMPT_FILE" \ + --mode "$MODE" \ + "$@" diff --git a/tools/run_stem.py b/tools/run_stem.py new file mode 100644 index 00000000..aecb4fc7 --- /dev/null +++ b/tools/run_stem.py @@ -0,0 +1,232 @@ +# Copyright 2025 Tencent Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Stem sparse attention inference script. + +Usage: + python tools/run_stem.py --mode stem --model-path /path/to/Qwen3-8B --prompt-file prompt.txt + python tools/run_stem.py --mode dense --model-path /path/to/Qwen3-8B --prompt-file prompt.txt + python tools/run_stem.py --mode stem --stem-backend hpc \ + --hpc-dtype fp8 --model-path /path/to/Qwen3-8B --prompt-file prompt.txt +""" + +import argparse +import sys +import time +import traceback + +import torch +from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer + +from angelslim.compressor.sparsity.stem import StemInference + +DEFAULT_MODEL_PATH = "Qwen/Qwen3-8B" +DEFAULT_MODEL_NAME = "Qwen/Qwen3-8B" +DEFAULT_MAX_MODEL_LEN = 131072 + + +def build_stem_alpha_schedule( + num_layers: int, + warmup_layers: int = 5, + warmup_alpha: float = 1.0, + steady_alpha: float = 0.7, +) -> list[float]: + if num_layers <= 0: + raise ValueError(f"num_layers must be positive, got {num_layers}") + if warmup_layers < 0 or warmup_layers > num_layers: + raise ValueError(f"warmup_layers must be in [0, {num_layers}], got {warmup_layers}") + return [warmup_alpha] * warmup_layers + [steady_alpha] * (num_layers - warmup_layers) + + +def build_prompt(tokenizer, raw_prompt: str) -> str: + messages = [{"role": "user", "content": raw_prompt}] + try: + return tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True, + enable_thinking=False, + ) + except TypeError: + return tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True, + ) + + +def parse_args(): + parser = argparse.ArgumentParser( + description="Stem sparse attention inference: Dense vs Stem on Qwen3.", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument( + "--mode", + type=str, + default="stem", + choices=["dense", "stem"], + help="Attention mode: 'dense' (no patch) or 'stem' (Stem sparse).", + ) + parser.add_argument( + "--stem-backend", + type=str, + default="torch", + choices=["torch", "hpc"], + help="Stem backend: 'torch' (PyTorch + Triton) or 'hpc' (HPC C++ extension).", + ) + parser.add_argument( + "--hpc-dtype", + type=str, + default="bf16", + choices=["bf16", "fp8"], + help="HPC backend precision.", + ) + parser.add_argument( + "--hpc-fp8-path", + type=str, + default="paged", + choices=["paged", "varlen"], + help="HPC FP8 execution path.", + ) + parser.add_argument( + "--model-path", + type=str, + default=DEFAULT_MODEL_PATH, + help="Path to the model directory or HuggingFace model ID.", + ) + parser.add_argument( + "--model-name", + type=str, + default=DEFAULT_MODEL_NAME, + help="Model name (used to determine rope_scaling config).", + ) + parser.add_argument( + "--max-model-len", + type=int, + default=DEFAULT_MAX_MODEL_LEN, + help="Maximum position embeddings length.", + ) + parser.add_argument( + "--prompt-file", + type=str, + required=True, + help="Path to a text file containing the input prompt.", + ) + parser.add_argument( + "--max-new-tokens", type=int, default=256, help="Maximum number of new tokens to generate." + ) + return parser.parse_args() + + +def main(): + args = parse_args() + + print(f"[Env] Python executable: {sys.executable}") + print(f"Loading tokenizer from {args.model_path} ...") + tokenizer = AutoTokenizer.from_pretrained(args.model_path, trust_remote_code=True) + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + + print("Loading model config ...") + config = AutoConfig.from_pretrained(args.model_path, trust_remote_code=True) + if "Qwen" in args.model_name or "qwen" in args.model_name.lower(): + rope_theta = ( + config.rope_scaling.get("rope_theta", 1000000) if config.rope_scaling else 1000000 + ) + config.rope_scaling = { + "rope_type": "yarn", + "rope_theta": rope_theta, + "factor": 4.0, + "original_max_position_embeddings": 32768, + } + config.max_position_embeddings = args.max_model_len + + print("Loading model weights ...") + model = AutoModelForCausalLM.from_pretrained( + args.model_path, + config=config, + torch_dtype=torch.bfloat16, + device_map="cuda", + trust_remote_code=True, + attn_implementation="flash_attention_2", + ) + + if args.mode == "stem": + num_layers = int(model.config.num_hidden_layers) + stem_alpha = build_stem_alpha_schedule( + num_layers=num_layers, + warmup_layers=5, + warmup_alpha=1.0, + steady_alpha=0.7, + ) + minf = StemInference( + attn_kwargs={ + "backend": args.stem_backend, + "hpc_dtype": args.hpc_dtype, + "hpc_fp8_path": args.hpc_fp8_path, + "stem_alpha": stem_alpha, + }, + ) + model = minf(model) + msg = ( + f"[Stem] Patch applied. backend={args.stem_backend}, " + f"hpc_dtype={args.hpc_dtype}, num_layers={num_layers}" + ) + print(msg) + else: + print("[Dense] No Stem patch applied. Using standard flash_attention_2.") + + with open(args.prompt_file, "r", encoding="utf-8") as f: + raw_prompt = f.read() + print(f"Loaded prompt from: {args.prompt_file}") + + prompt = build_prompt(tokenizer, raw_prompt) + inputs = tokenizer(prompt, return_tensors="pt").to(model.device) + input_len = int(inputs.input_ids.shape[1]) + print(f"[Input Stats] token_length={input_len}") + + print("Generating ...") + start = time.time() + try: + with torch.no_grad(): + out = model.generate( + **inputs, + max_new_tokens=args.max_new_tokens, + do_sample=False, + pad_token_id=tokenizer.eos_token_id, + ) + torch.cuda.synchronize() + end = time.time() + + gen_ids = out[0][input_len:] + gen_text = tokenizer.decode( + gen_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False + ) + + print("=" * 80) + print(f"Mode: {args.mode}") + print(f"Time Taken: {end - start:.3f} s") + print(f"Generated Tokens: {len(gen_ids)}") + print("Model Output:") + print(gen_text.strip()) + print("=" * 80) + except Exception as e: + print(f"[Error]: {e}") + traceback.print_exc() + raise + + +if __name__ == "__main__": + main()