diff --git a/ggml/src/ggml-cuda/mmq.cu b/ggml/src/ggml-cuda/mmq.cu index 23b852e8f19..cf4fc5ee6df 100644 --- a/ggml/src/ggml-cuda/mmq.cu +++ b/ggml/src/ggml-cuda/mmq.cu @@ -3,6 +3,78 @@ #include "quantize.cuh" #include "mmid.cuh" +// Copy Q5_K base (176 bytes) from each Q5_K_HIFI_RES8 block (196 bytes) for MMQ path. +// Uses vectorized 4-byte loads: 176/4=44 words, 196/4=49 words (both divisible by 4 so every +// block-start is uint32_t-aligned regardless of block index). +static_assert(sizeof(block_q5_K) % sizeof(uint32_t) == 0, "Q5_K size not a multiple of 4"); +static_assert(sizeof(block_q5_k_hifi_res8) % sizeof(uint32_t) == 0, "Q5_K_HIFI_RES8 size not a multiple of 4"); +static __global__ void ggml_cuda_compact_q5_k_hifi_res8_to_q5_k( + const void * __restrict__ src, void * __restrict__ dst, int64_t n_blocks) { + const int64_t i = (int64_t)blockIdx.x * blockDim.x + threadIdx.x; + if (i >= n_blocks) return; + const uint32_t * s = (const uint32_t *)((const char *)src + i * sizeof(block_q5_k_hifi_res8)); + uint32_t * d = (uint32_t *)((char *)dst + i * sizeof(block_q5_K)); + #pragma unroll + for (int j = 0; j < (int)(sizeof(block_q5_K) / sizeof(uint32_t)); ++j) { + d[j] = s[j]; + } +} + +// Add Q5_K_HIFI_RES8 INT8 residual corrections to MMQ output using F32 activations. +// Parallelised at the (row, block) level rather than (row, batch): +// - 92% of threads hit the early-exit (outlier_count==0) before touching src1 or dst. +// - The 8% of threads that do have outliers loop over all batch slots and atomicAdd +// their contribution. Contention is negligible (~1 writer per output cell on average). +static __global__ void ggml_cuda_add_q5_k_hifi_res8_residuals( + const block_q5_k_hifi_res8 * __restrict__ x, + const float * __restrict__ src1, float * __restrict__ dst, + int64_t nrows_x, int64_t ncols_x, int64_t ncols_dst, + int64_t stride_row_x, int64_t stride_src1, int64_t stride_dst) { + + const int64_t n_blocks = ncols_x / QK_K; + const int64_t rb = (int64_t)blockIdx.x * blockDim.x + threadIdx.x; + if (rb >= nrows_x * n_blocks) return; + + const int64_t row = rb / n_blocks; + const int64_t b = rb % n_blocks; + + const block_q5_k_hifi_res8 * block = x + row * stride_row_x + b; + const int n_out = (block->outlier_count & 0x7F); + if (n_out == 0) return; // fast path: ~92% of blocks exit here + + const uint8_t e4m3 = block->residual_scale_e4m3; + if (e4m3 == 0) return; + + // Decode E4M3 FP8 residual scale once, in registers + const int sign = (e4m3 >> 7) & 0x01; + const int exp = (e4m3 >> 3) & 0x0F; + const int mantissa = e4m3 & 0x07; + const float res_scale = (1.0f + (float)mantissa * 0.125f) + * exp2f((float)exp - 7.0f) + * (sign ? -1.0f : 1.0f) + * (1.0f / 127.0f); + + // Cache per-outlier column indices and scaled residual values in registers + // so the inner batch loop only reads src1 (no repeated block struct accesses). + const int n_valid = (n_out < Q5_K_HIFI_RES8_MAX_OUTLIERS) ? n_out : Q5_K_HIFI_RES8_MAX_OUTLIERS; + int cols [Q5_K_HIFI_RES8_MAX_OUTLIERS]; + float rvals[Q5_K_HIFI_RES8_MAX_OUTLIERS]; + for (int k = 0; k < n_valid; ++k) { + cols [k] = (int)b * QK_K + block->outlier_idx[k]; + rvals[k] = res_scale * (float)block->residual_vals[k]; + } + + // Accumulate residual dot-products over all batch slots and atomicAdd to dst. + // Low contention: at most ~1.3 enhanced blocks per row on average. + for (int64_t batch = 0; batch < ncols_dst; ++batch) { + float sum = 0.0f; + for (int k = 0; k < n_valid; ++k) { + sum += rvals[k] * src1[batch * stride_src1 + cols[k]]; + } + atomicAdd(&dst[batch * stride_dst + row], sum); + } +} + static void ggml_cuda_mul_mat_q_switch_type(ggml_backend_cuda_context & ctx, const mmq_args & args, cudaStream_t stream) { switch (args.type_x) { case GGML_TYPE_Q4_0: @@ -147,6 +219,33 @@ void ggml_cuda_mul_mat_q( ne11 * ne10_padded * sizeof(block_q8_1) / (QK8_1 * sizeof(int)); const int64_t s13 = ne12*s12; + if (src0->type == GGML_TYPE_Q5_K_HIFI_RES8) { + const int64_t n_blocks = (ne00 / QK_K) * ne01; + ggml_cuda_pool_alloc q5_k_compact(ctx.pool(), n_blocks * sizeof(block_q5_K)); + const int nth = 256; + ggml_cuda_compact_q5_k_hifi_res8_to_q5_k<<<(n_blocks + nth - 1) / nth, nth, 0, stream>>> + (src0_d, q5_k_compact.get(), n_blocks); + CUDA_CHECK(cudaGetLastError()); + const mmq_args args_q5 = { + q5_k_compact.get(), GGML_TYPE_Q5_K, (const int *) src1_q8_1.ptr, nullptr, nullptr, dst_d, + ne00, ne01, ne1, s01, ne11, s1, + ne02, ne12, s02, s12, s2, + ne03, ne13, s03, s13, s3, + use_stream_k, ne1}; + ggml_cuda_mul_mat_q_switch_type(ctx, args_q5, stream); + const int64_t stride_src1 = src1->nb[1] / (int64_t)sizeof(float); + const int64_t stride_dst = dst->nb[1] / (int64_t)sizeof(float); + // Launch one thread per (weight-row, block) pair. + // ~92% of threads exit immediately (no outliers); only ~8% touch src1/dst. + const int64_t n_blocks_per_row = ne00 / QK_K; + const int64_t n_rb = ne01 * n_blocks_per_row; + ggml_cuda_add_q5_k_hifi_res8_residuals<<<(n_rb + 255) / 256, 256, 0, stream>>> + ((const block_q5_k_hifi_res8 *)src0_d, (const float *)src1_d, dst_d, + ne01, ne00, ne1, s01, stride_src1, stride_dst); + CUDA_CHECK(cudaGetLastError()); + return; + } + const mmq_args args = { src0_d, src0->type, (const int *) src1_q8_1.ptr, nullptr, nullptr, dst_d, ne00, ne01, ne1, s01, ne11, s1, @@ -278,6 +377,7 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11, int64_t // Q3_K_HIFI excluded - uses MMVQ/dequant path instead case GGML_TYPE_Q4_K: case GGML_TYPE_Q5_K: + case GGML_TYPE_Q5_K_HIFI_RES8: // Use Q5_K MMQ path (compact copy + residual kernel) case GGML_TYPE_Q6_K: case GGML_TYPE_IQ2_XXS: case GGML_TYPE_IQ2_XS: diff --git a/ggml/src/ggml-cuda/mmq.cuh b/ggml/src/ggml-cuda/mmq.cuh index a382e6a6979..efe9e03459c 100644 --- a/ggml/src/ggml-cuda/mmq.cuh +++ b/ggml/src/ggml-cuda/mmq.cuh @@ -74,6 +74,7 @@ static mmq_q8_1_ds_layout mmq_get_q8_1_ds_layout(const ggml_type type_x) { return MMQ_Q8_1_DS_LAYOUT_D4; case GGML_TYPE_Q4_K: case GGML_TYPE_Q5_K: + case GGML_TYPE_Q5_K_HIFI_RES8: // uses Q5_K MMQ kernel after compact copy return MMQ_Q8_1_DS_LAYOUT_DS4; case GGML_TYPE_Q6_K: case GGML_TYPE_IQ2_XXS: diff --git a/ggml/src/ggml-metal/ggml-metal-device.cpp b/ggml/src/ggml-metal/ggml-metal-device.cpp index 9f77d7adfcc..cda71256341 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.cpp +++ b/ggml/src/ggml-metal/ggml-metal-device.cpp @@ -144,11 +144,15 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pool_2d(ggml_met return res; } +static const char * ggml_metal_type_name_for_kernel(ggml_type type); // forward declaration + ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_get_rows(ggml_metal_library_t lib, ggml_type tsrc) { char base[256]; char name[256]; - snprintf(base, 256, "kernel_get_rows_%s", ggml_type_name(tsrc)); + // Use ggml_metal_type_name_for_kernel for HIFI types so the kernel name matches + // the dedicated kernels registered in ggml-metal.metal (e.g. "q5_K_hifi_res8") + snprintf(base, 256, "kernel_get_rows_%s", ggml_metal_type_name_for_kernel(tsrc)); snprintf(name, 256, "%s", base); ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); @@ -532,9 +536,9 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_rwkv(ggml_metal_ return res; } -// Map HIFI types to their base types for kernel name generation -// Since HIFI types are based on Q6_K/Q5_K, they can use the same kernels -// Q3_K_HIFI has its own dedicated kernel, so it needs its own name +// Map HIFI types to their kernel name counterparts +// Q3_K_HIFI, Q4_K_HIFI, Q5_K_HIFI_RES8 have dedicated kernels with correct block strides +// Q6_K HIFI variants reuse Q6_K kernels (TODO: fix stride mismatch for Q6_K HIFI types) static const char * ggml_metal_type_name_for_kernel(ggml_type type) { switch (type) { case GGML_TYPE_Q3_K_HIFI: @@ -543,10 +547,11 @@ static const char * ggml_metal_type_name_for_kernel(ggml_type type) { return "q4_k_hifi"; case GGML_TYPE_Q6_K_HIFI: case GGML_TYPE_Q6_K_HIFI_DYNAMIC: - case GGML_TYPE_Q6_K_HIFI_RES8: return "q6_K"; + case GGML_TYPE_Q6_K_HIFI_RES8: + return "q6_K_hifi_res8"; case GGML_TYPE_Q5_K_HIFI_RES8: - return "q5_K"; + return "q5_K_hifi_res8"; default: return ggml_type_name(type); } diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 4072aa53525..93c50b5ecac 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -677,6 +677,14 @@ void dequantize_q6_K(device const block_q6_K *xb, short il, thread type4x4 & reg } } +// Q6_K_HIFI_RES8: Q6_K layout + 22-byte INT8 residual extension (232 bytes total) +// The base Q6_K fields (ql, qh, scales, d) are at identical byte offsets. +// Residual corrections are not applied in the Metal path (only in CPU path). +template +void dequantize_q6_k_hifi_res8(device const block_q6_k_hifi_res8 * xb, short il, thread type4x4 & reg) { + dequantize_q6_K((device const block_q6_K *)xb, il, reg); +} + template void dequantize_iq2_xxs(device const block_iq2_xxs * xb, short il, thread type4x4 & reg) { // il is 0...15 for QK_K = 256 => index of block of 32 is il/2 @@ -948,6 +956,14 @@ void dequantize_q4_k_hifi(device const block_q4_k_hifi * xb, short il, thread ty } } +// Q5_K_HIFI_RES8: Q5_K layout + 20-byte INT8 residual extension (196 bytes total) +// The base Q5_K fields (d, dmin, scales, qh, qs) are at identical byte offsets. +// Residual corrections are not applied in the Metal path (only in CPU path). +template +void dequantize_q5_k_hifi_res8(device const block_q5_k_hifi_res8 * xb, short il, thread type4x4 & reg) { + dequantize_q5_K((device const block_q5_K *)xb, il, reg); +} + enum ggml_sort_order { GGML_SORT_ORDER_ASC, GGML_SORT_ORDER_DESC, @@ -3810,11 +3826,25 @@ template [[host_name("kernel_mul_mv_ext_q5_K_f32_r1_3")]] kernel mul_mv_ext_q4x4 template [[host_name("kernel_mul_mv_ext_q5_K_f32_r1_4")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<4, block_q5_K, 256, dequantize_q5_K>; template [[host_name("kernel_mul_mv_ext_q5_K_f32_r1_5")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<5, block_q5_K, 256, dequantize_q5_K>; +typedef decltype(kernel_mul_mv_ext_q4x4_f32_disp<2, block_q5_k_hifi_res8, 256, dequantize_q5_k_hifi_res8>) mul_mv_ext_q5_K_hifi_res8_f32_t; + +template [[host_name("kernel_mul_mv_ext_q5_K_hifi_res8_f32_r1_2")]] kernel mul_mv_ext_q5_K_hifi_res8_f32_t kernel_mul_mv_ext_q4x4_f32_disp<2, block_q5_k_hifi_res8, 256, dequantize_q5_k_hifi_res8>; +template [[host_name("kernel_mul_mv_ext_q5_K_hifi_res8_f32_r1_3")]] kernel mul_mv_ext_q5_K_hifi_res8_f32_t kernel_mul_mv_ext_q4x4_f32_disp<3, block_q5_k_hifi_res8, 256, dequantize_q5_k_hifi_res8>; +template [[host_name("kernel_mul_mv_ext_q5_K_hifi_res8_f32_r1_4")]] kernel mul_mv_ext_q5_K_hifi_res8_f32_t kernel_mul_mv_ext_q4x4_f32_disp<4, block_q5_k_hifi_res8, 256, dequantize_q5_k_hifi_res8>; +template [[host_name("kernel_mul_mv_ext_q5_K_hifi_res8_f32_r1_5")]] kernel mul_mv_ext_q5_K_hifi_res8_f32_t kernel_mul_mv_ext_q4x4_f32_disp<5, block_q5_k_hifi_res8, 256, dequantize_q5_k_hifi_res8>; + template [[host_name("kernel_mul_mv_ext_q6_K_f32_r1_2")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<2, block_q6_K, 256, dequantize_q6_K>; template [[host_name("kernel_mul_mv_ext_q6_K_f32_r1_3")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<3, block_q6_K, 256, dequantize_q6_K>; template [[host_name("kernel_mul_mv_ext_q6_K_f32_r1_4")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<4, block_q6_K, 256, dequantize_q6_K>; template [[host_name("kernel_mul_mv_ext_q6_K_f32_r1_5")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<5, block_q6_K, 256, dequantize_q6_K>; +typedef decltype(kernel_mul_mv_ext_q4x4_f32_disp<2, block_q6_k_hifi_res8, 256, dequantize_q6_k_hifi_res8>) mul_mv_ext_q6_K_hifi_res8_f32_t; + +template [[host_name("kernel_mul_mv_ext_q6_K_hifi_res8_f32_r1_2")]] kernel mul_mv_ext_q6_K_hifi_res8_f32_t kernel_mul_mv_ext_q4x4_f32_disp<2, block_q6_k_hifi_res8, 256, dequantize_q6_k_hifi_res8>; +template [[host_name("kernel_mul_mv_ext_q6_K_hifi_res8_f32_r1_3")]] kernel mul_mv_ext_q6_K_hifi_res8_f32_t kernel_mul_mv_ext_q4x4_f32_disp<3, block_q6_k_hifi_res8, 256, dequantize_q6_k_hifi_res8>; +template [[host_name("kernel_mul_mv_ext_q6_K_hifi_res8_f32_r1_4")]] kernel mul_mv_ext_q6_K_hifi_res8_f32_t kernel_mul_mv_ext_q4x4_f32_disp<4, block_q6_k_hifi_res8, 256, dequantize_q6_k_hifi_res8>; +template [[host_name("kernel_mul_mv_ext_q6_K_hifi_res8_f32_r1_5")]] kernel mul_mv_ext_q6_K_hifi_res8_f32_t kernel_mul_mv_ext_q4x4_f32_disp<5, block_q6_k_hifi_res8, 256, dequantize_q6_k_hifi_res8>; + template void kernel_mul_mv_t_t_impl( args_t args, @@ -7867,6 +7897,140 @@ kernel void kernel_mul_mv_q5_K_f32( kernel_mul_mv_q5_K_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); } +// Q5_K_HIFI_RES8: identical to Q5_K mul_mv but uses block_q5_k_hifi_res8 pointer (196-byte stride) +// The base Q5_K fields are at identical byte offsets; HIFI residual extension is ignored here. +template +void kernel_mul_mv_q5_K_hifi_res8_f32_impl( + args_t args, + device const char * src0, + device const char * src1, + device char * dst, + threadgroup char * shmem, + uint3 tgpig, + ushort tiisg, + ushort sgitg) { + const short NSG = FC_mul_mv_nsg; + + const int nb = args.ne00/QK_K; + + const int r0 = tgpig.x; + const int r1 = tgpig.y; + const int im = tgpig.z; + + const int first_row = (r0 * NSG + sgitg) * nr0; + + const uint i12 = im%args.ne12; + const uint i13 = im/args.ne12; + + const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; + + // KEY FIX: use correct 196-byte struct stride instead of block_q5_K (176 bytes) + device const block_q5_k_hifi_res8 * x = (device const block_q5_k_hifi_res8 *) (src0 + offset0); + device const float * yy = (device const float *) (src1 + offset1); + + float sumf[nr0]={0.f}; + + float yl[16], yh[16]; + + constexpr uint16_t kmask1 = 0x3f3f; + constexpr uint16_t kmask2 = 0x0f0f; + constexpr uint16_t kmask3 = 0xc0c0; + + const short tid = tiisg/4; + const short ix = tiisg%4; + const short iq = tid/4; + const short ir = tid%4; + + const short l0 = 8*ir; + const short q_offset = 32*iq + l0; + const short y_offset = 64*iq + l0; + + const uint8_t hm1 = 1u << (2*iq); + const uint8_t hm2 = hm1 << 1; + const uint8_t hm3 = hm1 << 4; + const uint8_t hm4 = hm2 << 4; + + uint16_t sc16[4]; + thread const uint8_t * sc8 = (thread const uint8_t *)sc16; + + device const float * y1 = yy + ix*QK_K + y_offset; + + for (int i = ix; i < nb; i += 4) { + device const uint8_t * q1 = x[i].qs + q_offset; + device const uint8_t * qh = x[i].qh + l0; + device const half * dh = &x[i].d; + device const uint16_t * a = (device const uint16_t *)x[i].scales + iq; + + device const float * y2 = y1 + 128; + float4 sumy = {0.f, 0.f, 0.f, 0.f}; + for (short l = 0; l < 8; ++l) { + yl[l+0] = y1[l+ 0]; sumy[0] += yl[l+0]; + yl[l+8] = y1[l+32]; sumy[1] += yl[l+8]; + yh[l+0] = y2[l+ 0]; sumy[2] += yh[l+0]; + yh[l+8] = y2[l+32]; sumy[3] += yh[l+8]; + } + + for (short row = 0; row < nr0; ++row) { + device const uint8_t * q2 = q1 + 64; + + sc16[0] = a[0] & kmask1; + sc16[1] = a[2] & kmask1; + sc16[2] = ((a[4] >> 0) & kmask2) | ((a[0] & kmask3) >> 2); + sc16[3] = ((a[4] >> 4) & kmask2) | ((a[2] & kmask3) >> 2); + + float4 acc1 = {0.f}; + float4 acc2 = {0.f}; + FOR_UNROLL (short l = 0; l < 8; ++l) { + uint8_t h = qh[l]; + acc1[0] += yl[l+0] * (q1[l] & 0x0F); + acc1[1] += yl[l+8] * (q1[l] & 0xF0); + acc1[2] += yh[l+0] * (q2[l] & 0x0F); + acc1[3] += yh[l+8] * (q2[l] & 0xF0); + acc2[0] += h & hm1 ? yl[l+0] : 0.f; + acc2[1] += h & hm2 ? yl[l+8] : 0.f; + acc2[2] += h & hm3 ? yh[l+0] : 0.f; + acc2[3] += h & hm4 ? yh[l+8] : 0.f; + } + + sumf[row] += dh[0] * (sc8[0] * (acc1[0] + 16.f*acc2[0]) + + sc8[1] * (acc1[1]/16.f + 16.f*acc2[1]) + + sc8[4] * (acc1[2] + 16.f*acc2[2]) + + sc8[5] * (acc1[3]/16.f + 16.f*acc2[3])) - + dh[1] * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]); + + q1 += args.nb01; + qh += args.nb01; + dh += args.nb01/2; + a += args.nb01/2; + } + + y1 += 4 * QK_K; + } + + device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; + + for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) { + const float tot = simd_sum(sumf[row]); + if (tiisg == 0) { + dst_f32[first_row + row] = tot; + } + } +} + +[[host_name("kernel_mul_mv_q5_K_hifi_res8_f32")]] +kernel void kernel_mul_mv_q5_K_hifi_res8_f32( + constant ggml_metal_kargs_mul_mv & args, + device const char * src0, + device const char * src1, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { + + kernel_mul_mv_q5_K_hifi_res8_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); +} + template void kernel_mul_mv_q6_K_f32_impl( args_t args, @@ -7975,6 +8139,114 @@ kernel void kernel_mul_mv_q6_K_f32( kernel_mul_mv_q6_K_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); } +template +void kernel_mul_mv_q6_K_hifi_res8_f32_impl( + args_t args, + device const char * src0, + device const char * src1, + device char * dst, + threadgroup char * shmem, + uint3 tgpig, + ushort tiisg, + ushort sgitg) { + const short NSG = FC_mul_mv_nsg; + + constexpr uint8_t kmask1 = 0x03; + constexpr uint8_t kmask2 = 0x0C; + constexpr uint8_t kmask3 = 0x30; + constexpr uint8_t kmask4 = 0xC0; + + const int nb = args.ne00/QK_K; + + const int r0 = tgpig.x; + const int r1 = tgpig.y; + const int im = tgpig.z; + + const int first_row = (r0 * NSG + sgitg) * nr0; + + const uint i12 = im%args.ne12; + const uint i13 = im/args.ne12; + + const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; + + device const block_q6_k_hifi_res8 * x = (device const block_q6_k_hifi_res8 *) (src0 + offset0); + device const float * yy = (device const float *) (src1 + offset1); + + float sumf[nr0] = { 0.f }; + + float yl[16]; + + const short tid = tiisg/2; + const short ix = tiisg%2; + const short ip = tid/8; // 0 or 1 + const short il = tid%8; + const short l0 = 4*il; + const short is = 8*ip + l0/16; + + const short y_offset = 128*ip + l0; + const short q_offset_l = 64*ip + l0; + const short q_offset_h = 32*ip + l0; + + for (int i = ix; i < nb; i += 2) { + device const uint8_t * q1 = x[i].ql + q_offset_l; + device const uint8_t * q2 = q1 + 32; + device const uint8_t * qh = x[i].qh + q_offset_h; + device const int8_t * sc = x[i].scales + is; + device const half * dh = &x[i].d; + + device const float * y = yy + i * QK_K + y_offset; + + for (short l = 0; l < 4; ++l) { + yl[4*l + 0] = y[l + 0]; + yl[4*l + 1] = y[l + 32]; + yl[4*l + 2] = y[l + 64]; + yl[4*l + 3] = y[l + 96]; + } + + for (short row = 0; row < nr0; ++row) { + float4 sums = {0.f, 0.f, 0.f, 0.f}; + + FOR_UNROLL (short l = 0; l < 4; ++l) { + sums[0] += yl[4*l + 0] * ((int8_t)((q1[l] & 0xF) | ((qh[l] & kmask1) << 4)) - 32); + sums[1] += yl[4*l + 1] * ((int8_t)((q2[l] & 0xF) | ((qh[l] & kmask2) << 2)) - 32); + sums[2] += yl[4*l + 2] * ((int8_t)((q1[l] >> 4) | ((qh[l] & kmask3) << 0)) - 32); + sums[3] += yl[4*l + 3] * ((int8_t)((q2[l] >> 4) | ((qh[l] & kmask4) >> 2)) - 32); + } + + sumf[row] += dh[0] * (sums[0] * sc[0] + sums[1] * sc[2] + sums[2] * sc[4] + sums[3] * sc[6]); + + q1 += args.nb01; + q2 += args.nb01; + qh += args.nb01; + sc += args.nb01; + dh += args.nb01/2; + } + } + + device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; + + for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) { + float sum_all = simd_sum(sumf[row]); + if (tiisg == 0) { + dst_f32[first_row + row] = sum_all; + } + } +} + +[[host_name("kernel_mul_mv_q6_K_hifi_res8_f32")]] +kernel void kernel_mul_mv_q6_K_hifi_res8_f32( + constant ggml_metal_kargs_mul_mv & args, + device const char * src0, + device const char * src1, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { + + kernel_mul_mv_q6_K_hifi_res8_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); +} + // ======================= "True" 2-bit template @@ -9892,7 +10164,9 @@ template [[host_name("kernel_get_rows_q3_k_hifi")]] kernel get_rows_q_t kernel_g template [[host_name("kernel_get_rows_q4_K")]] kernel get_rows_q_t kernel_get_rows_q; template [[host_name("kernel_get_rows_q4_k_hifi")]] kernel get_rows_q_t kernel_get_rows_q; template [[host_name("kernel_get_rows_q5_K")]] kernel get_rows_q_t kernel_get_rows_q; +template [[host_name("kernel_get_rows_q5_k_hifi_res8")]] kernel get_rows_q_t kernel_get_rows_q; template [[host_name("kernel_get_rows_q6_K")]] kernel get_rows_q_t kernel_get_rows_q; +template [[host_name("kernel_get_rows_q6_K_hifi_res8")]] kernel get_rows_q_t kernel_get_rows_q; template [[host_name("kernel_get_rows_iq2_xxs")]] kernel get_rows_q_t kernel_get_rows_q; template [[host_name("kernel_get_rows_iq2_xs")]] kernel get_rows_q_t kernel_get_rows_q; template [[host_name("kernel_get_rows_iq3_xxs")]] kernel get_rows_q_t kernel_get_rows_q; @@ -9956,7 +10230,9 @@ template [[host_name("kernel_mul_mm_q3_k_hifi_f32")]] kernel mul_mm_t kernel_mul template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mul_mm_t kernel_mul_mm; template [[host_name("kernel_mul_mm_q4_k_hifi_f32")]] kernel mul_mm_t kernel_mul_mm; template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q5_K_hifi_res8_f32")]] kernel mul_mm_t kernel_mul_mm; template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q6_K_hifi_res8_f32")]] kernel mul_mm_t kernel_mul_mm; template [[host_name("kernel_mul_mm_iq2_xxs_f32")]] kernel mul_mm_t kernel_mul_mm; template [[host_name("kernel_mul_mm_iq2_xs_f32")]] kernel mul_mm_t kernel_mul_mm; template [[host_name("kernel_mul_mm_iq3_xxs_f32")]] kernel mul_mm_t kernel_mul_mm; @@ -9981,7 +10257,9 @@ template [[host_name("kernel_mul_mm_q3_k_hifi_f16")]] kernel mul_mm_t kernel_mul template [[host_name("kernel_mul_mm_q4_K_f16")]] kernel mul_mm_t kernel_mul_mm; template [[host_name("kernel_mul_mm_q4_k_hifi_f16")]] kernel mul_mm_t kernel_mul_mm; template [[host_name("kernel_mul_mm_q5_K_f16")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q5_K_hifi_res8_f16")]] kernel mul_mm_t kernel_mul_mm; template [[host_name("kernel_mul_mm_q6_K_f16")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q6_K_hifi_res8_f16")]] kernel mul_mm_t kernel_mul_mm; template [[host_name("kernel_mul_mm_iq2_xxs_f16")]] kernel mul_mm_t kernel_mul_mm; template [[host_name("kernel_mul_mm_iq2_xs_f16")]] kernel mul_mm_t kernel_mul_mm; template [[host_name("kernel_mul_mm_iq3_xxs_f16")]] kernel mul_mm_t kernel_mul_mm; @@ -10015,7 +10293,9 @@ template [[host_name("kernel_mul_mm_id_q3_k_hifi_f32")]] kernel mul_mm_id kernel template [[host_name("kernel_mul_mm_id_q4_K_f32")]] kernel mul_mm_id kernel_mul_mm_id; template [[host_name("kernel_mul_mm_id_q4_k_hifi_f32")]] kernel mul_mm_id kernel_mul_mm_id; template [[host_name("kernel_mul_mm_id_q5_K_f32")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_q5_K_hifi_res8_f32")]] kernel mul_mm_id kernel_mul_mm_id; template [[host_name("kernel_mul_mm_id_q6_K_f32")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_q6_K_hifi_res8_f32")]] kernel mul_mm_id kernel_mul_mm_id; template [[host_name("kernel_mul_mm_id_iq2_xxs_f32")]] kernel mul_mm_id kernel_mul_mm_id; template [[host_name("kernel_mul_mm_id_iq2_xs_f32")]] kernel mul_mm_id kernel_mul_mm_id; template [[host_name("kernel_mul_mm_id_iq3_xxs_f32")]] kernel mul_mm_id kernel_mul_mm_id; @@ -10040,7 +10320,9 @@ template [[host_name("kernel_mul_mm_id_q3_k_hifi_f16")]] kernel mul_mm_id kernel template [[host_name("kernel_mul_mm_id_q4_K_f16")]] kernel mul_mm_id kernel_mul_mm_id; template [[host_name("kernel_mul_mm_id_q4_k_hifi_f16")]] kernel mul_mm_id kernel_mul_mm_id; template [[host_name("kernel_mul_mm_id_q5_K_f16")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_q5_K_hifi_res8_f16")]] kernel mul_mm_id kernel_mul_mm_id; template [[host_name("kernel_mul_mm_id_q6_K_f16")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_q6_K_hifi_res8_f16")]] kernel mul_mm_id kernel_mul_mm_id; template [[host_name("kernel_mul_mm_id_iq2_xxs_f16")]] kernel mul_mm_id kernel_mul_mm_id; template [[host_name("kernel_mul_mm_id_iq2_xs_f16")]] kernel mul_mm_id kernel_mul_mm_id; template [[host_name("kernel_mul_mm_id_iq3_xxs_f16")]] kernel mul_mm_id kernel_mul_mm_id; @@ -10197,7 +10479,9 @@ template [[host_name("kernel_mul_mv_id_q3_k_hifi_f32")]] kernel kernel_mul_mv_id template [[host_name("kernel_mul_mv_id_q4_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; template [[host_name("kernel_mul_mv_id_q4_k_hifi_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; template [[host_name("kernel_mul_mv_id_q5_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_q5_K_hifi_res8_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; template [[host_name("kernel_mul_mv_id_q6_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_q6_K_hifi_res8_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; template [[host_name("kernel_mul_mv_id_iq1_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; template [[host_name("kernel_mul_mv_id_iq1_m_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; template [[host_name("kernel_mul_mv_id_iq2_xxs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; diff --git a/tools/quantize/quantize.cpp b/tools/quantize/quantize.cpp index d2a5969972c..8997e661379 100644 --- a/tools/quantize/quantize.cpp +++ b/tools/quantize/quantize.cpp @@ -220,6 +220,15 @@ static int load_legacy_imatrix(const std::string & imatrix_file, std::vector & imatrix_datasets, std::unordered_map> & imatrix_data) { + if (!std::filesystem::exists(imatrix_file)) { + fprintf(stderr, "%s: imatrix file '%s' not found\n", __func__, imatrix_file.c_str()); + exit(1); + } + if (!std::filesystem::is_regular_file(imatrix_file)) { + fprintf(stderr, "%s: imatrix path '%s' is not a regular file\n", __func__, imatrix_file.c_str()); + exit(1); + } + struct ggml_context * ctx = nullptr; struct gguf_init_params meta_gguf_params = { /* .no_alloc = */ false, // the data is needed