Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion bitsandbytes/cextension.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import torch

from bitsandbytes.consts import DYNAMIC_LIBRARY_SUFFIX, PACKAGE_DIR
from bitsandbytes.cuda_specs import CUDASpecs, get_cuda_specs
from bitsandbytes.cuda_specs import CUDASpecs, get_cuda_specs, get_rocm_gpu_arch

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -113,6 +113,8 @@ def get_native_library() -> BNBNativeLibrary:
return BNBNativeLibrary(dll)


ROCM_GPU_ARCH = get_rocm_gpu_arch()

try:
if torch.version.hip:
hip_major, hip_minor = map(int, torch.version.hip.split(".")[0:2])
Expand Down
26 changes: 26 additions & 0 deletions bitsandbytes/cuda_specs.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
import dataclasses
import logging
import re
import subprocess
from typing import List, Optional, Tuple

import torch
Expand Down Expand Up @@ -42,3 +45,26 @@ def get_cuda_specs() -> Optional[CUDASpecs]:
cuda_version_string=(get_cuda_version_string()),
cuda_version_tuple=get_cuda_version_tuple(),
)


def get_rocm_gpu_arch() -> str:
logger = logging.getLogger(__name__)
try:
if torch.version.hip:
result = subprocess.run(["rocminfo"], capture_output=True, text=True)
match = re.search(r"Name:\s+gfx(\d+)", result.stdout)
if match:
return "gfx" + match.group(1)
else:
return "unknown"
else:
return "unknown"
except Exception as e:
logger.error(f"Could not detect ROCm GPU architecture: {e}")
if torch.cuda.is_available():
logger.warning(
"""
ROCm GPU architecture detection failed despite ROCm being available.
""",
)
return "unknown"
6 changes: 5 additions & 1 deletion tests/test_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

import bitsandbytes as bnb
from bitsandbytes import functional as F
from bitsandbytes.cextension import BNB_HIP_VERSION, HIP_ENVIRONMENT
from bitsandbytes.cextension import BNB_HIP_VERSION, HIP_ENVIRONMENT, ROCM_GPU_ARCH
from tests.helpers import BOOLEAN_TUPLES, TRUE_FALSE, describe_dtype, get_blocksizes, get_test_dims, id_formatter

torch.set_printoptions(precision=5, sci_mode=False, linewidth=120, edgeitems=20, threshold=10000)
Expand Down Expand Up @@ -2242,6 +2242,10 @@ def test_managed():
@pytest.mark.parametrize("storage_type", ["nf4", "fp4"], ids=["nf4", "fp4"])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=describe_dtype)
@pytest.mark.parametrize("double_quant", [False], ids=["DQ_True"])
@pytest.mark.skipif(
HIP_ENVIRONMENT and ROCM_GPU_ARCH == "gfx90a",
reason="this test is not supported on ROCm with gfx90a architecture yet",
)
def test_gemv_eye_4bit(storage_type, dtype, double_quant):
dims = 10
torch.random.manual_seed(np.random.randint(0, 412424242))
Expand Down