From 6eed2b08cb22a8d091162ad5a14bcfee8f75a597 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Tue, 1 Jul 2025 10:28:29 +0000 Subject: [PATCH] fix triton kernel on the correct device Signed-off-by: jiqing-feng --- bitsandbytes/backends/triton/ops.py | 67 +++++++++++++++++------------ 1 file changed, 39 insertions(+), 28 deletions(-) diff --git a/bitsandbytes/backends/triton/ops.py b/bitsandbytes/backends/triton/ops.py index 1e2802ab5..058c2747d 100644 --- a/bitsandbytes/backends/triton/ops.py +++ b/bitsandbytes/backends/triton/ops.py @@ -9,6 +9,8 @@ # from bitsandbytes.functional import get_4bit_type # _FP4_QUANT_TABLE = get_4bit_type("fp4", device="xpu") # _NF4_QUANT_TABLE = get_4bit_type("nf4", device="xpu") +device_type = torch.accelerator.current_accelerator().type if hasattr(torch, "accelerator") else "cuda" +torch_accelerator_module = getattr(torch, device_type, torch.cuda) def quantize_blockwise(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> tuple[torch.Tensor, torch.Tensor]: @@ -21,7 +23,9 @@ def quantize_blockwise(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> t absmax = torch.empty((blocks,), device=A.device, dtype=A.dtype) out = torch.empty_like(A.flatten(), dtype=torch.uint8) - triton_kernels.quantize_blockwise_triton(A, blocksize, code, blocks, absmax, out) + with torch_accelerator_module.device(A.device): + triton_kernels.quantize_blockwise_triton(A, blocksize, code, blocks, absmax, out) + out = out.reshape(A.shape) return out, absmax.float() @@ -35,13 +39,14 @@ def dequantize_blockwise( # torch._check(dtype == torch.float32, lambda: f"dtype must be float32 on xpu, got {dtype}") out = torch.empty_like(A, dtype=dtype, device=A.device) - triton_kernels.dequant_int8_blockwise( - A, - code, - absmax, - out, - blocksize, - ) + with torch_accelerator_module.device(A.device): + triton_kernels.dequant_int8_blockwise( + A, + code, + absmax, + out, + blocksize, + ) return out @@ -55,13 +60,14 @@ def dequantize_blockwise_inplace( torch._check(out.device == A.device, lambda: f"Expected out.device == {A.device}, got {out.device}") torch._check(out.dtype == dtype, lambda: f"Expected out.dtype == {dtype}, got {out.dtype}") - triton_kernels.dequant_int8_blockwise( - A, - code, - absmax, - out, - blocksize, - ) + with torch_accelerator_module.device(A.device): + triton_kernels.dequant_int8_blockwise( + A, + code, + absmax, + out, + blocksize, + ) def quantize_4bit( @@ -84,9 +90,10 @@ def quantize_4bit( absmax = torch.empty((blocks * 2,), device=A.device, dtype=A.dtype) out = torch.empty((n // 2, 1), device=A.device, dtype=torch.uint8) - triton_kernels.quantize_4bit_blockwise_triton( - A, blocksize, quant_type, blocks, absmax, num_elements=n, quantized_out=out - ) + with torch_accelerator_module.device(A.device): + triton_kernels.quantize_4bit_blockwise_triton( + A, blocksize, quant_type, blocks, absmax, num_elements=n, quantized_out=out + ) packed = out if quant_storage != torch.uint8: @@ -119,7 +126,9 @@ def dequantize_4bit( out = torch.empty(shape, dtype=dtype, device=A.device) - triton_kernels._dequantize_4bit_impl(A, absmax, blocksize, quant_type, dtype, out=out) + with torch_accelerator_module.device(A.device): + triton_kernels._dequantize_4bit_impl(A, absmax, blocksize, quant_type, dtype, out=out) + return out @@ -134,7 +143,8 @@ def dequantize_4bit_inplace( ) -> None: torch._check(out.shape == shape, lambda: f"Expected out.shape == {shape}, got {out.shape}") torch._check(out.dtype == dtype, lambda: f"Expected out.dtype == {dtype}, got {out.dtype}") - triton_kernels._dequantize_4bit_impl(A, absmax, blocksize, quant_type, dtype, out=out) + with torch_accelerator_module.device(A.device): + triton_kernels._dequantize_4bit_impl(A, absmax, blocksize, quant_type, dtype, out=out) def gemv_4bit( @@ -150,14 +160,15 @@ def gemv_4bit( B_dq_triton = torch.empty(shapeB, dtype=A.dtype, device=A.device) - triton_kernels._dequantize_4bit_impl_passing_code( - B, - absmax, - blocksize, - code, - dtype=A.dtype, - out=B_dq_triton, - ) + with torch_accelerator_module.device(A.device): + triton_kernels._dequantize_4bit_impl_passing_code( + B, + absmax, + blocksize, + code, + dtype=A.dtype, + out=B_dq_triton, + ) return torch.nn.functional.linear( A,