Skip to content

[webgpu] Optimize MatMulNBits for f16 Block32 prefill performance#23908

Merged
guschmue merged 12 commits intomicrosoft:mainfrom
daijh:matmul-f16-block32-prefill
Apr 4, 2025
Merged

[webgpu] Optimize MatMulNBits for f16 Block32 prefill performance#23908
guschmue merged 12 commits intomicrosoft:mainfrom
daijh:matmul-f16-block32-prefill

Conversation

@daijh
Copy link
Contributor

@daijh daijh commented Mar 6, 2025

Description

This commit improve the MatMulNBits f16 Block32 prefill performance, by increasing tiling size and enhancing memory efficiency. Achieved a +2x performance boost on Intel iGPUs for Phi-3.5-mini f16 model.

Motivation and Context

See above.

@daijh
Copy link
Contributor Author

daijh commented Mar 6, 2025

Tests:

model_benchmark.exe -i Phi-3.5-mini-instruct-onnx-web -l 1000
Prompt-1000 Prefill-default (tps) Prefill-opt (tps)
LNL 14.5829 327.627
MTL 61.4833 160.695
ADL 45.1106 101.871

@daijh
Copy link
Contributor Author

daijh commented Mar 6, 2025

@qjia7 @sushraja-msft @jchen10
Please take a look, thanks.

@daijh
Copy link
Contributor Author

daijh commented Mar 6, 2025

Add shader for easy review.

enable f16;
enable subgroups_f16;
enable subgroups;
const workgroup_size_x: u32 = 128;
const workgroup_size_y: u32 = 1;
const workgroup_size_z: u32 = 1;
@group(0) @binding(0) var<storage, read> input_a: array<vec4<f16>>;
@group(0) @binding(1) var<storage, read> input_b: array<vec4<u32>>;
@group(0) @binding(2) var<storage, read> scales: array<f16>;
@group(0) @binding(3) var<storage, read_write> output: array<f16>;
struct Uniforms {
  input_a_shape: vec3<u32>,
  input_a_stride: vec2<u32>,
  input_b_shape: vec3<u32>,
  input_b_stride: vec2<u32>,
  output_shape: vec3<u32>,
  output_stride: vec2<u32>,
  block_size: u32
};
@group(0) @binding(4) var<uniform> uniforms: Uniforms;

alias input_a_value_t = vec4<f16>;
alias input_a_indices_t = vec3<u32>;
fn i2o_input_a(indices : input_a_indices_t)->u32 {
  return indices[0] * uniforms.input_a_stride[0] + indices[1] * uniforms.input_a_stride[1] + indices[2];
}
fn get_input_a_by_indices(indices: input_a_indices_t)->input_a_value_t {
  return input_a[i2o_input_a(indices)];
}
alias input_b_value_t = vec4<u32>;
alias input_b_indices_t = vec3<u32>;
fn i2o_input_b(indices : input_b_indices_t)->u32 {
  return indices[0] * uniforms.input_b_stride[0] + indices[1] * uniforms.input_b_stride[1] + indices[2];
}
fn get_input_b_by_indices(indices: input_b_indices_t)->input_b_value_t {
  return input_b[i2o_input_b(indices)];
}
alias output_value_t = f16;
alias output_indices_t = vec3<u32>;
alias output_element_t = f16;
fn i2o_output(indices : output_indices_t)->u32 {
  return indices[0] * uniforms.output_stride[0] + indices[1] * uniforms.output_stride[1] + indices[2];
}
fn set_output_by_indices(indices: output_indices_t, value: output_value_t) {
  output[i2o_output(indices)]=value;
}

fn mm_read_a(batch : u32, row : u32, col : u32) -> input_a_value_t {
  if (row < uniforms.input_a_shape[1] && col < uniforms.input_a_shape[2]) {
    return get_input_a_by_indices(input_a_indices_t(batch, row, col));
  }
  return input_a_value_t(0);
}

fn mm_read_b(row : u32, col : u32) -> input_b_value_t {
  if (row < uniforms.input_b_shape[0] && col < uniforms.input_b_shape[1]) {
    return get_input_b_by_indices(input_b_indices_t(row, col, 0));
  }
  return input_b_value_t(0);
}

fn mm_read_scale(row : u32, col : u32) -> output_value_t {
  if (row < uniforms.input_b_shape[0] && col < uniforms.input_b_shape[1]) {
    return scales[row * uniforms.input_b_shape[1] + col];
  }
  return output_value_t(0);
}

fn mm_write_y(batch : u32, row : u32, col : u32, value : output_value_t) {
  if (row < uniforms.output_shape[1] && col < uniforms.output_shape[2]) {
    set_output_by_indices(output_indices_t(batch, row, col), value);
  }
}

const tile_m = 16u;
const tile_n = 128u;

var<workgroup> a_data_wg: array<array<input_a_value_t, 8u>, tile_m>;

@compute @workgroup_size(workgroup_size_x, workgroup_size_y, workgroup_size_z)
fn main(@builtin(global_invocation_id) global_id : vec3<u32>,
        @builtin(workgroup_id) workgroup_id : vec3<u32>,
        @builtin(local_invocation_index) local_idx : u32,
        @builtin(local_invocation_id) local_id : vec3<u32>,
        @builtin(subgroup_invocation_id) sg_id : u32,
        @builtin(subgroup_size) sg_size : u32,
        @builtin(num_workgroups) num_workgroups : vec3<u32>) {
  let workgroup_idx = workgroup_id.z * num_workgroups[0] * num_workgroups[1] + workgroup_id.y * num_workgroups[0] + workgroup_id.x;
  let global_idx = workgroup_idx * (workgroup_size_x * workgroup_size_y * workgroup_size_z) + local_idx;

  let batch = workgroup_id.z;
  let row = workgroup_id.y * tile_m;
  let col = workgroup_id.x * tile_n;

  let a_elements_per_col = uniforms.input_a_shape[2];
  // A block32 containing 8 elements of `a`.
  let a_blocks_per_col = (a_elements_per_col + 7u) / 8u;

  // f32 accumulator
  var results : array<f32, tile_m>;
  for (var a_block_idx = 0u; a_block_idx < a_blocks_per_col; a_block_idx++) {
    // load `a` elements into workgroup memory, TileM x 8(block32).
    let a_row_idx = local_idx / 8u;
    let a_col_idx = local_idx % 8u;
    a_data_wg[a_row_idx][a Effect_col_idx] = mm_read_a(batch, row + a_row_idx, a_block_idx * 8u + a_col_idx);
    workgroupBarrier();

    let b_row = col + local_idx;
    let b_col = a_block_idx;

    let b_data = mm_read_b(b_row, b_col);
    let scale = mm_read_scale(b_row, b_col);
    let zero_point = output_element_t(8.0);

    for (var b_idx = 0u; b_idx < 4u; b_idx++) {
      let b_value = b_data[b_idx];
      let b_value_lower = unpack4xU8(b_value & 0x0F0F0F0Fu);
      let b_value_upper = unpack4xU8((b_value >> 4u) & 0x0F0F0F0Fu);
      let b_quantized_values = mat2x4<output_element_t>(
          output_element_t(b_value_lower[0]), output_element_t(b_value_upper[0]),
          output_element_t(b_value_lower[1]), output_element_t(b_value_upper[1]),
          output_element_t(b_value_lower[2]), output_element_t(b_value_upper[2]),
          output_element_t(b_value_lower[3]), output_element_t(b_value_upper[3]));
      let b_dequantized_values =
          (b_quantized_values - mat2x4<output_element_t>(zero_point, zero_point,
                                                        zero_point, zero_point,
                                                        zero_point, zero_point,
                                                        zero_point, zero_point)) * scale;

      for (var m_idx = 0u; m_idx < tile_m; m_idx++) {
        let a_data0 = a_data_wg[m_idx][b_idx * 2u];
        let a_data1 = a_data_wg[m_idx][b_idx * 2u + 1u];

        results[m_idx] += f32(dot(a_data0, b_dequantized_values[0u])) +
                          f32(dot(a_data1, b_dequantized_values[1u]));
      }
    }

    workgroupBarrier();
  }

  // write the results
  for (var m_idx = 0u; m_idx < tile_m; m_idx++) {
    mm_write_y(batch, row + m_idx, col + local_idx, output_value_t(results[m_idx]));
  }

}

@guschmue guschmue added the ep:WebGPU ort-web webgpu provider label Mar 6, 2025
@guschmue
Copy link
Contributor

guschmue commented Mar 6, 2025

/azp run ONNX Runtime Web CI Pipeline,Windows GPU CI Pipeline,Linux Android Emulator QNN CI Pipeline

@azure-pipelines
Copy link

Azure Pipelines successfully started running 2 pipeline(s).

@guschmue
Copy link
Contributor

guschmue commented Mar 6, 2025

/azp run Linux CPU CI Pipeline,Linux CPU Minimal Build E2E CI Pipeline,Linux GPU CI Pipeline,Linux GPU TensorRT CI Pipeline,Linux OpenVINO CI Pipeline,Linux QNN CI Pipeline,MacOS CI Pipeline,Windows ARM64 QNN CI Pipeline,Windows CPU CI Pipeline

@guschmue
Copy link
Contributor

guschmue commented Mar 6, 2025

/azp run Windows GPU TensorRT CI Pipeline,onnxruntime-binary-size-checks-ci-pipeline,orttraining-linux-ci-pipeline,orttraining-linux-gpu-ci-pipeline,orttraining-ortmodule-distributed,Windows x64 QNN CI Pipeline,Big Models

@guschmue
Copy link
Contributor

guschmue commented Mar 6, 2025

/azp run Windows GPU CUDA CI Pipeline,Windows GPU DML CI Pipeline,Windows GPU Doc Gen CI Pipeline, Win_TRT_Minimal_CUDA_Test_CI

@azure-pipelines
Copy link

Azure Pipelines successfully started running 4 pipeline(s).

@azure-pipelines
Copy link

Azure Pipelines successfully started running 9 pipeline(s).

@azure-pipelines
Copy link

Azure Pipelines successfully started running 4 pipeline(s).

@daijh
Copy link
Contributor Author

daijh commented Mar 7, 2025

From what I can tell yours is a generation mode shader, if you are seeing good perf with this tile size -we should just replace the current generation shader with yours. Even better if we can make these shaders have the tile sizes as tunable.

I'm trying to avoid making too many modifications in a single PR to keep it easier review, and comparable with previous shader.
If accepted, I'll subsequently integrate its improvements into the default shader prefill path (as decode performance is not improved).

What are your thoughts?

@daijh
Copy link
Contributor Author

daijh commented Mar 7, 2025

As to why you are seeing great prefill speed, its because our prefill fp16 shader is not based on co-operative matmul (we havent got around to rewriting that shader that way, if you can pick that up that would be amazing as well). The DP4A matmul shader is using techniques of co-operative matmul, and we are using that for many models by passing accuracy_level 4 with model_builder.py.

Yes, we observed quite good performance at accuracy level 4 using the DP4A shader. I'll investigate similar for f16.

@guschmue
Copy link
Contributor

I can capture some perf numbers as well

@sushraja-msft
Copy link
Contributor

From what I can tell yours is a generation mode shader, if you are seeing good perf with this tile size -we should just replace the current generation shader with yours. Even better if we can make these shaders have the tile sizes as tunable.

I'm trying to avoid making too many modifications in a single PR to keep it easier review, and comparable with previous shader. If accepted, I'll subsequently integrate its improvements into the default shader prefill path (as decode performance is not improved).

What are your thoughts?

that's acceptable, perhaps name this MatMulNBitsBlock32Program > MatMulNBitsBlockWideTileProgram and land this PR and then work towards making this the default prefill program on all platforms. Ill review the shader

@daijh
Copy link
Contributor Author

daijh commented Mar 12, 2025

From what I can tell yours is a generation mode shader, if you are seeing good perf with this tile size -we should just replace the current generation shader with yours. Even better if we can make these shaders have the tile sizes as tunable.

I'm trying to avoid making too many modifications in a single PR to keep it easier review, and comparable with previous shader. If accepted, I'll subsequently integrate its improvements into the default shader prefill path (as decode performance is not improved).
What are your thoughts?

that's acceptable, perhaps name this MatMulNBitsBlock32Program > MatMulNBitsBlockWideTileProgram and land this PR and then work towards making this the default prefill program on all platforms. Ill review the shader

Sure. Thanks.

@daijh
Copy link
Contributor Author

daijh commented Mar 12, 2025

I can capture some perf numbers as well

@guschmue thanks, please let me know if any issues.

@daijh daijh force-pushed the matmul-f16-block32-prefill branch from 8a250db to 74da290 Compare March 12, 2025 07:54
@guschmue
Copy link
Contributor

/azp run ONNX Runtime Web CI Pipeline,Windows GPU CI Pipeline,Linux Android Emulator QNN CI Pipeline

@guschmue
Copy link
Contributor

/azp run Linux CPU CI Pipeline,Linux CPU Minimal Build E2E CI Pipeline,Linux GPU CI Pipeline,Linux GPU TensorRT CI Pipeline, Linux OpenVINO CI Pipeline,Linux QNN CI Pipeline,MacOS CI Pipeline,Windows ARM64 QNN CI Pipeline,Windows CPU CI Pipeline

@guschmue
Copy link
Contributor

/azp run Windows GPU TensorRT CI Pipeline,onnxruntime-binary-size-checks-ci-pipeline,orttraining-linux-ci-pipeline,orttraining-linux-gpu-ci-pipeline,orttraining-ortmodule-distributed,Windows x64 QNN CI Pipeline,Big Models

@azure-pipelines
Copy link

Azure Pipelines successfully started running 2 pipeline(s).

@guschmue
Copy link
Contributor

/azp run Windows GPU CUDA CI Pipeline,Windows GPU DML CI Pipeline,Windows GPU Doc Gen CI Pipeline, Win_TRT_Minimal_CUDA_Test_CI

@azure-pipelines
Copy link

Azure Pipelines successfully started running 4 pipeline(s).

1 similar comment
@azure-pipelines
Copy link

Azure Pipelines successfully started running 4 pipeline(s).

@azure-pipelines
Copy link

Azure Pipelines successfully started running 9 pipeline(s).

@daijh
Copy link
Contributor Author

daijh commented Mar 26, 2025

Resolved existing comment. Please take another look, thanks.

@sushraja-msft @guschmue @qjia7

@daijh
Copy link
Contributor Author

daijh commented Mar 28, 2025

@guschmue could you have a look as well, and apply this PR?

@guschmue
Copy link
Contributor

guschmue commented Apr 1, 2025

CI pipelines changes - can you merge with main?

@daijh daijh force-pushed the matmul-f16-block32-prefill branch from 4d3801f to ca1710a Compare April 1, 2025 04:04
@daijh
Copy link
Contributor Author

daijh commented Apr 1, 2025

CI pipelines changes - can you merge with main?

Rebase to main. Please help to re-trigger the CI, thanks.

@guschmue
Copy link
Contributor

guschmue commented Apr 1, 2025

lint issue, wants you to run
lintrunner -a
If your local lintrunner doesn't complain it is maybe an older version (I ran into this earlier today)

@daijh
Copy link
Contributor Author

daijh commented Apr 1, 2025

lint issue, wants you to run lintrunner -a If your local lintrunner doesn't complain it is maybe an older version (I ran into this earlier today)

Fixed the lint issues.

@daijh
Copy link
Contributor Author

daijh commented Apr 2, 2025

The logs of CI failure shows it's likely a result of infrastructure instability, and does not to be related to the changes.

Copy link
Contributor

@guschmue guschmue left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

one more change

@guschmue
Copy link
Contributor

guschmue commented Apr 3, 2025

/azp run Windows x64 QNN CI Pipeline,Linux QNN CI Pipeline,Win_TRT_Minimal_CUDA_Test_CI,Windows ARM64 QNN CI Pipeline,Windows GPU Doc Gen CI Pipeline

@azure-pipelines
Copy link

Azure Pipelines successfully started running 5 pipeline(s).

@guschmue guschmue merged commit 3dfc2ae into microsoft:main Apr 4, 2025
69 checks passed
@daijh daijh deleted the matmul-f16-block32-prefill branch April 4, 2025 03:51
zhaoxul-qti pushed a commit to CodeLinaro/onnxruntime that referenced this pull request Apr 17, 2025
…crosoft#23908)

### Description
This commit improve the MatMulNBits f16 Block32 prefill performance, by
increasing tiling size and enhancing memory efficiency. Achieved a +2x
performance boost on Intel iGPUs for Phi-3.5-mini f16 model.

### Motivation and Context
See above.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ep:WebGPU ort-web webgpu provider

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants