Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions bitsandbytes/cextension.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def get_native_library() -> BNBNativeLibrary:
if cuda_binary_path.exists():
binary_path = cuda_binary_path
else:
logger.warning("Could not find the bitsandbytes CUDA binary at %r", cuda_binary_path)
logger.warning(f"Could not find the bitsandbytes {BNB_BACKEND} binary at {cuda_binary_path}")
logger.debug(f"Loading bitsandbytes native library from: {binary_path}")
dll = ct.cdll.LoadLibrary(str(binary_path))

Expand All @@ -120,21 +120,23 @@ def get_native_library() -> BNBNativeLibrary:
hip_major, hip_minor = map(int, torch.version.hip.split(".")[0:2])
HIP_ENVIRONMENT, BNB_HIP_VERSION = True, hip_major * 100 + hip_minor
BNB_HIP_VERSION_SHORT = f"{hip_major}{hip_minor}"
BNB_BACKEND = "ROCM"
else:
HIP_ENVIRONMENT, BNB_HIP_VERSION = False, 0
BNB_HIP_VERSION_SHORT = ""
BNB_BACKEND = "CUDA"
lib = get_native_library()
except Exception as e:
lib = None
logger.error(f"Could not load bitsandbytes native library: {e}", exc_info=True)
if torch.cuda.is_available():
logger.warning(
"""
CUDA Setup failed despite CUDA being available. Please run the following command to get more information:
f"""
{BNB_BACKEND} Setup failed despite {BNB_BACKEND} being available. Please run the following command to get more information:

python -m bitsandbytes

Inspect the output of the command and see if you can locate CUDA libraries. You might need to add them
Inspect the output of the command and see if you can locate {BNB_BACKEND} libraries. You might need to add them
to your LD_LIBRARY_PATH. If you suspect a bug, please take the information from python -m bitsandbytes
and open an issue at: https://github.com/TimDettmers/bitsandbytes/issues
""",
Expand Down
124 changes: 82 additions & 42 deletions bitsandbytes/diagnostics/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import torch

from bitsandbytes.cextension import get_cuda_bnb_library_path
from bitsandbytes.cextension import BNB_BACKEND, HIP_ENVIRONMENT, get_cuda_bnb_library_path
from bitsandbytes.consts import NONPYTORCH_DOC_URL
from bitsandbytes.cuda_specs import CUDASpecs
from bitsandbytes.diagnostics.utils import print_dedented
Expand Down Expand Up @@ -38,6 +38,9 @@
"nvcuda*.dll", # Windows
)

if HIP_ENVIRONMENT:
CUDA_RUNTIME_LIB_PATTERNS = ("libamdhip64.so*",)

logger = logging.getLogger(__name__)


Expand All @@ -56,8 +59,8 @@ def find_cuda_libraries_in_path_list(paths_list_candidate: str) -> Iterable[Path
except OSError: # Assume an esoteric error trying to poke at the directory
pass
for lib_pattern in CUDA_RUNTIME_LIB_PATTERNS:
for pth in dir.glob(lib_pattern):
if pth.is_file():
for pth in dir.rglob(lib_pattern):
if pth.is_file() and not pth.is_symlink():
yield pth
except (OSError, PermissionError):
pass
Expand Down Expand Up @@ -105,37 +108,58 @@ def find_cudart_libraries() -> Iterator[Path]:


def print_cuda_diagnostics(cuda_specs: CUDASpecs) -> None:
print(
f"PyTorch settings found: CUDA_VERSION={cuda_specs.cuda_version_string}, "
f"Highest Compute Capability: {cuda_specs.highest_compute_capability}.",
)
if not HIP_ENVIRONMENT:
print(
f"PyTorch settings found: CUDA_VERSION={cuda_specs.cuda_version_string}, "
f"Highest Compute Capability: {cuda_specs.highest_compute_capability}.",
)
else:
print(f"PyTorch settings found: ROCM_VERSION={cuda_specs.cuda_version_string}")

binary_path = get_cuda_bnb_library_path(cuda_specs)
if not binary_path.exists():
print_dedented(
f"""
Library not found: {binary_path}. Maybe you need to compile it from source?
If you compiled from source, try again with `make CUDA_VERSION=DETECTED_CUDA_VERSION`,
for example, `make CUDA_VERSION=113`.

The CUDA version for the compile might depend on your conda install, if using conda.
Inspect CUDA version via `conda list | grep cuda`.
""",
)

cuda_major, cuda_minor = cuda_specs.cuda_version_tuple
if cuda_major < 11:
print_dedented(
"""
WARNING: CUDA versions lower than 11 are currently not supported for LLM.int8().
You will be only to use 8-bit optimizers and quantization routines!
if not HIP_ENVIRONMENT:
print_dedented(
f"""
Library not found: {binary_path}. Maybe you need to compile it from source?
If you compiled from source, try again with `make CUDA_VERSION=DETECTED_CUDA_VERSION`,
for example, `make CUDA_VERSION=113`.

The CUDA version for the compile might depend on your conda install, if using conda.
Inspect CUDA version via `conda list | grep cuda`.
""",
)
)
else:
print_dedented(
f"""
Library not found: {binary_path}.
Maybe you need to compile it from source? If you compiled from source, check that ROCM_VERSION
in PyTorch Settings matches your ROCM install. If not, reinstall PyTorch for your ROCm version
and rebuild bitsandbytes.
""",
)

print(f"To manually override the PyTorch CUDA version please see: {NONPYTORCH_DOC_URL}")
cuda_major, cuda_minor = cuda_specs.cuda_version_tuple
if not HIP_ENVIRONMENT:
if cuda_major < 11:
print_dedented(
"""
WARNING: CUDA versions lower than 11 are currently not supported for LLM.int8().
You will be only to use 8-bit optimizers and quantization routines!
""",
)

print(f"To manually override the PyTorch CUDA version please see: {NONPYTORCH_DOC_URL}")
else:
if (cuda_major, cuda_minor) < (6, 1):
print_dedented(
"""
WARNING: bitandbytes is fully supported only from ROCm 6.1.
""",
)

# 7.5 is the minimum CC for cublaslt
if not cuda_specs.has_cublaslt:
if not cuda_specs.has_cublaslt and not HIP_ENVIRONMENT:
print_dedented(
"""
WARNING: Compute capability < 7.5 detected! Only slow 8-bit matmul is supported for your GPU!
Expand All @@ -152,25 +176,41 @@ def print_cuda_diagnostics(cuda_specs: CUDASpecs) -> None:
def print_cuda_runtime_diagnostics() -> None:
cudart_paths = list(find_cudart_libraries())
if not cudart_paths:
print("CUDA SETUP: WARNING! CUDA runtime files not found in any environmental path.")
print(f"{BNB_BACKEND} SETUP: WARNING! {BNB_BACKEND} runtime files not found in any environmental path.")
elif len(cudart_paths) > 1:
backend_version = torch.version.cuda if not HIP_ENVIRONMENT else torch.version.hip
print_dedented(
f"""
Found duplicate CUDA runtime files (see below).
Found duplicate {BNB_BACKEND} runtime files (see below).

We select the PyTorch default CUDA runtime, which is {torch.version.cuda},
but this might mismatch with the CUDA version that is needed for bitsandbytes.
To override this behavior set the `BNB_CUDA_VERSION=<version string, e.g. 122>` environmental variable.

For example, if you want to use the CUDA version 122,
BNB_CUDA_VERSION=122 python ...

OR set the environmental variable in your .bashrc:
export BNB_CUDA_VERSION=122

In the case of a manual override, make sure you set LD_LIBRARY_PATH, e.g.
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/cuda-11.2,
We select the PyTorch default {BNB_BACKEND} runtime, which is {backend_version},
but this might mismatch with the {BNB_BACKEND} version that is needed for bitsandbytes.
""",
)
if not HIP_ENVIRONMENT:
print_dedented(
"""
To override this behavior set the `BNB_CUDA_VERSION=<version string, e.g. 122>` environmental variable.

For example, if you want to use the CUDA version 122,
BNB_CUDA_VERSION=122 python ...

OR set the environmental variable in your .bashrc:
export BNB_CUDA_VERSION=122

In the case of a manual override, make sure you set LD_LIBRARY_PATH, e.g.
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/cuda-11.2,
""",
)
else:
print_dedented(
"""
To resolve it, install PyTorch built for the ROCm version you want to use

and set LD_LIBRARY_PATH to your ROCm install path, e.g.
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/rocm-6.1.2,
""",
)

for pth in cudart_paths:
print(f"* Found CUDA runtime at: {pth}")
print(f"* Found {BNB_BACKEND} runtime at: {pth}")
23 changes: 15 additions & 8 deletions bitsandbytes/diagnostics/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import torch

from bitsandbytes.cextension import BNB_BACKEND, HIP_ENVIRONMENT
from bitsandbytes.consts import PACKAGE_GITHUB_URL
from bitsandbytes.cuda_specs import get_cuda_specs
from bitsandbytes.diagnostics.cuda import (
Expand All @@ -16,12 +17,13 @@ def sanity_check():
from bitsandbytes.cextension import lib

if lib is None:
compute_backend = "cuda" if not HIP_ENVIRONMENT else "hip"
print_dedented(
"""
f"""
Couldn't load the bitsandbytes library, likely due to missing binaries.
Please ensure bitsandbytes is properly installed.

For source installations, compile the binaries with `cmake -DCOMPUTE_BACKEND=cuda -S .`.
For source installations, compile the binaries with `cmake -DCOMPUTE_BACKEND={compute_backend} -S .`.
See the documentation for more details if needed.

Trying a simple check anyway, but this will likely fail...
Expand Down Expand Up @@ -49,19 +51,24 @@ def main():

print_header("OTHER")
cuda_specs = get_cuda_specs()
print("CUDA specs:", cuda_specs)
if HIP_ENVIRONMENT:
rocm_specs = f" rocm_version_string='{cuda_specs.cuda_version_string}',"
rocm_specs += f" rocm_version_tuple={cuda_specs.cuda_version_tuple}"
print(f"{BNB_BACKEND} specs:{rocm_specs}")
else:
print(f"{BNB_BACKEND} specs:{cuda_specs}")
if not torch.cuda.is_available():
print("Torch says CUDA is not available. Possible reasons:")
print("1. CUDA driver not installed")
print("2. CUDA not installed")
print("3. You have multiple conflicting CUDA libraries")
print(f"Torch says {BNB_BACKEND} is not available. Possible reasons:")
print(f"1. {BNB_BACKEND} driver not installed")
print(f"2. {BNB_BACKEND} not installed")
print(f"3. You have multiple conflicting {BNB_BACKEND} libraries")
if cuda_specs:
print_cuda_diagnostics(cuda_specs)
print_cuda_runtime_diagnostics()
print_header("")
print_header("DEBUG INFO END")
print_header("")
print("Checking that the library is importable and CUDA is callable...")
print(f"Checking that the library is importable and {BNB_BACKEND} is callable...")
try:
sanity_check()
print("SUCCESS!")
Expand Down