From b6303fe8591f6e0473c2185e42b0736b5a306904 Mon Sep 17 00:00:00 2001 From: Santosh Bhavani Date: Mon, 26 Jan 2026 10:47:10 -0800 Subject: [PATCH] docs: update cuDNN sliding window attention support Update documentation to reflect that cuDNN now supports causal sliding window attention (SWA) starting from version 9.2+. Changes: - Updated backend support matrix table to show cuDNN supports SWA (cuDNN 9.2+, causal masks only) - Added SWA comparison between flash-attention and cuDNN in section 1.3 - Added clarifying note in cp_ag_thd_dpa_jax_deep_dive.ipynb that cuDNN supports SWA but not all striping patterns for context parallelism Technical details: - cuDNN 9.2+: Supports causal SWA with window_size=(left, 0) - cuDNN 9.6+: Enhanced support for asymmetric windows (left, right) - Constraints: Requires dropout=0.0 and bias_type="no_bias" - Only works with causal mask types Signed-off-by: Santosh Bhavani --- docs/examples/attention/attention.ipynb | 3 ++- docs/examples/attention/cp_ag_thd_dpa_jax_deep_dive.ipynb | 2 ++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/docs/examples/attention/attention.ipynb b/docs/examples/attention/attention.ipynb index 4b2ed80497..a6ef9b6834 100644 --- a/docs/examples/attention/attention.ipynb +++ b/docs/examples/attention/attention.ipynb @@ -151,6 +151,7 @@ "- flash-attention does not support `post_scale_bias`, and cuDNN attention does.\n", "- flash-attention supports KV-caching and paged attention, and cuDNN attention does not.\n", "- flash-attention uses bottom right diagonal for `causal` mask in cross attention (see [change log](https://github.com/Dao-AILab/flash-attention?tab=readme-ov-file#21-change-behavior-of-causal-flag)), and cuDNN attention supports both top left and bottom right.\n", + "- **Sliding window attention (SWA):** flash-attention has full SWA support for all mask types with dropout and bias. cuDNN attention supports causal SWA (cuDNN 9.2+) but requires `dropout=0.0` and `bias_type=\"no_bias\"`.\n", "- flash-attention outperforms cuDNN attention on Ampere architectures, and cuDNN attention has 20-50% advantages on Hopper architectures, based on our benchmarks for a number of commonly-used model configurations.\n", "\n", "To compare cuDNN attention and flash-attention, users can modify the `model_configs` dictionary in [benchmarks/attention/benchmark_attention.py](https://github.com/NVIDIA/TransformerEngine/blob/main/benchmarks/attention/benchmark_attention.py) to collect performance numbers. The script runs each entry in `model_configs` for `num_iters` times, each time with one forward pass and one backward pass. Both backends are tried, and if one backend does not have support for the specific user input, the runtimes and speedups in the final table would be 0." @@ -389,7 +390,7 @@ "\n", "| Attention Backend | Precision | Architecture | Sliding Window Attention | MQA/GQA | Multi-Latent Attention | Context Parallelism | Determinism Possible |\n", "| :---------------- | :-------- | :----------- | :----------------------- | :------ | :--------------------- | :------------------ | :------------ |\n", - "| cuDNN attention (all frameworks) | BF16, FP16, FP8 (PyTorch only) | sm80+ | No | Yes | Yes | Yes (`bshd`,`sbhd`, `thd`) | Yes |\n", + "| cuDNN attention (all frameworks) | BF16, FP16, FP8 (PyTorch only) | sm80+ | Yes (cuDNN 9.2+, causal masks only) | Yes | Yes | Yes (`bshd`,`sbhd`, `thd`) | Yes |\n", "| flash-attention (PyTorch) | BF16, FP16 | sm80+ | Yes | Yes | Yes | Yes (`bshd`,`thd`) | Yes |\n", "| Framework-native attention | BF16, FP16, FP32 | Any | No, unless used as a mask | Yes | Yes (PyTorch only) | No | Yes |\n", "\n", diff --git a/docs/examples/attention/cp_ag_thd_dpa_jax_deep_dive.ipynb b/docs/examples/attention/cp_ag_thd_dpa_jax_deep_dive.ipynb index 56bc3b13cf..e5c6d5dd2c 100644 --- a/docs/examples/attention/cp_ag_thd_dpa_jax_deep_dive.ipynb +++ b/docs/examples/attention/cp_ag_thd_dpa_jax_deep_dive.ipynb @@ -28,6 +28,8 @@ "source": [ "### Question 1: Why choose Striped>1 ?\n", "\n", + "**Note:** cuDNN supports Sliding Window Attention (SWA) starting from version 9.2+ for causal masks. However, not all striping patterns for context parallelism are supported. This section explains why `stripe_size>1` is chosen over `stripe_size=1` for CP+THD+AG with SWA.\n", + "\n", "Prior to the addition of this feature, Transformer Engine JAX attention already supported load balancing via a striping pattern, i.e., `stripe_size=1` for `CP + THD + P2P(Ring) + Striped + SWA`. However, this reordering technique does not lend itself well to an all-gathered (post-AG) pattern. The following example illustrates this distinction. For this example, `cp_size=4`, `num_segments=4`, `window_size=(8,0)`, and the pattern is for a single rank after striped reordering has been performed: \n", "\n", "#### I. Striped (`stripe_size=1`)\n",