diff --git a/examples/diffusers/README.md b/examples/diffusers/README.md index 6af226752d..a64ea2b1b4 100644 --- a/examples/diffusers/README.md +++ b/examples/diffusers/README.md @@ -13,6 +13,7 @@ Cache Diffusion is a technique that reuses cached outputs from previous diffusio | Pre-Requisites | Required & optional packages to use this technique | \[[Link](#pre-requisites)\] | | | Getting Started | Learn how to optimize your models using quantization/cache diffusion to reduce precision and improve inference efficiency | \[[Link](#getting-started)\] | \[[docs](https://nvidia.github.io/Model-Optimizer/guides/1_quantization.html)\] | | Support Matrix | View the support matrix to see quantization/cahce diffusion compatibility and feature availability across different models | \[[Link](#support-matrix)\] | \[[docs](https://nvidia.github.io/Model-Optimizer/guides/1_quantization.html)\] | +| Sparse Attention (Skip-Softmax) | Skip-softmax sparse attention for diffusion models | \[[Link](#sparse-attention-skip-softmax)\] | | | Cache Diffusion | Caching technique to accelerate inference without compromising quality | \[[Link](#cache-diffusion)\] | | | Post Training Quantization (PTQ) | Example scripts on how to run PTQ on diffusion models | \[[Link](#post-training-quantization-ptq)\] | \[[docs](https://nvidia.github.io/Model-Optimizer/guides/1_quantization.html)\] | | Quantization Aware Training (QAT) | Example scripts on how to run QAT on diffusion models | \[[Link](#quantization-aware-training-qat)\] | \[[docs](https://nvidia.github.io/Model-Optimizer/guides/1_quantization.html)\] | @@ -276,6 +277,67 @@ mto.restore(pipe.unet, your_quantized_ckpt) By following these steps, your PEFT LoRA model should be efficiently quantized using ModelOpt, ready for deployment while maximizing performance. +## Sparse Attention (Skip-Softmax) + +Skip-softmax sparse attention skips KV tiles whose attention scores are negligible during the softmax computation, reducing FLOPs without retraining. An exponential model (`scale_factor = a * exp(b * target_sparsity)`) is calibrated once, then the target sparsity can be adjusted at runtime without recalibration. + +### Getting Started + +```python +import modelopt.torch.sparsity.attention_sparsity as mtsa + +# 1. Define config with calibration +config = { + "sparse_cfg": { + "calibration": { + "target_sparse_ratio": {"prefill": 0.5}, + "threshold_trials": [1e-6, 5e-6, 1e-5, 5e-5, 1e-4, 5e-4, 1e-3, 5e-3, + 1e-2, 2e-2, 5e-2, 1e-1, 2e-1, 3e-1, 5e-1, 7e-1, + 8e-1, 9e-1, 9.9e-1], + }, + "*.attn1": { + "method": "triton_skip_softmax", + "backend": "triton", + "is_causal": False, + "collect_stats": True, + "enable": True, + }, + "*.attn2": {"enable": False}, + "default": {"enable": False}, + }, +} + +# 2. Provide a calibration forward loop +def forward_loop(model): + pipeline(prompt="a cat", num_frames=81, num_inference_steps=40, ...) + +# 3. Sparsify + calibrate +mtsa.sparsify(transformer, config, forward_loop=forward_loop) + +# 4. Generate as usual — sparsity is applied automatically +output = pipeline(prompt="a dog on the beach", ...) +``` + +### Example Scripts + +#### Wan 2.2 [Script](./sparsity/wan22_skip_softmax.py) + +The 14B model automatically sparsifies both `transformer` and `transformer_2`. + +```bash +# 5B model — calibrate + generate (4 prompts from OpenVid-1M, 151 frames, 40 steps) +python sparsity/wan22_skip_softmax.py \ + --model-path Wan-AI/Wan2.2-TI2V-5B-Diffusers \ + --calibrate --target-sparsity 0.5 --calib-size 4 \ + --prompt "A sunset over mountains" --output out.mp4 + +# 14B model (both transformers sparsified) +python sparsity/wan22_skip_softmax.py \ + --model-path Wan-AI/Wan2.2-T2V-A14B-Diffusers \ + --calibrate --target-sparsity 0.5 --calib-size 4 \ + --prompt "A sunset over mountains" --output out.mp4 +``` + ## Cache Diffusion Cache Diffusion methods, such as [DeepCache](https://arxiv.org/abs/2312.00858), [Block Caching](https://arxiv.org/abs/2312.03209) and [T-Gate](https://arxiv.org/abs/2404.02747), optimize performance by reusing cached outputs from previous steps instead of recalculating them. This **training-free** caching approach is compatible with a variety of models, like **DiT** and **UNet**, enabling considerable acceleration without compromising quality. diff --git a/examples/diffusers/sparsity/README.md b/examples/diffusers/sparsity/README.md new file mode 100644 index 0000000000..8e6c69112b --- /dev/null +++ b/examples/diffusers/sparsity/README.md @@ -0,0 +1,141 @@ +# Skip-Softmax Sparse Attention for Diffusion Models + +Skip-softmax sparse attention (BLASST, ) skips KV +tiles whose attention scores are negligible during the FlashAttention computation, +reducing FLOPs without retraining. + +Two modes are supported: +- **Fixed raw threshold** — pass a log2-space threshold directly to the Triton + kernel. No calibration needed. Good for quick testing and sweeps. +- **Calibrated threshold** — an exponential model + (`scale_factor = a * exp(b * target_sparsity)`) is calibrated once via the + Triton calibration kernel, then the target sparsity can be adjusted at runtime + without recalibration. Log-space fitting (`fit_logspace=True`) is recommended + for diffusion models where scale_factors span many orders of magnitude. + +## Supported Models + +| Model | Script | Notes | +|-------|--------|-------| +| WAN 2.2 5B | `wan22_skip_softmax.py` | Single transformer, self-attention only | +| WAN 2.2 14B | `wan22_skip_softmax.py` | Dual transformer (auto-detected) | +| LTX-2 | (coming soon) | Via `ltx_triton_attention.py` backend | + +## Quick Start + +```bash +# Fixed raw threshold (no calibration, fast) +python wan22_skip_softmax.py \ + --model-path /path/to/Wan2.2-T2V-A14B-Diffusers \ + --raw-threshold -0.7 \ + --prompt "A cat playing piano" --output out.mp4 + +# With calibration +python wan22_skip_softmax.py \ + --model-path /path/to/Wan2.2-T2V-A14B-Diffusers \ + --calibrate --target-sparsity 0.5 \ + --prompt "A cat playing piano" --output out.mp4 + +# Dense baseline (no sparsity, for comparison) +python wan22_skip_softmax.py \ + --model-path /path/to/Wan2.2-T2V-A14B-Diffusers \ + --baseline \ + --prompt "A cat playing piano" --output baseline.mp4 + +# Report runtime sparsity (per-layer tile skip ratios) +python wan22_skip_softmax.py \ + --model-path /path/to/Wan2.2-T2V-A14B-Diffusers \ + --raw-threshold -0.7 --report-avg-sparsity \ + --prompt "A cat playing piano" --output out.mp4 +``` + +## Architecture + +### Inference Path (Triton kernel with tile skipping) + +```text +SparseAttentionModule.forward() + └─ triton_skip_softmax._triton_inference_context() + ├─ Priority: raw_threshold > scale_factor (calibrated) > static threshold + ├─ _set_triton_backends(raw_threshold=X) or (scale_factor=X) + ├─ attention_backend("modelopt_triton") + └─ _diffusers_triton_attention() → attention() + └─ _attn_fwd kernel: skip tiles where tile_row_max < row_max + threshold +``` + +### Calibration Path (Triton calibration kernel) + +```text +mtsa.sparsify(transformer, config, forward_loop) + ├─ apply_mode() → replace attention with SparseAttentionModule + └─ calibrate() + ├─ DynamicThresholdCalibrator._set_thresholds() + │ └─ method._threshold_trials = [1e-6, ..., 9.9e-1] + ├─ forward_loop(model) + │ └─ SparseAttentionModule.forward() + │ └─ triton_skip_softmax._triton_calibration_context() + │ ├─ set_triton_skip_softmax_config(calibration_mode=True) + │ ├─ attention_backend("modelopt_triton") + │ └─ _diffusers_triton_attention() → attention_calibrate() + │ └─ _attn_fwd_calibrate kernel: + │ - Full attention (no skipping) for correct output + │ - Vectorized multi-threshold sparsity measurement + │ - Per-program output buffers (no atomic contention) + │ - Python-side reduction: sum across programs + ├─ Fit: scale_factor = a * exp(b * sparsity) + │ └─ fit_logspace=True: fits in log space (minimizes relative error) + └─ Apply a, b to all modules + └─ Inference: threshold = scale_factor / seq_k +``` + +## Core Files + +### Triton Kernels (`modelopt/torch/kernels/`) + +| File | Role | +|------|------| +| `triton_fa.py` | `_attn_fwd`: forward kernel with optional tile skipping + sparsity measurement. `_attn_fwd_calibrate`: calibration kernel with vectorized multi-threshold testing and per-program buffers (zero atomic contention). `attention()` and `attention_calibrate()` Python APIs. | + +### Sparse Attention Methods (`modelopt/torch/sparsity/attention_sparsity/methods/`) + +| File | Role | +|------|------| +| `triton_skip_softmax.py` | Primary method for diffusion models. Calibration context → Triton calibration kernel. Inference context → Triton forward kernel. Supports `scale_factor` (calibrated), `raw_threshold` (direct), and static `skip_softmax_threshold`. | +| `flash_skip_softmax.py` | PyTorch-based method for HF LLMs (not used by diffusers/LTX). | +| `registry.py` | Base class `SparseAttentionMethod` with `calibration_params`, `target_sparse_ratio`, `set_calibration_mode()`. | + +### Kernel Backends (`modelopt/torch/sparsity/attention_sparsity/kernels/`) + +| File | Role | +|------|------| +| `diffusers_triton_attention.py` | Registers `modelopt_triton` backend in diffusers. Handles calibration mode (→ `attention_calibrate`) and inference mode (→ `attention` with `scale_factor/seq_k` or `raw_threshold`). Runtime sparsity counter accumulation. | +| `ltx_triton_attention.py` | Patches `ltx_core.Attention` modules for Triton dispatch. Same calibration/inference modes. | +| `hf_triton_attention.py` | HuggingFace `attn_implementation="modelopt_triton"` backend for LLMs. | + +### Calibration (`modelopt/torch/sparsity/attention_sparsity/calibration/`) + +| File | Role | +|------|------| +| `calibrate.py` | Orchestrates calibration. Skips RULER dataset when user provides `forward_loop` (diffusion models). Applies fitted (a, b) to all modules. | +| `calibrator.py` | `DynamicThresholdCalibrator`: collects (scale_factor, sparsity) pairs via Triton calibration kernel, fits exponential model `scale_factor = a * exp(b * sparsity)`. Supports `fit_logspace=True` for log-space fitting (recommended for diffusion models). | + +### Config & Conversion + +| File | Role | +|------|------| +| `config.py` | `SparseAttentionAttributeConfig` with `skip_softmax_threshold`, `skip_softmax_raw_threshold`, calibration settings. `CalibrationConfig` with `fit_logspace` field. | +| `conversion.py` | `_register_diffusers_backends_if_needed()` auto-registers Triton backends on `sparsify()`. | +| `sparse_attention.py` | `SparseAttentionModule` wrapper — delegates to method's `get_sparse_context()`. | + +## Threshold Modes + +| Mode | How threshold reaches the kernel | Use case | +|------|----------------------------------|----------| +| **Raw threshold** (`--raw-threshold -0.7`) | Passed directly as `skip_threshold_log2` — no conversion | Quick testing, sweeps | +| **Calibrated** (`--calibrate --target-sparsity 0.5`) | `scale_factor = a * exp(b * target)`, then backend computes `threshold = scale_factor / seq_k`, then kernel converts `log2(threshold) * sm_scale` | Production use with automatic seqlen adaptation | +| **Static lambda** (default `skip_softmax_threshold=0.1`) | `log2(lambda) * sm_scale` | Fallback when neither raw nor calibrated | + +## Known Issues + +- **14B dual transformer calibration**: Transformers are calibrated sequentially — transformer_2's calibration runs while transformer_1 is already sparsified, introducing asymmetric calibration conditions. +- **Minimum achievable sparsity**: Even the strictest threshold may yield 30-40% sparsity on diffusion models (many tiles are inherently negligible). Targets below this floor cause extrapolation; an inference-time warning is emitted. diff --git a/examples/diffusers/sparsity/wan22_skip_softmax.py b/examples/diffusers/sparsity/wan22_skip_softmax.py new file mode 100644 index 0000000000..9824ce7251 --- /dev/null +++ b/examples/diffusers/sparsity/wan22_skip_softmax.py @@ -0,0 +1,477 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Wan 2.2 inference with skip-softmax sparse attention. + +This example applies skip-softmax sparse attention to the Wan 2.2 video +generation model (text-to-video). Three modes are supported: + +1. **Baseline** — pass ``--baseline`` for dense inference (default diffusers backend). +2. **Triton baseline** — pass ``--triton-baseline`` for dense Triton FA kernel + (no skip-softmax, same kernel as sparse runs for apples-to-apples comparison). +3. **Fixed raw threshold** — pass ``--raw-threshold`` to supply a log2-space + threshold directly to the Triton kernel. No calibration data is needed. +4. **Calibrated threshold** — pass ``--calibrate`` to run exponential-model + calibration (``scale_factor = a * exp(b * target_sparsity)``). + +During calibration, ``triton_skip_softmax`` with the Triton calibration kernel +collects sparsity statistics across multiple threshold trials. The fitted +exponential model then allows runtime control of the target sparsity ratio +without recalibration. + +The Wan 2.2 5B model has 40 transformer blocks with self-attention (attn1) +and cross-attention (attn2). Only self-attention is sparsified. + +Usage:: + + # Baseline (dense, no sparsity) + python wan22_skip_softmax.py --baseline --prompt "A cat playing piano" \\ + --output baseline.mp4 + + # Fixed raw threshold (no calibration needed) + python wan22_skip_softmax.py --raw-threshold -5.0 --report-avg-sparsity \\ + --prompt "A cat playing piano" --output out.mp4 + + # With calibration + python wan22_skip_softmax.py --calibrate --target-sparsity 0.25 \\ + --report-avg-sparsity --prompt "A cat playing piano" --output out.mp4 +""" + +import argparse +import os + +import torch +from diffusers import AutoencoderKLWan, WanPipeline +from diffusers.utils import export_to_video + +import modelopt.torch.sparsity.attention_sparsity as mtsa +from modelopt.torch.sparsity.attention_sparsity.sparse_attention import SparseAttentionModule + +DEFAULT_MODEL_PATH = os.environ.get("WAN22_MODEL_PATH", "Wan-AI/Wan2.2-TI2V-5B-Diffusers") + +# fmt: off +# ruff: noqa: RUF001 +DEFAULT_NEGATIVE_PROMPT = ( # Official Wan 2.2 negative prompt (Chinese) + "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰," + "最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部," + "画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面," + "杂乱的背景,三条腿,背景人很多,倒着走" +) +# fmt: on + +# Default threshold trials for calibration +DEFAULT_THRESHOLD_TRIALS = [ + 1e-12, + 1e-10, + 1e-8, + 1e-6, + 5e-6, + 1e-5, + 5e-5, + 1e-4, + 5e-4, + 1e-3, + 5e-3, + 1e-2, + 2e-2, + 5e-2, + 1e-1, + 2e-1, + 3e-1, + 5e-1, + 7e-1, + 8e-1, + 9e-1, + 9.9e-1, +] + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description="Wan 2.2 video generation with skip-softmax sparse attention" + ) + parser.add_argument( + "--prompt", + type=str, + default=None, + help="Text prompt for generation (optional, skips generation if not set)", + ) + parser.add_argument("--output", type=str, default="output.mp4", help="Output video path") + parser.add_argument( + "--model-path", type=str, default=DEFAULT_MODEL_PATH, help="Wan 2.2 model path or HF ID" + ) + parser.add_argument( + "--num-frames", type=int, default=81, help="Number of frames (must be 4k+1)" + ) + parser.add_argument("--height", type=int, default=480, help="Video height") + parser.add_argument("--width", type=int, default=832, help="Video width") + parser.add_argument("--num-steps", type=int, default=40, help="Number of inference steps") + parser.add_argument( + "--guidance-scale", type=float, default=4.0, help="Classifier-free guidance scale" + ) + parser.add_argument( + "--guidance-scale-2", + type=float, + default=3.0, + help="Second guidance scale for 14B dual-transformer model (ignored by 5B)", + ) + parser.add_argument( + "--negative-prompt", + type=str, + default=DEFAULT_NEGATIVE_PROMPT, + help="Negative prompt", + ) + parser.add_argument("--seed", type=int, default=42, help="Random seed") + + # Sparse attention options + parser.add_argument( + "--baseline", + action="store_true", + help="Run dense inference with default diffusers backend (no sparsity)", + ) + parser.add_argument( + "--triton-baseline", + action="store_true", + help="Run dense inference with Triton FA kernel (no skip-softmax, " + "apples-to-apples comparison with sparse runs)", + ) + parser.add_argument( + "--raw-threshold", + type=float, + default=None, + help="Raw skip_threshold_log2 value passed directly to the Triton kernel. " + "Negative values (e.g., -5.0 means tile must be within 5 units of running max). " + "Bypasses calibration and lambda conversion. Typical range: -1 to -30.", + ) + parser.add_argument( + "--skip-first-last", + type=int, + default=2, + help="Number of first/last transformer layers to keep dense (default: 2)", + ) + parser.add_argument( + "--report-avg-sparsity", + action="store_true", + help="Report per-layer and overall average tile sparsity after generation", + ) + + # Calibration options + parser.add_argument( + "--calibrate", + action="store_true", + help="Calibrate threshold via exponential model (recommended)", + ) + parser.add_argument( + "--target-sparsity", + type=float, + default=0.5, + help="Target sparsity ratio for calibration (0.0-1.0)", + ) + parser.add_argument( + "--calib-steps", + type=int, + default=40, + help="Inference steps for calibration", + ) + parser.add_argument( + "--calib-frames", + type=int, + default=151, + help="Number of frames for calibration", + ) + parser.add_argument( + "--calib-size", + type=int, + default=4, + help="Number of calibration prompts from OpenVid-1M dataset", + ) + return parser.parse_args() + + +def build_pipeline(model_path: str) -> WanPipeline: + """Build the Wan 2.2 text-to-video pipeline.""" + vae = AutoencoderKLWan.from_pretrained(model_path, subfolder="vae", torch_dtype=torch.float32) + pipe = WanPipeline.from_pretrained(model_path, vae=vae, torch_dtype=torch.bfloat16) + pipe.to("cuda") + return pipe + + +def build_sparse_config(args: argparse.Namespace, num_blocks: int) -> dict: + """Build sparse attention config from CLI args. + + Two modes: + - **Raw threshold**: ``--raw-threshold`` sets ``skip_softmax_raw_threshold`` + directly on the Triton kernel — no calibration needed. + - **Calibrated**: ``--calibrate`` collects multi-threshold sparsity statistics + via the Triton calibration kernel, then fits an exponential model: + ``scale_factor = a * exp(b * sparsity)``. + """ + attn_cfg: dict = { + "method": "triton_skip_softmax", + "skip_softmax_threshold": 0.0 if args.triton_baseline else 0.1, + "backend": "triton", + "is_causal": False, # Diffusion = bidirectional attention + "collect_stats": True, + "enable": True, + } + + # Raw threshold bypasses calibration and lambda conversion + if args.raw_threshold is not None: + attn_cfg["skip_softmax_raw_threshold"] = args.raw_threshold + + sparse_cfg: dict = { + "*.attn1*": attn_cfg, # Self-attention only + "*.attn2*": {"enable": False}, # Text cross-attention + "default": {"enable": False}, + } + + # Keep first/last N layers dense for quality + for i in range(args.skip_first_last): + sparse_cfg[f"*blocks.{i}.attn*"] = {"enable": False} + sparse_cfg[f"*blocks.{num_blocks - 1 - i}.attn*"] = {"enable": False} + + config: dict = {"sparse_cfg": sparse_cfg} + + # Add calibration config only when calibrating (not with raw threshold) + if args.calibrate and args.raw_threshold is None: + sparse_cfg["calibration"] = { + "target_sparse_ratio": {"prefill": args.target_sparsity}, + "samples": 1, + "threshold_trials": DEFAULT_THRESHOLD_TRIALS, + "fit_logspace": True, + } + + return config + + +def load_calib_prompts(calib_size: int) -> list[str]: + """Load calibration prompts from OpenVid-1M dataset.""" + from datasets import load_dataset + + dataset = load_dataset("nkp37/OpenVid-1M", split="train") + prompts = list(dataset["caption"][:calib_size]) + print(f"Loaded {len(prompts)} calibration prompts from OpenVid-1M") + return prompts + + +def build_calibration_forward_loop( + pipe: WanPipeline, + calib_size: int = 4, + num_steps: int = 40, + num_frames: int = 151, + height: int = 480, + width: int = 832, + seed: int = 42, + guidance_scale: float = 4.0, + guidance_scale_2: float | None = 3.0, + negative_prompt: str = "", +): + """Build a forward loop for exponential model calibration. + + Uses prompts from OpenVid-1M dataset (same as quantization examples). + Each prompt is run individually (batch_size=1). + """ + calib_prompts = load_calib_prompts(calib_size) + + def forward_loop(model): + for i, prompt in enumerate(calib_prompts): + print(f"Calibration [{i + 1}/{len(calib_prompts)}]: {prompt[:60]}...") + kw: dict = { + "prompt": prompt, + "negative_prompt": negative_prompt, + "num_frames": num_frames, + "height": height, + "width": width, + "num_inference_steps": num_steps, + "guidance_scale": guidance_scale, + "generator": torch.Generator(device="cuda").manual_seed(seed), + } + if guidance_scale_2 is not None: + kw["guidance_scale_2"] = guidance_scale_2 + pipe(**kw) + + return forward_loop + + +def enable_sparsity_measurement(model: torch.nn.Module) -> None: + """Enable runtime sparsity measurement on all sparse attention modules.""" + for _name, module in model.named_modules(): + if isinstance(module, SparseAttentionModule) and module.is_enabled: + method = module._sparse_method_instance + if hasattr(method, "enable_measure_sparsity"): + method.reset_sparsity_counters() + method.enable_measure_sparsity(True) + + +def print_sparsity_summary(model: torch.nn.Module) -> None: + """Print per-module sparsity statistics including runtime kernel counters.""" + enabled, disabled = [], [] + for name, module in model.named_modules(): + if isinstance(module, SparseAttentionModule): + if module.is_enabled: + enabled.append((name, module)) + else: + disabled.append(name) + + print(f"\nSparse attention: {len(enabled)} enabled, {len(disabled)} disabled") + for name, module in enabled: + info = module.get_threshold_info() + print(f" {name}: {info}") + + +def print_runtime_sparsity(model: torch.nn.Module) -> None: + """Print runtime tile sparsity measured via kernel atomic counters.""" + total_all = 0 + skipped_all = 0 + per_module: list[tuple[str, int, int]] = [] + + for name, module in model.named_modules(): + if isinstance(module, SparseAttentionModule) and module.is_enabled: + method = module._sparse_method_instance + if hasattr(method, "get_sparsity_counters"): + total, skipped = method.get_sparsity_counters() + if total > 0: + per_module.append((name, total, skipped)) + total_all += total + skipped_all += skipped + + if total_all == 0: + print("\nNo runtime sparsity data collected.") + return + + print("\n" + "=" * 70) + print("Runtime tile sparsity (measured via kernel atomic counters)") + print("=" * 70) + for name, total, skipped in per_module: + ratio = skipped / total + print(f" {name}: {skipped:,}/{total:,} tiles skipped ({ratio:.1%})") + ratio_all = skipped_all / total_all + print("-" * 70) + print(f" Overall: {skipped_all:,}/{total_all:,} tiles skipped ({ratio_all:.1%})") + print("=" * 70) + + +def _get_num_blocks(transformer: torch.nn.Module) -> int: + """Count transformer blocks by looking for *.blocks.N.* submodules.""" + max_idx = -1 + for name, _ in transformer.named_modules(): + parts = name.split(".") + for i, part in enumerate(parts): + if part == "blocks" and i + 1 < len(parts) and parts[i + 1].isdigit(): + max_idx = max(max_idx, int(parts[i + 1])) + return max_idx + 1 + + +def main() -> None: + args = parse_args() + + # ---- Build pipeline ---- + print(f"Loading Wan 2.2 from {args.model_path}...") + pipe = build_pipeline(args.model_path) + + # ---- Collect transformers ---- + # Wan 2.2 5B has one transformer; 14B has two (transformer + transformer_2) + transformers = [] + if pipe.transformer is not None: + transformers.append(("transformer", pipe.transformer)) + if getattr(pipe, "transformer_2", None) is not None: + transformers.append(("transformer_2", pipe.transformer_2)) + + # ---- Sparsify (unless baseline) ---- + if args.baseline: + print("Baseline mode: running dense inference (default diffusers backend)") + elif args.triton_baseline: + print("Triton baseline: dense Triton FA kernel (no skip-softmax)") + for name, transformer in transformers: + num_blocks = _get_num_blocks(transformer) + print(f"Applying Triton backend to {name} ({num_blocks} blocks)...") + config = build_sparse_config(args, num_blocks=num_blocks) + mtsa.sparsify(transformer, config, forward_loop=None) + else: + # Build calibration forward loop if needed + forward_loop = None + if args.raw_threshold is not None: + print(f"Using fixed raw threshold: {args.raw_threshold} (skipping calibration)") + if args.calibrate: + print("Warning: --calibrate is ignored when --raw-threshold is set") + elif args.calibrate: + forward_loop = build_calibration_forward_loop( + pipe, + calib_size=args.calib_size, + num_steps=args.calib_steps, + num_frames=args.calib_frames, + height=args.height, + width=args.width, + seed=args.seed, + guidance_scale=args.guidance_scale, + guidance_scale_2=args.guidance_scale_2, + negative_prompt=args.negative_prompt, + ) + else: + print( + "Warning: neither --baseline, --raw-threshold, nor --calibrate specified; " + "using default static threshold" + ) + + for name, transformer in transformers: + num_blocks = _get_num_blocks(transformer) + print(f"Applying skip-softmax to {name} ({num_blocks} blocks)...") + config = build_sparse_config(args, num_blocks=num_blocks) + mtsa.sparsify(transformer, config, forward_loop=forward_loop) + + # ---- Free calibration memory before inference ---- + if not args.baseline and not args.triton_baseline and forward_loop is not None: + import gc + + gc.collect() + torch.cuda.empty_cache() + print("Cleared CUDA cache after calibration") + + # ---- Generate (optional) ---- + if args.prompt: + # Enable runtime sparsity measurement before generation + if args.report_avg_sparsity and not args.baseline: + for _name, transformer in transformers: + enable_sparsity_measurement(transformer) + + print(f"Generating: {args.prompt[:80]}...") + pipe_kwargs: dict = { + "prompt": args.prompt, + "negative_prompt": args.negative_prompt, + "num_frames": args.num_frames, + "height": args.height, + "width": args.width, + "num_inference_steps": args.num_steps, + "guidance_scale": args.guidance_scale, + "generator": torch.Generator(device="cuda").manual_seed(args.seed), + } + if args.guidance_scale_2 is not None: + pipe_kwargs["guidance_scale_2"] = args.guidance_scale_2 + output = pipe(**pipe_kwargs) + + export_to_video(output.frames[0], args.output, fps=16) + print(f"Saved to {args.output}") + + # ---- Print stats ---- + if not args.baseline: + for name, transformer in transformers: + print(f"\n{name}:") + print_sparsity_summary(transformer) + if args.report_avg_sparsity: + print_runtime_sparsity(transformer) + + +if __name__ == "__main__": + main() diff --git a/modelopt/torch/kernels/__init__.py b/modelopt/torch/kernels/__init__.py index 24d27a1ba2..fa07b06e20 100644 --- a/modelopt/torch/kernels/__init__.py +++ b/modelopt/torch/kernels/__init__.py @@ -21,6 +21,7 @@ IS_AVAILABLE = False attention = None +attention_calibrate = None register_triton_attention = None if torch.cuda.is_available(): @@ -32,8 +33,10 @@ ), ): from .triton_fa import attention as _attention + from .triton_fa import attention_calibrate as _attention_calibrate attention = _attention + attention_calibrate = _attention_calibrate IS_AVAILABLE = True from .hf_triton_attention import register_triton_attention as _register_triton_attention @@ -42,5 +45,6 @@ __all__ = [ "IS_AVAILABLE", "attention", + "attention_calibrate", "register_triton_attention", ] diff --git a/modelopt/torch/kernels/triton_fa.py b/modelopt/torch/kernels/triton_fa.py index 8d3b11f1af..8044383889 100644 --- a/modelopt/torch/kernels/triton_fa.py +++ b/modelopt/torch/kernels/triton_fa.py @@ -252,6 +252,9 @@ def _attn_fwd( DENSE_WINDOW_SIZE: tl.constexpr = 64, # Tokens near diagonal kept dense (absolute, BLOCK_N-independent) APPLY_SKIP_SOFTMAX: tl.constexpr = False, # Skip KV tiles with negligible scores SKIP_THRESHOLD_LOG2: tl.constexpr = 0.0, # log2(lambda) * sm_scale, pre-scaled for comparison on scaled scores + Sparsity_total=None, # Optional int64 scalar for counting total tiles (atomic) + Sparsity_skipped=None, # Optional int64 scalar for counting skipped tiles (atomic) + MEASURE_SPARSITY: tl.constexpr = False, # When True, count total/skipped tiles via atomic adds ): # --- Grid: (batch, num_q_heads, num_q_tiles) --- # Example: batch=2, num_q_heads=32, seq_len=256, BLOCK_M=128 @@ -347,6 +350,12 @@ def _attn_fwd( # Per-tile: skip entire tile only if ALL rows are negligible skip_tile = tl.min(can_skip.to(tl.int32)) == 1 + # Optional runtime sparsity measurement via atomic counters + if MEASURE_SPARSITY: + tl.atomic_add(Sparsity_total, 1) # count every tile + if skip_tile: + tl.atomic_add(Sparsity_skipped, 1) # count skipped tiles + if not skip_tile: m_new = tl.maximum(row_max, tile_row_max) p = tl.math.exp2(scores - m_new[:, None]) @@ -385,7 +394,9 @@ def _attn_fwd( row_max = m_new # --- Final normalization: output = acc / row_sum --- - acc = acc / row_sum[:, None] + # Clamp denominator to avoid 0/0 NaN when skip-softmax skips all KV tiles. + # Safe because acc is also 0 in that case (never accumulated), so 0/eps = 0. + acc = acc / tl.maximum(row_sum[:, None], 1e-6) # Save LSE for backward pass (log2-space: lse = max + log2(sum)) if STORE_LSE: @@ -768,6 +779,8 @@ def forward( num_sink_tokens, dense_window_size, skip_softmax_threshold, + skip_softmax_raw_threshold, + measure_sparsity, ): HEAD_DIM = q.shape[2] num_q_heads = q.shape[1] @@ -788,20 +801,36 @@ def forward( # Triton tiles must be powers of 2; pad head dim BLOCK_D = triton.next_power_of_2(HEAD_DIM) - # Skip-softmax: convert threshold to scaled log2 space for the kernel. - # The BLASST reference (https://arxiv.org/pdf/2512.12087) checks - # ln(lambda) on unscaled scores. Our kernel works in log2-scaled space - # (scores pre-multiplied by qk_scale = sm_scale * LOG2E), so we - # pre-scale: threshold_scaled = log2(lambda) * sm_scale. - apply_skip = skip_softmax_threshold is not None and skip_softmax_threshold > 0.0 - if apply_skip: + # Skip-softmax threshold in scaled log2 space for the kernel. + # Two modes: + # 1. raw_threshold: passed directly as skip_threshold_log2 (for testing) + # 2. lambda threshold: converted via log2(lambda) * sm_scale + if skip_softmax_raw_threshold is not None: + apply_skip = True + skip_threshold_log2 = skip_softmax_raw_threshold + elif skip_softmax_threshold is not None and skip_softmax_threshold > 0.0: + apply_skip = True + # The BLASST reference (https://arxiv.org/pdf/2512.12087) checks + # ln(lambda) on unscaled scores. Our kernel works in log2-scaled space + # (scores pre-multiplied by qk_scale = sm_scale * LOG2E), so we + # pre-scale: threshold_scaled = log2(lambda) * sm_scale. skip_threshold_log2 = math.log2(skip_softmax_threshold) * sm_scale else: + apply_skip = False skip_threshold_log2 = 0.0 o = torch.empty_like(q) lse = torch.empty(q.shape[0], num_q_heads, device=q.device, dtype=torch.float32) + # Optional runtime sparsity counters (single int64 scalars for atomic adds) + do_measure = measure_sparsity and apply_skip + if do_measure: + sparsity_total = torch.zeros(1, dtype=torch.int64, device=q.device) + sparsity_skipped = torch.zeros(1, dtype=torch.int64, device=q.device) + else: + sparsity_total = None + sparsity_skipped = None + # Grid: (batch, q_heads, q_tiles). Uses a function because BLOCK_M is autotuned. def grid(META): return (batch, num_q_heads, triton.cdiv(max_input_len, META["BLOCK_M"])) @@ -839,9 +868,17 @@ def grid(META): DENSE_WINDOW_SIZE=dense_window_size, APPLY_SKIP_SOFTMAX=apply_skip, SKIP_THRESHOLD_LOG2=skip_threshold_log2, + Sparsity_total=sparsity_total, + Sparsity_skipped=sparsity_skipped, + MEASURE_SPARSITY=do_measure, # BLOCK_M, BLOCK_N, num_warps, num_stages chosen by autotune ) + # Store sparsity counters on the output tensor for retrieval by callers + if do_measure: + o._sparsity_total = sparsity_total.item() + o._sparsity_skipped = sparsity_skipped.item() + ctx.save_for_backward(q, k, v, o, lse, b_start_loc, b_seq_len, b_start_loc_k, b_seq_len_k) ctx.max_input_len = max_input_len ctx.max_input_len_k = max_input_len_k @@ -985,6 +1022,8 @@ def backward(ctx, grad_output): None, None, None, + None, + None, ) @@ -1006,6 +1045,8 @@ def attention( num_sink_tokens: int = 0, dense_window_size: int = 64, skip_softmax_threshold: float | None = None, + skip_softmax_raw_threshold: float | None = None, + measure_sparsity: bool = False, ) -> torch.Tensor: """Variable-length flash attention with GQA, autograd, and optional N:M sparse softmax and skip-softmax. @@ -1037,6 +1078,16 @@ def attention( softmax contribution is negligible. Tiles are skipped entirely (no softmax, V load, or BMM2). The threshold is applied on unscaled scores. Set to ``None`` or ``0`` to disable. + skip_softmax_raw_threshold: Raw ``skip_threshold_log2`` value passed + directly to the kernel without conversion. The kernel skips tiles + where ``tile_row_max < row_max + raw_threshold``. Typical values + are negative (e.g., ``-5.0`` means tiles must be within 5 units of + the running max in the kernel's scaled score space). Takes + precedence over ``skip_softmax_threshold`` when both are set. + measure_sparsity: When True and skip-softmax is active, count total + and skipped tiles via atomic counters. The counts are stored as + ``_sparsity_total`` and ``_sparsity_skipped`` attributes on the + returned output tensor. Returns: Output tensor [total_q_tokens, num_q_heads, head_dim]. @@ -1059,7 +1110,271 @@ def attention( num_sink_tokens, dense_window_size, skip_softmax_threshold, + skip_softmax_raw_threshold, + measure_sparsity, + ) + + +# --------------------------------------------------------------------------- +# Calibration kernel: collect multi-threshold skip-softmax sparsity stats +# --------------------------------------------------------------------------- +@triton.jit +def _attn_fwd_calibrate( + Q, + K, + V, + qk_scale, + b_start_loc, + b_seq_len, + b_start_loc_k, + b_seq_len_k, + Out, + stride_qbs, + stride_qh, + stride_kbs, + stride_kh, + stride_vbs, + stride_vh, + stride_obs, + stride_oh, + Threshold_trials, # [NUM_THRESHOLDS] float32 — pre-scaled to log2 space + Per_program_totals, # [num_programs * NUM_THRESHOLDS] int32 — per-program tile counts + Per_program_skipped, # [num_programs * NUM_THRESHOLDS] int32 — per-program skip counts + kv_group_num: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_D: tl.constexpr, + BLOCK_N: tl.constexpr, + IS_CAUSAL: tl.constexpr, + HEAD_DIM: tl.constexpr, + NUM_THRESHOLDS: tl.constexpr, + PADDED_THRESHOLDS: tl.constexpr, # next_power_of_2(NUM_THRESHOLDS) for tl.arange +): + """Forward kernel with multi-threshold sparsity measurement. + + Computes full attention (no skipping) while counting how many KV tiles + would be skipped at each threshold. Each program writes its local counts + to ``Per_program_totals`` and ``Per_program_skipped``; the Python wrapper + sums across programs afterward. This avoids global atomic contention. + """ + batch_idx = tl.program_id(0) + head_idx = tl.program_id(1) + tile_q = tl.program_id(2) + kv_head_idx = head_idx // kv_group_num + + seq_len_q = tl.load(b_seq_len + batch_idx) + seq_len_kv = tl.load(b_seq_len_k + batch_idx) + q_offset = tl.load(b_start_loc + batch_idx) + kv_offset = tl.load(b_start_loc_k + batch_idx) + + if tile_q * BLOCK_M >= seq_len_q: + return + + q_pos = tile_q * BLOCK_M + tl.arange(0, BLOCK_M) + kv_pos = tl.arange(0, BLOCK_N) + dim_pos = tl.arange(0, BLOCK_D) + d_mask = dim_pos < HEAD_DIM + + q_ptrs = (q_offset + q_pos[:, None]) * stride_qbs + head_idx * stride_qh + dim_pos[None, :] + q = tl.load(Q + q_ptrs, mask=(q_pos[:, None] < seq_len_q) & d_mask[None, :], other=0.0) + + k_base = K + kv_head_idx * stride_kh + v_base = V + kv_head_idx * stride_vh + + row_max = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + row_sum = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, BLOCK_D], dtype=tl.float32) + + # Pre-load all thresholds once (vectorized, stays in registers). + # tl.arange requires power-of-2 size, so use PADDED_THRESHOLDS with masking. + thresh_offs = tl.arange(0, PADDED_THRESHOLDS) + thresh_mask = thresh_offs < NUM_THRESHOLDS + thresholds = tl.load(Threshold_trials + thresh_offs, mask=thresh_mask, other=float("inf")) + + # Per-program local counters: avoid global atomic contention in inner loop. + # Each program accumulates locally, then writes once to Per_program buffers. + local_skipped = tl.zeros([PADDED_THRESHOLDS], dtype=tl.int32) + num_tiles = 0 + + kv_bound = seq_len_kv if not IS_CAUSAL else tl.minimum((tile_q + 1) * BLOCK_M, seq_len_kv) + + for kv_start in range(0, kv_bound, BLOCK_N): + kv_start = tl.multiple_of(kv_start, BLOCK_N) + + k_offs = (kv_offset + kv_start + kv_pos[None, :]) * stride_kbs + dim_pos[:, None] + k = tl.load( + k_base + k_offs, + mask=((kv_start + kv_pos[None, :]) < seq_len_kv) & d_mask[:, None], + other=0.0, + ) + + scores = tl.dot(q, k) * qk_scale + scores = _apply_mask(scores, q_pos, kv_pos, seq_len_q, seq_len_kv, kv_start, IS_CAUSAL) + + tile_row_max = tl.max(scores, 1) + + # --- Vectorized multi-threshold sparsity measurement --- + # A tile is skipped iff ALL Q rows satisfy: tile_row_max < row_max + thresh. + # Equivalently: max(tile_row_max - row_max) < thresh (worst-case row + # must still be below threshold for the tile to be skippable). + max_gap = tl.max(tile_row_max - row_max) # scalar + skip_mask = (max_gap < thresholds).to(tl.int32) # [PADDED_THRESHOLDS] + local_skipped += skip_mask + num_tiles += 1 + + # --- Always compute full attention (no skipping) --- + m_new = tl.maximum(row_max, tile_row_max) + p = tl.math.exp2(scores - m_new[:, None]) + l_new = tl.sum(p, 1) + correction = tl.math.exp2(row_max - m_new) + row_sum = row_sum * correction + l_new + acc = acc * correction[:, None] + + v_offs = (kv_offset + kv_start + kv_pos[:, None]) * stride_vbs + dim_pos[None, :] + v = tl.load( + v_base + v_offs, + mask=((kv_start + kv_pos[:, None]) < seq_len_kv) & d_mask[None, :], + other=0.0, + ) + acc = tl.dot(p.to(v.dtype), v, acc) + row_max = m_new + + # --- Write per-program counters (no atomics, just stores) --- + # Compute unique flat program index for this (batch, head, q_tile) + num_q_tiles = tl.cdiv(tl.load(b_seq_len + 0), BLOCK_M) # conservative upper bound + num_heads = tl.num_programs(1) + prog_idx = batch_idx * num_heads * num_q_tiles + head_idx * num_q_tiles + tile_q + base = prog_idx * NUM_THRESHOLDS + tl.store( + Per_program_totals + base + thresh_offs, + tl.full([PADDED_THRESHOLDS], num_tiles, dtype=tl.int32), + mask=thresh_mask, ) + tl.store( + Per_program_skipped + base + thresh_offs, + local_skipped, + mask=thresh_mask, + ) + + acc = acc / tl.maximum(row_sum[:, None], 1e-6) + o_ptrs = (q_offset + q_pos[:, None]) * stride_obs + head_idx * stride_oh + dim_pos[None, :] + tl.store(Out + o_ptrs, acc, mask=(q_pos[:, None] < seq_len_q) & d_mask[None, :]) + + +def attention_calibrate( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + b_start_loc: torch.Tensor, + b_seq_len: torch.Tensor, + max_input_len: int, + is_causal: bool = True, + softmax_scale: float | None = None, + b_start_loc_k: torch.Tensor | None = None, + b_seq_len_k: torch.Tensor | None = None, + max_input_len_k: int | None = None, + *, + threshold_trials: list[float] | None = None, +) -> tuple[torch.Tensor, torch.Tensor]: + """Flash attention with multi-threshold skip-softmax sparsity measurement. + + Computes full attention (identical output to dense attention) while + measuring how many KV tiles would be skipped at each threshold in + ``threshold_trials``. No autograd — forward only. + + Args: + q, k, v, b_start_loc, b_seq_len, max_input_len, is_causal, + softmax_scale, b_start_loc_k, b_seq_len_k, max_input_len_k: + Same as :func:`attention`. + threshold_trials: List of threshold values to measure sparsity for. + Each value is converted to log2-scaled space for the kernel. + + Returns: + Tuple of (output, sparsity_counters): + - output: ``[total_q_tokens, num_q_heads, head_dim]`` + - sparsity_counters: ``[num_thresholds, 2]`` int64 tensor where + ``[:, 0]`` = total tile evaluations, ``[:, 1]`` = skipped tiles. + Sparsity per threshold = ``counters[:, 1] / counters[:, 0]``. + """ + if threshold_trials is None or len(threshold_trials) == 0: + raise ValueError("threshold_trials must be a non-empty list") + + HEAD_DIM = q.shape[2] + num_q_heads = q.shape[1] + num_kv_heads = k.shape[1] + kv_group_num = num_q_heads // num_kv_heads + batch = b_seq_len.shape[0] + sm_scale = 1.0 / (HEAD_DIM**0.5) if softmax_scale is None else softmax_scale + qk_scale = sm_scale * LOG2E + BLOCK_D = triton.next_power_of_2(HEAD_DIM) + BLOCK_M = 128 + BLOCK_N = 64 + + if b_seq_len_k is None: + b_seq_len_k = b_seq_len + b_start_loc_k = b_start_loc + + num_thresholds = len(threshold_trials) + + # Convert thresholds to log2-scaled space: log2(lambda) * sm_scale + threshold_tensor = torch.tensor( + [math.log2(t) * sm_scale for t in threshold_trials], + dtype=torch.float32, + device=q.device, + ) + + o = torch.empty_like(q) + + num_q_tiles = triton.cdiv(max_input_len, BLOCK_M) + grid = (batch, num_q_heads, num_q_tiles) + num_programs = batch * num_q_heads * num_q_tiles + + # Per-program output buffers (no atomics needed — each program writes its own row) + per_program_totals = torch.zeros( + num_programs * num_thresholds, dtype=torch.int32, device=q.device + ) + per_program_skipped = torch.zeros( + num_programs * num_thresholds, dtype=torch.int32, device=q.device + ) + + _attn_fwd_calibrate[grid]( + q, + k, + v, + qk_scale, + b_start_loc, + b_seq_len, + b_start_loc_k, + b_seq_len_k, + o, + q.stride(0), + q.stride(1), + k.stride(0), + k.stride(1), + v.stride(0), + v.stride(1), + o.stride(0), + o.stride(1), + threshold_tensor, + per_program_totals, + per_program_skipped, + kv_group_num=kv_group_num, + BLOCK_M=BLOCK_M, + BLOCK_D=BLOCK_D, + BLOCK_N=BLOCK_N, + IS_CAUSAL=is_causal, + HEAD_DIM=HEAD_DIM, + NUM_THRESHOLDS=num_thresholds, + PADDED_THRESHOLDS=triton.next_power_of_2(num_thresholds), + num_warps=4, + num_stages=1, + ) + + # Reduce across programs: sum per-program counts → [num_thresholds] + totals = per_program_totals.view(num_programs, num_thresholds).sum(dim=0).to(torch.int64) + skipped = per_program_skipped.view(num_programs, num_thresholds).sum(dim=0).to(torch.int64) + sparsity_counters = torch.stack([totals, skipped], dim=1) # [num_thresholds, 2] + + return o, sparsity_counters -__all__ = ["attention"] +__all__ = ["attention", "attention_calibrate"] diff --git a/modelopt/torch/sparsity/attention_sparsity/calibration/calibrate.py b/modelopt/torch/sparsity/attention_sparsity/calibration/calibrate.py index dbc4d5bc27..f63feac69e 100644 --- a/modelopt/torch/sparsity/attention_sparsity/calibration/calibrate.py +++ b/modelopt/torch/sparsity/attention_sparsity/calibration/calibrate.py @@ -21,7 +21,6 @@ import torch import torch.nn as nn -from transformers import AutoTokenizer from modelopt.torch.utils import get_module_device @@ -32,8 +31,10 @@ from .ruler_dataset import RulerDatasetBuilder -def _load_tokenizer(tokenizer_name_or_path: str) -> "AutoTokenizer": +def _load_tokenizer(tokenizer_name_or_path: str): """Load tokenizer and ensure pad_token is set.""" + from transformers import AutoTokenizer + tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path) if not tokenizer.pad_token: tokenizer.pad_token = tokenizer.eos_token @@ -255,11 +256,14 @@ def calibrate_sparse_attention( print(f"Calibrating {len(sparse_modules)} sparse attention modules together...") - # Extract tokenizer and build calibration data if needed - tokenizer = _extract_tokenizer_from_model(model) + # Extract tokenizer and build calibration data only if no forward_loop is provided. + # When the user supplies their own forward_loop (e.g. for diffusion models), + # RULER dataset generation is skipped entirely. + tokenizer = None calibration_data = None - if calibrate_prefill or calibrate_decode: + if forward_loop is None and (calibrate_prefill or calibrate_decode): + tokenizer = _extract_tokenizer_from_model(model) builder = RulerDatasetBuilder( samples=calib_config.samples, max_seqlen=calib_config.max_seqlen, @@ -280,14 +284,19 @@ def calibrate_sparse_attention( print("PREFILL PHASE CALIBRATION") print("=" * 60) - if calibration_data is None: + if forward_loop is None and calibration_data is None: raise RuntimeError("calibration_data must be built before prefill") - prefill_forward_loop = forward_loop or create_calibration_forward_loop( - calibration_data, tokenizer, chunk_size=calib_config.chunk_size - ) + if forward_loop is not None: + prefill_forward_loop = forward_loop + else: + assert calibration_data is not None and tokenizer is not None + prefill_forward_loop = create_calibration_forward_loop( + calibration_data, tokenizer, chunk_size=calib_config.chunk_size + ) prefill_calibrator = DynamicThresholdCalibrator( threshold_trials=calib_config.threshold_trials, + fit_logspace=calib_config.fit_logspace, ) prefill_result = prefill_calibrator.calibrate(model, prefill_forward_loop, phase="prefill") @@ -302,14 +311,15 @@ def calibrate_sparse_attention( print("DECODE PHASE CALIBRATION") print("=" * 60) - if calibration_data is None: - raise RuntimeError("calibration_data must be built before decode") + if calibration_data is None or tokenizer is None: + raise RuntimeError("calibration_data and tokenizer must be built before decode") decode_forward_loop = create_decode_calibration_forward_loop( calibration_data, tokenizer, num_decode_tokens=calib_config.num_decode_tokens ) decode_calibrator = DynamicThresholdCalibrator( threshold_trials=calib_config.threshold_trials, + fit_logspace=calib_config.fit_logspace, ) decode_result = decode_calibrator.calibrate(model, decode_forward_loop, phase="decode") @@ -323,15 +333,20 @@ def calibrate_sparse_attention( warnings.warn("No calibration produced valid results") return {} - # Extract a and b for each phase + # Extract a, b, and observed sparsity range for each phase calibration_params: dict[str, dict[str, float]] = {} for phase in ["prefill", "decode"]: if phase in calibration_results: result = calibration_results[phase] - calibration_params[phase] = { + params: dict[str, float] = { "a": result["a"], "b": result["b"], } + if "min_observed_sparsity" in result: + params["min_observed_sparsity"] = result["min_observed_sparsity"] + if "max_observed_sparsity" in result: + params["max_observed_sparsity"] = result["max_observed_sparsity"] + calibration_params[phase] = params # Apply calibration params to all modules print("\n" + "=" * 60) @@ -341,7 +356,7 @@ def calibrate_sparse_attention( for phase, params in calibration_params.items(): result = calibration_results[phase] print(f" {phase}:") - print(f" Model: scale_factor = {params['a']:.6f} * exp({params['b']:.4f} * sparsity)") + print(f" Model: scale_factor = {params['a']:.6e} * exp({params['b']:.4f} * sparsity)") print(f" R-squared: {result['r_squared']:.6f}") for module_name, module in sparse_modules: diff --git a/modelopt/torch/sparsity/attention_sparsity/calibration/calibrator.py b/modelopt/torch/sparsity/attention_sparsity/calibration/calibrator.py index 6821206937..d3ed330325 100644 --- a/modelopt/torch/sparsity/attention_sparsity/calibration/calibrator.py +++ b/modelopt/torch/sparsity/attention_sparsity/calibration/calibrator.py @@ -55,12 +55,16 @@ class DynamicThresholdCalibrator: def __init__( self, threshold_trials: list[float] | None = None, + fit_logspace: bool = False, ): """Initialize dynamic threshold calibrator. Args: threshold_trials: List of thresholds to try during calibration. Should span a range that achieves sparsities from ~10% to ~95%. + fit_logspace: If True, fit the exponential model in log space + (minimizes relative error). Recommended for diffusion models + where scale_factors span many orders of magnitude. """ # Default threshold trials if not provided self.threshold_trials = threshold_trials or [ @@ -85,6 +89,7 @@ def __init__( 9.5e-1, 9.9e-1, ] + self.fit_logspace = fit_logspace def calibrate(self, model: nn.Module, forward_loop: Callable, phase: str) -> dict[str, Any]: """Calibrate a and b parameters for Exponential model. @@ -167,6 +172,8 @@ def calibrate(self, model: nn.Module, forward_loop: Callable, phase: str) -> dic # Filter out extreme sparsities (must be in (10%, 90%)) # Extreme values are unreliable for fitting valid_mask = (sparsities >= 0.10) & (sparsities <= 0.90) + if self.fit_logspace: + valid_mask &= scale_factors > 0 # log requires positive values scale_factors = scale_factors[valid_mask] sparsities = sparsities[valid_mask] @@ -176,47 +183,81 @@ def calibrate(self, model: nn.Module, forward_loop: Callable, phase: str) -> dic ) return {} - # Define Exponential model: sf = a * exp(b * S) - def exponential(sparsity, a, b): - return a * np.exp(b * sparsity) + # Record observed sparsity range for feasibility checks at inference + min_observed_sparsity = float(np.min(sparsities)) + max_observed_sparsity = float(np.max(sparsities)) - # Fit the model try: - popt, pcov = curve_fit( - exponential, - sparsities, - scale_factors, - p0=[1.0, 5.0], # Initial guess - bounds=([0.0, 0.0], [np.inf, 20.0]), # Bounds for a and b - maxfev=10000, - ) - a, b = popt + if self.fit_logspace: + # Log-space fit: minimizes relative error. Recommended for + # diffusion models where scale_factors span many orders of + # magnitude (e.g. 0.06 to 57,000) — a linear-space fit would + # be dominated by the largest values. + log_scale_factors = np.log(scale_factors) + + def log_exponential(sparsity, log_a, b): + return log_a + b * sparsity + + popt, pcov = curve_fit( + log_exponential, + sparsities, + log_scale_factors, + p0=[0.0, 10.0], + maxfev=10000, + ) + log_a, b = popt + a = np.exp(log_a) + + # R-squared in log space (where the fit was performed) + pred = log_exponential(sparsities, log_a, b) + ss_res = np.sum((log_scale_factors - pred) ** 2) + ss_tot = np.sum((log_scale_factors - np.mean(log_scale_factors)) ** 2) + else: + # Linear-space fit (default): minimizes absolute error. + + def exponential(sparsity, a, b): + return a * np.exp(b * sparsity) + + popt, pcov = curve_fit( + exponential, + sparsities, + scale_factors, + p0=[1.0, 5.0], + bounds=([0.0, 0.0], [np.inf, 20.0]), + maxfev=10000, + ) + a, b = popt + + pred = exponential(sparsities, a, b) + ss_res = np.sum((scale_factors - pred) ** 2) + ss_tot = np.sum((scale_factors - np.mean(scale_factors)) ** 2) except Exception as e: warnings.warn(f"Curve fitting failed: {e}") return {} - # Calculate R-squared and RMSE - pred_scale_factors = exponential(sparsities, a, b) - ss_res = np.sum((scale_factors - pred_scale_factors) ** 2) - ss_tot = np.sum((scale_factors - np.mean(scale_factors)) ** 2) r_squared = 1 - (ss_res / ss_tot) if ss_tot > 0 else 0 - rmse = np.sqrt(np.mean((scale_factors - pred_scale_factors) ** 2)) - print(f"\n{phase.capitalize()} Calibration Results (Exponential Model):") + fit_label = "log-space" if self.fit_logspace else "linear-space" + print(f"\n{phase.capitalize()} Calibration Results (Exponential Model, {fit_label} fit):") print(" Model: scale_factor = a * exp(b * sparsity)") - print(f" Fitted a: {a:.6f}") + print(f" Fitted a: {a:.6e}") print(f" Fitted b: {b:.4f}") print(f" R-squared: {r_squared:.6f}") - print(f" RMSE: {rmse:.2f}") + print( + f" Observed sparsity range: [{min_observed_sparsity:.1%}, {max_observed_sparsity:.1%}]" + ) print(f" Data points used: {int(np.sum(valid_mask))} / {len(all_data_points)}") # Show scale_factor for various target sparsities print("\nScale factors for different target sparsities:") - print(f" {'Target':<10} {'Scale Factor':<15}") - print(f" {'-' * 10} {'-' * 15}") - for target in [0.5, 0.7, 0.8, 0.9, 0.95]: + print(f" {'Target':<10} {'Scale Factor':<15} {'Note':<20}") + print(f" {'-' * 10} {'-' * 15} {'-' * 20}") + for target in [0.3, 0.4, 0.5, 0.6, 0.7, 0.8]: sf = a * np.exp(b * target) - print(f" {target:<10.0%} {sf:<15.2f}") + note = "" + if target < min_observed_sparsity or target > max_observed_sparsity: + note = "(extrapolation)" + print(f" {target:<10.0%} {sf:<15.4f} {note:<20}") # Print calibration data summary by threshold print("\nCalibration data summary (per threshold):") @@ -239,10 +280,11 @@ def exponential(sparsity, a, b): "a": float(a), "b": float(b), "r_squared": float(r_squared), - "rmse": float(rmse), "num_data_points": int(np.sum(valid_mask)), "total_samples": len(all_data_points), "calibration_type": "exponential", + "min_observed_sparsity": min_observed_sparsity, + "max_observed_sparsity": max_observed_sparsity, } def _enable_calibration_mode(self, modules: list[nn.Module]): @@ -333,6 +375,16 @@ def _extract_calibration_stats( return aggregated_stats def _set_thresholds(self, modules: list[nn.Module], thresholds: list[float]): - """Set thresholds list on sparse attention modules.""" + """Set thresholds list on sparse attention modules. + + Supports both flash_skip_softmax (sets ``thresholds`` attribute) and + triton_skip_softmax (sets ``_threshold_trials`` attribute). + """ for module in modules: - module._sparse_method_instance.thresholds = thresholds + method = module._sparse_method_instance + if hasattr(method, "_threshold_trials"): + # triton_skip_softmax: calibration uses Triton calibration kernel + method._threshold_trials = thresholds + else: + # flash_skip_softmax: calibration uses F.softmax patching + method.thresholds = thresholds diff --git a/modelopt/torch/sparsity/attention_sparsity/config.py b/modelopt/torch/sparsity/attention_sparsity/config.py index fa415b322b..eed50b87af 100644 --- a/modelopt/torch/sparsity/attention_sparsity/config.py +++ b/modelopt/torch/sparsity/attention_sparsity/config.py @@ -139,6 +139,17 @@ class SparseAttentionAttributeConfig(ModeloptBaseConfig): ), ) + skip_softmax_raw_threshold: float | None = ModeloptField( + default=None, + title="Raw skip-softmax threshold (skip_threshold_log2).", + description=( + "Raw value passed directly to the Triton kernel as skip_threshold_log2. " + "The kernel skips tiles where tile_row_max < row_max + raw_threshold. " + "Typical values are negative (e.g., -5.0). Takes precedence over " + "skip_softmax_threshold and calibration when set." + ), + ) + @field_validator("method") @classmethod def validate_method(cls, v): @@ -326,6 +337,15 @@ class CalibrationConfig(ModeloptBaseConfig): ), ) + fit_logspace: bool = ModeloptField( + default=False, + title="Fit in log space", + description=( + "If True, fit the exponential model in log space (minimizes relative error). " + "Recommended for diffusion models where scale_factors span many orders of magnitude." + ), + ) + cache_dir: str | None = ModeloptField( default=None, title="Cache directory", diff --git a/modelopt/torch/sparsity/attention_sparsity/conversion.py b/modelopt/torch/sparsity/attention_sparsity/conversion.py index 6ba238e77c..cc92819850 100644 --- a/modelopt/torch/sparsity/attention_sparsity/conversion.py +++ b/modelopt/torch/sparsity/attention_sparsity/conversion.py @@ -115,6 +115,37 @@ def is_attn_sparsified(model: nn.Module) -> bool: return any(isinstance(module, SparseAttentionModule) for module in model.modules()) +def _register_diffusers_backends_if_needed(model: nn.Module) -> None: + """Register diffusers/LTX Triton attention backends if the model needs them. + + Called before plugin registration so that the backends are available + when ``SparseAttentionModule.forward()`` activates the skip-softmax context. + """ + import contextlib + + # Register the diffusers Triton backend if the model is a diffusers ModelMixin + try: + from diffusers.models.modeling_utils import ModelMixin + + if isinstance(model, ModelMixin): + from .kernels import register_diffusers_triton_attention + + if register_diffusers_triton_attention is not None: + register_diffusers_triton_attention() + except (ImportError, Exception): + pass + + # Patch ltx_core Attention modules if present (independent of diffusers) + try: + from .kernels import register_ltx_triton_attention + except (ImportError, RuntimeError): + return + + if register_ltx_triton_attention is not None: + with contextlib.suppress(Exception): + register_ltx_triton_attention(model) + + def convert_to_sparse_attention_model( model: ModelLikeModule, config: SparseAttentionConfig ) -> ConvertReturnType: @@ -130,6 +161,9 @@ def convert_to_sparse_attention_model( # Initialize the true module if necessary model = model.init_modellike() if isinstance(model, ModelLikeModule) else model + # Register diffusers backends for diffusion models + _register_diffusers_backends_if_needed(model) + # Set the correct attn_implementation for the chosen backend _set_attn_implementation(model, config) @@ -484,6 +518,8 @@ def print_sparse_attention_summary(model: nn.Module): # Group by (method, threshold) groups: dict[tuple[str, str], int] = {} for _, module in sparse_modules: + if not module.is_enabled: + continue method = getattr(module, "_method", "unknown") threshold = _format_threshold(module.get_threshold_info()) groups[(method, threshold)] = groups.get((method, threshold), 0) + 1 diff --git a/modelopt/torch/sparsity/attention_sparsity/kernels/__init__.py b/modelopt/torch/sparsity/attention_sparsity/kernels/__init__.py index dee1bc472a..0cc4a202f5 100644 --- a/modelopt/torch/sparsity/attention_sparsity/kernels/__init__.py +++ b/modelopt/torch/sparsity/attention_sparsity/kernels/__init__.py @@ -13,12 +13,48 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Re-exports from modelopt.torch.kernels for backward compatibility.""" +"""Kernel integrations for sparse attention: Triton FA and diffusers/LTX backends.""" + +import contextlib +import threading from modelopt.torch.kernels import IS_AVAILABLE, attention, register_triton_attention +# --------------------------------------------------------------------------- +# Optional backend registrations (depend on diffusers / ltx_core) +# --------------------------------------------------------------------------- +register_diffusers_triton_attention = None +register_ltx_triton_attention = None + +# Suppress ImportError (missing package) and RuntimeError (triton without GPU driver) +with contextlib.suppress(ImportError, RuntimeError): + from .diffusers_triton_attention import register_diffusers_triton_attention + +with contextlib.suppress(ImportError, RuntimeError): + from .ltx_triton_attention import register_ltx_triton_attention + +# --------------------------------------------------------------------------- +# Thread-local flag for flash_skip_softmax's eager-attention context +# --------------------------------------------------------------------------- +_thread_local = threading.local() + + +def set_skip_softmax_context(active: bool) -> None: + """Set whether skip-softmax softmax patching is active (thread-local).""" + _thread_local.skip_softmax_active = active + + +def get_skip_softmax_context() -> bool: + """Return whether skip-softmax softmax patching is active.""" + return getattr(_thread_local, "skip_softmax_active", False) + + __all__ = [ "IS_AVAILABLE", "attention", + "get_skip_softmax_context", + "register_diffusers_triton_attention", + "register_ltx_triton_attention", "register_triton_attention", + "set_skip_softmax_context", ] diff --git a/modelopt/torch/sparsity/attention_sparsity/kernels/diffusers_triton_attention.py b/modelopt/torch/sparsity/attention_sparsity/kernels/diffusers_triton_attention.py new file mode 100644 index 0000000000..2923447cf0 --- /dev/null +++ b/modelopt/torch/sparsity/attention_sparsity/kernels/diffusers_triton_attention.py @@ -0,0 +1,251 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Triton flash attention backend for diffusers models. + +Registers a ``modelopt_triton`` backend in diffusers' ``_AttentionBackendRegistry`` +that converts the diffusers [B, S, H, D] layout to the Triton FA kernel's varlen +[total_tokens, H, D] format. + +Two modes: +- **Inference**: Calls ``attention()`` with skip-softmax tile skipping. +- **Calibration**: Calls ``attention_calibrate()`` to collect multi-threshold + sparsity statistics without skipping any tiles. +""" + +import inspect +import math +import threading + +import torch +from diffusers.models.attention_dispatch import ( + AttentionBackendName, + _AttentionBackendRegistry, + attention_backend, +) + +from modelopt.torch.kernels import attention, attention_calibrate + +_BACKEND_NAME = "modelopt_triton" +_BACKEND_REGISTERED = False + +# Thread-local storage for per-forward skip-softmax configuration. +_thread_local = threading.local() + + +def set_triton_skip_softmax_config( + threshold: float | None = None, + calibration_mode: bool = False, + threshold_trials: list[float] | None = None, + scale_factor: float | None = None, + raw_threshold: float | None = None, + measure_sparsity: bool = False, +) -> None: + """Set thread-local skip-softmax config for the next Triton attention call. + + Args: + threshold: Skip-softmax threshold for inference mode (static). + calibration_mode: If True, use the calibration kernel to collect + multi-threshold sparsity stats instead of skipping tiles. + threshold_trials: List of thresholds to measure sparsity for + (only used when calibration_mode=True). + scale_factor: Calibrated scale factor for dynamic threshold computation. + When set, the actual threshold is computed as ``scale_factor / seq_k`` + at attention call time, adapting to the actual sequence length. + raw_threshold: Raw ``skip_threshold_log2`` value passed directly to the + kernel without conversion. Takes precedence over other thresholds. + measure_sparsity: If True, count total and skipped tiles during + inference via atomic counters in the forward kernel. + """ + _thread_local.skip_threshold = threshold + _thread_local.calibration_mode = calibration_mode + _thread_local.threshold_trials = threshold_trials + _thread_local.scale_factor = scale_factor + _thread_local.raw_threshold = raw_threshold + _thread_local.measure_sparsity = measure_sparsity + # Accumulated counters across all attention calls in one forward pass + _thread_local.calibration_counters = None + _thread_local.calibration_seq_k = None + # Accumulated runtime sparsity counters (total_tiles, skipped_tiles) + _thread_local.sparsity_total = 0 + _thread_local.sparsity_skipped = 0 + + +def clear_triton_skip_softmax_config() -> None: + """Clear thread-local skip-softmax config.""" + _thread_local.skip_threshold = None + _thread_local.calibration_mode = False + _thread_local.threshold_trials = None + _thread_local.scale_factor = None + _thread_local.raw_threshold = None + _thread_local.measure_sparsity = False + _thread_local.calibration_counters = None + _thread_local.calibration_seq_k = None + _thread_local.sparsity_total = 0 + _thread_local.sparsity_skipped = 0 + + +def get_calibration_counters() -> "torch.Tensor | None": + """Return accumulated calibration counters ``[num_thresholds, 2]`` or None.""" + return getattr(_thread_local, "calibration_counters", None) + + +def get_calibration_seq_k() -> int | None: + """Return KV sequence length observed during calibration, or None.""" + return getattr(_thread_local, "calibration_seq_k", None) + + +def get_sparsity_counters() -> tuple[int, int]: + """Return accumulated runtime sparsity counters ``(total_tiles, skipped_tiles)``.""" + return ( + getattr(_thread_local, "sparsity_total", 0), + getattr(_thread_local, "sparsity_skipped", 0), + ) + + +# --------------------------------------------------------------------------- +# Triton attention implementation for diffusers layout +# --------------------------------------------------------------------------- + + +def _diffusers_triton_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_mask: torch.Tensor | None = None, + dropout_p: float = 0.0, + is_causal: bool = False, + scale: float | None = None, + enable_gqa: bool = False, +) -> torch.Tensor: + """Compute attention via Triton FA kernel on diffusers layout ``[B, S, H, D]``.""" + batch, seq_q, num_heads_q, head_dim = query.shape + seq_k = key.shape[1] + device = query.device + + # Reshape from diffusers [B, S, H, D] -> flat [B*S, H, D] + q = query.reshape(batch * seq_q, num_heads_q, head_dim).contiguous() + k = key.reshape(batch * seq_k, key.shape[2], head_dim).contiguous() + v = value.reshape(batch * seq_k, value.shape[2], head_dim).contiguous() + + # Build varlen metadata + b_start_loc_q = torch.arange(batch, device=device, dtype=torch.int32) * seq_q + b_seq_len_q = torch.full((batch,), seq_q, device=device, dtype=torch.int32) + + if scale is None: + scale = 1.0 / math.sqrt(head_dim) + + kw: dict = { + "b_start_loc": b_start_loc_q, + "b_seq_len": b_seq_len_q, + "max_input_len": seq_q, + "is_causal": is_causal, + "softmax_scale": scale, + } + + if seq_q != seq_k: + b_start_loc_k = torch.arange(batch, device=device, dtype=torch.int32) * seq_k + b_seq_len_k = torch.full((batch,), seq_k, device=device, dtype=torch.int32) + kw["b_start_loc_k"] = b_start_loc_k + kw["b_seq_len_k"] = b_seq_len_k + kw["max_input_len_k"] = seq_k + + # --- Calibration mode: collect multi-threshold stats --- + calib_mode = getattr(_thread_local, "calibration_mode", False) + if calib_mode: + trials = getattr(_thread_local, "threshold_trials", None) + if trials and attention_calibrate is not None: + o, counters = attention_calibrate(q, k, v, **kw, threshold_trials=trials) + + # Accumulate counters across all attention calls in this forward pass + prev = getattr(_thread_local, "calibration_counters", None) + if prev is None: + _thread_local.calibration_counters = counters + else: + _thread_local.calibration_counters = prev + counters + + # Store actual KV sequence length for calibration stats + _thread_local.calibration_seq_k = seq_k + + return o.view(batch, seq_q, num_heads_q, head_dim) + + # --- Inference mode: skip-softmax with raw, dynamic, or static threshold --- + raw_thresh = getattr(_thread_local, "raw_threshold", None) + if raw_thresh is not None: + # Raw threshold: passed directly to kernel as skip_threshold_log2 + kw["skip_softmax_raw_threshold"] = raw_thresh + else: + scale_factor = getattr(_thread_local, "scale_factor", None) + if scale_factor is not None and scale_factor > 0.0: + # Dynamic threshold: adapt to actual sequence length + kw["skip_softmax_threshold"] = scale_factor / seq_k + else: + threshold = getattr(_thread_local, "skip_threshold", None) + if threshold is not None and threshold > 0.0: + kw["skip_softmax_threshold"] = threshold + + assert attention is not None, "Triton attention kernel not available (requires CUDA + triton)" + do_measure = getattr(_thread_local, "measure_sparsity", False) + if do_measure: + kw["measure_sparsity"] = True + o = attention(q, k, v, **kw) + + # Accumulate runtime sparsity counters from the kernel output + if do_measure and hasattr(o, "_sparsity_total"): + prev_total = getattr(_thread_local, "sparsity_total", 0) + prev_skipped = getattr(_thread_local, "sparsity_skipped", 0) + _thread_local.sparsity_total = prev_total + o._sparsity_total + _thread_local.sparsity_skipped = prev_skipped + o._sparsity_skipped + + return o.view(batch, seq_q, num_heads_q, head_dim) + + +# --------------------------------------------------------------------------- +# Registration +# --------------------------------------------------------------------------- + + +def register_diffusers_triton_attention() -> None: + """Register ``modelopt_triton`` backend in diffusers. + + Safe to call multiple times; registration happens only once. + """ + global _BACKEND_REGISTERED + if _BACKEND_REGISTERED: + return + + new_member = str.__new__(AttentionBackendName, _BACKEND_NAME) + new_member._name_ = "MODELOPT_TRITON" + new_member._value_ = _BACKEND_NAME + AttentionBackendName._member_map_["MODELOPT_TRITON"] = new_member + AttentionBackendName._value2member_map_[_BACKEND_NAME] = new_member + + _AttentionBackendRegistry._backends[new_member] = _diffusers_triton_attention + _AttentionBackendRegistry._constraints[new_member] = [] + _AttentionBackendRegistry._supported_arg_names[new_member] = set( + inspect.signature(_diffusers_triton_attention).parameters.keys() + ) + + _BACKEND_REGISTERED = True + + +def get_triton_attention_backend(): + """Return a context manager that activates the modelopt_triton backend.""" + if not _BACKEND_REGISTERED: + raise RuntimeError( + "modelopt_triton backend not registered. " + "Call register_diffusers_triton_attention() first." + ) + return attention_backend(_BACKEND_NAME) diff --git a/modelopt/torch/sparsity/attention_sparsity/kernels/ltx_triton_attention.py b/modelopt/torch/sparsity/attention_sparsity/kernels/ltx_triton_attention.py new file mode 100644 index 0000000000..a68a2512a1 --- /dev/null +++ b/modelopt/torch/sparsity/attention_sparsity/kernels/ltx_triton_attention.py @@ -0,0 +1,177 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Triton flash attention wrapper for LTX-2 (ltx_core) skip-softmax sparse attention. + +Two modes: +- **Inference**: ``attention()`` with skip-softmax tile skipping. +- **Calibration**: ``attention_calibrate()`` to collect multi-threshold stats. +""" + +import math +import threading + +import torch +from ltx_core.model.transformer.attention import Attention + +from modelopt.torch.kernels import attention, attention_calibrate + +# Thread-local storage for skip-softmax configuration +_thread_local = threading.local() + + +def set_ltx_triton_context( + active: bool, + threshold: float | None = None, + calibration_mode: bool = False, + threshold_trials: list[float] | None = None, + scale_factor: float | None = None, + raw_threshold: float | None = None, + **kwargs, +) -> None: + """Set thread-local Triton config for LTX-2 attention.""" + _thread_local.active = active + _thread_local.threshold = threshold + _thread_local.calibration_mode = calibration_mode + _thread_local.threshold_trials = threshold_trials + _thread_local.scale_factor = scale_factor + _thread_local.raw_threshold = raw_threshold + if not calibration_mode: + _thread_local.calibration_counters = None + _thread_local.calibration_seq_k = None + + +def clear_ltx_triton_context() -> None: + """Clear thread-local Triton config.""" + _thread_local.active = False + _thread_local.threshold = None + _thread_local.calibration_mode = False + _thread_local.threshold_trials = None + _thread_local.scale_factor = None + _thread_local.raw_threshold = None + _thread_local.calibration_counters = None + _thread_local.calibration_seq_k = None + + +def _get_ltx_triton_context() -> tuple[bool, float | None, float | None]: + """Return (active, threshold, scale_factor).""" + return ( + getattr(_thread_local, "active", False), + getattr(_thread_local, "threshold", None), + getattr(_thread_local, "scale_factor", None), + ) + + +def get_calibration_counters() -> "torch.Tensor | None": + """Return accumulated calibration counters ``[num_thresholds, 2]`` or None.""" + return getattr(_thread_local, "calibration_counters", None) + + +def get_calibration_seq_k() -> int | None: + """Return KV sequence length observed during calibration, or None.""" + return getattr(_thread_local, "calibration_seq_k", None) + + +def _ltx_triton_attention( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + heads: int, + mask: torch.Tensor | None = None, + threshold: float | None = None, +) -> torch.Tensor: + """Triton FA attention on LTX-2 layout ``[B, T, H*D]``.""" + b, seq_q, dim_total = q.shape + dim_head = dim_total // heads + seq_k = k.shape[1] + device = q.device + + q_flat = q.view(b, seq_q, heads, dim_head).reshape(b * seq_q, heads, dim_head).contiguous() + k_flat = k.view(b, seq_k, heads, dim_head).reshape(b * seq_k, heads, dim_head).contiguous() + v_flat = v.view(b, seq_k, heads, dim_head).reshape(b * seq_k, heads, dim_head).contiguous() + + b_start_loc_q = torch.arange(b, device=device, dtype=torch.int32) * seq_q + b_seq_len_q = torch.full((b,), seq_q, device=device, dtype=torch.int32) + + scale = 1.0 / math.sqrt(dim_head) + + kw: dict = { + "b_start_loc": b_start_loc_q, + "b_seq_len": b_seq_len_q, + "max_input_len": seq_q, + "is_causal": False, + "softmax_scale": scale, + } + + if seq_q != seq_k: + b_start_loc_k = torch.arange(b, device=device, dtype=torch.int32) * seq_k + b_seq_len_k = torch.full((b,), seq_k, device=device, dtype=torch.int32) + kw["b_start_loc_k"] = b_start_loc_k + kw["b_seq_len_k"] = b_seq_len_k + kw["max_input_len_k"] = seq_k + + # --- Calibration mode --- + calib_mode = getattr(_thread_local, "calibration_mode", False) + if calib_mode: + trials = getattr(_thread_local, "threshold_trials", None) + if trials and attention_calibrate is not None: + o, counters = attention_calibrate(q_flat, k_flat, v_flat, **kw, threshold_trials=trials) + + prev = getattr(_thread_local, "calibration_counters", None) + if prev is None: + _thread_local.calibration_counters = counters + else: + _thread_local.calibration_counters = prev + counters + + # Store actual KV sequence length for calibration stats + _thread_local.calibration_seq_k = seq_k + + return o.view(b, seq_q, heads * dim_head) + + # --- Inference mode: raw, dynamic, or static threshold --- + raw_thresh = getattr(_thread_local, "raw_threshold", None) + scale_factor = getattr(_thread_local, "scale_factor", None) + if raw_thresh is not None: + kw["skip_softmax_raw_threshold"] = raw_thresh + elif scale_factor is not None and scale_factor > 0.0: + kw["skip_softmax_threshold"] = scale_factor / seq_k + elif threshold is not None and threshold > 0.0: + kw["skip_softmax_threshold"] = threshold + + assert attention is not None, "Triton attention kernel not available (requires CUDA + triton)" + o = attention(q_flat, k_flat, v_flat, **kw) + return o.view(b, seq_q, heads * dim_head) + + +class _TritonLTXAttentionWrapper: + """Wraps ltx_core attention_function for Triton dispatch.""" + + def __init__(self, original_fn): + self._original_fn = original_fn + + def __call__(self, q, k, v, heads, mask=None): + active, threshold, _scale_factor = _get_ltx_triton_context() + if active: + return _ltx_triton_attention(q, k, v, heads, mask, threshold) + return self._original_fn(q, k, v, heads, mask) + + +def register_ltx_triton_attention(model: torch.nn.Module) -> None: + """Patch all ``ltx_core.Attention`` modules for Triton dispatch.""" + for module in model.modules(): + if isinstance(module, Attention): + fn = module.attention_function + if not isinstance(fn, _TritonLTXAttentionWrapper): + module.attention_function = _TritonLTXAttentionWrapper(fn) diff --git a/modelopt/torch/sparsity/attention_sparsity/methods/flash_skip_softmax.py b/modelopt/torch/sparsity/attention_sparsity/methods/flash_skip_softmax.py index 2501b58f65..aab399292a 100644 --- a/modelopt/torch/sparsity/attention_sparsity/methods/flash_skip_softmax.py +++ b/modelopt/torch/sparsity/attention_sparsity/methods/flash_skip_softmax.py @@ -20,6 +20,7 @@ """ import math +from contextlib import ExitStack from typing import Any import numpy as np @@ -369,7 +370,11 @@ def get_threshold_info(self) -> dict[str, Any]: } def get_sparse_context(self, module: torch.nn.Module): - """Return a context manager that patches F.softmax with sparse masking.""" + """Return a context manager that patches F.softmax with sparse masking. + + Also registers the diffusers eager backend so that diffusion models + (which don't call F.softmax directly) route through the patched path. + """ original_softmax = F.softmax def sparse_softmax(input, dim=-1, *args, **kwargs): @@ -379,7 +384,21 @@ def sparse_softmax(input, dim=-1, *args, **kwargs): input = self.apply_sparsity(input, sparse_mask) return original_softmax(input, dim, *args, **kwargs) - return replace_function(torch.nn.functional, "softmax", sparse_softmax) + from ..kernels import set_skip_softmax_context + + stack = ExitStack() + set_skip_softmax_context(True) + stack.callback(set_skip_softmax_context, False) + + try: + from ..kernels.diffusers_eager_attention import get_skip_softmax_attention_backend + + stack.enter_context(get_skip_softmax_attention_backend()) + except (ImportError, RuntimeError): + pass + + stack.enter_context(replace_function(torch.nn.functional, "softmax", sparse_softmax)) + return stack @property def name(self) -> str: diff --git a/modelopt/torch/sparsity/attention_sparsity/methods/registry.py b/modelopt/torch/sparsity/attention_sparsity/methods/registry.py index 8037146643..3cb4f9010e 100644 --- a/modelopt/torch/sparsity/attention_sparsity/methods/registry.py +++ b/modelopt/torch/sparsity/attention_sparsity/methods/registry.py @@ -40,6 +40,10 @@ def __init__(self): # Video shape for VSA (T, H, W). None for non-VSA methods. self.video_shape: tuple[int, int, int] | None = None + def set_calibration_mode(self, enabled: bool) -> None: + """Enable or disable calibration mode (called by DynamicThresholdCalibrator).""" + self._calibration_mode = enabled + def forward_attention( self, query: torch.Tensor, diff --git a/modelopt/torch/sparsity/attention_sparsity/methods/triton_skip_softmax.py b/modelopt/torch/sparsity/attention_sparsity/methods/triton_skip_softmax.py index 4db51e894e..1e2f3905e7 100644 --- a/modelopt/torch/sparsity/attention_sparsity/methods/triton_skip_softmax.py +++ b/modelopt/torch/sparsity/attention_sparsity/methods/triton_skip_softmax.py @@ -13,10 +13,18 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Skip-softmax method for attention via Triton kernel tile skipping.""" +"""Skip-softmax method for attention via Triton kernel tile skipping. + +Supports two modes: +- **Inference**: KV tiles with negligible scores are skipped in-kernel. +- **Calibration**: The Triton calibration kernel collects multi-threshold + sparsity statistics without skipping any tiles. +""" from contextlib import contextmanager +import torch + from .registry import SparseAttentionMethod, register_sparse_method @@ -39,21 +47,251 @@ def __init__(self, method_config=None): super().__init__() method_config = method_config or {} self.skip_softmax_threshold = method_config.get("skip_softmax_threshold", 0.1) + self.skip_softmax_raw_threshold: float | None = method_config.get( + "skip_softmax_raw_threshold", None + ) + # Calibration state + self._threshold_trials: list[float] | None = None + # Runtime sparsity measurement + self._measure_sparsity: bool = False + self._sparsity_total: int = 0 + self._sparsity_skipped: int = 0 @property def name(self) -> str: """Method name identifier.""" return "triton_skip_softmax" + def calculate_sparsity(self, attention_scores): + """Return a no-op mask (skip decision is made inside the Triton kernel).""" + mask = torch.ones_like(attention_scores, dtype=torch.bool) + return mask, {} + + def apply_sparsity(self, attention_scores, sparse_mask=None): + """Not supported — tile skipping is fused into the Triton kernel.""" + raise NotImplementedError( + "triton_skip_softmax applies tile skipping inside the Triton kernel. " + "Use backend='triton', not backend='pytorch'." + ) + def get_sparse_context(self, module): - """Return context manager that activates skip-softmax during forward.""" + """Return context manager that activates skip-softmax during forward. + + In calibration mode, configures the Triton backend to use the + calibration kernel which collects multi-threshold sparsity stats. + In inference mode, sets the skip threshold for tile skipping. + """ + if self._calibration_mode and self._threshold_trials: + return self._triton_calibration_context(module) + return self._triton_inference_context(module) + + @contextmanager + def _triton_inference_context(self, module): + """Inference: activate skip-softmax with calibrated or fixed threshold.""" + module._apply_skip_softmax = True + + backend_kwargs: dict = {} + if self._measure_sparsity: + backend_kwargs["measure_sparsity"] = True + + # Priority: raw_threshold > scale_factor (calibrated) > static threshold + if self.skip_softmax_raw_threshold is not None: + self._set_triton_backends( + raw_threshold=self.skip_softmax_raw_threshold, **backend_kwargs + ) + else: + scale_factor = self._get_scale_factor() + if scale_factor is not None: + self._set_triton_backends(scale_factor=scale_factor, **backend_kwargs) + else: + self._set_triton_backends(threshold=self.skip_softmax_threshold, **backend_kwargs) + with self._get_diffusers_backend_context(): + try: + yield + finally: + # Collect accumulated runtime sparsity counters before clearing + if self._measure_sparsity: + self._collect_sparsity_counters() + module._apply_skip_softmax = False + self._clear_triton_backends() - @contextmanager - def _skip_softmax_context(): - module._apply_skip_softmax = True + @contextmanager + def _triton_calibration_context(self, module): + """Calibration: collect multi-threshold sparsity stats via Triton kernel.""" + module._apply_skip_softmax = True + self._set_triton_backends(calibration_mode=True, threshold_trials=self._threshold_trials) + with self._get_diffusers_backend_context(): try: yield + # After forward pass, extract counters and build stats + self._collect_calibration_stats(module) finally: module._apply_skip_softmax = False + self._clear_triton_backends() + + def _get_scale_factor(self) -> float | None: + """Compute scale_factor from calibration params, or None if uncalibrated. + + The scale_factor is sequence-length-independent. Backends divide by the + actual ``seq_k`` at call time: ``threshold = scale_factor / seq_k``. + """ + if self.calibration_params and self.target_sparse_ratio: + import math + import warnings + + params = self.calibration_params.get("prefill", {}) + a = params.get("a", 0) + b = params.get("b", 0) + target = self.target_sparse_ratio.get("prefill", 0.5) + if a > 0 and b > 0: + # Warn if target is outside the calibrated range + min_s = params.get("min_observed_sparsity") + max_s = params.get("max_observed_sparsity") + if min_s is not None and target < min_s: + warnings.warn( + f"Target sparsity {target:.1%} is below the minimum observed " + f"during calibration ({min_s:.1%}). The model is extrapolating " + f"and runtime sparsity will likely be higher than the target.", + stacklevel=2, + ) + elif max_s is not None and target > max_s: + warnings.warn( + f"Target sparsity {target:.1%} is above the maximum observed " + f"during calibration ({max_s:.1%}). The model is extrapolating.", + stacklevel=2, + ) + return a * math.exp(b * target) + return None + + @staticmethod + @contextmanager + def _get_diffusers_backend_context(): + """Activate the modelopt_triton diffusers backend if registered.""" + try: + from ..kernels.diffusers_triton_attention import get_triton_attention_backend + + with get_triton_attention_backend(): + yield + except (ImportError, RuntimeError): + yield + + def _set_triton_backends(self, **kwargs): + """Set config on both diffusers and LTX Triton backends.""" + try: + from ..kernels.diffusers_triton_attention import set_triton_skip_softmax_config + + set_triton_skip_softmax_config(**kwargs) + except ImportError: + pass + try: + from ..kernels.ltx_triton_attention import set_ltx_triton_context + + set_ltx_triton_context(active=True, **kwargs) + except ImportError: + pass + + def _clear_triton_backends(self): + """Clear config on both Triton backends.""" + try: + from ..kernels.diffusers_triton_attention import clear_triton_skip_softmax_config + + clear_triton_skip_softmax_config() + except ImportError: + pass + try: + from ..kernels.ltx_triton_attention import clear_ltx_triton_context + + clear_ltx_triton_context() + except ImportError: + pass + + def _collect_calibration_stats(self, module): + """Read Triton calibration counters and store as stats on the module.""" + counters = None + seq_k = None + + try: + from ..kernels.diffusers_triton_attention import ( + get_calibration_counters, + get_calibration_seq_k, + ) + + counters = get_calibration_counters() + seq_k = get_calibration_seq_k() + except ImportError: + pass + + if counters is None: + try: + from ..kernels.ltx_triton_attention import ( + get_calibration_counters, + get_calibration_seq_k, + ) + + counters = get_calibration_counters() + seq_k = get_calibration_seq_k() + except ImportError: + pass + + if counters is None or self._threshold_trials is None: + return + + # counters: [num_thresholds, 2] — [:, 0]=total, [:, 1]=skipped + total = counters[:, 0].float() + skipped = counters[:, 1].float() + sparsity_list = (skipped / total.clamp(min=1)).tolist() + + # Use actual KV sequence length from backend for the exponential model fit. + # The calibrator uses: scale_factor = threshold * sample_length, so this + # must be the real sequence length, not the total tile count. + sample_length = seq_k if seq_k is not None else 0 + + module._last_stats = { + "sparsity": sparsity_list, + "sample_length": sample_length, + "phase": "prefill", + } + + def get_threshold_info(self) -> dict: + """Get threshold information for debugging/display.""" + scale_factor = self._get_scale_factor() + if scale_factor is not None: + return { + "type": "dynamic_calibrated", + "formula": "threshold = scale_factor / seq_k (computed at runtime)", + "scale_factor": scale_factor, + "calibration_params": self.calibration_params, + "target_sparse_ratio": self.target_sparse_ratio, + } + return { + "type": "static", + "value": self.skip_softmax_threshold, + } + + # ------------------------------------------------------------------ + # Runtime sparsity measurement + # ------------------------------------------------------------------ + + def enable_measure_sparsity(self, enabled: bool = True) -> None: + """Enable or disable runtime sparsity measurement.""" + self._measure_sparsity = enabled + + def reset_sparsity_counters(self) -> None: + """Reset accumulated sparsity counters to zero.""" + self._sparsity_total = 0 + self._sparsity_skipped = 0 + + def get_sparsity_counters(self) -> tuple[int, int]: + """Return accumulated ``(total_tiles, skipped_tiles)``.""" + return self._sparsity_total, self._sparsity_skipped + + def _collect_sparsity_counters(self) -> None: + """Read runtime sparsity counters from the backend and accumulate.""" + try: + from ..kernels.diffusers_triton_attention import get_sparsity_counters - return _skip_softmax_context() + total, skipped = get_sparsity_counters() + self._sparsity_total += total + self._sparsity_skipped += skipped + except ImportError: + pass diff --git a/modelopt/torch/sparsity/attention_sparsity/plugins/huggingface.py b/modelopt/torch/sparsity/attention_sparsity/plugins/huggingface.py index 599832943d..d26b73f0b4 100644 --- a/modelopt/torch/sparsity/attention_sparsity/plugins/huggingface.py +++ b/modelopt/torch/sparsity/attention_sparsity/plugins/huggingface.py @@ -16,7 +16,6 @@ """Dynamic sparse attention registration for HuggingFace models.""" import torch.nn as nn -import transformers from modelopt.torch.opt.dynamic import DynamicModule @@ -112,11 +111,22 @@ def _is_supported_model(model: nn.Module) -> bool: """ # Check for HuggingFace PreTrainedModel try: + import transformers + if isinstance(model, transformers.PreTrainedModel): return True except ImportError: pass + # Check for diffusers ModelMixin + try: + from diffusers.models.modeling_utils import ModelMixin + + if isinstance(model, ModelMixin): + return True + except ImportError: + pass + # Support any PyTorch model with attention modules return isinstance(model, nn.Module) diff --git a/modelopt/torch/sparsity/attention_sparsity/stats_manager.py b/modelopt/torch/sparsity/attention_sparsity/stats_manager.py index 1eabdfe358..3b8d9e2b92 100644 --- a/modelopt/torch/sparsity/attention_sparsity/stats_manager.py +++ b/modelopt/torch/sparsity/attention_sparsity/stats_manager.py @@ -66,12 +66,13 @@ def collect(self, stats: dict): self.aggregated_stats["total_calls"] += 1 self.aggregated_stats["total_blocks"] += stats.get("total_blocks", 0) - incoming = stats["sparse_blocks"] - if "sparse_blocks" not in self.aggregated_stats: - self.aggregated_stats["sparse_blocks"] = list(incoming) - else: - for i, val in enumerate(incoming): - self.aggregated_stats["sparse_blocks"][i] += val + incoming = stats.get("sparse_blocks") + if incoming is not None: + if "sparse_blocks" not in self.aggregated_stats: + self.aggregated_stats["sparse_blocks"] = list(incoming) + else: + for i, val in enumerate(incoming): + self.aggregated_stats["sparse_blocks"][i] += val phase = stats.get("phase", "unknown") if phase in self.aggregated_stats["phase_counts"]: @@ -79,14 +80,15 @@ def collect(self, stats: dict): # In calibration mode, store per-sample stats if self.calibration_mode: - self.per_sample_stats.append( - { - "module": self.module_name, - "sparsity": stats.get("sparsity", 0.0), - "sample_length": stats.get("sample_length", 0), - "phase": phase, - } - ) + sample_stat = { + "module": self.module_name, + "sparsity": stats.get("sparsity", 0.0), + "sample_length": stats.get("sample_length", 0), + "phase": phase, + } + if "normalized_gaps" in stats: + sample_stat["normalized_gaps"] = stats["normalized_gaps"] + self.per_sample_stats.append(sample_stat) def get_summary(self) -> dict: """Get aggregated statistics summary. diff --git a/tests/_test_utils/torch/diffusers_models.py b/tests/_test_utils/torch/diffusers_models.py index c2f9d9b3d7..352cc60d79 100644 --- a/tests/_test_utils/torch/diffusers_models.py +++ b/tests/_test_utils/torch/diffusers_models.py @@ -32,6 +32,16 @@ except Exception: # pragma: no cover - optional diffusers models Flux2Transformer2DModel = None +try: + from diffusers.models.transformers import WanTransformer3DModel +except Exception: # pragma: no cover - optional diffusers models + WanTransformer3DModel = None + +try: + from diffusers.models.autoencoders import AutoencoderKLWan +except Exception: # pragma: no cover - optional diffusers models + AutoencoderKLWan = None + import modelopt.torch.opt as mto @@ -157,3 +167,86 @@ def df_modelopt_state_and_output_tester(model_ref, model_test): assert model_ref_state == model_test_state df_output_tester(model_ref, model_test) + + +def get_tiny_wan22_transformer(**config_kwargs): + """Create a tiny WanTransformer3DModel for testing.""" + if WanTransformer3DModel is None: + pytest.skip("WanTransformer3DModel is not available in this diffusers version.") + + kwargs = { + "patch_size": (1, 2, 2), + "num_attention_heads": 2, + "attention_head_dim": 12, + "in_channels": 16, + "out_channels": 16, + "text_dim": 32, + "freq_dim": 256, + "ffn_dim": 32, + "num_layers": 2, + "cross_attn_norm": True, + "qk_norm": "rms_norm_across_heads", + "rope_max_seq_len": 32, + } + kwargs.update(**config_kwargs) + return WanTransformer3DModel(**kwargs) + + +def get_tiny_wan22_vae(**config_kwargs): + """Create a tiny AutoencoderKLWan for testing.""" + if AutoencoderKLWan is None: + pytest.skip("AutoencoderKLWan is not available in this diffusers version.") + + kwargs = { + "base_dim": 3, + "z_dim": 16, + "dim_mult": [1, 1, 1, 1], + "num_res_blocks": 1, + "temperal_downsample": [False, True, True], + } + kwargs.update(**config_kwargs) + return AutoencoderKLWan(**kwargs) + + +def create_tiny_wan22_pipeline_dir(tmp_path: Path) -> Path: + """Create and save a tiny Wan 2.2 (14B-style) pipeline to a directory. + + Uses the same tiny config as diffusers' own Wan 2.2 tests: + - Transformer: 2 heads, 12 head_dim, 2 layers (hidden_dim=24) + - VAE: base_dim=3, z_dim=16 + - Text encoder: hf-internal-testing/tiny-random-t5 (hidden_size=32) + - Dual transformer (14B style) with boundary_ratio=0.875 + + The saved directory can be loaded with ``WanPipeline.from_pretrained(path)``. + """ + from diffusers import UniPCMultistepScheduler, WanPipeline + from transformers import AutoTokenizer, T5EncoderModel + + torch.manual_seed(0) + vae = get_tiny_wan22_vae() + + torch.manual_seed(0) + transformer = get_tiny_wan22_transformer() + + torch.manual_seed(0) + transformer_2 = get_tiny_wan22_transformer() + + scheduler = UniPCMultistepScheduler( + prediction_type="flow_prediction", use_flow_sigmas=True, flow_shift=3.0 + ) + text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5") + tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5") + + pipe = WanPipeline( + transformer=transformer, + transformer_2=transformer_2, + vae=vae, + scheduler=scheduler, + text_encoder=text_encoder, + tokenizer=tokenizer, + boundary_ratio=0.875, + ) + + save_dir = tmp_path / "tiny_wan22" + pipe.save_pretrained(save_dir) + return save_dir diff --git a/tests/examples/diffusers/test_sparsity.py b/tests/examples/diffusers/test_sparsity.py new file mode 100644 index 0000000000..d33be1df68 --- /dev/null +++ b/tests/examples/diffusers/test_sparsity.py @@ -0,0 +1,104 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for skip-softmax sparse attention on Wan 2.2 (examples/diffusers/sparsity/). + +Uses a tiny Wan 2.2 model (dual transformer, 2 layers, hidden_dim=24) created +from scratch. Tests run the wan22_skip_softmax.py example script in baseline, +triton-baseline, and raw-threshold modes. +""" + +import pytest +from _test_utils.examples.run_command import run_example_command +from _test_utils.torch.diffusers_models import create_tiny_wan22_pipeline_dir + +EXAMPLE_PATH = "diffusers/sparsity" + +# Tiny inference settings — fast but exercises all code paths +_TINY_ARGS = [ + "--num-frames", + "5", + "--height", + "16", + "--width", + "16", + "--num-steps", + "2", + "--guidance-scale", + "1.0", + "--skip-first-last", + "0", + "--negative-prompt", + "", +] + + +@pytest.fixture(scope="session") +def tiny_wan22_path(tmp_path_factory): + """Create a tiny Wan 2.2 pipeline saved to disk (session-scoped).""" + return str(create_tiny_wan22_pipeline_dir(tmp_path_factory.mktemp("tiny_wan22"))) + + +def test_wan22_baseline(tiny_wan22_path, tmp_path): + """Dense baseline — no sparsity, default diffusers attention backend.""" + cmd = [ + "python", + "wan22_skip_softmax.py", + "--model-path", + tiny_wan22_path, + "--baseline", + "--prompt", + "test", + "--output", + str(tmp_path / "baseline.mp4"), + *_TINY_ARGS, + ] + run_example_command(cmd, EXAMPLE_PATH) + + +def test_wan22_triton_baseline(tiny_wan22_path, tmp_path): + """Triton kernel without skip-softmax (threshold=0, apples-to-apples).""" + cmd = [ + "python", + "wan22_skip_softmax.py", + "--model-path", + tiny_wan22_path, + "--triton-baseline", + "--prompt", + "test", + "--output", + str(tmp_path / "triton_baseline.mp4"), + *_TINY_ARGS, + ] + run_example_command(cmd, EXAMPLE_PATH) + + +def test_wan22_raw_threshold(tiny_wan22_path, tmp_path): + """Skip-softmax with a fixed raw threshold — no calibration needed.""" + cmd = [ + "python", + "wan22_skip_softmax.py", + "--model-path", + tiny_wan22_path, + "--raw-threshold", + "-5.0", + "--report-avg-sparsity", + "--prompt", + "test", + "--output", + str(tmp_path / "raw_threshold.mp4"), + *_TINY_ARGS, + ] + run_example_command(cmd, EXAMPLE_PATH) diff --git a/tests/unit/torch/sparsity/attention_sparsity/test_kernel_backends.py b/tests/unit/torch/sparsity/attention_sparsity/test_kernel_backends.py new file mode 100644 index 0000000000..3dd94ccee4 --- /dev/null +++ b/tests/unit/torch/sparsity/attention_sparsity/test_kernel_backends.py @@ -0,0 +1,152 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for diffusers kernel backends and thread-local context.""" + +import sys +import types +from unittest.mock import MagicMock, patch + +import pytest +import torch.nn as nn + + +def _mock_diffusers(): + """Mock diffusers.models.attention_dispatch for testing without real diffusers.""" + m = types.ModuleType("diffusers.models.attention_dispatch") + + class FakeBackendName(str): + _member_map_: dict = {} + _value2member_map_: dict = {} + + m.AttentionBackendName = FakeBackendName + + class FakeReg: + _backends: dict = {} + _constraints: dict = {} + _supported_arg_names: dict = {} + + m._AttentionBackendRegistry = FakeReg + m.attention_backend = MagicMock() + return { + "diffusers": types.ModuleType("diffusers"), + "diffusers.models": types.ModuleType("diffusers.models"), + "diffusers.models.attention_dispatch": m, + } + + +# --------------------------------------------------------------------------- +# Tests: thread-local skip-softmax context +# --------------------------------------------------------------------------- + + +class TestSkipSoftmaxContext: + def test_default_is_false(self): + from modelopt.torch.sparsity.attention_sparsity.kernels import get_skip_softmax_context + + assert get_skip_softmax_context() is False + + def test_set_and_get(self): + from modelopt.torch.sparsity.attention_sparsity.kernels import ( + get_skip_softmax_context, + set_skip_softmax_context, + ) + + set_skip_softmax_context(True) + assert get_skip_softmax_context() is True + set_skip_softmax_context(False) + assert get_skip_softmax_context() is False + + +# --------------------------------------------------------------------------- +# Tests: diffusers triton attention +# --------------------------------------------------------------------------- + + +class TestDiffusersTritonAttention: + @pytest.fixture(autouse=True) + def _setup(self): + mocks = _mock_diffusers() + mk = types.ModuleType("modelopt.torch.kernels") + mk.attention = lambda q, k, v, **kw: q + mk.attention_calibrate = None + mk.IS_AVAILABLE = True + mk.register_triton_attention = None + mocks["modelopt.torch.kernels"] = mk + + with patch.dict(sys.modules, mocks): + from modelopt.torch.sparsity.attention_sparsity.kernels.diffusers_triton_attention import ( + _diffusers_triton_attention, + clear_triton_skip_softmax_config, + get_triton_attention_backend, + register_diffusers_triton_attention, + set_triton_skip_softmax_config, + ) + + self._fn = _diffusers_triton_attention + self._set = set_triton_skip_softmax_config + self._clear = clear_triton_skip_softmax_config + self._register = register_diffusers_triton_attention + self._get_backend = get_triton_attention_backend + + import modelopt.torch.sparsity.attention_sparsity.kernels.diffusers_triton_attention as mod + + mod._BACKEND_REGISTERED = False + yield + + def test_set_clear_config(self): + self._set(threshold=0.1) + self._clear() + + def test_register_idempotent(self): + self._register() + self._register() + + def test_get_backend_before_register_raises(self): + with pytest.raises(RuntimeError, match="not registered"): + self._get_backend() + + +# --------------------------------------------------------------------------- +# Tests: conversion.py _register_diffusers_backends_if_needed +# --------------------------------------------------------------------------- + + +class TestRegisterDiffusersBackends: + def test_no_diffusers_no_error(self): + from modelopt.torch.sparsity.attention_sparsity.conversion import ( + _register_diffusers_backends_if_needed, + ) + + _register_diffusers_backends_if_needed(nn.Linear(10, 10)) + + def test_with_diffusers_model(self): + from modelopt.torch.sparsity.attention_sparsity.conversion import ( + _register_diffusers_backends_if_needed, + ) + + mock_mixin = type("ModelMixin", (nn.Module,), {}) + mock_utils = types.ModuleType("diffusers.models.modeling_utils") + mock_utils.ModelMixin = mock_mixin + + with ( + patch.dict(sys.modules, {"diffusers.models.modeling_utils": mock_utils}), + patch( + "modelopt.torch.sparsity.attention_sparsity.kernels.register_diffusers_triton_attention", + MagicMock(), + ) as mock_triton, + ): + _register_diffusers_backends_if_needed(mock_mixin()) + mock_triton.assert_called_once()