From 7d4854edcee2f20944466729e944d40809d4554c Mon Sep 17 00:00:00 2001 From: sstamenk Date: Fri, 25 Jul 2025 15:40:13 +0200 Subject: [PATCH] warpSize is being made non constexpr in ROCm 7.0 --- csrc/kernels.hip | 20 ++++++++++++-------- csrc/ops.hip | 8 +++++++- 2 files changed, 19 insertions(+), 9 deletions(-) diff --git a/csrc/kernels.hip b/csrc/kernels.hip index ec3f7f025..58f6ed065 100644 --- a/csrc/kernels.hip +++ b/csrc/kernels.hip @@ -2109,7 +2109,11 @@ __global__ void kdequant_mm_int32_fp16( #define DENORM 1.0f/127.0f #define MAX_SPARSE_COUNT 32 #define SMEM_SIZE 8*256 -#define WARP_SIZE warpSize +#if defined(__GFX9__) + #define WARP_SIZE 64 +#else + #define WARP_SIZE 32 +#endif template __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, T *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB) { @@ -2708,13 +2712,13 @@ template __global__ void kgemm_4bit_inferenc // load step-by-step in chunks of [warp_size,warps]: 1xwarp_size * [warp_size,warps] -> [1,warps] // 4 warps -> 4 loads per iter // 1xwarp_size * warp_sizex4 -> 1x4 outputs per thread block - typedef hipcub::WarpReduce WarpReduce; - __shared__ typename WarpReduce::TempStorage temp_storage[THREADS/warpSize]; + typedef hipcub::WarpReduce WarpReduce; + __shared__ typename WarpReduce::TempStorage temp_storage[THREADS/WARP_SIZE]; - const int warp_idx = threadIdx.x / warpSize; - const int warp_lane = threadIdx.x % warpSize; - const int row_B = (THREADS/warpSize)*blockIdx.x + warp_idx; - const int offset_B = ldb*row_B; + const int warp_idx = threadIdx.x / WARP_SIZE; + const int warp_lane = threadIdx.x % WARP_SIZE; + const int row_B = (THREADS/WARP_SIZE)*blockIdx.x + warp_idx; + const int offset_B = ldb * row_B; const int num_values_8bit = num_values_4bit/2; float local_C = 0.0f; @@ -2732,7 +2736,7 @@ template __global__ void kgemm_4bit_inferenc // A: [1, K] // B: [M, K] - for(int inner_idx = warp_lane*num_values_4bit; inner_idx < K; inner_idx += warpSize*num_values_4bit) + for(int inner_idx = warp_lane*num_values_4bit; inner_idx < K; inner_idx += WARP_SIZE*num_values_4bit) { const int inner_idx_halved = inner_idx/2; diff --git a/csrc/ops.hip b/csrc/ops.hip index 260b74b30..b26d138e1 100644 --- a/csrc/ops.hip +++ b/csrc/ops.hip @@ -20,6 +20,12 @@ #define ERR_NOT_IMPLEMENTED 100 +#if defined(__GFX9__) + #define WARP_SIZE 64 +#else + #define WARP_SIZE 32 +#endif + using namespace BinSearch; using std::cout; using std::endl; @@ -692,7 +698,7 @@ template void gemm_4bit_inference_naive(int m, int n, int //warpsize - 32 int num_blocks = (m+3)/4; //warpsize - 64 - if (warpSize == 64) { + if (WARP_SIZE == 64) { num_blocks = (m+1)/2; }