Add Triton kernel benchmark suite#493
Open
jlamypoirier wants to merge 30 commits intomainfrom
Open
Conversation
Adds `tools/benchmark/` — a micro-benchmark harness for Fast-LLM's Triton kernels that measures throughput (GB/s, % peak BW, TFLOP/s) and checks numerical correctness against a fp32 reference for each kernel variant. Kernels covered: - entropy loss: cross_entropy (labels), cross_entropy (logits/distillation), reverse_kl (logits), and z_loss - normalization: LayerNorm and RMSNorm (fwd+bwd) - MLP activation: gated SiLU fused kernel (fwd+bwd) - rotary embeddings: in-place Triton kernel vs PyTorch eager/compiled - pointwise: cast-add-cast fused kernel Each benchmark compares fp32_reference, pytorch_eager, pytorch_compiled, pytorch_compiled_max, apex variants (where available), and fast_llm_triton. The runner auto-detects GPU peak bandwidth from device properties, reports % of peak BW per variant, and flags numerical deviations from the reference. Also adds `DataType.short` property (bf16/fp32/…) used by benchmark case names, and `tools/__init__.py` to make `tools` a package for `python -m tools.benchmark`. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Adds tools/benchmark/bench_sparse_linear.py covering the two sparse GEMM kernels in MoE FFN layers (output_sparse / up-proj and input_inner_sparse / down-proj), comparing fast_llm_triton against a PyTorch loop reference and torch.compile. Fixes three kernel correctness bugs surfaced by the new benchmark: 1. Phantom blocks left output uninitialized. Both output_sparse_matmul_kernel and input_inner_sparse_matmul_kernel short-circuited (`return`) on blocks past the last expert, leaving those rows of the caller-allocated output buffer with whatever garbage happened to be at that GPU address. Production silently discarded the garbage at the scatter-back boundary, but it is a latent footgun for any caller that reads the full output tensor (and what made the benchmark comparison nondeterministic). The skipped blocks now write zeros instead of returning, so the output is fully defined regardless of how the caller allocated it. 2. Inner-loop tl.dot accumulated in bfloat16. The first tl.dot in each of the three kernels already specified `out_dtype=tl.float32`, but the loop bodies did not, so accumulation past the first tile silently fell back to bf16 Tensor Core accumulation. Added `out_dtype=tl.float32` to every tl.dot call. 3. Backward used wrong argument order for input_row_sparse_matmul. The grad_rhs path in OutputSparseLinear and InputSparseLinear had lhs/grad_out swapped relative to what the kernel expects. After these fixes every rel_rms in the sparse_linear benchmark goes to 0 across all six (op, shape) configurations. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
- bench_grpo_loss.py: fused triton_grpo_loss_forward_backward kernel vs PyTorch eager / compiled, swept over vocab=32K/64K/128K (matches bench_entropy_loss shapes since GRPO is structurally similar). - bench_sparse_copy.py: MoE token dispatch (dense->sparse) and combine (sparse->dense) via copy_dense_to_sparse_autograd / copy_sparse_to_dense_autograd, against PyTorch index-scatter/gather references. Shapes match Mixtral-8x7B and fine-grained MoE. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
…rison Dispatch output and combine grad_sparse both have phantom rows (padding within expert ranges and the static tail beyond expert_ends[-1]) that copy_dense_to_sparse never writes. The PyTorch reference uses new_zeros while production code uses new_empty, so comparing the full tensors produced inf rel_rms. _zero_phantom_rows zeros those ranges in all variants before the runner compares them. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Add output_postprocess to Variant: a callable applied only during the accuracy check (not the timing loop), so phantom-row masking doesn't inflate measured latency. Precompute a boolean phantom_mask per case; _dispatch_postprocess and _combine_postprocess use masked_fill_ (one GPU op) to zero phantom rows before RMS comparison. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
torch.ones_like / torch.zeros_like inside timed fwd_bwd functions allocate a new GPU tensor on every timing rep, polluting measurements. bench_sparse_copy: add backward_grad to dispatch/combine make_inputs (shapes sparse_map.num_rows×hidden and tokens×hidden respectively) and use inp["backward_grad"] in all four fwd_bwd functions. bench_sparse_linear: same fix — precompute backward_grad in both make_inputs functions and remove the _zero_padded_rows(...ones_like...) call from all four fwd_bwd functions. The zeroing was never needed: pytorch_loop uses new_zeros so phantom rows have no autograd edge, and the Triton backward already bounds itself to expert_pad_begins. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
The Triton sparse-linear backward reads phantom row positions in the upstream gradient tensor. Passing all-ones for phantom rows gives a different grad_lhs than the PyTorch reference (which has no autograd edge for phantom rows), causing rel_rms ~0.28. Fix: apply _zero_padded_rows to backward_grad during make_inputs (outside the timing loop), matching the lhs_data preprocessing pattern. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
…block_size
Load re and im halves together as a single (head_block_size, head_size) contiguous
block per head rather than two separate strided half-head loads. Use a sign-flip
formula (out = x*cos + sign*x_partner*sin) to avoid splitting the loaded tensor.
The partner load (swapped halves via tl.where) hits L2 after the primary load.
Add @triton.autotune over head_block_size ∈ {1,2,4,8,16} × num_warps ∈ {4,8}
to let the tuner find the optimal block size per shape.
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
@triton.autotune is incompatible with in-place kernels: the autotuner runs all configs sequentially on the same tensor, so each trial rotates the tensor again, producing garbage results in the benchmark. Replace with a fixed head_block_size computed from POINTWISE_BLOCK_SIZE and simplify the kernel to load re and im halves separately, holding both in registers to compute out_re and out_im without a partner load. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
…lf load" This reverts commit 178903f.
…ne head_block_size" This reverts commit c21ef12.
…+ autotune head_block_size"" This reverts commit 0be67b1.
…o two-half load"" This reverts commit fa22f72.
The autotune is incompatible with in-place kernels during benchmarking because the timed variant calls the kernel on the same tensor multiple times — each call rotates it again, making the correctness check see a corrupted result. Fix by warming up autotune in make_inputs on a throwaway tensor, so the cached winning config is used for all timed runs. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Restore the original no-autotune kernel (benchmarking showed autotune neither helps nor can be benchmarked correctly for in-place kernels). Add tools/inspect_rotary_compile.py to dump the Triton code that torch.compile generates so we can compare it to the hand-written kernel. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
In-place variants (fast_llm_triton) were calling .clone() inside the timed callable, paying ~1 full HBM read+write per rep — roughly doubling measured time relative to the actual kernel cost. Fix: pre-allocate a "work" buffer in _make_rotary_inputs and restore it between reps via reset_inputs (runner calls this outside the timed region). No in-place variant now has allocation cost in the hot path. Runner change: Variant gains reset_inputs field; bench_fn gains reset parameter that is called before flush() and before each rep's start event. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
… variants In entropy_loss and grpo_loss, PyTorch fwd_bwd variants call .backward() which accumulates into inp["logits"].grad. After rep 1, .backward() on rep 2+ adds into the existing grad tensor (1 extra read+write of the full logits tensor per rep). Triton variants compute grad_logits fresh each rep without touching inp["logits"].grad — no accumulation, no extra read. For 4096×131K logits this is ~2 GB extra HBM traffic per rep, biasing PyTorch ~33% slower than reality in fwd_bwd timing. Fix: add reset_inputs=_reset_logits_grad to all PyTorch fwd_bwd variants in entropy_loss (all 4 groups: ce_labels, ce_dist, reverse_kl, z_loss) and grpo_loss. fp32_reference is unaffected (it detaches into a local tensor; inp["logits"].grad is never set). Other benchmarks (normalization, mlp_activation, sparse_copy/linear) are symmetric: all variants use .backward(), so gradient accumulation affects them equally. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
All other benchmarks (rotary, normalization, mlp_activation, entropy_loss, grpo_loss) use an fp32 reference as ground truth. sparse_copy and sparse_linear were using pytorch_eager / pytorch_loop in compute dtype (bf16) as is_reference=True — meaning Triton matching bf16 exactly gives zero RMS error even if both implementations share the same numeric drift. Add fp32_reference variants for all four cases: - bench_sparse_copy.py: dispatch and combine - bench_sparse_linear.py: output_sparse (layer 1) and input_inner_sparse (layer 2) Demote the existing pytorch_eager / pytorch_loop variants to regular comparison variants (no is_reference). Pattern matches the other benches: fp32 eager loop in float32, detaching inputs and recasting grad_output. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Issue 3 — sparse_linear warmup called once per make_inputs invocation: make_inputs is called many times per case (per variant × per fwd/fwd_bwd/ memory pass). The Triton autotuning warmup only needs to fire once per shape. Add module-level sets _output_sparse_warmed_up and _input_inner_sparse_warmed_up; skip the warmup on subsequent calls with the same (tokens, top_k, num_experts, hidden, ffn_per_expert, dtype) key. Issue 4 — private import of _fused/_fast_normalization_available: bench_normalization.py imported names with a leading underscore from fast_llm.layers.common.normalization.normalization. Drop the underscores from the source (they are Apex availability flags, already used widely inside the same module) and update the benchmark import to match. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
runner.py: rename l -> unit_label / unit_scale in _unit_column generator expression to avoid the E741 'ambiguous variable name' lint warning. bench_grpo_loss.py: add a comment explaining that the labels>=0 masks and clamp(min=0) mirror the production implementation's ignore_index=-100 handling; in this benchmark labels are always non-negative so the guards are unreachable. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
- inp → inputs, elem → element_size, _bytes_per_elem → _bytes_per_element in all bench_*.py files - n_reps → num_reps, n_warmup → num_warmup, sep → separator, l/s → unit_label/unit_scale in runner.py - Replace lambda default-capture patterns (t=tokens, v=vocab, d=dtype…) with functools.partial in all eight bench_*.py files Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Replace os.path / os.makedirs / os.walk with pathlib.Path per project convention. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
test_sparse_matmul.py: - Extend _SPARSE_TEST_DATAS with single-expert and fully-packed (no-padding) cases; all four existing kernel tests now run against these too - Add _output_sparse_linear_ref / _input_sparse_linear_ref Python-loop references (PyTorch-differentiable) - Add test_output_sparse_linear_autograd / test_input_sparse_linear_autograd comparing fwd output and both gradients against the loop reference — would have caught the swapped lhs/grad_out bug in the autograd backward tests/tools/test_benchmark_smoke.py: - test_run_benchmark_wiring: exercises Case/Variant/run_benchmark end-to-end with a trivial relu kernel; always runs without a GPU - test_bench_pointwise_smoke: monkeypatches _SIZES_NUMEL=[1024] and calls the real bench_pointwise pipeline, asserting fp32_reference and pytorch_eager variants succeed Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
The runner is a dev tool; its correctness is covered by the sparse matmul autograd tests that exercise the same kernel path in CI. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
- bench_mlp_activation.py: FLOPs/element_size → FLOPs/element in comment (rename script incorrectly replaced elem in a comment where it meant element, not element_size) - runner.py: Variant.reset_inputs return type None → Any (copy_() returns the destination tensor; runner discards it, so functionally fine) Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Warmup calls and the post-warmup one_rep_ms calibration were not calling reset(), so variants with reset_inputs (in-place rotary, PyTorch fwd_bwd entropy/grpo) ran their warmup reps in a dirty input state. The timed reps were already correct. This biased the num_reps estimate slightly low. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
All eight benchmark modules now accept a `shapes` parameter on their `run()` entry point (and the internal `_*_cases()` helpers), so callers can supply a custom list of input sizes without monkey-patching module globals. `tests/tools/test_triton_benchmark.py` uses this to run every module with one tiny shape and float32 dtype, keeping the full runner code path exercised without the long compile/autotune time of production shapes. The `_disable_dynamo` fixture suppresses torch.compile cold-start (~20 s per variant on CPU). The sparse_linear shape uses hidden=ffn_per_expert=256 to satisfy the Triton kernel's block-size divisibility assertions. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
1e96654 to
8441615
Compare
Add benchmarks() to all 8 bench modules (returning a list of (name, cases, variants) tuples) so each sub-benchmark is addressable independently. Thread min_reps through runner.py's bench_fn and run_benchmark so callers can cap rep count. Rewrite tests/tools/test_triton_benchmark.py to build one pytest parameter per kernel from benchmarks() calls (16 tests instead of 8 file-level tests) and run each with warmup_ms=0, rep_ms=0, min_reps=1, reducing the test suite from ~80s to ~45s while covering every kernel. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
pytorch_compiled and pytorch_compiled_max aren't needed for correctness checking and slow down the suite without covering new code paths. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
… patch tl.histogram is broken in the Triton interpreter; skip sparse_copy and sparse_linear in interpreter mode rather than patching around multiple cascading bugs. The np.histogram monkeypatch is reverted as it is no longer needed. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Adds a Triton kernel benchmark suite under
tools/benchmark/for comparing Fast-LLM Triton kernels against PyTorch eager,torch.compile, and fp32 reference implementations.Benchmark modules
bench_entropy_lossbench_grpo_lossbench_mlp_activationbench_normalizationbench_pointwisebench_rotarybench_sparse_copybench_sparse_linearRunner (
tools/benchmark/runner.py)reset_inputscallback is called before every warmup and timed rep, not just between casesBug fix: sparse matmul backward had swapped
lhs/grad_outargumentsOutputSparseLinear.backwardandInputSparseLinear.backwardwere calling the gradient kernel withlhsandgrad_outputin the wrong order, producing silently wrong gradients. This is fixed and covered by new autograd tests intests/functional/test_sparse_matmul.py.Parameterisable shapes for testing
All
run()entry points accept an optionalshapes=argument (andsizes=for pointwise) so that callers can pass a custom list of input sizes — no monkey-patching of module-level constants needed.tests/tools/test_triton_benchmark.pyuses this to run all 8 modules with one tiny shape andfloat32, exercising the full runner code path quickly. A_disable_dynamofixture suppresses thetorch.compileJIT cold-start (~20 s per variant on CPU).Test plan
test_triton_benchmarkparametrized cases pass on H100 (8 passed in 80s)test_output_sparse_linear_autograd/test_input_sparse_linear_autogradpasspython -m tools.benchmark <module>runs and produces a formatted table on GPU🤖 Generated with Claude Code