Skip to content

Add Triton kernel benchmark suite#493

Open
jlamypoirier wants to merge 30 commits intomainfrom
worktree-triton_benchmark
Open

Add Triton kernel benchmark suite#493
jlamypoirier wants to merge 30 commits intomainfrom
worktree-triton_benchmark

Conversation

@jlamypoirier
Copy link
Copy Markdown
Collaborator

@jlamypoirier jlamypoirier commented Apr 25, 2026

Summary

Adds a Triton kernel benchmark suite under tools/benchmark/ for comparing Fast-LLM Triton kernels against PyTorch eager, torch.compile, and fp32 reference implementations.

Benchmark modules

Module Kernels
bench_entropy_loss cross-entropy (labels + logits), reverse-KL, z-loss
bench_grpo_loss GRPO loss (fused softmax + ratio + clipped pg)
bench_mlp_activation gated-SiLU fused MLP activation
bench_normalization LayerNorm, RMSNorm (+ Apex variants)
bench_pointwise copy, fill, add
bench_rotary rotary embeddings
bench_sparse_copy MoE token dispatch / combine
bench_sparse_linear MoE output-sparse and input-inner-sparse matmuls

Runner (tools/benchmark/runner.py)

  • Measures fwd-only and fwd+bwd separately; reports throughput as % of HBM bandwidth and peak FLOP/s
  • Verifies numerical closeness of each variant against the fp32 reference
  • Tracks peak GPU memory per case
  • CPU fallback (1 warmup + 1 timed rep) so tests pass without a GPU
  • reset_inputs callback is called before every warmup and timed rep, not just between cases

Bug fix: sparse matmul backward had swapped lhs / grad_out arguments

OutputSparseLinear.backward and InputSparseLinear.backward were calling the gradient kernel with lhs and grad_output in the wrong order, producing silently wrong gradients. This is fixed and covered by new autograd tests in tests/functional/test_sparse_matmul.py.

Parameterisable shapes for testing

All run() entry points accept an optional shapes= argument (and sizes= for pointwise) so that callers can pass a custom list of input sizes — no monkey-patching of module-level constants needed.

tests/tools/test_triton_benchmark.py uses this to run all 8 modules with one tiny shape and float32, exercising the full runner code path quickly. A _disable_dynamo fixture suppresses the torch.compile JIT cold-start (~20 s per variant on CPU).

Test plan

  • All 8 test_triton_benchmark parametrized cases pass on H100 (8 passed in 80s)
  • New test_output_sparse_linear_autograd / test_input_sparse_linear_autograd pass
  • python -m tools.benchmark <module> runs and produces a formatted table on GPU

🤖 Generated with Claude Code

jlamypoirier and others added 27 commits April 29, 2026 23:30
Adds `tools/benchmark/` — a micro-benchmark harness for Fast-LLM's Triton
kernels that measures throughput (GB/s, % peak BW, TFLOP/s) and checks
numerical correctness against a fp32 reference for each kernel variant.

Kernels covered:
- entropy loss: cross_entropy (labels), cross_entropy (logits/distillation),
  reverse_kl (logits), and z_loss
- normalization: LayerNorm and RMSNorm (fwd+bwd)
- MLP activation: gated SiLU fused kernel (fwd+bwd)
- rotary embeddings: in-place Triton kernel vs PyTorch eager/compiled
- pointwise: cast-add-cast fused kernel

Each benchmark compares fp32_reference, pytorch_eager, pytorch_compiled,
pytorch_compiled_max, apex variants (where available), and fast_llm_triton.
The runner auto-detects GPU peak bandwidth from device properties, reports
% of peak BW per variant, and flags numerical deviations from the reference.

Also adds `DataType.short` property (bf16/fp32/…) used by benchmark case names,
and `tools/__init__.py` to make `tools` a package for `python -m tools.benchmark`.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Adds tools/benchmark/bench_sparse_linear.py covering the two sparse
GEMM kernels in MoE FFN layers (output_sparse / up-proj and
input_inner_sparse / down-proj), comparing fast_llm_triton against a
PyTorch loop reference and torch.compile.

Fixes three kernel correctness bugs surfaced by the new benchmark:

1. Phantom blocks left output uninitialized. Both
   output_sparse_matmul_kernel and input_inner_sparse_matmul_kernel
   short-circuited (`return`) on blocks past the last expert, leaving
   those rows of the caller-allocated output buffer with whatever
   garbage happened to be at that GPU address. Production silently
   discarded the garbage at the scatter-back boundary, but it is a
   latent footgun for any caller that reads the full output tensor
   (and what made the benchmark comparison nondeterministic). The
   skipped blocks now write zeros instead of returning, so the output
   is fully defined regardless of how the caller allocated it.

2. Inner-loop tl.dot accumulated in bfloat16. The first tl.dot in
   each of the three kernels already specified `out_dtype=tl.float32`,
   but the loop bodies did not, so accumulation past the first tile
   silently fell back to bf16 Tensor Core accumulation. Added
   `out_dtype=tl.float32` to every tl.dot call.

3. Backward used wrong argument order for input_row_sparse_matmul.
   The grad_rhs path in OutputSparseLinear and InputSparseLinear had
   lhs/grad_out swapped relative to what the kernel expects.

After these fixes every rel_rms in the sparse_linear benchmark goes
to 0 across all six (op, shape) configurations.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
- bench_grpo_loss.py: fused triton_grpo_loss_forward_backward kernel
  vs PyTorch eager / compiled, swept over vocab=32K/64K/128K (matches
  bench_entropy_loss shapes since GRPO is structurally similar).
- bench_sparse_copy.py: MoE token dispatch (dense->sparse) and combine
  (sparse->dense) via copy_dense_to_sparse_autograd /
  copy_sparse_to_dense_autograd, against PyTorch index-scatter/gather
  references. Shapes match Mixtral-8x7B and fine-grained MoE.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
…rison

Dispatch output and combine grad_sparse both have phantom rows (padding
within expert ranges and the static tail beyond expert_ends[-1]) that
copy_dense_to_sparse never writes. The PyTorch reference uses new_zeros
while production code uses new_empty, so comparing the full tensors
produced inf rel_rms. _zero_phantom_rows zeros those ranges in all
variants before the runner compares them.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Add output_postprocess to Variant: a callable applied only during the
accuracy check (not the timing loop), so phantom-row masking doesn't
inflate measured latency. Precompute a boolean phantom_mask per case;
_dispatch_postprocess and _combine_postprocess use masked_fill_ (one
GPU op) to zero phantom rows before RMS comparison.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
torch.ones_like / torch.zeros_like inside timed fwd_bwd functions
allocate a new GPU tensor on every timing rep, polluting measurements.

bench_sparse_copy: add backward_grad to dispatch/combine make_inputs
(shapes sparse_map.num_rows×hidden and tokens×hidden respectively) and
use inp["backward_grad"] in all four fwd_bwd functions.

bench_sparse_linear: same fix — precompute backward_grad in both
make_inputs functions and remove the _zero_padded_rows(...ones_like...)
call from all four fwd_bwd functions. The zeroing was never needed:
pytorch_loop uses new_zeros so phantom rows have no autograd edge, and
the Triton backward already bounds itself to expert_pad_begins.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
The Triton sparse-linear backward reads phantom row positions in the
upstream gradient tensor. Passing all-ones for phantom rows gives a
different grad_lhs than the PyTorch reference (which has no autograd
edge for phantom rows), causing rel_rms ~0.28.

Fix: apply _zero_padded_rows to backward_grad during make_inputs (outside
the timing loop), matching the lhs_data preprocessing pattern.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
…block_size

Load re and im halves together as a single (head_block_size, head_size) contiguous
block per head rather than two separate strided half-head loads.  Use a sign-flip
formula (out = x*cos + sign*x_partner*sin) to avoid splitting the loaded tensor.
The partner load (swapped halves via tl.where) hits L2 after the primary load.

Add @triton.autotune over head_block_size ∈ {1,2,4,8,16} × num_warps ∈ {4,8}
to let the tuner find the optimal block size per shape.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
@triton.autotune is incompatible with in-place kernels: the autotuner
runs all configs sequentially on the same tensor, so each trial rotates
the tensor again, producing garbage results in the benchmark.

Replace with a fixed head_block_size computed from POINTWISE_BLOCK_SIZE
and simplify the kernel to load re and im halves separately, holding
both in registers to compute out_re and out_im without a partner load.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
…+ autotune head_block_size""

This reverts commit 0be67b1.
The autotune is incompatible with in-place kernels during benchmarking
because the timed variant calls the kernel on the same tensor multiple
times — each call rotates it again, making the correctness check see a
corrupted result.  Fix by warming up autotune in make_inputs on a
throwaway tensor, so the cached winning config is used for all timed runs.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Restore the original no-autotune kernel (benchmarking showed autotune
neither helps nor can be benchmarked correctly for in-place kernels).
Add tools/inspect_rotary_compile.py to dump the Triton code that
torch.compile generates so we can compare it to the hand-written kernel.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
In-place variants (fast_llm_triton) were calling .clone() inside the
timed callable, paying ~1 full HBM read+write per rep — roughly doubling
measured time relative to the actual kernel cost.

Fix: pre-allocate a "work" buffer in _make_rotary_inputs and restore it
between reps via reset_inputs (runner calls this outside the timed region).
No in-place variant now has allocation cost in the hot path.

Runner change: Variant gains reset_inputs field; bench_fn gains reset
parameter that is called before flush() and before each rep's start event.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
… variants

In entropy_loss and grpo_loss, PyTorch fwd_bwd variants call .backward()
which accumulates into inp["logits"].grad. After rep 1, .backward() on
rep 2+ adds into the existing grad tensor (1 extra read+write of the full
logits tensor per rep). Triton variants compute grad_logits fresh each rep
without touching inp["logits"].grad — no accumulation, no extra read.

For 4096×131K logits this is ~2 GB extra HBM traffic per rep, biasing
PyTorch ~33% slower than reality in fwd_bwd timing.

Fix: add reset_inputs=_reset_logits_grad to all PyTorch fwd_bwd variants
in entropy_loss (all 4 groups: ce_labels, ce_dist, reverse_kl, z_loss)
and grpo_loss. fp32_reference is unaffected (it detaches into a local
tensor; inp["logits"].grad is never set).

Other benchmarks (normalization, mlp_activation, sparse_copy/linear) are
symmetric: all variants use .backward(), so gradient accumulation affects
them equally.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
All other benchmarks (rotary, normalization, mlp_activation, entropy_loss,
grpo_loss) use an fp32 reference as ground truth. sparse_copy and
sparse_linear were using pytorch_eager / pytorch_loop in compute dtype
(bf16) as is_reference=True — meaning Triton matching bf16 exactly gives
zero RMS error even if both implementations share the same numeric drift.

Add fp32_reference variants for all four cases:
- bench_sparse_copy.py: dispatch and combine
- bench_sparse_linear.py: output_sparse (layer 1) and input_inner_sparse (layer 2)

Demote the existing pytorch_eager / pytorch_loop variants to regular
comparison variants (no is_reference). Pattern matches the other benches:
fp32 eager loop in float32, detaching inputs and recasting grad_output.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Issue 3 — sparse_linear warmup called once per make_inputs invocation:
make_inputs is called many times per case (per variant × per fwd/fwd_bwd/
memory pass). The Triton autotuning warmup only needs to fire once per
shape. Add module-level sets _output_sparse_warmed_up and
_input_inner_sparse_warmed_up; skip the warmup on subsequent calls with
the same (tokens, top_k, num_experts, hidden, ffn_per_expert, dtype) key.

Issue 4 — private import of _fused/_fast_normalization_available:
bench_normalization.py imported names with a leading underscore from
fast_llm.layers.common.normalization.normalization. Drop the underscores
from the source (they are Apex availability flags, already used widely
inside the same module) and update the benchmark import to match.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
runner.py: rename l -> unit_label / unit_scale in _unit_column generator
expression to avoid the E741 'ambiguous variable name' lint warning.

bench_grpo_loss.py: add a comment explaining that the labels>=0 masks and
clamp(min=0) mirror the production implementation's ignore_index=-100
handling; in this benchmark labels are always non-negative so the guards
are unreachable.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
- inp → inputs, elem → element_size, _bytes_per_elem → _bytes_per_element
  in all bench_*.py files
- n_reps → num_reps, n_warmup → num_warmup, sep → separator,
  l/s → unit_label/unit_scale in runner.py
- Replace lambda default-capture patterns (t=tokens, v=vocab, d=dtype…)
  with functools.partial in all eight bench_*.py files

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Replace os.path / os.makedirs / os.walk with pathlib.Path per project convention.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
test_sparse_matmul.py:
- Extend _SPARSE_TEST_DATAS with single-expert and fully-packed (no-padding)
  cases; all four existing kernel tests now run against these too
- Add _output_sparse_linear_ref / _input_sparse_linear_ref Python-loop
  references (PyTorch-differentiable)
- Add test_output_sparse_linear_autograd / test_input_sparse_linear_autograd
  comparing fwd output and both gradients against the loop reference —
  would have caught the swapped lhs/grad_out bug in the autograd backward

tests/tools/test_benchmark_smoke.py:
- test_run_benchmark_wiring: exercises Case/Variant/run_benchmark end-to-end
  with a trivial relu kernel; always runs without a GPU
- test_bench_pointwise_smoke: monkeypatches _SIZES_NUMEL=[1024] and calls
  the real bench_pointwise pipeline, asserting fp32_reference and
  pytorch_eager variants succeed

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
The runner is a dev tool; its correctness is covered by the sparse matmul
autograd tests that exercise the same kernel path in CI.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
- bench_mlp_activation.py: FLOPs/element_size → FLOPs/element in comment
  (rename script incorrectly replaced elem in a comment where it meant
  element, not element_size)
- runner.py: Variant.reset_inputs return type None → Any (copy_() returns
  the destination tensor; runner discards it, so functionally fine)

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Warmup calls and the post-warmup one_rep_ms calibration were not calling
reset(), so variants with reset_inputs (in-place rotary, PyTorch fwd_bwd
entropy/grpo) ran their warmup reps in a dirty input state. The timed reps
were already correct. This biased the num_reps estimate slightly low.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
All eight benchmark modules now accept a `shapes` parameter on their
`run()` entry point (and the internal `_*_cases()` helpers), so callers
can supply a custom list of input sizes without monkey-patching module
globals.

`tests/tools/test_triton_benchmark.py` uses this to run every module
with one tiny shape and float32 dtype, keeping the full runner code
path exercised without the long compile/autotune time of production
shapes.  The `_disable_dynamo` fixture suppresses torch.compile
cold-start (~20 s per variant on CPU).  The sparse_linear shape uses
hidden=ffn_per_expert=256 to satisfy the Triton kernel's block-size
divisibility assertions.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
@jlamypoirier jlamypoirier force-pushed the worktree-triton_benchmark branch from 1e96654 to 8441615 Compare April 30, 2026 03:30
jlamypoirier and others added 2 commits April 29, 2026 23:55
Add benchmarks() to all 8 bench modules (returning a list of
(name, cases, variants) tuples) so each sub-benchmark is addressable
independently. Thread min_reps through runner.py's bench_fn and
run_benchmark so callers can cap rep count.

Rewrite tests/tools/test_triton_benchmark.py to build one pytest
parameter per kernel from benchmarks() calls (16 tests instead of 8
file-level tests) and run each with warmup_ms=0, rep_ms=0, min_reps=1,
reducing the test suite from ~80s to ~45s while covering every kernel.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
pytorch_compiled and pytorch_compiled_max aren't needed for correctness
checking and slow down the suite without covering new code paths.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
… patch

tl.histogram is broken in the Triton interpreter; skip sparse_copy and
sparse_linear in interpreter mode rather than patching around multiple
cascading bugs. The np.histogram monkeypatch is reverted as it is no
longer needed.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant