From 1be49b216e25d0ff426cf6cb987a4f24146c632a Mon Sep 17 00:00:00 2001 From: Jianhui Dai Date: Wed, 5 Mar 2025 16:42:28 +0800 Subject: [PATCH 01/12] [webgpu] Optimize MatMulNBits for f16 Block32 prefill performance --- .../webgpu/quantization/matmul_nbits.cc | 145 ++++++++++++++++++ .../webgpu/quantization/matmul_nbits.h | 8 + 2 files changed, 153 insertions(+) diff --git a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc index be105a0fd4374..f8f5a9c3c8ef6 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc +++ b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc @@ -523,6 +523,117 @@ Status MatMulNBitsProgram::GenerateShaderCode(ShaderHelper& shader) const { return Status::OK(); } +Status MatMulNBitsBlock32Program::GenerateShaderCode(ShaderHelper& shader) const { + const auto& a = shader.AddInput("input_a", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias); + const auto& b = shader.AddInput("input_b", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias); + const auto& scales = shader.AddInput("scales", ShaderUsage::UseUniform); + const auto& y = shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias | ShaderUsage::UseIndicesTypeAlias); + + const uint32_t workgroup_size = WorkgroupSizeX() * WorkgroupSizeY(); + + // memory read/write helpers + shader.AdditionalImplementation() << "fn mm_read_a(batch : u32, row : u32, col : u32) -> input_a_value_t {\n"; + shader.AdditionalImplementation() << " if (row < uniforms.input_a_shape[1] && col < uniforms.input_a_shape[2]) {\n"; + shader.AdditionalImplementation() << " return " << a.GetByIndices("input_a_indices_t(batch, row, col)") << ";\n"; + shader.AdditionalImplementation() << " }\n"; + shader.AdditionalImplementation() << " return input_a_value_t(0);\n"; + shader.AdditionalImplementation() << "}\n"; + + shader.AdditionalImplementation() << "\n"; + shader.AdditionalImplementation() << "fn mm_read_b(row : u32, col : u32) -> input_b_value_t {\n"; + shader.AdditionalImplementation() << " if (row < uniforms.input_b_shape[0] && col < uniforms.input_b_shape[1]) {\n"; + shader.AdditionalImplementation() << " return " << b.GetByIndices("input_b_indices_t(row, col, 0)") << ";\n"; + shader.AdditionalImplementation() << " }\n"; + shader.AdditionalImplementation() << " return input_b_value_t(0);\n"; + shader.AdditionalImplementation() << "}\n"; + + shader.AdditionalImplementation() << "\n"; + shader.AdditionalImplementation() << "fn mm_read_scale(row : u32, col : u32) -> output_value_t {\n"; + shader.AdditionalImplementation() << " if (row < uniforms.input_b_shape[0] && col < uniforms.input_b_shape[1]) {\n"; + shader.AdditionalImplementation() << " return " << scales.GetByOffset("row * uniforms.input_b_shape[1] + col") << ";\n"; + shader.AdditionalImplementation() << " }\n"; + shader.AdditionalImplementation() << " return output_value_t(0);\n"; + shader.AdditionalImplementation() << "}\n"; + + shader.AdditionalImplementation() << "\n"; + shader.AdditionalImplementation() << "fn mm_write_y(batch : u32, row : u32, col : u32, value : output_value_t) {\n"; + shader.AdditionalImplementation() << " if (row < uniforms.output_shape[1] && col < uniforms.output_shape[2]) {\n"; + shader.AdditionalImplementation() << " " << y.SetByIndices("output_indices_t(batch, row, col)", "value") << "\n"; + shader.AdditionalImplementation() << " }\n"; + shader.AdditionalImplementation() << "}\n"; + + // declare const variables + shader.AdditionalImplementation() << "\n"; + shader.AdditionalImplementation() << "const tile_m = " << workgroup_size / 8 << "u;\n"; + shader.AdditionalImplementation() << "const tile_n = " << workgroup_size << "u;\n"; + + // declare workgroup memory + shader.AdditionalImplementation() << "\n"; + shader.AdditionalImplementation() << "var a_data_wg: array, tile_m>;\n"; + shader.AdditionalImplementation() << "\n"; + + // main + shader.MainFunctionBody() << R"MAIN_FN( + let batch = workgroup_id.z; + let row = workgroup_id.y * tile_m; + let col = workgroup_id.x * tile_n; + + let a_elements_per_col = uniforms.input_a_shape[2]; + // A block32 containing 8 elements of `a`. + let a_blocks_per_col = (a_elements_per_col + 7u) / 8u; + + // f32 accumulator + var results : array; + for (var a_block_idx = 0u; a_block_idx < a_blocks_per_col; a_block_idx++) { + // load `a` elements into workgroup memory, TileM x 8(block32) + let a_row_idx = local_idx / 8u; + let a_col_idx = local_idx % 8u; + a_data_wg[a_row_idx][a_col_idx] = mm_read_a(batch, row + a_row_idx, a_block_idx * 8u + a_col_idx); + workgroupBarrier(); + + let b_row = col + local_idx; + let b_col = a_block_idx; + + let b_data = mm_read_b(b_row, b_col); + let scale = mm_read_scale(b_row, b_col); + let zero_point = output_element_t(8.0); + + for (var b_idx = 0u; b_idx < 4u; b_idx++) { + let b_value = b_data[b_idx]; + let b_value_lower = unpack4xU8(b_value & 0x0F0F0F0Fu); + let b_value_upper = unpack4xU8((b_value >> 4u) & 0x0F0F0F0Fu); + let b_quantized_values = mat2x4( + output_element_t(b_value_lower[0]), output_element_t(b_value_upper[0]), + output_element_t(b_value_lower[1]), output_element_t(b_value_upper[1]), + output_element_t(b_value_lower[2]), output_element_t(b_value_upper[2]), + output_element_t(b_value_lower[3]), output_element_t(b_value_upper[3])); + let b_dequantized_values = + (b_quantized_values - mat2x4(zero_point, zero_point, + zero_point, zero_point, + zero_point, zero_point, + zero_point, zero_point)) * scale; + + for (var m_idx = 0u; m_idx < tile_m; m_idx++) { + let a_data0 = a_data_wg[m_idx][b_idx * 2u]; + let a_data1 = a_data_wg[m_idx][b_idx * 2u + 1u]; + + results[m_idx] += f32(dot(a_data0, b_dequantized_values[0u])) + + f32(dot(a_data1, b_dequantized_values[1u])); + } + } + + workgroupBarrier(); + } + + // write the results + for (var m_idx = 0u; m_idx < tile_m; m_idx++) { + mm_write_y(batch, row + m_idx, col + local_idx, output_value_t(results[m_idx])); + } +)MAIN_FN"; + + return Status::OK(); +} + Status MatMulNBits::ComputeInternal(onnxruntime::webgpu::ComputeContext& context) const { const Tensor* a = context.Input(0); const Tensor* b = context.Input(1); @@ -569,6 +680,40 @@ Status MatMulNBits::ComputeInternal(onnxruntime::webgpu::ComputeContext& context return ApplyDP4AMatrixMatMulNBits(a, b, scales, M, N, K, block_size, kMinMForTileOptimization, context, y); } + // Block32 prefill program + // This program is optimized for Block32 prefill using Tile16x128. + const bool use_block32_program = block_size == 32 && batch_count == 1 && !has_zero_points && + components_a == 4 && components_b == 4 && M > 1 && + context.AdapterInfo().vendor == std::string_view{"intel"}; + if (use_block32_program) { + // enforce components to 1. + components = 1; + + constexpr uint32_t workgroup_size = 128; + constexpr uint32_t tile_m = workgroup_size / 8; + constexpr uint32_t tile_n = workgroup_size; + + MatMulNBitsBlock32Program program{}; + program.SetWorkgroupSize(workgroup_size); + program.SetDispatchGroupSize((N + tile_n - 1) / tile_n, + (M + tile_m - 1) / tile_m, + batch_count); + program.CacheHint("Tile" + std::to_string(tile_m) + "x" + std::to_string(tile_n) + "_Block32"); + + TensorShape reshaped_a_shape{batch_count, M, K / components_a}; + TensorShape reshaped_b_shape{N, n_blocks_per_col, blob_size_in_words / components_b}; + TensorShape reshaped_y_shape{batch_count, M, N / components}; + + program + .AddInputs({{a, ProgramTensorMetadataDependency::TypeAndRank, reshaped_a_shape, gsl::narrow(components_a)}, + {b, ProgramTensorMetadataDependency::TypeAndRank, reshaped_b_shape, gsl::narrow(components_b * 4)}, + {scales, ProgramTensorMetadataDependency::None}}) + .AddOutput({y, ProgramTensorMetadataDependency::TypeAndRank, reshaped_y_shape, gsl::narrow(components)}) + .AddUniformVariable({block_size}); + return context.RunProgram(program); + } + + // Generic program // TODO: Support output_number > 1. Some cases are failed when output_number > 1. constexpr uint32_t output_number = 1; const uint32_t tile_m = M > kMinMForTileOptimization ? 4 : 1; diff --git a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.h b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.h index 10221e19c7400..e877a048730dc 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.h +++ b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.h @@ -35,6 +35,14 @@ class MatMulNBitsProgram final : public Program { bool use_subgroup_; }; +class MatMulNBitsBlock32Program final : public Program { + public: + MatMulNBitsBlock32Program() : Program{"MatMulNBitsBlock32"} {} + + Status GenerateShaderCode(ShaderHelper& sh) const override; + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"block_size", ProgramUniformVariableDataType::Uint32}); +}; + class MatMulNBits final : public WebGpuKernel { public: MatMulNBits(const OpKernelInfo& info) : WebGpuKernel(info) { From 4ead004353f605a8fe51257910df9517c4946a60 Mon Sep 17 00:00:00 2001 From: Jianhui Dai Date: Wed, 12 Mar 2025 16:10:19 +0800 Subject: [PATCH 02/12] Resolve comments - Rename to `MatMulNBitsBlockWideTileProgram` for clarity. - Enforce `M >= kMinMForTileOptimization`. - Add TODO for future improvements. --- .../webgpu/quantization/matmul_nbits.cc | 15 ++++++++------- .../webgpu/quantization/matmul_nbits.h | 4 ++-- 2 files changed, 10 insertions(+), 9 deletions(-) diff --git a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc index f8f5a9c3c8ef6..041b46e1c2c8e 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc +++ b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc @@ -523,7 +523,7 @@ Status MatMulNBitsProgram::GenerateShaderCode(ShaderHelper& shader) const { return Status::OK(); } -Status MatMulNBitsBlock32Program::GenerateShaderCode(ShaderHelper& shader) const { +Status MatMulNBitsBlockWideTileProgram::GenerateShaderCode(ShaderHelper& shader) const { const auto& a = shader.AddInput("input_a", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias); const auto& b = shader.AddInput("input_b", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias); const auto& scales = shader.AddInput("scales", ShaderUsage::UseUniform); @@ -680,12 +680,13 @@ Status MatMulNBits::ComputeInternal(onnxruntime::webgpu::ComputeContext& context return ApplyDP4AMatrixMatMulNBits(a, b, scales, M, N, K, block_size, kMinMForTileOptimization, context, y); } - // Block32 prefill program + // BlockWideTileProgram // This program is optimized for Block32 prefill using Tile16x128. - const bool use_block32_program = block_size == 32 && batch_count == 1 && !has_zero_points && - components_a == 4 && components_b == 4 && M > 1 && - context.AdapterInfo().vendor == std::string_view{"intel"}; - if (use_block32_program) { + // TODO: loosen restrictions on batch_count, has_zero_points, and vendor. + const bool use_block_wide_tile_program = block_size == 32 && batch_count == 1 && !has_zero_points && + components_a == 4 && components_b == 4 && M >= kMinMForTileOptimization && + context.AdapterInfo().vendor == std::string_view{"intel"}; + if (use_block_wide_tile_program) { // enforce components to 1. components = 1; @@ -693,7 +694,7 @@ Status MatMulNBits::ComputeInternal(onnxruntime::webgpu::ComputeContext& context constexpr uint32_t tile_m = workgroup_size / 8; constexpr uint32_t tile_n = workgroup_size; - MatMulNBitsBlock32Program program{}; + MatMulNBitsBlockWideTileProgram program{}; program.SetWorkgroupSize(workgroup_size); program.SetDispatchGroupSize((N + tile_n - 1) / tile_n, (M + tile_m - 1) / tile_m, diff --git a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.h b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.h index e877a048730dc..30315fbd09c78 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.h +++ b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.h @@ -35,9 +35,9 @@ class MatMulNBitsProgram final : public Program { bool use_subgroup_; }; -class MatMulNBitsBlock32Program final : public Program { +class MatMulNBitsBlockWideTileProgram final : public Program { public: - MatMulNBitsBlock32Program() : Program{"MatMulNBitsBlock32"} {} + MatMulNBitsBlockWideTileProgram() : Program{"MatMulNBitsBlockWideTileProgram"} {} Status GenerateShaderCode(ShaderHelper& sh) const override; WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"block_size", ProgramUniformVariableDataType::Uint32}); From 14bbe9dfbdc24e3339866a8b2c7fca308a374b2f Mon Sep 17 00:00:00 2001 From: Jianhui Dai Date: Tue, 18 Mar 2025 13:45:36 +0800 Subject: [PATCH 03/12] Fix variable naming --- .../contrib_ops/webgpu/quantization/matmul_nbits.cc | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc index 041b46e1c2c8e..b12b332ba939a 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc +++ b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc @@ -569,7 +569,7 @@ Status MatMulNBitsBlockWideTileProgram::GenerateShaderCode(ShaderHelper& shader) // declare workgroup memory shader.AdditionalImplementation() << "\n"; - shader.AdditionalImplementation() << "var a_data_wg: array, tile_m>;\n"; + shader.AdditionalImplementation() << "var a_data_tile: array, tile_m>;\n"; shader.AdditionalImplementation() << "\n"; // main @@ -588,7 +588,7 @@ Status MatMulNBitsBlockWideTileProgram::GenerateShaderCode(ShaderHelper& shader) // load `a` elements into workgroup memory, TileM x 8(block32) let a_row_idx = local_idx / 8u; let a_col_idx = local_idx % 8u; - a_data_wg[a_row_idx][a_col_idx] = mm_read_a(batch, row + a_row_idx, a_block_idx * 8u + a_col_idx); + a_data_tile[a_row_idx][a_col_idx] = mm_read_a(batch, row + a_row_idx, a_block_idx * 8u + a_col_idx); workgroupBarrier(); let b_row = col + local_idx; @@ -614,8 +614,8 @@ Status MatMulNBitsBlockWideTileProgram::GenerateShaderCode(ShaderHelper& shader) zero_point, zero_point)) * scale; for (var m_idx = 0u; m_idx < tile_m; m_idx++) { - let a_data0 = a_data_wg[m_idx][b_idx * 2u]; - let a_data1 = a_data_wg[m_idx][b_idx * 2u + 1u]; + let a_data0 = a_data_tile[m_idx][b_idx * 2u]; + let a_data1 = a_data_tile[m_idx][b_idx * 2u + 1u]; results[m_idx] += f32(dot(a_data0, b_dequantized_values[0u])) + f32(dot(a_data1, b_dequantized_values[1u])); From 0f55827fd2a39950bfba561aae7d4815ab2198a6 Mon Sep 17 00:00:00 2001 From: Jianhui Dai Date: Wed, 26 Mar 2025 09:53:04 +0800 Subject: [PATCH 04/12] Add comment on f32 accumulator --- onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc index b12b332ba939a..2ac4a333da4ea 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc +++ b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc @@ -582,7 +582,8 @@ Status MatMulNBitsBlockWideTileProgram::GenerateShaderCode(ShaderHelper& shader) // A block32 containing 8 elements of `a`. let a_blocks_per_col = (a_elements_per_col + 7u) / 8u; - // f32 accumulator + // Utilizing an f32 accumulator mitigated precision loss with minimal + // performance impact compared to an f16 accumulator. var results : array; for (var a_block_idx = 0u; a_block_idx < a_blocks_per_col; a_block_idx++) { // load `a` elements into workgroup memory, TileM x 8(block32) From 695d9d059d26f97c07f3549b0c05bcf6bdf60d72 Mon Sep 17 00:00:00 2001 From: Jianhui Dai Date: Wed, 26 Mar 2025 09:58:17 +0800 Subject: [PATCH 05/12] Improve comment --- onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc index 2ac4a333da4ea..902a75f873e45 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc +++ b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc @@ -688,7 +688,7 @@ Status MatMulNBits::ComputeInternal(onnxruntime::webgpu::ComputeContext& context components_a == 4 && components_b == 4 && M >= kMinMForTileOptimization && context.AdapterInfo().vendor == std::string_view{"intel"}; if (use_block_wide_tile_program) { - // enforce components to 1. + // Enforce output components to 1. components = 1; constexpr uint32_t workgroup_size = 128; From fd751cbd7f82c88707439f5c92bc6a7d7e26aa0a Mon Sep 17 00:00:00 2001 From: Jianhui Dai Date: Wed, 26 Mar 2025 10:28:58 +0800 Subject: [PATCH 06/12] More comment and avoid magic number --- .../webgpu/quantization/matmul_nbits.cc | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc index 902a75f873e45..9aa8ed2006bc0 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc +++ b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc @@ -529,6 +529,8 @@ Status MatMulNBitsBlockWideTileProgram::GenerateShaderCode(ShaderHelper& shader) const auto& scales = shader.AddInput("scales", ShaderUsage::UseUniform); const auto& y = shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias | ShaderUsage::UseIndicesTypeAlias); + // Bock size 32, `a` component size 4, 8 `a` components per block. + constexpr uint32_t kAComponentSizeForBlock32 = 8; const uint32_t workgroup_size = WorkgroupSizeX() * WorkgroupSizeY(); // memory read/write helpers @@ -564,12 +566,14 @@ Status MatMulNBitsBlockWideTileProgram::GenerateShaderCode(ShaderHelper& shader) // declare const variables shader.AdditionalImplementation() << "\n"; + shader.AdditionalImplementation() << "// A block32 containing 8 components of `a`." << "\n"; + shader.AdditionalImplementation() << "const kAComponentSizeForBlock32 = " << kAComponentSizeForBlock32 << "u;\n"; shader.AdditionalImplementation() << "const tile_m = " << workgroup_size / 8 << "u;\n"; shader.AdditionalImplementation() << "const tile_n = " << workgroup_size << "u;\n"; // declare workgroup memory shader.AdditionalImplementation() << "\n"; - shader.AdditionalImplementation() << "var a_data_tile: array, tile_m>;\n"; + shader.AdditionalImplementation() << "var a_data_tile: array, tile_m>;\n"; shader.AdditionalImplementation() << "\n"; // main @@ -579,17 +583,16 @@ Status MatMulNBitsBlockWideTileProgram::GenerateShaderCode(ShaderHelper& shader) let col = workgroup_id.x * tile_n; let a_elements_per_col = uniforms.input_a_shape[2]; - // A block32 containing 8 elements of `a`. - let a_blocks_per_col = (a_elements_per_col + 7u) / 8u; + let a_blocks_per_col = (a_elements_per_col + kAComponentSizeForBlock32 - 1) / kAComponentSizeForBlock32; // Utilizing an f32 accumulator mitigated precision loss with minimal // performance impact compared to an f16 accumulator. var results : array; for (var a_block_idx = 0u; a_block_idx < a_blocks_per_col; a_block_idx++) { - // load `a` elements into workgroup memory, TileM x 8(block32) - let a_row_idx = local_idx / 8u; - let a_col_idx = local_idx % 8u; - a_data_tile[a_row_idx][a_col_idx] = mm_read_a(batch, row + a_row_idx, a_block_idx * 8u + a_col_idx); + // Load `a` elements into workgroup memory, TileM x kAComponentSizeForBlock32 (block32) + let a_row_idx = local_idx / kAComponentSizeForBlock32; + let a_col_idx = local_idx % kAComponentSizeForBlock32; + a_data_tile[a_row_idx][a_col_idx] = mm_read_a(batch, row + a_row_idx, a_block_idx * kAComponentSizeForBlock32 + a_col_idx); workgroupBarrier(); let b_row = col + local_idx; @@ -599,6 +602,7 @@ Status MatMulNBitsBlockWideTileProgram::GenerateShaderCode(ShaderHelper& shader) let scale = mm_read_scale(b_row, b_col); let zero_point = output_element_t(8.0); + // `b` component size is 4. for (var b_idx = 0u; b_idx < 4u; b_idx++) { let b_value = b_data[b_idx]; let b_value_lower = unpack4xU8(b_value & 0x0F0F0F0Fu); From 58d76f62065889bd7e877d84a03994f9cc915909 Mon Sep 17 00:00:00 2001 From: Jianhui Dai Date: Wed, 26 Mar 2025 10:32:55 +0800 Subject: [PATCH 07/12] Improve variable naming --- .../webgpu/quantization/matmul_nbits.cc | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc index 9aa8ed2006bc0..3befe3e89ac7e 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc +++ b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc @@ -530,7 +530,7 @@ Status MatMulNBitsBlockWideTileProgram::GenerateShaderCode(ShaderHelper& shader) const auto& y = shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias | ShaderUsage::UseIndicesTypeAlias); // Bock size 32, `a` component size 4, 8 `a` components per block. - constexpr uint32_t kAComponentSizeForBlock32 = 8; + constexpr uint32_t kAComponentsForBlock32 = 8; const uint32_t workgroup_size = WorkgroupSizeX() * WorkgroupSizeY(); // memory read/write helpers @@ -567,13 +567,13 @@ Status MatMulNBitsBlockWideTileProgram::GenerateShaderCode(ShaderHelper& shader) // declare const variables shader.AdditionalImplementation() << "\n"; shader.AdditionalImplementation() << "// A block32 containing 8 components of `a`." << "\n"; - shader.AdditionalImplementation() << "const kAComponentSizeForBlock32 = " << kAComponentSizeForBlock32 << "u;\n"; + shader.AdditionalImplementation() << "const kAComponentsForBlock32 = " << kAComponentsForBlock32 << "u;\n"; shader.AdditionalImplementation() << "const tile_m = " << workgroup_size / 8 << "u;\n"; shader.AdditionalImplementation() << "const tile_n = " << workgroup_size << "u;\n"; // declare workgroup memory shader.AdditionalImplementation() << "\n"; - shader.AdditionalImplementation() << "var a_data_tile: array, tile_m>;\n"; + shader.AdditionalImplementation() << "var a_data_tile: array, tile_m>;\n"; shader.AdditionalImplementation() << "\n"; // main @@ -583,16 +583,16 @@ Status MatMulNBitsBlockWideTileProgram::GenerateShaderCode(ShaderHelper& shader) let col = workgroup_id.x * tile_n; let a_elements_per_col = uniforms.input_a_shape[2]; - let a_blocks_per_col = (a_elements_per_col + kAComponentSizeForBlock32 - 1) / kAComponentSizeForBlock32; + let a_blocks_per_col = (a_elements_per_col + kAComponentsForBlock32 - 1) / kAComponentsForBlock32; // Utilizing an f32 accumulator mitigated precision loss with minimal // performance impact compared to an f16 accumulator. var results : array; for (var a_block_idx = 0u; a_block_idx < a_blocks_per_col; a_block_idx++) { - // Load `a` elements into workgroup memory, TileM x kAComponentSizeForBlock32 (block32) - let a_row_idx = local_idx / kAComponentSizeForBlock32; - let a_col_idx = local_idx % kAComponentSizeForBlock32; - a_data_tile[a_row_idx][a_col_idx] = mm_read_a(batch, row + a_row_idx, a_block_idx * kAComponentSizeForBlock32 + a_col_idx); + // Load `a` elements into workgroup memory, TileM x kAComponentsForBlock32 (block32) + let a_row_idx = local_idx / kAComponentsForBlock32; + let a_col_idx = local_idx % kAComponentsForBlock32; + a_data_tile[a_row_idx][a_col_idx] = mm_read_a(batch, row + a_row_idx, a_block_idx * kAComponentsForBlock32 + a_col_idx); workgroupBarrier(); let b_row = col + local_idx; From ae482a2a4458f461f247574ffd6c44a4cb5e5be0 Mon Sep 17 00:00:00 2001 From: Jianhui Dai Date: Wed, 26 Mar 2025 10:56:06 +0800 Subject: [PATCH 08/12] Add tile_m and tile_n into constructor --- .../contrib_ops/webgpu/quantization/matmul_nbits.cc | 9 ++++++--- .../contrib_ops/webgpu/quantization/matmul_nbits.h | 7 ++++++- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc index 3befe3e89ac7e..dae511fd836e9 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc +++ b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc @@ -531,7 +531,10 @@ Status MatMulNBitsBlockWideTileProgram::GenerateShaderCode(ShaderHelper& shader) // Bock size 32, `a` component size 4, 8 `a` components per block. constexpr uint32_t kAComponentsForBlock32 = 8; + const uint32_t workgroup_size = WorkgroupSizeX() * WorkgroupSizeY(); + ORT_ENFORCE(tile_m_ == workgroup_size / 8, "tile_m must be workgroup_size / 8."); + ORT_ENFORCE(tile_n_ == workgroup_size, "tile_n must be workgroup_size."); // memory read/write helpers shader.AdditionalImplementation() << "fn mm_read_a(batch : u32, row : u32, col : u32) -> input_a_value_t {\n"; @@ -568,8 +571,8 @@ Status MatMulNBitsBlockWideTileProgram::GenerateShaderCode(ShaderHelper& shader) shader.AdditionalImplementation() << "\n"; shader.AdditionalImplementation() << "// A block32 containing 8 components of `a`." << "\n"; shader.AdditionalImplementation() << "const kAComponentsForBlock32 = " << kAComponentsForBlock32 << "u;\n"; - shader.AdditionalImplementation() << "const tile_m = " << workgroup_size / 8 << "u;\n"; - shader.AdditionalImplementation() << "const tile_n = " << workgroup_size << "u;\n"; + shader.AdditionalImplementation() << "const tile_m = " << tile_m_ << "u;\n"; + shader.AdditionalImplementation() << "const tile_n = " << tile_n_ << "u;\n"; // declare workgroup memory shader.AdditionalImplementation() << "\n"; @@ -699,7 +702,7 @@ Status MatMulNBits::ComputeInternal(onnxruntime::webgpu::ComputeContext& context constexpr uint32_t tile_m = workgroup_size / 8; constexpr uint32_t tile_n = workgroup_size; - MatMulNBitsBlockWideTileProgram program{}; + MatMulNBitsBlockWideTileProgram program{tile_m, tile_n}; program.SetWorkgroupSize(workgroup_size); program.SetDispatchGroupSize((N + tile_n - 1) / tile_n, (M + tile_m - 1) / tile_m, diff --git a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.h b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.h index 30315fbd09c78..b0bc1a1a82f35 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.h +++ b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.h @@ -37,10 +37,15 @@ class MatMulNBitsProgram final : public Program { class MatMulNBitsBlockWideTileProgram final : public Program { public: - MatMulNBitsBlockWideTileProgram() : Program{"MatMulNBitsBlockWideTileProgram"} {} + MatMulNBitsBlockWideTileProgram(uint32_t tile_m, uint32_t tile_n) + : Program{"MatMulNBitsBlockWideTileProgram"}, tile_m_(tile_m), tile_n_(tile_n) {} Status GenerateShaderCode(ShaderHelper& sh) const override; WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"block_size", ProgramUniformVariableDataType::Uint32}); + + private: + uint32_t tile_m_; + uint32_t tile_n_; }; class MatMulNBits final : public WebGpuKernel { From 287be7e3b3a89726e55f95cb3db12d270b644337 Mon Sep 17 00:00:00 2001 From: Jianhui Dai Date: Wed, 26 Mar 2025 11:06:17 +0800 Subject: [PATCH 09/12] Rename to MatMulNBitsWideTileProgram --- onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc | 4 ++-- onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.h | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc index dae511fd836e9..1b0e40ed42ba7 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc +++ b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc @@ -523,7 +523,7 @@ Status MatMulNBitsProgram::GenerateShaderCode(ShaderHelper& shader) const { return Status::OK(); } -Status MatMulNBitsBlockWideTileProgram::GenerateShaderCode(ShaderHelper& shader) const { +Status MatMulNBitsWideTileProgram::GenerateShaderCode(ShaderHelper& shader) const { const auto& a = shader.AddInput("input_a", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias); const auto& b = shader.AddInput("input_b", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias); const auto& scales = shader.AddInput("scales", ShaderUsage::UseUniform); @@ -702,7 +702,7 @@ Status MatMulNBits::ComputeInternal(onnxruntime::webgpu::ComputeContext& context constexpr uint32_t tile_m = workgroup_size / 8; constexpr uint32_t tile_n = workgroup_size; - MatMulNBitsBlockWideTileProgram program{tile_m, tile_n}; + MatMulNBitsWideTileProgram program{tile_m, tile_n}; program.SetWorkgroupSize(workgroup_size); program.SetDispatchGroupSize((N + tile_n - 1) / tile_n, (M + tile_m - 1) / tile_m, diff --git a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.h b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.h index b0bc1a1a82f35..07c47021d516a 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.h +++ b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.h @@ -35,10 +35,10 @@ class MatMulNBitsProgram final : public Program { bool use_subgroup_; }; -class MatMulNBitsBlockWideTileProgram final : public Program { +class MatMulNBitsWideTileProgram final : public Program { public: - MatMulNBitsBlockWideTileProgram(uint32_t tile_m, uint32_t tile_n) - : Program{"MatMulNBitsBlockWideTileProgram"}, tile_m_(tile_m), tile_n_(tile_n) {} + MatMulNBitsWideTileProgram(uint32_t tile_m, uint32_t tile_n) + : Program{"MatMulNBitsWideTileProgram"}, tile_m_(tile_m), tile_n_(tile_n) {} Status GenerateShaderCode(ShaderHelper& sh) const override; WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"block_size", ProgramUniformVariableDataType::Uint32}); From ca1710a9fc1ac99d4325affe00cdac519370191c Mon Sep 17 00:00:00 2001 From: Jianhui Dai Date: Wed, 26 Mar 2025 12:15:36 +0800 Subject: [PATCH 10/12] Improve comment to reflect new naming --- onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc index 1b0e40ed42ba7..8824cad871f8f 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc +++ b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc @@ -688,13 +688,13 @@ Status MatMulNBits::ComputeInternal(onnxruntime::webgpu::ComputeContext& context return ApplyDP4AMatrixMatMulNBits(a, b, scales, M, N, K, block_size, kMinMForTileOptimization, context, y); } - // BlockWideTileProgram + // WideTileProgram // This program is optimized for Block32 prefill using Tile16x128. // TODO: loosen restrictions on batch_count, has_zero_points, and vendor. - const bool use_block_wide_tile_program = block_size == 32 && batch_count == 1 && !has_zero_points && + const bool use_wide_tile_program = block_size == 32 && batch_count == 1 && !has_zero_points && components_a == 4 && components_b == 4 && M >= kMinMForTileOptimization && context.AdapterInfo().vendor == std::string_view{"intel"}; - if (use_block_wide_tile_program) { + if (use_wide_tile_program) { // Enforce output components to 1. components = 1; From 17c0b1f32f516e14dc5b885a2e5dfcfc0a374d75 Mon Sep 17 00:00:00 2001 From: Jianhui Dai Date: Wed, 2 Apr 2025 06:49:44 +0800 Subject: [PATCH 11/12] Fix lint --- onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc index 8824cad871f8f..c4211a616badd 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc +++ b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc @@ -692,8 +692,8 @@ Status MatMulNBits::ComputeInternal(onnxruntime::webgpu::ComputeContext& context // This program is optimized for Block32 prefill using Tile16x128. // TODO: loosen restrictions on batch_count, has_zero_points, and vendor. const bool use_wide_tile_program = block_size == 32 && batch_count == 1 && !has_zero_points && - components_a == 4 && components_b == 4 && M >= kMinMForTileOptimization && - context.AdapterInfo().vendor == std::string_view{"intel"}; + components_a == 4 && components_b == 4 && M >= kMinMForTileOptimization && + context.AdapterInfo().vendor == std::string_view{"intel"}; if (use_wide_tile_program) { // Enforce output components to 1. components = 1; From e7f8bb433a8faa1995dc5fb48bcfe9a9563003d2 Mon Sep 17 00:00:00 2001 From: Jianhui Dai Date: Thu, 3 Apr 2025 18:20:52 +0800 Subject: [PATCH 12/12] Prefer onnxruntime::narrow --- onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc index c4211a616badd..54081ca7025d6 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc +++ b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc @@ -714,10 +714,10 @@ Status MatMulNBits::ComputeInternal(onnxruntime::webgpu::ComputeContext& context TensorShape reshaped_y_shape{batch_count, M, N / components}; program - .AddInputs({{a, ProgramTensorMetadataDependency::TypeAndRank, reshaped_a_shape, gsl::narrow(components_a)}, - {b, ProgramTensorMetadataDependency::TypeAndRank, reshaped_b_shape, gsl::narrow(components_b * 4)}, + .AddInputs({{a, ProgramTensorMetadataDependency::TypeAndRank, reshaped_a_shape, onnxruntime::narrow(components_a)}, + {b, ProgramTensorMetadataDependency::TypeAndRank, reshaped_b_shape, onnxruntime::narrow(components_b * 4)}, {scales, ProgramTensorMetadataDependency::None}}) - .AddOutput({y, ProgramTensorMetadataDependency::TypeAndRank, reshaped_y_shape, gsl::narrow(components)}) + .AddOutput({y, ProgramTensorMetadataDependency::TypeAndRank, reshaped_y_shape, onnxruntime::narrow(components)}) .AddUniformVariable({block_size}); return context.RunProgram(program); }