Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 62 additions & 0 deletions examples/diffusers/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ Cache Diffusion is a technique that reuses cached outputs from previous diffusio
| Pre-Requisites | Required & optional packages to use this technique | \[[Link](#pre-requisites)\] | |
| Getting Started | Learn how to optimize your models using quantization/cache diffusion to reduce precision and improve inference efficiency | \[[Link](#getting-started)\] | \[[docs](https://nvidia.github.io/Model-Optimizer/guides/1_quantization.html)\] |
| Support Matrix | View the support matrix to see quantization/cahce diffusion compatibility and feature availability across different models | \[[Link](#support-matrix)\] | \[[docs](https://nvidia.github.io/Model-Optimizer/guides/1_quantization.html)\] |
| Sparse Attention (Skip-Softmax) | Skip-softmax sparse attention for diffusion models | \[[Link](#sparse-attention-skip-softmax)\] | |
| Cache Diffusion | Caching technique to accelerate inference without compromising quality | \[[Link](#cache-diffusion)\] | |
| Post Training Quantization (PTQ) | Example scripts on how to run PTQ on diffusion models | \[[Link](#post-training-quantization-ptq)\] | \[[docs](https://nvidia.github.io/Model-Optimizer/guides/1_quantization.html)\] |
| Quantization Aware Training (QAT) | Example scripts on how to run QAT on diffusion models | \[[Link](#quantization-aware-training-qat)\] | \[[docs](https://nvidia.github.io/Model-Optimizer/guides/1_quantization.html)\] |
Expand Down Expand Up @@ -276,6 +277,67 @@ mto.restore(pipe.unet, your_quantized_ckpt)

By following these steps, your PEFT LoRA model should be efficiently quantized using ModelOpt, ready for deployment while maximizing performance.

## Sparse Attention (Skip-Softmax)

Skip-softmax sparse attention skips KV tiles whose attention scores are negligible during the softmax computation, reducing FLOPs without retraining. An exponential model (`scale_factor = a * exp(b * target_sparsity)`) is calibrated once, then the target sparsity can be adjusted at runtime without recalibration.

### Getting Started

```python
import modelopt.torch.sparsity.attention_sparsity as mtsa

# 1. Define config with calibration
config = {
"sparse_cfg": {
"calibration": {
"target_sparse_ratio": {"prefill": 0.5},
"threshold_trials": [1e-6, 5e-6, 1e-5, 5e-5, 1e-4, 5e-4, 1e-3, 5e-3,
1e-2, 2e-2, 5e-2, 1e-1, 2e-1, 3e-1, 5e-1, 7e-1,
8e-1, 9e-1, 9.9e-1],
},
"*.attn1": {
"method": "triton_skip_softmax",
"backend": "triton",
"is_causal": False,
"collect_stats": True,
"enable": True,
},
"*.attn2": {"enable": False},
"default": {"enable": False},
},
}

# 2. Provide a calibration forward loop
def forward_loop(model):
pipeline(prompt="a cat", num_frames=81, num_inference_steps=40, ...)

# 3. Sparsify + calibrate
mtsa.sparsify(transformer, config, forward_loop=forward_loop)

# 4. Generate as usual — sparsity is applied automatically
output = pipeline(prompt="a dog on the beach", ...)
```

### Example Scripts

#### Wan 2.2 [Script](./sparsity/wan22_skip_softmax.py)

The 14B model automatically sparsifies both `transformer` and `transformer_2`.

```bash
# 5B model — calibrate + generate (4 prompts from OpenVid-1M, 151 frames, 40 steps)
python sparsity/wan22_skip_softmax.py \
--model-path Wan-AI/Wan2.2-TI2V-5B-Diffusers \
--calibrate --target-sparsity 0.5 --calib-size 4 \
--prompt "A sunset over mountains" --output out.mp4

# 14B model (both transformers sparsified)
python sparsity/wan22_skip_softmax.py \
--model-path Wan-AI/Wan2.2-T2V-A14B-Diffusers \
--calibrate --target-sparsity 0.5 --calib-size 4 \
--prompt "A sunset over mountains" --output out.mp4
```

## Cache Diffusion

Cache Diffusion methods, such as [DeepCache](https://arxiv.org/abs/2312.00858), [Block Caching](https://arxiv.org/abs/2312.03209) and [T-Gate](https://arxiv.org/abs/2404.02747), optimize performance by reusing cached outputs from previous steps instead of recalculating them. This **training-free** caching approach is compatible with a variety of models, like **DiT** and **UNet**, enabling considerable acceleration without compromising quality.
Expand Down
141 changes: 141 additions & 0 deletions examples/diffusers/sparsity/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
# Skip-Softmax Sparse Attention for Diffusion Models

Skip-softmax sparse attention (BLASST, <https://arxiv.org/pdf/2512.12087>) skips KV
tiles whose attention scores are negligible during the FlashAttention computation,
reducing FLOPs without retraining.

Two modes are supported:
- **Fixed raw threshold** — pass a log2-space threshold directly to the Triton
kernel. No calibration needed. Good for quick testing and sweeps.
- **Calibrated threshold** — an exponential model
(`scale_factor = a * exp(b * target_sparsity)`) is calibrated once via the
Triton calibration kernel, then the target sparsity can be adjusted at runtime
without recalibration. Log-space fitting (`fit_logspace=True`) is recommended
for diffusion models where scale_factors span many orders of magnitude.

## Supported Models

| Model | Script | Notes |
|-------|--------|-------|
| WAN 2.2 5B | `wan22_skip_softmax.py` | Single transformer, self-attention only |
| WAN 2.2 14B | `wan22_skip_softmax.py` | Dual transformer (auto-detected) |
| LTX-2 | (coming soon) | Via `ltx_triton_attention.py` backend |

## Quick Start

```bash
# Fixed raw threshold (no calibration, fast)
python wan22_skip_softmax.py \
--model-path /path/to/Wan2.2-T2V-A14B-Diffusers \
--raw-threshold -0.7 \
--prompt "A cat playing piano" --output out.mp4

# With calibration
python wan22_skip_softmax.py \
--model-path /path/to/Wan2.2-T2V-A14B-Diffusers \
--calibrate --target-sparsity 0.5 \
--prompt "A cat playing piano" --output out.mp4

# Dense baseline (no sparsity, for comparison)
python wan22_skip_softmax.py \
--model-path /path/to/Wan2.2-T2V-A14B-Diffusers \
--baseline \
--prompt "A cat playing piano" --output baseline.mp4

# Report runtime sparsity (per-layer tile skip ratios)
python wan22_skip_softmax.py \
--model-path /path/to/Wan2.2-T2V-A14B-Diffusers \
--raw-threshold -0.7 --report-avg-sparsity \
--prompt "A cat playing piano" --output out.mp4
```

## Architecture

### Inference Path (Triton kernel with tile skipping)

```text
SparseAttentionModule.forward()
└─ triton_skip_softmax._triton_inference_context()
├─ Priority: raw_threshold > scale_factor (calibrated) > static threshold
├─ _set_triton_backends(raw_threshold=X) or (scale_factor=X)
├─ attention_backend("modelopt_triton")
└─ _diffusers_triton_attention() → attention()
└─ _attn_fwd kernel: skip tiles where tile_row_max < row_max + threshold
```

### Calibration Path (Triton calibration kernel)

```text
mtsa.sparsify(transformer, config, forward_loop)
├─ apply_mode() → replace attention with SparseAttentionModule
└─ calibrate()
├─ DynamicThresholdCalibrator._set_thresholds()
│ └─ method._threshold_trials = [1e-6, ..., 9.9e-1]
├─ forward_loop(model)
│ └─ SparseAttentionModule.forward()
│ └─ triton_skip_softmax._triton_calibration_context()
│ ├─ set_triton_skip_softmax_config(calibration_mode=True)
│ ├─ attention_backend("modelopt_triton")
│ └─ _diffusers_triton_attention() → attention_calibrate()
│ └─ _attn_fwd_calibrate kernel:
│ - Full attention (no skipping) for correct output
│ - Vectorized multi-threshold sparsity measurement
│ - Per-program output buffers (no atomic contention)
│ - Python-side reduction: sum across programs
├─ Fit: scale_factor = a * exp(b * sparsity)
│ └─ fit_logspace=True: fits in log space (minimizes relative error)
└─ Apply a, b to all modules
└─ Inference: threshold = scale_factor / seq_k
```

## Core Files

### Triton Kernels (`modelopt/torch/kernels/`)

| File | Role |
|------|------|
| `triton_fa.py` | `_attn_fwd`: forward kernel with optional tile skipping + sparsity measurement. `_attn_fwd_calibrate`: calibration kernel with vectorized multi-threshold testing and per-program buffers (zero atomic contention). `attention()` and `attention_calibrate()` Python APIs. |

### Sparse Attention Methods (`modelopt/torch/sparsity/attention_sparsity/methods/`)

| File | Role |
|------|------|
| `triton_skip_softmax.py` | Primary method for diffusion models. Calibration context → Triton calibration kernel. Inference context → Triton forward kernel. Supports `scale_factor` (calibrated), `raw_threshold` (direct), and static `skip_softmax_threshold`. |
| `flash_skip_softmax.py` | PyTorch-based method for HF LLMs (not used by diffusers/LTX). |
| `registry.py` | Base class `SparseAttentionMethod` with `calibration_params`, `target_sparse_ratio`, `set_calibration_mode()`. |

### Kernel Backends (`modelopt/torch/sparsity/attention_sparsity/kernels/`)

| File | Role |
|------|------|
| `diffusers_triton_attention.py` | Registers `modelopt_triton` backend in diffusers. Handles calibration mode (→ `attention_calibrate`) and inference mode (→ `attention` with `scale_factor/seq_k` or `raw_threshold`). Runtime sparsity counter accumulation. |
| `ltx_triton_attention.py` | Patches `ltx_core.Attention` modules for Triton dispatch. Same calibration/inference modes. |
| `hf_triton_attention.py` | HuggingFace `attn_implementation="modelopt_triton"` backend for LLMs. |

### Calibration (`modelopt/torch/sparsity/attention_sparsity/calibration/`)

| File | Role |
|------|------|
| `calibrate.py` | Orchestrates calibration. Skips RULER dataset when user provides `forward_loop` (diffusion models). Applies fitted (a, b) to all modules. |
| `calibrator.py` | `DynamicThresholdCalibrator`: collects (scale_factor, sparsity) pairs via Triton calibration kernel, fits exponential model `scale_factor = a * exp(b * sparsity)`. Supports `fit_logspace=True` for log-space fitting (recommended for diffusion models). |

### Config & Conversion

| File | Role |
|------|------|
| `config.py` | `SparseAttentionAttributeConfig` with `skip_softmax_threshold`, `skip_softmax_raw_threshold`, calibration settings. `CalibrationConfig` with `fit_logspace` field. |
| `conversion.py` | `_register_diffusers_backends_if_needed()` auto-registers Triton backends on `sparsify()`. |
| `sparse_attention.py` | `SparseAttentionModule` wrapper — delegates to method's `get_sparse_context()`. |

## Threshold Modes

| Mode | How threshold reaches the kernel | Use case |
|------|----------------------------------|----------|
| **Raw threshold** (`--raw-threshold -0.7`) | Passed directly as `skip_threshold_log2` — no conversion | Quick testing, sweeps |
| **Calibrated** (`--calibrate --target-sparsity 0.5`) | `scale_factor = a * exp(b * target)`, then backend computes `threshold = scale_factor / seq_k`, then kernel converts `log2(threshold) * sm_scale` | Production use with automatic seqlen adaptation |
| **Static lambda** (default `skip_softmax_threshold=0.1`) | `log2(lambda) * sm_scale` | Fallback when neither raw nor calibrated |

## Known Issues

- **14B dual transformer calibration**: Transformers are calibrated sequentially — transformer_2's calibration runs while transformer_1 is already sparsified, introducing asymmetric calibration conditions.
- **Minimum achievable sparsity**: Even the strictest threshold may yield 30-40% sparsity on diffusion models (many tiles are inherently negligible). Targets below this floor cause extrapolation; an inference-time warning is emitted.
Loading
Loading