Skip to content
154 changes: 154 additions & 0 deletions onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc
Original file line number Diff line number Diff line change
Expand Up @@ -523,6 +523,125 @@
return Status::OK();
}

Status MatMulNBitsWideTileProgram::GenerateShaderCode(ShaderHelper& shader) const {
const auto& a = shader.AddInput("input_a", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias);
const auto& b = shader.AddInput("input_b", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias);
const auto& scales = shader.AddInput("scales", ShaderUsage::UseUniform);
const auto& y = shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias | ShaderUsage::UseIndicesTypeAlias);

// Bock size 32, `a` component size 4, 8 `a` components per block.
constexpr uint32_t kAComponentsForBlock32 = 8;

const uint32_t workgroup_size = WorkgroupSizeX() * WorkgroupSizeY();
ORT_ENFORCE(tile_m_ == workgroup_size / 8, "tile_m must be workgroup_size / 8.");
ORT_ENFORCE(tile_n_ == workgroup_size, "tile_n must be workgroup_size.");

// memory read/write helpers
shader.AdditionalImplementation() << "fn mm_read_a(batch : u32, row : u32, col : u32) -> input_a_value_t {\n";
shader.AdditionalImplementation() << " if (row < uniforms.input_a_shape[1] && col < uniforms.input_a_shape[2]) {\n";
shader.AdditionalImplementation() << " return " << a.GetByIndices("input_a_indices_t(batch, row, col)") << ";\n";
shader.AdditionalImplementation() << " }\n";
shader.AdditionalImplementation() << " return input_a_value_t(0);\n";
shader.AdditionalImplementation() << "}\n";

shader.AdditionalImplementation() << "\n";
shader.AdditionalImplementation() << "fn mm_read_b(row : u32, col : u32) -> input_b_value_t {\n";
shader.AdditionalImplementation() << " if (row < uniforms.input_b_shape[0] && col < uniforms.input_b_shape[1]) {\n";
shader.AdditionalImplementation() << " return " << b.GetByIndices("input_b_indices_t(row, col, 0)") << ";\n";
shader.AdditionalImplementation() << " }\n";
shader.AdditionalImplementation() << " return input_b_value_t(0);\n";
shader.AdditionalImplementation() << "}\n";

shader.AdditionalImplementation() << "\n";
shader.AdditionalImplementation() << "fn mm_read_scale(row : u32, col : u32) -> output_value_t {\n";
shader.AdditionalImplementation() << " if (row < uniforms.input_b_shape[0] && col < uniforms.input_b_shape[1]) {\n";
shader.AdditionalImplementation() << " return " << scales.GetByOffset("row * uniforms.input_b_shape[1] + col") << ";\n";
shader.AdditionalImplementation() << " }\n";
shader.AdditionalImplementation() << " return output_value_t(0);\n";
shader.AdditionalImplementation() << "}\n";

shader.AdditionalImplementation() << "\n";
shader.AdditionalImplementation() << "fn mm_write_y(batch : u32, row : u32, col : u32, value : output_value_t) {\n";
shader.AdditionalImplementation() << " if (row < uniforms.output_shape[1] && col < uniforms.output_shape[2]) {\n";
shader.AdditionalImplementation() << " " << y.SetByIndices("output_indices_t(batch, row, col)", "value") << "\n";
shader.AdditionalImplementation() << " }\n";
shader.AdditionalImplementation() << "}\n";

// declare const variables
shader.AdditionalImplementation() << "\n";
shader.AdditionalImplementation() << "// A block32 containing 8 components of `a`." << "\n";
shader.AdditionalImplementation() << "const kAComponentsForBlock32 = " << kAComponentsForBlock32 << "u;\n";
shader.AdditionalImplementation() << "const tile_m = " << tile_m_ << "u;\n";
shader.AdditionalImplementation() << "const tile_n = " << tile_n_ << "u;\n";

// declare workgroup memory
shader.AdditionalImplementation() << "\n";
shader.AdditionalImplementation() << "var<workgroup> a_data_tile: array<array<input_a_value_t, kAComponentsForBlock32>, tile_m>;\n";
shader.AdditionalImplementation() << "\n";

// main
shader.MainFunctionBody() << R"MAIN_FN(
let batch = workgroup_id.z;
let row = workgroup_id.y * tile_m;
let col = workgroup_id.x * tile_n;

let a_elements_per_col = uniforms.input_a_shape[2];
let a_blocks_per_col = (a_elements_per_col + kAComponentsForBlock32 - 1) / kAComponentsForBlock32;

// Utilizing an f32 accumulator mitigated precision loss with minimal
// performance impact compared to an f16 accumulator.
var results : array<f32, tile_m>;
for (var a_block_idx = 0u; a_block_idx < a_blocks_per_col; a_block_idx++) {
// Load `a` elements into workgroup memory, TileM x kAComponentsForBlock32 (block32)
let a_row_idx = local_idx / kAComponentsForBlock32;
let a_col_idx = local_idx % kAComponentsForBlock32;
a_data_tile[a_row_idx][a_col_idx] = mm_read_a(batch, row + a_row_idx, a_block_idx * kAComponentsForBlock32 + a_col_idx);
workgroupBarrier();

let b_row = col + local_idx;
let b_col = a_block_idx;

let b_data = mm_read_b(b_row, b_col);
let scale = mm_read_scale(b_row, b_col);
let zero_point = output_element_t(8.0);

// `b` component size is 4.
for (var b_idx = 0u; b_idx < 4u; b_idx++) {
let b_value = b_data[b_idx];
let b_value_lower = unpack4xU8(b_value & 0x0F0F0F0Fu);
let b_value_upper = unpack4xU8((b_value >> 4u) & 0x0F0F0F0Fu);
let b_quantized_values = mat2x4<output_element_t>(
output_element_t(b_value_lower[0]), output_element_t(b_value_upper[0]),
output_element_t(b_value_lower[1]), output_element_t(b_value_upper[1]),
output_element_t(b_value_lower[2]), output_element_t(b_value_upper[2]),
output_element_t(b_value_lower[3]), output_element_t(b_value_upper[3]));
let b_dequantized_values =
(b_quantized_values - mat2x4<output_element_t>(zero_point, zero_point,
zero_point, zero_point,
zero_point, zero_point,
zero_point, zero_point)) * scale;

for (var m_idx = 0u; m_idx < tile_m; m_idx++) {
let a_data0 = a_data_tile[m_idx][b_idx * 2u];
let a_data1 = a_data_tile[m_idx][b_idx * 2u + 1u];

results[m_idx] += f32(dot(a_data0, b_dequantized_values[0u])) +
f32(dot(a_data1, b_dequantized_values[1u]));
}
}

workgroupBarrier();
}

// write the results
for (var m_idx = 0u; m_idx < tile_m; m_idx++) {
mm_write_y(batch, row + m_idx, col + local_idx, output_value_t(results[m_idx]));
}
)MAIN_FN";

return Status::OK();
}

Status MatMulNBits::ComputeInternal(onnxruntime::webgpu::ComputeContext& context) const {
const Tensor* a = context.Input(0);
const Tensor* b = context.Input(1);
Expand Down Expand Up @@ -569,6 +688,41 @@
return ApplyDP4AMatrixMatMulNBits(a, b, scales, M, N, K, block_size, kMinMForTileOptimization, context, y);
}

// WideTileProgram
// This program is optimized for Block32 prefill using Tile16x128.
// TODO: loosen restrictions on batch_count, has_zero_points, and vendor.

Check warning on line 693 in onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Missing username in TODO; it should look like "// TODO(my_username): Stuff." [readability/todo] [2] Raw Output: onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc:693: Missing username in TODO; it should look like "// TODO(my_username): Stuff." [readability/todo] [2]
const bool use_wide_tile_program = block_size == 32 && batch_count == 1 && !has_zero_points &&
components_a == 4 && components_b == 4 && M >= kMinMForTileOptimization &&
context.AdapterInfo().vendor == std::string_view{"intel"};
if (use_wide_tile_program) {
// Enforce output components to 1.
components = 1;

constexpr uint32_t workgroup_size = 128;
constexpr uint32_t tile_m = workgroup_size / 8;
constexpr uint32_t tile_n = workgroup_size;

MatMulNBitsWideTileProgram program{tile_m, tile_n};
program.SetWorkgroupSize(workgroup_size);
program.SetDispatchGroupSize((N + tile_n - 1) / tile_n,
(M + tile_m - 1) / tile_m,
batch_count);
program.CacheHint("Tile" + std::to_string(tile_m) + "x" + std::to_string(tile_n) + "_Block32");

TensorShape reshaped_a_shape{batch_count, M, K / components_a};
TensorShape reshaped_b_shape{N, n_blocks_per_col, blob_size_in_words / components_b};
TensorShape reshaped_y_shape{batch_count, M, N / components};

program
.AddInputs({{a, ProgramTensorMetadataDependency::TypeAndRank, reshaped_a_shape, onnxruntime::narrow<int>(components_a)},
{b, ProgramTensorMetadataDependency::TypeAndRank, reshaped_b_shape, onnxruntime::narrow<int>(components_b * 4)},
{scales, ProgramTensorMetadataDependency::None}})
.AddOutput({y, ProgramTensorMetadataDependency::TypeAndRank, reshaped_y_shape, onnxruntime::narrow<int>(components)})
.AddUniformVariable({block_size});
return context.RunProgram(program);
}

// Generic program
// TODO: Support output_number > 1. Some cases are failed when output_number > 1.
constexpr uint32_t output_number = 1;
const uint32_t tile_m = M > kMinMForTileOptimization ? 4 : 1;
Expand Down
13 changes: 13 additions & 0 deletions onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,19 @@ class MatMulNBitsProgram final : public Program<MatMulNBitsProgram> {
bool use_subgroup_;
};

class MatMulNBitsWideTileProgram final : public Program<MatMulNBitsWideTileProgram> {
public:
MatMulNBitsWideTileProgram(uint32_t tile_m, uint32_t tile_n)
: Program{"MatMulNBitsWideTileProgram"}, tile_m_(tile_m), tile_n_(tile_n) {}

Status GenerateShaderCode(ShaderHelper& sh) const override;
WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"block_size", ProgramUniformVariableDataType::Uint32});

private:
uint32_t tile_m_;
uint32_t tile_n_;
};

class MatMulNBits final : public WebGpuKernel {
public:
MatMulNBits(const OpKernelInfo& info) : WebGpuKernel(info) {
Expand Down
Loading