diff --git a/bitsandbytes/cextension.py b/bitsandbytes/cextension.py index 69cf0b15f..03d2cbd61 100644 --- a/bitsandbytes/cextension.py +++ b/bitsandbytes/cextension.py @@ -24,7 +24,7 @@ import torch from bitsandbytes.consts import DYNAMIC_LIBRARY_SUFFIX, PACKAGE_DIR -from bitsandbytes.cuda_specs import CUDASpecs, get_cuda_specs +from bitsandbytes.cuda_specs import CUDASpecs, get_cuda_specs, get_rocm_gpu_arch logger = logging.getLogger(__name__) @@ -113,6 +113,8 @@ def get_native_library() -> BNBNativeLibrary: return BNBNativeLibrary(dll) +ROCM_GPU_ARCH = get_rocm_gpu_arch() + try: if torch.version.hip: hip_major, hip_minor = map(int, torch.version.hip.split(".")[0:2]) diff --git a/bitsandbytes/cuda_specs.py b/bitsandbytes/cuda_specs.py index 50c139317..e104762e3 100644 --- a/bitsandbytes/cuda_specs.py +++ b/bitsandbytes/cuda_specs.py @@ -1,4 +1,7 @@ import dataclasses +import logging +import re +import subprocess from typing import List, Optional, Tuple import torch @@ -42,3 +45,26 @@ def get_cuda_specs() -> Optional[CUDASpecs]: cuda_version_string=(get_cuda_version_string()), cuda_version_tuple=get_cuda_version_tuple(), ) + + +def get_rocm_gpu_arch() -> str: + logger = logging.getLogger(__name__) + try: + if torch.version.hip: + result = subprocess.run(["rocminfo"], capture_output=True, text=True) + match = re.search(r"Name:\s+gfx(\d+)", result.stdout) + if match: + return "gfx" + match.group(1) + else: + return "unknown" + else: + return "unknown" + except Exception as e: + logger.error(f"Could not detect ROCm GPU architecture: {e}") + if torch.cuda.is_available(): + logger.warning( + """ +ROCm GPU architecture detection failed despite ROCm being available. + """, + ) + return "unknown" diff --git a/tests/test_functional.py b/tests/test_functional.py index 04a898d4b..8acd5395d 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -11,7 +11,7 @@ import bitsandbytes as bnb from bitsandbytes import functional as F -from bitsandbytes.cextension import BNB_HIP_VERSION, HIP_ENVIRONMENT +from bitsandbytes.cextension import BNB_HIP_VERSION, HIP_ENVIRONMENT, ROCM_GPU_ARCH from tests.helpers import BOOLEAN_TUPLES, TRUE_FALSE, describe_dtype, get_blocksizes, get_test_dims, id_formatter torch.set_printoptions(precision=5, sci_mode=False, linewidth=120, edgeitems=20, threshold=10000) @@ -2242,6 +2242,10 @@ def test_managed(): @pytest.mark.parametrize("storage_type", ["nf4", "fp4"], ids=["nf4", "fp4"]) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=describe_dtype) @pytest.mark.parametrize("double_quant", [False], ids=["DQ_True"]) +@pytest.mark.skipif( + HIP_ENVIRONMENT and ROCM_GPU_ARCH == "gfx90a", + reason="this test is not supported on ROCm with gfx90a architecture yet", +) def test_gemv_eye_4bit(storage_type, dtype, double_quant): dims = 10 torch.random.manual_seed(np.random.randint(0, 412424242))