diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml index d2da82501..21c4c1895 100644 --- a/.github/workflows/python-package.yml +++ b/.github/workflows/python-package.yml @@ -107,7 +107,7 @@ jobs: os: [ubuntu-latest] arch: [x86_64] rocm_version: - ["6.1.2"] + ["6.1.2", "6.2"] runs-on: ${{ matrix.os }} # One day, we could run them on native agents. Azure supports this now but it's planned only for Q3 2023 for hosted agents steps: - uses: actions/checkout@v4 diff --git a/.github/workflows/upload_pr_documentation.yml b/.github/workflows/upload_pr_documentation.yml index 6497caf2d..707705297 100644 --- a/.github/workflows/upload_pr_documentation.yml +++ b/.github/workflows/upload_pr_documentation.yml @@ -6,6 +6,10 @@ on: types: - completed +permissions: + contents: read + pull-requests: write # Allows posting comments on pull requests + jobs: build: uses: huggingface/doc-builder/.github/workflows/upload_pr_documentation.yml@main diff --git a/bitsandbytes/backends/cpu_xpu_common.py b/bitsandbytes/backends/cpu_xpu_common.py index 0fcfffa07..0d865b541 100644 --- a/bitsandbytes/backends/cpu_xpu_common.py +++ b/bitsandbytes/backends/cpu_xpu_common.py @@ -370,25 +370,6 @@ def quantize_4bit_impl( quant_type=quant_type, ) - if ipex_cpu and _ipex_cpu_version_prereq(2, 3) and input_shape[1] % blocksize == 0 and quant_type == "nf4": - # lowp_mode: lowest precision for computation - lowp_mode = ipex_cpu.quantization.WoqLowpMode.BF16 - state.op_context = torch.ops.ipex_prepack.weight_only_qlinear_prepack( - out.reshape([input_shape[0], input_shape[1] // 2]), - ipex_cpu.quantization.WoqWeightDtype.NF4, - input_shape, # weight shape - absmax.view(input_shape[0], input_shape[1] // blocksize), # scales - None, # zero_points - None, # bias - None, # g_idx - None, # batch_size - blocksize, - int(lowp_mode), - -1, # act_quant_mode. -1 means don't quant activation - ) - state.absmax = torch.Tensor() - return torch.empty([1, 0], dtype=torch.uint8), state - return out.unsqueeze(0), state diff --git a/bitsandbytes/cextension.py b/bitsandbytes/cextension.py index e322693b5..cc5d8deff 100644 --- a/bitsandbytes/cextension.py +++ b/bitsandbytes/cextension.py @@ -106,10 +106,6 @@ def get_native_library() -> BNBNativeLibrary: if hasattr(dll, "get_context"): # only a CUDA-built library exposes this return CudaBNBNativeLibrary(dll) - logger.warning( - "The installed version of bitsandbytes was compiled without GPU support. " - "8-bit optimizers, 8-bit multiplication, and GPU quantization are unavailable.", - ) return BNBNativeLibrary(dll) diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index 2348d0791..ad424a6f4 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -19,6 +19,7 @@ INVERSE_LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING, LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING, OutlierTracer, + enable_ipex_fusion, ) T = TypeVar("T", bound="torch.nn.Module") @@ -444,17 +445,35 @@ def _save_to_state_dict(self, destination, prefix, keep_vars): save weight and bias, then fill state_dict with components of quant_state """ + if ( + getattr(self.weight, "quant_state", None) is not None + and getattr(self.weight.quant_state, "op_context", None) is not None + ): + context = self.weight.quant_state.op_context + self.weight.data = context.to_public(context.get_weight()).reshape([1, -1]) + super()._save_to_state_dict(destination, prefix, keep_vars) # saving weight and bias if getattr(self.weight, "quant_state", None) is not None: + if ( + self.weight.quant_state.absmax.shape.numel() == 0 + and getattr(self.weight.quant_state, "op_context", None) is not None + ): + self.weight.quant_state.absmax = context.get_scales().reshape(-1) + delattr(self.weight.quant_state, "op_context") for k, v in self.weight.quant_state.as_dict(packed=True).items(): destination[prefix + "weight." + k] = v if keep_vars else v.detach() - if getattr(self.weight.quant_state, "op_context", None) is not None: - context = self.weight.quant_state.op_context - destination[prefix + "weight." + "absmax"] = context.get_scales().reshape(-1) - self.weight.data = context.to_public(context.get_weight()).reshape([1, -1]) def forward(self, x: torch.Tensor): + # Check if ipex fusion can be used + if ( + x.device.type == "cpu" + and not hasattr(self.weight.quant_state, "op_context") + and self.weight.quant_state.shape[1] % self.weight.quant_state.blocksize == 0 + and self.weight.quant_state.quant_type == "nf4" + ): + enable_ipex_fusion(self.weight, self.weight.quant_state) + # weights are cast automatically as Int8Params, but the bias has to be cast manually if self.bias is not None and self.bias.dtype != x.dtype: self.bias.data = self.bias.data.to(x.dtype) diff --git a/bitsandbytes/utils.py b/bitsandbytes/utils.py index fa9a7eb70..9e52c915d 100644 --- a/bitsandbytes/utils.py +++ b/bitsandbytes/utils.py @@ -200,6 +200,30 @@ def unpack_tensor_to_dict(tensor_data): return unpacked_dict +def enable_ipex_fusion(weight, quant_state): + from bitsandbytes.backends.cpu_xpu_common import _ipex_cpu_version_prereq + + if _ipex_cpu_version_prereq(2, 3): + import intel_extension_for_pytorch as ipex + + lowp_mode = ipex.quantization.WoqLowpMode.BF16 + quant_state.op_context = torch.ops.ipex_prepack.weight_only_qlinear_prepack( + weight.data.reshape([quant_state.shape[0], quant_state.shape[1] // 2]), + ipex.quantization.WoqWeightDtype.NF4, + quant_state.shape, # weight shape + quant_state.absmax.view(quant_state.shape[0], quant_state.shape[1] // quant_state.blocksize), # scales + None, # zero_points + None, # bias + None, # g_idx + None, # batch_size + quant_state.blocksize, + int(lowp_mode), + -1, # act_quant_mode. -1 means don't quant activation + ) + quant_state.absmax = torch.Tensor() + weight.data = torch.empty([1, 0], dtype=torch.uint8) + + class QuantState: """container for quantization state components to work with Params4bit and similar classes""" diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml index fdfe19ee4..a72eb1967 100644 --- a/docs/source/_toctree.yml +++ b/docs/source/_toctree.yml @@ -12,6 +12,8 @@ title: 8-bit optimizers - local: algorithms title: Algorithms + - local: non_cuda_backends + title: Non-CUDA compute backends - local: fsdp_qlora title: FSDP-QLoRA - local: integrations diff --git a/docs/source/installation.mdx b/docs/source/installation.mdx index 0e8da0cda..60419b38a 100644 --- a/docs/source/installation.mdx +++ b/docs/source/installation.mdx @@ -134,14 +134,23 @@ export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/home/YOUR_USERNAME/local/cuda-11.7 3. Now when you launch bitsandbytes with these environment variables, the PyTorch CUDA version is overridden by the new CUDA version (in this example, version 11.7) and a different bitsandbytes library is loaded. -## Multi-backend preview release compilation[[multi-backend]] +## Multi-backend[[multi-backend]] + +> [!TIP] +> This functionality is currently in preview and therefore not yet production-ready! Please follow these steps to install bitsandbytes with device-specific backend support other than CUDA: +### Pip install the pre-built wheel (recommended for most) + +WIP (will be added in the coming days) + +### Compilation + -### AMD GPU +#### AMD GPU bitsandbytes is fully supported from ROCm 6.1 onwards (currently in alpha release). @@ -179,7 +188,7 @@ pip install -e . # `-e` for "editable" install, when developing BNB (otherwise -### Intel CPU +#### Intel CPU > [!TIP] > Intel CPU backend only supports building from source; for now, please follow the instructions below. @@ -200,6 +209,8 @@ pip install -e . # `-e` for "editable" install, when developing BNB (otherwise +#### Apple Silicon + WIP diff --git a/docs/source/non_cuda_backends.mdx b/docs/source/non_cuda_backends.mdx new file mode 100644 index 000000000..fca586534 --- /dev/null +++ b/docs/source/non_cuda_backends.mdx @@ -0,0 +1,27 @@ +# Multi-backend support (non-CUDA backends) + +As part of a recent refactoring effort, we will soon offer official multi-backend support. Currently, this feature is available in a preview alpha release, allowing us to gather early feedback from users to improve the functionality and identify any bugs. + +At present, the Intel CPU and AMD ROCm backends are considered fully functional. The Intel XPU backend has limited functionality and is less mature. + +Please refer to the [installation instructions](./installation#multi-backend) for details on installing the backend you intend to test (and hopefully provide feedback on). + +> [!Tip] +> Apple Silicon support is planned for Q4 2024. We are actively seeking contributors to help implement this, develop a concrete plan, and create a detailed list of requirements. Due to limited resources, we rely on community contributions for this implementation effort. To discuss further, please spell out your thoughts and discuss in [this GitHub discussion](https://github.com/bitsandbytes-foundation/bitsandbytes/discussions/1340) and tag `@Titus-von-Koeller` and `@matthewdouglas`. Thank you! + +## Alpha Release + +As we are currently in the alpha testing phase, bugs are expected, and performance might not meet expectations. However, this is exactly what we want to discover from **your** perspective as the end user! + +Please share and discuss your feedback with us here: + +- [Github Discussion: Multi-backend refactor: Alpha release ( AMD ROCm ONLY )](https://github.com/bitsandbytes-foundation/bitsandbytes/discussions/1339) +- [Github Discussion: Multi-backend refactor: Alpha release ( Intel ONLY )](https://github.com/bitsandbytes-foundation/bitsandbytes/discussions/1338) + +Thank you for your support! + +## Benchmarks + +### Intel + +### AMD diff --git a/tests/test_functional.py b/tests/test_functional.py index a9d926b89..35187db78 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -2303,6 +2303,7 @@ def test_gemv_4bit(dtype, storage_type, quant_storage, double_quant, kind): assert maxratio < 1.02 and maxratio > 0.98 +@pytest.mark.skipif(HIP_ENVIRONMENT, reason="this test is not supported on ROCm yet") @pytest.mark.parametrize("kind", ["fc1", "fc2", "attn", "attn_packed"]) @pytest.mark.parametrize("quant_type", ["nf4", "fp4"]) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=describe_dtype)