Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 12 additions & 8 deletions ggml/src/ggml-cuda/mmq.cu
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@
static void ggml_cuda_mul_mat_q_switch_type(ggml_backend_cuda_context & ctx, const mmq_args & args, cudaStream_t stream) {
switch (args.type_x) {
// TODO: Q1_0/Q1_0_g128 MMQ disabled due to accuracy issues; for now commenting these to use cuBLAS fallback
// case GGML_TYPE_Q1_0:
// mul_mat_q_case<GGML_TYPE_Q1_0>(ctx, args, stream);
// break;
// case GGML_TYPE_Q1_0_g128:
// mul_mat_q_case<GGML_TYPE_Q1_0_g128>(ctx, args, stream);
// break;
case GGML_TYPE_Q1_0:
mul_mat_q_case<GGML_TYPE_Q1_0>(ctx, args, stream);
break;
case GGML_TYPE_Q1_0_g128:
mul_mat_q_case<GGML_TYPE_Q1_0_g128>(ctx, args, stream);
break;
case GGML_TYPE_Q4_0:
mul_mat_q_case<GGML_TYPE_Q4_0>(ctx, args, stream);
break;
Expand Down Expand Up @@ -275,8 +275,8 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11, int64_t

switch (type) {
// TODO: Q1_0 and Q1_0_g128 MMQ implementation exists but is currently disabled due to accuracy issues
// case GGML_TYPE_Q1_0:
// case GGML_TYPE_Q1_0_g128:
case GGML_TYPE_Q1_0:
case GGML_TYPE_Q1_0_g128:
Comment on lines 277 to +279
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_Q5_0:
Expand Down Expand Up @@ -307,6 +307,10 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11, int64_t
return false;
}

if ((type == GGML_TYPE_Q1_0 || type == GGML_TYPE_Q1_0_g128) && !turing_mma_available(cc)) {
return false;
}

if (turing_mma_available(cc)) {
return true;
}
Expand Down
163 changes: 102 additions & 61 deletions ggml/src/ggml-cuda/mmq.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -174,8 +174,6 @@ static constexpr __device__ int get_mmq_y_device() {
// in terms of 32 bit elements that means K % 2 == 1 for dp4a or K % 8 == 4 for mma.
#define MMQ_TILE_NE_K 32

#define MMQ_DP4A_TXS_Q1_0 tile_x_sizes{mmq_y*MMQ_TILE_NE_K + mmq_y, mmq_y*MMQ_TILE_NE_K/QI1_0 + mmq_y/QI1_0, 0}
#define MMQ_DP4A_TXS_Q1_0_g128 tile_x_sizes{mmq_y*MMQ_TILE_NE_K*4 + mmq_y, mmq_y*MMQ_TILE_NE_K*4/QI1_0_g128 + mmq_y/(QI1_0_g128/4), 0}
#define MMQ_DP4A_TXS_Q4_0 tile_x_sizes{mmq_y*MMQ_TILE_NE_K + mmq_y, mmq_y*MMQ_TILE_NE_K/QI4_0 + mmq_y/QI4_0, 0}
#define MMQ_DP4A_TXS_Q4_1 tile_x_sizes{mmq_y*MMQ_TILE_NE_K + mmq_y, mmq_y*MMQ_TILE_NE_K/QI4_1 + mmq_y/QI4_1, 0}
#define MMQ_DP4A_TXS_Q8_0 tile_x_sizes{mmq_y*MMQ_TILE_NE_K*2 + mmq_y, mmq_y*MMQ_TILE_NE_K*2/QI8_0 + mmq_y/(QI8_0/2), 0}
Expand All @@ -189,8 +187,6 @@ static constexpr __device__ int get_mmq_y_device() {

static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml_type type, int mmq_y) {
switch (type) {
case GGML_TYPE_Q1_0: return MMQ_DP4A_TXS_Q1_0;
case GGML_TYPE_Q1_0_g128: return MMQ_DP4A_TXS_Q1_0_g128;
case GGML_TYPE_Q4_0: return MMQ_DP4A_TXS_Q4_0;
case GGML_TYPE_Q4_1: return MMQ_DP4A_TXS_Q4_1;
case GGML_TYPE_Q5_0: return MMQ_DP4A_TXS_Q8_0;
Expand Down Expand Up @@ -234,7 +230,7 @@ static_assert(MMQ_MMA_TILE_X_K_FP4 == MMQ_MMA_TILE_X_K_Q8_1, "Wrong tile size fo
static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) {
switch (type) {
case GGML_TYPE_Q1_0: return MMQ_MMA_TILE_X_K_Q8_0;
case GGML_TYPE_Q1_0_g128: return MMQ_MMA_TILE_X_K_Q8_0_g128;
case GGML_TYPE_Q1_0_g128: return MMQ_MMA_TILE_X_K_Q8_0;
case GGML_TYPE_Q4_0: return MMQ_MMA_TILE_X_K_Q8_0;
case GGML_TYPE_Q4_1: return MMQ_MMA_TILE_X_K_Q8_1;
case GGML_TYPE_Q5_0: return MMQ_MMA_TILE_X_K_Q8_0;
Expand Down Expand Up @@ -309,20 +305,19 @@ static constexpr __device__ int mmq_get_nwarps_device() {

template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_q1_0(
const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
#if !(defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE))
GGML_UNUSED_VARS(x, x_tile, kbx0, i_max, stride, mmq_y, need_check);
NO_DEVICE_CODE;
#else
constexpr int nwarps = mmq_get_nwarps_device();
constexpr int warp_size = ggml_cuda_get_physical_warp_size();

#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
int * x_qs = (int *) x_tile;
float * x_df = (float *) (x_qs + MMQ_TILE_NE_K);
#else
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q1_0, mmq_y);
int * x_qs = (int *) x_tile;
float * x_df = (float *) (x_qs + txs.qs);
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)

constexpr int threads_per_row = MMQ_ITER_K_Q1_0 / (4 * QR1_0);
float * x_df = (float *) (x_qs + 2*MMQ_TILE_NE_K);
constexpr int blocks_per_iter = MMQ_ITER_K / QK1_0;
constexpr int threads_per_row = blocks_per_iter * QI1_0;
constexpr int nrows = warp_size / threads_per_row;
constexpr int scale_entries_per_row = blocks_per_iter * (QK1_0 / QK8_1);
const int txi = threadIdx.x % threads_per_row;
const int kbx = txi / QI1_0;

Expand All @@ -341,7 +336,6 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
// Read all 4 bytes safely to avoid alignment issues
const int qs0 = bxi->qs[0] | (bxi->qs[1] << 8) | (bxi->qs[2] << 16) | (bxi->qs[3] << 24);

#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
// For MMA: unpack 1-bit values to signed bytes (-1 or +1)
// Process all 32 bits, 4 at a time
int unpacked_bytes[8];
Expand All @@ -358,34 +352,98 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
// Store unpacked values
#pragma unroll
for (int j = 0; j < 8; ++j) {
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*8 + j] = unpacked_bytes[j];
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*QI8_0 + j] = unpacked_bytes[j];
}
#else
// For DP4A: store raw bits, will unpack in vec_dot
x_qs[i*(MMQ_TILE_NE_K + 1) + txi] = qs0;
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
}

constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI1_0;
constexpr int rows_per_warp = warp_size / blocks_per_tile_x_row;
const int kbxd = threadIdx.x % blocks_per_tile_x_row;
constexpr int rows_per_warp = warp_size / scale_entries_per_row;
const int kbxd = threadIdx.x % scale_entries_per_row;

#pragma unroll
for (int i0 = 0; i0 < mmq_y; i0 += nwarps * rows_per_warp) {
int i = i0 + threadIdx.y * rows_per_warp + threadIdx.x / blocks_per_tile_x_row;
int i = i0 + threadIdx.y * rows_per_warp + threadIdx.x / scale_entries_per_row;

if (need_check) {
i = min(i, i_max);
}

const block_q1_0 * bxi = (const block_q1_0 *) x + kbx0 + i*stride + kbxd;

#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = bxi->d;
#else
x_df[i*(MMQ_TILE_NE_K/QI1_0) + i/QI1_0 + kbxd] = bxi->d;
}
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
}

template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_q1_0_g128(
const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
#if !(defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE))
GGML_UNUSED_VARS(x, x_tile, kbx0, i_max, stride, mmq_y, need_check);
NO_DEVICE_CODE;
#else
constexpr int nwarps = mmq_get_nwarps_device();
constexpr int warp_size = ggml_cuda_get_physical_warp_size();

int * x_qs = (int *) x_tile;
float * x_df = (float *) (x_qs + 2*MMQ_TILE_NE_K);

constexpr int blocks_per_iter = MMQ_ITER_K / QK1_0_g128;
constexpr int threads_per_row = blocks_per_iter * QI1_0_g128;
constexpr int nrows = warp_size / threads_per_row;
constexpr int scale_entries_per_block = QK1_0_g128 / QK8_1;
constexpr int scale_entries_per_row = blocks_per_iter * scale_entries_per_block;

const int txi = threadIdx.x % threads_per_row;
const int kbx = txi / QI1_0_g128;
const int kqsx = txi % QI1_0_g128;

#pragma unroll
for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) {
int i = i0 + threadIdx.y*nrows + threadIdx.x/threads_per_row;

if (need_check) {
i = min(i, i_max);
}

const block_q1_0_g128 * bxi = (const block_q1_0_g128 *) x + kbx0 + i*stride + kbx;
const int qs_offset = 4*kqsx;
const int qs0 = bxi->qs[qs_offset + 0] | (bxi->qs[qs_offset + 1] << 8) |
(bxi->qs[qs_offset + 2] << 16) | (bxi->qs[qs_offset + 3] << 24);

int unpacked_bytes[8];
#pragma unroll
for (int j = 0; j < 8; ++j) {
const int shift = j * 4;
const int bits4 = (qs0 >> shift) & 0x0F;
const int b0 = (bits4 & 0x01) ? 1 : -1;
const int b1 = (bits4 & 0x02) ? 1 : -1;
const int b2 = (bits4 & 0x04) ? 1 : -1;
const int b3 = (bits4 & 0x08) ? 1 : -1;
unpacked_bytes[j] = (b0 & 0xFF) | ((b1 & 0xFF) << 8) | ((b2 & 0xFF) << 16) | ((b3 & 0xFF) << 24);
}

const int dst_offset = kbx*(scale_entries_per_block*QI8_0) + kqsx*QI8_0;
#pragma unroll
for (int j = 0; j < 8; ++j) {
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + dst_offset + j] = unpacked_bytes[j];
}
}

const int ksx = threadIdx.x % scale_entries_per_row;
const int scale_block = ksx / scale_entries_per_block;

#pragma unroll
for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
int i = i0 + threadIdx.y;

if (need_check) {
i = min(i, i_max);
}

const block_q1_0_g128 * bxi = (const block_q1_0_g128 *) x + kbx0 + i*stride + scale_block;

x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + ksx] = bxi->d;
}
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
}

template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_q4_0(
Expand Down Expand Up @@ -450,38 +508,12 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
}

template <int mmq_x, int mmq_y>
static __device__ __forceinline__ void vec_dot_q1_0_q8_1_dp4a(
static __device__ __forceinline__ void vec_dot_q1_mmq_dp4a_disabled(
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
constexpr int nwarps = mmq_get_nwarps_device();
constexpr int warp_size = ggml_cuda_get_physical_warp_size();

constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q1_0, mmq_y);
const int * x_qs = (const int *) x;
const float * x_df = (const float *) x_qs + txs.qs;
const int * y_qs = (const int *) y + 4;
const float * y_df = (const float *) y;

// #pragma unroll
for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += VDR_Q1_0_Q8_1_MMQ) {
const int k0 = k00 + k01;

#pragma unroll
for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
const int j = j0 + threadIdx.y;

#pragma unroll
for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
const int i = i0 + threadIdx.x;

// For MMQ, we have separate float scales, need to combine them into half2 format
const half2 ds8 = __floats2half2_rn(y_df[j*MMQ_TILE_Y_K + (k0/QI8_1) % (MMQ_TILE_NE_K/QI8_1)], 0.0f);

sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q1_0_q8_1_impl<VDR_Q1_0_Q8_1_MMQ>
(&x_qs[i*(MMQ_TILE_NE_K + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k0 % MMQ_TILE_NE_K],
x_df[i*(MMQ_TILE_NE_K/QI1_0) + i/QI1_0 + k0/QI1_0], ds8);
}
}
}
// Q1_0 and Q1_0_g128 intentionally target the MMA path only on this branch.
// If DP4A support is needed later for older GPUs, it should be reintroduced and validated separately.
GGML_UNUSED_VARS(x, y, sum, k00, mmq_x, mmq_y);
NO_DEVICE_CODE;
}

template <int mmq_x, int mmq_y>
Expand Down Expand Up @@ -3341,7 +3373,16 @@ struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_Q1_0> {
static constexpr int vdr = VDR_Q1_0_Q8_1_MMQ;
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q1_0<mmq_y, need_check>;
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, MMQ_Q8_1_DS_LAYOUT_D4>;
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q1_0_q8_1_dp4a<mmq_x, mmq_y>;
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q1_mmq_dp4a_disabled<mmq_x, mmq_y>;
};

template <int mmq_x, int mmq_y, bool need_check>
struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_Q1_0_g128> {
static constexpr int vdr = VDR_Q1_0_g128_Q8_1_MMQ;
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q1_0_g128<mmq_y, need_check>;
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, MMQ_Q8_1_DS_LAYOUT_D4>;
// The DP4A path is intentionally disabled; keep the MMA path as the validated route.
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q1_mmq_dp4a_disabled<mmq_x, mmq_y>;
};

template <int mmq_x, int mmq_y, bool need_check>
Expand Down Expand Up @@ -4199,8 +4240,8 @@ void mul_mat_q_case(ggml_backend_cuda_context & ctx, const mmq_args & args, cuda
#define DECL_MMQ_CASE(type) \
template void mul_mat_q_case<type>(ggml_backend_cuda_context & ctx, const mmq_args & args, cudaStream_t stream) \

// extern DECL_MMQ_CASE(GGML_TYPE_Q1_0); // disabled, cuBLAS fallback used
// extern DECL_MMQ_CASE(GGML_TYPE_Q1_0_g128); // disabled, cuBLAS fallback used
extern DECL_MMQ_CASE(GGML_TYPE_Q1_0);
extern DECL_MMQ_CASE(GGML_TYPE_Q1_0_g128);
extern DECL_MMQ_CASE(GGML_TYPE_Q4_0);
extern DECL_MMQ_CASE(GGML_TYPE_Q4_1);
extern DECL_MMQ_CASE(GGML_TYPE_Q5_0);
Expand Down
1 change: 1 addition & 0 deletions ggml/src/ggml-cuda/quantize.cu
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,7 @@ void quantize_mmq_q8_1_cuda(
const int64_t block_num_y = (ne0 + 4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ - 1) / (4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ);
const dim3 num_blocks(ne1, block_num_y, ne2*ne3);
const dim3 block_size(CUDA_QUANTIZE_BLOCK_SIZE_MMQ, 1, 1);

switch (mmq_get_q8_1_ds_layout(type_src0)) {
case MMQ_Q8_1_DS_LAYOUT_D4:
quantize_mmq_q8_1<MMQ_Q8_1_DS_LAYOUT_D4>
Expand Down
1 change: 1 addition & 0 deletions ggml/src/ggml-cuda/template-instances/generate_cu_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
SOURCE_FATTN_MMA_CASE = "DECL_FATTN_MMA_F16_CASE({head_size_kq}, {head_size_v}, {ncols1}, {ncols2});\n"

TYPES_MMQ = [
"GGML_TYPE_Q1_0", "GGML_TYPE_Q1_0_g128",
"GGML_TYPE_Q4_0", "GGML_TYPE_Q4_1", "GGML_TYPE_Q5_0", "GGML_TYPE_Q5_1", "GGML_TYPE_Q8_0",
"GGML_TYPE_Q2_K", "GGML_TYPE_Q3_K", "GGML_TYPE_Q4_K", "GGML_TYPE_Q5_K", "GGML_TYPE_Q6_K",
"GGML_TYPE_IQ2_XXS", "GGML_TYPE_IQ2_XS", "GGML_TYPE_IQ2_S", "GGML_TYPE_IQ3_XXS", "GGML_TYPE_IQ3_S",
Expand Down
5 changes: 5 additions & 0 deletions ggml/src/ggml-cuda/template-instances/mmq-instance-q1_0.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.

#include "../mmq.cuh"

DECL_MMQ_CASE(GGML_TYPE_Q1_0);
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.

#include "../mmq.cuh"

DECL_MMQ_CASE(GGML_TYPE_Q1_0_g128);
4 changes: 2 additions & 2 deletions src/llama-model-loader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ static std::string llama_model_ftype_name(llama_ftype ftype) {
case LLAMA_FTYPE_ALL_F32: return "all F32";
case LLAMA_FTYPE_MOSTLY_F16: return "F16";
case LLAMA_FTYPE_MOSTLY_BF16: return "BF16";
case LLAMA_FTYPE_MOSTLY_Q1_0: return "Q1_0 - 1.5 bpw";
case LLAMA_FTYPE_MOSTLY_Q1_0_g128: return "Q1_0_g128 - 1.125 bpw";
case LLAMA_FTYPE_MOSTLY_Q1_0: return "Q1_0";
case LLAMA_FTYPE_MOSTLY_Q1_0_g128: return "Q1_0_g128";
case LLAMA_FTYPE_MOSTLY_Q4_0: return "Q4_0";
case LLAMA_FTYPE_MOSTLY_Q4_1: return "Q4_1";
case LLAMA_FTYPE_MOSTLY_Q5_0: return "Q5_0";
Expand Down
Loading