Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 12 additions & 8 deletions csrc/kernels.hip
Original file line number Diff line number Diff line change
Expand Up @@ -2109,7 +2109,11 @@ __global__ void kdequant_mm_int32_fp16(
#define DENORM 1.0f/127.0f
#define MAX_SPARSE_COUNT 32
#define SMEM_SIZE 8*256
#define WARP_SIZE warpSize
#if defined(__GFX9__)
#define WARP_SIZE 64
#else
#define WARP_SIZE 32
#endif
template <typename T, int SPMM_ITEMS, int BITS>
__global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, T *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB)
{
Expand Down Expand Up @@ -2708,13 +2712,13 @@ template <typename T, int THREADS, int BITS> __global__ void kgemm_4bit_inferenc
// load step-by-step in chunks of [warp_size,warps]: 1xwarp_size * [warp_size,warps] -> [1,warps]
// 4 warps -> 4 loads per iter
// 1xwarp_size * warp_sizex4 -> 1x4 outputs per thread block
typedef hipcub::WarpReduce<float, warpSize> WarpReduce;
__shared__ typename WarpReduce::TempStorage temp_storage[THREADS/warpSize];
typedef hipcub::WarpReduce<float, WARP_SIZE> WarpReduce;
__shared__ typename WarpReduce::TempStorage temp_storage[THREADS/WARP_SIZE];

const int warp_idx = threadIdx.x / warpSize;
const int warp_lane = threadIdx.x % warpSize;
const int row_B = (THREADS/warpSize)*blockIdx.x + warp_idx;
const int offset_B = ldb*row_B;
const int warp_idx = threadIdx.x / WARP_SIZE;
const int warp_lane = threadIdx.x % WARP_SIZE;
const int row_B = (THREADS/WARP_SIZE)*blockIdx.x + warp_idx;
const int offset_B = ldb * row_B;
const int num_values_8bit = num_values_4bit/2;
float local_C = 0.0f;

Expand All @@ -2732,7 +2736,7 @@ template <typename T, int THREADS, int BITS> __global__ void kgemm_4bit_inferenc

// A: [1, K]
// B: [M, K]
for(int inner_idx = warp_lane*num_values_4bit; inner_idx < K; inner_idx += warpSize*num_values_4bit)
for(int inner_idx = warp_lane*num_values_4bit; inner_idx < K; inner_idx += WARP_SIZE*num_values_4bit)
{
const int inner_idx_halved = inner_idx/2;

Expand Down
8 changes: 7 additions & 1 deletion csrc/ops.hip
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,12 @@

#define ERR_NOT_IMPLEMENTED 100

#if defined(__GFX9__)
#define WARP_SIZE 64
#else
#define WARP_SIZE 32
#endif

using namespace BinSearch;
using std::cout;
using std::endl;
Expand Down Expand Up @@ -692,7 +698,7 @@ template <typename T, int BITS> void gemm_4bit_inference_naive(int m, int n, int
//warpsize - 32
int num_blocks = (m+3)/4;
//warpsize - 64
if (warpSize == 64) {
if (WARP_SIZE == 64) {
num_blocks = (m+1)/2;
}

Expand Down