diff --git a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc index be105a0fd4374..54081ca7025d6 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc +++ b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc @@ -523,6 +523,125 @@ Status MatMulNBitsProgram::GenerateShaderCode(ShaderHelper& shader) const { return Status::OK(); } +Status MatMulNBitsWideTileProgram::GenerateShaderCode(ShaderHelper& shader) const { + const auto& a = shader.AddInput("input_a", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias); + const auto& b = shader.AddInput("input_b", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias); + const auto& scales = shader.AddInput("scales", ShaderUsage::UseUniform); + const auto& y = shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias | ShaderUsage::UseIndicesTypeAlias); + + // Bock size 32, `a` component size 4, 8 `a` components per block. + constexpr uint32_t kAComponentsForBlock32 = 8; + + const uint32_t workgroup_size = WorkgroupSizeX() * WorkgroupSizeY(); + ORT_ENFORCE(tile_m_ == workgroup_size / 8, "tile_m must be workgroup_size / 8."); + ORT_ENFORCE(tile_n_ == workgroup_size, "tile_n must be workgroup_size."); + + // memory read/write helpers + shader.AdditionalImplementation() << "fn mm_read_a(batch : u32, row : u32, col : u32) -> input_a_value_t {\n"; + shader.AdditionalImplementation() << " if (row < uniforms.input_a_shape[1] && col < uniforms.input_a_shape[2]) {\n"; + shader.AdditionalImplementation() << " return " << a.GetByIndices("input_a_indices_t(batch, row, col)") << ";\n"; + shader.AdditionalImplementation() << " }\n"; + shader.AdditionalImplementation() << " return input_a_value_t(0);\n"; + shader.AdditionalImplementation() << "}\n"; + + shader.AdditionalImplementation() << "\n"; + shader.AdditionalImplementation() << "fn mm_read_b(row : u32, col : u32) -> input_b_value_t {\n"; + shader.AdditionalImplementation() << " if (row < uniforms.input_b_shape[0] && col < uniforms.input_b_shape[1]) {\n"; + shader.AdditionalImplementation() << " return " << b.GetByIndices("input_b_indices_t(row, col, 0)") << ";\n"; + shader.AdditionalImplementation() << " }\n"; + shader.AdditionalImplementation() << " return input_b_value_t(0);\n"; + shader.AdditionalImplementation() << "}\n"; + + shader.AdditionalImplementation() << "\n"; + shader.AdditionalImplementation() << "fn mm_read_scale(row : u32, col : u32) -> output_value_t {\n"; + shader.AdditionalImplementation() << " if (row < uniforms.input_b_shape[0] && col < uniforms.input_b_shape[1]) {\n"; + shader.AdditionalImplementation() << " return " << scales.GetByOffset("row * uniforms.input_b_shape[1] + col") << ";\n"; + shader.AdditionalImplementation() << " }\n"; + shader.AdditionalImplementation() << " return output_value_t(0);\n"; + shader.AdditionalImplementation() << "}\n"; + + shader.AdditionalImplementation() << "\n"; + shader.AdditionalImplementation() << "fn mm_write_y(batch : u32, row : u32, col : u32, value : output_value_t) {\n"; + shader.AdditionalImplementation() << " if (row < uniforms.output_shape[1] && col < uniforms.output_shape[2]) {\n"; + shader.AdditionalImplementation() << " " << y.SetByIndices("output_indices_t(batch, row, col)", "value") << "\n"; + shader.AdditionalImplementation() << " }\n"; + shader.AdditionalImplementation() << "}\n"; + + // declare const variables + shader.AdditionalImplementation() << "\n"; + shader.AdditionalImplementation() << "// A block32 containing 8 components of `a`." << "\n"; + shader.AdditionalImplementation() << "const kAComponentsForBlock32 = " << kAComponentsForBlock32 << "u;\n"; + shader.AdditionalImplementation() << "const tile_m = " << tile_m_ << "u;\n"; + shader.AdditionalImplementation() << "const tile_n = " << tile_n_ << "u;\n"; + + // declare workgroup memory + shader.AdditionalImplementation() << "\n"; + shader.AdditionalImplementation() << "var a_data_tile: array, tile_m>;\n"; + shader.AdditionalImplementation() << "\n"; + + // main + shader.MainFunctionBody() << R"MAIN_FN( + let batch = workgroup_id.z; + let row = workgroup_id.y * tile_m; + let col = workgroup_id.x * tile_n; + + let a_elements_per_col = uniforms.input_a_shape[2]; + let a_blocks_per_col = (a_elements_per_col + kAComponentsForBlock32 - 1) / kAComponentsForBlock32; + + // Utilizing an f32 accumulator mitigated precision loss with minimal + // performance impact compared to an f16 accumulator. + var results : array; + for (var a_block_idx = 0u; a_block_idx < a_blocks_per_col; a_block_idx++) { + // Load `a` elements into workgroup memory, TileM x kAComponentsForBlock32 (block32) + let a_row_idx = local_idx / kAComponentsForBlock32; + let a_col_idx = local_idx % kAComponentsForBlock32; + a_data_tile[a_row_idx][a_col_idx] = mm_read_a(batch, row + a_row_idx, a_block_idx * kAComponentsForBlock32 + a_col_idx); + workgroupBarrier(); + + let b_row = col + local_idx; + let b_col = a_block_idx; + + let b_data = mm_read_b(b_row, b_col); + let scale = mm_read_scale(b_row, b_col); + let zero_point = output_element_t(8.0); + + // `b` component size is 4. + for (var b_idx = 0u; b_idx < 4u; b_idx++) { + let b_value = b_data[b_idx]; + let b_value_lower = unpack4xU8(b_value & 0x0F0F0F0Fu); + let b_value_upper = unpack4xU8((b_value >> 4u) & 0x0F0F0F0Fu); + let b_quantized_values = mat2x4( + output_element_t(b_value_lower[0]), output_element_t(b_value_upper[0]), + output_element_t(b_value_lower[1]), output_element_t(b_value_upper[1]), + output_element_t(b_value_lower[2]), output_element_t(b_value_upper[2]), + output_element_t(b_value_lower[3]), output_element_t(b_value_upper[3])); + let b_dequantized_values = + (b_quantized_values - mat2x4(zero_point, zero_point, + zero_point, zero_point, + zero_point, zero_point, + zero_point, zero_point)) * scale; + + for (var m_idx = 0u; m_idx < tile_m; m_idx++) { + let a_data0 = a_data_tile[m_idx][b_idx * 2u]; + let a_data1 = a_data_tile[m_idx][b_idx * 2u + 1u]; + + results[m_idx] += f32(dot(a_data0, b_dequantized_values[0u])) + + f32(dot(a_data1, b_dequantized_values[1u])); + } + } + + workgroupBarrier(); + } + + // write the results + for (var m_idx = 0u; m_idx < tile_m; m_idx++) { + mm_write_y(batch, row + m_idx, col + local_idx, output_value_t(results[m_idx])); + } +)MAIN_FN"; + + return Status::OK(); +} + Status MatMulNBits::ComputeInternal(onnxruntime::webgpu::ComputeContext& context) const { const Tensor* a = context.Input(0); const Tensor* b = context.Input(1); @@ -569,6 +688,41 @@ Status MatMulNBits::ComputeInternal(onnxruntime::webgpu::ComputeContext& context return ApplyDP4AMatrixMatMulNBits(a, b, scales, M, N, K, block_size, kMinMForTileOptimization, context, y); } + // WideTileProgram + // This program is optimized for Block32 prefill using Tile16x128. + // TODO: loosen restrictions on batch_count, has_zero_points, and vendor. + const bool use_wide_tile_program = block_size == 32 && batch_count == 1 && !has_zero_points && + components_a == 4 && components_b == 4 && M >= kMinMForTileOptimization && + context.AdapterInfo().vendor == std::string_view{"intel"}; + if (use_wide_tile_program) { + // Enforce output components to 1. + components = 1; + + constexpr uint32_t workgroup_size = 128; + constexpr uint32_t tile_m = workgroup_size / 8; + constexpr uint32_t tile_n = workgroup_size; + + MatMulNBitsWideTileProgram program{tile_m, tile_n}; + program.SetWorkgroupSize(workgroup_size); + program.SetDispatchGroupSize((N + tile_n - 1) / tile_n, + (M + tile_m - 1) / tile_m, + batch_count); + program.CacheHint("Tile" + std::to_string(tile_m) + "x" + std::to_string(tile_n) + "_Block32"); + + TensorShape reshaped_a_shape{batch_count, M, K / components_a}; + TensorShape reshaped_b_shape{N, n_blocks_per_col, blob_size_in_words / components_b}; + TensorShape reshaped_y_shape{batch_count, M, N / components}; + + program + .AddInputs({{a, ProgramTensorMetadataDependency::TypeAndRank, reshaped_a_shape, onnxruntime::narrow(components_a)}, + {b, ProgramTensorMetadataDependency::TypeAndRank, reshaped_b_shape, onnxruntime::narrow(components_b * 4)}, + {scales, ProgramTensorMetadataDependency::None}}) + .AddOutput({y, ProgramTensorMetadataDependency::TypeAndRank, reshaped_y_shape, onnxruntime::narrow(components)}) + .AddUniformVariable({block_size}); + return context.RunProgram(program); + } + + // Generic program // TODO: Support output_number > 1. Some cases are failed when output_number > 1. constexpr uint32_t output_number = 1; const uint32_t tile_m = M > kMinMForTileOptimization ? 4 : 1; diff --git a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.h b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.h index 10221e19c7400..07c47021d516a 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.h +++ b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.h @@ -35,6 +35,19 @@ class MatMulNBitsProgram final : public Program { bool use_subgroup_; }; +class MatMulNBitsWideTileProgram final : public Program { + public: + MatMulNBitsWideTileProgram(uint32_t tile_m, uint32_t tile_n) + : Program{"MatMulNBitsWideTileProgram"}, tile_m_(tile_m), tile_n_(tile_n) {} + + Status GenerateShaderCode(ShaderHelper& sh) const override; + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"block_size", ProgramUniformVariableDataType::Uint32}); + + private: + uint32_t tile_m_; + uint32_t tile_n_; +}; + class MatMulNBits final : public WebGpuKernel { public: MatMulNBits(const OpKernelInfo& info) : WebGpuKernel(info) {