perf: Optimize quantized matmul for small decode batches#122
Open
Li0k wants to merge 1 commit into
Open
Conversation
2db7941 to
79c9c41
Compare
Author
|
@Connor1996 Could you take a look? Looking forward to your comments. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
This PR optimizes the reference 4-bit quantized matmul Metal launch shape for small-M decode workloads.
The existing kernel always uses a 32-row tile in the M dimension:
During decode, M is often small (
M = 1for single-request decode, or small batch sizes for continuous batching). In those cases, most M-dimension lanes are idle. This PR keeps the total threadgroup size unchanged, but shifts lanes from the M dimension to the K/output-column dimension when M is small:For larger M, this keeps the original 32-row tile.
Benchmark support
I also added a synthetic-token batch decode benchmark path to
bench.py:--batch-decoderuns the Week 2 continuous-batching decode path.BatchingKvCacheand batched offsets, matching the Week 2 batching task shape.--num-seqs >= --batch-sizeso larger batch-size measurements are not accidentally under-filled.The original non-batch benchmark path remains the default.
Benchmark setup
Model and command shape:
Fixed workload:
Results
The improvement is concentrated in small-M decode cases. At batch size 16, the result is essentially flat, which is expected because this PR only changes the tile shape for
M <= 8.Validation
Correctness tests:
PYTHONPATH=src pdm run test-refsol --week 2 --day 2 -- -k quantized_matmul -q PYTHONPATH=src pdm run test-refsol --week 2 --day 4 -- -k flash_attention -q PYTHONPATH=src pdm run test-refsol --week 2 --day 6 -- -k "batching_kv_cache or qwen3_0_6b" -qSmoke checks: