Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion docs/examples/attention/attention.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@
"- flash-attention does not support `post_scale_bias`, and cuDNN attention does.\n",
"- flash-attention supports KV-caching and paged attention, and cuDNN attention does not.\n",
"- flash-attention uses bottom right diagonal for `causal` mask in cross attention (see [change log](https://github.com/Dao-AILab/flash-attention?tab=readme-ov-file#21-change-behavior-of-causal-flag)), and cuDNN attention supports both top left and bottom right.\n",
"- **Sliding window attention (SWA):** flash-attention has full SWA support for all mask types with dropout and bias. cuDNN attention supports causal SWA (cuDNN 9.2+) but requires `dropout=0.0` and `bias_type=\"no_bias\"`.\n",
"- flash-attention outperforms cuDNN attention on Ampere architectures, and cuDNN attention has 20-50% advantages on Hopper architectures, based on our benchmarks for a number of commonly-used model configurations.\n",
"\n",
"To compare cuDNN attention and flash-attention, users can modify the `model_configs` dictionary in [benchmarks/attention/benchmark_attention.py](https://github.com/NVIDIA/TransformerEngine/blob/main/benchmarks/attention/benchmark_attention.py) to collect performance numbers. The script runs each entry in `model_configs` for `num_iters` times, each time with one forward pass and one backward pass. Both backends are tried, and if one backend does not have support for the specific user input, the runtimes and speedups in the final table would be 0."
Expand Down Expand Up @@ -389,7 +390,7 @@
"\n",
"| Attention Backend | Precision | Architecture | Sliding Window Attention | MQA/GQA | Multi-Latent Attention | Context Parallelism | Determinism Possible |\n",
"| :---------------- | :-------- | :----------- | :----------------------- | :------ | :--------------------- | :------------------ | :------------ |\n",
"| cuDNN attention (all frameworks) | BF16, FP16, FP8 (PyTorch only) | sm80+ | No | Yes | Yes | Yes (`bshd`,`sbhd`, `thd`) | Yes |\n",
"| cuDNN attention (all frameworks) | BF16, FP16, FP8 (PyTorch only) | sm80+ | Yes (cuDNN 9.2+, causal masks only) | Yes | Yes | Yes (`bshd`,`sbhd`, `thd`) | Yes |\n",
"| flash-attention (PyTorch) | BF16, FP16 | sm80+ | Yes | Yes | Yes | Yes (`bshd`,`thd`) | Yes |\n",
"| Framework-native attention | BF16, FP16, FP32 | Any | No, unless used as a mask | Yes | Yes (PyTorch only) | No | Yes |\n",
"\n",
Expand Down
2 changes: 2 additions & 0 deletions docs/examples/attention/cp_ag_thd_dpa_jax_deep_dive.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
"source": [
"### Question 1: Why choose Striped>1 ?\n",
"\n",
"**Note:** cuDNN supports Sliding Window Attention (SWA) starting from version 9.2+ for causal masks. However, not all striping patterns for context parallelism are supported. This section explains why `stripe_size>1` is chosen over `stripe_size=1` for CP+THD+AG with SWA.\n",
"\n",
"Prior to the addition of this feature, Transformer Engine JAX attention already supported load balancing via a striping pattern, i.e., `stripe_size=1` for `CP + THD + P2P(Ring) + Striped + SWA`. However, this reordering technique does not lend itself well to an all-gathered (post-AG) pattern. The following example illustrates this distinction. For this example, `cp_size=4`, `num_segments=4`, `window_size=(8,0)`, and the pattern is for a single rank after striped reordering has been performed: \n",
"\n",
"#### I. Striped (`stripe_size=1`)\n",
Expand Down
Loading