From 5ee2ec1b65b2be46bf6ea2d9200b8d3b48fa5e51 Mon Sep 17 00:00:00 2001 From: arlo-phoenix Date: Fri, 8 Sep 2023 14:31:58 +0200 Subject: [PATCH 1/8] Add ROCm make target ROCM_TARGET=gfx1030 make hip Uses define BITS_AND_BYTES_USE_ROCM to redefine cuda functions to ROCm equivalent credit to previous ports: Co-authored-by: broncotc Co-authored-by: agrocylo <130291676+agrocylo@users.noreply.github.com> --- Makefile | 21 +++++++++++++++- csrc/kernels.cu | 25 +++++++++++++++---- csrc/ops.cu | 7 +++++- csrc/ops.cuh | 54 ++++++++++++++++++++++++++++++++++++++++-- include/Algo-Direct2.h | 16 ++++++------- 5 files changed, 107 insertions(+), 16 deletions(-) diff --git a/Makefile b/Makefile index 5f997a122..e27b2c24e 100644 --- a/Makefile +++ b/Makefile @@ -6,7 +6,11 @@ GPP:= /usr/bin/g++ ifeq ($(CUDA_HOME),) CUDA_HOME:= $(shell which nvcc | rev | cut -d'/' -f3- | rev) endif +ifeq ($(ROCM_HOME),) + ROCM_HOME:= $(shell which hipcc | rev | cut -d'/' -f4- | rev) +endif +ifneq ($(CUDA_HOME),) ifndef CUDA_VERSION ifneq ($(MAKECMDGOALS),clean) $(warning WARNING: CUDA_VERSION not set. Call make with CUDA string, for example: make cuda11x CUDA_VERSION=115 or make cpuonly CUDA_VERSION=CPU) @@ -14,9 +18,17 @@ CUDA_VERSION:= endif endif +else ifneq ($(ROCM_HOME),) +ifndef ROCM_TARGET +$(error ERROR: ROCM_TARGET not set. Call make with ROCM string (see https://www.llvm.org/docs/AMDGPUUsage.html#processors), for example: make hip ROCM_TARGET=gfx1030) +ROCM_TARGET:= +endif +endif + NVCC := $(CUDA_HOME)/bin/nvcc +HIPCC:= $(ROCM_HOME)/bin/hipcc ########################################### @@ -28,7 +40,8 @@ FILES_CPP := $(CSRC)/common.cpp $(CSRC)/cpu_ops.cpp $(CSRC)/pythonInterface.c INCLUDE := -I $(CUDA_HOME)/include -I $(ROOT_DIR)/csrc -I $(CONDA_PREFIX)/include -I $(ROOT_DIR)/include LIB := -L $(CUDA_HOME)/lib64 -lcudart -lcublas -lcublasLt -lcusparse -L $(CONDA_PREFIX)/lib - +HIP_INCLUDE := -I $(ROCM_HOME)/include -I $(ROOT_DIR)/csrc -I $(ROOT_DIR)/include +HIP_LIB := -L $(ROCM_HOME)/lib -lhipblas -lhiprand -lhipsparse #-lhipblaslt, currently only gfx90a # NVIDIA NVCC compilation flags COMPUTE_CAPABILITY += -gencode arch=compute_50,code=sm_50 # Maxwell COMPUTE_CAPABILITY += -gencode arch=compute_52,code=sm_52 # Maxwell @@ -115,6 +128,12 @@ cuda12x: $(BUILD_DIR) env cpuonly: $(BUILD_DIR) env $(GPP) -std=c++14 -shared -fPIC -I $(ROOT_DIR)/csrc -I $(ROOT_DIR)/include $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes_cpu.so +hip: $(BUILD_DIR) + $(HIPCC) -std=c++14 -c -fPIC --offload-arch=$(ROCM_TARGET) $(HIP_INCLUDE) -o $(BUILD_DIR)/ops.o -DNO_CUBLASLT -DBITS_AND_BYTES_USE_ROCM $(CSRC)/ops.cu + $(HIPCC) -std=c++14 -c -fPIC --offload-arch=$(ROCM_TARGET) $(HIP_INCLUDE) -o $(BUILD_DIR)/kernels.o -DNO_CUBLASLT -DBITS_AND_BYTES_USE_ROCM $(CSRC)/kernels.cu + # HCC is deprecated, but used by hipBLASlt header. Since blas isn't even used doesn't matter, this is just so that it even compiles + $(GPP) -std=c++14 -D__HIP_PLATFORM_HCC__ -D__HIP_PLATFORM_AMD__ -DBUILD_CUDA -DBITS_AND_BYTES_USE_ROCM -shared -fPIC $(HIP_INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(FILES_CPP) $(HIP_LIB) -o ./bitsandbytes/libbitsandbytes_hip_nohipblaslt.so + env: @echo "ENVIRONMENT" @echo "============================" diff --git a/csrc/kernels.cu b/csrc/kernels.cu index 9ebe0a69e..8fffbc33b 100644 --- a/csrc/kernels.cu +++ b/csrc/kernels.cu @@ -4,6 +4,23 @@ // LICENSE file in the root directory of this source tree. #include + +#ifdef BITS_AND_BYTES_USE_ROCM +#include +#include +#include +#include +#include +#include +#include +#include +#include +#define cub hipcub +#define __syncwarp __syncthreads //HIP doesn't have this, so just sync threads + +#else +#include +#include #include #include #include @@ -11,18 +28,17 @@ #include #include #include -#include +#endif + #include #include -#include - #define HLF_MAX 65504 #define TH 1024 #define NUM 4 #define NUM_BLOCK 4096 - +#ifndef BITS_AND_BYTES_USE_ROCM // source: https://stackoverflow.com/questions/17399119/how-do-i-use-atomicmax-on-floating-point-values-in-cuda __device__ float atomicMax(float* address, float val) { int* address_as_i = reinterpret_cast(address); @@ -47,6 +63,7 @@ __device__ float atomicMin(float* address, float val) { } while (assumed != old); return __int_as_float(old); } +#endif __device__ float dDequantizeFP4(unsigned char val, float absmax) { diff --git a/csrc/ops.cu b/csrc/ops.cu index 97761216c..252cc09fa 100644 --- a/csrc/ops.cu +++ b/csrc/ops.cu @@ -5,12 +5,17 @@ #include #include -#include #include #include #include #include +#ifdef BITS_AND_BYTES_USE_ROCM +#include +#else +#include +#endif + using namespace BinSearch; using std::cout; diff --git a/csrc/ops.cuh b/csrc/ops.cuh index f37b3b3af..f1cf7f9e1 100644 --- a/csrc/ops.cuh +++ b/csrc/ops.cuh @@ -12,16 +12,66 @@ #include #include + +#ifdef BITS_AND_BYTES_USE_ROCM +// check rocminfo | grep "Wavefront Size". Should be supported on all new GPU's +// dirty hack to force wavefront_size 32 so this compiles +// RDNA 2 defaults to 64 which conflicts with kQuantizeBlockwise +#define __AMDGCN_WAVEFRONT_SIZE 32 + +#include +#include +#include +#include //only using header to allow redefines +#include + +#define cudaPeekAtLastError hipPeekAtLastError +#define cudaMemset hipMemset +#define cudaMemAttachHost hipMemAttachHost +#define cudaMemPrefetchAsync hipMemPrefetchAsync +#define cudaMallocManaged hipMallocManaged +#define cudaDevAttrConcurrentManagedAccess hipDeviceAttributeConcurrentManagedAccess +#define cudaDeviceGetAttribute hipDeviceGetAttribute +#define cublasGemmEx hipblasGemmEx +#define cublasStatus_t hipblasStatus_t +#define CUBLAS_OP_T HIPBLAS_OP_T +#define CUBLAS_OP_N HIPBLAS_OP_N +#define CUDA_R_8I HIPBLAS_R_8I +#define CUDA_R_32I HIPBLAS_R_32I +#define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS +#define cublasStatus_t hipblasStatus_t +#define cublasGemmStridedBatchedEx hipblasGemmStridedBatchedEx +#define cublasOperation_t hipblasOperation_t +#define cublasLtMatrixLayoutCreate hipblasLtMatrixLayoutCreate +#define cudaError_t hipError_t +#define cudaGetErrorString hipGetErrorString +#define cudaSuccess hipSuccess +#define cusparseStatus_t hipsparseStatus_t +#define CUSPARSE_STATUS_SUCCESS HIPSPARSE_STATUS_SUCCESS +#define cublasStatus_t hipblasStatus_t +#define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS +#define cublasHandle_t hipblasHandle_t +#define cublasCreate_v2 hipblasCreate +#define cusparseHandle_t hipsparseHandle_t +#define cusparseCreate hipsparseCreate +#define __nv_bfloat16 hip_bfloat16 +#define cublasLtHandle_t hipblasLtHandle_t +#define cublasLtCreate hipblasLtCreate +#define CUBLAS_GEMM_DEFAULT HIPBLAS_GEMM_DEFAULT +#define CUBLAS_GEMM_DEFAULT_TENSOR_OP HIPBLAS_GEMM_DEFAULT //TODO: HIP didn't have the right one, might cause issues + +#else #include #include #include #include #include -#include -#include #include #include +#endif +#include +#include diff --git a/include/Algo-Direct2.h b/include/Algo-Direct2.h index d5fa58d12..f9387e5cd 100644 --- a/include/Algo-Direct2.h +++ b/include/Algo-Direct2.h @@ -93,8 +93,8 @@ struct AlgoVecBase::val __m128 vxp = _mm_shuffle_ps(xp01, xp23, (1) + (3 << 2) + (1 << 4) + (3 << 6)); #endif IVec i(u.vec); - IVec vlem = vz < vxm; - IVec vlep = vz < vxp; + IVec vlem = operator< (vz,vxm); + IVec vlep = operator< (vz,vxp); i = i + vlem + vlep; i.store(pr); } @@ -123,8 +123,8 @@ struct AlgoVecBase::val __m128d vxp = _mm_shuffle_pd(vx0, vx1, 3); IVec i(b1, b0); - IVec vlem = (vz < vxm); - IVec vlep = (vz < vxp); + IVec vlem = operator< (vz, vxm); + IVec vlep = operator< (vz, vxp); i = i + vlem + vlep; union { @@ -227,8 +227,8 @@ struct AlgoVecBase::val #endif - IVec vlem = vz < vxm; - IVec vlep = vz < vxp; + IVec vlem = operator< (vz, vxm); + IVec vlep = operator< (vz, vxp); ip = ip + vlem + vlep; ip.store(pr); @@ -277,8 +277,8 @@ struct AlgoVecBase::val // FVec vxp = _mm256_insertf128_pd(_mm256_castpd128_pd256(h01p), h23p, 1); IVec i(u.vec); - IVec vlem = vz < vxm; - IVec vlep = vz < vxp; + IVec vlem = operator< (vz,vxm); + IVec vlep = operator< (vz,vxp); i = i + vlem + vlep; i.extractLo32s().store(pr); } From f9e2a843a2351a53f055442de546ab7bd911cbec Mon Sep 17 00:00:00 2001 From: arlo-phoenix Date: Fri, 8 Sep 2023 14:33:34 +0200 Subject: [PATCH 2/8] Add ROCm support to python library disables igemm for now and adds path to compiled library libbitsandbytes_hip_nohipblaslt --- bitsandbytes/autograd/_functions.py | 2 +- bitsandbytes/cuda_setup/main.py | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index 19f224391..5d400b0e8 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -224,7 +224,7 @@ def backward(ctx, grad_output): def supports_igemmlt(device: torch.device) -> bool: """check if this device supports the optimized int8 kernel""" - if torch.cuda.get_device_capability(device=device) < (7, 5): + if torch.cuda.get_device_capability(device=device) < (7, 5) or torch.version.hip: return False device_name = torch.cuda.get_device_name(device=device) nvidia16_models = ('GTX 1630', 'GTX 1650', 'GTX 1660') # https://en.wikipedia.org/wiki/GeForce_16_series diff --git a/bitsandbytes/cuda_setup/main.py b/bitsandbytes/cuda_setup/main.py index 34c035425..59931cae2 100644 --- a/bitsandbytes/cuda_setup/main.py +++ b/bitsandbytes/cuda_setup/main.py @@ -338,7 +338,9 @@ def evaluate_cuda_setup(): cuda_setup.add_log_entry(('Welcome to bitsandbytes. For bug reports, please run\n\npython -m bitsandbytes\n\n'), ('and submit this information together with your error trace to: https://github.com/TimDettmers/bitsandbytes/issues')) cuda_setup.add_log_entry('='*80) + if not torch.cuda.is_available(): return 'libbitsandbytes_cpu.so', None, None, None + if torch.version.hip: return 'libbitsandbytes_hip_nohipblaslt.so', None, None, None cudart_path = determine_cuda_runtime_lib_path() ccs = get_compute_capabilities() From 78eecb3861ffa15c6e46726e7a92838940818500 Mon Sep 17 00:00:00 2001 From: arlo-phoenix Date: Fri, 8 Sep 2023 14:33:55 +0200 Subject: [PATCH 3/8] Add ROCm information to README --- README.md | 15 +++++++++++++++ compile_from_source.md | 13 +++++++++++++ 2 files changed, 28 insertions(+) diff --git a/README.md b/README.md index ebf40909f..fb3850a98 100644 --- a/README.md +++ b/README.md @@ -22,10 +22,14 @@ Python >=3.8. Linux distribution (Ubuntu, MacOS, etc.) + CUDA > 10.0. In some cases it can happen that you need to compile from source. If this happens please consider submitting a bug report with `python -m bitsandbytes` information. What now follows is some short instructions which might work out of the box if `nvcc` is installed. If these do not work see further below. Compilation quickstart: + ```bash git clone https://github.com/timdettmers/bitsandbytes.git cd bitsandbytes +``` +For CUDA +```bash # CUDA_VERSIONS in {110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 120} # make argument in {cuda110, cuda11x, cuda12x} # if you do not know what CUDA you have, try looking at the output of: python -m bitsandbytes @@ -33,6 +37,17 @@ CUDA_VERSION=117 make cuda11x python setup.py install ``` +For ROCm +```bash +# Requiers ROCm 5.6+ +# Check if your GPU supports Wave32 with rocminfo | grep "Wavefront Size" +# If this doesn't output 32 and instead 64 this library won't work + +# Your ROCm target can be found with rocminfo | grep gfx +ROCM_TARGET=gfx1030 make hip +pip install . +``` + **Using Int8 inference with HuggingFace Transformers** ```python diff --git a/compile_from_source.md b/compile_from_source.md index c2f97088d..cccc448f0 100644 --- a/compile_from_source.md +++ b/compile_from_source.md @@ -38,3 +38,16 @@ If you have problems compiling the library with these instructions from source, Since 0.39.1 bitsandbytes installed via pip no longer provides Kepler binaries and these need to be compiled from source. Follow the steps above and instead of `cuda11x_nomatmul` etc use `cuda11x_nomatmul_kepler` +## Compilation with ROCm + +Since this library requires hipblasLt this only supports **ROCm 5.6+**. +Works well with these docker images: +- [rocm/pytorch](https://hub.docker.com/r/rocm/pytorch) +- [rocm/pytorch-nightly](https://hub.docker.com/r/rocm/pytorch-nightly). + +For installation do: +```bash +make hip ROCM_TARGET=gfx1030 +pip install . +``` +see https://www.llvm.org/docs/AMDGPUUsage.html#processors for finding ROCM_TARGET (e.g. gfx1030 for 6800XT,6900XT) or do `rocminfo | grep gfx`. \ No newline at end of file From 940c52ed81cd9b3f1cfa62d66b14162aab1c54d1 Mon Sep 17 00:00:00 2001 From: arlo-phoenix Date: Sat, 16 Dec 2023 11:16:54 +0100 Subject: [PATCH 4/8] Makefile: Add fallback to /opt/rocm home --- Makefile | 3 +++ 1 file changed, 3 insertions(+) diff --git a/Makefile b/Makefile index e27b2c24e..eb6268208 100644 --- a/Makefile +++ b/Makefile @@ -23,6 +23,9 @@ ifndef ROCM_TARGET $(error ERROR: ROCM_TARGET not set. Call make with ROCM string (see https://www.llvm.org/docs/AMDGPUUsage.html#processors), for example: make hip ROCM_TARGET=gfx1030) ROCM_TARGET:= endif +else +$(warning WARNING: Unable to find hipcc in path, fallback to ROCM_HOME /opt/rocm) +ROCM_HOME:=/opt/rocm endif From a485a0265c5cf81efb97d8fdfeba32425d8218c9 Mon Sep 17 00:00:00 2001 From: arlo-phoenix Date: Thu, 21 Dec 2023 19:20:09 +0100 Subject: [PATCH 5/8] make cut one less --- Makefile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Makefile b/Makefile index eb6268208..214d9db4d 100644 --- a/Makefile +++ b/Makefile @@ -7,7 +7,7 @@ ifeq ($(CUDA_HOME),) CUDA_HOME:= $(shell which nvcc | rev | cut -d'/' -f3- | rev) endif ifeq ($(ROCM_HOME),) - ROCM_HOME:= $(shell which hipcc | rev | cut -d'/' -f4- | rev) + ROCM_HOME:= $(shell which hipcc | rev | cut -d'/' -f3- | rev) endif ifneq ($(CUDA_HOME),) From 32cd5e0a4f8897f5e00626bdc48c077d903c0893 Mon Sep 17 00:00:00 2001 From: arlo-phoenix Date: Thu, 11 Jan 2024 13:52:12 +0100 Subject: [PATCH 6/8] Rename BITS_AND_BYTES_USE_ROCM to BNB_USE_HIP --- Makefile | 6 +++--- csrc/kernels.cu | 4 ++-- csrc/ops.cu | 2 +- csrc/ops.cuh | 2 +- 4 files changed, 7 insertions(+), 7 deletions(-) diff --git a/Makefile b/Makefile index 214d9db4d..9b8063972 100644 --- a/Makefile +++ b/Makefile @@ -132,10 +132,10 @@ cpuonly: $(BUILD_DIR) env $(GPP) -std=c++14 -shared -fPIC -I $(ROOT_DIR)/csrc -I $(ROOT_DIR)/include $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes_cpu.so hip: $(BUILD_DIR) - $(HIPCC) -std=c++14 -c -fPIC --offload-arch=$(ROCM_TARGET) $(HIP_INCLUDE) -o $(BUILD_DIR)/ops.o -DNO_CUBLASLT -DBITS_AND_BYTES_USE_ROCM $(CSRC)/ops.cu - $(HIPCC) -std=c++14 -c -fPIC --offload-arch=$(ROCM_TARGET) $(HIP_INCLUDE) -o $(BUILD_DIR)/kernels.o -DNO_CUBLASLT -DBITS_AND_BYTES_USE_ROCM $(CSRC)/kernels.cu + $(HIPCC) -std=c++14 -c -fPIC --offload-arch=$(ROCM_TARGET) $(HIP_INCLUDE) -o $(BUILD_DIR)/ops.o -DNO_CUBLASLT -DBNB_USE_HIP $(CSRC)/ops.cu + $(HIPCC) -std=c++14 -c -fPIC --offload-arch=$(ROCM_TARGET) $(HIP_INCLUDE) -o $(BUILD_DIR)/kernels.o -DNO_CUBLASLT -DBNB_USE_HIP $(CSRC)/kernels.cu # HCC is deprecated, but used by hipBLASlt header. Since blas isn't even used doesn't matter, this is just so that it even compiles - $(GPP) -std=c++14 -D__HIP_PLATFORM_HCC__ -D__HIP_PLATFORM_AMD__ -DBUILD_CUDA -DBITS_AND_BYTES_USE_ROCM -shared -fPIC $(HIP_INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(FILES_CPP) $(HIP_LIB) -o ./bitsandbytes/libbitsandbytes_hip_nohipblaslt.so + $(GPP) -std=c++14 -D__HIP_PLATFORM_HCC__ -D__HIP_PLATFORM_AMD__ -DBUILD_CUDA -DBNB_USE_HIP -shared -fPIC $(HIP_INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(FILES_CPP) $(HIP_LIB) -o ./bitsandbytes/libbitsandbytes_hip_nohipblaslt.so env: @echo "ENVIRONMENT" diff --git a/csrc/kernels.cu b/csrc/kernels.cu index 8fffbc33b..093466b11 100644 --- a/csrc/kernels.cu +++ b/csrc/kernels.cu @@ -5,7 +5,7 @@ #include -#ifdef BITS_AND_BYTES_USE_ROCM +#ifdef BNB_USE_HIP #include #include #include @@ -38,7 +38,7 @@ #define NUM 4 #define NUM_BLOCK 4096 -#ifndef BITS_AND_BYTES_USE_ROCM +#ifndef BNB_USE_HIP // source: https://stackoverflow.com/questions/17399119/how-do-i-use-atomicmax-on-floating-point-values-in-cuda __device__ float atomicMax(float* address, float val) { int* address_as_i = reinterpret_cast(address); diff --git a/csrc/ops.cu b/csrc/ops.cu index 252cc09fa..4a7c80328 100644 --- a/csrc/ops.cu +++ b/csrc/ops.cu @@ -10,7 +10,7 @@ #include #include -#ifdef BITS_AND_BYTES_USE_ROCM +#ifdef BNB_USE_HIP #include #else #include diff --git a/csrc/ops.cuh b/csrc/ops.cuh index f1cf7f9e1..3584e5982 100644 --- a/csrc/ops.cuh +++ b/csrc/ops.cuh @@ -13,7 +13,7 @@ #include -#ifdef BITS_AND_BYTES_USE_ROCM +#ifdef BNB_USE_HIP // check rocminfo | grep "Wavefront Size". Should be supported on all new GPU's // dirty hack to force wavefront_size 32 so this compiles // RDNA 2 defaults to 64 which conflicts with kQuantizeBlockwise From e03a8bdd77ab68be75767b047adc8f00fe3f5f7b Mon Sep 17 00:00:00 2001 From: arlo-phoenix Date: Fri, 12 Jan 2024 17:55:05 +0100 Subject: [PATCH 7/8] Adjust kQuantizeBlockwise to work with WARP size 64 --- csrc/kernels.cu | 37 ++++++++++++++++++++++--------------- csrc/ops.cuh | 6 +----- 2 files changed, 23 insertions(+), 20 deletions(-) diff --git a/csrc/kernels.cu b/csrc/kernels.cu index 093466b11..104aa7be4 100644 --- a/csrc/kernels.cu +++ b/csrc/kernels.cu @@ -740,21 +740,28 @@ template 0) ? NUM_PER_TH/2 : NUM_PER_TH; + const int n_full = gridDim.x * BLOCK_SIZE; int valid_items = 0; const int base_idx = (blockIdx.x * BLOCK_SIZE); - T vals[NUM_PER_TH]; - float rand_vals[NUM_PER_TH]; - unsigned char qvals[(DATA_TYPE > 0) ? NUM_PER_TH/2 : NUM_PER_TH]; + T vals[CUB_NUM_PER_TH]; + float rand_vals[CUB_NUM_PER_TH]; + unsigned char qvals[DATA_NUM_PER_TH]; //float local_abs_max = -FLT_MAX; float local_abs_max = 0.0f; int local_rand_idx = 0; - typedef cub::BlockLoad LoadT; - typedef cub::BlockStore 0) ? NUM_PER_TH/2 : NUM_PER_TH, cub::BLOCK_STORE_WARP_TRANSPOSE> StoreChar; - typedef cub::BlockReduce BlockReduce; - typedef cub::BlockLoad LoadFloat; + typedef cub::BlockLoad LoadT; + typedef cub::BlockStore StoreChar; + typedef cub::BlockReduce BlockReduce; + typedef cub::BlockLoad LoadFloat; __shared__ typename LoadT::TempStorage loadt; __shared__ typename LoadFloat::TempStorage loadf; @@ -779,8 +786,8 @@ __global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float // 2. broadcast local max // 3. normalize inputs and quantize - #pragma unroll NUM_PER_TH - for(int j = 0; j < NUM_PER_TH; j++) + #pragma unroll CUB_NUM_PER_TH + for(int j = 0; j < CUB_NUM_PER_TH; j++) local_abs_max = fmaxf(local_abs_max, fabsf((float)vals[j])); local_abs_max = BlockReduce(reduce).Reduce(local_abs_max, cub::Max(), valid_items); @@ -809,8 +816,8 @@ __global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float switch(DATA_TYPE) { case General8bit: - #pragma unroll NUM_PER_TH - for(int j = 0; j < NUM_PER_TH; j++) + #pragma unroll CUB_NUM_PER_TH + for(int j = 0; j < CUB_NUM_PER_TH; j++) { if(!STOCHASTIC) qvals[j] = dQuantize<0>(smem_code, 0.0f, ((float)vals[j])*local_abs_max); @@ -819,8 +826,8 @@ __global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float } break; case FP4: - #pragma unroll NUM_PER_TH - for(int j = 0; j < NUM_PER_TH/2; j++) + #pragma unroll CUB_NUM_PER_TH + for(int j = 0; j < DATA_NUM_PER_TH; j++) { packed_4bit |= dQuantizeFP4(((float)vals[2*j])*local_abs_max) << 4; packed_4bit |= dQuantizeFP4(((float)vals[2*j+1])*local_abs_max); @@ -828,8 +835,8 @@ __global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float } break; case NF4: - #pragma unroll NUM_PER_TH - for(int j = 0; j < NUM_PER_TH/2; j++) + #pragma unroll CUB_NUM_PER_TH + for(int j = 0; j < DATA_NUM_PER_TH; j++) { packed_4bit |= dQuantizeNF4(((float)vals[2*j])*local_abs_max) << 4; packed_4bit |= dQuantizeNF4(((float)vals[2*j+1])*local_abs_max); diff --git a/csrc/ops.cuh b/csrc/ops.cuh index 3584e5982..87203edae 100644 --- a/csrc/ops.cuh +++ b/csrc/ops.cuh @@ -14,10 +14,6 @@ #ifdef BNB_USE_HIP -// check rocminfo | grep "Wavefront Size". Should be supported on all new GPU's -// dirty hack to force wavefront_size 32 so this compiles -// RDNA 2 defaults to 64 which conflicts with kQuantizeBlockwise -#define __AMDGCN_WAVEFRONT_SIZE 32 #include #include @@ -58,7 +54,7 @@ #define cublasLtHandle_t hipblasLtHandle_t #define cublasLtCreate hipblasLtCreate #define CUBLAS_GEMM_DEFAULT HIPBLAS_GEMM_DEFAULT -#define CUBLAS_GEMM_DEFAULT_TENSOR_OP HIPBLAS_GEMM_DEFAULT //TODO: HIP didn't have the right one, might cause issues +#define CUBLAS_GEMM_DEFAULT_TENSOR_OP HIPBLAS_GEMM_DEFAULT #else #include From ac20c059bd61692537ea850e944f8ddb8c59609d Mon Sep 17 00:00:00 2001 From: arlo-phoenix Date: Mon, 22 Jan 2024 17:10:45 +0100 Subject: [PATCH 8/8] Make sure DATA_NUM_PER_TH <= CUB_NUM_PER_TH the unrolls already somehow worked correctly before, but they shouldn't have. --- csrc/kernels.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/kernels.cu b/csrc/kernels.cu index 104aa7be4..77447b6e0 100644 --- a/csrc/kernels.cu +++ b/csrc/kernels.cu @@ -745,7 +745,7 @@ __global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float #else const int CUB_NUM_PER_TH=NUM_PER_TH; #endif - const int DATA_NUM_PER_TH=(DATA_TYPE > 0) ? NUM_PER_TH/2 : NUM_PER_TH; + const int DATA_NUM_PER_TH=(DATA_TYPE > 0) ? NUM_PER_TH/2 : CUB_NUM_PER_TH; const int n_full = gridDim.x * BLOCK_SIZE; int valid_items = 0;