diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py index de11d23cbecf..9ef02381e00b 100644 --- a/src/transformers/utils/import_utils.py +++ b/src/transformers/utils/import_utils.py @@ -948,7 +948,7 @@ def is_flash_attn_2_available() -> bool: is_available, flash_attn_version = _is_package_available("flash_attn", return_version=True) # FA4 is also distributed under "flash_attn", hence we need to check the naming here is_available = is_available and "flash-attn" in [ - pkg.replace("_", "-") for pkg in PACKAGE_DISTRIBUTION_MAPPING["flash_attn"] + pkg.replace("_", "-") for pkg in PACKAGE_DISTRIBUTION_MAPPING.get("flash_attn", []) ] if not is_available or not (is_torch_cuda_available() or is_torch_mlu_available()): @@ -967,7 +967,7 @@ def is_flash_attn_3_available() -> bool: is_available = _is_package_available("flash_attn_interface")[0] # Resolving and ensuring the proper name of FA3 being associated is_available = is_available and "flash-attn-3" in [ - pkg.replace("_", "-") for pkg in PACKAGE_DISTRIBUTION_MAPPING["flash_attn_interface"] + pkg.replace("_", "-") for pkg in PACKAGE_DISTRIBUTION_MAPPING.get("flash_attn_interface", []) ] return is_available and is_torch_cuda_available() @@ -979,7 +979,7 @@ def is_flash_attn_4_available() -> bool: # NOTE: FA2 seems to distribute the `cute` subdirectory even if only FA2 has been installed # -> check for the proper (normalized) distribution name is_available = is_available and "flash-attn-4" in [ - pkg.replace("_", "-") for pkg in PACKAGE_DISTRIBUTION_MAPPING["flash_attn"] + pkg.replace("_", "-") for pkg in PACKAGE_DISTRIBUTION_MAPPING.get("flash_attn", []) ] return is_available and is_torch_cuda_available() @@ -990,7 +990,7 @@ def is_flash_attn_greater_or_equal(library_version: str) -> bool: is_available, flash_attn_version = _is_package_available("flash_attn", return_version=True) # FA4 is also distributed under "flash_attn", hence we need to check the naming here is_available = is_available and "flash-attn" in [ - pkg.replace("_", "-") for pkg in PACKAGE_DISTRIBUTION_MAPPING["flash_attn"] + pkg.replace("_", "-") for pkg in PACKAGE_DISTRIBUTION_MAPPING.get("flash_attn", []) ] if not is_available: