Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 10 additions & 6 deletions examples/llm_sparsity/attention_sparsity/hf_sa.py
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to update anything in example readme or changelog?

Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from modelopt.torch.sparsity.attention_sparsity.config import (
SKIP_SOFTMAX_CALIB,
SKIP_SOFTMAX_DEFAULT,
SPARSE24_TRITON,
)
from modelopt.torch.utils.memory_monitor import launch_memory_monitor

Expand All @@ -43,6 +44,7 @@
SPARSE_ATTN_CFG_CHOICES = {
"skip_softmax": SKIP_SOFTMAX_DEFAULT,
"skip_softmax_calib": SKIP_SOFTMAX_CALIB,
"sparse24_triton": SPARSE24_TRITON,
}


Expand Down Expand Up @@ -144,12 +146,14 @@ def main(args):

print(f"Loading model: {args.pyt_ckpt_path}")

# Load model and tokenizer
# Note: attn_implementation="eager" is required for calibration to work properly
# (flash_attention_2 or sdpa would bypass the softmax patching needed for stats collection)
# Select attn_implementation based on sparse method:
# - skip_softmax methods require "eager" (softmax patching bypassed by flash/sdpa)
# - sparse24_triton requires "modelopt_triton" (fused Triton kernel)
# No need to specify attn_implementation here — mtsa.sparsify() handles it
# automatically based on the sparse config (sets "modelopt_triton" for triton
# backend, keeps "eager" for pytorch backend).
model = AutoModelForCausalLM.from_pretrained(
args.pyt_ckpt_path,
attn_implementation="eager",
torch_dtype=torch.bfloat16,
)
Comment on lines +149 to 158
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Before/after comparison uses different attention backends for flash_skip_softmax.

Before sparsify() the model runs with whatever attn_implementation was selected at load time (likely "sdpa"); after sparsify() validate_eager_attention forces "eager". Any output difference now conflates sparsity effects with the SDPA → eager backend switch. For the sparse24_triton path this is less of a concern, but the skip_softmax path should still load with a consistent backend for a meaningful comparison.

Consider documenting this limitation in the comment block at lines 149-154, or conditionally set attn_implementation="eager" when the config uses a pytorch backend:

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/llm_sparsity/attention_sparsity/hf_sa.py` around lines 149 - 158,
The before/after comparison is invalid because the model is loaded with the
default attention backend (e.g., "sdpa") via
AutoModelForCausalLM.from_pretrained but after mtsa.sparsify()
validate_eager_attention forces "eager", so differences mix backend changes with
sparsity effects; fix by either (1) explicitly passing
attn_implementation="eager" into AutoModelForCausalLM.from_pretrained when the
sparse config indicates a PyTorch backend/flash_skip_softmax path (detect via
the sparsity config or args), or (2) add a clear comment in the block around
AutoModelForCausalLM.from_pretrained / mtsa.sparsify() documenting this
limitation and that comparisons for flash_skip_softmax should load with
attn_implementation set to "eager" to ensure a fair baseline.

tokenizer = AutoTokenizer.from_pretrained(args.pyt_ckpt_path)
Expand Down Expand Up @@ -246,8 +250,8 @@ def main(args):
"--backend",
type=str,
default="pytorch",
choices=["pytorch"],
help="Backend for sparse attention (default: pytorch). More backends coming soon.",
choices=["pytorch", "triton"],
help="Backend for sparse attention (default: pytorch). Use 'triton' with sparse24_triton.",
)

# Sequence length arguments
Expand Down
37 changes: 31 additions & 6 deletions modelopt/torch/sparsity/attention_sparsity/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,8 @@ class SparseAttentionAttributeConfig(ModeloptBaseConfig):
title="Backend implementation.",
description=(
"Backend to use for sparse attention computation. "
"Only 'pytorch' is supported, which uses softmax patching with F.softmax. "
"Requires model to be loaded with attn_implementation='eager'."
"'pytorch' uses softmax patching with F.softmax (requires attn_implementation='eager'). "
"'triton' uses the fused Triton kernel (requires attn_implementation='modelopt_triton')."
),
)

Expand All @@ -89,10 +89,20 @@ class SparseAttentionAttributeConfig(ModeloptBaseConfig):
description=(
"Whether the model uses causal (autoregressive) attention. "
"If True, sparsity statistics are calculated over the lower triangle only. "
"Set to False for cross-attention models. "
"Defaults to True for decoder-only models like GPT, LLaMA, etc."
),
)

skip_diagonal_blocks: bool = ModeloptField(
default=True,
title="Skip diagonal blocks.",
description=(
"When True, keep diagonal tiles dense for 2:4 sparse attention. "
"Only used by sparse24_triton method. Defaults to True."
),
)

@field_validator("method")
@classmethod
def validate_method(cls, v):
Expand All @@ -104,11 +114,12 @@ def validate_method(cls, v):
@field_validator("backend")
@classmethod
def validate_backend(cls, v):
"""Validate backend is pytorch."""
if v != "pytorch":
"""Validate backend is pytorch or triton."""
if v not in ("pytorch", "triton"):
raise ValueError(
f"Invalid backend: {v}. Only 'pytorch' backend is supported. "
f"Model must be loaded with attn_implementation='eager'."
f"Invalid backend: {v}. Supported backends: 'pytorch' (requires "
f"attn_implementation='eager'), 'triton' (requires "
f"attn_implementation='modelopt_triton')."
)
return v

Expand Down Expand Up @@ -416,10 +427,24 @@ class FlashSkipSoftmaxConfig(SparseAttentionConfig):
},
}

# 2:4 structured sparsity via Triton prefill kernel (prefill-only)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Comment says "prefill-only" but the kernel supports both prefill and decode.

The PR description explicitly states the unified Triton kernel supports both prefill (2D kernel) and decode (3D kernel) paths with paged KV cache. The comment at line 429 is inaccurate and should be corrected to avoid misleading users.

📝 Proposed fix
-# 2:4 structured sparsity via Triton prefill kernel (prefill-only)
+# 2:4 structured sparsity via Triton unified attention kernel (prefill + decode)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
# 2:4 structured sparsity via Triton prefill kernel (prefill-only)
# 2:4 structured sparsity via Triton unified attention kernel (prefill + decode)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/sparsity/attention_sparsity/config.py` at line 429, Update the
inaccurate comment string "# 2:4 structured sparsity via Triton prefill kernel
(prefill-only)" to indicate the Triton kernel supports both prefill (2D) and
decode (3D) paths with the paged KV cache; locate the comment in the attention
sparsity config where "# 2:4 structured sparsity via Triton prefill kernel
(prefill-only)" appears and change it to something like "# 2:4 structured
sparsity via unified Triton kernel (supports prefill 2D and decode 3D with paged
KV cache)" so it correctly documents the kernel capabilities.

SPARSE24_TRITON = {
"sparse_cfg": {
"*attn*": {
"method": "sparse24_triton",
"backend": "triton",
"skip_diagonal_blocks": True,
"enable": True,
},
"default": {"enable": False},
},
}


__all__ = [
"SKIP_SOFTMAX_CALIB",
"SKIP_SOFTMAX_DEFAULT",
"SPARSE24_TRITON",
"CalibrationConfig",
"FlashSkipSoftmaxConfig",
"SparseAttentionAttributeConfig",
Expand Down
34 changes: 34 additions & 0 deletions modelopt/torch/sparsity/attention_sparsity/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,37 @@
from .utils import get_named_sparse_attention_modules, get_sparse_attention_modules


def _register_triton_backend_if_needed(model: nn.Module, config: SparseAttentionConfig) -> None:
"""Register the Triton attention backend and set attn_implementation if needed.
When the config uses ``backend="triton"``, this function:
1. Registers the Triton kernel with HF's ``ALL_ATTENTION_FUNCTIONS``.
2. Sets ``model.config._attn_implementation = "modelopt_triton"`` so the
model dispatches to the Triton kernel at forward time.
This is called automatically during ``mtsa.sparsify()`` so users never need
to manually call ``register_triton_attention()`` or set ``attn_implementation``.
"""
sparse_cfg = config.sparse_cfg if hasattr(config, "sparse_cfg") else {}
needs_triton = any(
isinstance(v, dict) and v.get("backend") == "triton" for v in sparse_cfg.values()
)
if not needs_triton:
return

from .kernels import register_triton_attention

if register_triton_attention is not None:
register_triton_attention()

# Set attn_implementation on the model so HF dispatches to the Triton kernel.
# HF's ALL_ATTENTION_FUNCTIONS is checked at forward time, not construction time,
# so this works even after the model is already loaded.
model_config = getattr(model, "config", None)
if model_config is not None:
model_config._attn_implementation = "modelopt_triton"


def is_attn_sparsified(model: nn.Module) -> bool:
"""Check if a model has sparse attention applied.
Expand Down Expand Up @@ -61,6 +92,9 @@ def convert_to_sparse_attention_model(
# Initialize the true module if necessary
model = model.init_modellike() if isinstance(model, ModelLikeModule) else model

# Register Triton attention backend and set attn_implementation if needed
_register_triton_backend_if_needed(model, config)

# Apply custom model plugins
register_custom_model_plugins_on_the_fly(model)

Expand Down
56 changes: 56 additions & 0 deletions modelopt/torch/sparsity/attention_sparsity/kernels/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Triton attention kernels for sparse attention optimization."""

import torch

from modelopt.torch.utils import import_plugin

IS_AVAILABLE = False
context_attention_fwd = None
register_triton_attention = None
set_sparse24 = None
unified_attention = None

if torch.cuda.is_available():
with import_plugin(
"triton",
msg_if_missing=(
"Your device is potentially capable of using the triton attention "
"kernel. Try to install triton with `pip install triton`."
),
):
from .triton_unified_attention import context_attention_fwd as _context_attention_fwd
from .triton_unified_attention import unified_attention as _unified_attention

context_attention_fwd = _context_attention_fwd
unified_attention = _unified_attention
IS_AVAILABLE = True
with import_plugin("transformers"):
from .hf_triton_attention import register_triton_attention as _register_triton_attention
from .hf_triton_attention import set_sparse24 as _set_sparse24

register_triton_attention = _register_triton_attention
set_sparse24 = _set_sparse24
_register_triton_attention()

__all__ = [
"IS_AVAILABLE",
"context_attention_fwd",
"register_triton_attention",
"set_sparse24",
"unified_attention",
]
Loading