From 22be6ec525364655c367b77e3197ecaa6c5f40c8 Mon Sep 17 00:00:00 2001 From: SAY-5 Date: Sun, 19 Apr 2026 22:44:33 -0700 Subject: [PATCH] utils: stop crashing with KeyError when flash_attn is importable but not in the distribution map is_flash_attn_2_available / _3 / _4 / _greater_or_equal do two checks: is_available, _ = _is_package_available("flash_attn", return_version=True) is_available = is_available and "flash-attn" in [ pkg.replace("_", "-") for pkg in PACKAGE_DISTRIBUTION_MAPPING["flash_attn"] ] Step 1 uses importlib.util.find_spec, which returns a spec if any "flash_attn" import is findable (an editable install, a namespace package, a bundled shim, or a stub module under another project). Step 2 then assumes that every findable import name also has an entry in importlib.metadata.packages_distributions(). That assumption does not hold. On Python 3.13 with ComfyUI setups (#45520), and in any environment where the import is resolvable via a non-pip source, packages_distributions() has no "flash_attn" key. Because the list comprehension is evaluated before the `in` operator, short-circuit evaluation of the outer `and` does not protect us - the KeyError fires during `transformers` import and takes down the whole process before any model is loaded. Swap the four raising subscripts for `.get(name, [])`. If the name is missing from the distribution map we simply conclude that the requested flash-attention flavour is not properly installed - which is the same answer is_flash_attn_*_available() would have returned anyway - instead of raising. The inner helper `_is_package_available` already wraps the same subscript in a try/except, so we are only making the outer call sites match that contract. Fixes #45520 --- src/transformers/utils/import_utils.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py index de11d23cbecf..9ef02381e00b 100644 --- a/src/transformers/utils/import_utils.py +++ b/src/transformers/utils/import_utils.py @@ -948,7 +948,7 @@ def is_flash_attn_2_available() -> bool: is_available, flash_attn_version = _is_package_available("flash_attn", return_version=True) # FA4 is also distributed under "flash_attn", hence we need to check the naming here is_available = is_available and "flash-attn" in [ - pkg.replace("_", "-") for pkg in PACKAGE_DISTRIBUTION_MAPPING["flash_attn"] + pkg.replace("_", "-") for pkg in PACKAGE_DISTRIBUTION_MAPPING.get("flash_attn", []) ] if not is_available or not (is_torch_cuda_available() or is_torch_mlu_available()): @@ -967,7 +967,7 @@ def is_flash_attn_3_available() -> bool: is_available = _is_package_available("flash_attn_interface")[0] # Resolving and ensuring the proper name of FA3 being associated is_available = is_available and "flash-attn-3" in [ - pkg.replace("_", "-") for pkg in PACKAGE_DISTRIBUTION_MAPPING["flash_attn_interface"] + pkg.replace("_", "-") for pkg in PACKAGE_DISTRIBUTION_MAPPING.get("flash_attn_interface", []) ] return is_available and is_torch_cuda_available() @@ -979,7 +979,7 @@ def is_flash_attn_4_available() -> bool: # NOTE: FA2 seems to distribute the `cute` subdirectory even if only FA2 has been installed # -> check for the proper (normalized) distribution name is_available = is_available and "flash-attn-4" in [ - pkg.replace("_", "-") for pkg in PACKAGE_DISTRIBUTION_MAPPING["flash_attn"] + pkg.replace("_", "-") for pkg in PACKAGE_DISTRIBUTION_MAPPING.get("flash_attn", []) ] return is_available and is_torch_cuda_available() @@ -990,7 +990,7 @@ def is_flash_attn_greater_or_equal(library_version: str) -> bool: is_available, flash_attn_version = _is_package_available("flash_attn", return_version=True) # FA4 is also distributed under "flash_attn", hence we need to check the naming here is_available = is_available and "flash-attn" in [ - pkg.replace("_", "-") for pkg in PACKAGE_DISTRIBUTION_MAPPING["flash_attn"] + pkg.replace("_", "-") for pkg in PACKAGE_DISTRIBUTION_MAPPING.get("flash_attn", []) ] if not is_available: