diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py index d8a4cdc61c83..af272c43a930 100755 --- a/src/transformers/utils/import_utils.py +++ b/src/transformers/utils/import_utils.py @@ -812,11 +812,11 @@ def is_flash_attn_greater_or_equal_2_10(): return version.parse(importlib.metadata.version("flash_attn")) >= version.parse("2.1.0") -def is_flash_attn_greater_or_equal(version): +def is_flash_attn_greater_or_equal(pkg_version): if not _is_package_available("flash_attn"): return False - return version.parse(importlib.metadata.version("flash_attn")) >= version.parse(version) + return version.parse(importlib.metadata.version("flash_attn")) >= version.parse(pkg_version) def is_torchdistx_available():