diff --git a/examples/llm_sparsity/attention_sparsity/hf_sa.py b/examples/llm_sparsity/attention_sparsity/hf_sa.py index 74c5e9a54..9a8c694fc 100644 --- a/examples/llm_sparsity/attention_sparsity/hf_sa.py +++ b/examples/llm_sparsity/attention_sparsity/hf_sa.py @@ -31,6 +31,7 @@ from modelopt.torch.sparsity.attention_sparsity.config import ( SKIP_SOFTMAX_CALIB, SKIP_SOFTMAX_DEFAULT, + SPARSE24_TRITON, ) from modelopt.torch.utils.memory_monitor import launch_memory_monitor @@ -43,6 +44,7 @@ SPARSE_ATTN_CFG_CHOICES = { "skip_softmax": SKIP_SOFTMAX_DEFAULT, "skip_softmax_calib": SKIP_SOFTMAX_CALIB, + "sparse24_triton": SPARSE24_TRITON, } @@ -144,12 +146,14 @@ def main(args): print(f"Loading model: {args.pyt_ckpt_path}") - # Load model and tokenizer - # Note: attn_implementation="eager" is required for calibration to work properly - # (flash_attention_2 or sdpa would bypass the softmax patching needed for stats collection) + # Select attn_implementation based on sparse method: + # - skip_softmax methods require "eager" (softmax patching bypassed by flash/sdpa) + # - sparse24_triton requires "modelopt_triton" (fused Triton kernel) + # No need to specify attn_implementation here — mtsa.sparsify() handles it + # automatically based on the sparse config (sets "modelopt_triton" for triton + # backend, keeps "eager" for pytorch backend). model = AutoModelForCausalLM.from_pretrained( args.pyt_ckpt_path, - attn_implementation="eager", torch_dtype=torch.bfloat16, ) tokenizer = AutoTokenizer.from_pretrained(args.pyt_ckpt_path) @@ -246,8 +250,8 @@ def main(args): "--backend", type=str, default="pytorch", - choices=["pytorch"], - help="Backend for sparse attention (default: pytorch). More backends coming soon.", + choices=["pytorch", "triton"], + help="Backend for sparse attention (default: pytorch). Use 'triton' with sparse24_triton.", ) # Sequence length arguments diff --git a/modelopt/torch/sparsity/attention_sparsity/config.py b/modelopt/torch/sparsity/attention_sparsity/config.py index d2d3b1078..e178594ca 100644 --- a/modelopt/torch/sparsity/attention_sparsity/config.py +++ b/modelopt/torch/sparsity/attention_sparsity/config.py @@ -72,8 +72,8 @@ class SparseAttentionAttributeConfig(ModeloptBaseConfig): title="Backend implementation.", description=( "Backend to use for sparse attention computation. " - "Only 'pytorch' is supported, which uses softmax patching with F.softmax. " - "Requires model to be loaded with attn_implementation='eager'." + "'pytorch' uses softmax patching with F.softmax (requires attn_implementation='eager'). " + "'triton' uses the fused Triton kernel (requires attn_implementation='modelopt_triton')." ), ) @@ -89,10 +89,20 @@ class SparseAttentionAttributeConfig(ModeloptBaseConfig): description=( "Whether the model uses causal (autoregressive) attention. " "If True, sparsity statistics are calculated over the lower triangle only. " + "Set to False for cross-attention models. " "Defaults to True for decoder-only models like GPT, LLaMA, etc." ), ) + skip_diagonal_blocks: bool = ModeloptField( + default=True, + title="Skip diagonal blocks.", + description=( + "When True, keep diagonal tiles dense for 2:4 sparse attention. " + "Only used by sparse24_triton method. Defaults to True." + ), + ) + @field_validator("method") @classmethod def validate_method(cls, v): @@ -104,11 +114,12 @@ def validate_method(cls, v): @field_validator("backend") @classmethod def validate_backend(cls, v): - """Validate backend is pytorch.""" - if v != "pytorch": + """Validate backend is pytorch or triton.""" + if v not in ("pytorch", "triton"): raise ValueError( - f"Invalid backend: {v}. Only 'pytorch' backend is supported. " - f"Model must be loaded with attn_implementation='eager'." + f"Invalid backend: {v}. Supported backends: 'pytorch' (requires " + f"attn_implementation='eager'), 'triton' (requires " + f"attn_implementation='modelopt_triton')." ) return v @@ -416,10 +427,24 @@ class FlashSkipSoftmaxConfig(SparseAttentionConfig): }, } +# 2:4 structured sparsity via Triton prefill kernel (prefill-only) +SPARSE24_TRITON = { + "sparse_cfg": { + "*attn*": { + "method": "sparse24_triton", + "backend": "triton", + "skip_diagonal_blocks": True, + "enable": True, + }, + "default": {"enable": False}, + }, +} + __all__ = [ "SKIP_SOFTMAX_CALIB", "SKIP_SOFTMAX_DEFAULT", + "SPARSE24_TRITON", "CalibrationConfig", "FlashSkipSoftmaxConfig", "SparseAttentionAttributeConfig", diff --git a/modelopt/torch/sparsity/attention_sparsity/conversion.py b/modelopt/torch/sparsity/attention_sparsity/conversion.py index 2155a13d0..26fb4e08a 100644 --- a/modelopt/torch/sparsity/attention_sparsity/conversion.py +++ b/modelopt/torch/sparsity/attention_sparsity/conversion.py @@ -32,6 +32,37 @@ from .utils import get_named_sparse_attention_modules, get_sparse_attention_modules +def _register_triton_backend_if_needed(model: nn.Module, config: SparseAttentionConfig) -> None: + """Register the Triton attention backend and set attn_implementation if needed. + + When the config uses ``backend="triton"``, this function: + 1. Registers the Triton kernel with HF's ``ALL_ATTENTION_FUNCTIONS``. + 2. Sets ``model.config._attn_implementation = "modelopt_triton"`` so the + model dispatches to the Triton kernel at forward time. + + This is called automatically during ``mtsa.sparsify()`` so users never need + to manually call ``register_triton_attention()`` or set ``attn_implementation``. + """ + sparse_cfg = config.sparse_cfg if hasattr(config, "sparse_cfg") else {} + needs_triton = any( + isinstance(v, dict) and v.get("backend") == "triton" for v in sparse_cfg.values() + ) + if not needs_triton: + return + + from .kernels import register_triton_attention + + if register_triton_attention is not None: + register_triton_attention() + + # Set attn_implementation on the model so HF dispatches to the Triton kernel. + # HF's ALL_ATTENTION_FUNCTIONS is checked at forward time, not construction time, + # so this works even after the model is already loaded. + model_config = getattr(model, "config", None) + if model_config is not None: + model_config._attn_implementation = "modelopt_triton" + + def is_attn_sparsified(model: nn.Module) -> bool: """Check if a model has sparse attention applied. @@ -61,6 +92,9 @@ def convert_to_sparse_attention_model( # Initialize the true module if necessary model = model.init_modellike() if isinstance(model, ModelLikeModule) else model + # Register Triton attention backend and set attn_implementation if needed + _register_triton_backend_if_needed(model, config) + # Apply custom model plugins register_custom_model_plugins_on_the_fly(model) diff --git a/modelopt/torch/sparsity/attention_sparsity/kernels/__init__.py b/modelopt/torch/sparsity/attention_sparsity/kernels/__init__.py new file mode 100644 index 000000000..bf134dd2e --- /dev/null +++ b/modelopt/torch/sparsity/attention_sparsity/kernels/__init__.py @@ -0,0 +1,56 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Triton attention kernels for sparse attention optimization.""" + +import torch + +from modelopt.torch.utils import import_plugin + +IS_AVAILABLE = False +context_attention_fwd = None +register_triton_attention = None +set_sparse24 = None +unified_attention = None + +if torch.cuda.is_available(): + with import_plugin( + "triton", + msg_if_missing=( + "Your device is potentially capable of using the triton attention " + "kernel. Try to install triton with `pip install triton`." + ), + ): + from .triton_unified_attention import context_attention_fwd as _context_attention_fwd + from .triton_unified_attention import unified_attention as _unified_attention + + context_attention_fwd = _context_attention_fwd + unified_attention = _unified_attention + IS_AVAILABLE = True + with import_plugin("transformers"): + from .hf_triton_attention import register_triton_attention as _register_triton_attention + from .hf_triton_attention import set_sparse24 as _set_sparse24 + + register_triton_attention = _register_triton_attention + set_sparse24 = _set_sparse24 + _register_triton_attention() + +__all__ = [ + "IS_AVAILABLE", + "context_attention_fwd", + "register_triton_attention", + "set_sparse24", + "unified_attention", +] diff --git a/modelopt/torch/sparsity/attention_sparsity/kernels/hf_triton_attention.py b/modelopt/torch/sparsity/attention_sparsity/kernels/hf_triton_attention.py new file mode 100644 index 000000000..fb574db05 --- /dev/null +++ b/modelopt/torch/sparsity/attention_sparsity/kernels/hf_triton_attention.py @@ -0,0 +1,442 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Hugging Face attention backend for the Triton unified attention kernel. + +Registers the Triton kernel as attn_implementation="modelopt_triton" so HF models +use it natively without patching forward. Both prefill and decode use the unified +Triton kernel. +""" + +from __future__ import annotations + +import torch +import torch.nn as nn +import triton + +from modelopt.torch.sparsity.attention_sparsity.kernels.triton_unified_attention import ( + context_attention_fwd, + unified_attention, +) + +# Matches vLLM heuristics (vllm/v1/attention/backends/triton_attn.py) +_MIN_LAUNCH_GRID_SIZE_2D = 128 +_NUM_PAR_SOFTMAX_SEGMENTS = 16 + + +def _attention_mask_supported_for_triton(attention_mask: torch.Tensor) -> bool: + """Return True if mask shape is supported for packing (2D [batch, seq_len]).""" + return attention_mask.dim() == 2 and attention_mask.shape[0] > 0 and attention_mask.shape[1] > 0 + + +def _packed_token_indices( + seq_lens: torch.Tensor, + device: torch.device, +) -> tuple[torch.Tensor, torch.Tensor]: + """Compute vectorized (batch_idx, token_idx) for packing/unpacking variable-length sequences. + + Assumes valid tokens occupy positions ``0..seq_lens[b]-1`` in each batch + element (right-padded layout). This matches the HF convention where padding + tokens are appended after the valid content during prefill. + + Args: + seq_lens: [batch] number of valid tokens per sequence. + device: Target device. + + Returns: + (batch_indices, token_indices) each of shape [total_valid_tokens]. + """ + total = int(seq_lens.sum().item()) + cumsum = torch.zeros(seq_lens.shape[0] + 1, device=device, dtype=torch.long) + cumsum[1:] = torch.cumsum(seq_lens, dim=0) + flat_idx = torch.arange(total, device=device, dtype=torch.long) + batch_indices = torch.bucketize(flat_idx, cumsum[1:], right=True) + token_indices = flat_idx - cumsum[batch_indices] + return batch_indices, token_indices + + +def _derive_seq_lens_and_pack( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, int]: + """Derive b_seq_len and b_start_loc from 2D mask; pack q,k,v to contiguous [total, heads, dim]. + + attention_mask: [batch, seq_len], 1 = valid, 0 = pad. Assumes valid tokens are + at positions 0..n-1 (right-padded layout). The count of valid tokens per row + determines the packing lengths. + Returns: (q_packed, k_packed, v_packed, b_start_loc, b_seq_len, max_input_len). + """ + batch = query.shape[0] + device = query.device + # Valid length per batch: number of ones (or non-zero) in the mask per row + if attention_mask.dtype == torch.bool: + seq_lens = attention_mask.sum(dim=1).long() + else: + seq_lens = (attention_mask != 0).sum(dim=1).long() + seq_lens = seq_lens.to(device) + b_start_loc = torch.zeros(batch + 1, device=device, dtype=torch.int32) + b_start_loc[1:] = torch.cumsum(seq_lens, dim=0) + b_start_loc = b_start_loc[:batch] + b_seq_len = seq_lens.to(torch.int32) + max_input_len = int(seq_lens.max().item()) + + # Vectorized packing: query [batch, heads, seq, dim] -> [total, heads, dim] + batch_indices, token_indices = _packed_token_indices(seq_lens, device) + q_packed = query[batch_indices, :, token_indices, :].contiguous() + k_packed = key[batch_indices, :, token_indices, :].contiguous() + v_packed = value[batch_indices, :, token_indices, :].contiguous() + return q_packed, k_packed, v_packed, b_start_loc, b_seq_len, max_input_len + + +def _unpack_attn_output( + o_packed: torch.Tensor, + batch: int, + num_heads: int, + head_dim: int, + seq_len: int, + b_seq_len: torch.Tensor, + dtype: torch.dtype, + device: torch.device, +) -> torch.Tensor: + """Scatter packed output [total_tokens, num_heads, head_dim] to [batch, seq_len, num_heads, head_dim].""" + attn_output = torch.zeros(batch, seq_len, num_heads, head_dim, device=device, dtype=dtype) + total = int(b_seq_len.sum().item()) + if total == 0: + return attn_output + batch_indices, token_indices = _packed_token_indices(b_seq_len.long(), device) + attn_output[batch_indices, token_indices] = o_packed + return attn_output + + +def _get_or_create_segm_buffers( + module: nn.Module, + num_heads: int, + num_kv_heads: int, + head_dim: int, + device: torch.device, +) -> tuple[int, torch.Tensor, torch.Tensor, torch.Tensor]: + """Get or lazily create 3D segment-parallel buffers cached on the module. + + Buffers are allocated once at ``seq_threshold_3D`` size (the max batch that + uses the 3D kernel) and reused across decode steps, matching vLLM's + pre-allocation pattern. + + Returns: + (seq_threshold_3D, segm_output, segm_max, segm_expsum) + """ + seq_threshold_3D = _MIN_LAUNCH_GRID_SIZE_2D // num_kv_heads + head_size_padded = max(triton.next_power_of_2(head_dim), 16) + + # Check if cached buffers are compatible + cached = getattr(module, "_segm_buffers", None) + if cached is not None: + c_threshold, c_output, _, _ = cached + if ( + c_threshold == seq_threshold_3D + and c_output.shape[1] == num_heads + and c_output.shape[3] == head_size_padded + and c_output.device == device + ): + return cached + + segm_output = torch.empty( + (seq_threshold_3D, num_heads, _NUM_PAR_SOFTMAX_SEGMENTS, head_size_padded), + dtype=torch.float32, + device=device, + ) + segm_max = torch.empty( + (seq_threshold_3D, num_heads, _NUM_PAR_SOFTMAX_SEGMENTS), + dtype=torch.float32, + device=device, + ) + segm_expsum = torch.empty( + (seq_threshold_3D, num_heads, _NUM_PAR_SOFTMAX_SEGMENTS), + dtype=torch.float32, + device=device, + ) + buffers = (seq_threshold_3D, segm_output, segm_max, segm_expsum) + module._segm_buffers = buffers + return buffers + + +def _decode_attention( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: torch.Tensor | None, + scaling: float, +) -> torch.Tensor: + """Decode attention via unified_attention kernel (one query token per sequence). + + Uses the 3D segment-parallel kernel when batch size is small enough + (matching vLLM heuristic: ``num_seqs <= 128 // num_kv_heads``), otherwise + falls back to the 2D kernel. Segment buffers are cached on ``module`` and + reused across decode steps. + + Args: + module: The attention module; used to cache 3D segment buffers. + query: [batch, num_heads, 1, head_dim]. + key: [batch, num_kv_heads, seq_k, head_dim]. + value: [batch, num_kv_heads, seq_k, head_dim]. + attention_mask: Optional 2D [batch, seq_k] mask; 1=valid, 0=pad. + scaling: Softmax scale. + + Returns: + attn_output: [batch, 1, num_heads, head_dim]. + """ + batch = query.shape[0] + num_heads = query.shape[1] + num_kv_heads = key.shape[1] + seq_k = key.shape[2] + head_dim = query.shape[3] + device = query.device + dtype = query.dtype + + # Q: [batch, heads, 1, dim] -> [batch, heads, dim] + q = query.squeeze(2).contiguous() + + # Page K/V into cache: [batch, kv_heads, seq_k, dim] -> [batch, block_size, kv_heads, dim] + block_size = ((seq_k + 31) // 32) * 32 + k_cache = torch.zeros(batch, block_size, num_kv_heads, head_dim, device=device, dtype=dtype) + v_cache = torch.zeros(batch, block_size, num_kv_heads, head_dim, device=device, dtype=dtype) + k_cache[:, :seq_k] = key.permute(0, 2, 1, 3) + v_cache[:, :seq_k] = value.permute(0, 2, 1, 3) + + # Derive per-sequence KV lengths from attention_mask if present + if attention_mask is not None and _attention_mask_supported_for_triton(attention_mask): + if attention_mask.dtype == torch.bool: + seqused_k = attention_mask.sum(dim=1).to(torch.int32).to(device) + else: + seqused_k = (attention_mask != 0).sum(dim=1).to(torch.int32).to(device) + else: + seqused_k = torch.full((batch,), seq_k, device=device, dtype=torch.int32) + + cu_seqlens_q = torch.arange(batch + 1, device=device, dtype=torch.int32) + block_table = torch.arange(batch, device=device, dtype=torch.int32).unsqueeze(1) + + # 3D segment-parallel kernel for small batches (matches vLLM heuristic) + seq_threshold_3D, segm_output, segm_max, segm_expsum = _get_or_create_segm_buffers( + module, num_heads, num_kv_heads, head_dim, device + ) + if batch > seq_threshold_3D: + seq_threshold_3D = None + segm_output = segm_max = segm_expsum = None + + out = torch.empty_like(q) + unified_attention( + q=q, + k=k_cache, + v=v_cache, + out=out, + cu_seqlens_q=cu_seqlens_q, + max_seqlen_q=1, + seqused_k=seqused_k, + max_seqlen_k=block_size, + softmax_scale=scaling, + causal=True, + window_size=(-1, -1), + block_table=block_table, + seq_threshold_3D=seq_threshold_3D, + num_par_softmax_segments=_NUM_PAR_SOFTMAX_SEGMENTS, + softmax_segm_output=segm_output, + softmax_segm_max=segm_max, + softmax_segm_expsum=segm_expsum, + ) + + # [batch, heads, dim] -> [batch, 1, heads, dim] + return out.unsqueeze(1) + + +def triton_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: torch.Tensor | None, + scaling: float, + dropout: float = 0.0, + **kwargs, +) -> tuple[torch.Tensor, None]: + """Attention forward compatible with HF AttentionInterface. + + Uses the unified Triton kernel for both prefill (seq_len > 1) and decode + (seq_len == 1). Same signature as eager_attention_forward. + + Args: + module: The attention module (LlamaAttention etc.). + query: [batch, num_heads, seq_len, head_dim]. + key: [batch, num_kv_heads, seq_k, head_dim]. + value: [batch, num_kv_heads, seq_k, head_dim]. + attention_mask: Optional; kernel handles causal internally. + 2D [batch, seq_k] masks are used to derive per-sequence lengths. + Unsupported formats raise an error. + scaling: Softmax scale (e.g. 1/sqrt(head_dim)). + dropout: Ignored (kernel has no dropout); use 0 for eval. + **kwargs: May contain apply_sparse24, skip_diagonal_blocks for 2:4 sparse attention. + + Returns: + (attn_output, None) with attn_output [batch, seq_len, num_heads, head_dim]. + """ + batch, num_heads, seq_len, head_dim = query.shape + seq_k = key.shape[2] + is_cross_attention = seq_len != seq_k + + # Decode: one query token per sequence, full context in K/V + if seq_len <= 1: + attn_output = _decode_attention(module, query, key, value, attention_mask, scaling) + return (attn_output, None) + + device = query.device + num_kv_heads = key.shape[1] + is_causal = not is_cross_attention + apply_sparse24 = kwargs.get("apply_sparse24", getattr(module, "_apply_sparse24", False)) + skip_diagonal_blocks = kwargs.get( + "skip_diagonal_blocks", getattr(module, "_skip_diagonal_blocks", True) + ) + + use_packed = attention_mask is not None and _attention_mask_supported_for_triton(attention_mask) + if use_packed: + q_packed, k_packed, v_packed, b_start_loc, b_seq_len, max_input_len = ( + _derive_seq_lens_and_pack(query, key, value, attention_mask) + ) + o_packed = torch.empty_like(q_packed) + context_attention_fwd( + q_packed, + k_packed, + v_packed, + o_packed, + b_start_loc=b_start_loc, + b_seq_len=b_seq_len, + max_input_len=max_input_len, + is_causal=is_causal, + softmax_scale=scaling, + apply_sparse24=apply_sparse24, + skip_diagonal_blocks=skip_diagonal_blocks, + ) + attn_output = _unpack_attn_output( + o_packed, + batch, + num_heads, + head_dim, + seq_len, + b_seq_len, + query.dtype, + device, + ) + return (attn_output, None) + if attention_mask is not None: + raise ValueError( + f"Unsupported attention_mask format for modelopt_triton: " + f"dim={attention_mask.dim()}, shape={attention_mask.shape}. " + f"Only 2D [batch, seq_len] masks are supported." + ) + + q = query.permute(0, 2, 1, 3).reshape(-1, num_heads, head_dim).contiguous() + k = key.permute(0, 2, 1, 3).reshape(-1, num_kv_heads, head_dim).contiguous() + v = value.permute(0, 2, 1, 3).reshape(-1, num_kv_heads, head_dim).contiguous() + b_start_loc_q = torch.arange(batch, device=device, dtype=torch.int32) * seq_len + b_seq_len_q = torch.full((batch,), seq_len, device=device, dtype=torch.int32) + + if is_cross_attention: + b_start_loc_k = torch.arange(batch, device=device, dtype=torch.int32) * seq_k + b_seq_len_k = torch.full((batch,), seq_k, device=device, dtype=torch.int32) + else: + b_start_loc_k = None + b_seq_len_k = None + + o = torch.empty_like(q) + context_attention_fwd( + q, + k, + v, + o, + b_start_loc=b_start_loc_q, + b_seq_len=b_seq_len_q, + max_input_len=seq_len, + is_causal=is_causal, + softmax_scale=scaling, + apply_sparse24=apply_sparse24, + skip_diagonal_blocks=skip_diagonal_blocks, + b_start_loc_k=b_start_loc_k, + b_seq_len_k=b_seq_len_k, + max_input_len_k=seq_k if is_cross_attention else None, + ) + attn_output = o.view(batch, seq_len, num_heads, head_dim) + return (attn_output, None) + + +def register_triton_attention() -> bool: + """Register the Triton backend with HF AttentionInterface. + + Call after importing this module so that attn_implementation="modelopt_triton" + is available when loading models. + + Returns: + True if registration succeeded. + """ + try: + from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS + + ALL_ATTENTION_FUNCTIONS.register("modelopt_triton", triton_attention_forward) + return True + except Exception: + return False + + +def set_sparse24( + model: nn.Module, + apply_sparse24: bool = True, + skip_diagonal_blocks: bool = True, +) -> None: + """Set 2:4 sparse attention on all attention modules in the model. + + Prefer using ``mtsa.sparsify(model, SPARSE24_TRITON)`` from + ``modelopt.torch.sparsity.attention_sparsity`` for config-driven setup, + pattern-based layer selection, and consistency with other sparse methods. + This helper remains for backward compatibility and one-off scripting. + + The Triton backend reads ``getattr(module, '_apply_sparse24', False)`` and + ``getattr(module, '_skip_diagonal_blocks', True)`` when kwargs don't provide them. + + Limitations: + - **Prefill-only sparsity:** 2:4 sparsity is applied during prefill only; + decode uses the unified kernel without sparsity. + - **Fixed 50% sparsity:** 2:4 keeps top 2 of every 4 attention scores; + no threshold tuning or calibration. + - **Mutually exclusive with flash_skip_softmax:** sparse24 requires + ``attn_implementation="modelopt_triton"``; flash_skip_softmax requires + ``attn_implementation="eager"``. They cannot be combined in one model. + + Args: + model: Hugging Face model (e.g. LlamaForCausalLM). + apply_sparse24: Whether to apply 2:4 sparsity to attention scores. + skip_diagonal_blocks: If True, keep diagonal tiles dense (local attention). + """ + for _, module in model.named_modules(): + # Match only actual attention modules (have o_proj + head_dim), not their children + # like q_proj, k_proj, v_proj, rotary_emb, etc. + if hasattr(module, "o_proj") and hasattr(module, "head_dim"): + setattr(module, "_apply_sparse24", apply_sparse24) + setattr(module, "_skip_diagonal_blocks", skip_diagonal_blocks) + + +__all__ = [ + "register_triton_attention", + "set_sparse24", + "triton_attention_forward", +] diff --git a/modelopt/torch/sparsity/attention_sparsity/kernels/triton_unified_attention.py b/modelopt/torch/sparsity/attention_sparsity/kernels/triton_unified_attention.py new file mode 100644 index 000000000..37ee64c73 --- /dev/null +++ b/modelopt/torch/sparsity/attention_sparsity/kernels/triton_unified_attention.py @@ -0,0 +1,885 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Adapted from triton_unified_attention.py from +# https://github.com/vllm-project/vllm/blob/v0.15.0/vllm/v1/attention/ops/triton_unified_attention.py +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unified Triton attention for prefill and decode with paged KV cache. + +Supports variable sequence lengths, causal masking, GQA, and sliding window. +""" + +import torch +import triton +import triton.language as tl + + +@triton.jit +def cdiv_fn(x, y): + return (x + y - 1) // y + + +# 2:4 structured sparsity helpers (from vLLM flash_attn_triton_sparse24) +@triton.jit +def _sparse24_noabs_ops(x0, x1, x2, x3): + """Compute 2:4 sparsity mask: for every 4 values, determine which 2 are largest.""" + (a1, a2, a3, a4, a5, a6) = ( + x0 > x1, + x0 > x2, + x0 > x3, + x1 > x2, + x1 > x3, + x2 > x3, + ) + # Use (x == 0) instead of ~x to avoid interpreter bug with __invert__ on bool tensors + na1 = a1 == 0 + na2 = a2 == 0 + na3 = a3 == 0 + na4 = a4 == 0 + na5 = a5 == 0 + na6 = a6 == 0 + m0 = a2 & a3 | a1 & a2 | a1 & a3 + m1 = na1 & a5 | a4 & a5 | na1 & a4 + m2 = na2 & na4 | na2 & a6 | na4 & a6 + m3 = na3 & na5 | na3 & na6 | na5 & na6 + return x0, x1, x2, x3, m0, m1, m2, m3 + + +@triton.jit +def _apply_sparse24_to_qk_tile( + qk, + M: tl.constexpr, + N: tl.constexpr, + MASK_VAL: tl.constexpr, +): + """Apply 2:4 sparsity to attention score tile [M, N]: keep top 2 of every 4 along N.""" + reshaped = tl.reshape(qk, (M, N // 4, 4)) + cols = tl.arange(0, 4)[None, None, :] + x0 = tl.sum(tl.where(cols == 0, reshaped, 0.0), axis=2) + x1 = tl.sum(tl.where(cols == 1, reshaped, 0.0), axis=2) + x2 = tl.sum(tl.where(cols == 2, reshaped, 0.0), axis=2) + x3 = tl.sum(tl.where(cols == 3, reshaped, 0.0), axis=2) + _, _, _, _, m0, m1, m2, m3 = _sparse24_noabs_ops(x0, x1, x2, x3) + s0 = tl.where(m0, x0, MASK_VAL) + s1 = tl.where(m1, x1, MASK_VAL) + s2 = tl.where(m2, x2, MASK_VAL) + s3 = tl.where(m3, x3, MASK_VAL) + sparse_reshaped = tl.full((M, N // 4, 4), 0.0, dtype=qk.dtype) + sparse_reshaped = tl.where((cols == 0), tl.expand_dims(s0, 2), sparse_reshaped) + sparse_reshaped = tl.where((cols == 1), tl.expand_dims(s1, 2), sparse_reshaped) + sparse_reshaped = tl.where((cols == 2), tl.expand_dims(s2, 2), sparse_reshaped) + sparse_reshaped = tl.where((cols == 3), tl.expand_dims(s3, 2), sparse_reshaped) + sparse_qk = tl.reshape(sparse_reshaped, (M, N)) + return sparse_qk + + +@triton.jit +def find_seq_idx( + query_start_len_ptr, + target_idx, + num_seqs, + BLOCK_Q: tl.constexpr, + use_q_block_mode: tl.constexpr, +): + left: tl.int32 = 0 + right = num_seqs + while left < right: + mid = (left + right) // 2 + val = tl.load(query_start_len_ptr + mid) + mid_val = val // BLOCK_Q + mid if use_q_block_mode else val + + if mid_val <= target_idx: + left = mid + 1 + else: + right = mid + + return left - 1 + + +@triton.jit +def kernel_unified_attention_2d( + output_ptr, + query_ptr, + key_cache_ptr, + value_cache_ptr, + block_tables_ptr, + seq_lens_ptr, + scale, + num_query_heads: tl.constexpr, + num_queries_per_kv: tl.constexpr, + block_table_stride: tl.int64, + query_stride_0: tl.int64, + query_stride_1: tl.int64, + output_stride_0: tl.int64, + output_stride_1: tl.int64, + BLOCK_SIZE: tl.constexpr, + TILE_SIZE: tl.constexpr, + HEAD_SIZE: tl.constexpr, + HEAD_SIZE_PADDED: tl.constexpr, + SLIDING_WINDOW: tl.constexpr, + stride_k_cache_0: tl.int64, + stride_k_cache_1: tl.int64, + stride_k_cache_2: tl.int64, + stride_k_cache_3: tl.constexpr, + stride_v_cache_0: tl.int64, + stride_v_cache_1: tl.int64, + stride_v_cache_2: tl.int64, + stride_v_cache_3: tl.constexpr, + query_start_len_ptr, + BLOCK_Q: tl.constexpr, + num_seqs: tl.int32, + BLOCK_M: tl.constexpr, + APPLY_SPARSE24: tl.constexpr, + SKIP_DIAGONAL_BLOCKS: tl.constexpr, + CAUSAL: tl.constexpr, +): + q_block_global_idx = tl.program_id(0) + kv_head_idx = tl.program_id(1) + + seq_idx = find_seq_idx(query_start_len_ptr, q_block_global_idx, num_seqs, BLOCK_Q, True) + + q_block_start_idx = tl.load(query_start_len_ptr + seq_idx) // BLOCK_Q + seq_idx + q_block_local_idx = q_block_global_idx - q_block_start_idx + + cur_batch_in_all_start_index = tl.load(query_start_len_ptr + seq_idx) + cur_batch_in_all_stop_index = tl.load(query_start_len_ptr + seq_idx + 1) + cur_batch_query_len = cur_batch_in_all_stop_index - cur_batch_in_all_start_index + + if q_block_local_idx * BLOCK_Q >= cur_batch_query_len: + return + + offs_m = tl.arange(0, BLOCK_M) + offs_d = tl.arange(0, HEAD_SIZE_PADDED) + offs_t = tl.arange(0, TILE_SIZE) + query_pos = q_block_local_idx * BLOCK_Q + offs_m // num_queries_per_kv + + query_offset_0 = cur_batch_in_all_start_index + query_pos + query_offset_1 = kv_head_idx * num_queries_per_kv + offs_m % num_queries_per_kv + query_offset = ( + query_offset_0[:, None] * query_stride_0 + + query_offset_1[:, None] * query_stride_1 + + offs_d[None, :] + ) + + dim_mask = tl.where(offs_d < HEAD_SIZE, 1, 0).to(tl.int1) + query_mask_0 = tl.where(query_pos < cur_batch_query_len, 1, 0).to(tl.int1) + query_mask_1 = tl.where(query_offset_1 < num_query_heads, 1, 0).to(tl.int1) + + Q = tl.load( + query_ptr + query_offset, + mask=dim_mask[None, :] & query_mask_0[:, None] & query_mask_1[:, None], + other=0.0, + ) + + block_table_offset = seq_idx * block_table_stride + M = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32) + L = tl.full([BLOCK_M], 1.0, dtype=tl.float32) + acc = tl.zeros([BLOCK_M, HEAD_SIZE_PADDED], dtype=tl.float32) + + seq_len = tl.load(seq_lens_ptr + seq_idx) + context_len = seq_len - cur_batch_query_len + + if CAUSAL: + # Causal: only attend up to the query position + max_seq_prefix_len = ( + context_len + q_block_local_idx * BLOCK_Q + (BLOCK_M - 1) // num_queries_per_kv + 1 + ) + max_seq_prefix_len = tl.minimum(max_seq_prefix_len, seq_len) + else: + # Non-causal (cross-attention): attend to all K/V positions + max_seq_prefix_len = seq_len + num_tiles = cdiv_fn(max_seq_prefix_len, TILE_SIZE) + + tile_start = 0 + tile_end = num_tiles + if CAUSAL and SLIDING_WINDOW > 0: + qpos_lo = q_block_local_idx * BLOCK_Q + qpos_hi = tl.minimum( + qpos_lo + (BLOCK_M - 1) // num_queries_per_kv, + cur_batch_query_len - 1, + ) + first_allowed_key = context_len + qpos_lo - SLIDING_WINDOW + 1 + last_allowed_key = context_len + qpos_hi + tile_start = tl.maximum(0, first_allowed_key // TILE_SIZE) + tile_end = tl.minimum((last_allowed_key // TILE_SIZE) + 1, num_tiles) + + for j in range(tile_start, tile_end): + seq_offset = j * TILE_SIZE + offs_t + tile_mask = seq_offset < max_seq_prefix_len + + physical_block_idx = tl.load( + block_tables_ptr + block_table_offset + seq_offset // BLOCK_SIZE + ).to(tl.int64) + + v_offset = ( + physical_block_idx[:, None] * stride_v_cache_0 + + kv_head_idx * stride_v_cache_2 + + offs_d[None, :] * stride_v_cache_3 + + (seq_offset % BLOCK_SIZE)[:, None] * stride_v_cache_1 + ) + k_offset = ( + physical_block_idx[None, :] * stride_k_cache_0 + + kv_head_idx * stride_k_cache_2 + + offs_d[:, None] * stride_k_cache_3 + + (seq_offset % BLOCK_SIZE)[None, :] * stride_k_cache_1 + ) + + K = tl.load( + key_cache_ptr + k_offset, + mask=dim_mask[:, None] & tile_mask[None, :], + other=0.0, + ) + V = tl.load( + value_cache_ptr + v_offset, + mask=dim_mask[None, :] & tile_mask[:, None], + other=0.0, + ) + + if CAUSAL: + query_abs_pos = context_len + query_pos[:, None] + seq_mask = seq_offset[None, :] <= query_abs_pos + if SLIDING_WINDOW > 0: + seq_mask = seq_mask & ((query_abs_pos - seq_offset) < SLIDING_WINDOW) + else: + seq_mask = tile_mask[None, :] + + S = tl.zeros([BLOCK_M, TILE_SIZE], dtype=tl.float32) + S += scale * tl.dot(Q, K) + S = tl.where(query_mask_1[:, None] & query_mask_0[:, None] & seq_mask, S, float("-inf")) + + if APPLY_SPARSE24: + if CAUSAL and SKIP_DIAGONAL_BLOCKS: + tile_key_start = j * TILE_SIZE + tile_key_end = tile_key_start + TILE_SIZE + query_abs_start = context_len + q_block_local_idx * BLOCK_Q + query_abs_end = query_abs_start + BLOCK_Q + is_diagonal = (tile_key_start < query_abs_end) & (tile_key_end > query_abs_start) + if not is_diagonal: + S = _apply_sparse24_to_qk_tile(S, BLOCK_M, TILE_SIZE, float("-inf")) + else: + S = _apply_sparse24_to_qk_tile(S, BLOCK_M, TILE_SIZE, float("-inf")) + + m_j = tl.maximum(M, tl.max(S, axis=1)) + m_j = tl.where(m_j > float("-inf"), m_j, 0.0) + + P = tl.exp(S - m_j[:, None]) + l_j = tl.sum(P, axis=1) + alpha = tl.exp(M - m_j) + acc = acc * alpha[:, None] + L = L * alpha + l_j + M = m_j + + if CAUSAL and SLIDING_WINDOW > 0: + qpos_lo = q_block_local_idx * BLOCK_Q + V = tl.where((context_len + qpos_lo - seq_offset[:, None]) < SLIDING_WINDOW, V, 0.0) + acc += tl.dot(P.to(V.dtype), V) + + acc = acc / L[:, None] + output_offset = ( + query_offset_0[:, None] * output_stride_0 + + query_offset_1[:, None] * output_stride_1 + + offs_d[None, :] + ) + tl.store( + output_ptr + output_offset, + acc, + mask=dim_mask[None, :] & query_mask_0[:, None] & query_mask_1[:, None], + ) + + +@triton.jit +def kernel_unified_attention_3d( + segm_output_ptr, + segm_max_ptr, + segm_expsum_ptr, + query_ptr, + key_cache_ptr, + value_cache_ptr, + block_tables_ptr, + seq_lens_ptr, + scale, + num_query_heads: tl.constexpr, + num_queries_per_kv: tl.constexpr, + block_table_stride: tl.int64, + query_stride_0: tl.int64, + query_stride_1: tl.int64, + BLOCK_SIZE: tl.constexpr, + TILE_SIZE: tl.constexpr, + HEAD_SIZE: tl.constexpr, + HEAD_SIZE_PADDED: tl.constexpr, + SLIDING_WINDOW: tl.constexpr, + stride_k_cache_0: tl.int64, + stride_k_cache_1: tl.int64, + stride_k_cache_2: tl.int64, + stride_k_cache_3: tl.constexpr, + stride_v_cache_0: tl.int64, + stride_v_cache_1: tl.int64, + stride_v_cache_2: tl.int64, + stride_v_cache_3: tl.constexpr, + query_start_len_ptr, + BLOCK_Q: tl.constexpr, + num_seqs: tl.int32, + BLOCK_M: tl.constexpr, + NUM_SEGMENTS_PER_SEQ: tl.constexpr, + CAUSAL: tl.constexpr, +): + q_block_global_idx = tl.program_id(0) + kv_head_idx = tl.program_id(1) + segm_idx = tl.program_id(2) + + seq_idx = find_seq_idx(query_start_len_ptr, q_block_global_idx, num_seqs, BLOCK_Q, True) + q_block_start_idx = tl.load(query_start_len_ptr + seq_idx) // BLOCK_Q + seq_idx + q_block_local_idx = q_block_global_idx - q_block_start_idx + + cur_batch_in_all_start_index = tl.load(query_start_len_ptr + seq_idx) + cur_batch_in_all_stop_index = tl.load(query_start_len_ptr + seq_idx + 1) + cur_batch_query_len = cur_batch_in_all_stop_index - cur_batch_in_all_start_index + + if q_block_local_idx * BLOCK_Q >= cur_batch_query_len: + return + + seq_len = tl.load(seq_lens_ptr + seq_idx) + num_segments = NUM_SEGMENTS_PER_SEQ + tiles_per_segment = cdiv_fn(seq_len, num_segments * TILE_SIZE) + + if segm_idx * tiles_per_segment * TILE_SIZE >= seq_len: + return + + offs_m = tl.arange(0, BLOCK_M) + offs_d = tl.arange(0, HEAD_SIZE_PADDED) + offs_t = tl.arange(0, TILE_SIZE) + query_pos = q_block_local_idx * BLOCK_Q + offs_m // num_queries_per_kv + + query_offset_0 = cur_batch_in_all_start_index + query_pos + query_offset_1 = kv_head_idx * num_queries_per_kv + offs_m % num_queries_per_kv + query_offset = ( + query_offset_0[:, None] * query_stride_0 + + query_offset_1[:, None] * query_stride_1 + + offs_d[None, :] + ) + + dim_mask = tl.where(offs_d < HEAD_SIZE, 1, 0).to(tl.int1) + query_mask_0 = tl.where(query_pos < cur_batch_query_len, 1, 0).to(tl.int1) + query_mask_1 = tl.where(query_offset_1 < num_query_heads, 1, 0).to(tl.int1) + + Q = tl.load( + query_ptr + query_offset, + mask=dim_mask[None, :] & query_mask_0[:, None] & query_mask_1[:, None], + other=0.0, + ) + + block_table_offset = seq_idx * block_table_stride + M = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32) + L = tl.full([BLOCK_M], 1.0, dtype=tl.float32) + acc = tl.zeros([BLOCK_M, HEAD_SIZE_PADDED], dtype=tl.float32) + + context_len = seq_len - cur_batch_query_len + + if CAUSAL: + max_seq_prefix_len = ( + context_len + q_block_local_idx * BLOCK_Q + (BLOCK_M - 1) // num_queries_per_kv + 1 + ) + max_seq_prefix_len = tl.minimum(max_seq_prefix_len, seq_len) + else: + max_seq_prefix_len = seq_len + num_tiles = cdiv_fn(max_seq_prefix_len, TILE_SIZE) + + tile_start = 0 + tile_end = num_tiles + if CAUSAL and SLIDING_WINDOW > 0: + qpos_lo = q_block_local_idx * BLOCK_Q + qpos_hi = tl.minimum( + qpos_lo + (BLOCK_M - 1) // num_queries_per_kv, + cur_batch_query_len - 1, + ) + first_allowed_key = context_len + qpos_lo - SLIDING_WINDOW + 1 + last_allowed_key = context_len + qpos_hi + tile_start = tl.maximum(0, first_allowed_key // TILE_SIZE) + tile_end = tl.minimum((last_allowed_key // TILE_SIZE) + 1, num_tiles) + + for j in range( + max(segm_idx * tiles_per_segment, tile_start), + min((segm_idx + 1) * tiles_per_segment, tile_end), + ): + seq_offset = j * TILE_SIZE + offs_t + tile_mask = seq_offset < max_seq_prefix_len + + physical_block_idx = tl.load( + block_tables_ptr + block_table_offset + seq_offset // BLOCK_SIZE + ).to(tl.int64) + + v_offset = ( + physical_block_idx[:, None] * stride_v_cache_0 + + kv_head_idx * stride_v_cache_2 + + offs_d[None, :] * stride_v_cache_3 + + (seq_offset % BLOCK_SIZE)[:, None] * stride_v_cache_1 + ) + k_offset = ( + physical_block_idx[None, :] * stride_k_cache_0 + + kv_head_idx * stride_k_cache_2 + + offs_d[:, None] * stride_k_cache_3 + + (seq_offset % BLOCK_SIZE)[None, :] * stride_k_cache_1 + ) + + K = tl.load( + key_cache_ptr + k_offset, + mask=dim_mask[:, None] & tile_mask[None, :], + other=0.0, + ) + V = tl.load( + value_cache_ptr + v_offset, + mask=dim_mask[None, :] & tile_mask[:, None], + other=0.0, + ) + + if CAUSAL: + query_abs_pos = context_len + query_pos[:, None] + seq_mask = seq_offset[None, :] <= query_abs_pos + if SLIDING_WINDOW > 0: + seq_mask = seq_mask & ((query_abs_pos - seq_offset) < SLIDING_WINDOW) + else: + seq_mask = tile_mask[None, :] + + S = tl.zeros([BLOCK_M, TILE_SIZE], dtype=tl.float32) + S += scale * tl.dot(Q, K) + S = tl.where(query_mask_1[:, None] & query_mask_0[:, None] & seq_mask, S, float("-inf")) + + m_j = tl.maximum(M, tl.max(S, axis=1)) + m_j = tl.where(m_j > float("-inf"), m_j, 0.0) + P = tl.exp(S - m_j[:, None]) + l_j = tl.sum(P, axis=1) + alpha = tl.exp(M - m_j) + acc = acc * alpha[:, None] + L = L * alpha + l_j + M = m_j + + if CAUSAL and SLIDING_WINDOW > 0: + qpos_lo = q_block_local_idx * BLOCK_Q + V = tl.where((context_len + qpos_lo - seq_offset[:, None]) < SLIDING_WINDOW, V, 0.0) + acc += tl.dot(P.to(V.dtype), V) + + segm_output_offset = ( + query_offset_0[:, None].to(tl.int64) + * (num_query_heads * NUM_SEGMENTS_PER_SEQ * HEAD_SIZE_PADDED) + + query_offset_1[:, None] * (NUM_SEGMENTS_PER_SEQ * HEAD_SIZE_PADDED) + + segm_idx * HEAD_SIZE_PADDED + + tl.arange(0, HEAD_SIZE_PADDED)[None, :] + ) + tl.store( + segm_output_ptr + segm_output_offset, + acc, + mask=dim_mask[None, :] & query_mask_0[:, None] & query_mask_1[:, None], + ) + segm_offset = ( + query_offset_0.to(tl.int64) * (num_query_heads * NUM_SEGMENTS_PER_SEQ) + + query_offset_1 * NUM_SEGMENTS_PER_SEQ + + segm_idx + ) + tl.store(segm_max_ptr + segm_offset, M, mask=query_mask_0 & query_mask_1) + tl.store(segm_expsum_ptr + segm_offset, L, mask=query_mask_0 & query_mask_1) + + +@triton.jit +def reduce_segments( + output_ptr, + segm_output_ptr, + segm_max_ptr, + segm_expsum_ptr, + seq_lens_ptr, + num_seqs, + num_query_heads: tl.constexpr, + output_stride_0: tl.int64, + output_stride_1: tl.int64, + block_table_stride: tl.int64, + TILE_SIZE: tl.constexpr, + HEAD_SIZE: tl.constexpr, + HEAD_SIZE_PADDED: tl.constexpr, + query_start_len_ptr, + BLOCK_Q: tl.constexpr, + NUM_SEGMENTS_PER_SEQ: tl.constexpr, +): + query_token_idx = tl.program_id(0) + query_head_idx = tl.program_id(1) + + seq_idx = find_seq_idx(query_start_len_ptr, query_token_idx, num_seqs, BLOCK_Q, False) + seq_len = tl.load(seq_lens_ptr + seq_idx) + num_segments = NUM_SEGMENTS_PER_SEQ + tiles_per_segment = cdiv_fn(seq_len, num_segments * TILE_SIZE) + + act_num_segments = cdiv_fn(seq_len, tiles_per_segment * TILE_SIZE) + segm_mask = tl.arange(0, NUM_SEGMENTS_PER_SEQ) < tl.full( + [NUM_SEGMENTS_PER_SEQ], act_num_segments, dtype=tl.int32 + ) + dim_mask = tl.where(tl.arange(0, HEAD_SIZE_PADDED) < HEAD_SIZE, 1, 0).to(tl.int1) + + segm_offset = ( + query_token_idx.to(tl.int64) * (num_query_heads * NUM_SEGMENTS_PER_SEQ) + + query_head_idx * NUM_SEGMENTS_PER_SEQ + + tl.arange(0, NUM_SEGMENTS_PER_SEQ) + ) + segm_max = tl.load(segm_max_ptr + segm_offset, mask=segm_mask, other=float("-inf")) + overall_max = tl.max(segm_max) + + segm_expsum = tl.load(segm_expsum_ptr + segm_offset, mask=segm_mask, other=0.0) + segm_expsum = segm_expsum * tl.exp(segm_max - overall_max) + overall_expsum = tl.sum(segm_expsum) + + segm_output_offset = ( + query_token_idx.to(tl.int64) * (num_query_heads * NUM_SEGMENTS_PER_SEQ * HEAD_SIZE_PADDED) + + query_head_idx * (NUM_SEGMENTS_PER_SEQ * HEAD_SIZE_PADDED) + + tl.arange(0, NUM_SEGMENTS_PER_SEQ)[:, None] * HEAD_SIZE_PADDED + + tl.arange(0, HEAD_SIZE_PADDED)[None, :] + ) + segm_output = tl.load( + segm_output_ptr + segm_output_offset, + mask=segm_mask[:, None] & dim_mask[None, :], + other=0.0, + ) + segm_output *= tl.exp(segm_max - overall_max)[:, None] + acc_sum = tl.sum(segm_output, axis=0) + acc = tl.where(overall_expsum == 0.0, 0.0, acc_sum / overall_expsum) + + output_offset = ( + query_token_idx * output_stride_0 + + query_head_idx * output_stride_1 + + tl.arange(0, HEAD_SIZE_PADDED) + ) + tl.store(output_ptr + output_offset, acc, mask=dim_mask) + + +def _get_tile_size( + head_size: int, + sliding_window: int, + element_size: int, + is_prefill: bool, +) -> int: + """Select tile size. Must be power of 2.""" + if sliding_window == 1024 and head_size in (128, 256): + return 32 + if is_prefill: + return 32 + return 16 if element_size >= 2 else 32 + + +def unified_attention( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + out: torch.Tensor, + cu_seqlens_q: torch.Tensor, + max_seqlen_q: int, + seqused_k: torch.Tensor, + max_seqlen_k: int, + softmax_scale: float, + causal: bool, + window_size: tuple[int, int], + block_table: torch.Tensor, + seq_threshold_3D: int | None = None, + num_par_softmax_segments: int | None = None, + softmax_segm_output: torch.Tensor | None = None, + softmax_segm_max: torch.Tensor | None = None, + softmax_segm_expsum: torch.Tensor | None = None, + apply_sparse24: bool = False, + skip_diagonal_blocks: bool = True, +) -> None: + """Unified attention over paged KV cache (prefill and decode). + + Args: + q: [num_tokens, num_query_heads, head_size] + k: [num_blocks, block_size, num_kv_heads, head_size] (paged K cache) + v: [num_blocks, block_size, num_kv_heads, head_size] (paged V cache) + out: [num_tokens, num_query_heads, head_size] + cu_seqlens_q: [num_seqs + 1] cumulative query token counts + max_seqlen_q: max query length + seqused_k: [num_seqs] total sequence length per batch (context + query) + max_seqlen_k: max sequence length + softmax_scale: attention scale (e.g. 1/sqrt(head_size)) + causal: True for causal self-attention, False for cross-attention + window_size: (q_window, k_window), -1 means disabled; only used when causal=True + block_table: [num_seqs, max_blocks_per_seq] + seq_threshold_3D: if set with 3D buffers, use 3D kernel when num_seqs <= this + num_par_softmax_segments: number of segments for 3D kernel + softmax_segm_output, softmax_segm_max, softmax_segm_expsum: 3D kernel buffers + apply_sparse24: If True, apply 2:4 structured sparsity to attention scores. + Only applied during prefill (max_seqlen_q > 1); automatically disabled + during decode. The 3D kernel path also ignores this flag (a warning is + emitted). TILE_SIZE must be divisible by 4. + skip_diagonal_blocks: If True, keep diagonal tiles dense (local attention + preserved) when sparse24 is active. + """ + block_size = v.shape[1] + num_seqs = seqused_k.shape[0] + num_query_heads = q.shape[1] + num_kv_heads = k.shape[2] + num_queries_per_kv = num_query_heads // num_kv_heads + head_size = q.shape[2] + + BLOCK_M = 16 if num_queries_per_kv <= 16 else triton.next_power_of_2(num_queries_per_kv) + BLOCK_Q = BLOCK_M // num_queries_per_kv + total_num_q_blocks = q.shape[0] // BLOCK_Q + num_seqs + + head_size_padded = max(triton.next_power_of_2(head_size), 16) + sliding_window_val = 1 + window_size[0] if window_size[0] >= 0 else 0 + TILE_SIZE_PREFILL = _get_tile_size( + head_size, sliding_window_val, q.element_size(), is_prefill=True + ) + TILE_SIZE_DECODE = _get_tile_size( + head_size, sliding_window_val, q.element_size(), is_prefill=False + ) + + if apply_sparse24: + assert TILE_SIZE_PREFILL % 4 == 0, ( + f"sparse24 requires TILE_SIZE divisible by 4, got TILE_SIZE_PREFILL={TILE_SIZE_PREFILL}" + ) + + use_3d = ( + seq_threshold_3D is not None + and num_par_softmax_segments is not None + and softmax_segm_output is not None + and softmax_segm_max is not None + and softmax_segm_expsum is not None + and max_seqlen_q <= 1 + and num_seqs <= seq_threshold_3D + ) + + # Sparse24 is only meaningful during prefill (max_seqlen_q > 1). + # During decode (max_seqlen_q <= 1), disable it regardless of the caller's flag. + # The 3D kernel (decode-only) therefore never sees sparse24 enabled. + is_prefill = max_seqlen_q > 1 + effective_sparse24 = apply_sparse24 and is_prefill + + if not use_3d: + kernel_unified_attention_2d[(total_num_q_blocks, num_kv_heads)]( + output_ptr=out, + query_ptr=q, + key_cache_ptr=k, + value_cache_ptr=v, + block_tables_ptr=block_table, + seq_lens_ptr=seqused_k, + scale=softmax_scale, + num_query_heads=num_query_heads, + num_queries_per_kv=num_queries_per_kv, + block_table_stride=block_table.stride(0), + query_stride_0=q.stride(0), + query_stride_1=q.stride(1), + output_stride_0=out.stride(0), + output_stride_1=out.stride(1), + BLOCK_SIZE=block_size, + TILE_SIZE=TILE_SIZE_PREFILL, + HEAD_SIZE=head_size, + HEAD_SIZE_PADDED=head_size_padded, + SLIDING_WINDOW=sliding_window_val, + stride_k_cache_0=k.stride(0), + stride_k_cache_1=k.stride(1), + stride_k_cache_2=k.stride(2), + stride_k_cache_3=k.stride(3), + stride_v_cache_0=v.stride(0), + stride_v_cache_1=v.stride(1), + stride_v_cache_2=v.stride(2), + stride_v_cache_3=v.stride(3), + query_start_len_ptr=cu_seqlens_q, + BLOCK_Q=BLOCK_Q, + num_seqs=num_seqs, + BLOCK_M=BLOCK_M, + APPLY_SPARSE24=effective_sparse24, + SKIP_DIAGONAL_BLOCKS=skip_diagonal_blocks, + CAUSAL=causal, + ) + else: + kernel_unified_attention_3d[(total_num_q_blocks, num_kv_heads, num_par_softmax_segments)]( + segm_output_ptr=softmax_segm_output, + segm_max_ptr=softmax_segm_max, + segm_expsum_ptr=softmax_segm_expsum, + query_ptr=q, + key_cache_ptr=k, + value_cache_ptr=v, + block_tables_ptr=block_table, + seq_lens_ptr=seqused_k, + scale=softmax_scale, + num_query_heads=num_query_heads, + num_queries_per_kv=num_queries_per_kv, + block_table_stride=block_table.stride(0), + query_stride_0=q.stride(0), + query_stride_1=q.stride(1), + BLOCK_SIZE=block_size, + TILE_SIZE=TILE_SIZE_DECODE, + HEAD_SIZE=head_size, + HEAD_SIZE_PADDED=head_size_padded, + SLIDING_WINDOW=sliding_window_val, + stride_k_cache_0=k.stride(0), + stride_k_cache_1=k.stride(1), + stride_k_cache_2=k.stride(2), + stride_k_cache_3=k.stride(3), + stride_v_cache_0=v.stride(0), + stride_v_cache_1=v.stride(1), + stride_v_cache_2=v.stride(2), + stride_v_cache_3=v.stride(3), + query_start_len_ptr=cu_seqlens_q, + BLOCK_Q=BLOCK_Q, + num_seqs=num_seqs, + BLOCK_M=BLOCK_M, + NUM_SEGMENTS_PER_SEQ=num_par_softmax_segments, + CAUSAL=causal, + ) + reduce_segments[(q.shape[0], num_query_heads)]( + output_ptr=out, + segm_output_ptr=softmax_segm_output, + segm_max_ptr=softmax_segm_max, + segm_expsum_ptr=softmax_segm_expsum, + seq_lens_ptr=seqused_k, + num_seqs=num_seqs, + num_query_heads=num_query_heads, + output_stride_0=out.stride(0), + output_stride_1=out.stride(1), + block_table_stride=block_table.stride(0), + TILE_SIZE=TILE_SIZE_DECODE, + HEAD_SIZE=head_size, + HEAD_SIZE_PADDED=head_size_padded, + query_start_len_ptr=cu_seqlens_q, + BLOCK_Q=BLOCK_Q, + NUM_SEGMENTS_PER_SEQ=num_par_softmax_segments, + ) + + +def context_attention_fwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + o: torch.Tensor, + b_start_loc: torch.Tensor, + b_seq_len: torch.Tensor, + max_input_len: int, + is_causal: bool = True, + softmax_scale: float | None = None, + sliding_window_q: int | None = None, + sliding_window_k: int | None = None, + apply_sparse24: bool = False, + skip_diagonal_blocks: bool = True, + b_start_loc_k: torch.Tensor | None = None, + b_seq_len_k: torch.Tensor | None = None, + max_input_len_k: int | None = None, +) -> None: + """Prefill attention over contiguous Q/K/V (packed format). + + Converts contiguous tensors to paged format and calls unified_attention. + For causal self-attention, Q and K/V share the same sequence lengths + (``b_seq_len``). For cross-attention (``is_causal=False``), K/V may have + different lengths specified via ``b_seq_len_k``. + + Note: + When ``apply_sparse24=True``, TILE_SIZE must be divisible by 4 (the 2:4 + sparsity reshape requires ``N // 4``). Current tile sizes (16, 32) satisfy + this. Causal-masked elements participate in the 2:4 top-2 selection, which + may waste sparsity slots near the diagonal. + """ + if q.dim() != 3 or k.dim() != 3 or v.dim() != 3 or o.dim() != 3: + raise ValueError( + "q, k, v, o must be rank-3 [total_tokens, num_heads, head_dim]; " + f"got q.dim()={q.dim()}, k.dim()={k.dim()}, v.dim()={v.dim()}, o.dim()={o.dim()}." + ) + head_dim = q.shape[2] + if k.shape[2] != head_dim or v.shape[2] != head_dim or o.shape[2] != head_dim: + raise ValueError( + "q, k, v, o must have same head_dim (shape[2]); " + f"got {q.shape[2]}, {k.shape[2]}, {v.shape[2]}, {o.shape[2]}." + ) + if o.shape[0] != q.shape[0] or o.shape[1] != q.shape[1]: + raise ValueError(f"o must match q shape; got o={o.shape}, q={q.shape}.") + num_kv_heads = k.shape[1] + if num_kv_heads <= 0: + raise ValueError(f"k.shape[1] (num_kv_heads) must be positive; got {num_kv_heads}.") + if q.shape[1] % num_kv_heads != 0: + raise ValueError( + f"num_heads (q.shape[1]) must be divisible by num_kv_heads (k.shape[1]); " + f"got {q.shape[1]} and {num_kv_heads}." + ) + + # For causal self-attention, Q and K/V share lengths. + # For cross-attention, K/V lengths come from separate parameters. + if b_seq_len_k is None: + # Self-attention: Q and K/V have same total tokens and lengths + total_q = q.shape[0] + if k.shape[0] != total_q or v.shape[0] != total_q: + raise ValueError( + "For causal self-attention, q, k, v must have same shape[0]; " + f"got {q.shape[0]}, {k.shape[0]}, {v.shape[0]}. " + "For cross-attention, pass b_seq_len_k and b_start_loc_k." + ) + b_seq_len_k = b_seq_len + b_start_loc_k = b_start_loc + max_input_len_k = max_input_len + + batch = b_seq_len.shape[0] + if b_start_loc_k is None: + b_start_loc_k = torch.zeros(batch + 1, device=q.device, dtype=torch.int32) + b_start_loc_k[1:] = torch.cumsum(b_seq_len_k.to(torch.int64), dim=0) + b_start_loc_k = b_start_loc_k[:batch] + if max_input_len_k is None: + max_input_len_k = int(b_seq_len_k.max().item()) + + device = q.device + dtype = q.dtype + block_size = ((max_input_len_k + 31) // 32) * 32 + + k_cache = torch.zeros((batch, block_size, num_kv_heads, head_dim), device=device, dtype=dtype) + v_cache = torch.zeros((batch, block_size, num_kv_heads, head_dim), device=device, dtype=dtype) + for i in range(batch): + start = int(b_start_loc_k[i].item()) + length = int(b_seq_len_k[i].item()) + if length > 0: + k_cache[i, :length, :, :] = k[start : start + length] + v_cache[i, :length, :, :] = v[start : start + length] + + block_table = torch.arange(batch, device=device, dtype=torch.int32).unsqueeze(1) + + cu_seqlens_q = torch.zeros(batch + 1, device=device, dtype=torch.int32) + cu_seqlens_q[1:] = torch.cumsum(b_seq_len.to(torch.int64), dim=0) + seqused_k = b_seq_len_k.to(torch.int32) + + scale = 1.0 / (head_dim**0.5) if softmax_scale is None else softmax_scale + sw_q = sliding_window_q if sliding_window_q is not None else -1 + sw_k = sliding_window_k if sliding_window_k is not None else -1 + window_size = (sw_q, sw_k) + + unified_attention( + q=q, + k=k_cache, + v=v_cache, + out=o, + cu_seqlens_q=cu_seqlens_q, + max_seqlen_q=max_input_len, + seqused_k=seqused_k, + max_seqlen_k=block_size, + softmax_scale=scale, + causal=is_causal, + window_size=window_size, + block_table=block_table, + apply_sparse24=apply_sparse24, + skip_diagonal_blocks=skip_diagonal_blocks, + ) + + +__all__ = ["context_attention_fwd", "unified_attention"] diff --git a/modelopt/torch/sparsity/attention_sparsity/methods/__init__.py b/modelopt/torch/sparsity/attention_sparsity/methods/__init__.py index 8a109fda7..21c6f4312 100644 --- a/modelopt/torch/sparsity/attention_sparsity/methods/__init__.py +++ b/modelopt/torch/sparsity/attention_sparsity/methods/__init__.py @@ -24,4 +24,4 @@ ] # Import method implementations to trigger registration -from . import flash_skip_softmax +from . import flash_skip_softmax, sparse24_triton diff --git a/modelopt/torch/sparsity/attention_sparsity/methods/flash_skip_softmax.py b/modelopt/torch/sparsity/attention_sparsity/methods/flash_skip_softmax.py index e575de4da..3ebfd8c7e 100644 --- a/modelopt/torch/sparsity/attention_sparsity/methods/flash_skip_softmax.py +++ b/modelopt/torch/sparsity/attention_sparsity/methods/flash_skip_softmax.py @@ -24,6 +24,9 @@ import numpy as np import torch +import torch.nn.functional as F + +from modelopt.torch.quantization.utils import replace_function from . import SparseAttentionMethod, register_sparse_method @@ -353,6 +356,19 @@ def get_threshold_info(self) -> dict[str, Any]: "value": self.threshold_config, } + def get_sparse_context(self, module: torch.nn.Module): + """Return a context manager that patches F.softmax with sparse masking.""" + original_softmax = F.softmax + + def sparse_softmax(input, dim=-1, *args, **kwargs): + sparse_mask, stats = self.calculate_sparsity(input) + module._last_stats = stats + if not self._calibration_mode: + input = self.apply_sparsity(input, sparse_mask) + return original_softmax(input, dim, *args, **kwargs) + + return replace_function(torch.nn.functional, "softmax", sparse_softmax) + @property def name(self) -> str: """Method identifier.""" diff --git a/modelopt/torch/sparsity/attention_sparsity/methods/registry.py b/modelopt/torch/sparsity/attention_sparsity/methods/registry.py index 6329e4446..3f3e78db6 100644 --- a/modelopt/torch/sparsity/attention_sparsity/methods/registry.py +++ b/modelopt/torch/sparsity/attention_sparsity/methods/registry.py @@ -70,6 +70,18 @@ def apply_sparsity( Masked attention scores with sparse elements set to -inf """ + def get_sparse_context(self, module: torch.nn.Module): + """Return a context manager that activates this method's sparsity during forward. + + Each method subclass implements its own activation mechanism: + - Softmax-patching methods replace F.softmax during the forward pass. + - Kernel-fused methods set flags on ``module`` that the kernel reads. + + Args: + module: The SparseAttentionModule wrapping the attention layer. + """ + raise NotImplementedError(f"{type(self).__name__} must implement get_sparse_context()") + def get_threshold_info(self) -> dict[str, Any]: """Get threshold information for display/debugging. diff --git a/modelopt/torch/sparsity/attention_sparsity/methods/sparse24_triton.py b/modelopt/torch/sparsity/attention_sparsity/methods/sparse24_triton.py new file mode 100644 index 000000000..061349bbc --- /dev/null +++ b/modelopt/torch/sparsity/attention_sparsity/methods/sparse24_triton.py @@ -0,0 +1,161 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""2:4 structured sparse attention method for the Triton prefill kernel. + +This method is used with backend="triton" and attn_implementation="modelopt_triton". +Sparsity is applied inside the Triton kernel during prefill; this class provides +the SparseAttentionMethod interface for config-driven setup and optional diagnostics. +""" + +import contextlib +from typing import Any + +import torch + +from . import SparseAttentionMethod, register_sparse_method + + +def _sparse24_mask_along_last_dim(scores: torch.Tensor) -> torch.Tensor: + """Compute 2:4 mask: for every 4 elements along the last dim, keep the 2 largest. + + Args: + scores: Tensor of shape [..., N] with N divisible by 4. + + Returns: + Boolean mask of same shape; True where the element is kept (top-2 of 4). + """ + *prefix, n = scores.shape + assert n % 4 == 0, "2:4 sparsity requires last dim divisible by 4" + grouped = scores.reshape(*prefix, n // 4, 4) + # topk(2) along dim=-1; indices [..., 0] and [..., 1] are the two largest + _, top2_idx = torch.topk(grouped, k=2, dim=-1, largest=True, sorted=False) + mask = torch.zeros_like(grouped, dtype=torch.bool) + mask.scatter_(-1, top2_idx, True) + return mask.reshape(*prefix, n) + + +@register_sparse_method("sparse24_triton") +class Sparse24Triton(SparseAttentionMethod): + """2:4 structured sparse attention for the Triton prefill kernel. + + When backend is "triton", sparsity is applied inside the kernel; this method + provides the config interface and optional PyTorch-side diagnostics (e.g. + calculate_sparsity for stats). No calibration; pattern is fixed (top-2 of every 4). + """ + + def __init__(self, method_config: dict | None = None): + """Initialize 2:4 Triton sparse attention method. + + Args: + method_config: Configuration dict. Uses skip_diagonal_blocks, is_causal; + ignores threshold, br, bc (not used by 2:4). + """ + super().__init__() + config = method_config or {} + self.skip_diagonal_blocks = config.get("skip_diagonal_blocks", True) + self.is_causal = config.get("is_causal", True) + self.backend = config.get("backend", "triton") + + def _infer_phase(self, attention_scores: torch.Tensor) -> str: + """Infer phase from attention scores shape.""" + return "decode" if attention_scores.shape[2] == 1 else "prefill" + + def calculate_sparsity( + self, + attention_scores: torch.Tensor, + ) -> tuple[torch.Tensor, dict]: + """Calculate 2:4 sparsity mask and statistics (PyTorch reference). + + Used for diagnostics when collect_stats is enabled. The actual sparsity + during forward with backend="triton" is applied inside the Triton kernel. + + Args: + attention_scores: [batch, heads, seq_q, seq_k] + + Returns: + (sparse_mask, stats_dict) + """ + assert attention_scores.dim() == 4, ( + f"Expected 4D attention scores, got shape {attention_scores.shape}" + ) + batch, num_heads, seq_q, seq_k = attention_scores.shape + phase = self._infer_phase(attention_scores) + + # Pad seq_k to multiple of 4 for 2:4 grouping + pad = (4 - seq_k % 4) % 4 + if pad > 0: + scores_padded = torch.nn.functional.pad( + attention_scores, (0, pad), value=torch.finfo(attention_scores.dtype).min + ) + else: + scores_padded = attention_scores + + mask_padded = _sparse24_mask_along_last_dim(scores_padded) + if pad > 0: + sparse_mask = mask_padded[..., :seq_k].contiguous() + else: + sparse_mask = mask_padded + + # 2:4 keeps 2 of 4 -> 50% kept (0.5 sparsity ratio as "fraction sparse" = 0.5) + sparsity = 0.5 + stats = { + "sparsity": sparsity, + "phase": phase, + "total_blocks": (seq_k + pad) // 4 * seq_q * num_heads * batch, + "sparse_blocks": int(0.5 * (seq_k + pad) // 4 * seq_q * num_heads * batch), + "sample_length": seq_k, + } + return sparse_mask, stats + + def apply_sparsity( + self, + attention_scores: torch.Tensor, + sparse_mask: torch.Tensor | None = None, + ) -> torch.Tensor: + """Apply 2:4 sparsity mask to attention scores. + + Args: + attention_scores: [batch, heads, seq_q, seq_k] + sparse_mask: Optional pre-computed mask. If None, computes via calculate_sparsity. + + Returns: + Masked scores (same shape); masked positions set to dtype min. + """ + if sparse_mask is None: + sparse_mask, _ = self.calculate_sparsity(attention_scores) + mask_value = torch.finfo(attention_scores.dtype).min + return attention_scores.masked_fill(~sparse_mask, mask_value) + + @contextlib.contextmanager + def get_sparse_context(self, module: torch.nn.Module): + """Set _apply_sparse24 and _skip_diagonal_blocks on module for the Triton kernel.""" + module._apply_sparse24 = True + # Diagonal skip only applies to causal self-attention; for cross-attention + # there is no diagonal relationship between Q and K positions. + module._skip_diagonal_blocks = self.skip_diagonal_blocks and self.is_causal + try: + yield + finally: + module._apply_sparse24 = False + + def get_threshold_info(self) -> dict[str, Any]: + """Return fixed 2:4 pattern info (no tunable threshold).""" + return {"type": "fixed", "value": "2:4 structured"} + + @property + def name(self) -> str: + """Method identifier.""" + return "sparse24_triton" diff --git a/modelopt/torch/sparsity/attention_sparsity/plugins/huggingface.py b/modelopt/torch/sparsity/attention_sparsity/plugins/huggingface.py index 828d126e8..90c473005 100644 --- a/modelopt/torch/sparsity/attention_sparsity/plugins/huggingface.py +++ b/modelopt/torch/sparsity/attention_sparsity/plugins/huggingface.py @@ -15,6 +15,7 @@ """Dynamic sparse attention registration for HuggingFace models.""" +import logging import warnings import torch.nn as nn @@ -25,6 +26,8 @@ from ..sparse_attention import SparseAttentionModule, SparseAttentionRegistry from . import CUSTOM_MODEL_PLUGINS +logger = logging.getLogger(__name__) + class _GenericSparseAttention(SparseAttentionModule): """Generic sparse attention that works with any HF attention module. @@ -93,10 +96,12 @@ def register_sparse_attention_on_the_fly(model: nn.Module) -> bool: SparseAttentionRegistry.register({module_type: type_name})(_GenericSparseAttention) attention_types.add(module_type) registered_count += 1 - print(f"Registered {type_name} for sparse attention optimization") + logger.info("Registered %s for sparse attention optimization", type_name) if registered_count > 0: - print(f"Dynamically registered {registered_count} attention module types for sparsity") + logger.info( + "Dynamically registered %d attention module types for sparsity", registered_count + ) return registered_count > 0 @@ -124,10 +129,12 @@ def _is_supported_model(model: nn.Module) -> bool: def validate_eager_attention(model: nn.Module) -> None: - """Validate and enforce eager attention for HuggingFace models. + """Validate attention implementation for HuggingFace models. - Sparse attention requires attn_implementation='eager' because it - patches torch.nn.functional.softmax, which is only called in eager mode. + For softmax-patching methods (e.g. flash_skip_softmax) the model must use + attn_implementation='eager'. For the Triton 2:4 kernel (sparse24_triton) + the model must use attn_implementation='modelopt_triton'. We only force + eager when the current implementation is neither eager nor modelopt_triton. Args: model: Model to validate @@ -136,10 +143,10 @@ def validate_eager_attention(model: nn.Module) -> None: return attn_impl = getattr(model.config, "_attn_implementation", None) - if attn_impl and attn_impl != "eager": + if attn_impl and attn_impl not in ("eager", "modelopt_triton"): warnings.warn( - f"Sparse attention requires attn_implementation='eager', but model uses '{attn_impl}'. " - "Forcing eager attention implementation." + f"Sparse attention expects attn_implementation='eager' or 'modelopt_triton', " + f"but model uses '{attn_impl}'. Forcing eager attention implementation." ) model.config._attn_implementation = "eager" diff --git a/modelopt/torch/sparsity/attention_sparsity/sparse_attention.py b/modelopt/torch/sparsity/attention_sparsity/sparse_attention.py index 281e11e7d..17afeccde 100644 --- a/modelopt/torch/sparsity/attention_sparsity/sparse_attention.py +++ b/modelopt/torch/sparsity/attention_sparsity/sparse_attention.py @@ -17,11 +17,7 @@ from typing import Any -import torch -import torch.nn.functional as F - from modelopt.torch.opt.dynamic import DynamicModule, _DMRegistryCls -from modelopt.torch.quantization.utils import replace_function from .config import SparseAttentionAttributeConfig from .methods import get_sparse_method @@ -32,28 +28,23 @@ class SparseAttentionModule(DynamicModule): """Generic sparse attention module wrapper for applying sparsity to attention layers. This module wraps existing attention implementations to add sparse attention - capabilities by patching torch.nn.functional.softmax. + capabilities. The activation mechanism is delegated to the configured method + via ``method.get_sparse_context(module)``, so each method defines how it + integrates with the forward pass (e.g. softmax patching, kernel flags). Forward Flow: ------------- 1. Check if sparse attention is enabled (pass-through if disabled) - 2. Create softmax patch context with sparse_softmax function - 3. Apply sparse attention by patching F.softmax: - - Patches torch.nn.functional.softmax with sparse_softmax - - sparse_softmax applies method's sparsity logic before softmax - 4. Forward through original attention with sparsity applied - - Requirements: - ------------- - - Model must be loaded with attn_implementation="eager" for proper softmax interception - - Only PyTorch backend is supported (patches F.softmax) + 2. Obtain method-specific context via ``_sparse_method_instance.get_sparse_context(self)`` + 3. Run the original forward inside the context + 4. Collect statistics if stats manager is enabled Attributes: ----------- _enabled: bool Whether sparse attention is enabled _method: str - The sparse attention method to use (e.g., "flash_skip_softmax") + The sparse attention method to use (e.g., "flash_skip_softmax", "sparse24_triton") _method_config: dict Configuration dictionary for the sparse method (threshold, br, bc, etc.) _sparse_method_instance: SparseAttentionMethod @@ -190,32 +181,12 @@ def forward(self, *args, **kwargs): return result def _get_sparse_context(self): - """Get the softmax patch context for applying sparse attention.""" - return self._create_softmax_patch_context() - - def _create_softmax_patch_context(self): - """Create context manager for patching softmax function.""" - return replace_function(torch.nn.functional, "softmax", self._create_sparse_softmax()) - - def _create_sparse_softmax(self): - """Create sparse softmax function for current method.""" - original_softmax = F.softmax + """Get the context manager for applying sparse attention. - def sparse_softmax(input, dim=-1, *args, **kwargs): - # Calculate sparsity mask and collect statistics - sparse_mask, stats = self._sparse_method_instance.calculate_sparsity(input) - - # Store stats for collection - self._last_stats = stats - - # Only apply sparsity mask after calibration (not during calibration) - # During calibration, we measure sparsity without modifying the output - if not self._sparse_method_instance._calibration_mode: - input = self._sparse_method_instance.apply_sparsity(input, sparse_mask) - - return original_softmax(input, dim, *args, **kwargs) - - return sparse_softmax + Delegates to the method instance so each method defines its own + activation mechanism (softmax patching, kernel flags, etc.). + """ + return self._sparse_method_instance.get_sparse_context(self) # Create registry for sparse attention modules diff --git a/pyproject.toml b/pyproject.toml index bffa547b6..3324dcecb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -81,6 +81,7 @@ extend-ignore = [ "E501", ] # Ignore missing docstrings or line length for Jupyter notebooks "modelopt/torch/quantization/triton/*" = ["N803", "N806", "E731"] # triton style +"modelopt/torch/sparsity/attention_sparsity/kernels/*" = ["N803", "N806"] # triton kernel style "examples/deepseek/ds_kernel.py" = ["N803", "N806", "E731"] # triton style [tool.ruff.lint.pycodestyle] diff --git a/tests/gpu/torch/sparsity/attention_sparsity/test_triton_unified_attention.py b/tests/gpu/torch/sparsity/attention_sparsity/test_triton_unified_attention.py new file mode 100644 index 000000000..d13c83714 --- /dev/null +++ b/tests/gpu/torch/sparsity/attention_sparsity/test_triton_unified_attention.py @@ -0,0 +1,419 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""GPU tests for Triton unified attention kernel.""" + +import pytest +import torch +import torch.nn.functional as F + +pytestmark = [ + pytest.mark.filterwarnings("ignore::UserWarning"), + pytest.mark.filterwarnings("ignore::RuntimeWarning"), + pytest.mark.filterwarnings("ignore::DeprecationWarning"), +] + +from modelopt.torch.sparsity.attention_sparsity.kernels import ( + IS_AVAILABLE as TRITON_KERNEL_AVAILABLE, +) + +if TRITON_KERNEL_AVAILABLE: + from modelopt.torch.sparsity.attention_sparsity.kernels import ( + context_attention_fwd, + unified_attention, + ) + + +def _sdpa_reference(q, k, v, b_start_loc, b_seq_len): + """SDPA causal reference. Supports GQA. Returns [total_tokens, num_heads, dim].""" + batch = b_seq_len.shape[0] + num_q, num_kv = q.shape[1], k.shape[1] + parts = [] + for b in range(batch): + s, n = int(b_start_loc[b].item()), int(b_seq_len[b].item()) + qb = q[s : s + n].unsqueeze(0).permute(0, 2, 1, 3) + kb = k[s : s + n].unsqueeze(0).permute(0, 2, 1, 3) + vb = v[s : s + n].unsqueeze(0).permute(0, 2, 1, 3) + if num_q != num_kv: + r = num_q // num_kv + kb = kb.repeat_interleave(r, dim=1) + vb = vb.repeat_interleave(r, dim=1) + ob = F.scaled_dot_product_attention(qb, kb, vb, is_causal=True) + parts.append(ob.permute(0, 2, 1, 3).squeeze(0)) + return torch.cat(parts, dim=0) + + +def _sparse24_top2(x0, x1, x2, x3): + """Top-2-of-4 mask (same logic as Triton _sparse24_noabs_ops).""" + a1, a2, a3 = x0 > x1, x0 > x2, x0 > x3 + a4, a5, a6 = x1 > x2, x1 > x3, x2 > x3 + m0 = (a2 and a3) or (a1 and a2) or (a1 and a3) + m1 = (not a1 and a5) or (a4 and a5) or (not a1 and a4) + m2 = (not a2 and not a4) or (not a2 and a6) or (not a4 and a6) + m3 = (not a3 and not a5) or (not a3 and not a6) or (not a5 and not a6) + return m0, m1, m2, m3 + + +def _attention_sparse24_ref(q, k, v, scale, bq, ts, skip_diag=True): + """Reference attention with 2:4 sparsity + diagonal skip. [seq, dim] -> [seq, dim].""" + n = q.shape[0] + scores = scale * (q @ k.T) + scores.masked_fill_( + torch.triu(torch.ones(n, n, device=scores.device, dtype=torch.bool), 1), float("-inf") + ) + nqb = (n + bq - 1) // bq + ntiles = (n + ts - 1) // ts + for qb in range(nqb): + qs, qe = qb * bq, min((qb + 1) * bq, n) + for t in range(ntiles): + ks, ke = t * ts, min((t + 1) * ts, n) + if skip_diag and ks < qe and ke > qs: + continue + for row in range(qs, qe): + for g in range((ke - ks) // 4): + c = ks + g * 4 + vals = [scores[row, c + i].item() for i in range(4)] + mask = _sparse24_top2(*vals) + for i in range(4): + if not mask[i]: + scores[row, c + i] = float("-inf") + return F.softmax(scores.float(), dim=-1).to(q.dtype) @ v + + +@pytest.fixture(scope="module") +def tiny_llama_dir(tmp_path_factory): + """Tiny Llama: 2 layers, 64 hidden, 4 q-heads, 2 kv-heads, head_dim=16.""" + from _test_utils.torch.transformers_models import create_tiny_llama_dir + + return create_tiny_llama_dir( + tmp_path_factory.mktemp("tiny_llama"), + with_tokenizer=True, + num_hidden_layers=2, + hidden_size=64, + num_attention_heads=4, + num_key_value_heads=2, + intermediate_size=64, + max_position_embeddings=64, + ) + + +@pytest.mark.skipif(not TRITON_KERNEL_AVAILABLE, reason="Need CUDA + triton") +class TestUnifiedAttentionVsSdpa: + """Triton unified attention matches PyTorch SDPA for prefill and decode.""" + + @pytest.mark.parametrize( + ("dtype", "num_heads", "num_kv_heads", "head_dim", "tol"), + [ + (torch.float32, 2, 2, 32, 1e-2), + (torch.float16, 4, 2, 64, 2e-2), + ], + ids=["fp32_mha", "fp16_gqa"], + ) + def test_prefill_matches_sdpa(self, dtype, num_heads, num_kv_heads, head_dim, tol): + """Prefill via context_attention_fwd matches SDPA (variable-length batch).""" + seq_lens = [8, 12] + total = sum(seq_lens) + scale = 1.0 / (head_dim**0.5) + + torch.manual_seed(123) + q = torch.randn(total, num_heads, head_dim, device="cuda", dtype=dtype) + k = torch.randn(total, num_kv_heads, head_dim, device="cuda", dtype=dtype) + v = torch.randn(total, num_kv_heads, head_dim, device="cuda", dtype=dtype) + locs = torch.tensor([0, seq_lens[0]], device="cuda", dtype=torch.int32) + lens = torch.tensor(seq_lens, device="cuda", dtype=torch.int32) + + o = torch.empty_like(q) + context_attention_fwd( + q, + k, + v, + o, + b_start_loc=locs, + b_seq_len=lens, + max_input_len=max(seq_lens), + is_causal=True, + softmax_scale=scale, + ) + torch.testing.assert_close(o, _sdpa_reference(q, k, v, locs, lens), rtol=tol, atol=tol) + + def test_cross_attention_matches_sdpa(self): + """Non-causal cross-attention: different Q and K/V lengths, matches SDPA.""" + seq_q, seq_k = 6, 10 + num_heads, num_kv_heads, head_dim = 4, 2, 32 + scale = 1.0 / (head_dim**0.5) + + torch.manual_seed(501) + q = torch.randn(seq_q, num_heads, head_dim, device="cuda", dtype=torch.float32) + k = torch.randn(seq_k, num_kv_heads, head_dim, device="cuda", dtype=torch.float32) + v = torch.randn(seq_k, num_kv_heads, head_dim, device="cuda", dtype=torch.float32) + + o = torch.empty_like(q) + context_attention_fwd( + q, + k, + v, + o, + b_start_loc=torch.tensor([0], device="cuda", dtype=torch.int32), + b_seq_len=torch.tensor([seq_q], device="cuda", dtype=torch.int32), + max_input_len=seq_q, + is_causal=False, + softmax_scale=scale, + b_start_loc_k=torch.tensor([0], device="cuda", dtype=torch.int32), + b_seq_len_k=torch.tensor([seq_k], device="cuda", dtype=torch.int32), + max_input_len_k=seq_k, + ) + + # Reference: SDPA non-causal + q_ref = q.unsqueeze(0).permute(0, 2, 1, 3) # [1, heads, seq_q, dim] + k_ref = k.unsqueeze(0).permute(0, 2, 1, 3) + v_ref = v.unsqueeze(0).permute(0, 2, 1, 3) + k_ref = k_ref.repeat_interleave(num_heads // num_kv_heads, dim=1) + v_ref = v_ref.repeat_interleave(num_heads // num_kv_heads, dim=1) + o_ref = F.scaled_dot_product_attention(q_ref, k_ref, v_ref, is_causal=False) + o_ref = o_ref.permute(0, 2, 1, 3).squeeze(0) + + torch.testing.assert_close(o, o_ref, rtol=1e-2, atol=1e-2) + + def test_decode_matches_sdpa(self): + """Decode with GQA paged KV cache matches per-sample SDPA.""" + batch, ctx_lens = 2, [4, 8] + num_heads, num_kv_heads, head_dim = 4, 2, 32 + block_size = ((max(ctx_lens) + 1 + 31) // 32) * 32 + scale = 1.0 / (head_dim**0.5) + + torch.manual_seed(103) + q_dec = torch.randn(batch, num_heads, head_dim, device="cuda", dtype=torch.float32) + kc = torch.randn( + batch, block_size, num_kv_heads, head_dim, device="cuda", dtype=torch.float32 + ) + vc = torch.randn( + batch, block_size, num_kv_heads, head_dim, device="cuda", dtype=torch.float32 + ) + for i, cl in enumerate(ctx_lens): + kc[i, cl + 1 :] = 0 + vc[i, cl + 1 :] = 0 + + bt = torch.arange(batch, device="cuda", dtype=torch.int32).unsqueeze(1) + cu = torch.arange(batch + 1, device="cuda", dtype=torch.int32) + sk = torch.tensor([c + 1 for c in ctx_lens], device="cuda", dtype=torch.int32) + out = torch.empty_like(q_dec) + + unified_attention( + q=q_dec, + k=kc, + v=vc, + out=out, + cu_seqlens_q=cu, + max_seqlen_q=1, + seqused_k=sk, + max_seqlen_k=block_size, + softmax_scale=scale, + causal=True, + window_size=(-1, -1), + block_table=bt, + ) + + for i in range(batch): + sl = ctx_lens[i] + 1 + qb = q_dec[i : i + 1].unsqueeze(2) + kb = kc[i, :sl].unsqueeze(0).permute(0, 2, 1, 3) + vb = vc[i, :sl].unsqueeze(0).permute(0, 2, 1, 3) + kb = kb.repeat_interleave(num_heads // num_kv_heads, dim=1) + vb = vb.repeat_interleave(num_heads // num_kv_heads, dim=1) + ref = F.scaled_dot_product_attention(qb, kb, vb, is_causal=False).squeeze(2) + torch.testing.assert_close(out[i : i + 1], ref, rtol=1e-2, atol=1e-2) + + def test_prefill_decode_consistency(self): + """Last token of prefill matches decode output for the same sequence.""" + seq_len = 8 + num_heads, num_kv_heads, head_dim = 2, 2, 32 + block_size = ((seq_len + 15) // 16) * 16 + scale = 1.0 / (head_dim**0.5) + + torch.manual_seed(104) + q = torch.randn(seq_len, num_heads, head_dim, device="cuda", dtype=torch.float32) + k = torch.randn(seq_len, num_kv_heads, head_dim, device="cuda", dtype=torch.float32) + v = torch.randn(seq_len, num_kv_heads, head_dim, device="cuda", dtype=torch.float32) + + # Prefill + o_pf = torch.empty_like(q) + context_attention_fwd( + q, + k, + v, + o_pf, + b_start_loc=torch.tensor([0], device="cuda", dtype=torch.int32), + b_seq_len=torch.tensor([seq_len], device="cuda", dtype=torch.int32), + max_input_len=seq_len, + is_causal=True, + softmax_scale=scale, + ) + + # Decode (last token as query, full KV in cache) + kc = torch.zeros(1, block_size, num_kv_heads, head_dim, device="cuda", dtype=torch.float32) + vc = torch.zeros(1, block_size, num_kv_heads, head_dim, device="cuda", dtype=torch.float32) + kc[0, :seq_len] = k + vc[0, :seq_len] = v + o_dec = torch.empty_like(q[:1]) + unified_attention( + q=q[-1:], + k=kc, + v=vc, + out=o_dec, + cu_seqlens_q=torch.tensor([0, 1], device="cuda", dtype=torch.int32), + max_seqlen_q=1, + seqused_k=torch.tensor([seq_len], device="cuda", dtype=torch.int32), + max_seqlen_k=block_size, + softmax_scale=scale, + causal=True, + window_size=(-1, -1), + block_table=torch.zeros(1, 1, device="cuda", dtype=torch.int32), + ) + + torch.testing.assert_close(o_pf[-1:], o_dec, rtol=1e-2, atol=1e-2) + + +@pytest.mark.skipif(not TRITON_KERNEL_AVAILABLE, reason="Need CUDA + triton") +class TestSparse24Attention: + """2:4 sparse attention applied inside the Triton kernel.""" + + def test_sparse24_output_differs_from_dense(self): + """Sparse24 enabled produces different (but valid) output vs dense.""" + seq_lens, total = [48, 64], 112 + num_heads, num_kv_heads, head_dim = 2, 2, 32 + scale = 1.0 / (head_dim**0.5) + + torch.manual_seed(789) + q = torch.randn(total, num_heads, head_dim, device="cuda", dtype=torch.float32) + k = torch.randn(total, num_kv_heads, head_dim, device="cuda", dtype=torch.float32) + v = torch.randn(total, num_kv_heads, head_dim, device="cuda", dtype=torch.float32) + locs = torch.tensor([0, seq_lens[0]], device="cuda", dtype=torch.int32) + lens = torch.tensor(seq_lens, device="cuda", dtype=torch.int32) + + kw = { + "b_start_loc": locs, + "b_seq_len": lens, + "max_input_len": max(seq_lens), + "is_causal": True, + "softmax_scale": scale, + } + + o_dense = torch.empty_like(q) + context_attention_fwd(q, k, v, o_dense, apply_sparse24=False, **kw) + o_sparse = torch.empty_like(q) + context_attention_fwd( + q, k, v, o_sparse, apply_sparse24=True, skip_diagonal_blocks=True, **kw + ) + + assert not torch.equal(o_dense, o_sparse), "Sparse should differ from dense" + assert not torch.isnan(o_sparse).any() and not torch.isinf(o_sparse).any() + + def test_sparse24_matches_reference(self): + """Sparse24 with GQA (4 q-heads, 2 kv-heads) matches Python reference.""" + seq_len = 32 + num_heads, num_kv_heads, head_dim = 4, 2, 32 + nqkv = num_heads // num_kv_heads + scale = 1.0 / (head_dim**0.5) + bq, ts = 16 // nqkv, 32 + + torch.manual_seed(303) + q = torch.randn(seq_len, num_heads, head_dim, device="cuda", dtype=torch.float32) + k = torch.randn(seq_len, num_kv_heads, head_dim, device="cuda", dtype=torch.float32) + v = torch.randn(seq_len, num_kv_heads, head_dim, device="cuda", dtype=torch.float32) + + o_tri = torch.empty_like(q) + context_attention_fwd( + q, + k, + v, + o_tri, + b_start_loc=torch.tensor([0], device="cuda", dtype=torch.int32), + b_seq_len=torch.tensor([seq_len], device="cuda", dtype=torch.int32), + max_input_len=seq_len, + is_causal=True, + softmax_scale=scale, + apply_sparse24=True, + skip_diagonal_blocks=True, + ) + + o_ref = torch.empty_like(q) + for h in range(num_heads): + o_ref[:, h] = _attention_sparse24_ref( + q[:, h], + k[:, h // nqkv], + v[:, h // nqkv], + scale, + bq, + ts, + ) + + torch.testing.assert_close(o_tri, o_ref, rtol=1e-2, atol=1e-2) + + +@pytest.mark.skipif(not TRITON_KERNEL_AVAILABLE, reason="Need CUDA + triton") +class TestSparseAttentionIntegration: + """HF model + mtsa.sparsify integration.""" + + def test_triton_forward_and_generate(self, tiny_llama_dir): + """modelopt_triton attention: prefill logits valid, generate produces tokens.""" + pytest.importorskip("transformers") + from transformers import AutoModelForCausalLM, AutoTokenizer + + model = AutoModelForCausalLM.from_pretrained( + tiny_llama_dir, + attn_implementation="modelopt_triton", + torch_dtype=torch.bfloat16, + device_map="cuda", + ) + model.eval() + tok = AutoTokenizer.from_pretrained(tiny_llama_dir) + if tok.pad_token_id is None: + tok.pad_token_id = tok.eos_token_id + + ids = tok("The capital of France is", return_tensors="pt").input_ids.to("cuda") + with torch.no_grad(): + logits = model(input_ids=ids).logits + assert not torch.isnan(logits).any() and not torch.isinf(logits).any() + + with torch.no_grad(): + out = model.generate( + ids, max_new_tokens=5, do_sample=False, pad_token_id=tok.pad_token_id + ) + assert out.shape[1] == ids.shape[1] + 5 + + def test_sparsify_sparse24_produces_valid_output(self, tiny_llama_dir): + """mtsa.sparsify(model, SPARSE24_TRITON) forward produces valid logits.""" + pytest.importorskip("transformers") + from transformers import AutoModelForCausalLM, AutoTokenizer + + import modelopt.torch.sparsity.attention_sparsity as mtsa + from modelopt.torch.sparsity.attention_sparsity.config import SPARSE24_TRITON + + model = AutoModelForCausalLM.from_pretrained( + tiny_llama_dir, + torch_dtype=torch.bfloat16, + device_map="cuda", + ) + model = mtsa.sparsify(model, SPARSE24_TRITON) + model.eval() + + tok = AutoTokenizer.from_pretrained(tiny_llama_dir) + if tok.pad_token_id is None: + tok.pad_token_id = tok.eos_token_id + ids = tok("Hello world", return_tensors="pt").input_ids.to("cuda") + + with torch.no_grad(): + logits = model(input_ids=ids).logits + assert not torch.isnan(logits).any() and not torch.isinf(logits).any()