From 60d7560a6010eeee1bab9ef66a82eb501891b74a Mon Sep 17 00:00:00 2001 From: root Date: Fri, 26 Apr 2024 20:51:29 +0000 Subject: [PATCH 1/5] adding arch detection for test_gemv_eye_4bit --- bitsandbytes/cextension.py | 11 ++++++++++- tests/test_functional.py | 3 ++- 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/bitsandbytes/cextension.py b/bitsandbytes/cextension.py index 69cf0b15f..f5924f7f9 100644 --- a/bitsandbytes/cextension.py +++ b/bitsandbytes/cextension.py @@ -19,6 +19,9 @@ import ctypes as ct import logging import os +import subprocess +import re + from pathlib import Path import torch @@ -117,8 +120,14 @@ def get_native_library() -> BNBNativeLibrary: if torch.version.hip: hip_major, hip_minor = map(int, torch.version.hip.split(".")[0:2]) HIP_ENVIRONMENT, BNB_HIP_VERSION = True, hip_major * 100 + hip_minor + result = subprocess.run(['rocminfo'], capture_output=True, text=True) + match = re.search(r'Name:\s+gfx(\d+)', result.stdout) + if match: + ROCM_GPU_ARCH = "gfx" + match.group(1) + else: + ROCM_GPU_ARCH = "unknown" else: - HIP_ENVIRONMENT, BNB_HIP_VERSION = False, 0 + HIP_ENVIRONMENT, BNB_HIP_VERSION, ROCM_GPU_ARCH = False, 0, "unknown" lib = get_native_library() except Exception as e: lib = None diff --git a/tests/test_functional.py b/tests/test_functional.py index 04a898d4b..dffe724a6 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -11,7 +11,7 @@ import bitsandbytes as bnb from bitsandbytes import functional as F -from bitsandbytes.cextension import BNB_HIP_VERSION, HIP_ENVIRONMENT +from bitsandbytes.cextension import BNB_HIP_VERSION, HIP_ENVIRONMENT, ROCM_GPU_ARCH from tests.helpers import BOOLEAN_TUPLES, TRUE_FALSE, describe_dtype, get_blocksizes, get_test_dims, id_formatter torch.set_printoptions(precision=5, sci_mode=False, linewidth=120, edgeitems=20, threshold=10000) @@ -2242,6 +2242,7 @@ def test_managed(): @pytest.mark.parametrize("storage_type", ["nf4", "fp4"], ids=["nf4", "fp4"]) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=describe_dtype) @pytest.mark.parametrize("double_quant", [False], ids=["DQ_True"]) +@pytest.mark.skipif(HIP_ENVIRONMENT and ROCM_GPU_ARCH == "gfx90a", reason="this test is not supported on ROCm with gfx90a architecture yet") def test_gemv_eye_4bit(storage_type, dtype, double_quant): dims = 10 torch.random.manual_seed(np.random.randint(0, 412424242)) From cae33c38d56d2b7a42f1b481ef60c8218374c8be Mon Sep 17 00:00:00 2001 From: root Date: Mon, 29 Apr 2024 18:57:51 +0000 Subject: [PATCH 2/5] implement get_rocm_gpu_arch --- bitsandbytes/cextension.py | 12 ++++-------- bitsandbytes/cuda_specs.py | 26 ++++++++++++++++++++++++++ 2 files changed, 30 insertions(+), 8 deletions(-) diff --git a/bitsandbytes/cextension.py b/bitsandbytes/cextension.py index f5924f7f9..090c6116a 100644 --- a/bitsandbytes/cextension.py +++ b/bitsandbytes/cextension.py @@ -27,7 +27,7 @@ import torch from bitsandbytes.consts import DYNAMIC_LIBRARY_SUFFIX, PACKAGE_DIR -from bitsandbytes.cuda_specs import CUDASpecs, get_cuda_specs +from bitsandbytes.cuda_specs import CUDASpecs, get_cuda_specs, get_rocm_gpu_arch logger = logging.getLogger(__name__) @@ -116,18 +116,14 @@ def get_native_library() -> BNBNativeLibrary: return BNBNativeLibrary(dll) +ROCM_GPU_ARCH = get_rocm_gpu_arch() + try: if torch.version.hip: hip_major, hip_minor = map(int, torch.version.hip.split(".")[0:2]) HIP_ENVIRONMENT, BNB_HIP_VERSION = True, hip_major * 100 + hip_minor - result = subprocess.run(['rocminfo'], capture_output=True, text=True) - match = re.search(r'Name:\s+gfx(\d+)', result.stdout) - if match: - ROCM_GPU_ARCH = "gfx" + match.group(1) - else: - ROCM_GPU_ARCH = "unknown" else: - HIP_ENVIRONMENT, BNB_HIP_VERSION, ROCM_GPU_ARCH = False, 0, "unknown" + HIP_ENVIRONMENT, BNB_HIP_VERSION = False, 0 lib = get_native_library() except Exception as e: lib = None diff --git a/bitsandbytes/cuda_specs.py b/bitsandbytes/cuda_specs.py index 50c139317..58c43789c 100644 --- a/bitsandbytes/cuda_specs.py +++ b/bitsandbytes/cuda_specs.py @@ -1,5 +1,8 @@ import dataclasses from typing import List, Optional, Tuple +import logging +import subprocess +import re import torch @@ -42,3 +45,26 @@ def get_cuda_specs() -> Optional[CUDASpecs]: cuda_version_string=(get_cuda_version_string()), cuda_version_tuple=get_cuda_version_tuple(), ) + +def get_rocm_gpu_arch() -> str: + logger = logging.getLogger(__name__) + try: + if torch.version.hip: + result = subprocess.run(['rocminfo'], capture_output=True, text=True) + match = re.search(r'Name:\s+gfx(\d+)', result.stdout) + if match: + return "gfx" + match.group(1) + else: + return "unknown" + else: + return "unknown" + except Exception as e: + logger.error(f"Could not detect ROCm GPU architecture: {e}") + if torch.cuda.is_available(): + logger.warning( + """ +ROCm GPU architecture detection failed despite ROCm being available. + """, + ) + return "unknown" + From da53f39fba7a6516aec228b60c7ff1199b6c510b Mon Sep 17 00:00:00 2001 From: root Date: Tue, 30 Apr 2024 00:01:29 +0000 Subject: [PATCH 3/5] fixing lint --- bitsandbytes/archive_functional.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/bitsandbytes/archive_functional.py b/bitsandbytes/archive_functional.py index dac7430ed..b050a6018 100644 --- a/bitsandbytes/archive_functional.py +++ b/bitsandbytes/archive_functional.py @@ -170,7 +170,9 @@ def get_instance(cls): dtype2bytes[torch.int8] = 1 -def get_paged(*shape, dtype=torch.float32, device=torch.device("cuda", index=0)): +def get_paged(*shape, dtype=torch.float32, device=None): + if device is None: + torch.device("cuda", index=0) num_bytes = dtype2bytes[dtype] * prod(shape) cuda_ptr = lib.cget_managed_ptr(ct.c_size_t(num_bytes)) c_ptr = ct.cast(cuda_ptr, ct.POINTER(ct.c_int)) @@ -246,8 +248,8 @@ def create_linear_map(signed=True, total_bits=8, add_zero=True): if gap == 0: return values else: - l = values.numel() // 2 - return torch.Tensor(values[:l].tolist() + [0] * gap + values[l:].tolist()) + l_var = values.numel() // 2 + return torch.Tensor(values[:l_var].tolist() + [0] * gap + values[l_var:].tolist()) def create_normal_map(offset=0.9677083, use_extra_value=True): @@ -679,7 +681,7 @@ def quantize_blockwise( def dequantize_blockwise( A: Tensor, - quant_state: Tuple[Tensor, Tensor] = None, + quant_state: Optional[Tuple[Tensor, Tensor]] = None, absmax: Tensor = None, code: Tensor = None, out: Tensor = None, @@ -857,7 +859,7 @@ def quantize_4bit( def dequantize_fp4( A: Tensor, - quant_state: Tuple[Tensor, Tensor] = None, + quant_state: Optional[Tuple[Tensor, Tensor]] = None, absmax: Tensor = None, out: Tensor = None, blocksize: int = 64, @@ -867,7 +869,7 @@ def dequantize_fp4( def dequantize_nf4( A: Tensor, - quant_state: Tuple[Tensor, Tensor] = None, + quant_state: Optional[Tuple[Tensor, Tensor]] = None, absmax: Tensor = None, out: Tensor = None, blocksize: int = 64, @@ -877,7 +879,7 @@ def dequantize_nf4( def dequantize_4bit( A: Tensor, - quant_state: Tuple[Tensor, Tensor] = None, + quant_state: Optional[Tuple[Tensor, Tensor]] = None, absmax: Tensor = None, out: Tensor = None, blocksize: int = 64, @@ -979,7 +981,7 @@ def quantize(A: Tensor, code: Tensor = None, out: Tensor = None) -> Tensor: def dequantize( A: Tensor, - quant_state: Tuple[Tensor, Tensor] = None, + quant_state: Optional[Tuple[Tensor, Tensor]] = None, absmax: Tensor = None, code: Tensor = None, out: Tensor = None, From ae4dcec5279ca53b8dbcc624063b3f9b03b156ec Mon Sep 17 00:00:00 2001 From: root Date: Tue, 30 Apr 2024 00:22:56 +0000 Subject: [PATCH 4/5] fixing lint --- bitsandbytes/archive_functional.py | 2 +- bitsandbytes/cextension.py | 3 --- bitsandbytes/cuda_specs.py | 8 ++++---- tests/test_functional.py | 5 ++++- 4 files changed, 9 insertions(+), 9 deletions(-) diff --git a/bitsandbytes/archive_functional.py b/bitsandbytes/archive_functional.py index b050a6018..53b0c3ce6 100644 --- a/bitsandbytes/archive_functional.py +++ b/bitsandbytes/archive_functional.py @@ -6,7 +6,7 @@ from functools import reduce # Required in Python 3 import itertools import operator -from typing import Tuple +from typing import Optional, Tuple import numpy as np from scipy.stats import norm diff --git a/bitsandbytes/cextension.py b/bitsandbytes/cextension.py index 090c6116a..03d2cbd61 100644 --- a/bitsandbytes/cextension.py +++ b/bitsandbytes/cextension.py @@ -19,9 +19,6 @@ import ctypes as ct import logging import os -import subprocess -import re - from pathlib import Path import torch diff --git a/bitsandbytes/cuda_specs.py b/bitsandbytes/cuda_specs.py index 58c43789c..d532b738c 100644 --- a/bitsandbytes/cuda_specs.py +++ b/bitsandbytes/cuda_specs.py @@ -1,8 +1,8 @@ import dataclasses -from typing import List, Optional, Tuple import logging -import subprocess import re +import subprocess +from typing import List, Optional, Tuple import torch @@ -62,9 +62,9 @@ def get_rocm_gpu_arch() -> str: logger.error(f"Could not detect ROCm GPU architecture: {e}") if torch.cuda.is_available(): logger.warning( - """ + """ ROCm GPU architecture detection failed despite ROCm being available. - """, + """, ) return "unknown" diff --git a/tests/test_functional.py b/tests/test_functional.py index dffe724a6..b5a3fab35 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -2242,7 +2242,10 @@ def test_managed(): @pytest.mark.parametrize("storage_type", ["nf4", "fp4"], ids=["nf4", "fp4"]) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=describe_dtype) @pytest.mark.parametrize("double_quant", [False], ids=["DQ_True"]) -@pytest.mark.skipif(HIP_ENVIRONMENT and ROCM_GPU_ARCH == "gfx90a", reason="this test is not supported on ROCm with gfx90a architecture yet") +@pytest.mark.skipif( + HIP_ENVIRONMENT and ROCM_GPU_ARCH == "gfx90a", + reason="this test is not supported on ROCm with gfx90a architecture yet", +) def test_gemv_eye_4bit(storage_type, dtype, double_quant): dims = 10 torch.random.manual_seed(np.random.randint(0, 412424242)) From 21d5ff6066389ecbafa5963a934a362091307fbd Mon Sep 17 00:00:00 2001 From: root Date: Tue, 30 Apr 2024 14:34:52 +0000 Subject: [PATCH 5/5] correct lint error --- bitsandbytes/archive_functional.py | 20 +++++++++----------- bitsandbytes/cuda_specs.py | 8 ++++---- tests/test_functional.py | 4 ++-- 3 files changed, 15 insertions(+), 17 deletions(-) diff --git a/bitsandbytes/archive_functional.py b/bitsandbytes/archive_functional.py index 53b0c3ce6..dac7430ed 100644 --- a/bitsandbytes/archive_functional.py +++ b/bitsandbytes/archive_functional.py @@ -6,7 +6,7 @@ from functools import reduce # Required in Python 3 import itertools import operator -from typing import Optional, Tuple +from typing import Tuple import numpy as np from scipy.stats import norm @@ -170,9 +170,7 @@ def get_instance(cls): dtype2bytes[torch.int8] = 1 -def get_paged(*shape, dtype=torch.float32, device=None): - if device is None: - torch.device("cuda", index=0) +def get_paged(*shape, dtype=torch.float32, device=torch.device("cuda", index=0)): num_bytes = dtype2bytes[dtype] * prod(shape) cuda_ptr = lib.cget_managed_ptr(ct.c_size_t(num_bytes)) c_ptr = ct.cast(cuda_ptr, ct.POINTER(ct.c_int)) @@ -248,8 +246,8 @@ def create_linear_map(signed=True, total_bits=8, add_zero=True): if gap == 0: return values else: - l_var = values.numel() // 2 - return torch.Tensor(values[:l_var].tolist() + [0] * gap + values[l_var:].tolist()) + l = values.numel() // 2 + return torch.Tensor(values[:l].tolist() + [0] * gap + values[l:].tolist()) def create_normal_map(offset=0.9677083, use_extra_value=True): @@ -681,7 +679,7 @@ def quantize_blockwise( def dequantize_blockwise( A: Tensor, - quant_state: Optional[Tuple[Tensor, Tensor]] = None, + quant_state: Tuple[Tensor, Tensor] = None, absmax: Tensor = None, code: Tensor = None, out: Tensor = None, @@ -859,7 +857,7 @@ def quantize_4bit( def dequantize_fp4( A: Tensor, - quant_state: Optional[Tuple[Tensor, Tensor]] = None, + quant_state: Tuple[Tensor, Tensor] = None, absmax: Tensor = None, out: Tensor = None, blocksize: int = 64, @@ -869,7 +867,7 @@ def dequantize_fp4( def dequantize_nf4( A: Tensor, - quant_state: Optional[Tuple[Tensor, Tensor]] = None, + quant_state: Tuple[Tensor, Tensor] = None, absmax: Tensor = None, out: Tensor = None, blocksize: int = 64, @@ -879,7 +877,7 @@ def dequantize_nf4( def dequantize_4bit( A: Tensor, - quant_state: Optional[Tuple[Tensor, Tensor]] = None, + quant_state: Tuple[Tensor, Tensor] = None, absmax: Tensor = None, out: Tensor = None, blocksize: int = 64, @@ -981,7 +979,7 @@ def quantize(A: Tensor, code: Tensor = None, out: Tensor = None) -> Tensor: def dequantize( A: Tensor, - quant_state: Optional[Tuple[Tensor, Tensor]] = None, + quant_state: Tuple[Tensor, Tensor] = None, absmax: Tensor = None, code: Tensor = None, out: Tensor = None, diff --git a/bitsandbytes/cuda_specs.py b/bitsandbytes/cuda_specs.py index d532b738c..e104762e3 100644 --- a/bitsandbytes/cuda_specs.py +++ b/bitsandbytes/cuda_specs.py @@ -46,12 +46,13 @@ def get_cuda_specs() -> Optional[CUDASpecs]: cuda_version_tuple=get_cuda_version_tuple(), ) + def get_rocm_gpu_arch() -> str: logger = logging.getLogger(__name__) try: if torch.version.hip: - result = subprocess.run(['rocminfo'], capture_output=True, text=True) - match = re.search(r'Name:\s+gfx(\d+)', result.stdout) + result = subprocess.run(["rocminfo"], capture_output=True, text=True) + match = re.search(r"Name:\s+gfx(\d+)", result.stdout) if match: return "gfx" + match.group(1) else: @@ -65,6 +66,5 @@ def get_rocm_gpu_arch() -> str: """ ROCm GPU architecture detection failed despite ROCm being available. """, - ) + ) return "unknown" - diff --git a/tests/test_functional.py b/tests/test_functional.py index b5a3fab35..8acd5395d 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -2243,8 +2243,8 @@ def test_managed(): @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=describe_dtype) @pytest.mark.parametrize("double_quant", [False], ids=["DQ_True"]) @pytest.mark.skipif( - HIP_ENVIRONMENT and ROCM_GPU_ARCH == "gfx90a", - reason="this test is not supported on ROCm with gfx90a architecture yet", + HIP_ENVIRONMENT and ROCM_GPU_ARCH == "gfx90a", + reason="this test is not supported on ROCm with gfx90a architecture yet", ) def test_gemv_eye_4bit(storage_type, dtype, double_quant): dims = 10