Fix flash_attention_3 detection and import for hopper wheel installs#45387
Fix flash_attention_3 detection and import for hopper wheel installs#45387albertorkive wants to merge 2 commits intohuggingface:mainfrom
Conversation
is_flash_attn_3_available() checked package distribution metadata for "flash-attn-3", but the hopper wheel (built from flash-attention/hopper/) doesn't register under that distribution name. Replace metadata check with actual import probing of both known module paths: 1. flash_attn_interface (standalone flash-attn-3 wheel) 2. hopper.flash_attn_interface (built from flash-attention/hopper/) Also fix _lazy_imports() to try both paths, and correct the compatibility matrix minimum CUDA version from 8 (Ampere) to 9 (Hopper) — FA3 kernels require sm90+.
|
View the CircleCI Test Summary for this PR: https://huggingface.co/spaces/transformers-community/circle-ci-viz?pr=45387&sha=d4c3f2 |
|
cc @ArthurZucker @Cyrilvallez @vasqu but not sure if it's real or code agent hallucination |
vasqu
left a comment
There was a problem hiding this comment.
This is not the most common way to install FA3, you should use python setup.py install, e.g. see the recommended way in https://github.com/Dao-AILab/flash-attention?tab=readme-ov-file#flashattention-3-beta-release
This seems more like a broken install tbh and we shouldn't support it: Kv cache function missing, different exposing of the official functions in the docs (readme).
I tend to say it might be some agent tbh
| # 1. "flash_attn_interface" — standalone flash-attn-3 wheel | ||
| # 2. "hopper.flash_attn_interface" — built from flash-attention/hopper/ |
There was a problem hiding this comment.
https://github.com/Dao-AILab/flash-attention?tab=readme-ov-file#flashattention-3-beta-release shows the usage to be
import flash_attn_interface
flash_attn_interface.flash_attn_func()
So I don't see why hopper.flash_attn_interface should be valid - that implies some weird structure inherited from pip which shouldn't be used
| "general_availability_check": is_flash_attn_3_available, | ||
| "pkg_availability_check": lambda *args, **kwargs: importlib.util.find_spec("flash_attn_interface") is not None | ||
| and "flash-attn-3" in [pkg.replace("_", "-") for pkg in PACKAGE_DISTRIBUTION_MAPPING["flash_attn_interface"]], | ||
| "pkg_availability_check": is_flash_attn_3_available, # defers to actual import probing |
There was a problem hiding this comment.
This would also use cuda checks, we only want to check for metadata here
| "pkg_availability_check": is_flash_attn_3_available, # defers to actual import probing | ||
| "supported_devices": ((is_torch_cuda_available, "cuda"),), | ||
| "cuda_min_major_version": 8, # Ampere | ||
| "cuda_min_major_version": 9, # Hopper (sm90+) — FA3 does NOT run on Ampere |
There was a problem hiding this comment.
Plain wrong, and have checked since ages that ampere works
| except ImportError: | ||
| from hopper.flash_attn_interface import flash_attn_func, flash_attn_varlen_func # type: ignore[no-redef] | ||
|
|
||
| flash_attn_with_kvcache = None # hopper wheel may not expose this |
There was a problem hiding this comment.
Honestly, given that this weird install doesn't have the kv cache function, it implies something is seriously wrong because the existing functions are not properly exposed.
|
Fair, closing this. The problem I ran into was the metadata check in is_flash_attn_3_available() installed FA3 correctly from the hopper wheel but PACKAGE_DISTRIBUTION_MAPPING didn't pick it up, so detection returned false. Went too broad with the fix. I'll verify and open an issue if it still reproduces. |
What does this PR do?
Fixes
attn_implementation="flash_attention_3"which is currently broken for the most common FA3 install method — the hopper wheel built fromflash-attention/hopper/.Three issues fixed:
is_flash_attn_3_available()returnsFalseeven when FA3 is installed. The check looks for"flash-attn-3"inPACKAGE_DISTRIBUTION_MAPPING["flash_attn_interface"], but the hopper wheel doesn't register under that distribution name. Fix: try the actual imports (flash_attn_interfaceandhopper.flash_attn_interface) instead of relying on package metadata._lazy_imports()fails at runtime even if detection passes. It only triesfrom flash_attn_interface import ..., which fails for the hopper wheel (which exposeshopper.flash_attn_interface). Fix: try both import paths with a fallback.Compatibility matrix lists wrong minimum CUDA version.
cuda_min_major_versionis set to8(Ampere) for FA3, but FA3 kernels require Hopper (sm90+). Fix: set to9.Context: FA3 support was added in #38972, but the detection relies on package distribution metadata that doesn't match how most users actually install FA3 (building from the
hopper/subdirectory of the flash-attention repo, or using pre-built hopper wheels). Theflash-dispatchpackage (a Flash Attention ecosystem tool) solves this same problem by probing both import paths — this PR applies the same approach directly in transformers.Files changed
src/transformers/utils/import_utils.py— rewriteis_flash_attn_3_available()to probe actual importssrc/transformers/modeling_flash_attention_utils.py— try both FA3 import paths in_lazy_imports(), fix compat matrixHow to reproduce the bug
After this PR, both work correctly.