diff --git a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp index 34cbf3694b1..ee65c44aae1 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp @@ -1797,6 +1797,25 @@ class ggml_webgpu_shader_lib { defines.push_back("U32_DEQUANT_HELPERS"); defines.push_back("SRC0_INNER_TYPE=u32"); + switch (context.src0->type) { + case GGML_TYPE_IQ1_S: + case GGML_TYPE_IQ1_M: + case GGML_TYPE_IQ4_NL: + case GGML_TYPE_IQ4_XS: + defines.push_back(type_upper + "_GRID"); + break; + case GGML_TYPE_IQ2_XXS: + case GGML_TYPE_IQ2_XS: + case GGML_TYPE_IQ2_S: + case GGML_TYPE_IQ3_XXS: + case GGML_TYPE_IQ3_S: + defines.push_back(type_upper + "_GRID"); + defines.push_back(type_upper + "_TABLES"); + break; + default: + break; + } + variant += std::string("_") + src0_name; break; } diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index 762d9f8d1b4..2dff6ccab63 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -1422,7 +1422,7 @@ static webgpu_encoded_op ggml_webgpu_mul_mat(webgpu_context & ctx, case GGML_TYPE_IQ3_S: case GGML_TYPE_IQ4_NL: case GGML_TYPE_IQ4_XS: - use_fast = is_vec; + use_fast = true; break; default: break; diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl index 15b22c4f731..51cf08f196f 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl @@ -740,3 +740,426 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 } } #endif // INIT_SRC0_SHMEM_Q6_K + +#ifdef INIT_SRC0_SHMEM_IQ4_NL +const BLOCK_SIZE = 32u; +const BLOCK_SIZE_BYTES = 18u; + +fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) { + for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) { + let tile_m = elem_idx / TILE_K; + let tile_k = elem_idx % TILE_K; + let global_m = offset_m + tile_m; + let global_k = k_outer + tile_k; + + if (global_m >= params.m || global_k >= params.k) { + shmem[elem_idx] = f16(0.0); + continue; + } + + let block_k = global_k / BLOCK_SIZE; + let k_in_block = global_k % BLOCK_SIZE; + + let src0_idx = batch_offset + global_m * params.stride_01 + block_k; + let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; + let d = load_f16_at_src0(block_byte_base); + + let pos = k_in_block % 16u; + let nib_shift = (k_in_block / 16u) * 4u; + let q_packed = load_u32_at_src0(block_byte_base + 2u + (pos / 4u) * 4u); + let nib = (get_byte(q_packed, pos % 4u) >> nib_shift) & 0xFu; + + shmem[elem_idx] = d * f16(kvalues_iq4nl[nib]); + } +} +#endif // INIT_SRC0_SHMEM_IQ4_NL + +#ifdef INIT_SRC0_SHMEM_IQ4_XS +const BLOCK_SIZE = 256u; +const BLOCK_SIZE_BYTES = 136u; + +fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) { + for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) { + let tile_m = elem_idx / TILE_K; + let tile_k = elem_idx % TILE_K; + let global_m = offset_m + tile_m; + let global_k = k_outer + tile_k; + + if (global_m >= params.m || global_k >= params.k) { + shmem[elem_idx] = f16(0.0); + continue; + } + + let block_k = global_k / BLOCK_SIZE; + let k_in_block = global_k % BLOCK_SIZE; + + let src0_idx = batch_offset + global_m * params.stride_01 + block_k; + let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; + + let d_scales_h = load_u32_at_src0(block_byte_base); + let d = bitcast>(d_scales_h).x; + let scales_h = d_scales_h >> 16u; + + let ib = k_in_block / 32u; + let pos = k_in_block % 32u; + + let scales_l_word = load_u32_at_src0(block_byte_base + 4u); + let ls_lo = (get_byte(scales_l_word, ib / 2u) >> ((ib & 1u) * 4u)) & 0xFu; + let ls_hi = ((scales_h >> (2u * ib)) & 3u) << 4u; + let dl = d * f16(i32(ls_lo | ls_hi) - 32); + + let iqs = ib * 16u + (pos % 16u); + let nib_shift = (pos / 16u) * 4u; + let q_packed = load_u32_at_src0(block_byte_base + 8u + (iqs / 4u) * 4u); + let nib = (get_byte(q_packed, iqs % 4u) >> nib_shift) & 0xFu; + + shmem[elem_idx] = dl * f16(kvalues_iq4nl[nib]); + } +} +#endif // INIT_SRC0_SHMEM_IQ4_XS + +#ifdef INIT_SRC0_SHMEM_IQ1_S +const BLOCK_SIZE = 256u; +const BLOCK_SIZE_BYTES = 50u; + +fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) { + for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) { + let tile_m = elem_idx / TILE_K; + let tile_k = elem_idx % TILE_K; + let global_m = offset_m + tile_m; + let global_k = k_outer + tile_k; + + if (global_m >= params.m || global_k >= params.k) { + shmem[elem_idx] = f16(0.0); + continue; + } + + let block_k = global_k / BLOCK_SIZE; + let k_in_block = global_k % BLOCK_SIZE; + + let src0_idx = batch_offset + global_m * params.stride_01 + block_k; + let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; + let d = load_f16_as_f32_at_src0(block_byte_base); + + let ib = k_in_block / 32u; + let pos = k_in_block % 32u; + let l = pos / 8u; + let j = pos % 8u; + + let qh = load_u32_at_src0(block_byte_base + 34u + ib * 2u) & 0xFFFFu; + let dl = d * (2.0 * f32((qh >> 12u) & 7u) + 1.0); + let delta = select(IQ1_DELTA, -IQ1_DELTA, (qh & 0x8000u) != 0u); + + let qs_w = load_u32_at_src0(block_byte_base + 2u + ib * 4u); + let ig = (get_byte(qs_w, l) | (((qh >> (3u * l)) & 7u) << 8u)) * 8u; + + let gw = iq1_grid[(ig + j) / 16u]; + let g = (gw >> (((ig + j) % 16u) * 2u)) & 3u; + let gs = bitcast(g << 30u) >> 30u; + + shmem[elem_idx] = f16(dl * (f32(gs) + delta)); + } +} +#endif // INIT_SRC0_SHMEM_IQ1_S + +#ifdef INIT_SRC0_SHMEM_IQ1_M +const BLOCK_SIZE = 256u; +const BLOCK_SIZE_BYTES = 56u; + +fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) { + for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) { + let tile_m = elem_idx / TILE_K; + let tile_k = elem_idx % TILE_K; + let global_m = offset_m + tile_m; + let global_k = k_outer + tile_k; + + if (global_m >= params.m || global_k >= params.k) { + shmem[elem_idx] = f16(0.0); + continue; + } + + let block_k = global_k / BLOCK_SIZE; + let k_in_block = global_k % BLOCK_SIZE; + + let src0_idx = batch_offset + global_m * params.stride_01 + block_k; + let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; + + let scales0 = load_u32_at_src0(block_byte_base + 48u); + let scales1 = load_u32_at_src0(block_byte_base + 52u); + let scale_packed = ((scales0 >> 12u) & 0xFu) | + ((scales0 >> 24u) & 0x00F0u) | + ((scales1 >> 4u) & 0x0F00u) | + ((scales1 >> 16u) & 0xF000u); + let d = f32(bitcast>(scale_packed).x); + + let ib = k_in_block / 32u; + let pos = k_in_block % 32u; + let l = pos / 8u; + let j = pos % 8u; + + let scales = select(scales0, scales1, ib >= 4u); + let sw = (scales >> (16u * ((ib / 2u) % 2u))) & 0xFFFFu; + let s_pair = (sw >> (6u * (ib % 2u) + 3u * (l / 2u))) & 0x7u; + let dl = d * f32(2u * s_pair + 1u); + + let qh_word = load_u32_at_src0(block_byte_base + 32u + (ib / 2u) * 4u); + let qh = qh_word >> (16u * (ib % 2u)); + let qh_nib = (qh >> (4u * l)) & 0xFu; + + let qs_w = load_u32_at_src0(block_byte_base + ib * 4u); + let idx = get_byte(qs_w, l) | ((qh_nib & 7u) << 8u); + let delta = select(IQ1_DELTA, -IQ1_DELTA, (qh_nib & 0x8u) != 0u); + + let ig = idx * 8u; + let gw = iq1_grid[(ig + j) / 16u]; + let g = (gw >> (((ig + j) % 16u) * 2u)) & 3u; + let gs = bitcast(g << 30u) >> 30u; + + shmem[elem_idx] = f16(dl * (f32(gs) + delta)); + } +} +#endif // INIT_SRC0_SHMEM_IQ1_M + +#ifdef INIT_SRC0_SHMEM_IQ2_XXS +const BLOCK_SIZE = 256u; +const BLOCK_SIZE_BYTES = 66u; + +fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) { + for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) { + let tile_m = elem_idx / TILE_K; + let tile_k = elem_idx % TILE_K; + let global_m = offset_m + tile_m; + let global_k = k_outer + tile_k; + + if (global_m >= params.m || global_k >= params.k) { + shmem[elem_idx] = f16(0.0); + continue; + } + + let block_k = global_k / BLOCK_SIZE; + let k_in_block = global_k % BLOCK_SIZE; + + let src0_idx = batch_offset + global_m * params.stride_01 + block_k; + let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; + let d = load_f16_as_f32_at_src0(block_byte_base); + + let entry_idx = k_in_block / 8u; + let j = k_in_block % 8u; + + let ib = entry_idx & ~3u; + let l = entry_idx & 3u; + + let aux0 = load_u32_at_src0(block_byte_base + 2u + ib * 2u); + let aux1 = load_u32_at_src0(block_byte_base + 2u + (ib + 2u) * 2u); + let db = d * (0.5 + f32(aux1 >> 28u)) * 0.25; + + let ig = get_byte(aux0, l) * 8u; + let is = (aux1 >> (7u * l)) & 127u; + let signs = get_byte(ksigns_iq2xs[is / 4u], is % 4u); + + let g = get_byte(iq2xxs_grid[(ig + j) / 4u], (ig + j) % 4u); + let m = select(1.0, -1.0, (get_byte(kmask_iq2xs[j / 4u], j % 4u) & signs) != 0u); + + shmem[elem_idx] = f16(db * f32(g) * m); + } +} +#endif // INIT_SRC0_SHMEM_IQ2_XXS + +#ifdef INIT_SRC0_SHMEM_IQ2_XS +const BLOCK_SIZE = 256u; +const BLOCK_SIZE_BYTES = 74u; + +fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) { + for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) { + let tile_m = elem_idx / TILE_K; + let tile_k = elem_idx % TILE_K; + let global_m = offset_m + tile_m; + let global_k = k_outer + tile_k; + + if (global_m >= params.m || global_k >= params.k) { + shmem[elem_idx] = f16(0.0); + continue; + } + + let block_k = global_k / BLOCK_SIZE; + let k_in_block = global_k % BLOCK_SIZE; + + let src0_idx = batch_offset + global_m * params.stride_01 + block_k; + let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; + let d = load_f16_as_f32_at_src0(block_byte_base); + + let entry_idx = k_in_block / 8u; + let j = k_in_block % 8u; + + let ib = entry_idx & ~3u; + let l = entry_idx & 3u; + + let scales_word = load_u32_at_src0(block_byte_base + 66u + (ib / 16u) * 4u); + let s = get_byte(scales_word, (ib % 16u) / 4u); + let s_nib = select(s & 0xFu, (s >> 4u) & 0xFu, (l / 2u) != 0u); + let dl = d * (0.5 + f32(s_nib)) * 0.25; + + let qs_word = load_u32_at_src0(block_byte_base + 2u + (ib + l) * 2u); + let qs_val = qs_word & 0xFFFFu; + let ig = (qs_val & 511u) * 8u; + let is = qs_val >> 9u; + let signs = get_byte(ksigns_iq2xs[is / 4u], is % 4u); + + let g = get_byte(iq2xs_grid[(ig + j) / 4u], (ig + j) % 4u); + let m = select(1.0, -1.0, (get_byte(kmask_iq2xs[j / 4u], j % 4u) & signs) != 0u); + + shmem[elem_idx] = f16(dl * f32(g) * m); + } +} +#endif // INIT_SRC0_SHMEM_IQ2_XS + +#ifdef INIT_SRC0_SHMEM_IQ2_S +const BLOCK_SIZE = 256u; +const BLOCK_SIZE_BYTES = 82u; + +fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) { + for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) { + let tile_m = elem_idx / TILE_K; + let tile_k = elem_idx % TILE_K; + let global_m = offset_m + tile_m; + let global_k = k_outer + tile_k; + + if (global_m >= params.m || global_k >= params.k) { + shmem[elem_idx] = f16(0.0); + continue; + } + + let block_k = global_k / BLOCK_SIZE; + let k_in_block = global_k % BLOCK_SIZE; + + let src0_idx = batch_offset + global_m * params.stride_01 + block_k; + let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; + let d = load_f16_as_f32_at_src0(block_byte_base); + + let ib = k_in_block / 32u; + let l = (k_in_block % 32u) / 8u; + let j = k_in_block % 8u; + + let scales_word = load_u32_at_src0(block_byte_base + 74u + (ib / 4u) * 4u); + let s = get_byte(scales_word, ib % 4u); + let s_nib = select(s & 0xFu, (s >> 4u) & 0xFu, (l / 2u) != 0u); + let dl = d * (0.5 + f32(s_nib)) * 0.25; + + let qs_word = load_u32_at_src0(block_byte_base + 2u + ib * 4u); + let qh_word = load_u32_at_src0(block_byte_base + 66u + (ib / 4u) * 4u); + let qh_b = (get_byte(qh_word, ib % 4u) << (8u - 2u * l)) & 0x300u; + let ig = (get_byte(qs_word, l) | qh_b) * 8u; + + let signs_word = load_u32_at_src0(block_byte_base + 34u + ib * 4u); + let signs = get_byte(signs_word, l); + + let g = get_byte(iq2s_grid[(ig + j) / 4u], (ig + j) % 4u); + let m = select(1.0, -1.0, (get_byte(kmask_iq2xs[j / 4u], j % 4u) & signs) != 0u); + + shmem[elem_idx] = f16(dl * f32(g) * m); + } +} +#endif // INIT_SRC0_SHMEM_IQ2_S + +#ifdef INIT_SRC0_SHMEM_IQ3_XXS +const BLOCK_SIZE = 256u; +const BLOCK_SIZE_BYTES = 98u; + +fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) { + for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) { + let tile_m = elem_idx / TILE_K; + let tile_k = elem_idx % TILE_K; + let global_m = offset_m + tile_m; + let global_k = k_outer + tile_k; + + if (global_m >= params.m || global_k >= params.k) { + shmem[elem_idx] = f16(0.0); + continue; + } + + let block_k = global_k / BLOCK_SIZE; + let k_in_block = global_k % BLOCK_SIZE; + + let src0_idx = batch_offset + global_m * params.stride_01 + block_k; + let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; + let d = load_f16_as_f32_at_src0(block_byte_base); + + let ib_pair = k_in_block / 32u; + let in_pair = k_in_block % 32u; + let l = in_pair / 8u; + let in_l = in_pair % 8u; + let k2 = in_l / 4u; + let j = in_l % 4u; + + let ib = ib_pair * 2u; + let sc_sign_off = block_byte_base + 2u + (ib + 32u) * 2u; + let sc_sign = load_u32_at_src0(sc_sign_off); + let db = d * (0.5 + f32(sc_sign >> 28u)) * 0.5; + let is = (sc_sign >> (7u * l)) & 127u; + let signs = get_byte(ksigns_iq2xs[is / 4u], is % 4u); + + let ig_word = load_u32_at_src0(block_byte_base + 2u + (ib * 2u + l) * 2u) & 0xFFFFu; + let ig_byte = get_byte(ig_word, k2); + let g = get_byte(iq3xxs_grid[ig_byte], j); + let m = select(1.0, -1.0, (get_byte(kmask_iq2xs[k2], j) & signs) != 0u); + + shmem[elem_idx] = f16(db * f32(g) * m); + } +} +#endif // INIT_SRC0_SHMEM_IQ3_XXS + +#ifdef INIT_SRC0_SHMEM_IQ3_S +const BLOCK_SIZE = 256u; +const BLOCK_SIZE_BYTES = 110u; + +fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) { + for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) { + let tile_m = elem_idx / TILE_K; + let tile_k = elem_idx % TILE_K; + let global_m = offset_m + tile_m; + let global_k = k_outer + tile_k; + + if (global_m >= params.m || global_k >= params.k) { + shmem[elem_idx] = f16(0.0); + continue; + } + + let block_k = global_k / BLOCK_SIZE; + let k_in_block = global_k % BLOCK_SIZE; + + let src0_idx = batch_offset + global_m * params.stride_01 + block_k; + let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; + let d = load_f16_as_f32_at_src0(block_byte_base); + + let ib = k_in_block / 64u; + let rest = k_in_block % 64u; + let k = rest / 32u; + let in_k = rest % 32u; + let l = in_k / 8u; + let in_l = in_k % 8u; + let k2 = in_l / 4u; + let j = in_l % 4u; + + let scales_word = load_u32_at_src0(block_byte_base + 106u); + let s = get_byte(scales_word, ib); + let s_nib = select(s & 0xFu, (s >> 4u) & 0xFu, k != 0u); + let dl = d * (1.0 + 2.0 * f32(s_nib)); + + let qh_word = load_u32_at_src0(block_byte_base + 66u + (ib / 2u) * 4u); + let qh_byte = get_byte(qh_word, (ib % 2u) * 2u + k); + + let ig_word = load_u32_at_src0(block_byte_base + 2u + (ib * 8u + k * 4u + l) * 2u) & 0xFFFFu; + let ig_lo = get_byte(ig_word, 0u) | ((qh_byte << (8u - 2u * l)) & 256u); + let ig_hi = get_byte(ig_word, 1u) | ((qh_byte << (7u - 2u * l)) & 256u); + let ig = select(ig_lo, ig_hi, k2 != 0u); + + let signs_word = load_u32_at_src0(block_byte_base + 74u + (ib * 2u + k) * 4u); + let signs = get_byte(signs_word, l); + + let g = get_byte(iq3s_grid[ig], j); + let m = select(1.0, -1.0, (get_byte(kmask_iq2xs[k2], j) & signs) != 0u); + + shmem[elem_idx] = f16(dl * f32(g) * m); + } +} +#endif // INIT_SRC0_SHMEM_IQ3_S