From c76ee4d8418f13e9578ffc7f9ccbb4039237d08c Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 24 Apr 2026 22:35:00 -0400 Subject: [PATCH 01/41] Add Triton kernel benchmark suite MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds `tools/benchmark/` — a micro-benchmark harness for Fast-LLM's Triton kernels that measures throughput (GB/s, % peak BW, TFLOP/s) and checks numerical correctness against a fp32 reference for each kernel variant. Kernels covered: - entropy loss: cross_entropy (labels), cross_entropy (logits/distillation), reverse_kl (logits), and z_loss - normalization: LayerNorm and RMSNorm (fwd+bwd) - MLP activation: gated SiLU fused kernel (fwd+bwd) - rotary embeddings: in-place Triton kernel vs PyTorch eager/compiled - pointwise: cast-add-cast fused kernel Each benchmark compares fp32_reference, pytorch_eager, pytorch_compiled, pytorch_compiled_max, apex variants (where available), and fast_llm_triton. The runner auto-detects GPU peak bandwidth from device properties, reports % of peak BW per variant, and flags numerical deviations from the reference. Also adds `DataType.short` property (bf16/fp32/…) used by benchmark case names, and `tools/__init__.py` to make `tools` a package for `python -m tools.benchmark`. Co-Authored-By: Claude Sonnet 4.6 --- fast_llm/engine/config_utils/data_type.py | 6 + tools/__init__.py | 0 tools/benchmark/__init__.py | 0 tools/benchmark/__main__.py | 79 +++ tools/benchmark/bench_entropy_loss.py | 461 +++++++++++++++ tools/benchmark/bench_mlp_activation.py | 180 ++++++ tools/benchmark/bench_normalization.py | 346 ++++++++++++ tools/benchmark/bench_pointwise.py | 141 +++++ tools/benchmark/bench_rotary.py | 117 ++++ tools/benchmark/gpu_specs.py | 99 ++++ tools/benchmark/runner.py | 650 ++++++++++++++++++++++ tools/benchmark/utils.py | 85 +++ 12 files changed, 2164 insertions(+) create mode 100644 tools/__init__.py create mode 100644 tools/benchmark/__init__.py create mode 100644 tools/benchmark/__main__.py create mode 100644 tools/benchmark/bench_entropy_loss.py create mode 100644 tools/benchmark/bench_mlp_activation.py create mode 100644 tools/benchmark/bench_normalization.py create mode 100644 tools/benchmark/bench_pointwise.py create mode 100644 tools/benchmark/bench_rotary.py create mode 100644 tools/benchmark/gpu_specs.py create mode 100644 tools/benchmark/runner.py create mode 100644 tools/benchmark/utils.py diff --git a/fast_llm/engine/config_utils/data_type.py b/fast_llm/engine/config_utils/data_type.py index 27709a8bb..c8543b782 100644 --- a/fast_llm/engine/config_utils/data_type.py +++ b/fast_llm/engine/config_utils/data_type.py @@ -83,6 +83,11 @@ def triton(self) -> "tl.dtype": _set_triton_dtype_map() return _TRITON_DTYPE_MAP[self] + @property + def short(self) -> str: + """Abbreviated name (bf16, fp32, ...) when one is defined, else the full name.""" + return _DTYPE_SHORT_NAME_MAP.get(self, self.value) + _KNOWN_DATA_TYPE_PREFIXES = {"DataType", "numpy", "np", "torch", "triton.language", "tl"} @@ -92,6 +97,7 @@ def triton(self) -> "tl.dtype": "fp16": DataType.float16, "bf16": DataType.bfloat16, } +_DTYPE_SHORT_NAME_MAP = {v: k for k, v in _DTYPE_ALT_NAME_MAP_INV.items()} _TORCH_DTYPE_MAP: dict[DataType, "torch.dtype"] = {} _TORCH_DTYPE_MAP_INV: dict["torch.dtype", DataType] = {} diff --git a/tools/__init__.py b/tools/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tools/benchmark/__init__.py b/tools/benchmark/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tools/benchmark/__main__.py b/tools/benchmark/__main__.py new file mode 100644 index 000000000..102d72e63 --- /dev/null +++ b/tools/benchmark/__main__.py @@ -0,0 +1,79 @@ +""" +CLI entry point for the Fast-LLM Triton kernel benchmarking suite. + +Usage: + python -m tools.benchmark + +Available kernels are discovered dynamically from `bench_*.py` files in this +package. Each such module must expose a `run(verbose: bool = False)` callable. +""" + +import argparse +import importlib +import logging +import pkgutil +import warnings + +# Each bench file compiles the same function with multiple shapes/dtypes; the +# default cache size (8) is too small, causing dynamo to give up and fall back +# to eager. Bump it before any `torch.compile`-decorated code runs. +import torch._dynamo + +import tools.benchmark as _pkg +from fast_llm.engine.config_utils.data_type import DataType + +torch._dynamo.config.cache_size_limit = 64 + +# In-place ops (copy_, fill_, add with out=) emit "skipping cudagraphs due to +# mutated inputs" when using max-autotune. The fallback is correct; suppress noise. +warnings.filterwarnings("ignore", message=".*[Ss]kipping (cuda|CUDA)[Gg]raphs.*") +logging.getLogger("torch._inductor.cudagraph_trees").setLevel(logging.ERROR) + + +def _list_benchmarks() -> dict[str, str]: + """Map short kernel name → fully-qualified bench module name.""" + names = {} + for info in pkgutil.iter_modules(_pkg.__path__): + if info.name.startswith("bench_"): + names[info.name.removeprefix("bench_")] = f"tools.benchmark.{info.name}" + return names + + +def main() -> None: + benches = _list_benchmarks() + parser = argparse.ArgumentParser( + prog="python -m tools.benchmark", + description="Benchmark Fast-LLM Triton kernels against PyTorch alternatives.", + ) + parser.add_argument( + "kernels", + nargs="*", + choices=sorted(benches), + metavar="kernel", + help="Which kernels to benchmark. If omitted, run all kernels.", + ) + parser.add_argument( + "-v", + "--verbose", + action="store_true", + help="Show additional timing columns (mean, min, max).", + ) + parser.add_argument( + "-d", + "--dtypes", + nargs="+", + default=["bfloat16"], + metavar="DTYPE", + help="Fast-LLM DataType names to sweep (default: bfloat16). " + "Examples: bfloat16 float32 fp16. Accepts alternate names like bf16.", + ) + args = parser.parse_args() + dtypes = [DataType(d).torch for d in args.dtypes] + + selected = args.kernels or sorted(benches) + for kernel in selected: + importlib.import_module(benches[kernel]).run(verbose=args.verbose, dtypes=dtypes) + + +if __name__ == "__main__": + main() diff --git a/tools/benchmark/bench_entropy_loss.py b/tools/benchmark/bench_entropy_loss.py new file mode 100644 index 000000000..13fec495e --- /dev/null +++ b/tools/benchmark/bench_entropy_loss.py @@ -0,0 +1,461 @@ +""" +Benchmark entropy loss kernels. + +All Triton kernels fuse fwd+bwd into a single logits-tensor pass; `grad_output=1.0` +triggers gradient computation alongside the loss. + +Three main training cases benchmarked: + + cross_entropy + labels — standard LM training (integer targets) + cross_entropy + logits — distillation CE with soft targets, p=softmax(target_logits) + reverse_kl + logits — reverse KL divergence KL(q||p), p=softmax(target_logits) + +z_loss is also included (shared input structure with the labels case). + +Shapes fix tokens=4096, sweep vocab size from Llama-2 (32K) to Llama-3 (128K). +""" + +import torch +import torch.nn.functional as F + +from fast_llm.functional.config import EntropyLossType, TargetFormat, TritonConfig +from fast_llm.functional.triton.entropy_loss import triton_entropy_loss_forward_backward +from fast_llm.functional.triton.z_loss import triton_z_loss_forward_backward +from tools.benchmark.runner import Case, Variant, run_benchmark +from tools.benchmark.utils import case_name, device + +# (tokens, vocab_size) +_SHAPES = [ + (4096, 32768), # 7B / Llama-2 vocab + (4096, 65536), # 64K vocab + (4096, 131072), # Llama-3 vocab +] +_DEFAULT_DTYPES = (torch.bfloat16,) + + +# --------------------------------------------------------------------------- inputs + + +def _make_label_inputs(tokens: int, vocab: int, dtype: torch.dtype) -> dict: + return { + "logits": torch.randn(tokens, vocab, dtype=dtype, device=device(), requires_grad=True), + "labels": torch.randint(0, vocab, (tokens,), dtype=torch.long, device=device()), + } + + +def _make_distribution_inputs(tokens: int, vocab: int, dtype: torch.dtype) -> dict: + return { + "logits": torch.randn(tokens, vocab, dtype=dtype, device=device(), requires_grad=True), + # target_logits: teacher logits; no gradient needed w.r.t. these. + "target_logits": torch.randn(tokens, vocab, dtype=dtype, device=device()), + } + + +# --------------------------------------------------------------------------- cross_entropy (labels) + + +def _ce_labels_eager(logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: + return F.cross_entropy(logits, labels) + + +_ce_labels_compiled_default = torch.compile(_ce_labels_eager, mode="default", dynamic=False) +_ce_labels_compiled_max = torch.compile(_ce_labels_eager, mode="max-autotune-no-cudagraphs", dynamic=False) + + +def _run_ce_labels_fwd(inp: dict, fn) -> dict: + return {"loss": fn(inp["logits"], inp["labels"])} + + +def _run_ce_labels_fwd_fp32(inp: dict) -> dict: + logits_fp32 = inp["logits"].float().detach().requires_grad_(True) + return {"loss": _ce_labels_eager(logits_fp32, inp["labels"])} + + +def _run_ce_labels_fwd_bwd(inp: dict, fn) -> dict: + loss = fn(inp["logits"], inp["labels"]) + loss.backward() + return {"loss": loss.detach(), "grad_logits": inp["logits"].grad} + + +def _run_ce_labels_fwd_bwd_fp32(inp: dict) -> dict: + logits_fp32 = inp["logits"].float().detach().requires_grad_(True) + loss = _ce_labels_eager(logits_fp32, inp["labels"]) + loss.backward() + return {"loss": loss.detach(), "grad_logits": logits_fp32.grad} + + +def _run_ce_labels_fwd_triton(inp: dict) -> dict: + loss, _ = triton_entropy_loss_forward_backward(inp["logits"], inp["labels"], loss_mask=None, grad_output=None) + return {"loss": loss} + + +def _run_ce_labels_fwd_bwd_triton(inp: dict) -> dict: + loss, grad_logits = triton_entropy_loss_forward_backward( + inp["logits"], inp["labels"], loss_mask=None, grad_output=1.0 + ) + return {"loss": loss, "grad_logits": grad_logits} + + +def _ce_labels_variants() -> list[Variant]: + variants = [ + Variant( + name="fp32_reference", + fwd=_run_ce_labels_fwd_fp32, + fwd_bwd=_run_ce_labels_fwd_bwd_fp32, + is_reference=True, + ), + Variant( + name="pytorch_eager", + fwd=lambda inp: _run_ce_labels_fwd(inp, _ce_labels_eager), + fwd_bwd=lambda inp: _run_ce_labels_fwd_bwd(inp, _ce_labels_eager), + ), + Variant( + name="pytorch_compiled", + fwd=lambda inp: _run_ce_labels_fwd(inp, _ce_labels_compiled_default), + fwd_bwd=lambda inp: _run_ce_labels_fwd_bwd(inp, _ce_labels_compiled_default), + ), + Variant( + name="pytorch_compiled_max", + fwd=lambda inp: _run_ce_labels_fwd(inp, _ce_labels_compiled_max), + fwd_bwd=lambda inp: _run_ce_labels_fwd_bwd(inp, _ce_labels_compiled_max), + ), + ] + if TritonConfig.enabled(): + variants.append( + Variant( + name="fast_llm_triton", + fwd=_run_ce_labels_fwd_triton, + fwd_bwd=_run_ce_labels_fwd_bwd_triton, + ) + ) + return variants + + +# --------------------------------------------------------------------------- cross_entropy (logits / distribution) + + +def _ce_dist_eager(logits: torch.Tensor, target_logits: torch.Tensor) -> torch.Tensor: + """CE(p, q) where p = softmax(target_logits), q = softmax(logits).""" + return F.cross_entropy(logits, target_logits.softmax(dim=-1)) + + +_ce_dist_compiled_default = torch.compile(_ce_dist_eager, mode="default", dynamic=False) +_ce_dist_compiled_max = torch.compile(_ce_dist_eager, mode="max-autotune-no-cudagraphs", dynamic=False) + + +def _run_dist_fwd(inp: dict, fn) -> dict: + return {"loss": fn(inp["logits"], inp["target_logits"])} + + +def _run_ce_dist_fwd_fp32(inp: dict) -> dict: + logits_fp32 = inp["logits"].float().detach().requires_grad_(True) + return {"loss": _ce_dist_eager(logits_fp32, inp["target_logits"].float())} + + +def _run_dist_fwd_bwd(inp: dict, fn) -> dict: + loss = fn(inp["logits"], inp["target_logits"]) + loss.backward() + return {"loss": loss.detach(), "grad_logits": inp["logits"].grad} + + +def _run_ce_dist_fwd_bwd_fp32(inp: dict) -> dict: + logits_fp32 = inp["logits"].float().detach().requires_grad_(True) + loss = _ce_dist_eager(logits_fp32, inp["target_logits"].float()) + loss.backward() + return {"loss": loss.detach(), "grad_logits": logits_fp32.grad} + + +def _run_ce_dist_fwd_triton(inp: dict) -> dict: + loss, _ = triton_entropy_loss_forward_backward( + inp["logits"], + inp["target_logits"], + loss_mask=None, + grad_output=None, + target_format=TargetFormat.logits, + entropy_loss_type=EntropyLossType.cross_entropy, + ) + return {"loss": loss} + + +def _run_ce_dist_fwd_bwd_triton(inp: dict) -> dict: + loss, grad_logits = triton_entropy_loss_forward_backward( + inp["logits"], + inp["target_logits"], + loss_mask=None, + grad_output=1.0, + target_format=TargetFormat.logits, + entropy_loss_type=EntropyLossType.cross_entropy, + ) + return {"loss": loss, "grad_logits": grad_logits} + + +def _ce_dist_variants() -> list[Variant]: + variants = [ + Variant( + name="fp32_reference", + fwd=_run_ce_dist_fwd_fp32, + fwd_bwd=_run_ce_dist_fwd_bwd_fp32, + is_reference=True, + ), + Variant( + name="pytorch_eager", + fwd=lambda inp: _run_dist_fwd(inp, _ce_dist_eager), + fwd_bwd=lambda inp: _run_dist_fwd_bwd(inp, _ce_dist_eager), + ), + Variant( + name="pytorch_compiled", + fwd=lambda inp: _run_dist_fwd(inp, _ce_dist_compiled_default), + fwd_bwd=lambda inp: _run_dist_fwd_bwd(inp, _ce_dist_compiled_default), + ), + Variant( + name="pytorch_compiled_max", + fwd=lambda inp: _run_dist_fwd(inp, _ce_dist_compiled_max), + fwd_bwd=lambda inp: _run_dist_fwd_bwd(inp, _ce_dist_compiled_max), + ), + ] + if TritonConfig.enabled(): + variants.append( + Variant( + name="fast_llm_triton", + fwd=_run_ce_dist_fwd_triton, + fwd_bwd=_run_ce_dist_fwd_bwd_triton, + ) + ) + return variants + + +# --------------------------------------------------------------------------- reverse_kl (logits / distribution) + + +def _reverse_kl_eager(logits: torch.Tensor, target_logits: torch.Tensor) -> torch.Tensor: + """KL(q||p) where q = softmax(logits), p = softmax(target_logits).""" + return F.kl_div( + target_logits.log_softmax(dim=-1), + logits.softmax(dim=-1), + reduction="batchmean", + ) + + +_reverse_kl_compiled_default = torch.compile(_reverse_kl_eager, mode="default", dynamic=False) +_reverse_kl_compiled_max = torch.compile(_reverse_kl_eager, mode="max-autotune-no-cudagraphs", dynamic=False) + + +def _run_rkl_fwd_fp32(inp: dict) -> dict: + logits_fp32 = inp["logits"].float().detach().requires_grad_(True) + return {"loss": _reverse_kl_eager(logits_fp32, inp["target_logits"].float())} + + +def _run_rkl_fwd_bwd_fp32(inp: dict) -> dict: + logits_fp32 = inp["logits"].float().detach().requires_grad_(True) + loss = _reverse_kl_eager(logits_fp32, inp["target_logits"].float()) + loss.backward() + return {"loss": loss.detach(), "grad_logits": logits_fp32.grad} + + +def _run_rkl_fwd_triton(inp: dict) -> dict: + loss, _ = triton_entropy_loss_forward_backward( + inp["logits"], + inp["target_logits"], + loss_mask=None, + grad_output=None, + target_format=TargetFormat.logits, + entropy_loss_type=EntropyLossType.reverse_kl, + ) + return {"loss": loss} + + +def _run_rkl_fwd_bwd_triton(inp: dict) -> dict: + loss, grad_logits = triton_entropy_loss_forward_backward( + inp["logits"], + inp["target_logits"], + loss_mask=None, + grad_output=1.0, + target_format=TargetFormat.logits, + entropy_loss_type=EntropyLossType.reverse_kl, + ) + return {"loss": loss, "grad_logits": grad_logits} + + +def _reverse_kl_variants() -> list[Variant]: + variants = [ + Variant( + name="fp32_reference", + fwd=_run_rkl_fwd_fp32, + fwd_bwd=_run_rkl_fwd_bwd_fp32, + is_reference=True, + ), + Variant( + name="pytorch_eager", + fwd=lambda inp: _run_dist_fwd(inp, _reverse_kl_eager), + fwd_bwd=lambda inp: _run_dist_fwd_bwd(inp, _reverse_kl_eager), + ), + Variant( + name="pytorch_compiled", + fwd=lambda inp: _run_dist_fwd(inp, _reverse_kl_compiled_default), + fwd_bwd=lambda inp: _run_dist_fwd_bwd(inp, _reverse_kl_compiled_default), + ), + Variant( + name="pytorch_compiled_max", + fwd=lambda inp: _run_dist_fwd(inp, _reverse_kl_compiled_max), + fwd_bwd=lambda inp: _run_dist_fwd_bwd(inp, _reverse_kl_compiled_max), + ), + ] + if TritonConfig.enabled(): + variants.append( + Variant( + name="fast_llm_triton", + fwd=_run_rkl_fwd_triton, + fwd_bwd=_run_rkl_fwd_bwd_triton, + ) + ) + return variants + + +# --------------------------------------------------------------------------- z_loss + + +def _z_loss_eager(logits: torch.Tensor) -> torch.Tensor: + log_z = torch.logsumexp(logits.float(), dim=-1) + return (log_z * log_z).mean() + + +_z_loss_compiled_default = torch.compile(_z_loss_eager, mode="default", dynamic=False) +_z_loss_compiled_max = torch.compile(_z_loss_eager, mode="max-autotune-no-cudagraphs", dynamic=False) + + +def _run_zl_fwd(inp: dict, fn) -> dict: + return {"loss": fn(inp["logits"])} + + +def _run_zl_fwd_fp32(inp: dict) -> dict: + logits_fp32 = inp["logits"].float().detach().requires_grad_(True) + return {"loss": _z_loss_eager(logits_fp32)} + + +def _run_zl_fwd_bwd(inp: dict, fn) -> dict: + loss = fn(inp["logits"]) + loss.backward() + return {"loss": loss.detach(), "grad_logits": inp["logits"].grad} + + +def _run_zl_fwd_bwd_fp32(inp: dict) -> dict: + logits_fp32 = inp["logits"].float().detach().requires_grad_(True) + loss = _z_loss_eager(logits_fp32) + loss.backward() + return {"loss": loss.detach(), "grad_logits": logits_fp32.grad} + + +def _run_zl_fwd_triton(inp: dict) -> dict: + loss, _ = triton_z_loss_forward_backward(inp["logits"], loss_mask=None, grad_output=None) + return {"loss": loss} + + +def _run_zl_fwd_bwd_triton(inp: dict) -> dict: + loss, grad_logits = triton_z_loss_forward_backward(inp["logits"], loss_mask=None, grad_output=1.0) + return {"loss": loss, "grad_logits": grad_logits} + + +def _z_loss_variants() -> list[Variant]: + variants = [ + Variant(name="fp32_reference", fwd=_run_zl_fwd_fp32, fwd_bwd=_run_zl_fwd_bwd_fp32, is_reference=True), + Variant( + name="pytorch_eager", + fwd=lambda inp: _run_zl_fwd(inp, _z_loss_eager), + fwd_bwd=lambda inp: _run_zl_fwd_bwd(inp, _z_loss_eager), + ), + Variant( + name="pytorch_compiled", + fwd=lambda inp: _run_zl_fwd(inp, _z_loss_compiled_default), + fwd_bwd=lambda inp: _run_zl_fwd_bwd(inp, _z_loss_compiled_default), + ), + Variant( + name="pytorch_compiled_max", + fwd=lambda inp: _run_zl_fwd(inp, _z_loss_compiled_max), + fwd_bwd=lambda inp: _run_zl_fwd_bwd(inp, _z_loss_compiled_max), + ), + ] + if TritonConfig.enabled(): + variants.append(Variant(name="fast_llm_triton", fwd=_run_zl_fwd_triton, fwd_bwd=_run_zl_fwd_bwd_triton)) + return variants + + +# --------------------------------------------------------------------------- cases + + +def _bytes_per_elem(dtype: torch.dtype) -> int: + return torch.tensor([], dtype=dtype).element_size() + + +def _label_loss_bytes(tokens: int, vocab: int, dtype: torch.dtype) -> int: + """fwd+bwd: read logits, read labels (int32), write grad_logits.""" + elem = _bytes_per_elem(dtype) + return 2 * tokens * vocab * elem + tokens * 4 + + +def _dist_loss_bytes(tokens: int, vocab: int, dtype: torch.dtype) -> int: + """fwd+bwd: read logits, read target_logits, write grad_logits.""" + elem = _bytes_per_elem(dtype) + return 3 * tokens * vocab * elem + + +def _entropy_loss_flops(tokens: int, vocab: int) -> int: + # fwd ≈ 3*vocab per token (max, sum_exp, CE); bwd ≈ vocab. Total ≈ 4*vocab. + return 4 * tokens * vocab + + +def _label_cases(kernel_name: str, dtypes: tuple[torch.dtype, ...]) -> list[Case]: + return [ + Case( + name=case_name(kernel_name, (tokens, vocab), dtype), + make_inputs=(lambda t=tokens, v=vocab, d=dtype: _make_label_inputs(t, v, d)), + expected_bytes=_label_loss_bytes(tokens, vocab, dtype), + expected_flops=_entropy_loss_flops(tokens, vocab), + compute_dtype=dtype, + ) + for dtype in dtypes + for tokens, vocab in _SHAPES + ] + + +def _dist_cases(kernel_name: str, dtypes: tuple[torch.dtype, ...]) -> list[Case]: + return [ + Case( + name=case_name(kernel_name, (tokens, vocab), dtype), + make_inputs=(lambda t=tokens, v=vocab, d=dtype: _make_distribution_inputs(t, v, d)), + expected_bytes=_dist_loss_bytes(tokens, vocab, dtype), + expected_flops=_entropy_loss_flops(tokens, vocab), + compute_dtype=dtype, + ) + for dtype in dtypes + for tokens, vocab in _SHAPES + ] + + +# --------------------------------------------------------------------------- entry point + + +def run(verbose: bool = False, dtypes: tuple[torch.dtype, ...] | None = None) -> None: + dtypes = tuple(dtypes) if dtypes else _DEFAULT_DTYPES + run_benchmark( + "entropy_loss: cross_entropy (labels)", + _label_cases("cross_entropy_labels", dtypes), + _ce_labels_variants(), + verbose=verbose, + ) + run_benchmark( + "entropy_loss: cross_entropy (logits)", + _dist_cases("cross_entropy_logits", dtypes), + _ce_dist_variants(), + verbose=verbose, + ) + run_benchmark( + "entropy_loss: reverse_kl (logits)", + _dist_cases("reverse_kl_logits", dtypes), + _reverse_kl_variants(), + verbose=verbose, + ) + run_benchmark("entropy_loss: z_loss", _label_cases("z_loss", dtypes), _z_loss_variants(), verbose=verbose) + + +if __name__ == "__main__": + run() diff --git a/tools/benchmark/bench_mlp_activation.py b/tools/benchmark/bench_mlp_activation.py new file mode 100644 index 000000000..1e2df9831 --- /dev/null +++ b/tools/benchmark/bench_mlp_activation.py @@ -0,0 +1,180 @@ +""" +Benchmark the fused MLP activation kernel. + +The Triton kernel (`triton_mlp_activation_autograd`) fuses the element-wise +activation and (for gated models) the gated multiply into a single pass. For +gated SiLU the fwd input is (tokens, 2*ffn_dim) — [gate_proj, up_proj] +concatenated — and the output is (tokens, ffn_dim). + +Comparisons: +- fp32_reference: torch_mlp_activation in fp32 with autograd +- pytorch_eager: torch_mlp_activation in compute dtype +- pytorch_compiled / pytorch_compiled_max: torch.compile of the above +- fast_llm_triton: triton_mlp_activation_autograd + +Shapes fix tokens=8192 and sweep ffn_dim across typical MLP widths. +""" + +import torch + +from fast_llm.functional.config import ActivationType, TritonConfig +from fast_llm.functional.triton.mlp import ( + torch_mlp_activation, + triton_mlp_activation_autograd, + triton_mlp_activation_forward, +) +from tools.benchmark.runner import Case, Variant, run_benchmark +from tools.benchmark.utils import case_name, device + +# (tokens, ffn_dim) — input tensor has shape (tokens, 2*ffn_dim) for gated. +_SHAPES = [ + (8192, 4096), # 7B/13B models + (8192, 8192), # large + (8192, 14336), # 70B models + (4096, 28672), # MoE up-projection +] +_ACTIVATION = ActivationType.silu # standard for Llama-style gated models +_DEFAULT_DTYPES = (torch.bfloat16,) + + +# --------------------------------------------------------------------------- inputs + + +def _make_mlp_inputs(tokens: int, ffn_dim: int, dtype: torch.dtype) -> dict: + return { + "input_": torch.randn(tokens, 2 * ffn_dim, dtype=dtype, device=device(), requires_grad=True), + "grad_output": torch.randn(tokens, ffn_dim, dtype=dtype, device=device()), + "gated": True, + "activation_type": _ACTIVATION, + } + + +# --------------------------------------------------------------------------- forward wrappers + + +def _pytorch_fwd(input_: torch.Tensor, gated: bool, activation_type: ActivationType) -> torch.Tensor: + return torch_mlp_activation(input_, gated, activation_type) + + +_pytorch_compiled_default = torch.compile(_pytorch_fwd, mode="default", dynamic=False) +_pytorch_compiled_max = torch.compile(_pytorch_fwd, mode="max-autotune-no-cudagraphs", dynamic=False) + + +def _run_fwd(inp: dict, fn) -> dict: + return {"output": fn(inp["input_"], inp["gated"], inp["activation_type"])} + + +def _run_fwd_fp32(inp: dict) -> dict: + return {"output": _pytorch_fwd(inp["input_"].float(), inp["gated"], inp["activation_type"])} + + +def _run_fwd_triton(inp: dict) -> dict: + output, _ = triton_mlp_activation_forward(inp["input_"], inp["gated"], inp["activation_type"]) + return {"output": output} + + +# --------------------------------------------------------------------------- fwd+bwd wrappers + + +def _run_fwd_bwd(inp: dict, fn) -> dict: + output = fn(inp["input_"], inp["gated"], inp["activation_type"]) + output.backward(inp["grad_output"]) + return {"output": output.detach(), "grad_input": inp["input_"].grad} + + +def _run_fwd_bwd_fp32(inp: dict) -> dict: + input_fp32 = inp["input_"].float().detach().requires_grad_(True) + output = _pytorch_fwd(input_fp32, inp["gated"], inp["activation_type"]) + output.backward(inp["grad_output"].float()) + return {"output": output.detach(), "grad_input": input_fp32.grad} + + +def _run_fwd_bwd_triton(inp: dict) -> dict: + output = triton_mlp_activation_autograd(inp["input_"], inp["gated"], inp["activation_type"]) + output.backward(inp["grad_output"]) + return {"output": output.detach(), "grad_input": inp["input_"].grad} + + +# --------------------------------------------------------------------------- variants + + +def _mlp_activation_variants() -> list[Variant]: + variants = [ + Variant( + name="fp32_reference", + fwd=_run_fwd_fp32, + fwd_bwd=_run_fwd_bwd_fp32, + is_reference=True, + ), + Variant( + name="pytorch_eager", + fwd=lambda inp: _run_fwd(inp, _pytorch_fwd), + fwd_bwd=lambda inp: _run_fwd_bwd(inp, _pytorch_fwd), + ), + Variant( + name="pytorch_compiled", + fwd=lambda inp: _run_fwd(inp, _pytorch_compiled_default), + fwd_bwd=lambda inp: _run_fwd_bwd(inp, _pytorch_compiled_default), + ), + Variant( + name="pytorch_compiled_max", + fwd=lambda inp: _run_fwd(inp, _pytorch_compiled_max), + fwd_bwd=lambda inp: _run_fwd_bwd(inp, _pytorch_compiled_max), + ), + ] + if TritonConfig.enabled(): + variants.append( + Variant( + name="fast_llm_triton", + fwd=_run_fwd_triton, + fwd_bwd=_run_fwd_bwd_triton, + ) + ) + return variants + + +# --------------------------------------------------------------------------- cases + + +def _bytes_per_elem(dtype: torch.dtype) -> int: + return torch.tensor([], dtype=dtype).element_size() + + +def _mlp_activation_bytes(tokens: int, ffn_dim: int, dtype: torch.dtype) -> int: + """fwd: read input (2*ffn_dim) + write output (ffn_dim). + bwd: read grad_output (ffn_dim) + read input (2*ffn_dim) + write grad_input (2*ffn_dim). + Total: 8 × tokens × ffn_dim × elem_size.""" + return 8 * tokens * ffn_dim * _bytes_per_elem(dtype) + + +def _mlp_activation_flops(tokens: int, ffn_dim: int) -> int: + # gated silu: fwd ≈ 6 FLOPs/elem, bwd ≈ 8 FLOPs/elem, total ≈ 14 per output element. + return 14 * tokens * ffn_dim + + +def _mlp_activation_cases(dtypes: tuple[torch.dtype, ...]) -> list[Case]: + return [ + Case( + name=case_name("mlp_activation", (tokens, ffn_dim), dtype), + make_inputs=(lambda t=tokens, f=ffn_dim, d=dtype: _make_mlp_inputs(t, f, d)), + expected_bytes=_mlp_activation_bytes(tokens, ffn_dim, dtype), + expected_flops=_mlp_activation_flops(tokens, ffn_dim), + compute_dtype=dtype, + ) + for dtype in dtypes + for tokens, ffn_dim in _SHAPES + ] + + +# --------------------------------------------------------------------------- entry point + + +def run(verbose: bool = False, dtypes: tuple[torch.dtype, ...] | None = None) -> None: + dtypes = tuple(dtypes) if dtypes else _DEFAULT_DTYPES + run_benchmark( + "mlp_activation (gated silu)", _mlp_activation_cases(dtypes), _mlp_activation_variants(), verbose=verbose + ) + + +if __name__ == "__main__": + run() diff --git a/tools/benchmark/bench_normalization.py b/tools/benchmark/bench_normalization.py new file mode 100644 index 000000000..066c53676 --- /dev/null +++ b/tools/benchmark/bench_normalization.py @@ -0,0 +1,346 @@ +""" +Benchmark normalization kernels: LayerNorm and RMSNorm. + +Both are fwd+bwd kernels. The Triton implementation in +`fast_llm/functional/triton/normalization.py` handles both flavors via the +`bias` argument (LayerNorm when given, RMSNorm when None) and writes parameter +gradients to Fast-LLM's `grad_buffer` attribute rather than autograd's `.grad`. + +Comparisons: +- fp32_reference: torch.{layer,rms}_norm in fp32 (eager) +- pytorch_eager: torch.{layer,rms}_norm in the case dtype +- pytorch_compiled / pytorch_compiled_max: torch.compile of the above +- apex_fused: Apex fused_layer_norm_cuda (all widths, layer+rms norm) +- apex_fast: Apex fast_layer_norm contrib (layer norm only, restricted widths) +- fast_llm_triton: triton_normalization_autograd +""" + +import torch + +from fast_llm.functional.config import TritonConfig +from fast_llm.functional.triton.normalization import triton_normalization_autograd +from fast_llm.layers.common.normalization.normalization import ( + FastLayerNorm, + FusedLayerNorm, + FusedRMSNorm, + _fast_normalization_available, + _fused_normalization_available, +) +from tools.benchmark.runner import Case, Variant, run_benchmark +from tools.benchmark.utils import case_name, device + +# Activation shape (batch*seq, hidden). Numel fixed at 32M to mimic a constant +# training memory budget across model widths; hidden swept from 1K to 16K covers +# small models through Llama-405B / wide-MoE territory. +_SHAPES = [ + (32768, 1024), + (16384, 2048), + (8192, 4096), + (4096, 8192), + (2048, 16384), +] +_DEFAULT_DTYPES = (torch.bfloat16,) +_EPS = 1e-5 + + +# --------------------------------------------------------------------------- input setup + + +def _setup_param(tensor: torch.Tensor) -> torch.Tensor: + """Triton's normalization backward writes weight/bias gradients to a + `grad_buffer` attribute (Fast-LLM convention) instead of autograd's `.grad`. + Wire up the buffer + zero-flag the kernel expects.""" + tensor.grad_buffer = torch.zeros_like(tensor) + tensor.param_grad_is_zero = True + return tensor + + +def _to_fp32_input(tensor: torch.Tensor) -> torch.Tensor: + return tensor.float().detach().requires_grad_() + + +def _to_fp32_param(tensor: torch.Tensor) -> torch.Tensor: + return _setup_param(tensor.float().detach().requires_grad_()) + + +def _make_layer_norm_inputs(rows: int, cols: int, dtype: torch.dtype) -> dict: + return { + "input_": torch.randn(rows, cols, dtype=dtype, device=device(), requires_grad=True), + "weight": _setup_param(torch.randn(cols, dtype=dtype, device=device(), requires_grad=True)), + "bias": _setup_param(torch.zeros(cols, dtype=dtype, device=device(), requires_grad=True)), + "grad_output": torch.randn(rows, cols, dtype=dtype, device=device()), + } + + +def _make_rms_norm_inputs(rows: int, cols: int, dtype: torch.dtype) -> dict: + return { + "input_": torch.randn(rows, cols, dtype=dtype, device=device(), requires_grad=True), + "weight": _setup_param(torch.randn(cols, dtype=dtype, device=device(), requires_grad=True)), + "grad_output": torch.randn(rows, cols, dtype=dtype, device=device()), + } + + +def _layer_norm_inputs_fp32(inp: dict) -> dict: + return { + "input_": _to_fp32_input(inp["input_"]), + "weight": _to_fp32_param(inp["weight"]), + "bias": _to_fp32_param(inp["bias"]), + "grad_output": inp["grad_output"].float(), + } + + +def _rms_norm_inputs_fp32(inp: dict) -> dict: + return { + "input_": _to_fp32_input(inp["input_"]), + "weight": _to_fp32_param(inp["weight"]), + "grad_output": inp["grad_output"].float(), + } + + +# --------------------------------------------------------------------------- forward functions + + +def _layer_norm_eager(input_, weight, bias): + return torch.layer_norm(input_, weight.shape, weight, bias, _EPS) + + +def _rms_norm_eager(input_, weight): + return torch.rms_norm(input_, weight.shape, weight, _EPS) + + +def _layer_norm_triton(input_, weight, bias): + return triton_normalization_autograd(input_, weight, bias, _EPS, True, False) + + +def _rms_norm_triton(input_, weight): + return triton_normalization_autograd(input_, weight, None, _EPS, True, False) + + +def _layer_norm_apex_fused(input_, weight, bias): + return FusedLayerNorm.apply(input_, weight.shape, weight, bias, _EPS) + + +def _layer_norm_apex_fast(input_, weight, bias): + return FastLayerNorm.apply(input_, weight.shape, weight, bias, _EPS) + + +def _rms_norm_apex_fused(input_, weight): + return FusedRMSNorm.apply(input_, weight.shape, weight, _EPS) + + +_layer_compiled_default = torch.compile(_layer_norm_eager, mode="default", dynamic=False) +_layer_compiled_max = torch.compile(_layer_norm_eager, mode="max-autotune-no-cudagraphs", dynamic=False) +_rms_compiled_default = torch.compile(_rms_norm_eager, mode="default", dynamic=False) +_rms_compiled_max = torch.compile(_rms_norm_eager, mode="max-autotune-no-cudagraphs", dynamic=False) + + +# --------------------------------------------------------------------------- variant wrappers + + +def _param_grad(param: torch.Tensor) -> torch.Tensor: + """Pull the parameter gradient from wherever the kernel wrote it. + Triton writes to `grad_buffer`; autograd writes to `.grad`.""" + return param.grad if param.grad is not None else param.grad_buffer + + +def _run_layer_fwd(inp: dict, fn) -> dict: + return {"output": fn(inp["input_"], inp["weight"], inp["bias"])} + + +def _run_layer_fwd_bwd(inp: dict, fn) -> dict: + output = fn(inp["input_"], inp["weight"], inp["bias"]) + output.backward(inp["grad_output"]) + return { + "grad_input": inp["input_"].grad, + "grad_weight": _param_grad(inp["weight"]), + "grad_bias": _param_grad(inp["bias"]), + } + + +def _run_rms_fwd(inp: dict, fn) -> dict: + return {"output": fn(inp["input_"], inp["weight"])} + + +def _run_rms_fwd_bwd(inp: dict, fn) -> dict: + output = fn(inp["input_"], inp["weight"]) + output.backward(inp["grad_output"]) + return { + "grad_input": inp["input_"].grad, + "grad_weight": _param_grad(inp["weight"]), + } + + +# --------------------------------------------------------------------------- variants + + +def _layer_norm_variants() -> list[Variant]: + variants = [ + Variant( + name="fp32_reference", + fwd=lambda inp: _run_layer_fwd(_layer_norm_inputs_fp32(inp), _layer_norm_eager), + fwd_bwd=lambda inp: _run_layer_fwd_bwd(_layer_norm_inputs_fp32(inp), _layer_norm_eager), + is_reference=True, + ), + Variant( + name="pytorch_eager", + fwd=lambda inp: _run_layer_fwd(inp, _layer_norm_eager), + fwd_bwd=lambda inp: _run_layer_fwd_bwd(inp, _layer_norm_eager), + ), + Variant( + name="pytorch_compiled", + fwd=lambda inp: _run_layer_fwd(inp, _layer_compiled_default), + fwd_bwd=lambda inp: _run_layer_fwd_bwd(inp, _layer_compiled_default), + ), + Variant( + name="pytorch_compiled_max", + fwd=lambda inp: _run_layer_fwd(inp, _layer_compiled_max), + fwd_bwd=lambda inp: _run_layer_fwd_bwd(inp, _layer_compiled_max), + ), + ] + if _fused_normalization_available: + variants.append( + Variant( + name="apex_fused", + fwd=lambda inp: _run_layer_fwd(inp, _layer_norm_apex_fused), + fwd_bwd=lambda inp: _run_layer_fwd_bwd(inp, _layer_norm_apex_fused), + ) + ) + if _fast_normalization_available: + # apex_fast only supports widths in _PERSIST_LN_SIZES; all shapes in _SHAPES qualify. + variants.append( + Variant( + name="apex_fast", + fwd=lambda inp: _run_layer_fwd(inp, _layer_norm_apex_fast), + fwd_bwd=lambda inp: _run_layer_fwd_bwd(inp, _layer_norm_apex_fast), + ) + ) + if TritonConfig.enabled(): + variants.append( + Variant( + name="fast_llm_triton", + fwd=lambda inp: _run_layer_fwd(inp, _layer_norm_triton), + fwd_bwd=lambda inp: _run_layer_fwd_bwd(inp, _layer_norm_triton), + ) + ) + return variants + + +def _rms_norm_variants() -> list[Variant]: + variants = [ + Variant( + name="fp32_reference", + fwd=lambda inp: _run_rms_fwd(_rms_norm_inputs_fp32(inp), _rms_norm_eager), + fwd_bwd=lambda inp: _run_rms_fwd_bwd(_rms_norm_inputs_fp32(inp), _rms_norm_eager), + is_reference=True, + ), + Variant( + name="pytorch_eager", + fwd=lambda inp: _run_rms_fwd(inp, _rms_norm_eager), + fwd_bwd=lambda inp: _run_rms_fwd_bwd(inp, _rms_norm_eager), + ), + Variant( + name="pytorch_compiled", + fwd=lambda inp: _run_rms_fwd(inp, _rms_compiled_default), + fwd_bwd=lambda inp: _run_rms_fwd_bwd(inp, _rms_compiled_default), + ), + Variant( + name="pytorch_compiled_max", + fwd=lambda inp: _run_rms_fwd(inp, _rms_compiled_max), + fwd_bwd=lambda inp: _run_rms_fwd_bwd(inp, _rms_compiled_max), + ), + ] + if _fused_normalization_available: + variants.append( + Variant( + name="apex_fused", + fwd=lambda inp: _run_rms_fwd(inp, _rms_norm_apex_fused), + fwd_bwd=lambda inp: _run_rms_fwd_bwd(inp, _rms_norm_apex_fused), + ) + ) + if TritonConfig.enabled(): + variants.append( + Variant( + name="fast_llm_triton", + fwd=lambda inp: _run_rms_fwd(inp, _rms_norm_triton), + fwd_bwd=lambda inp: _run_rms_fwd_bwd(inp, _rms_norm_triton), + ) + ) + return variants + + +# --------------------------------------------------------------------------- cases + + +def _bytes_per_elem(dtype: torch.dtype) -> int: + return torch.tensor([], dtype=dtype).element_size() + + +def _layer_norm_bytes(rows: int, cols: int, dtype: torch.dtype) -> int: + """Approximate fwd+bwd memory traffic for LayerNorm. + fwd reads input + weight + bias and writes output (also stores inv_var). + bwd reads grad_output, output, weight, bias, inv_var; writes grad_input, + grad_weight, grad_bias. Activation tensors dominate.""" + elem = _bytes_per_elem(dtype) + activations = 4 * rows * cols * elem # fwd in/out + bwd grad_in/out + parameters = 6 * cols * elem # weight, bias × (read + grad write) twice + return activations + parameters + + +def _rms_norm_bytes(rows: int, cols: int, dtype: torch.dtype) -> int: + elem = _bytes_per_elem(dtype) + activations = 4 * rows * cols * elem + parameters = 3 * cols * elem + return activations + parameters + + +def _layer_norm_flops(rows: int, cols: int) -> int: + """Approximate fwd+bwd FLOPs for LayerNorm. + fwd: mean (1), variance (2), normalize (2), scale+shift (2) ≈ 7 per element. + bwd: ~2x fwd.""" + return 21 * rows * cols + + +def _rms_norm_flops(rows: int, cols: int) -> int: + """Same idea as LayerNorm but no mean subtraction or bias.""" + return 15 * rows * cols + + +def _layer_norm_cases(dtypes: tuple[torch.dtype, ...]) -> list[Case]: + return [ + Case( + name=case_name("layer_norm", shape, dtype), + make_inputs=(lambda s=shape, d=dtype: _make_layer_norm_inputs(s[0], s[1], d)), + expected_bytes=_layer_norm_bytes(shape[0], shape[1], dtype), + expected_flops=_layer_norm_flops(shape[0], shape[1]), + compute_dtype=dtype, + ) + for dtype in dtypes + for shape in _SHAPES + ] + + +def _rms_norm_cases(dtypes: tuple[torch.dtype, ...]) -> list[Case]: + return [ + Case( + name=case_name("rms_norm", shape, dtype), + make_inputs=(lambda s=shape, d=dtype: _make_rms_norm_inputs(s[0], s[1], d)), + expected_bytes=_rms_norm_bytes(shape[0], shape[1], dtype), + expected_flops=_rms_norm_flops(shape[0], shape[1]), + compute_dtype=dtype, + ) + for dtype in dtypes + for shape in _SHAPES + ] + + +# --------------------------------------------------------------------------- entry point + + +def run(verbose: bool = False, dtypes: tuple[torch.dtype, ...] | None = None) -> None: + dtypes = tuple(dtypes) if dtypes else _DEFAULT_DTYPES + run_benchmark("normalization: layer_norm", _layer_norm_cases(dtypes), _layer_norm_variants(), verbose=verbose) + run_benchmark("normalization: rms_norm", _rms_norm_cases(dtypes), _rms_norm_variants(), verbose=verbose) + + +if __name__ == "__main__": + run() diff --git a/tools/benchmark/bench_pointwise.py b/tools/benchmark/bench_pointwise.py new file mode 100644 index 000000000..b2ea1f00e --- /dev/null +++ b/tools/benchmark/bench_pointwise.py @@ -0,0 +1,141 @@ +""" +Benchmark pointwise kernels: copy, fill, add. + +These kernels are pure bandwidth-bound: runtime is dominated by reading inputs +and writing outputs, so GB/s and %-of-peak-BW are the headline metrics. The +Triton kernels live in `fast_llm/functional/triton/pointwise.py` and are +documented as being ~2x faster than the PyTorch equivalent on A100. +""" + +import torch + +from fast_llm.functional.triton.pointwise import triton_add, triton_copy, triton_fill +from tools.benchmark.runner import Case, run_benchmark +from tools.benchmark.utils import case_name, device, standard_fwd_variants + +# Sizes span from L2-resident to comfortably HBM-bound, in 4× steps so the +# regime transitions (L2 → HBM, mid-HBM → saturated-HBM) are visible. +_SIZES_NUMEL = [ + 1 << 20, # 1M — 2 MiB bf16 (L2-resident on most GPUs) + 1 << 22, # 4M — 8 MiB bf16 (L2 boundary) + 1 << 24, # 16M — 32 MiB bf16 (HBM) + 1 << 26, # 64M — 128 MiB bf16 (HBM) + 1 << 28, # 256M — 512 MiB bf16 (large HBM, near-saturated) +] +_DEFAULT_DTYPES = (torch.bfloat16,) + + +# --------------------------------------------------------------------------- copy + + +def _copy_eager(input_: torch.Tensor, out: torch.Tensor) -> torch.Tensor: + return out.copy_(input_) + + +def _make_copy_inputs(numel: int, dtype: torch.dtype) -> dict: + input_ = torch.randn(numel, dtype=dtype, device=device()) + out = torch.empty_like(input_) + return {"input_": input_, "out": out} + + +def _copy_cases(dtypes: tuple[torch.dtype, ...]) -> list[Case]: + return [ + Case( + name=case_name("copy", (numel,), dtype), + make_inputs=(lambda n=numel, d=dtype: _make_copy_inputs(n, d)), + # Read input + write output. + expected_bytes=2 * numel * torch.tensor([], dtype=dtype).element_size(), + ) + for dtype in dtypes + for numel in _SIZES_NUMEL + ] + + +_COPY_VARIANTS = standard_fwd_variants( + eager_fn=_copy_eager, + triton_fn=triton_copy, + unpack=lambda inp: (inp["input_"], inp["out"]), +) + + +# --------------------------------------------------------------------------- fill + + +def _fill_eager(input_: torch.Tensor, value: float) -> torch.Tensor: + return input_.fill_(value) + + +def _make_fill_inputs(numel: int, dtype: torch.dtype) -> dict: + return {"input_": torch.empty(numel, dtype=dtype, device=device()), "value": 1.5} + + +def _fill_cases(dtypes: tuple[torch.dtype, ...]) -> list[Case]: + return [ + Case( + name=case_name("fill", (numel,), dtype), + make_inputs=(lambda n=numel, d=dtype: _make_fill_inputs(n, d)), + # Write only. + expected_bytes=numel * torch.tensor([], dtype=dtype).element_size(), + ) + for dtype in dtypes + for numel in _SIZES_NUMEL + ] + + +_FILL_VARIANTS = standard_fwd_variants( + eager_fn=_fill_eager, + triton_fn=triton_fill, + unpack=lambda inp: (inp["input_"], inp["value"]), +) + + +# --------------------------------------------------------------------------- add + + +def _add_eager(input_: torch.Tensor, other: torch.Tensor, out: torch.Tensor) -> torch.Tensor: + return torch.add(input_, other, out=out) + + +def _make_add_inputs(numel: int, dtype: torch.dtype) -> dict: + return { + "input_": torch.randn(numel, dtype=dtype, device=device()), + "other": torch.randn(numel, dtype=dtype, device=device()), + "out": torch.empty(numel, dtype=dtype, device=device()), + } + + +def _add_cases(dtypes: tuple[torch.dtype, ...]) -> list[Case]: + return [ + Case( + name=case_name("add", (numel,), dtype), + make_inputs=(lambda n=numel, d=dtype: _make_add_inputs(n, d)), + # Read 2 inputs + write 1 output. + expected_bytes=3 * numel * torch.tensor([], dtype=dtype).element_size(), + # One fp add per element. + expected_flops=numel, + compute_dtype=dtype, + ) + for dtype in dtypes + for numel in _SIZES_NUMEL + ] + + +_ADD_VARIANTS = standard_fwd_variants( + eager_fn=_add_eager, + triton_fn=triton_add, + unpack=lambda inp: (inp["input_"], inp["other"], inp["out"]), +) + + +# --------------------------------------------------------------------------- entry point + + +def run(verbose: bool = False, dtypes: tuple[torch.dtype, ...] | None = None) -> None: + dtypes = tuple(dtypes) if dtypes else _DEFAULT_DTYPES + run_benchmark("pointwise: copy", _copy_cases(dtypes), _COPY_VARIANTS, verbose=verbose) + run_benchmark("pointwise: fill", _fill_cases(dtypes), _FILL_VARIANTS, verbose=verbose) + run_benchmark("pointwise: add", _add_cases(dtypes), _ADD_VARIANTS, verbose=verbose) + + +if __name__ == "__main__": + run() diff --git a/tools/benchmark/bench_rotary.py b/tools/benchmark/bench_rotary.py new file mode 100644 index 000000000..4ca56662f --- /dev/null +++ b/tools/benchmark/bench_rotary.py @@ -0,0 +1,117 @@ +""" +Benchmark rotary position embeddings. + +The Triton kernel (`triton_rotary_`) operates in-place on (tokens, num_heads, +head_size) tensors, loading pre-computed (cos, sin) frequencies from +(tokens, 2*rotary_dim). The backward is an identical rotation call with +conjugated frequencies — same cost — so only fwd is benchmarked. + +Shapes sweep (tokens, num_heads, head_size) across typical attention configs: +- 32 heads × 128 → 7B/13B models +- 64 heads × 128 → 70B / MoE models +- 8 heads × 128 → GQA key-value heads (Llama 3) +""" + +import torch + +from fast_llm.functional.config import TritonConfig +from fast_llm.functional.triton.rotary import triton_rotary_ +from tools.benchmark.runner import Case, Variant, run_benchmark +from tools.benchmark.utils import case_name, device + +# (tokens, num_heads, head_size) — tokens = batch * seq_len +_SHAPES = [ + (4096, 32, 128), # 7B/13B, 4K context + (8192, 32, 128), # 7B/13B, 8K context + (4096, 64, 128), # 70B / MoE, 4K context + (4096, 8, 128), # GQA key-value heads, 4K context +] +_DEFAULT_DTYPES = (torch.bfloat16,) + + +def _make_rotary_inputs(tokens: int, num_heads: int, head_size: int, dtype: torch.dtype) -> dict: + rotary_dim = head_size // 2 + return { + "input_": torch.randn(tokens, num_heads, head_size, dtype=dtype, device=device()), + "frequencies": torch.randn(tokens, 2 * rotary_dim, dtype=torch.float32, device=device()), + } + + +def _rotary_eager(input_: torch.Tensor, frequencies: torch.Tensor) -> torch.Tensor: + """Non-in-place full rotary (rotary_dim = head_size / 2).""" + rotary_dim = frequencies.shape[-1] // 2 + freq_re = frequencies[:, :rotary_dim].unsqueeze(1) # (tokens, 1, rotary_dim) + freq_im = frequencies[:, rotary_dim:].unsqueeze(1) + x_re, x_im = input_.chunk(2, dim=-1) + out_re = x_re * freq_re - x_im * freq_im + out_im = x_im * freq_re + x_re * freq_im + return torch.cat([out_re, out_im], dim=-1) + + +_rotary_compiled_default = torch.compile(_rotary_eager, mode="default", dynamic=False) +_rotary_compiled_max = torch.compile(_rotary_eager, mode="max-autotune-no-cudagraphs", dynamic=False) + + +def _rotary_bytes(tokens: int, num_heads: int, head_size: int, dtype: torch.dtype) -> int: + elem = torch.tensor([], dtype=dtype).element_size() + # Read + write input tensor; frequencies are float32. + return 2 * tokens * num_heads * head_size * elem + tokens * head_size * 4 + + +def _rotary_flops(tokens: int, num_heads: int, head_size: int) -> int: + # 6 FLOPs per (re, im) element pair: 4 muls + 2 add/sub. + return 6 * tokens * num_heads * (head_size // 2) + + +def _rotary_variants() -> list[Variant]: + variants = [ + Variant( + name="fp32_reference", + fwd=lambda inp: {"output": _rotary_eager(inp["input_"].float(), inp["frequencies"])}, + is_reference=True, + ), + Variant( + name="pytorch_eager", + fwd=lambda inp: {"output": _rotary_eager(inp["input_"], inp["frequencies"])}, + ), + Variant( + name="pytorch_compiled", + fwd=lambda inp: {"output": _rotary_compiled_default(inp["input_"], inp["frequencies"])}, + ), + Variant( + name="pytorch_compiled_max", + fwd=lambda inp: {"output": _rotary_compiled_max(inp["input_"], inp["frequencies"])}, + ), + ] + if TritonConfig.enabled(): + # triton_rotary_ is in-place; clone so the benchmark input stays intact. + variants.append( + Variant( + name="fast_llm_triton", + fwd=lambda inp: {"output": triton_rotary_(inp["input_"].clone(), inp["frequencies"])}, + ) + ) + return variants + + +def _rotary_cases(dtypes: tuple[torch.dtype, ...]) -> list[Case]: + return [ + Case( + name=case_name("rotary", (tokens, num_heads, head_size), dtype), + make_inputs=(lambda t=tokens, h=num_heads, s=head_size, d=dtype: _make_rotary_inputs(t, h, s, d)), + expected_bytes=_rotary_bytes(tokens, num_heads, head_size, dtype), + expected_flops=_rotary_flops(tokens, num_heads, head_size), + compute_dtype=dtype, + ) + for dtype in dtypes + for tokens, num_heads, head_size in _SHAPES + ] + + +def run(verbose: bool = False, dtypes: tuple[torch.dtype, ...] | None = None) -> None: + dtypes = tuple(dtypes) if dtypes else _DEFAULT_DTYPES + run_benchmark("rotary", _rotary_cases(dtypes), _rotary_variants(), verbose=verbose) + + +if __name__ == "__main__": + run() diff --git a/tools/benchmark/gpu_specs.py b/tools/benchmark/gpu_specs.py new file mode 100644 index 000000000..563348dcd --- /dev/null +++ b/tools/benchmark/gpu_specs.py @@ -0,0 +1,99 @@ +""" +Peak memory bandwidth and compute for high-end datacenter GPUs, used to report +%-of-peak. Values are vendor-published nominal dense peaks (no 2:4 sparsity); +real-world achievable is typically ~80-90% of nominal for bandwidth and +~70-85% for compute. +""" + +import dataclasses + +import torch + + +@dataclasses.dataclass(frozen=True) +class GpuSpec: + name: str + peak_bandwidth_gbps: float + peak_tflops_fp32: float + peak_tflops_tf32: float + peak_tflops_fp16: float + peak_tflops_bf16: float + + def peak_tflops(self, dtype: torch.dtype) -> float | None: + if dtype == torch.float32: + return self.peak_tflops_fp32 + if dtype == torch.float16: + return self.peak_tflops_fp16 + if dtype == torch.bfloat16: + return self.peak_tflops_bf16 + return None + + +# Name-substring match → spec. First match wins. Dense (non-sparse) peaks. +_GPU_SPECS: list[tuple[str, GpuSpec]] = [ + ( + "B200", + GpuSpec( + name="NVIDIA B200 SXM", + peak_bandwidth_gbps=8000.0, + peak_tflops_fp32=80.0, + peak_tflops_tf32=1100.0, + peak_tflops_fp16=2250.0, + peak_tflops_bf16=2250.0, + ), + ), + ( + "B100", + GpuSpec( + name="NVIDIA B100 SXM", + peak_bandwidth_gbps=8000.0, + peak_tflops_fp32=60.0, + peak_tflops_tf32=880.0, + peak_tflops_fp16=1800.0, + peak_tflops_bf16=1800.0, + ), + ), + ( + "H200", + GpuSpec( + name="NVIDIA H200 SXM", + peak_bandwidth_gbps=4800.0, + peak_tflops_fp32=67.0, + peak_tflops_tf32=494.0, + peak_tflops_fp16=989.0, + peak_tflops_bf16=989.0, + ), + ), + ( + "H100", + GpuSpec( + name="NVIDIA H100 SXM", + peak_bandwidth_gbps=3350.0, + peak_tflops_fp32=67.0, + peak_tflops_tf32=494.0, + peak_tflops_fp16=989.0, + peak_tflops_bf16=989.0, + ), + ), + ( + "A100", + GpuSpec( + name="NVIDIA A100 80GB SXM", + peak_bandwidth_gbps=2039.0, + peak_tflops_fp32=19.5, + peak_tflops_tf32=156.0, + peak_tflops_fp16=312.0, + peak_tflops_bf16=312.0, + ), + ), +] + + +def detect_gpu_spec() -> GpuSpec | None: + if not torch.cuda.is_available(): + return None + name = torch.cuda.get_device_name() + for needle, spec in _GPU_SPECS: + if needle in name: + return spec + return None diff --git a/tools/benchmark/runner.py b/tools/benchmark/runner.py new file mode 100644 index 000000000..2836787f6 --- /dev/null +++ b/tools/benchmark/runner.py @@ -0,0 +1,650 @@ +""" +Core benchmarking infrastructure for Fast-LLM Triton kernels. + +Each benchmark file defines a list of `Case` objects (input shape/dtype +sweep) and a list of `Variant` objects (implementations to compare — e.g. +pytorch eager, pytorch compiled, Triton). The runner invokes each variant +on each case, measures timing (median + mean + percentiles via CUDA events), +measures peak/final memory, and compares outputs against an fp32 reference +using RMS error. Results are printed as a table per case. +""" + +import dataclasses +import gc +import math +import statistics +import time +from collections.abc import Callable +from typing import Any + +import torch + +from fast_llm.utils import header +from tools.benchmark.gpu_specs import GpuSpec, detect_gpu_spec + +# Before each timed CUDA-graph-backed call we must mark a new step so the graph +# system knows the previous rep's output buffers are no longer live. Without +# this, max-autotune compiled functions raise "tensor output of CUDAGraphs has +# been overwritten by a subsequent run" on the second call. +_cudagraph_mark_step_begin: Callable[[], None] | None = getattr( + getattr(torch, "compiler", None), "cudagraph_mark_step_begin", None +) + + +def _guarded(fn: Callable[[], Any]) -> Callable[[], Any]: + """Wrap fn so cudagraph_mark_step_begin() is called before each invocation. + This tells the CUDA graph system that previous outputs are no longer live + and can be overwritten, preventing 'overwritten by subsequent run' errors + when a max-autotune compiled function is called more than once. + When CUDA graphs are not in use the wrapper is a no-op pass-through.""" + if _cudagraph_mark_step_begin is None: + return fn + + def _wrapped() -> Any: + _cudagraph_mark_step_begin() + return fn() + + return _wrapped + + +Inputs = dict[str, Any] +VariantFn = Callable[[Inputs], Any] + + +@dataclasses.dataclass +class Variant: + """A single implementation being compared. Provide `fwd` for forward-only + timing. Provide `fwd_bwd` for forward+backward timing; when both are set, + backward-only time is reported as `fwd_bwd - fwd`.""" + + name: str + fwd: VariantFn | None = None + fwd_bwd: VariantFn | None = None + # The fp32 reference variant. Exactly one per benchmark; its outputs are + # the ground truth for RMS-error comparison. + is_reference: bool = False + + +@dataclasses.dataclass +class Case: + """A single input configuration for the kernel under test. `make_inputs` + builds fresh input tensors on demand. It is called once per variant per + mode, after a global seed reset, so every variant sees identical inputs.""" + + name: str + make_inputs: Callable[[], Inputs] + # Minimum bytes read+written by the op. Used for GB/s + %BW. Optional. + expected_bytes: int | None = None + # Minimum floating-point ops performed by the op. Used for TFLOP/s + %FLOPs. Optional. + expected_flops: int | None = None + # For %FLOPs: which peak column to use (dtype of the hot inputs). + compute_dtype: torch.dtype | None = None + + +def _seeded_inputs(case: Case, seed: int = 0) -> Inputs: + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(seed) + return case.make_inputs() + + +@dataclasses.dataclass +class TimingStats: + median_ms: float + mean_ms: float + min_ms: float + max_ms: float + std_ms: float + n_reps: int + + +@dataclasses.dataclass +class MemoryStats: + peak_mib: float + final_mib: float + delta_peak_mib: float + + +@dataclasses.dataclass +class VariantResult: + variant_name: str + fwd_timing: TimingStats | None = None + fwd_bwd_timing: TimingStats | None = None + memory: MemoryStats | None = None + rms_errors: dict[str, float] | None = None # output name → RMS rel error vs reference + error: str | None = None # If the variant failed, the error message + + +# --------------------------------------------------------------------------- timing + + +def _make_cache_flusher(size_bytes: int = 256 * 1024 * 1024) -> Callable[[], None]: + """Allocate a scratch buffer larger than any GPU L2 and zero it between reps + to invalidate cached values (avoids over-optimistic timings).""" + if not torch.cuda.is_available(): + return lambda: None + buffer = torch.empty(size_bytes // 4, dtype=torch.int32, device="cuda") + + def flush() -> None: + buffer.zero_() + + return flush + + +def bench_fn( + fn: Callable[[], Any], + warmup_ms: float = 25.0, + rep_ms: float = 100.0, + min_reps: int = 5, + max_reps: int = 10_000, +) -> TimingStats: + """Benchmark `fn` — it should be a no-arg callable that invokes the kernel + being timed (close over inputs). Returns timing statistics in ms. + + Mirrors `triton.testing.do_bench` logic but returns raw per-rep list so we + can compute {median, mean, min, max, std} from one set of runs. + """ + if not torch.cuda.is_available(): + # CPU / Triton interpret: single timed run with wall clock. + fn() # warmup + start = time.perf_counter() + fn() + elapsed_ms = (time.perf_counter() - start) * 1000 + return TimingStats(elapsed_ms, elapsed_ms, elapsed_ms, elapsed_ms, 0.0, 1) + + flush = _make_cache_flusher() + + # Warmup to JIT-compile, autotune, etc. + torch.cuda.synchronize() + warmup_start = torch.cuda.Event(enable_timing=True) + warmup_end = torch.cuda.Event(enable_timing=True) + warmup_start.record() + fn() + warmup_end.record() + torch.cuda.synchronize() + one_rep_ms = warmup_start.elapsed_time(warmup_end) + + # Additional warmup to stabilize (covers autotune misses on first call) + n_warmup = max(1, int(warmup_ms / max(one_rep_ms, 0.01))) + for _ in range(n_warmup): + fn() + torch.cuda.synchronize() + + # Re-estimate after warmup (autotune usually settles to a faster config). + post_start = torch.cuda.Event(enable_timing=True) + post_end = torch.cuda.Event(enable_timing=True) + post_start.record() + fn() + post_end.record() + torch.cuda.synchronize() + one_rep_ms = max(post_start.elapsed_time(post_end), 0.001) + + n_reps = max(min_reps, min(max_reps, int(rep_ms / one_rep_ms))) + + start_events = [torch.cuda.Event(enable_timing=True) for _ in range(n_reps)] + end_events = [torch.cuda.Event(enable_timing=True) for _ in range(n_reps)] + for i in range(n_reps): + flush() + start_events[i].record() + fn() + end_events[i].record() + torch.cuda.synchronize() + + times = [start_events[i].elapsed_time(end_events[i]) for i in range(n_reps)] + return TimingStats( + median_ms=statistics.median(times), + mean_ms=statistics.fmean(times), + min_ms=min(times), + max_ms=max(times), + std_ms=statistics.pstdev(times) if len(times) > 1 else 0.0, + n_reps=n_reps, + ) + + +# --------------------------------------------------------------------------- memory + + +def measure_memory(fn: Callable[[], Any]) -> MemoryStats: + """Run `fn` once and capture peak and final device memory. Must be called + on a fresh GPU state (the caller resets stats before constructing inputs).""" + if not torch.cuda.is_available(): + return MemoryStats(0.0, 0.0, 0.0) + torch.cuda.synchronize() + baseline = torch.cuda.memory_allocated() + torch.cuda.reset_peak_memory_stats() + result = fn() + torch.cuda.synchronize() + peak = torch.cuda.max_memory_allocated() + final = torch.cuda.memory_allocated() + # Hold onto the result until after the measurement so it stays in `final`. + del result + return MemoryStats( + peak_mib=peak / 1024 / 1024, + final_mib=final / 1024 / 1024, + delta_peak_mib=(peak - baseline) / 1024 / 1024, + ) + + +# --------------------------------------------------------------------------- correctness + + +def rms_relative_error(candidate: torch.Tensor, reference: torch.Tensor) -> float: + """Root-mean-squared error of `candidate - reference`, normalized by the + RMS of `reference`. Both are cast to fp32 before comparison.""" + cand = candidate.detach().float() + ref = reference.detach().float() + diff_rms = (cand - ref).pow(2).mean().sqrt().item() + ref_rms = ref.pow(2).mean().sqrt().item() + return diff_rms / max(ref_rms, 1e-30) + + +def _as_output_dict(output: Any) -> dict[str, torch.Tensor]: + """Normalize a variant's output into a {name: tensor} dict for comparison.""" + if isinstance(output, torch.Tensor): + return {"out": output} + if isinstance(output, dict): + return {k: v for k, v in output.items() if isinstance(v, torch.Tensor)} + if isinstance(output, (tuple, list)): + return {f"out{i}": v for i, v in enumerate(output) if isinstance(v, torch.Tensor)} + raise TypeError(f"Cannot extract tensors from variant output of type {type(output).__name__}") + + +# --------------------------------------------------------------------------- runner + + +def _run_one_variant( + variant: Variant, + case: Case, + reference_outputs: dict[str, dict[str, torch.Tensor]] | None, + warmup_ms: float, + rep_ms: float, +) -> VariantResult: + result = VariantResult(variant_name=variant.name) + try: + # --- correctness + memory: one fresh invocation per mode + # fwd mode + if variant.fwd is not None: + inputs = _seeded_inputs(case) + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + def _fwd_once() -> Any: + return variant.fwd(inputs) + + _guarded_fwd = _guarded(_fwd_once) + + # First: correctness. Run once, capture output for comparison. + fwd_output = _guarded_fwd() + if torch.cuda.is_available(): + torch.cuda.synchronize() + + if reference_outputs is not None and not variant.is_reference: + ref_fwd = reference_outputs.get("fwd") + if ref_fwd is not None: + cand = _as_output_dict(fwd_output) + rms = {name: rms_relative_error(cand[name], ref_fwd[name]) for name in ref_fwd if name in cand} + result.rms_errors = (result.rms_errors or {}) | {f"fwd.{k}": v for k, v in rms.items()} + del fwd_output + + # Timing: reuse the same input tensors, fn closes over them. + result.fwd_timing = bench_fn(_guarded_fwd, warmup_ms=warmup_ms, rep_ms=rep_ms) + del inputs + + # fwd+bwd mode + if variant.fwd_bwd is not None: + inputs = _seeded_inputs(case) + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + def _fwd_bwd_once() -> Any: + return variant.fwd_bwd(inputs) + + _guarded_fwd_bwd = _guarded(_fwd_bwd_once) + + fwd_bwd_output = _guarded_fwd_bwd() + if torch.cuda.is_available(): + torch.cuda.synchronize() + + if reference_outputs is not None and not variant.is_reference: + ref_fb = reference_outputs.get("fwd_bwd") + if ref_fb is not None: + cand = _as_output_dict(fwd_bwd_output) + rms = {name: rms_relative_error(cand[name], ref_fb[name]) for name in ref_fb if name in cand} + result.rms_errors = (result.rms_errors or {}) | {f"fb.{k}": v for k, v in rms.items()} + del fwd_bwd_output + + # Memory measurement: one fresh call on fresh inputs. + fresh_inputs = _seeded_inputs(case) + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + result.memory = measure_memory(_guarded(lambda: variant.fwd_bwd(fresh_inputs))) + del fresh_inputs + + # Timing. + result.fwd_bwd_timing = bench_fn(_guarded_fwd_bwd, warmup_ms=warmup_ms, rep_ms=rep_ms) + del inputs + elif variant.fwd is not None and result.memory is None: + # No backward — measure fwd-mode memory. + fresh_inputs = _seeded_inputs(case) + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + result.memory = measure_memory(_guarded(lambda: variant.fwd(fresh_inputs))) + del fresh_inputs + except Exception as exc: # noqa: BLE001 + result.error = f"{type(exc).__name__}: {exc}" + return result + + +def _collect_reference_outputs( + variant: Variant, + case: Case, +) -> dict[str, dict[str, torch.Tensor]]: + out: dict[str, dict[str, torch.Tensor]] = {} + if variant.fwd is not None: + out["fwd"] = _as_output_dict(variant.fwd(_seeded_inputs(case))) + if variant.fwd_bwd is not None: + out["fwd_bwd"] = _as_output_dict(variant.fwd_bwd(_seeded_inputs(case))) + if torch.cuda.is_available(): + torch.cuda.synchronize() + # Detach+clone to guard against in-place mutation by later variants + return {mode: {k: v.detach().clone() for k, v in tensors.items()} for mode, tensors in out.items()} + + +# --------------------------------------------------------------------------- table + + +def _column_decimals(values: list[float | None]) -> int: + """Number of decimal places to give the smallest non-zero value in a column + at least 4 significant digits. Capped at 6 so one tiny value doesn't bloat + the whole column (e.g. 0.00001 alongside 100 would otherwise force 8 decimals).""" + nonzero = [abs(v) for v in values if v is not None and v != 0] + if not nonzero: + return 0 + min_magnitude = math.floor(math.log10(min(nonzero))) + return min(6, max(0, 3 - min_magnitude)) + + +def _format_aligned(values: list[float | None]) -> list[str]: + """Format a column with the same number of decimals for every entry, so + decimal points line up. Zeros get the trailing zeros too (e.g. '0.0000').""" + decimals = _column_decimals(values) + out: list[str] = [] + for value in values: + if value is None: + out.append("—") + else: + out.append(f"{value:.{decimals}f}") + return out + + +def _pick_unit(max_value: float, table: list[tuple[str, float]]) -> tuple[str, float]: + """Given a magnitude-ordered list of (unit_label, scale_to_unit) pairs and + the column's max absolute value, return the unit where max_value*scale is + in [1, 1000) when possible. `table` must be ordered by ascending magnitude + (largest unit last).""" + chosen_label, chosen_scale = table[0] + for label, scale in table: + if max_value * scale >= 1: + chosen_label, chosen_scale = label, scale + else: + break + return chosen_label, chosen_scale + + +# Each table is ordered ascending (small unit → large unit). `scale` converts +# from the canonical storage unit (ms / bytes-per-second / flops-per-second / +# MiB) into the display unit. +_TIME_UNITS = [("ns", 1e6), ("us", 1e3), ("ms", 1.0), ("s", 1e-3)] +_BANDWIDTH_UNITS = [("B/s", 1.0), ("KB/s", 1e-3), ("MB/s", 1e-6), ("GB/s", 1e-9), ("TB/s", 1e-12)] +_FLOPS_UNITS = [ + ("FLOP/s", 1.0), + ("KFLOP/s", 1e-3), + ("MFLOP/s", 1e-6), + ("GFLOP/s", 1e-9), + ("TFLOP/s", 1e-12), + ("PFLOP/s", 1e-15), +] +_MEMORY_UNITS = [("KiB", 1024.0), ("MiB", 1.0), ("GiB", 1 / 1024), ("TiB", 1 / 1024 / 1024)] + + +def _unit_column( + prefix: str, canonical_values: list[float | None], units: list[tuple[str, float]] +) -> tuple[str, list[str]]: + """Pick the best display unit for a column's magnitude and format with + aligned decimals. Header is ' '.""" + non_none = [abs(v) for v in canonical_values if v is not None] + max_value = max(non_none, default=0.0) + if max_value > 0: + label, scale = _pick_unit(max_value, units) + else: + # All values are zero / None. Fall back to the canonical unit (scale=1.0) + # so e.g. memory defaults to MiB rather than the middle of the table. + label, scale = next(((l, s) for (l, s) in units if s == 1.0), units[0]) + scaled = [v * scale if v is not None else None for v in canonical_values] + header = f"{prefix} {label}" if prefix else label + return header, _format_aligned(scaled) + + +def _percent_column(values: list[float | None]) -> list[str]: + """Format a column of ratios as aligned percentages.""" + scaled = [v * 100 if v is not None else None for v in values] + formatted = _format_aligned(scaled) + return [f if f == "—" else f"{f}%" for f in formatted] + + +def _rms_column(values: list[float | None]) -> list[str]: + """Align RMS errors in scientific notation with a shared exponent-free width.""" + decimals = 3 # 4 sig figs + out: list[str] = [] + for value in values: + if value is None: + out.append("—") + elif value == 0.0: + out.append(f"{0.0:.{decimals}e}") + else: + out.append(f"{value:.{decimals}e}") + return out + + +def _simplify_rms_key(key: str, all_keys: list[str]) -> str: + """Turn internal keys like 'fwd.out' / 'fb.loss' into concise display labels. + + Rules: + - strip the mode prefix ('fwd.'/'fb.') when all keys share the same mode + - rename 'fb' → 'bwd' for display when it survives + - drop the trailing '.out' / standalone 'out' (the placeholder key used + when a variant returns a single unnamed tensor) + """ + mode, _, tensor = key.partition(".") + all_modes = {k.partition(".")[0] for k in all_keys} + if len(all_modes) <= 1: + remainder = tensor + else: + pretty = "bwd" if mode == "fb" else mode + remainder = f"{pretty}.{tensor}" if tensor else pretty + if remainder == "out": + return "" + return remainder.removesuffix(".out") + + +def _rms_header(key: str, all_keys: list[str]) -> str: + simplified = _simplify_rms_key(key, all_keys) + return f"rel_rms({simplified})" if simplified else "rel_rms" + + +def _render_table( + case: Case, + results: list[VariantResult], + gpu_spec: GpuSpec | None, + has_fwd: bool, + has_fwd_bwd: bool, + rms_keys: list[str], + verbose: bool, +) -> str: + # First column header carries the case name so the per-case label and the + # variant-name column are merged into one (avoids a redundant title line). + columns: list[tuple[str, list[str]]] = [(case.name, [r.variant_name for r in results])] + + def _add(header: str, values: list[str]) -> None: + columns.append((header, values)) + + if has_fwd: + _add(*_unit_column("fwd", [r.fwd_timing.median_ms if r.fwd_timing else None for r in results], _TIME_UNITS)) + if verbose: + _add( + *_unit_column( + "fwd mean", [r.fwd_timing.mean_ms if r.fwd_timing else None for r in results], _TIME_UNITS + ) + ) + _add( + *_unit_column("fwd min", [r.fwd_timing.min_ms if r.fwd_timing else None for r in results], _TIME_UNITS) + ) + _add( + *_unit_column("fwd max", [r.fwd_timing.max_ms if r.fwd_timing else None for r in results], _TIME_UNITS) + ) + + if has_fwd_bwd: + # Backward-only derived: fwd+bwd − fwd. + bwd_values: list[float | None] = [] + total_values: list[float | None] = [] + for r in results: + if r.fwd_bwd_timing is None: + bwd_values.append(None) + total_values.append(None) + continue + total = r.fwd_bwd_timing.median_ms + bwd_values.append(total - r.fwd_timing.median_ms if r.fwd_timing else None) + total_values.append(total) + _add(*_unit_column("bwd", bwd_values, _TIME_UNITS)) + _add(*_unit_column("total", total_values, _TIME_UNITS)) + + def _time_for_throughput(r: VariantResult) -> float | None: + if r.fwd_bwd_timing is not None: + return r.fwd_bwd_timing.median_ms + if r.fwd_timing is not None: + return r.fwd_timing.median_ms + return None + + if case.expected_bytes is not None: + bandwidths: list[float | None] = [] + for r in results: + t_ms = _time_for_throughput(r) + bandwidths.append(case.expected_bytes / (t_ms / 1000) if t_ms is not None else None) + header, values = _unit_column("", bandwidths, _BANDWIDTH_UNITS) + _add(header, values) + if gpu_spec is not None: + peak_bytes_per_s = gpu_spec.peak_bandwidth_gbps * 1e9 + pct = [bw / peak_bytes_per_s if bw is not None else None for bw in bandwidths] + _add("%BW", _percent_column(pct)) + + if case.expected_flops is not None: + flop_rates: list[float | None] = [] + for r in results: + t_ms = _time_for_throughput(r) + flop_rates.append(case.expected_flops / (t_ms / 1000) if t_ms is not None else None) + header, values = _unit_column("", flop_rates, _FLOPS_UNITS) + _add(header, values) + peak_tflops = gpu_spec.peak_tflops(case.compute_dtype) if gpu_spec and case.compute_dtype else None + if peak_tflops is not None: + peak_flops_per_s = peak_tflops * 1e12 + pct = [fr / peak_flops_per_s if fr is not None else None for fr in flop_rates] + _add("%FLOPs", _percent_column(pct)) + + peak_mib = [r.memory.peak_mib if r.memory else None for r in results] + delta_mib = [r.memory.delta_peak_mib if r.memory else None for r in results] + _add(*_unit_column("peak", peak_mib, _MEMORY_UNITS)) + _add(*_unit_column("Δpeak", delta_mib, _MEMORY_UNITS)) + + for key in rms_keys: + _add(_rms_header(key, rms_keys), _rms_column([(r.rms_errors or {}).get(key) for r in results])) + + _add("error", [r.error or "" for r in results]) + # Drop the error column if nothing failed + if not any(r.error for r in results): + columns.pop() + + widths = [max(len(header), *(len(v) for v in values)) for header, values in columns] + sep = " " + + # First column (case name + variant names) is text — left-justify. All other + # columns are numeric — right-justify so decimal points line up across rows. + def _justify(text: str, width: int, column_index: int) -> str: + return text.ljust(width) if column_index == 0 else text.rjust(width) + + header_line = sep.join(_justify(h, w, i) for i, ((h, _), w) in enumerate(zip(columns, widths))) + divider = sep.join("-" * w for w in widths) + body_lines = [] + for row in range(len(results)): + body_lines.append( + sep.join(_justify(values[row], w, i) for i, ((_, values), w) in enumerate(zip(columns, widths))) + ) + return "\n".join([header_line, divider, *body_lines]) + + +# --------------------------------------------------------------------------- orchestration + + +def run_benchmark( + benchmark_name: str, + cases: list[Case], + variants: list[Variant], + *, + warmup_ms: float = 25.0, + rep_ms: float = 100.0, + verbose: bool = False, + print_fn: Callable[[str], None] = print, +) -> list[tuple[Case, list[VariantResult]]]: + """Run all (case, variant) combinations and print one table per case. + + Exactly one variant should have `is_reference=True` — its outputs are the + ground truth for RMS-error comparisons. That variant should compute in + fp32, eager, using the most straightforward reference implementation.""" + reference = [v for v in variants if v.is_reference] + if len(reference) != 1: + raise ValueError( + f"Expected exactly one reference variant (is_reference=True), got {len(reference)}. " + f"Variants: {[v.name for v in variants]}" + ) + gpu_spec = detect_gpu_spec() + print_fn(header(benchmark_name)) + if gpu_spec is not None: + print_fn(f"gpu: {gpu_spec.name} (peak BW {gpu_spec.peak_bandwidth_gbps:.0f} GB/s)") + else: + print_fn("gpu: unknown (no %-of-peak columns)") + print_fn("") + + all_results: list[tuple[Case, list[VariantResult]]] = [] + for case in cases: + ref_outputs = _collect_reference_outputs(reference[0], case) + + results = [] + has_fwd = False + has_fwd_bwd = False + rms_keys_seen: list[str] = [] + for variant in variants: + r = _run_one_variant(variant, case, ref_outputs, warmup_ms=warmup_ms, rep_ms=rep_ms) + results.append(r) + has_fwd = has_fwd or r.fwd_timing is not None + has_fwd_bwd = has_fwd_bwd or r.fwd_bwd_timing is not None + for k in r.rms_errors or {}: + if k not in rms_keys_seen: + rms_keys_seen.append(k) + + print_fn( + _render_table( + case, + results, + gpu_spec=gpu_spec, + has_fwd=has_fwd, + has_fwd_bwd=has_fwd_bwd, + rms_keys=rms_keys_seen, + verbose=verbose, + ) + ) + print_fn("") + all_results.append((case, results)) + return all_results diff --git a/tools/benchmark/utils.py b/tools/benchmark/utils.py new file mode 100644 index 000000000..1a5ab417c --- /dev/null +++ b/tools/benchmark/utils.py @@ -0,0 +1,85 @@ +""" +Convenience helpers for writing kernel benchmark files. Reduces the boilerplate +of building cases and variants so each `bench_*.py` can stay focused on +kernel-specific logic (input construction, expected_bytes/flops, special variants). +""" + +from collections.abc import Callable + +import torch + +from fast_llm.engine.config_utils.data_type import DataType +from fast_llm.functional.config import TritonConfig +from tools.benchmark.runner import Inputs, Variant + +# --------------------------------------------------------------------------- formatting + + +def format_size(n: int) -> str: + """Format an int with the largest binary prefix that divides it exactly: 1048576 → '1 Mi'.""" + for unit, factor in (("Gi", 1 << 30), ("Mi", 1 << 20), ("Ki", 1 << 10)): + if n >= factor and n % factor == 0: + return f"{n // factor} {unit}" + return str(n) + + +def format_shape(shape: tuple[int, ...]) -> str: + """Format a shape tuple with human-readable sizes per dim: (16777216,) → '(16 Mi,)'.""" + joined = ", ".join(format_size(n) for n in shape) + return f"({joined},)" if len(shape) == 1 else f"({joined})" + + +def case_name(kernel: str, shape: tuple[int, ...], dtype: torch.dtype) -> str: + """Build the standard case header: `[copy] (16 Mi,) bf16`.""" + return f"[{kernel}] {format_shape(shape)} {DataType.from_torch(dtype).short}" + + +def device() -> str: + """The device benchmarks should target. Falls back to CPU when CUDA is missing + so non-Triton variants can still run for local smoke testing.""" + return "cuda" if torch.cuda.is_available() else "cpu" + + +# --------------------------------------------------------------------------- variant builders + + +def standard_fwd_variants( + eager_fn: Callable, + triton_fn: Callable | None, + unpack: Callable[[Inputs], tuple], +) -> list[Variant]: + """Build the canonical 5-variant set for a forward-only kernel. + + Generates: fp32_reference, pytorch_eager, pytorch_compiled, pytorch_compiled_max, + and (if `TritonConfig.enabled()`) fast_llm_triton. + + `eager_fn` is the plain PyTorch implementation taking positional tensor args. + `triton_fn` is the Fast-LLM Triton wrapper; pass `None` if the kernel has no + Triton variant. Both are invoked with `unpack(inputs)` unpacked positionally; + `triton_fn` is called with an extra `use_triton=True` kwarg. + + The fp32 reference upcasts every floating-point tensor in the unpacked + arguments to fp32 (non-tensor / non-float arguments are passed through). + """ + + def _fp32_unpack(inputs: Inputs) -> tuple: + return tuple( + arg.float() if isinstance(arg, torch.Tensor) and arg.is_floating_point() else arg for arg in unpack(inputs) + ) + + compiled_default = torch.compile(eager_fn, mode="default", dynamic=False) + compiled_max = torch.compile(eager_fn, mode="max-autotune-no-cudagraphs", dynamic=False) + + variants = [ + Variant( + name="fp32_reference", + fwd=lambda inp: eager_fn(*_fp32_unpack(inp)), + is_reference=True, + ), + Variant(name="pytorch_eager", fwd=lambda inp: eager_fn(*unpack(inp))), + Variant(name="pytorch_compiled", fwd=lambda inp: compiled_default(*unpack(inp))), + Variant(name="pytorch_compiled_max", fwd=lambda inp: compiled_max(*unpack(inp))), + ] + if triton_fn is not None and TritonConfig.enabled(): + variants.append(Variant(name="fast_llm_triton", fwd=lambda inp: triton_fn(*unpack(inp), use_triton=True))) + return variants From 13678d520d315edf0b08f5d67b2c72c0f50c87c2 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Tue, 28 Apr 2026 02:58:27 -0400 Subject: [PATCH 02/41] Add sparse_linear benchmark and fix sparse matmul kernel correctness Adds tools/benchmark/bench_sparse_linear.py covering the two sparse GEMM kernels in MoE FFN layers (output_sparse / up-proj and input_inner_sparse / down-proj), comparing fast_llm_triton against a PyTorch loop reference and torch.compile. Fixes three kernel correctness bugs surfaced by the new benchmark: 1. Phantom blocks left output uninitialized. Both output_sparse_matmul_kernel and input_inner_sparse_matmul_kernel short-circuited (`return`) on blocks past the last expert, leaving those rows of the caller-allocated output buffer with whatever garbage happened to be at that GPU address. Production silently discarded the garbage at the scatter-back boundary, but it is a latent footgun for any caller that reads the full output tensor (and what made the benchmark comparison nondeterministic). The skipped blocks now write zeros instead of returning, so the output is fully defined regardless of how the caller allocated it. 2. Inner-loop tl.dot accumulated in bfloat16. The first tl.dot in each of the three kernels already specified `out_dtype=tl.float32`, but the loop bodies did not, so accumulation past the first tile silently fell back to bf16 Tensor Core accumulation. Added `out_dtype=tl.float32` to every tl.dot call. 3. Backward used wrong argument order for input_row_sparse_matmul. The grad_rhs path in OutputSparseLinear and InputSparseLinear had lhs/grad_out swapped relative to what the kernel expects. After these fixes every rel_rms in the sparse_linear benchmark goes to 0 across all six (op, shape) configurations. Co-Authored-By: Claude Sonnet 4.6 --- fast_llm/functional/triton/sparse_linear.py | 40 ++- tools/benchmark/bench_sparse_linear.py | 304 ++++++++++++++++++++ 2 files changed, 331 insertions(+), 13 deletions(-) create mode 100644 tools/benchmark/bench_sparse_linear.py diff --git a/fast_llm/functional/triton/sparse_linear.py b/fast_llm/functional/triton/sparse_linear.py index 15af789d7..601ae0fa5 100644 --- a/fast_llm/functional/triton/sparse_linear.py +++ b/fast_llm/functional/triton/sparse_linear.py @@ -231,24 +231,30 @@ def output_sparse_matmul_kernel( sparse_range = tl_arange(0, padded_sparse_dim) expert_ends = tl.load(expert_ends_ptr + sparse_range, mask=sparse_range < sparse_dim, other=row_dim) sparse_index = tl.sum((expert_ends <= row_offset).to(tl.int64)) # noqa - if sparse_index == sparse_dim: - return - col_dense_offset = col_sparse_offset + sparse_index * col_sparse_dim # Pointers row_range = tl_arange(0, block_size_row)[:, None] col_range = tl_arange(0, block_size_col)[None, :] + out_ptr += (row_offset + row_range) * out_stride_row + (col_sparse_offset + col_range) * out_stride_col + + if sparse_index == sparse_dim: + # Phantom block: row_offset is past the last expert. Write zeros so the + # output is fully defined regardless of the caller's allocation. + if not accumulate: + tl.store(out_ptr, tl.zeros((block_size_row, block_size_col), dtype=out_ptr.dtype.element_ty)) + return + col_dense_offset = col_sparse_offset + sparse_index * col_sparse_dim + inner_range = tl_arange(0, block_size_inner) lhs_ptr += (row_offset + row_range) * lhs_stride_row + inner_range[None, :] * lhs_stride_inner rhs_ptr += inner_range[:, None] * rhs_stride_inner + (col_dense_offset + col_range) * rhs_stride_col - out_ptr += (row_offset + row_range) * out_stride_row + (col_sparse_offset + col_range) * out_stride_col # Matrix multiplication out = tl.dot(tl.load(lhs_ptr), tl.load(rhs_ptr), out_dtype=tl.float32) for k in range(1, inner_dim // block_size_inner): lhs_ptr += block_size_inner * lhs_stride_inner rhs_ptr += block_size_inner * rhs_stride_inner - out += tl.dot(tl.load(lhs_ptr), tl.load(rhs_ptr)) + out += tl.dot(tl.load(lhs_ptr), tl.load(rhs_ptr), out_dtype=tl.float32) if accumulate: out += tl.load(out_ptr) @@ -350,31 +356,37 @@ def input_inner_sparse_matmul_kernel( # Grid offsets row_offset = pid_row * block_size_row + col_offset = pid_col * block_size_col sparse_range = tl_arange(0, padded_sparse_dim) expert_ends = tl.load(expert_ends_ptr + sparse_range, mask=sparse_range < sparse_dim, other=row_dim) sparse_index = tl.sum((expert_ends <= row_offset).to(tl.int64)) # noqa - if sparse_index == sparse_dim: - return - inner_dense_offset = sparse_index * inner_sparse_dim - col_offset = pid_col * block_size_col # Pointers row_range = tl_arange(0, block_size_row)[:, None] col_range = tl_arange(0, block_size_col)[None, :] + out_ptr += (row_offset + row_range) * out_stride_row + (col_offset + col_range) * out_stride_col + + if sparse_index == sparse_dim: + # Phantom block: row_offset is past the last expert. Write zeros so the + # output is fully defined regardless of the caller's allocation. + if not accumulate: + tl.store(out_ptr, tl.zeros((block_size_row, block_size_col), dtype=out_ptr.dtype.element_ty)) + return + inner_dense_offset = sparse_index * inner_sparse_dim + inner_range = tl_arange(0, block_size_inner) lhs_ptr += (row_offset + row_range) * lhs_stride_row + inner_range[None, :] * lhs_stride_inner rhs_ptr += (inner_dense_offset + inner_range[:, None]) * rhs_stride_inner + ( col_offset + col_range ) * rhs_stride_col - out_ptr += (row_offset + row_range) * out_stride_row + (col_offset + col_range) * out_stride_col # Matrix multiplication out = tl.dot(tl.load(lhs_ptr), tl.load(rhs_ptr), out_dtype=tl.float32) for k in range(1, inner_sparse_dim // block_size_inner): lhs_ptr += block_size_inner * lhs_stride_inner rhs_ptr += block_size_inner * rhs_stride_inner - out += tl.dot(tl.load(lhs_ptr), tl.load(rhs_ptr)) + out += tl.dot(tl.load(lhs_ptr), tl.load(rhs_ptr), out_dtype=tl.float32) if accumulate: out += tl.load(out_ptr) @@ -497,6 +509,7 @@ def input_row_sparse_matmul_kernel( out = tl.dot( tl.load(lhs_ptr + inner_range[None, :] * lhs_stride_inner, mask=mask[None, :], other=0), tl.load(rhs_ptr + inner_range[:, None] * rhs_stride_inner, mask=mask[:, None], other=0), + out_dtype=tl.float32, ) for i in range(1, tl.cdiv(inner_end - inner_offset, block_size_inner)): inner_range += block_size_inner @@ -504,6 +517,7 @@ def input_row_sparse_matmul_kernel( out += tl.dot( tl.load(lhs_ptr + inner_range[None, :] * lhs_stride_inner, mask=mask[None, :], other=0), tl.load(rhs_ptr + inner_range[:, None] * rhs_stride_inner, mask=mask[:, None], other=0), + out_dtype=tl.float32, ) if accumulate: @@ -578,7 +592,7 @@ def backward(ctx, grad_out: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, N grad_out = grad_out.contiguous() lhs, rhs = ctx.saved_tensors grad_lhs = input_inner_sparse_matmul(grad_out, rhs.t(), ctx.sparse_map) - grad_rhs = input_row_sparse_matmul(lhs.t(), grad_out, ctx.sparse_map).t() + grad_rhs = input_row_sparse_matmul(grad_out.t(), lhs, ctx.sparse_map).t() return grad_lhs, grad_rhs, None @@ -597,7 +611,7 @@ def backward(ctx, grad_out: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, N grad_out = grad_out.contiguous() lhs, rhs = ctx.saved_tensors grad_lhs = output_sparse_matmul(grad_out, rhs.t(), ctx.sparse_map) - grad_rhs = input_row_sparse_matmul(grad_out.t(), lhs, ctx.sparse_map) + grad_rhs = input_row_sparse_matmul(lhs.t(), grad_out, ctx.sparse_map) return grad_lhs, grad_rhs, None diff --git a/tools/benchmark/bench_sparse_linear.py b/tools/benchmark/bench_sparse_linear.py new file mode 100644 index 000000000..82f0afd1d --- /dev/null +++ b/tools/benchmark/bench_sparse_linear.py @@ -0,0 +1,304 @@ +""" +Benchmark MoE sparse grouped GEMM kernels. + +Two operations are benchmarked, corresponding to the two linear layers in a MoE FFN: + +output_sparse (layer 1 / up-proj): + out[i, :] = lhs[i, :] @ rhs[:, expert(i)*ffn_per_expert : (expert(i)+1)*ffn_per_expert] + lhs: (sparse_tokens, hidden), rhs: (hidden, ffn_per_expert × num_experts) + Each token's output columns come from its assigned expert's slice of rhs. + OutputSparseLinear.apply handles fwd+bwd. + +input_inner_sparse (layer 2 / down-proj): + out[i, :] = lhs[i, :] @ rhs[expert(i)*ffn_per_expert : (expert(i)+1)*ffn_per_expert, :] + lhs: (sparse_tokens, ffn_per_expert), rhs: (ffn_per_expert × num_experts, hidden) + Each token's inner dimension comes from its assigned expert's slice of rhs. + InputSparseLinear.apply handles fwd+bwd. + +Comparisons: +- pytorch_loop: loop over experts with torch.mm per expert (the obvious PyTorch approach) +- pytorch_compiled: torch.compile of the loop +- fast_llm_triton: OutputSparseLinear / InputSparseLinear autograd functions + +Shapes: (tokens, top_k, num_experts, hidden, ffn_per_expert) matching MoE FFN configs. +""" + +import torch + +from fast_llm.functional.config import TritonConfig +from fast_llm.functional.triton.sparse_copy import SparseMap, get_sparse_map +from fast_llm.functional.triton.sparse_linear import InputSparseLinear, OutputSparseLinear +from tools.benchmark.runner import Case, Variant, run_benchmark +from tools.benchmark.utils import case_name, device + +# (tokens, top_k, num_experts, hidden, ffn_per_expert) +_SHAPES = [ + (4096, 2, 8, 4096, 14336), # Mixtral-8x7B: 8 experts, ffn=14336 + (4096, 2, 64, 4096, 1792), # fine-grained MoE: 64 experts, same total capacity + (4096, 2, 8, 8192, 28672), # large hidden / wide FFN +] +_DEFAULT_DTYPES = (torch.bfloat16,) + + +def _make_sparse_map(tokens: int, top_k: int, num_experts: int) -> SparseMap: + top_experts = torch.randint(0, num_experts, (tokens, top_k), device=device()) + return get_sparse_map(top_experts, num_experts) + + +def _zero_padded_rows(tensor: torch.Tensor, sparse_map: SparseMap) -> torch.Tensor: + for e in range(sparse_map.num_experts): + pad_start = int(sparse_map.expert_pad_begins[e]) + pad_end = int(sparse_map.expert_ends[e]) + if pad_end > pad_start: + tensor[pad_start:pad_end] = 0 + return tensor + + +def _make_output_sparse_inputs( + tokens: int, top_k: int, num_experts: int, hidden: int, ffn_per_expert: int, dtype: torch.dtype +) -> dict: + sparse_map = _make_sparse_map(tokens, top_k, num_experts) + lhs_data = _zero_padded_rows(torch.randn(sparse_map.num_rows, hidden, dtype=dtype, device=device()), sparse_map) + rhs_data = torch.randn(hidden, ffn_per_expert * num_experts, dtype=dtype, device=device()) + # Warm up Triton autotuning so the timed runs aren't dominated by JIT compilation. + if TritonConfig.enabled(): + _w_lhs = lhs_data.detach().requires_grad_(True) + _w_rhs = rhs_data.detach().requires_grad_(True) + _w_out = OutputSparseLinear.apply(_w_lhs, _w_rhs, sparse_map) + _w_out.backward(torch.ones_like(_w_out)) + del _w_lhs, _w_rhs, _w_out + return { + "lhs": lhs_data.requires_grad_(True), + "rhs": rhs_data.requires_grad_(True), + "sparse_map": sparse_map, + "ffn_per_expert": ffn_per_expert, + } + + +def _make_input_inner_sparse_inputs( + tokens: int, top_k: int, num_experts: int, hidden: int, ffn_per_expert: int, dtype: torch.dtype +) -> dict: + sparse_map = _make_sparse_map(tokens, top_k, num_experts) + lhs_data = _zero_padded_rows( + torch.randn(sparse_map.num_rows, ffn_per_expert, dtype=dtype, device=device()), sparse_map + ) + rhs_data = torch.randn(ffn_per_expert * num_experts, hidden, dtype=dtype, device=device()) + # Warm up Triton autotuning so the timed runs aren't dominated by JIT compilation. + if TritonConfig.enabled(): + _w_lhs = lhs_data.detach().requires_grad_(True) + _w_rhs = rhs_data.detach().requires_grad_(True) + _w_out = InputSparseLinear.apply(_w_lhs, _w_rhs, sparse_map) + _w_out.backward(torch.ones_like(_w_out)) + del _w_lhs, _w_rhs, _w_out + return { + "lhs": lhs_data.requires_grad_(True), + "rhs": rhs_data.requires_grad_(True), + "sparse_map": sparse_map, + "ffn_per_expert": ffn_per_expert, + } + + +# --------------------------------------------------------------------------- output_sparse references + + +def _output_sparse_loop(lhs: torch.Tensor, rhs: torch.Tensor, sparse_map: SparseMap) -> torch.Tensor: + ffn_per_expert = rhs.shape[1] // sparse_map.num_experts + out = lhs.new_zeros(sparse_map.num_rows, ffn_per_expert) + for e in range(sparse_map.num_experts): + row_begin = int(sparse_map.expert_ends[e - 1]) if e > 0 else 0 + row_end = int(sparse_map.expert_pad_begins[e]) + if row_end > row_begin: + col_begin = e * ffn_per_expert + out[row_begin:row_end] = lhs[row_begin:row_end] @ rhs[:, col_begin : col_begin + ffn_per_expert] + return out + + +_output_sparse_compiled = torch.compile(_output_sparse_loop, mode="default", dynamic=False) + + +def _run_output_sparse_fwd(inp: dict, fn) -> dict: + return {"output": fn(inp["lhs"], inp["rhs"], inp["sparse_map"])} + + +def _run_output_sparse_fwd_bwd(inp: dict, fn) -> dict: + output = fn(inp["lhs"], inp["rhs"], inp["sparse_map"]) + output.backward(_zero_padded_rows(torch.ones_like(output), inp["sparse_map"])) + return {"output": output.detach(), "grad_lhs": inp["lhs"].grad, "grad_rhs": inp["rhs"].grad} + + +def _run_output_sparse_fwd_triton(inp: dict) -> dict: + return {"output": OutputSparseLinear.apply(inp["lhs"], inp["rhs"], inp["sparse_map"])} + + +def _run_output_sparse_fwd_bwd_triton(inp: dict) -> dict: + output = OutputSparseLinear.apply(inp["lhs"], inp["rhs"], inp["sparse_map"]) + output.backward(_zero_padded_rows(torch.ones_like(output), inp["sparse_map"])) + return {"output": output.detach(), "grad_lhs": inp["lhs"].grad, "grad_rhs": inp["rhs"].grad} + + +def _output_sparse_variants() -> list[Variant]: + variants = [ + Variant( + name="pytorch_loop", + fwd=lambda inp: _run_output_sparse_fwd(inp, _output_sparse_loop), + fwd_bwd=lambda inp: _run_output_sparse_fwd_bwd(inp, _output_sparse_loop), + is_reference=True, + ), + Variant( + name="pytorch_compiled", + fwd=lambda inp: _run_output_sparse_fwd(inp, _output_sparse_compiled), + fwd_bwd=lambda inp: _run_output_sparse_fwd_bwd(inp, _output_sparse_compiled), + ), + ] + if TritonConfig.enabled(): + variants.append( + Variant( + name="fast_llm_triton", + fwd=_run_output_sparse_fwd_triton, + fwd_bwd=_run_output_sparse_fwd_bwd_triton, + ) + ) + return variants + + +# --------------------------------------------------------------------------- input_inner_sparse references + + +def _input_inner_sparse_loop(lhs: torch.Tensor, rhs: torch.Tensor, sparse_map: SparseMap) -> torch.Tensor: + ffn_per_expert = rhs.shape[0] // sparse_map.num_experts + out = lhs.new_zeros(sparse_map.num_rows, rhs.shape[1]) + for e in range(sparse_map.num_experts): + row_begin = int(sparse_map.expert_ends[e - 1]) if e > 0 else 0 + row_end = int(sparse_map.expert_pad_begins[e]) + if row_end > row_begin: + inner_begin = e * ffn_per_expert + out[row_begin:row_end] = lhs[row_begin:row_end] @ rhs[inner_begin : inner_begin + ffn_per_expert] + return out + + +_input_inner_sparse_compiled = torch.compile(_input_inner_sparse_loop, mode="default", dynamic=False) + + +def _run_input_inner_sparse_fwd(inp: dict, fn) -> dict: + return {"output": fn(inp["lhs"], inp["rhs"], inp["sparse_map"])} + + +def _run_input_inner_sparse_fwd_bwd(inp: dict, fn) -> dict: + output = fn(inp["lhs"], inp["rhs"], inp["sparse_map"]) + output.backward(_zero_padded_rows(torch.ones_like(output), inp["sparse_map"])) + return {"output": output.detach(), "grad_lhs": inp["lhs"].grad, "grad_rhs": inp["rhs"].grad} + + +def _run_input_inner_sparse_fwd_triton(inp: dict) -> dict: + return {"output": InputSparseLinear.apply(inp["lhs"], inp["rhs"], inp["sparse_map"])} + + +def _run_input_inner_sparse_fwd_bwd_triton(inp: dict) -> dict: + output = InputSparseLinear.apply(inp["lhs"], inp["rhs"], inp["sparse_map"]) + output.backward(_zero_padded_rows(torch.ones_like(output), inp["sparse_map"])) + return {"output": output.detach(), "grad_lhs": inp["lhs"].grad, "grad_rhs": inp["rhs"].grad} + + +def _input_inner_sparse_variants() -> list[Variant]: + variants = [ + Variant( + name="pytorch_loop", + fwd=lambda inp: _run_input_inner_sparse_fwd(inp, _input_inner_sparse_loop), + fwd_bwd=lambda inp: _run_input_inner_sparse_fwd_bwd(inp, _input_inner_sparse_loop), + is_reference=True, + ), + Variant( + name="pytorch_compiled", + fwd=lambda inp: _run_input_inner_sparse_fwd(inp, _input_inner_sparse_compiled), + fwd_bwd=lambda inp: _run_input_inner_sparse_fwd_bwd(inp, _input_inner_sparse_compiled), + ), + ] + if TritonConfig.enabled(): + variants.append( + Variant( + name="fast_llm_triton", + fwd=_run_input_inner_sparse_fwd_triton, + fwd_bwd=_run_input_inner_sparse_fwd_bwd_triton, + ) + ) + return variants + + +# --------------------------------------------------------------------------- cases / bytes / flops + + +def _bytes_per_elem(dtype: torch.dtype) -> int: + return torch.tensor([], dtype=dtype).element_size() + + +def _sparse_linear_bytes( + sparse_tokens: int, hidden: int, ffn_per_expert: int, num_experts: int, dtype: torch.dtype +) -> int: + elem = _bytes_per_elem(dtype) + # fwd: read lhs + read rhs_full + write output + # bwd: read grad_output + read rhs_full + write grad_lhs + read lhs + read grad_output + write grad_rhs + # Simplification: 3× lhs traffic + 3× rhs traffic + 2× output traffic + lhs_bytes = sparse_tokens * hidden * elem + rhs_bytes = hidden * ffn_per_expert * num_experts * elem + out_bytes = sparse_tokens * ffn_per_expert * elem + return 3 * lhs_bytes + 3 * rhs_bytes + 2 * out_bytes + + +def _sparse_linear_flops(sparse_tokens_unpadded: int, hidden: int, ffn_per_expert: int) -> int: + # fwd + bwd ≈ 3 matmuls (fwd: lhs@rhs, bwd_lhs: grad@rhs.T, bwd_rhs: lhs.T@grad) + return 3 * 2 * sparse_tokens_unpadded * hidden * ffn_per_expert + + +def _output_sparse_cases(dtypes: tuple[torch.dtype, ...]) -> list[Case]: + return [ + Case( + name=case_name("output_sparse", (tokens, top_k, num_experts, hidden, ffn_per_expert), dtype), + make_inputs=lambda t=tokens, k=top_k, n=num_experts, h=hidden, f=ffn_per_expert, d=dtype: ( + _make_output_sparse_inputs(t, k, n, h, f, d) + ), + expected_bytes=_sparse_linear_bytes(tokens * top_k, hidden, ffn_per_expert, num_experts, dtype), + expected_flops=_sparse_linear_flops(tokens * top_k, hidden, ffn_per_expert), + compute_dtype=dtype, + ) + for dtype in dtypes + for tokens, top_k, num_experts, hidden, ffn_per_expert in _SHAPES + ] + + +def _input_inner_sparse_cases(dtypes: tuple[torch.dtype, ...]) -> list[Case]: + return [ + Case( + name=case_name("input_inner_sparse", (tokens, top_k, num_experts, hidden, ffn_per_expert), dtype), + make_inputs=lambda t=tokens, k=top_k, n=num_experts, h=hidden, f=ffn_per_expert, d=dtype: ( + _make_input_inner_sparse_inputs(t, k, n, h, f, d) + ), + expected_bytes=_sparse_linear_bytes(tokens * top_k, ffn_per_expert, hidden, num_experts, dtype), + expected_flops=_sparse_linear_flops(tokens * top_k, ffn_per_expert, hidden), + compute_dtype=dtype, + ) + for dtype in dtypes + for tokens, top_k, num_experts, hidden, ffn_per_expert in _SHAPES + ] + + +# --------------------------------------------------------------------------- entry point + + +def run(verbose: bool = False, dtypes: tuple[torch.dtype, ...] | None = None) -> None: + dtypes = tuple(dtypes) if dtypes else _DEFAULT_DTYPES + run_benchmark( + "sparse_linear: output_sparse (layer 1 / up-proj)", + _output_sparse_cases(dtypes), + _output_sparse_variants(), + verbose=verbose, + ) + run_benchmark( + "sparse_linear: input_inner_sparse (layer 2 / down-proj)", + _input_inner_sparse_cases(dtypes), + _input_inner_sparse_variants(), + verbose=verbose, + ) + + +if __name__ == "__main__": + run() From 3a84431e80c2cbeef00d0ac851a113f1a6127fcc Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Tue, 28 Apr 2026 03:07:30 -0400 Subject: [PATCH 03/41] Add GRPO loss and sparse copy benchmarks - bench_grpo_loss.py: fused triton_grpo_loss_forward_backward kernel vs PyTorch eager / compiled, swept over vocab=32K/64K/128K (matches bench_entropy_loss shapes since GRPO is structurally similar). - bench_sparse_copy.py: MoE token dispatch (dense->sparse) and combine (sparse->dense) via copy_dense_to_sparse_autograd / copy_sparse_to_dense_autograd, against PyTorch index-scatter/gather references. Shapes match Mixtral-8x7B and fine-grained MoE. Co-Authored-By: Claude Sonnet 4.6 --- tools/benchmark/bench_grpo_loss.py | 192 +++++++++++++++++++ tools/benchmark/bench_sparse_copy.py | 268 +++++++++++++++++++++++++++ 2 files changed, 460 insertions(+) create mode 100644 tools/benchmark/bench_grpo_loss.py create mode 100644 tools/benchmark/bench_sparse_copy.py diff --git a/tools/benchmark/bench_grpo_loss.py b/tools/benchmark/bench_grpo_loss.py new file mode 100644 index 000000000..a60e1751c --- /dev/null +++ b/tools/benchmark/bench_grpo_loss.py @@ -0,0 +1,192 @@ +""" +Benchmark the fused GRPO loss kernel. + +GRPO (Group Relative Policy Optimization) loss computes a clipped importance-weighted +policy gradient per token: loss = -min(ratio * adv, clip(ratio, 1-eps, 1+eps) * adv), +where ratio = exp(log_prob_new - log_prob_old). + +The Triton kernel fuses softmax, log-prob extraction, ratio computation, clipping, and +the backward gradient into a single pass over logits — same structure as the cross_entropy +kernel. + +Comparisons: +- fp32_reference: PyTorch GRPO in fp32 +- pytorch_eager: PyTorch GRPO in compute dtype +- pytorch_compiled / pytorch_compiled_max: torch.compile of the above +- fast_llm_triton: triton_grpo_loss_forward_backward + +Shapes match bench_entropy_loss: tokens=4096, vocab swept over 32K/64K/128K. +""" + +import torch + +from fast_llm.functional.config import TritonConfig +from fast_llm.functional.triton.grpo_loss import triton_grpo_loss_forward_backward +from tools.benchmark.runner import Case, Variant, run_benchmark +from tools.benchmark.utils import case_name, device + +_SHAPES = [ + (4096, 32768), + (4096, 65536), + (4096, 131072), +] +_DEFAULT_DTYPES = (torch.bfloat16,) +_EPSILON_LOW = 0.2 +_EPSILON_HIGH = 0.2 + + +def _make_grpo_inputs(tokens: int, vocab: int, dtype: torch.dtype) -> dict: + return { + "logits": torch.randn(tokens, vocab, dtype=dtype, device=device(), requires_grad=True), + "labels": torch.randint(0, vocab, (tokens,), dtype=torch.long, device=device()), + "advantages": torch.randn(tokens, dtype=torch.float32, device=device()), + "old_log_probs": torch.randn(tokens, dtype=torch.float32, device=device()) - 5.0, + } + + +def _grpo_eager(logits: torch.Tensor, labels: torch.Tensor, advantages: torch.Tensor, old_log_probs: torch.Tensor): + log_probs = logits.float().log_softmax(-1) + new_log_probs = log_probs.gather(-1, labels.clamp(min=0).unsqueeze(-1)).squeeze(-1) + new_log_probs = torch.where(labels >= 0, new_log_probs, torch.zeros_like(new_log_probs)) + ratio = (new_log_probs - old_log_probs).exp() + clipped_ratio = ratio.clamp(1.0 - _EPSILON_LOW, 1.0 + _EPSILON_HIGH) + per_token_loss = torch.where( + labels >= 0, + -torch.minimum(ratio * advantages, clipped_ratio * advantages), + torch.zeros_like(ratio), + ) + return per_token_loss.mean() + + +_grpo_compiled_default = torch.compile(_grpo_eager, mode="default", dynamic=False) +_grpo_compiled_max = torch.compile(_grpo_eager, mode="max-autotune-no-cudagraphs", dynamic=False) + + +def _run_fwd(inp: dict, fn) -> dict: + return {"loss": fn(inp["logits"], inp["labels"], inp["advantages"], inp["old_log_probs"])} + + +def _run_fwd_fp32(inp: dict) -> dict: + return { + "loss": _grpo_eager( + inp["logits"].float().detach().requires_grad_(), + inp["labels"], + inp["advantages"], + inp["old_log_probs"], + ) + } + + +def _run_fwd_bwd(inp: dict, fn) -> dict: + loss = fn(inp["logits"], inp["labels"], inp["advantages"], inp["old_log_probs"]) + loss.backward() + return {"loss": loss.detach(), "grad_logits": inp["logits"].grad} + + +def _run_fwd_bwd_fp32(inp: dict) -> dict: + logits_fp32 = inp["logits"].float().detach().requires_grad_() + loss = _grpo_eager(logits_fp32, inp["labels"], inp["advantages"], inp["old_log_probs"]) + loss.backward() + return {"loss": loss.detach(), "grad_logits": logits_fp32.grad} + + +def _run_fwd_triton(inp: dict) -> dict: + loss, _, _ = triton_grpo_loss_forward_backward( + inp["logits"], + inp["labels"], + inp["advantages"], + inp["old_log_probs"], + grad_output=None, + epsilon_low=_EPSILON_LOW, + epsilon_high=_EPSILON_HIGH, + ) + return {"loss": loss} + + +def _run_fwd_bwd_triton(inp: dict) -> dict: + loss, grad_logits, _ = triton_grpo_loss_forward_backward( + inp["logits"], + inp["labels"], + inp["advantages"], + inp["old_log_probs"], + grad_output=1.0, + epsilon_low=_EPSILON_LOW, + epsilon_high=_EPSILON_HIGH, + ) + return {"loss": loss, "grad_logits": grad_logits} + + +def _grpo_variants() -> list[Variant]: + variants = [ + Variant( + name="fp32_reference", + fwd=_run_fwd_fp32, + fwd_bwd=_run_fwd_bwd_fp32, + is_reference=True, + ), + Variant( + name="pytorch_eager", + fwd=lambda inp: _run_fwd(inp, _grpo_eager), + fwd_bwd=lambda inp: _run_fwd_bwd(inp, _grpo_eager), + ), + Variant( + name="pytorch_compiled", + fwd=lambda inp: _run_fwd(inp, _grpo_compiled_default), + fwd_bwd=lambda inp: _run_fwd_bwd(inp, _grpo_compiled_default), + ), + Variant( + name="pytorch_compiled_max", + fwd=lambda inp: _run_fwd(inp, _grpo_compiled_max), + fwd_bwd=lambda inp: _run_fwd_bwd(inp, _grpo_compiled_max), + ), + ] + if TritonConfig.enabled(): + variants.append( + Variant( + name="fast_llm_triton", + fwd=_run_fwd_triton, + fwd_bwd=_run_fwd_bwd_triton, + ) + ) + return variants + + +def _bytes_per_elem(dtype: torch.dtype) -> int: + return torch.tensor([], dtype=dtype).element_size() + + +def _grpo_bytes(tokens: int, vocab: int, dtype: torch.dtype) -> int: + elem = _bytes_per_elem(dtype) + # fwd: read logits + bwd: read logits + write grad_logits + logit_traffic = 3 * tokens * vocab * elem + # labels (int64), advantages (fp32), old_log_probs (fp32) + scalar_traffic = tokens * (8 + 4 + 4) + return logit_traffic + scalar_traffic + + +def _grpo_flops(tokens: int, vocab: int) -> int: + # Similar to cross_entropy labels: softmax (fwd) + grad (bwd) ≈ 14 FLOPs/element + return 14 * tokens * vocab + + +def _grpo_cases(dtypes: tuple[torch.dtype, ...]) -> list[Case]: + return [ + Case( + name=case_name("grpo_loss", (tokens, vocab), dtype), + make_inputs=lambda t=tokens, v=vocab, d=dtype: _make_grpo_inputs(t, v, d), + expected_bytes=_grpo_bytes(tokens, vocab, dtype), + expected_flops=_grpo_flops(tokens, vocab), + compute_dtype=dtype, + ) + for dtype in dtypes + for tokens, vocab in _SHAPES + ] + + +def run(verbose: bool = False, dtypes: tuple[torch.dtype, ...] | None = None) -> None: + dtypes = tuple(dtypes) if dtypes else _DEFAULT_DTYPES + run_benchmark("grpo_loss", _grpo_cases(dtypes), _grpo_variants(), verbose=verbose) + + +if __name__ == "__main__": + run() diff --git a/tools/benchmark/bench_sparse_copy.py b/tools/benchmark/bench_sparse_copy.py new file mode 100644 index 000000000..9a84adf9f --- /dev/null +++ b/tools/benchmark/bench_sparse_copy.py @@ -0,0 +1,268 @@ +""" +Benchmark MoE token dispatch and combine (sparse copy) kernels. + +Two operations are benchmarked separately: + +dispatch (dense → sparse): + Each token is copied to top_k expert slots in the sparse buffer. + copy_dense_to_sparse_autograd handles fwd+bwd (bwd = sparse-to-dense, no scores). + +combine (sparse → dense): + Expert outputs are gathered and weighted by routing scores back to token space. + copy_sparse_to_dense_autograd handles fwd+bwd (bwd = dense-to-sparse + score grad). + +Comparisons: +- pytorch_eager: index-based scatter/gather in compute dtype +- pytorch_compiled / pytorch_compiled_max: torch.compile of the above +- fast_llm_triton: copy_dense_to_sparse_autograd / copy_sparse_to_dense_autograd + +Shapes: (tokens, top_k, num_experts, hidden_size) matching Mixtral-8x7B and fine-grained MoE. +The SparseMap is pre-computed once per case (routing structure, not data). +""" + +import torch + +from fast_llm.functional.config import TritonConfig +from fast_llm.functional.triton.sparse_copy import ( + SparseMap, + copy_dense_to_sparse_autograd, + copy_sparse_to_dense_autograd, + get_sparse_map, +) +from tools.benchmark.runner import Case, Variant, run_benchmark +from tools.benchmark.utils import case_name, device + +# (tokens, top_k, num_experts, hidden_size) +_SHAPES = [ + (4096, 2, 8, 4096), # Mixtral-8x7B-like + (4096, 2, 64, 4096), # fine-grained MoE + (4096, 2, 8, 8192), # wide hidden +] +_DEFAULT_DTYPES = (torch.bfloat16,) + + +def _make_sparse_map(tokens: int, top_k: int, num_experts: int) -> SparseMap: + top_experts = torch.randint(0, num_experts, (tokens, top_k), device=device()) + return get_sparse_map(top_experts, num_experts) + + +def _make_dispatch_inputs(tokens: int, top_k: int, num_experts: int, hidden: int, dtype: torch.dtype) -> dict: + sparse_map = _make_sparse_map(tokens, top_k, num_experts) + return { + "dense_input": torch.randn(tokens, hidden, dtype=dtype, device=device(), requires_grad=True), + "sparse_map": sparse_map, + } + + +def _make_combine_inputs(tokens: int, top_k: int, num_experts: int, hidden: int, dtype: torch.dtype) -> dict: + sparse_map = _make_sparse_map(tokens, top_k, num_experts) + return { + "sparse_input": torch.randn(sparse_map.num_rows, hidden, dtype=dtype, device=device(), requires_grad=True), + "scores": torch.softmax(torch.randn(tokens, top_k, dtype=dtype, device=device()), dim=-1).requires_grad_(True), + "sparse_map": sparse_map, + } + + +# --------------------------------------------------------------------------- dispatch + + +def _dispatch_pytorch(dense_input: torch.Tensor, sparse_map: SparseMap) -> torch.Tensor: + out = dense_input.new_zeros(sparse_map.num_rows, dense_input.shape[1]) + sparse_rows = sparse_map.sparse_rows.long() + for k in range(sparse_map.num_experts_per_token): + out[sparse_rows[:, k]] = dense_input + return out + + +_dispatch_compiled_default = torch.compile(_dispatch_pytorch, mode="default", dynamic=False) +_dispatch_compiled_max = torch.compile(_dispatch_pytorch, mode="max-autotune-no-cudagraphs", dynamic=False) + + +def _run_dispatch_fwd(inp: dict, fn) -> dict: + return {"output": fn(inp["dense_input"], inp["sparse_map"])} + + +def _run_dispatch_fwd_bwd(inp: dict, fn) -> dict: + output = fn(inp["dense_input"], inp["sparse_map"]) + output.backward(torch.ones_like(output)) + return {"output": output.detach(), "grad_dense": inp["dense_input"].grad} + + +def _run_dispatch_fwd_triton(inp: dict) -> dict: + return {"output": copy_dense_to_sparse_autograd(inp["dense_input"], inp["sparse_map"])} + + +def _run_dispatch_fwd_bwd_triton(inp: dict) -> dict: + output = copy_dense_to_sparse_autograd(inp["dense_input"], inp["sparse_map"]) + output.backward(torch.ones_like(output)) + return {"output": output.detach(), "grad_dense": inp["dense_input"].grad} + + +def _dispatch_variants() -> list[Variant]: + variants = [ + Variant( + name="pytorch_eager", + fwd=lambda inp: _run_dispatch_fwd(inp, _dispatch_pytorch), + fwd_bwd=lambda inp: _run_dispatch_fwd_bwd(inp, _dispatch_pytorch), + is_reference=True, + ), + Variant( + name="pytorch_compiled", + fwd=lambda inp: _run_dispatch_fwd(inp, _dispatch_compiled_default), + fwd_bwd=lambda inp: _run_dispatch_fwd_bwd(inp, _dispatch_compiled_default), + ), + Variant( + name="pytorch_compiled_max", + fwd=lambda inp: _run_dispatch_fwd(inp, _dispatch_compiled_max), + fwd_bwd=lambda inp: _run_dispatch_fwd_bwd(inp, _dispatch_compiled_max), + ), + ] + if TritonConfig.enabled(): + variants.append( + Variant( + name="fast_llm_triton", + fwd=_run_dispatch_fwd_triton, + fwd_bwd=_run_dispatch_fwd_bwd_triton, + ) + ) + return variants + + +# --------------------------------------------------------------------------- combine + + +def _combine_pytorch(sparse_input: torch.Tensor, scores: torch.Tensor, sparse_map: SparseMap) -> torch.Tensor: + out = sparse_input.new_zeros(sparse_map.num_rows_dense, sparse_input.shape[1]) + sparse_rows = sparse_map.sparse_rows.long() + for k in range(sparse_map.num_experts_per_token): + out = out + sparse_input[sparse_rows[:, k]] * scores[:, k : k + 1] + return out + + +_combine_compiled_default = torch.compile(_combine_pytorch, mode="default", dynamic=False) +_combine_compiled_max = torch.compile(_combine_pytorch, mode="max-autotune-no-cudagraphs", dynamic=False) + + +def _run_combine_fwd(inp: dict, fn) -> dict: + return {"output": fn(inp["sparse_input"], inp["scores"], inp["sparse_map"])} + + +def _run_combine_fwd_bwd(inp: dict, fn) -> dict: + output = fn(inp["sparse_input"], inp["scores"], inp["sparse_map"]) + output.backward(torch.ones_like(output)) + return { + "output": output.detach(), + "grad_sparse": inp["sparse_input"].grad, + "grad_scores": inp["scores"].grad, + } + + +def _run_combine_fwd_triton(inp: dict) -> dict: + return {"output": copy_sparse_to_dense_autograd(inp["sparse_input"], inp["scores"], inp["sparse_map"])} + + +def _run_combine_fwd_bwd_triton(inp: dict) -> dict: + output = copy_sparse_to_dense_autograd(inp["sparse_input"], inp["scores"], inp["sparse_map"]) + output.backward(torch.ones_like(output)) + return { + "output": output.detach(), + "grad_sparse": inp["sparse_input"].grad, + "grad_scores": inp["scores"].grad, + } + + +def _combine_variants() -> list[Variant]: + variants = [ + Variant( + name="pytorch_eager", + fwd=lambda inp: _run_combine_fwd(inp, _combine_pytorch), + fwd_bwd=lambda inp: _run_combine_fwd_bwd(inp, _combine_pytorch), + is_reference=True, + ), + Variant( + name="pytorch_compiled", + fwd=lambda inp: _run_combine_fwd(inp, _combine_compiled_default), + fwd_bwd=lambda inp: _run_combine_fwd_bwd(inp, _combine_compiled_default), + ), + Variant( + name="pytorch_compiled_max", + fwd=lambda inp: _run_combine_fwd(inp, _combine_compiled_max), + fwd_bwd=lambda inp: _run_combine_fwd_bwd(inp, _combine_compiled_max), + ), + ] + if TritonConfig.enabled(): + variants.append( + Variant( + name="fast_llm_triton", + fwd=_run_combine_fwd_triton, + fwd_bwd=_run_combine_fwd_bwd_triton, + ) + ) + return variants + + +# --------------------------------------------------------------------------- cases / bytes + + +def _bytes_per_elem(dtype: torch.dtype) -> int: + return torch.tensor([], dtype=dtype).element_size() + + +def _dispatch_bytes(tokens: int, top_k: int, hidden: int, dtype: torch.dtype) -> int: + elem = _bytes_per_elem(dtype) + # fwd: read dense (tokens×h) + write sparse (top_k×tokens×h) + # bwd: read sparse grad + write dense grad → same traffic reversed + return 2 * (1 + top_k) * tokens * hidden * elem + + +def _combine_bytes(tokens: int, top_k: int, hidden: int, dtype: torch.dtype) -> int: + elem = _bytes_per_elem(dtype) + sparse_rows = top_k * tokens + # fwd: read sparse (sparse×h) + read scores (tokens×top_k) + write dense (tokens×h) + # bwd: read dense grad + read scores + write sparse grad + write score grad + return 2 * (sparse_rows + tokens) * hidden * elem + 4 * tokens * top_k * elem + + +def _dispatch_cases(dtypes: tuple[torch.dtype, ...]) -> list[Case]: + return [ + Case( + name=case_name("dispatch", (tokens, top_k, num_experts, hidden), dtype), + make_inputs=lambda t=tokens, k=top_k, n=num_experts, h=hidden, d=dtype: _make_dispatch_inputs( + t, k, n, h, d + ), + expected_bytes=_dispatch_bytes(tokens, top_k, hidden, dtype), + expected_flops=0, + compute_dtype=dtype, + ) + for dtype in dtypes + for tokens, top_k, num_experts, hidden in _SHAPES + ] + + +def _combine_cases(dtypes: tuple[torch.dtype, ...]) -> list[Case]: + return [ + Case( + name=case_name("combine", (tokens, top_k, num_experts, hidden), dtype), + make_inputs=lambda t=tokens, k=top_k, n=num_experts, h=hidden, d=dtype: _make_combine_inputs( + t, k, n, h, d + ), + expected_bytes=_combine_bytes(tokens, top_k, hidden, dtype), + expected_flops=0, + compute_dtype=dtype, + ) + for dtype in dtypes + for tokens, top_k, num_experts, hidden in _SHAPES + ] + + +# --------------------------------------------------------------------------- entry point + + +def run(verbose: bool = False, dtypes: tuple[torch.dtype, ...] | None = None) -> None: + dtypes = tuple(dtypes) if dtypes else _DEFAULT_DTYPES + run_benchmark("sparse_copy: dispatch", _dispatch_cases(dtypes), _dispatch_variants(), verbose=verbose) + run_benchmark("sparse_copy: combine", _combine_cases(dtypes), _combine_variants(), verbose=verbose) + + +if __name__ == "__main__": + run() From 2c619aa98ae1ee0b3f147c05d3f8673d27903240 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Tue, 28 Apr 2026 04:20:16 -0400 Subject: [PATCH 04/41] Fix sparse_copy benchmark: zero phantom rows before correctness comparison Dispatch output and combine grad_sparse both have phantom rows (padding within expert ranges and the static tail beyond expert_ends[-1]) that copy_dense_to_sparse never writes. The PyTorch reference uses new_zeros while production code uses new_empty, so comparing the full tensors produced inf rel_rms. _zero_phantom_rows zeros those ranges in all variants before the runner compares them. Co-Authored-By: Claude Sonnet 4.6 --- tools/benchmark/bench_sparse_copy.py | 28 ++++++++++++++++++++++------ 1 file changed, 22 insertions(+), 6 deletions(-) diff --git a/tools/benchmark/bench_sparse_copy.py b/tools/benchmark/bench_sparse_copy.py index 9a84adf9f..5ebe117f9 100644 --- a/tools/benchmark/bench_sparse_copy.py +++ b/tools/benchmark/bench_sparse_copy.py @@ -63,6 +63,18 @@ def _make_combine_inputs(tokens: int, top_k: int, num_experts: int, hidden: int, } +def _zero_phantom_rows(tensor: torch.Tensor, sparse_map: SparseMap) -> torch.Tensor: + for e in range(sparse_map.num_experts): + pad_begin = int(sparse_map.expert_pad_begins[e]) + pad_end = int(sparse_map.expert_ends[e]) + if pad_end > pad_begin: + tensor[pad_begin:pad_end] = 0 + tail_begin = int(sparse_map.expert_ends[-1]) + if tensor.shape[0] > tail_begin: + tensor[tail_begin:] = 0 + return tensor + + # --------------------------------------------------------------------------- dispatch @@ -79,23 +91,27 @@ def _dispatch_pytorch(dense_input: torch.Tensor, sparse_map: SparseMap) -> torch def _run_dispatch_fwd(inp: dict, fn) -> dict: - return {"output": fn(inp["dense_input"], inp["sparse_map"])} + return {"output": _zero_phantom_rows(fn(inp["dense_input"], inp["sparse_map"]), inp["sparse_map"])} def _run_dispatch_fwd_bwd(inp: dict, fn) -> dict: output = fn(inp["dense_input"], inp["sparse_map"]) output.backward(torch.ones_like(output)) - return {"output": output.detach(), "grad_dense": inp["dense_input"].grad} + return {"output": _zero_phantom_rows(output.detach(), inp["sparse_map"]), "grad_dense": inp["dense_input"].grad} def _run_dispatch_fwd_triton(inp: dict) -> dict: - return {"output": copy_dense_to_sparse_autograd(inp["dense_input"], inp["sparse_map"])} + return { + "output": _zero_phantom_rows( + copy_dense_to_sparse_autograd(inp["dense_input"], inp["sparse_map"]), inp["sparse_map"] + ) + } def _run_dispatch_fwd_bwd_triton(inp: dict) -> dict: output = copy_dense_to_sparse_autograd(inp["dense_input"], inp["sparse_map"]) output.backward(torch.ones_like(output)) - return {"output": output.detach(), "grad_dense": inp["dense_input"].grad} + return {"output": _zero_phantom_rows(output.detach(), inp["sparse_map"]), "grad_dense": inp["dense_input"].grad} def _dispatch_variants() -> list[Variant]: @@ -152,7 +168,7 @@ def _run_combine_fwd_bwd(inp: dict, fn) -> dict: output.backward(torch.ones_like(output)) return { "output": output.detach(), - "grad_sparse": inp["sparse_input"].grad, + "grad_sparse": _zero_phantom_rows(inp["sparse_input"].grad, inp["sparse_map"]), "grad_scores": inp["scores"].grad, } @@ -166,7 +182,7 @@ def _run_combine_fwd_bwd_triton(inp: dict) -> dict: output.backward(torch.ones_like(output)) return { "output": output.detach(), - "grad_sparse": inp["sparse_input"].grad, + "grad_sparse": _zero_phantom_rows(inp["sparse_input"].grad, inp["sparse_map"]), "grad_scores": inp["scores"].grad, } From 194976468be011ebd2760e5eb7297347e5c34fa8 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Tue, 28 Apr 2026 18:48:37 -0400 Subject: [PATCH 05/41] Fix sparse_copy benchmark correctness without polluting timing Add output_postprocess to Variant: a callable applied only during the accuracy check (not the timing loop), so phantom-row masking doesn't inflate measured latency. Precompute a boolean phantom_mask per case; _dispatch_postprocess and _combine_postprocess use masked_fill_ (one GPU op) to zero phantom rows before RMS comparison. Co-Authored-By: Claude Sonnet 4.6 --- tools/benchmark/bench_sparse_copy.py | 59 +++++++++++++++++----------- tools/benchmark/runner.py | 9 +++++ 2 files changed, 46 insertions(+), 22 deletions(-) diff --git a/tools/benchmark/bench_sparse_copy.py b/tools/benchmark/bench_sparse_copy.py index 5ebe117f9..92a797c9b 100644 --- a/tools/benchmark/bench_sparse_copy.py +++ b/tools/benchmark/bench_sparse_copy.py @@ -46,11 +46,28 @@ def _make_sparse_map(tokens: int, top_k: int, num_experts: int) -> SparseMap: return get_sparse_map(top_experts, num_experts) +def _make_phantom_mask(sparse_map: SparseMap) -> torch.Tensor: + # Boolean mask shape (num_rows, 1): True for phantom rows (within-expert padding + # and the static tail beyond expert_ends[-1]). Precomputed once per case and + # used with masked_fill_ in output_postprocess — never inside the timed path. + mask = torch.zeros(sparse_map.num_rows, 1, dtype=torch.bool, device=device()) + for e in range(sparse_map.num_experts): + pad_begin = int(sparse_map.expert_pad_begins[e]) + pad_end = int(sparse_map.expert_ends[e]) + if pad_end > pad_begin: + mask[pad_begin:pad_end] = True + tail_begin = int(sparse_map.expert_ends[-1]) + if sparse_map.num_rows > tail_begin: + mask[tail_begin:] = True + return mask + + def _make_dispatch_inputs(tokens: int, top_k: int, num_experts: int, hidden: int, dtype: torch.dtype) -> dict: sparse_map = _make_sparse_map(tokens, top_k, num_experts) return { "dense_input": torch.randn(tokens, hidden, dtype=dtype, device=device(), requires_grad=True), "sparse_map": sparse_map, + "phantom_mask": _make_phantom_mask(sparse_map), } @@ -60,21 +77,10 @@ def _make_combine_inputs(tokens: int, top_k: int, num_experts: int, hidden: int, "sparse_input": torch.randn(sparse_map.num_rows, hidden, dtype=dtype, device=device(), requires_grad=True), "scores": torch.softmax(torch.randn(tokens, top_k, dtype=dtype, device=device()), dim=-1).requires_grad_(True), "sparse_map": sparse_map, + "phantom_mask": _make_phantom_mask(sparse_map), } -def _zero_phantom_rows(tensor: torch.Tensor, sparse_map: SparseMap) -> torch.Tensor: - for e in range(sparse_map.num_experts): - pad_begin = int(sparse_map.expert_pad_begins[e]) - pad_end = int(sparse_map.expert_ends[e]) - if pad_end > pad_begin: - tensor[pad_begin:pad_end] = 0 - tail_begin = int(sparse_map.expert_ends[-1]) - if tensor.shape[0] > tail_begin: - tensor[tail_begin:] = 0 - return tensor - - # --------------------------------------------------------------------------- dispatch @@ -91,27 +97,28 @@ def _dispatch_pytorch(dense_input: torch.Tensor, sparse_map: SparseMap) -> torch def _run_dispatch_fwd(inp: dict, fn) -> dict: - return {"output": _zero_phantom_rows(fn(inp["dense_input"], inp["sparse_map"]), inp["sparse_map"])} + return {"output": fn(inp["dense_input"], inp["sparse_map"])} def _run_dispatch_fwd_bwd(inp: dict, fn) -> dict: output = fn(inp["dense_input"], inp["sparse_map"]) output.backward(torch.ones_like(output)) - return {"output": _zero_phantom_rows(output.detach(), inp["sparse_map"]), "grad_dense": inp["dense_input"].grad} + return {"output": output.detach(), "grad_dense": inp["dense_input"].grad} def _run_dispatch_fwd_triton(inp: dict) -> dict: - return { - "output": _zero_phantom_rows( - copy_dense_to_sparse_autograd(inp["dense_input"], inp["sparse_map"]), inp["sparse_map"] - ) - } + return {"output": copy_dense_to_sparse_autograd(inp["dense_input"], inp["sparse_map"])} def _run_dispatch_fwd_bwd_triton(inp: dict) -> dict: output = copy_dense_to_sparse_autograd(inp["dense_input"], inp["sparse_map"]) output.backward(torch.ones_like(output)) - return {"output": _zero_phantom_rows(output.detach(), inp["sparse_map"]), "grad_dense": inp["dense_input"].grad} + return {"output": output.detach(), "grad_dense": inp["dense_input"].grad} + + +def _dispatch_postprocess(out: dict[str, torch.Tensor], inp: dict) -> dict[str, torch.Tensor]: + out["output"].masked_fill_(inp["phantom_mask"], 0) + return out def _dispatch_variants() -> list[Variant]: @@ -139,6 +146,7 @@ def _dispatch_variants() -> list[Variant]: name="fast_llm_triton", fwd=_run_dispatch_fwd_triton, fwd_bwd=_run_dispatch_fwd_bwd_triton, + output_postprocess=_dispatch_postprocess, ) ) return variants @@ -168,7 +176,7 @@ def _run_combine_fwd_bwd(inp: dict, fn) -> dict: output.backward(torch.ones_like(output)) return { "output": output.detach(), - "grad_sparse": _zero_phantom_rows(inp["sparse_input"].grad, inp["sparse_map"]), + "grad_sparse": inp["sparse_input"].grad, "grad_scores": inp["scores"].grad, } @@ -182,11 +190,17 @@ def _run_combine_fwd_bwd_triton(inp: dict) -> dict: output.backward(torch.ones_like(output)) return { "output": output.detach(), - "grad_sparse": _zero_phantom_rows(inp["sparse_input"].grad, inp["sparse_map"]), + "grad_sparse": inp["sparse_input"].grad, "grad_scores": inp["scores"].grad, } +def _combine_postprocess(out: dict[str, torch.Tensor], inp: dict) -> dict[str, torch.Tensor]: + if "grad_sparse" in out: + out["grad_sparse"].masked_fill_(inp["phantom_mask"], 0) + return out + + def _combine_variants() -> list[Variant]: variants = [ Variant( @@ -212,6 +226,7 @@ def _combine_variants() -> list[Variant]: name="fast_llm_triton", fwd=_run_combine_fwd_triton, fwd_bwd=_run_combine_fwd_bwd_triton, + output_postprocess=_combine_postprocess, ) ) return variants diff --git a/tools/benchmark/runner.py b/tools/benchmark/runner.py index 2836787f6..a629684c4 100644 --- a/tools/benchmark/runner.py +++ b/tools/benchmark/runner.py @@ -63,6 +63,11 @@ class Variant: # The fp32 reference variant. Exactly one per benchmark; its outputs are # the ground truth for RMS-error comparison. is_reference: bool = False + # Applied to the output dict during the correctness check only — never during + # timing. Receives {name: tensor} and the full inputs dict; returns the + # (possibly modified) dict. Use this to mask out don't-care regions so they + # don't inflate RMS errors (e.g. uninitialized phantom rows in sparse buffers). + output_postprocess: Callable[[dict[str, torch.Tensor], Inputs], dict[str, torch.Tensor]] | None = None @dataclasses.dataclass @@ -283,6 +288,8 @@ def _fwd_once() -> Any: ref_fwd = reference_outputs.get("fwd") if ref_fwd is not None: cand = _as_output_dict(fwd_output) + if variant.output_postprocess is not None: + cand = variant.output_postprocess(cand, inputs) rms = {name: rms_relative_error(cand[name], ref_fwd[name]) for name in ref_fwd if name in cand} result.rms_errors = (result.rms_errors or {}) | {f"fwd.{k}": v for k, v in rms.items()} del fwd_output @@ -311,6 +318,8 @@ def _fwd_bwd_once() -> Any: ref_fb = reference_outputs.get("fwd_bwd") if ref_fb is not None: cand = _as_output_dict(fwd_bwd_output) + if variant.output_postprocess is not None: + cand = variant.output_postprocess(cand, inputs) rms = {name: rms_relative_error(cand[name], ref_fb[name]) for name in ref_fb if name in cand} result.rms_errors = (result.rms_errors or {}) | {f"fb.{k}": v for k, v in rms.items()} del fwd_bwd_output From e3b43378783b680ce5aa39685e7591cd3cd23538 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Tue, 28 Apr 2026 19:16:17 -0400 Subject: [PATCH 06/41] Fix benchmark timing loops: precompute backward gradients MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit torch.ones_like / torch.zeros_like inside timed fwd_bwd functions allocate a new GPU tensor on every timing rep, polluting measurements. bench_sparse_copy: add backward_grad to dispatch/combine make_inputs (shapes sparse_map.num_rows×hidden and tokens×hidden respectively) and use inp["backward_grad"] in all four fwd_bwd functions. bench_sparse_linear: same fix — precompute backward_grad in both make_inputs functions and remove the _zero_padded_rows(...ones_like...) call from all four fwd_bwd functions. The zeroing was never needed: pytorch_loop uses new_zeros so phantom rows have no autograd edge, and the Triton backward already bounds itself to expert_pad_begins. Co-Authored-By: Claude Sonnet 4.6 --- tools/benchmark/bench_sparse_copy.py | 10 ++++++---- tools/benchmark/bench_sparse_linear.py | 16 ++++++++++------ 2 files changed, 16 insertions(+), 10 deletions(-) diff --git a/tools/benchmark/bench_sparse_copy.py b/tools/benchmark/bench_sparse_copy.py index 92a797c9b..aada0f649 100644 --- a/tools/benchmark/bench_sparse_copy.py +++ b/tools/benchmark/bench_sparse_copy.py @@ -68,6 +68,7 @@ def _make_dispatch_inputs(tokens: int, top_k: int, num_experts: int, hidden: int "dense_input": torch.randn(tokens, hidden, dtype=dtype, device=device(), requires_grad=True), "sparse_map": sparse_map, "phantom_mask": _make_phantom_mask(sparse_map), + "backward_grad": torch.ones(sparse_map.num_rows, hidden, dtype=dtype, device=device()), } @@ -78,6 +79,7 @@ def _make_combine_inputs(tokens: int, top_k: int, num_experts: int, hidden: int, "scores": torch.softmax(torch.randn(tokens, top_k, dtype=dtype, device=device()), dim=-1).requires_grad_(True), "sparse_map": sparse_map, "phantom_mask": _make_phantom_mask(sparse_map), + "backward_grad": torch.ones(tokens, hidden, dtype=dtype, device=device()), } @@ -102,7 +104,7 @@ def _run_dispatch_fwd(inp: dict, fn) -> dict: def _run_dispatch_fwd_bwd(inp: dict, fn) -> dict: output = fn(inp["dense_input"], inp["sparse_map"]) - output.backward(torch.ones_like(output)) + output.backward(inp["backward_grad"]) return {"output": output.detach(), "grad_dense": inp["dense_input"].grad} @@ -112,7 +114,7 @@ def _run_dispatch_fwd_triton(inp: dict) -> dict: def _run_dispatch_fwd_bwd_triton(inp: dict) -> dict: output = copy_dense_to_sparse_autograd(inp["dense_input"], inp["sparse_map"]) - output.backward(torch.ones_like(output)) + output.backward(inp["backward_grad"]) return {"output": output.detach(), "grad_dense": inp["dense_input"].grad} @@ -173,7 +175,7 @@ def _run_combine_fwd(inp: dict, fn) -> dict: def _run_combine_fwd_bwd(inp: dict, fn) -> dict: output = fn(inp["sparse_input"], inp["scores"], inp["sparse_map"]) - output.backward(torch.ones_like(output)) + output.backward(inp["backward_grad"]) return { "output": output.detach(), "grad_sparse": inp["sparse_input"].grad, @@ -187,7 +189,7 @@ def _run_combine_fwd_triton(inp: dict) -> dict: def _run_combine_fwd_bwd_triton(inp: dict) -> dict: output = copy_sparse_to_dense_autograd(inp["sparse_input"], inp["scores"], inp["sparse_map"]) - output.backward(torch.ones_like(output)) + output.backward(inp["backward_grad"]) return { "output": output.detach(), "grad_sparse": inp["sparse_input"].grad, diff --git a/tools/benchmark/bench_sparse_linear.py b/tools/benchmark/bench_sparse_linear.py index 82f0afd1d..bbacfd0b5 100644 --- a/tools/benchmark/bench_sparse_linear.py +++ b/tools/benchmark/bench_sparse_linear.py @@ -60,18 +60,20 @@ def _make_output_sparse_inputs( sparse_map = _make_sparse_map(tokens, top_k, num_experts) lhs_data = _zero_padded_rows(torch.randn(sparse_map.num_rows, hidden, dtype=dtype, device=device()), sparse_map) rhs_data = torch.randn(hidden, ffn_per_expert * num_experts, dtype=dtype, device=device()) + backward_grad = torch.ones(sparse_map.num_rows, ffn_per_expert, dtype=dtype, device=device()) # Warm up Triton autotuning so the timed runs aren't dominated by JIT compilation. if TritonConfig.enabled(): _w_lhs = lhs_data.detach().requires_grad_(True) _w_rhs = rhs_data.detach().requires_grad_(True) _w_out = OutputSparseLinear.apply(_w_lhs, _w_rhs, sparse_map) - _w_out.backward(torch.ones_like(_w_out)) + _w_out.backward(backward_grad) del _w_lhs, _w_rhs, _w_out return { "lhs": lhs_data.requires_grad_(True), "rhs": rhs_data.requires_grad_(True), "sparse_map": sparse_map, "ffn_per_expert": ffn_per_expert, + "backward_grad": backward_grad, } @@ -83,18 +85,20 @@ def _make_input_inner_sparse_inputs( torch.randn(sparse_map.num_rows, ffn_per_expert, dtype=dtype, device=device()), sparse_map ) rhs_data = torch.randn(ffn_per_expert * num_experts, hidden, dtype=dtype, device=device()) + backward_grad = torch.ones(sparse_map.num_rows, hidden, dtype=dtype, device=device()) # Warm up Triton autotuning so the timed runs aren't dominated by JIT compilation. if TritonConfig.enabled(): _w_lhs = lhs_data.detach().requires_grad_(True) _w_rhs = rhs_data.detach().requires_grad_(True) _w_out = InputSparseLinear.apply(_w_lhs, _w_rhs, sparse_map) - _w_out.backward(torch.ones_like(_w_out)) + _w_out.backward(backward_grad) del _w_lhs, _w_rhs, _w_out return { "lhs": lhs_data.requires_grad_(True), "rhs": rhs_data.requires_grad_(True), "sparse_map": sparse_map, "ffn_per_expert": ffn_per_expert, + "backward_grad": backward_grad, } @@ -122,7 +126,7 @@ def _run_output_sparse_fwd(inp: dict, fn) -> dict: def _run_output_sparse_fwd_bwd(inp: dict, fn) -> dict: output = fn(inp["lhs"], inp["rhs"], inp["sparse_map"]) - output.backward(_zero_padded_rows(torch.ones_like(output), inp["sparse_map"])) + output.backward(inp["backward_grad"]) return {"output": output.detach(), "grad_lhs": inp["lhs"].grad, "grad_rhs": inp["rhs"].grad} @@ -132,7 +136,7 @@ def _run_output_sparse_fwd_triton(inp: dict) -> dict: def _run_output_sparse_fwd_bwd_triton(inp: dict) -> dict: output = OutputSparseLinear.apply(inp["lhs"], inp["rhs"], inp["sparse_map"]) - output.backward(_zero_padded_rows(torch.ones_like(output), inp["sparse_map"])) + output.backward(inp["backward_grad"]) return {"output": output.detach(), "grad_lhs": inp["lhs"].grad, "grad_rhs": inp["rhs"].grad} @@ -185,7 +189,7 @@ def _run_input_inner_sparse_fwd(inp: dict, fn) -> dict: def _run_input_inner_sparse_fwd_bwd(inp: dict, fn) -> dict: output = fn(inp["lhs"], inp["rhs"], inp["sparse_map"]) - output.backward(_zero_padded_rows(torch.ones_like(output), inp["sparse_map"])) + output.backward(inp["backward_grad"]) return {"output": output.detach(), "grad_lhs": inp["lhs"].grad, "grad_rhs": inp["rhs"].grad} @@ -195,7 +199,7 @@ def _run_input_inner_sparse_fwd_triton(inp: dict) -> dict: def _run_input_inner_sparse_fwd_bwd_triton(inp: dict) -> dict: output = InputSparseLinear.apply(inp["lhs"], inp["rhs"], inp["sparse_map"]) - output.backward(_zero_padded_rows(torch.ones_like(output), inp["sparse_map"])) + output.backward(inp["backward_grad"]) return {"output": output.detach(), "grad_lhs": inp["lhs"].grad, "grad_rhs": inp["rhs"].grad} From d521cf7f6cfa3fcb2474b6ecda4a44a16f14bfde Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Tue, 28 Apr 2026 19:32:08 -0400 Subject: [PATCH 07/41] Fix sparse_linear backward_grad: zero phantom rows before precomputing The Triton sparse-linear backward reads phantom row positions in the upstream gradient tensor. Passing all-ones for phantom rows gives a different grad_lhs than the PyTorch reference (which has no autograd edge for phantom rows), causing rel_rms ~0.28. Fix: apply _zero_padded_rows to backward_grad during make_inputs (outside the timing loop), matching the lhs_data preprocessing pattern. Co-Authored-By: Claude Sonnet 4.6 --- tools/benchmark/bench_sparse_linear.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/tools/benchmark/bench_sparse_linear.py b/tools/benchmark/bench_sparse_linear.py index bbacfd0b5..81af3f60b 100644 --- a/tools/benchmark/bench_sparse_linear.py +++ b/tools/benchmark/bench_sparse_linear.py @@ -60,7 +60,9 @@ def _make_output_sparse_inputs( sparse_map = _make_sparse_map(tokens, top_k, num_experts) lhs_data = _zero_padded_rows(torch.randn(sparse_map.num_rows, hidden, dtype=dtype, device=device()), sparse_map) rhs_data = torch.randn(hidden, ffn_per_expert * num_experts, dtype=dtype, device=device()) - backward_grad = torch.ones(sparse_map.num_rows, ffn_per_expert, dtype=dtype, device=device()) + backward_grad = _zero_padded_rows( + torch.ones(sparse_map.num_rows, ffn_per_expert, dtype=dtype, device=device()), sparse_map + ) # Warm up Triton autotuning so the timed runs aren't dominated by JIT compilation. if TritonConfig.enabled(): _w_lhs = lhs_data.detach().requires_grad_(True) @@ -85,7 +87,9 @@ def _make_input_inner_sparse_inputs( torch.randn(sparse_map.num_rows, ffn_per_expert, dtype=dtype, device=device()), sparse_map ) rhs_data = torch.randn(ffn_per_expert * num_experts, hidden, dtype=dtype, device=device()) - backward_grad = torch.ones(sparse_map.num_rows, hidden, dtype=dtype, device=device()) + backward_grad = _zero_padded_rows( + torch.ones(sparse_map.num_rows, hidden, dtype=dtype, device=device()), sparse_map + ) # Warm up Triton autotuning so the timed runs aren't dominated by JIT compilation. if TritonConfig.enabled(): _w_lhs = lhs_data.detach().requires_grad_(True) From d431d6602fa8ab6f8fe41fc5bedd7c7620e43b51 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Tue, 28 Apr 2026 22:11:50 -0400 Subject: [PATCH 08/41] Restructure rotary kernel: full-head contiguous load + autotune head_block_size MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Load re and im halves together as a single (head_block_size, head_size) contiguous block per head rather than two separate strided half-head loads. Use a sign-flip formula (out = x*cos + sign*x_partner*sin) to avoid splitting the loaded tensor. The partner load (swapped halves via tl.where) hits L2 after the primary load. Add @triton.autotune over head_block_size ∈ {1,2,4,8,16} × num_warps ∈ {4,8} to let the tuner find the optimal block size per shape. Co-Authored-By: Claude Sonnet 4.6 --- fast_llm/functional/triton/rotary.py | 95 +++++++++++++++++----------- 1 file changed, 58 insertions(+), 37 deletions(-) diff --git a/fast_llm/functional/triton/rotary.py b/fast_llm/functional/triton/rotary.py index f07046a52..80fa05de4 100644 --- a/fast_llm/functional/triton/rotary.py +++ b/fast_llm/functional/triton/rotary.py @@ -1,11 +1,30 @@ +import os + import torch -from fast_llm.functional.config import TritonConfig -from fast_llm.functional.triton import tl, tl_arange, tl_constexpr, triton, triton_jit +from fast_llm.functional.triton import TritonConfig, tl, tl_arange, tl_constexpr, triton, triton_autotune, triton_jit from fast_llm.functional.utils import wrap_forward_backward from fast_llm.utils import div - +autotune_configs = ( + TritonConfig({"head_block_size": 16}, num_warps=4), + TritonConfig({"head_block_size": 8}, num_warps=4), + TritonConfig({"head_block_size": 4}, num_warps=4), + TritonConfig({"head_block_size": 2}, num_warps=4), + TritonConfig({"head_block_size": 1}, num_warps=4), + TritonConfig({"head_block_size": 16}, num_warps=8), + TritonConfig({"head_block_size": 8}, num_warps=8), + TritonConfig({"head_block_size": 4}, num_warps=8), +) + +if os.environ.get("FAST_LLM_SKIP_TRITON_AUTOTUNE"): + autotune_configs = (autotune_configs[0],) + + +@triton_autotune( + configs=autotune_configs, + key=["rotary_dim", "num_heads", "seq_len"], +) @triton_jit() def triton_rotary_kernel( input_ptr, @@ -16,49 +35,57 @@ def triton_rotary_kernel( rotary_dim: tl_constexpr, num_heads: tl_constexpr, rotary_block_size: tl_constexpr, - head_block_size: tl_constexpr, seq_len: tl_constexpr, backward: tl_constexpr, + head_block_size: tl_constexpr, # injected by autotune ): # TODO: Int64 ptr if needed? pid_0 = tl.program_id(axis=0) # Folded (batch * seq) index - pid_1 = tl.program_id(axis=1) # Head index + pid_1 = tl.program_id(axis=1) # Head block index position_id = pid_0 % seq_len - offsets = tl_arange(0, rotary_block_size) - head_offsets = pid_1 * head_block_size + tl_arange(0, head_block_size)[:, None] - input_offsets = stride_0 * (pid_0 // seq_len) + stride_1 * position_id + stride_2 * head_offsets + offsets[None, :] - input_re_ptr = input_ptr + input_offsets - input_im_ptr = input_re_ptr + rotary_dim + # Full-head column offsets: [0, 1, …, 2*rotary_dim-1] + col_offsets = tl_arange(0, 2 * rotary_block_size) + head_row = pid_1 * head_block_size + tl_arange(0, head_block_size) + base = stride_0 * (pid_0 // seq_len) + stride_1 * position_id + input_offsets = base + stride_2 * head_row[:, None] + col_offsets[None, :] + + # Load full head as one contiguous block per head (re and im halves together) + if rotary_block_size % rotary_dim == 0 and num_heads % head_block_size == 0: + x = tl.load(input_ptr + input_offsets).to(tl.float32) + else: + mask = (col_offsets[None, :] < 2 * rotary_dim) & (head_row[:, None] < num_heads) + x = tl.load(input_ptr + input_offsets, mask=mask).to(tl.float32) + # Partner: x[e + rotary_dim] for re-columns, x[e - rotary_dim] for im-columns. + # These are the same cache lines as x, so expect L2 hits after the x load above. + partner_col = tl.where(col_offsets < rotary_dim, col_offsets + rotary_dim, col_offsets - rotary_dim) + partner_offsets = base + stride_2 * head_row[:, None] + partner_col[None, :] if rotary_block_size % rotary_dim == 0 and num_heads % head_block_size == 0: - input_re = tl.load(input_re_ptr).to(tl.float32) - input_im = tl.load(input_im_ptr).to(tl.float32) + x_partner = tl.load(input_ptr + partner_offsets).to(tl.float32) else: - mask = (offsets[None, :] < rotary_dim) & (head_offsets < num_heads) - input_re = tl.load(input_re_ptr, mask=mask).to(tl.float32) - input_im = tl.load(input_im_ptr, mask=mask).to(tl.float32) + x_partner = tl.load(input_ptr + partner_offsets, mask=mask).to(tl.float32) - # Computing frequencies here is faster but hurts precision, so we load pre-computed ones instead. - frequencies_offsets = 2 * rotary_dim * position_id + offsets - frequencies_re_ptr = frequencies_ptr + frequencies_offsets - frequencies_im_ptr = frequencies_re_ptr + rotary_dim - frequencies_re = tl.load(frequencies_re_ptr) - frequencies_im = tl.load(frequencies_im_ptr) + # Frequencies: same index for both halves (cos/sin repeat for re and im columns) + freq_col = tl.where(col_offsets < rotary_dim, col_offsets, col_offsets - rotary_dim) + freq_base = frequencies_ptr + 2 * rotary_dim * position_id + freq_re = tl.load(freq_base + freq_col) + freq_im = tl.load(freq_base + rotary_dim + freq_col) + # out[e] = x[e]*cos ± x_partner[e]*sin + # fwd: sign=-1 for re columns (cos*re - sin*im), +1 for im (cos*im + sin*re) + # bwd: conjugate rotation flips the sign if backward: - out_re = input_re * frequencies_re + input_im * frequencies_im - out_im = input_im * frequencies_re - input_re * frequencies_im + sign = tl.where(col_offsets < rotary_dim, 1.0, -1.0) else: - out_re = input_re * frequencies_re - input_im * frequencies_im - out_im = input_im * frequencies_re + input_re * frequencies_im + sign = tl.where(col_offsets < rotary_dim, -1.0, 1.0) + + out = x * freq_re[None, :] + sign[None, :] * x_partner * freq_im[None, :] if rotary_block_size % rotary_dim == 0 and num_heads % head_block_size == 0: - tl.store(input_re_ptr, out_re) - tl.store(input_im_ptr, out_im) + tl.store(input_ptr + input_offsets, out) else: - tl.store(input_re_ptr, out_re, mask=mask) # noqa - tl.store(input_im_ptr, out_im, mask=mask) + tl.store(input_ptr + input_offsets, out, mask=mask) def triton_rotary_( @@ -67,12 +94,9 @@ def triton_rotary_( is_key_value: bool = False, backward: bool = False, ) -> torch.Tensor: - # TODO: Improve assumptions. # TODO: Make a transposed version to avoid contiguous call in key backward. - # TODO: Improve block size heuristics. out = input_ if input_.stride(-1) != 1: - # TODO: Make a transposed version to avoid contiguous call in key backward. input_ = input_.contiguous() if input_.ndim == 3: input_ = input_.unsqueeze(0) @@ -82,12 +106,10 @@ def triton_rotary_( batch_size, seq_len, num_heads, head_size = input_.shape rotary_dim = div(head_size, 2) rotary_block_size = triton.next_power_of_2(rotary_dim) - head_block_size = triton.cdiv(TritonConfig.POINTWISE_BLOCK_SIZE, rotary_block_size) - if head_block_size > num_heads: - head_block_size = triton.next_power_of_2(num_heads) # Folded the large y dim into the x dim as gridDim.x is 32 bit while gridDim.y and gridDim.z are 16 bit registers - triton_rotary_kernel[(batch_size * seq_len, triton.cdiv(num_heads, head_block_size))]( + grid = lambda meta: (batch_size * seq_len, triton.cdiv(num_heads, meta["head_block_size"])) + triton_rotary_kernel[grid]( input_, frequencies, input_.stride(0), @@ -96,7 +118,6 @@ def triton_rotary_( rotary_dim, num_heads, rotary_block_size, - head_block_size, seq_len, backward, # noqa ) From ae0ae12300f4bde937b26b69521391a4ec6f02f3 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Wed, 29 Apr 2026 00:02:41 -0400 Subject: [PATCH 09/41] Fix Triton rotary kernel: remove autotune, simplify to two-half load @triton.autotune is incompatible with in-place kernels: the autotuner runs all configs sequentially on the same tensor, so each trial rotates the tensor again, producing garbage results in the benchmark. Replace with a fixed head_block_size computed from POINTWISE_BLOCK_SIZE and simplify the kernel to load re and im halves separately, holding both in registers to compute out_re and out_im without a partner load. Co-Authored-By: Claude Sonnet 4.6 --- fast_llm/functional/triton/rotary.py | 84 ++++++++++------------------ 1 file changed, 31 insertions(+), 53 deletions(-) diff --git a/fast_llm/functional/triton/rotary.py b/fast_llm/functional/triton/rotary.py index 80fa05de4..0298fb15f 100644 --- a/fast_llm/functional/triton/rotary.py +++ b/fast_llm/functional/triton/rotary.py @@ -1,30 +1,11 @@ -import os - import torch -from fast_llm.functional.triton import TritonConfig, tl, tl_arange, tl_constexpr, triton, triton_autotune, triton_jit +from fast_llm.functional.config import TritonConfig +from fast_llm.functional.triton import tl, tl_arange, tl_constexpr, triton, triton_jit from fast_llm.functional.utils import wrap_forward_backward from fast_llm.utils import div -autotune_configs = ( - TritonConfig({"head_block_size": 16}, num_warps=4), - TritonConfig({"head_block_size": 8}, num_warps=4), - TritonConfig({"head_block_size": 4}, num_warps=4), - TritonConfig({"head_block_size": 2}, num_warps=4), - TritonConfig({"head_block_size": 1}, num_warps=4), - TritonConfig({"head_block_size": 16}, num_warps=8), - TritonConfig({"head_block_size": 8}, num_warps=8), - TritonConfig({"head_block_size": 4}, num_warps=8), -) - -if os.environ.get("FAST_LLM_SKIP_TRITON_AUTOTUNE"): - autotune_configs = (autotune_configs[0],) - - -@triton_autotune( - configs=autotune_configs, - key=["rotary_dim", "num_heads", "seq_len"], -) + @triton_jit() def triton_rotary_kernel( input_ptr, @@ -35,57 +16,51 @@ def triton_rotary_kernel( rotary_dim: tl_constexpr, num_heads: tl_constexpr, rotary_block_size: tl_constexpr, + head_block_size: tl_constexpr, seq_len: tl_constexpr, backward: tl_constexpr, - head_block_size: tl_constexpr, # injected by autotune ): # TODO: Int64 ptr if needed? pid_0 = tl.program_id(axis=0) # Folded (batch * seq) index pid_1 = tl.program_id(axis=1) # Head block index position_id = pid_0 % seq_len - # Full-head column offsets: [0, 1, …, 2*rotary_dim-1] - col_offsets = tl_arange(0, 2 * rotary_block_size) + col_offsets = tl_arange(0, rotary_block_size) head_row = pid_1 * head_block_size + tl_arange(0, head_block_size) base = stride_0 * (pid_0 // seq_len) + stride_1 * position_id - input_offsets = base + stride_2 * head_row[:, None] + col_offsets[None, :] - # Load full head as one contiguous block per head (re and im halves together) - if rotary_block_size % rotary_dim == 0 and num_heads % head_block_size == 0: - x = tl.load(input_ptr + input_offsets).to(tl.float32) - else: - mask = (col_offsets[None, :] < 2 * rotary_dim) & (head_row[:, None] < num_heads) - x = tl.load(input_ptr + input_offsets, mask=mask).to(tl.float32) + # Load re and im halves separately so both are in registers simultaneously, + # avoiding a partner load (reading the same cache lines again with shuffled indices). + re_offsets = base + stride_2 * head_row[:, None] + col_offsets[None, :] + im_offsets = re_offsets + rotary_dim - # Partner: x[e + rotary_dim] for re-columns, x[e - rotary_dim] for im-columns. - # These are the same cache lines as x, so expect L2 hits after the x load above. - partner_col = tl.where(col_offsets < rotary_dim, col_offsets + rotary_dim, col_offsets - rotary_dim) - partner_offsets = base + stride_2 * head_row[:, None] + partner_col[None, :] if rotary_block_size % rotary_dim == 0 and num_heads % head_block_size == 0: - x_partner = tl.load(input_ptr + partner_offsets).to(tl.float32) + x_re = tl.load(input_ptr + re_offsets).to(tl.float32) + x_im = tl.load(input_ptr + im_offsets).to(tl.float32) else: - x_partner = tl.load(input_ptr + partner_offsets, mask=mask).to(tl.float32) + mask = (col_offsets[None, :] < rotary_dim) & (head_row[:, None] < num_heads) + x_re = tl.load(input_ptr + re_offsets, mask=mask).to(tl.float32) + x_im = tl.load(input_ptr + im_offsets, mask=mask).to(tl.float32) - # Frequencies: same index for both halves (cos/sin repeat for re and im columns) - freq_col = tl.where(col_offsets < rotary_dim, col_offsets, col_offsets - rotary_dim) freq_base = frequencies_ptr + 2 * rotary_dim * position_id - freq_re = tl.load(freq_base + freq_col) - freq_im = tl.load(freq_base + rotary_dim + freq_col) + freq_re = tl.load(freq_base + col_offsets) + freq_im = tl.load(freq_base + rotary_dim + col_offsets) - # out[e] = x[e]*cos ± x_partner[e]*sin - # fwd: sign=-1 for re columns (cos*re - sin*im), +1 for im (cos*im + sin*re) - # bwd: conjugate rotation flips the sign + # fwd: out_re = cos*re - sin*im, out_im = cos*im + sin*re + # bwd: conjugate rotation, sin signs flipped if backward: - sign = tl.where(col_offsets < rotary_dim, 1.0, -1.0) + out_re = x_re * freq_re[None, :] + x_im * freq_im[None, :] + out_im = x_im * freq_re[None, :] - x_re * freq_im[None, :] else: - sign = tl.where(col_offsets < rotary_dim, -1.0, 1.0) - - out = x * freq_re[None, :] + sign[None, :] * x_partner * freq_im[None, :] + out_re = x_re * freq_re[None, :] - x_im * freq_im[None, :] + out_im = x_im * freq_re[None, :] + x_re * freq_im[None, :] if rotary_block_size % rotary_dim == 0 and num_heads % head_block_size == 0: - tl.store(input_ptr + input_offsets, out) + tl.store(input_ptr + re_offsets, out_re) + tl.store(input_ptr + im_offsets, out_im) else: - tl.store(input_ptr + input_offsets, out, mask=mask) + tl.store(input_ptr + re_offsets, out_re, mask=mask) + tl.store(input_ptr + im_offsets, out_im, mask=mask) def triton_rotary_( @@ -106,10 +81,12 @@ def triton_rotary_( batch_size, seq_len, num_heads, head_size = input_.shape rotary_dim = div(head_size, 2) rotary_block_size = triton.next_power_of_2(rotary_dim) + head_block_size = triton.cdiv(TritonConfig.POINTWISE_BLOCK_SIZE, rotary_block_size) + if head_block_size > num_heads: + head_block_size = triton.next_power_of_2(num_heads) # Folded the large y dim into the x dim as gridDim.x is 32 bit while gridDim.y and gridDim.z are 16 bit registers - grid = lambda meta: (batch_size * seq_len, triton.cdiv(num_heads, meta["head_block_size"])) - triton_rotary_kernel[grid]( + triton_rotary_kernel[(batch_size * seq_len, triton.cdiv(num_heads, head_block_size))]( input_, frequencies, input_.stride(0), @@ -118,6 +95,7 @@ def triton_rotary_( rotary_dim, num_heads, rotary_block_size, + head_block_size, seq_len, backward, # noqa ) From cc56b9221f86d4d8ef1feb6477d8207d24b2e175 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Wed, 29 Apr 2026 00:08:46 -0400 Subject: [PATCH 10/41] Revert "Fix Triton rotary kernel: remove autotune, simplify to two-half load" This reverts commit 178903ff503ad6aa5cdeffd515ee8b49ac4b4d4b. --- fast_llm/functional/triton/rotary.py | 84 ++++++++++++++++++---------- 1 file changed, 53 insertions(+), 31 deletions(-) diff --git a/fast_llm/functional/triton/rotary.py b/fast_llm/functional/triton/rotary.py index 0298fb15f..80fa05de4 100644 --- a/fast_llm/functional/triton/rotary.py +++ b/fast_llm/functional/triton/rotary.py @@ -1,11 +1,30 @@ +import os + import torch -from fast_llm.functional.config import TritonConfig -from fast_llm.functional.triton import tl, tl_arange, tl_constexpr, triton, triton_jit +from fast_llm.functional.triton import TritonConfig, tl, tl_arange, tl_constexpr, triton, triton_autotune, triton_jit from fast_llm.functional.utils import wrap_forward_backward from fast_llm.utils import div - +autotune_configs = ( + TritonConfig({"head_block_size": 16}, num_warps=4), + TritonConfig({"head_block_size": 8}, num_warps=4), + TritonConfig({"head_block_size": 4}, num_warps=4), + TritonConfig({"head_block_size": 2}, num_warps=4), + TritonConfig({"head_block_size": 1}, num_warps=4), + TritonConfig({"head_block_size": 16}, num_warps=8), + TritonConfig({"head_block_size": 8}, num_warps=8), + TritonConfig({"head_block_size": 4}, num_warps=8), +) + +if os.environ.get("FAST_LLM_SKIP_TRITON_AUTOTUNE"): + autotune_configs = (autotune_configs[0],) + + +@triton_autotune( + configs=autotune_configs, + key=["rotary_dim", "num_heads", "seq_len"], +) @triton_jit() def triton_rotary_kernel( input_ptr, @@ -16,51 +35,57 @@ def triton_rotary_kernel( rotary_dim: tl_constexpr, num_heads: tl_constexpr, rotary_block_size: tl_constexpr, - head_block_size: tl_constexpr, seq_len: tl_constexpr, backward: tl_constexpr, + head_block_size: tl_constexpr, # injected by autotune ): # TODO: Int64 ptr if needed? pid_0 = tl.program_id(axis=0) # Folded (batch * seq) index pid_1 = tl.program_id(axis=1) # Head block index position_id = pid_0 % seq_len - col_offsets = tl_arange(0, rotary_block_size) + # Full-head column offsets: [0, 1, …, 2*rotary_dim-1] + col_offsets = tl_arange(0, 2 * rotary_block_size) head_row = pid_1 * head_block_size + tl_arange(0, head_block_size) base = stride_0 * (pid_0 // seq_len) + stride_1 * position_id + input_offsets = base + stride_2 * head_row[:, None] + col_offsets[None, :] - # Load re and im halves separately so both are in registers simultaneously, - # avoiding a partner load (reading the same cache lines again with shuffled indices). - re_offsets = base + stride_2 * head_row[:, None] + col_offsets[None, :] - im_offsets = re_offsets + rotary_dim + # Load full head as one contiguous block per head (re and im halves together) + if rotary_block_size % rotary_dim == 0 and num_heads % head_block_size == 0: + x = tl.load(input_ptr + input_offsets).to(tl.float32) + else: + mask = (col_offsets[None, :] < 2 * rotary_dim) & (head_row[:, None] < num_heads) + x = tl.load(input_ptr + input_offsets, mask=mask).to(tl.float32) + # Partner: x[e + rotary_dim] for re-columns, x[e - rotary_dim] for im-columns. + # These are the same cache lines as x, so expect L2 hits after the x load above. + partner_col = tl.where(col_offsets < rotary_dim, col_offsets + rotary_dim, col_offsets - rotary_dim) + partner_offsets = base + stride_2 * head_row[:, None] + partner_col[None, :] if rotary_block_size % rotary_dim == 0 and num_heads % head_block_size == 0: - x_re = tl.load(input_ptr + re_offsets).to(tl.float32) - x_im = tl.load(input_ptr + im_offsets).to(tl.float32) + x_partner = tl.load(input_ptr + partner_offsets).to(tl.float32) else: - mask = (col_offsets[None, :] < rotary_dim) & (head_row[:, None] < num_heads) - x_re = tl.load(input_ptr + re_offsets, mask=mask).to(tl.float32) - x_im = tl.load(input_ptr + im_offsets, mask=mask).to(tl.float32) + x_partner = tl.load(input_ptr + partner_offsets, mask=mask).to(tl.float32) + # Frequencies: same index for both halves (cos/sin repeat for re and im columns) + freq_col = tl.where(col_offsets < rotary_dim, col_offsets, col_offsets - rotary_dim) freq_base = frequencies_ptr + 2 * rotary_dim * position_id - freq_re = tl.load(freq_base + col_offsets) - freq_im = tl.load(freq_base + rotary_dim + col_offsets) + freq_re = tl.load(freq_base + freq_col) + freq_im = tl.load(freq_base + rotary_dim + freq_col) - # fwd: out_re = cos*re - sin*im, out_im = cos*im + sin*re - # bwd: conjugate rotation, sin signs flipped + # out[e] = x[e]*cos ± x_partner[e]*sin + # fwd: sign=-1 for re columns (cos*re - sin*im), +1 for im (cos*im + sin*re) + # bwd: conjugate rotation flips the sign if backward: - out_re = x_re * freq_re[None, :] + x_im * freq_im[None, :] - out_im = x_im * freq_re[None, :] - x_re * freq_im[None, :] + sign = tl.where(col_offsets < rotary_dim, 1.0, -1.0) else: - out_re = x_re * freq_re[None, :] - x_im * freq_im[None, :] - out_im = x_im * freq_re[None, :] + x_re * freq_im[None, :] + sign = tl.where(col_offsets < rotary_dim, -1.0, 1.0) + + out = x * freq_re[None, :] + sign[None, :] * x_partner * freq_im[None, :] if rotary_block_size % rotary_dim == 0 and num_heads % head_block_size == 0: - tl.store(input_ptr + re_offsets, out_re) - tl.store(input_ptr + im_offsets, out_im) + tl.store(input_ptr + input_offsets, out) else: - tl.store(input_ptr + re_offsets, out_re, mask=mask) - tl.store(input_ptr + im_offsets, out_im, mask=mask) + tl.store(input_ptr + input_offsets, out, mask=mask) def triton_rotary_( @@ -81,12 +106,10 @@ def triton_rotary_( batch_size, seq_len, num_heads, head_size = input_.shape rotary_dim = div(head_size, 2) rotary_block_size = triton.next_power_of_2(rotary_dim) - head_block_size = triton.cdiv(TritonConfig.POINTWISE_BLOCK_SIZE, rotary_block_size) - if head_block_size > num_heads: - head_block_size = triton.next_power_of_2(num_heads) # Folded the large y dim into the x dim as gridDim.x is 32 bit while gridDim.y and gridDim.z are 16 bit registers - triton_rotary_kernel[(batch_size * seq_len, triton.cdiv(num_heads, head_block_size))]( + grid = lambda meta: (batch_size * seq_len, triton.cdiv(num_heads, meta["head_block_size"])) + triton_rotary_kernel[grid]( input_, frequencies, input_.stride(0), @@ -95,7 +118,6 @@ def triton_rotary_( rotary_dim, num_heads, rotary_block_size, - head_block_size, seq_len, backward, # noqa ) From 77e6b07952e1986974af12c5f40a8c330253d32a Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Wed, 29 Apr 2026 00:08:46 -0400 Subject: [PATCH 11/41] Revert "Restructure rotary kernel: full-head contiguous load + autotune head_block_size" This reverts commit c21ef1279f4dfbe5964574c73aae379baff25a26. --- fast_llm/functional/triton/rotary.py | 95 +++++++++++----------------- 1 file changed, 37 insertions(+), 58 deletions(-) diff --git a/fast_llm/functional/triton/rotary.py b/fast_llm/functional/triton/rotary.py index 80fa05de4..f07046a52 100644 --- a/fast_llm/functional/triton/rotary.py +++ b/fast_llm/functional/triton/rotary.py @@ -1,30 +1,11 @@ -import os - import torch -from fast_llm.functional.triton import TritonConfig, tl, tl_arange, tl_constexpr, triton, triton_autotune, triton_jit +from fast_llm.functional.config import TritonConfig +from fast_llm.functional.triton import tl, tl_arange, tl_constexpr, triton, triton_jit from fast_llm.functional.utils import wrap_forward_backward from fast_llm.utils import div -autotune_configs = ( - TritonConfig({"head_block_size": 16}, num_warps=4), - TritonConfig({"head_block_size": 8}, num_warps=4), - TritonConfig({"head_block_size": 4}, num_warps=4), - TritonConfig({"head_block_size": 2}, num_warps=4), - TritonConfig({"head_block_size": 1}, num_warps=4), - TritonConfig({"head_block_size": 16}, num_warps=8), - TritonConfig({"head_block_size": 8}, num_warps=8), - TritonConfig({"head_block_size": 4}, num_warps=8), -) - -if os.environ.get("FAST_LLM_SKIP_TRITON_AUTOTUNE"): - autotune_configs = (autotune_configs[0],) - - -@triton_autotune( - configs=autotune_configs, - key=["rotary_dim", "num_heads", "seq_len"], -) + @triton_jit() def triton_rotary_kernel( input_ptr, @@ -35,57 +16,49 @@ def triton_rotary_kernel( rotary_dim: tl_constexpr, num_heads: tl_constexpr, rotary_block_size: tl_constexpr, + head_block_size: tl_constexpr, seq_len: tl_constexpr, backward: tl_constexpr, - head_block_size: tl_constexpr, # injected by autotune ): # TODO: Int64 ptr if needed? pid_0 = tl.program_id(axis=0) # Folded (batch * seq) index - pid_1 = tl.program_id(axis=1) # Head block index + pid_1 = tl.program_id(axis=1) # Head index position_id = pid_0 % seq_len - # Full-head column offsets: [0, 1, …, 2*rotary_dim-1] - col_offsets = tl_arange(0, 2 * rotary_block_size) - head_row = pid_1 * head_block_size + tl_arange(0, head_block_size) - base = stride_0 * (pid_0 // seq_len) + stride_1 * position_id - input_offsets = base + stride_2 * head_row[:, None] + col_offsets[None, :] - - # Load full head as one contiguous block per head (re and im halves together) - if rotary_block_size % rotary_dim == 0 and num_heads % head_block_size == 0: - x = tl.load(input_ptr + input_offsets).to(tl.float32) - else: - mask = (col_offsets[None, :] < 2 * rotary_dim) & (head_row[:, None] < num_heads) - x = tl.load(input_ptr + input_offsets, mask=mask).to(tl.float32) + offsets = tl_arange(0, rotary_block_size) + head_offsets = pid_1 * head_block_size + tl_arange(0, head_block_size)[:, None] + input_offsets = stride_0 * (pid_0 // seq_len) + stride_1 * position_id + stride_2 * head_offsets + offsets[None, :] + input_re_ptr = input_ptr + input_offsets + input_im_ptr = input_re_ptr + rotary_dim - # Partner: x[e + rotary_dim] for re-columns, x[e - rotary_dim] for im-columns. - # These are the same cache lines as x, so expect L2 hits after the x load above. - partner_col = tl.where(col_offsets < rotary_dim, col_offsets + rotary_dim, col_offsets - rotary_dim) - partner_offsets = base + stride_2 * head_row[:, None] + partner_col[None, :] if rotary_block_size % rotary_dim == 0 and num_heads % head_block_size == 0: - x_partner = tl.load(input_ptr + partner_offsets).to(tl.float32) + input_re = tl.load(input_re_ptr).to(tl.float32) + input_im = tl.load(input_im_ptr).to(tl.float32) else: - x_partner = tl.load(input_ptr + partner_offsets, mask=mask).to(tl.float32) + mask = (offsets[None, :] < rotary_dim) & (head_offsets < num_heads) + input_re = tl.load(input_re_ptr, mask=mask).to(tl.float32) + input_im = tl.load(input_im_ptr, mask=mask).to(tl.float32) - # Frequencies: same index for both halves (cos/sin repeat for re and im columns) - freq_col = tl.where(col_offsets < rotary_dim, col_offsets, col_offsets - rotary_dim) - freq_base = frequencies_ptr + 2 * rotary_dim * position_id - freq_re = tl.load(freq_base + freq_col) - freq_im = tl.load(freq_base + rotary_dim + freq_col) + # Computing frequencies here is faster but hurts precision, so we load pre-computed ones instead. + frequencies_offsets = 2 * rotary_dim * position_id + offsets + frequencies_re_ptr = frequencies_ptr + frequencies_offsets + frequencies_im_ptr = frequencies_re_ptr + rotary_dim + frequencies_re = tl.load(frequencies_re_ptr) + frequencies_im = tl.load(frequencies_im_ptr) - # out[e] = x[e]*cos ± x_partner[e]*sin - # fwd: sign=-1 for re columns (cos*re - sin*im), +1 for im (cos*im + sin*re) - # bwd: conjugate rotation flips the sign if backward: - sign = tl.where(col_offsets < rotary_dim, 1.0, -1.0) + out_re = input_re * frequencies_re + input_im * frequencies_im + out_im = input_im * frequencies_re - input_re * frequencies_im else: - sign = tl.where(col_offsets < rotary_dim, -1.0, 1.0) - - out = x * freq_re[None, :] + sign[None, :] * x_partner * freq_im[None, :] + out_re = input_re * frequencies_re - input_im * frequencies_im + out_im = input_im * frequencies_re + input_re * frequencies_im if rotary_block_size % rotary_dim == 0 and num_heads % head_block_size == 0: - tl.store(input_ptr + input_offsets, out) + tl.store(input_re_ptr, out_re) + tl.store(input_im_ptr, out_im) else: - tl.store(input_ptr + input_offsets, out, mask=mask) + tl.store(input_re_ptr, out_re, mask=mask) # noqa + tl.store(input_im_ptr, out_im, mask=mask) def triton_rotary_( @@ -94,9 +67,12 @@ def triton_rotary_( is_key_value: bool = False, backward: bool = False, ) -> torch.Tensor: + # TODO: Improve assumptions. # TODO: Make a transposed version to avoid contiguous call in key backward. + # TODO: Improve block size heuristics. out = input_ if input_.stride(-1) != 1: + # TODO: Make a transposed version to avoid contiguous call in key backward. input_ = input_.contiguous() if input_.ndim == 3: input_ = input_.unsqueeze(0) @@ -106,10 +82,12 @@ def triton_rotary_( batch_size, seq_len, num_heads, head_size = input_.shape rotary_dim = div(head_size, 2) rotary_block_size = triton.next_power_of_2(rotary_dim) + head_block_size = triton.cdiv(TritonConfig.POINTWISE_BLOCK_SIZE, rotary_block_size) + if head_block_size > num_heads: + head_block_size = triton.next_power_of_2(num_heads) # Folded the large y dim into the x dim as gridDim.x is 32 bit while gridDim.y and gridDim.z are 16 bit registers - grid = lambda meta: (batch_size * seq_len, triton.cdiv(num_heads, meta["head_block_size"])) - triton_rotary_kernel[grid]( + triton_rotary_kernel[(batch_size * seq_len, triton.cdiv(num_heads, head_block_size))]( input_, frequencies, input_.stride(0), @@ -118,6 +96,7 @@ def triton_rotary_( rotary_dim, num_heads, rotary_block_size, + head_block_size, seq_len, backward, # noqa ) From 4e2b996bae14e5917a415deb79c036890a3781c3 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Wed, 29 Apr 2026 00:11:22 -0400 Subject: [PATCH 12/41] Revert "Revert "Restructure rotary kernel: full-head contiguous load + autotune head_block_size"" This reverts commit 0be67b12a73e86c41199a2ba3f78d836e66f8069. --- fast_llm/functional/triton/rotary.py | 95 +++++++++++++++++----------- 1 file changed, 58 insertions(+), 37 deletions(-) diff --git a/fast_llm/functional/triton/rotary.py b/fast_llm/functional/triton/rotary.py index f07046a52..80fa05de4 100644 --- a/fast_llm/functional/triton/rotary.py +++ b/fast_llm/functional/triton/rotary.py @@ -1,11 +1,30 @@ +import os + import torch -from fast_llm.functional.config import TritonConfig -from fast_llm.functional.triton import tl, tl_arange, tl_constexpr, triton, triton_jit +from fast_llm.functional.triton import TritonConfig, tl, tl_arange, tl_constexpr, triton, triton_autotune, triton_jit from fast_llm.functional.utils import wrap_forward_backward from fast_llm.utils import div - +autotune_configs = ( + TritonConfig({"head_block_size": 16}, num_warps=4), + TritonConfig({"head_block_size": 8}, num_warps=4), + TritonConfig({"head_block_size": 4}, num_warps=4), + TritonConfig({"head_block_size": 2}, num_warps=4), + TritonConfig({"head_block_size": 1}, num_warps=4), + TritonConfig({"head_block_size": 16}, num_warps=8), + TritonConfig({"head_block_size": 8}, num_warps=8), + TritonConfig({"head_block_size": 4}, num_warps=8), +) + +if os.environ.get("FAST_LLM_SKIP_TRITON_AUTOTUNE"): + autotune_configs = (autotune_configs[0],) + + +@triton_autotune( + configs=autotune_configs, + key=["rotary_dim", "num_heads", "seq_len"], +) @triton_jit() def triton_rotary_kernel( input_ptr, @@ -16,49 +35,57 @@ def triton_rotary_kernel( rotary_dim: tl_constexpr, num_heads: tl_constexpr, rotary_block_size: tl_constexpr, - head_block_size: tl_constexpr, seq_len: tl_constexpr, backward: tl_constexpr, + head_block_size: tl_constexpr, # injected by autotune ): # TODO: Int64 ptr if needed? pid_0 = tl.program_id(axis=0) # Folded (batch * seq) index - pid_1 = tl.program_id(axis=1) # Head index + pid_1 = tl.program_id(axis=1) # Head block index position_id = pid_0 % seq_len - offsets = tl_arange(0, rotary_block_size) - head_offsets = pid_1 * head_block_size + tl_arange(0, head_block_size)[:, None] - input_offsets = stride_0 * (pid_0 // seq_len) + stride_1 * position_id + stride_2 * head_offsets + offsets[None, :] - input_re_ptr = input_ptr + input_offsets - input_im_ptr = input_re_ptr + rotary_dim + # Full-head column offsets: [0, 1, …, 2*rotary_dim-1] + col_offsets = tl_arange(0, 2 * rotary_block_size) + head_row = pid_1 * head_block_size + tl_arange(0, head_block_size) + base = stride_0 * (pid_0 // seq_len) + stride_1 * position_id + input_offsets = base + stride_2 * head_row[:, None] + col_offsets[None, :] + + # Load full head as one contiguous block per head (re and im halves together) + if rotary_block_size % rotary_dim == 0 and num_heads % head_block_size == 0: + x = tl.load(input_ptr + input_offsets).to(tl.float32) + else: + mask = (col_offsets[None, :] < 2 * rotary_dim) & (head_row[:, None] < num_heads) + x = tl.load(input_ptr + input_offsets, mask=mask).to(tl.float32) + # Partner: x[e + rotary_dim] for re-columns, x[e - rotary_dim] for im-columns. + # These are the same cache lines as x, so expect L2 hits after the x load above. + partner_col = tl.where(col_offsets < rotary_dim, col_offsets + rotary_dim, col_offsets - rotary_dim) + partner_offsets = base + stride_2 * head_row[:, None] + partner_col[None, :] if rotary_block_size % rotary_dim == 0 and num_heads % head_block_size == 0: - input_re = tl.load(input_re_ptr).to(tl.float32) - input_im = tl.load(input_im_ptr).to(tl.float32) + x_partner = tl.load(input_ptr + partner_offsets).to(tl.float32) else: - mask = (offsets[None, :] < rotary_dim) & (head_offsets < num_heads) - input_re = tl.load(input_re_ptr, mask=mask).to(tl.float32) - input_im = tl.load(input_im_ptr, mask=mask).to(tl.float32) + x_partner = tl.load(input_ptr + partner_offsets, mask=mask).to(tl.float32) - # Computing frequencies here is faster but hurts precision, so we load pre-computed ones instead. - frequencies_offsets = 2 * rotary_dim * position_id + offsets - frequencies_re_ptr = frequencies_ptr + frequencies_offsets - frequencies_im_ptr = frequencies_re_ptr + rotary_dim - frequencies_re = tl.load(frequencies_re_ptr) - frequencies_im = tl.load(frequencies_im_ptr) + # Frequencies: same index for both halves (cos/sin repeat for re and im columns) + freq_col = tl.where(col_offsets < rotary_dim, col_offsets, col_offsets - rotary_dim) + freq_base = frequencies_ptr + 2 * rotary_dim * position_id + freq_re = tl.load(freq_base + freq_col) + freq_im = tl.load(freq_base + rotary_dim + freq_col) + # out[e] = x[e]*cos ± x_partner[e]*sin + # fwd: sign=-1 for re columns (cos*re - sin*im), +1 for im (cos*im + sin*re) + # bwd: conjugate rotation flips the sign if backward: - out_re = input_re * frequencies_re + input_im * frequencies_im - out_im = input_im * frequencies_re - input_re * frequencies_im + sign = tl.where(col_offsets < rotary_dim, 1.0, -1.0) else: - out_re = input_re * frequencies_re - input_im * frequencies_im - out_im = input_im * frequencies_re + input_re * frequencies_im + sign = tl.where(col_offsets < rotary_dim, -1.0, 1.0) + + out = x * freq_re[None, :] + sign[None, :] * x_partner * freq_im[None, :] if rotary_block_size % rotary_dim == 0 and num_heads % head_block_size == 0: - tl.store(input_re_ptr, out_re) - tl.store(input_im_ptr, out_im) + tl.store(input_ptr + input_offsets, out) else: - tl.store(input_re_ptr, out_re, mask=mask) # noqa - tl.store(input_im_ptr, out_im, mask=mask) + tl.store(input_ptr + input_offsets, out, mask=mask) def triton_rotary_( @@ -67,12 +94,9 @@ def triton_rotary_( is_key_value: bool = False, backward: bool = False, ) -> torch.Tensor: - # TODO: Improve assumptions. # TODO: Make a transposed version to avoid contiguous call in key backward. - # TODO: Improve block size heuristics. out = input_ if input_.stride(-1) != 1: - # TODO: Make a transposed version to avoid contiguous call in key backward. input_ = input_.contiguous() if input_.ndim == 3: input_ = input_.unsqueeze(0) @@ -82,12 +106,10 @@ def triton_rotary_( batch_size, seq_len, num_heads, head_size = input_.shape rotary_dim = div(head_size, 2) rotary_block_size = triton.next_power_of_2(rotary_dim) - head_block_size = triton.cdiv(TritonConfig.POINTWISE_BLOCK_SIZE, rotary_block_size) - if head_block_size > num_heads: - head_block_size = triton.next_power_of_2(num_heads) # Folded the large y dim into the x dim as gridDim.x is 32 bit while gridDim.y and gridDim.z are 16 bit registers - triton_rotary_kernel[(batch_size * seq_len, triton.cdiv(num_heads, head_block_size))]( + grid = lambda meta: (batch_size * seq_len, triton.cdiv(num_heads, meta["head_block_size"])) + triton_rotary_kernel[grid]( input_, frequencies, input_.stride(0), @@ -96,7 +118,6 @@ def triton_rotary_( rotary_dim, num_heads, rotary_block_size, - head_block_size, seq_len, backward, # noqa ) From 091ae5fb512f07901d82ae644326cf8f96e3d6c3 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Wed, 29 Apr 2026 00:11:22 -0400 Subject: [PATCH 13/41] Revert "Revert "Fix Triton rotary kernel: remove autotune, simplify to two-half load"" This reverts commit fa22f72b9e2d175fba76d98f35b74f7c1da0c0a6. --- fast_llm/functional/triton/rotary.py | 84 ++++++++++------------------ 1 file changed, 31 insertions(+), 53 deletions(-) diff --git a/fast_llm/functional/triton/rotary.py b/fast_llm/functional/triton/rotary.py index 80fa05de4..0298fb15f 100644 --- a/fast_llm/functional/triton/rotary.py +++ b/fast_llm/functional/triton/rotary.py @@ -1,30 +1,11 @@ -import os - import torch -from fast_llm.functional.triton import TritonConfig, tl, tl_arange, tl_constexpr, triton, triton_autotune, triton_jit +from fast_llm.functional.config import TritonConfig +from fast_llm.functional.triton import tl, tl_arange, tl_constexpr, triton, triton_jit from fast_llm.functional.utils import wrap_forward_backward from fast_llm.utils import div -autotune_configs = ( - TritonConfig({"head_block_size": 16}, num_warps=4), - TritonConfig({"head_block_size": 8}, num_warps=4), - TritonConfig({"head_block_size": 4}, num_warps=4), - TritonConfig({"head_block_size": 2}, num_warps=4), - TritonConfig({"head_block_size": 1}, num_warps=4), - TritonConfig({"head_block_size": 16}, num_warps=8), - TritonConfig({"head_block_size": 8}, num_warps=8), - TritonConfig({"head_block_size": 4}, num_warps=8), -) - -if os.environ.get("FAST_LLM_SKIP_TRITON_AUTOTUNE"): - autotune_configs = (autotune_configs[0],) - - -@triton_autotune( - configs=autotune_configs, - key=["rotary_dim", "num_heads", "seq_len"], -) + @triton_jit() def triton_rotary_kernel( input_ptr, @@ -35,57 +16,51 @@ def triton_rotary_kernel( rotary_dim: tl_constexpr, num_heads: tl_constexpr, rotary_block_size: tl_constexpr, + head_block_size: tl_constexpr, seq_len: tl_constexpr, backward: tl_constexpr, - head_block_size: tl_constexpr, # injected by autotune ): # TODO: Int64 ptr if needed? pid_0 = tl.program_id(axis=0) # Folded (batch * seq) index pid_1 = tl.program_id(axis=1) # Head block index position_id = pid_0 % seq_len - # Full-head column offsets: [0, 1, …, 2*rotary_dim-1] - col_offsets = tl_arange(0, 2 * rotary_block_size) + col_offsets = tl_arange(0, rotary_block_size) head_row = pid_1 * head_block_size + tl_arange(0, head_block_size) base = stride_0 * (pid_0 // seq_len) + stride_1 * position_id - input_offsets = base + stride_2 * head_row[:, None] + col_offsets[None, :] - # Load full head as one contiguous block per head (re and im halves together) - if rotary_block_size % rotary_dim == 0 and num_heads % head_block_size == 0: - x = tl.load(input_ptr + input_offsets).to(tl.float32) - else: - mask = (col_offsets[None, :] < 2 * rotary_dim) & (head_row[:, None] < num_heads) - x = tl.load(input_ptr + input_offsets, mask=mask).to(tl.float32) + # Load re and im halves separately so both are in registers simultaneously, + # avoiding a partner load (reading the same cache lines again with shuffled indices). + re_offsets = base + stride_2 * head_row[:, None] + col_offsets[None, :] + im_offsets = re_offsets + rotary_dim - # Partner: x[e + rotary_dim] for re-columns, x[e - rotary_dim] for im-columns. - # These are the same cache lines as x, so expect L2 hits after the x load above. - partner_col = tl.where(col_offsets < rotary_dim, col_offsets + rotary_dim, col_offsets - rotary_dim) - partner_offsets = base + stride_2 * head_row[:, None] + partner_col[None, :] if rotary_block_size % rotary_dim == 0 and num_heads % head_block_size == 0: - x_partner = tl.load(input_ptr + partner_offsets).to(tl.float32) + x_re = tl.load(input_ptr + re_offsets).to(tl.float32) + x_im = tl.load(input_ptr + im_offsets).to(tl.float32) else: - x_partner = tl.load(input_ptr + partner_offsets, mask=mask).to(tl.float32) + mask = (col_offsets[None, :] < rotary_dim) & (head_row[:, None] < num_heads) + x_re = tl.load(input_ptr + re_offsets, mask=mask).to(tl.float32) + x_im = tl.load(input_ptr + im_offsets, mask=mask).to(tl.float32) - # Frequencies: same index for both halves (cos/sin repeat for re and im columns) - freq_col = tl.where(col_offsets < rotary_dim, col_offsets, col_offsets - rotary_dim) freq_base = frequencies_ptr + 2 * rotary_dim * position_id - freq_re = tl.load(freq_base + freq_col) - freq_im = tl.load(freq_base + rotary_dim + freq_col) + freq_re = tl.load(freq_base + col_offsets) + freq_im = tl.load(freq_base + rotary_dim + col_offsets) - # out[e] = x[e]*cos ± x_partner[e]*sin - # fwd: sign=-1 for re columns (cos*re - sin*im), +1 for im (cos*im + sin*re) - # bwd: conjugate rotation flips the sign + # fwd: out_re = cos*re - sin*im, out_im = cos*im + sin*re + # bwd: conjugate rotation, sin signs flipped if backward: - sign = tl.where(col_offsets < rotary_dim, 1.0, -1.0) + out_re = x_re * freq_re[None, :] + x_im * freq_im[None, :] + out_im = x_im * freq_re[None, :] - x_re * freq_im[None, :] else: - sign = tl.where(col_offsets < rotary_dim, -1.0, 1.0) - - out = x * freq_re[None, :] + sign[None, :] * x_partner * freq_im[None, :] + out_re = x_re * freq_re[None, :] - x_im * freq_im[None, :] + out_im = x_im * freq_re[None, :] + x_re * freq_im[None, :] if rotary_block_size % rotary_dim == 0 and num_heads % head_block_size == 0: - tl.store(input_ptr + input_offsets, out) + tl.store(input_ptr + re_offsets, out_re) + tl.store(input_ptr + im_offsets, out_im) else: - tl.store(input_ptr + input_offsets, out, mask=mask) + tl.store(input_ptr + re_offsets, out_re, mask=mask) + tl.store(input_ptr + im_offsets, out_im, mask=mask) def triton_rotary_( @@ -106,10 +81,12 @@ def triton_rotary_( batch_size, seq_len, num_heads, head_size = input_.shape rotary_dim = div(head_size, 2) rotary_block_size = triton.next_power_of_2(rotary_dim) + head_block_size = triton.cdiv(TritonConfig.POINTWISE_BLOCK_SIZE, rotary_block_size) + if head_block_size > num_heads: + head_block_size = triton.next_power_of_2(num_heads) # Folded the large y dim into the x dim as gridDim.x is 32 bit while gridDim.y and gridDim.z are 16 bit registers - grid = lambda meta: (batch_size * seq_len, triton.cdiv(num_heads, meta["head_block_size"])) - triton_rotary_kernel[grid]( + triton_rotary_kernel[(batch_size * seq_len, triton.cdiv(num_heads, head_block_size))]( input_, frequencies, input_.stride(0), @@ -118,6 +95,7 @@ def triton_rotary_( rotary_dim, num_heads, rotary_block_size, + head_block_size, seq_len, backward, # noqa ) From 3b35693a8a6c53735e63ff68389bf73617bd8486 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Wed, 29 Apr 2026 00:12:18 -0400 Subject: [PATCH 14/41] Restore autotuned rotary kernel, add benchmark warmup MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The autotune is incompatible with in-place kernels during benchmarking because the timed variant calls the kernel on the same tensor multiple times — each call rotates it again, making the correctness check see a corrupted result. Fix by warming up autotune in make_inputs on a throwaway tensor, so the cached winning config is used for all timed runs. Co-Authored-By: Claude Sonnet 4.6 --- fast_llm/functional/triton/rotary.py | 84 ++++++++++++++++++---------- tools/benchmark/bench_rotary.py | 6 +- 2 files changed, 58 insertions(+), 32 deletions(-) diff --git a/fast_llm/functional/triton/rotary.py b/fast_llm/functional/triton/rotary.py index 0298fb15f..80fa05de4 100644 --- a/fast_llm/functional/triton/rotary.py +++ b/fast_llm/functional/triton/rotary.py @@ -1,11 +1,30 @@ +import os + import torch -from fast_llm.functional.config import TritonConfig -from fast_llm.functional.triton import tl, tl_arange, tl_constexpr, triton, triton_jit +from fast_llm.functional.triton import TritonConfig, tl, tl_arange, tl_constexpr, triton, triton_autotune, triton_jit from fast_llm.functional.utils import wrap_forward_backward from fast_llm.utils import div - +autotune_configs = ( + TritonConfig({"head_block_size": 16}, num_warps=4), + TritonConfig({"head_block_size": 8}, num_warps=4), + TritonConfig({"head_block_size": 4}, num_warps=4), + TritonConfig({"head_block_size": 2}, num_warps=4), + TritonConfig({"head_block_size": 1}, num_warps=4), + TritonConfig({"head_block_size": 16}, num_warps=8), + TritonConfig({"head_block_size": 8}, num_warps=8), + TritonConfig({"head_block_size": 4}, num_warps=8), +) + +if os.environ.get("FAST_LLM_SKIP_TRITON_AUTOTUNE"): + autotune_configs = (autotune_configs[0],) + + +@triton_autotune( + configs=autotune_configs, + key=["rotary_dim", "num_heads", "seq_len"], +) @triton_jit() def triton_rotary_kernel( input_ptr, @@ -16,51 +35,57 @@ def triton_rotary_kernel( rotary_dim: tl_constexpr, num_heads: tl_constexpr, rotary_block_size: tl_constexpr, - head_block_size: tl_constexpr, seq_len: tl_constexpr, backward: tl_constexpr, + head_block_size: tl_constexpr, # injected by autotune ): # TODO: Int64 ptr if needed? pid_0 = tl.program_id(axis=0) # Folded (batch * seq) index pid_1 = tl.program_id(axis=1) # Head block index position_id = pid_0 % seq_len - col_offsets = tl_arange(0, rotary_block_size) + # Full-head column offsets: [0, 1, …, 2*rotary_dim-1] + col_offsets = tl_arange(0, 2 * rotary_block_size) head_row = pid_1 * head_block_size + tl_arange(0, head_block_size) base = stride_0 * (pid_0 // seq_len) + stride_1 * position_id + input_offsets = base + stride_2 * head_row[:, None] + col_offsets[None, :] - # Load re and im halves separately so both are in registers simultaneously, - # avoiding a partner load (reading the same cache lines again with shuffled indices). - re_offsets = base + stride_2 * head_row[:, None] + col_offsets[None, :] - im_offsets = re_offsets + rotary_dim + # Load full head as one contiguous block per head (re and im halves together) + if rotary_block_size % rotary_dim == 0 and num_heads % head_block_size == 0: + x = tl.load(input_ptr + input_offsets).to(tl.float32) + else: + mask = (col_offsets[None, :] < 2 * rotary_dim) & (head_row[:, None] < num_heads) + x = tl.load(input_ptr + input_offsets, mask=mask).to(tl.float32) + # Partner: x[e + rotary_dim] for re-columns, x[e - rotary_dim] for im-columns. + # These are the same cache lines as x, so expect L2 hits after the x load above. + partner_col = tl.where(col_offsets < rotary_dim, col_offsets + rotary_dim, col_offsets - rotary_dim) + partner_offsets = base + stride_2 * head_row[:, None] + partner_col[None, :] if rotary_block_size % rotary_dim == 0 and num_heads % head_block_size == 0: - x_re = tl.load(input_ptr + re_offsets).to(tl.float32) - x_im = tl.load(input_ptr + im_offsets).to(tl.float32) + x_partner = tl.load(input_ptr + partner_offsets).to(tl.float32) else: - mask = (col_offsets[None, :] < rotary_dim) & (head_row[:, None] < num_heads) - x_re = tl.load(input_ptr + re_offsets, mask=mask).to(tl.float32) - x_im = tl.load(input_ptr + im_offsets, mask=mask).to(tl.float32) + x_partner = tl.load(input_ptr + partner_offsets, mask=mask).to(tl.float32) + # Frequencies: same index for both halves (cos/sin repeat for re and im columns) + freq_col = tl.where(col_offsets < rotary_dim, col_offsets, col_offsets - rotary_dim) freq_base = frequencies_ptr + 2 * rotary_dim * position_id - freq_re = tl.load(freq_base + col_offsets) - freq_im = tl.load(freq_base + rotary_dim + col_offsets) + freq_re = tl.load(freq_base + freq_col) + freq_im = tl.load(freq_base + rotary_dim + freq_col) - # fwd: out_re = cos*re - sin*im, out_im = cos*im + sin*re - # bwd: conjugate rotation, sin signs flipped + # out[e] = x[e]*cos ± x_partner[e]*sin + # fwd: sign=-1 for re columns (cos*re - sin*im), +1 for im (cos*im + sin*re) + # bwd: conjugate rotation flips the sign if backward: - out_re = x_re * freq_re[None, :] + x_im * freq_im[None, :] - out_im = x_im * freq_re[None, :] - x_re * freq_im[None, :] + sign = tl.where(col_offsets < rotary_dim, 1.0, -1.0) else: - out_re = x_re * freq_re[None, :] - x_im * freq_im[None, :] - out_im = x_im * freq_re[None, :] + x_re * freq_im[None, :] + sign = tl.where(col_offsets < rotary_dim, -1.0, 1.0) + + out = x * freq_re[None, :] + sign[None, :] * x_partner * freq_im[None, :] if rotary_block_size % rotary_dim == 0 and num_heads % head_block_size == 0: - tl.store(input_ptr + re_offsets, out_re) - tl.store(input_ptr + im_offsets, out_im) + tl.store(input_ptr + input_offsets, out) else: - tl.store(input_ptr + re_offsets, out_re, mask=mask) - tl.store(input_ptr + im_offsets, out_im, mask=mask) + tl.store(input_ptr + input_offsets, out, mask=mask) def triton_rotary_( @@ -81,12 +106,10 @@ def triton_rotary_( batch_size, seq_len, num_heads, head_size = input_.shape rotary_dim = div(head_size, 2) rotary_block_size = triton.next_power_of_2(rotary_dim) - head_block_size = triton.cdiv(TritonConfig.POINTWISE_BLOCK_SIZE, rotary_block_size) - if head_block_size > num_heads: - head_block_size = triton.next_power_of_2(num_heads) # Folded the large y dim into the x dim as gridDim.x is 32 bit while gridDim.y and gridDim.z are 16 bit registers - triton_rotary_kernel[(batch_size * seq_len, triton.cdiv(num_heads, head_block_size))]( + grid = lambda meta: (batch_size * seq_len, triton.cdiv(num_heads, meta["head_block_size"])) + triton_rotary_kernel[grid]( input_, frequencies, input_.stride(0), @@ -95,7 +118,6 @@ def triton_rotary_( rotary_dim, num_heads, rotary_block_size, - head_block_size, seq_len, backward, # noqa ) diff --git a/tools/benchmark/bench_rotary.py b/tools/benchmark/bench_rotary.py index 4ca56662f..0f733d7d3 100644 --- a/tools/benchmark/bench_rotary.py +++ b/tools/benchmark/bench_rotary.py @@ -31,9 +31,13 @@ def _make_rotary_inputs(tokens: int, num_heads: int, head_size: int, dtype: torch.dtype) -> dict: rotary_dim = head_size // 2 + frequencies = torch.randn(tokens, 2 * rotary_dim, dtype=torch.float32, device=device()) + # Warm up Triton autotuning so the timed runs aren't dominated by JIT/autotune overhead. + if TritonConfig.enabled(): + triton_rotary_(torch.randn(tokens, num_heads, head_size, dtype=dtype, device=device()), frequencies) return { "input_": torch.randn(tokens, num_heads, head_size, dtype=dtype, device=device()), - "frequencies": torch.randn(tokens, 2 * rotary_dim, dtype=torch.float32, device=device()), + "frequencies": frequencies, } From 7b6ce6fa01e6ef84be12af870eeb00c455e9801a Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Wed, 29 Apr 2026 00:18:20 -0400 Subject: [PATCH 15/41] Drop autotune from rotary kernel, add torch.compile inspection script Restore the original no-autotune kernel (benchmarking showed autotune neither helps nor can be benchmarked correctly for in-place kernels). Add tools/inspect_rotary_compile.py to dump the Triton code that torch.compile generates so we can compare it to the hand-written kernel. Co-Authored-By: Claude Sonnet 4.6 --- fast_llm/functional/triton/rotary.py | 95 +++++++++++----------------- tools/benchmark/bench_rotary.py | 6 +- tools/inspect_rotary_compile.py | 63 ++++++++++++++++++ 3 files changed, 101 insertions(+), 63 deletions(-) create mode 100644 tools/inspect_rotary_compile.py diff --git a/fast_llm/functional/triton/rotary.py b/fast_llm/functional/triton/rotary.py index 80fa05de4..f07046a52 100644 --- a/fast_llm/functional/triton/rotary.py +++ b/fast_llm/functional/triton/rotary.py @@ -1,30 +1,11 @@ -import os - import torch -from fast_llm.functional.triton import TritonConfig, tl, tl_arange, tl_constexpr, triton, triton_autotune, triton_jit +from fast_llm.functional.config import TritonConfig +from fast_llm.functional.triton import tl, tl_arange, tl_constexpr, triton, triton_jit from fast_llm.functional.utils import wrap_forward_backward from fast_llm.utils import div -autotune_configs = ( - TritonConfig({"head_block_size": 16}, num_warps=4), - TritonConfig({"head_block_size": 8}, num_warps=4), - TritonConfig({"head_block_size": 4}, num_warps=4), - TritonConfig({"head_block_size": 2}, num_warps=4), - TritonConfig({"head_block_size": 1}, num_warps=4), - TritonConfig({"head_block_size": 16}, num_warps=8), - TritonConfig({"head_block_size": 8}, num_warps=8), - TritonConfig({"head_block_size": 4}, num_warps=8), -) - -if os.environ.get("FAST_LLM_SKIP_TRITON_AUTOTUNE"): - autotune_configs = (autotune_configs[0],) - - -@triton_autotune( - configs=autotune_configs, - key=["rotary_dim", "num_heads", "seq_len"], -) + @triton_jit() def triton_rotary_kernel( input_ptr, @@ -35,57 +16,49 @@ def triton_rotary_kernel( rotary_dim: tl_constexpr, num_heads: tl_constexpr, rotary_block_size: tl_constexpr, + head_block_size: tl_constexpr, seq_len: tl_constexpr, backward: tl_constexpr, - head_block_size: tl_constexpr, # injected by autotune ): # TODO: Int64 ptr if needed? pid_0 = tl.program_id(axis=0) # Folded (batch * seq) index - pid_1 = tl.program_id(axis=1) # Head block index + pid_1 = tl.program_id(axis=1) # Head index position_id = pid_0 % seq_len - # Full-head column offsets: [0, 1, …, 2*rotary_dim-1] - col_offsets = tl_arange(0, 2 * rotary_block_size) - head_row = pid_1 * head_block_size + tl_arange(0, head_block_size) - base = stride_0 * (pid_0 // seq_len) + stride_1 * position_id - input_offsets = base + stride_2 * head_row[:, None] + col_offsets[None, :] - - # Load full head as one contiguous block per head (re and im halves together) - if rotary_block_size % rotary_dim == 0 and num_heads % head_block_size == 0: - x = tl.load(input_ptr + input_offsets).to(tl.float32) - else: - mask = (col_offsets[None, :] < 2 * rotary_dim) & (head_row[:, None] < num_heads) - x = tl.load(input_ptr + input_offsets, mask=mask).to(tl.float32) + offsets = tl_arange(0, rotary_block_size) + head_offsets = pid_1 * head_block_size + tl_arange(0, head_block_size)[:, None] + input_offsets = stride_0 * (pid_0 // seq_len) + stride_1 * position_id + stride_2 * head_offsets + offsets[None, :] + input_re_ptr = input_ptr + input_offsets + input_im_ptr = input_re_ptr + rotary_dim - # Partner: x[e + rotary_dim] for re-columns, x[e - rotary_dim] for im-columns. - # These are the same cache lines as x, so expect L2 hits after the x load above. - partner_col = tl.where(col_offsets < rotary_dim, col_offsets + rotary_dim, col_offsets - rotary_dim) - partner_offsets = base + stride_2 * head_row[:, None] + partner_col[None, :] if rotary_block_size % rotary_dim == 0 and num_heads % head_block_size == 0: - x_partner = tl.load(input_ptr + partner_offsets).to(tl.float32) + input_re = tl.load(input_re_ptr).to(tl.float32) + input_im = tl.load(input_im_ptr).to(tl.float32) else: - x_partner = tl.load(input_ptr + partner_offsets, mask=mask).to(tl.float32) + mask = (offsets[None, :] < rotary_dim) & (head_offsets < num_heads) + input_re = tl.load(input_re_ptr, mask=mask).to(tl.float32) + input_im = tl.load(input_im_ptr, mask=mask).to(tl.float32) - # Frequencies: same index for both halves (cos/sin repeat for re and im columns) - freq_col = tl.where(col_offsets < rotary_dim, col_offsets, col_offsets - rotary_dim) - freq_base = frequencies_ptr + 2 * rotary_dim * position_id - freq_re = tl.load(freq_base + freq_col) - freq_im = tl.load(freq_base + rotary_dim + freq_col) + # Computing frequencies here is faster but hurts precision, so we load pre-computed ones instead. + frequencies_offsets = 2 * rotary_dim * position_id + offsets + frequencies_re_ptr = frequencies_ptr + frequencies_offsets + frequencies_im_ptr = frequencies_re_ptr + rotary_dim + frequencies_re = tl.load(frequencies_re_ptr) + frequencies_im = tl.load(frequencies_im_ptr) - # out[e] = x[e]*cos ± x_partner[e]*sin - # fwd: sign=-1 for re columns (cos*re - sin*im), +1 for im (cos*im + sin*re) - # bwd: conjugate rotation flips the sign if backward: - sign = tl.where(col_offsets < rotary_dim, 1.0, -1.0) + out_re = input_re * frequencies_re + input_im * frequencies_im + out_im = input_im * frequencies_re - input_re * frequencies_im else: - sign = tl.where(col_offsets < rotary_dim, -1.0, 1.0) - - out = x * freq_re[None, :] + sign[None, :] * x_partner * freq_im[None, :] + out_re = input_re * frequencies_re - input_im * frequencies_im + out_im = input_im * frequencies_re + input_re * frequencies_im if rotary_block_size % rotary_dim == 0 and num_heads % head_block_size == 0: - tl.store(input_ptr + input_offsets, out) + tl.store(input_re_ptr, out_re) + tl.store(input_im_ptr, out_im) else: - tl.store(input_ptr + input_offsets, out, mask=mask) + tl.store(input_re_ptr, out_re, mask=mask) # noqa + tl.store(input_im_ptr, out_im, mask=mask) def triton_rotary_( @@ -94,9 +67,12 @@ def triton_rotary_( is_key_value: bool = False, backward: bool = False, ) -> torch.Tensor: + # TODO: Improve assumptions. # TODO: Make a transposed version to avoid contiguous call in key backward. + # TODO: Improve block size heuristics. out = input_ if input_.stride(-1) != 1: + # TODO: Make a transposed version to avoid contiguous call in key backward. input_ = input_.contiguous() if input_.ndim == 3: input_ = input_.unsqueeze(0) @@ -106,10 +82,12 @@ def triton_rotary_( batch_size, seq_len, num_heads, head_size = input_.shape rotary_dim = div(head_size, 2) rotary_block_size = triton.next_power_of_2(rotary_dim) + head_block_size = triton.cdiv(TritonConfig.POINTWISE_BLOCK_SIZE, rotary_block_size) + if head_block_size > num_heads: + head_block_size = triton.next_power_of_2(num_heads) # Folded the large y dim into the x dim as gridDim.x is 32 bit while gridDim.y and gridDim.z are 16 bit registers - grid = lambda meta: (batch_size * seq_len, triton.cdiv(num_heads, meta["head_block_size"])) - triton_rotary_kernel[grid]( + triton_rotary_kernel[(batch_size * seq_len, triton.cdiv(num_heads, head_block_size))]( input_, frequencies, input_.stride(0), @@ -118,6 +96,7 @@ def triton_rotary_( rotary_dim, num_heads, rotary_block_size, + head_block_size, seq_len, backward, # noqa ) diff --git a/tools/benchmark/bench_rotary.py b/tools/benchmark/bench_rotary.py index 0f733d7d3..4ca56662f 100644 --- a/tools/benchmark/bench_rotary.py +++ b/tools/benchmark/bench_rotary.py @@ -31,13 +31,9 @@ def _make_rotary_inputs(tokens: int, num_heads: int, head_size: int, dtype: torch.dtype) -> dict: rotary_dim = head_size // 2 - frequencies = torch.randn(tokens, 2 * rotary_dim, dtype=torch.float32, device=device()) - # Warm up Triton autotuning so the timed runs aren't dominated by JIT/autotune overhead. - if TritonConfig.enabled(): - triton_rotary_(torch.randn(tokens, num_heads, head_size, dtype=dtype, device=device()), frequencies) return { "input_": torch.randn(tokens, num_heads, head_size, dtype=dtype, device=device()), - "frequencies": frequencies, + "frequencies": torch.randn(tokens, 2 * rotary_dim, dtype=torch.float32, device=device()), } diff --git a/tools/inspect_rotary_compile.py b/tools/inspect_rotary_compile.py new file mode 100644 index 000000000..0493e6cbf --- /dev/null +++ b/tools/inspect_rotary_compile.py @@ -0,0 +1,63 @@ +""" +Dump the Triton kernel that torch.compile generates for the rotary embedding, +so we can compare it to our hand-written fast_llm kernel. + +Run on a GPU node: + python tools/inspect_rotary_compile.py +Output lands in /tmp/torchinductor_*/ (one subdir per compiled function). +This script also prints the path and first 200 lines of each .py file found. +""" + +import os + +import torch +import torch._inductor.config as inductor_config + +# Route torch.compile output to a known directory. +_OUT = "/tmp/torchinductor_rotary_inspect" +os.makedirs(_OUT, exist_ok=True) +os.environ["TORCHINDUCTOR_CACHE_DIR"] = _OUT + + +inductor_config.debug = True # writes generated Triton .py files alongside the cache + +tokens, num_heads, head_size = 4096, 32, 128 +rotary_dim = head_size // 2 +dtype = torch.bfloat16 +device = "cuda" + +input_ = torch.randn(tokens, num_heads, head_size, dtype=dtype, device=device) +frequencies = torch.randn(tokens, 2 * rotary_dim, dtype=torch.float32, device=device) + + +def _rotary_eager(input_: torch.Tensor, frequencies: torch.Tensor) -> torch.Tensor: + rotary_dim = frequencies.shape[-1] // 2 + freq_re = frequencies[:, :rotary_dim].unsqueeze(1) + freq_im = frequencies[:, rotary_dim:].unsqueeze(1) + x_re, x_im = input_.chunk(2, dim=-1) + out_re = x_re * freq_re - x_im * freq_im + out_im = x_im * freq_re + x_re * freq_im + return torch.cat([out_re, out_im], dim=-1) + + +compiled = torch.compile(_rotary_eager, mode="default", dynamic=False) + +# Trigger compilation. +out = compiled(input_, frequencies) +torch.cuda.synchronize() +print(f"Output shape: {out.shape}, dtype: {out.dtype}") +print(f"\nInductor cache / debug output dir: {_OUT}") + +# Find and print the generated Triton kernel files. +for root, dirs, files in os.walk(_OUT): + for fname in sorted(files): + if fname.endswith(".py"): + path = os.path.join(root, fname) + print(f"\n{'='*80}") + print(f"FILE: {path}") + print("=" * 80) + with open(path) as f: + lines = f.readlines() + print("".join(lines[:300])) + if len(lines) > 300: + print(f"... ({len(lines) - 300} more lines)") From cb3c161cee2b79f7facf22781a628f70fe00e6c0 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Wed, 29 Apr 2026 19:54:57 -0400 Subject: [PATCH 16/41] Fix benchmark fairness: pre-allocate work buffer, reset between reps MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit In-place variants (fast_llm_triton) were calling .clone() inside the timed callable, paying ~1 full HBM read+write per rep — roughly doubling measured time relative to the actual kernel cost. Fix: pre-allocate a "work" buffer in _make_rotary_inputs and restore it between reps via reset_inputs (runner calls this outside the timed region). No in-place variant now has allocation cost in the hot path. Runner change: Variant gains reset_inputs field; bench_fn gains reset parameter that is called before flush() and before each rep's start event. Co-Authored-By: Claude Sonnet 4.6 --- tools/benchmark/bench_rotary.py | 8 +++++--- tools/benchmark/runner.py | 15 +++++++++++++-- 2 files changed, 18 insertions(+), 5 deletions(-) diff --git a/tools/benchmark/bench_rotary.py b/tools/benchmark/bench_rotary.py index 4ca56662f..232d45fde 100644 --- a/tools/benchmark/bench_rotary.py +++ b/tools/benchmark/bench_rotary.py @@ -31,8 +31,10 @@ def _make_rotary_inputs(tokens: int, num_heads: int, head_size: int, dtype: torch.dtype) -> dict: rotary_dim = head_size // 2 + input_ = torch.randn(tokens, num_heads, head_size, dtype=dtype, device=device()) return { - "input_": torch.randn(tokens, num_heads, head_size, dtype=dtype, device=device()), + "input_": input_, + "work": input_.clone(), # pre-allocated work buffer for in-place variants "frequencies": torch.randn(tokens, 2 * rotary_dim, dtype=torch.float32, device=device()), } @@ -84,11 +86,11 @@ def _rotary_variants() -> list[Variant]: ), ] if TritonConfig.enabled(): - # triton_rotary_ is in-place; clone so the benchmark input stays intact. variants.append( Variant( name="fast_llm_triton", - fwd=lambda inp: {"output": triton_rotary_(inp["input_"].clone(), inp["frequencies"])}, + fwd=lambda inp: {"output": triton_rotary_(inp["work"], inp["frequencies"])}, + reset_inputs=lambda inp: inp["work"].copy_(inp["input_"]), ) ) return variants diff --git a/tools/benchmark/runner.py b/tools/benchmark/runner.py index a629684c4..36a6f3eb6 100644 --- a/tools/benchmark/runner.py +++ b/tools/benchmark/runner.py @@ -68,6 +68,10 @@ class Variant: # (possibly modified) dict. Use this to mask out don't-care regions so they # don't inflate RMS errors (e.g. uninitialized phantom rows in sparse buffers). output_postprocess: Callable[[dict[str, torch.Tensor], Inputs], dict[str, torch.Tensor]] | None = None + # Called between timing reps (outside the timed region) to restore any + # input tensors the variant mutates in-place. Use this instead of cloning + # inside the timed callable so the mutation cost is not measured. + reset_inputs: Callable[[Inputs], None] | None = None @dataclasses.dataclass @@ -138,6 +142,7 @@ def flush() -> None: def bench_fn( fn: Callable[[], Any], + reset: Callable[[], None] | None = None, warmup_ms: float = 25.0, rep_ms: float = 100.0, min_reps: int = 5, @@ -189,6 +194,8 @@ def bench_fn( start_events = [torch.cuda.Event(enable_timing=True) for _ in range(n_reps)] end_events = [torch.cuda.Event(enable_timing=True) for _ in range(n_reps)] for i in range(n_reps): + if reset is not None: + reset() flush() start_events[i].record() fn() @@ -295,7 +302,8 @@ def _fwd_once() -> Any: del fwd_output # Timing: reuse the same input tensors, fn closes over them. - result.fwd_timing = bench_fn(_guarded_fwd, warmup_ms=warmup_ms, rep_ms=rep_ms) + _reset_fwd = (lambda: variant.reset_inputs(inputs)) if variant.reset_inputs else None + result.fwd_timing = bench_fn(_guarded_fwd, reset=_reset_fwd, warmup_ms=warmup_ms, rep_ms=rep_ms) del inputs # fwd+bwd mode @@ -333,7 +341,10 @@ def _fwd_bwd_once() -> Any: del fresh_inputs # Timing. - result.fwd_bwd_timing = bench_fn(_guarded_fwd_bwd, warmup_ms=warmup_ms, rep_ms=rep_ms) + _reset_fwd_bwd = (lambda: variant.reset_inputs(inputs)) if variant.reset_inputs else None + result.fwd_bwd_timing = bench_fn( + _guarded_fwd_bwd, reset=_reset_fwd_bwd, warmup_ms=warmup_ms, rep_ms=rep_ms + ) del inputs elif variant.fwd is not None and result.memory is None: # No backward — measure fwd-mode memory. From 057efa029a8d5c77162a692bd954801ad3df14d4 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Wed, 29 Apr 2026 20:38:25 -0400 Subject: [PATCH 17/41] Fix fwd_bwd benchmark bias: zero logits.grad between reps for PyTorch variants MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit In entropy_loss and grpo_loss, PyTorch fwd_bwd variants call .backward() which accumulates into inp["logits"].grad. After rep 1, .backward() on rep 2+ adds into the existing grad tensor (1 extra read+write of the full logits tensor per rep). Triton variants compute grad_logits fresh each rep without touching inp["logits"].grad — no accumulation, no extra read. For 4096×131K logits this is ~2 GB extra HBM traffic per rep, biasing PyTorch ~33% slower than reality in fwd_bwd timing. Fix: add reset_inputs=_reset_logits_grad to all PyTorch fwd_bwd variants in entropy_loss (all 4 groups: ce_labels, ce_dist, reverse_kl, z_loss) and grpo_loss. fp32_reference is unaffected (it detaches into a local tensor; inp["logits"].grad is never set). Other benchmarks (normalization, mlp_activation, sparse_copy/linear) are symmetric: all variants use .backward(), so gradient accumulation affects them equally. Co-Authored-By: Claude Sonnet 4.6 --- tools/benchmark/bench_entropy_loss.py | 16 ++++++++++++++++ tools/benchmark/bench_grpo_loss.py | 7 +++++++ 2 files changed, 23 insertions(+) diff --git a/tools/benchmark/bench_entropy_loss.py b/tools/benchmark/bench_entropy_loss.py index 13fec495e..2eec74cb2 100644 --- a/tools/benchmark/bench_entropy_loss.py +++ b/tools/benchmark/bench_entropy_loss.py @@ -71,6 +71,10 @@ def _run_ce_labels_fwd_fp32(inp: dict) -> dict: return {"loss": _ce_labels_eager(logits_fp32, inp["labels"])} +def _reset_logits_grad(inp: dict) -> None: + inp["logits"].grad = None + + def _run_ce_labels_fwd_bwd(inp: dict, fn) -> dict: loss = fn(inp["logits"], inp["labels"]) loss.backward() @@ -108,16 +112,19 @@ def _ce_labels_variants() -> list[Variant]: name="pytorch_eager", fwd=lambda inp: _run_ce_labels_fwd(inp, _ce_labels_eager), fwd_bwd=lambda inp: _run_ce_labels_fwd_bwd(inp, _ce_labels_eager), + reset_inputs=_reset_logits_grad, ), Variant( name="pytorch_compiled", fwd=lambda inp: _run_ce_labels_fwd(inp, _ce_labels_compiled_default), fwd_bwd=lambda inp: _run_ce_labels_fwd_bwd(inp, _ce_labels_compiled_default), + reset_inputs=_reset_logits_grad, ), Variant( name="pytorch_compiled_max", fwd=lambda inp: _run_ce_labels_fwd(inp, _ce_labels_compiled_max), fwd_bwd=lambda inp: _run_ce_labels_fwd_bwd(inp, _ce_labels_compiled_max), + reset_inputs=_reset_logits_grad, ), ] if TritonConfig.enabled(): @@ -201,16 +208,19 @@ def _ce_dist_variants() -> list[Variant]: name="pytorch_eager", fwd=lambda inp: _run_dist_fwd(inp, _ce_dist_eager), fwd_bwd=lambda inp: _run_dist_fwd_bwd(inp, _ce_dist_eager), + reset_inputs=_reset_logits_grad, ), Variant( name="pytorch_compiled", fwd=lambda inp: _run_dist_fwd(inp, _ce_dist_compiled_default), fwd_bwd=lambda inp: _run_dist_fwd_bwd(inp, _ce_dist_compiled_default), + reset_inputs=_reset_logits_grad, ), Variant( name="pytorch_compiled_max", fwd=lambda inp: _run_dist_fwd(inp, _ce_dist_compiled_max), fwd_bwd=lambda inp: _run_dist_fwd_bwd(inp, _ce_dist_compiled_max), + reset_inputs=_reset_logits_grad, ), ] if TritonConfig.enabled(): @@ -288,16 +298,19 @@ def _reverse_kl_variants() -> list[Variant]: name="pytorch_eager", fwd=lambda inp: _run_dist_fwd(inp, _reverse_kl_eager), fwd_bwd=lambda inp: _run_dist_fwd_bwd(inp, _reverse_kl_eager), + reset_inputs=_reset_logits_grad, ), Variant( name="pytorch_compiled", fwd=lambda inp: _run_dist_fwd(inp, _reverse_kl_compiled_default), fwd_bwd=lambda inp: _run_dist_fwd_bwd(inp, _reverse_kl_compiled_default), + reset_inputs=_reset_logits_grad, ), Variant( name="pytorch_compiled_max", fwd=lambda inp: _run_dist_fwd(inp, _reverse_kl_compiled_max), fwd_bwd=lambda inp: _run_dist_fwd_bwd(inp, _reverse_kl_compiled_max), + reset_inputs=_reset_logits_grad, ), ] if TritonConfig.enabled(): @@ -362,16 +375,19 @@ def _z_loss_variants() -> list[Variant]: name="pytorch_eager", fwd=lambda inp: _run_zl_fwd(inp, _z_loss_eager), fwd_bwd=lambda inp: _run_zl_fwd_bwd(inp, _z_loss_eager), + reset_inputs=_reset_logits_grad, ), Variant( name="pytorch_compiled", fwd=lambda inp: _run_zl_fwd(inp, _z_loss_compiled_default), fwd_bwd=lambda inp: _run_zl_fwd_bwd(inp, _z_loss_compiled_default), + reset_inputs=_reset_logits_grad, ), Variant( name="pytorch_compiled_max", fwd=lambda inp: _run_zl_fwd(inp, _z_loss_compiled_max), fwd_bwd=lambda inp: _run_zl_fwd_bwd(inp, _z_loss_compiled_max), + reset_inputs=_reset_logits_grad, ), ] if TritonConfig.enabled(): diff --git a/tools/benchmark/bench_grpo_loss.py b/tools/benchmark/bench_grpo_loss.py index a60e1751c..4ca0f3673 100644 --- a/tools/benchmark/bench_grpo_loss.py +++ b/tools/benchmark/bench_grpo_loss.py @@ -77,6 +77,10 @@ def _run_fwd_fp32(inp: dict) -> dict: } +def _reset_logits_grad(inp: dict) -> None: + inp["logits"].grad = None + + def _run_fwd_bwd(inp: dict, fn) -> dict: loss = fn(inp["logits"], inp["labels"], inp["advantages"], inp["old_log_probs"]) loss.backward() @@ -128,16 +132,19 @@ def _grpo_variants() -> list[Variant]: name="pytorch_eager", fwd=lambda inp: _run_fwd(inp, _grpo_eager), fwd_bwd=lambda inp: _run_fwd_bwd(inp, _grpo_eager), + reset_inputs=_reset_logits_grad, ), Variant( name="pytorch_compiled", fwd=lambda inp: _run_fwd(inp, _grpo_compiled_default), fwd_bwd=lambda inp: _run_fwd_bwd(inp, _grpo_compiled_default), + reset_inputs=_reset_logits_grad, ), Variant( name="pytorch_compiled_max", fwd=lambda inp: _run_fwd(inp, _grpo_compiled_max), fwd_bwd=lambda inp: _run_fwd_bwd(inp, _grpo_compiled_max), + reset_inputs=_reset_logits_grad, ), ] if TritonConfig.enabled(): From c854b8b90a88422bb43bb6d2a8d04e16e5ee2dd0 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Wed, 29 Apr 2026 20:49:41 -0400 Subject: [PATCH 18/41] Add fp32 references for sparse_copy and sparse_linear benchmarks MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit All other benchmarks (rotary, normalization, mlp_activation, entropy_loss, grpo_loss) use an fp32 reference as ground truth. sparse_copy and sparse_linear were using pytorch_eager / pytorch_loop in compute dtype (bf16) as is_reference=True — meaning Triton matching bf16 exactly gives zero RMS error even if both implementations share the same numeric drift. Add fp32_reference variants for all four cases: - bench_sparse_copy.py: dispatch and combine - bench_sparse_linear.py: output_sparse (layer 1) and input_inner_sparse (layer 2) Demote the existing pytorch_eager / pytorch_loop variants to regular comparison variants (no is_reference). Pattern matches the other benches: fp32 eager loop in float32, detaching inputs and recasting grad_output. Co-Authored-By: Claude Sonnet 4.6 --- tools/benchmark/bench_sparse_copy.py | 44 ++++++++++++++++++++++++-- tools/benchmark/bench_sparse_linear.py | 42 ++++++++++++++++++++++-- 2 files changed, 82 insertions(+), 4 deletions(-) diff --git a/tools/benchmark/bench_sparse_copy.py b/tools/benchmark/bench_sparse_copy.py index aada0f649..77768faf5 100644 --- a/tools/benchmark/bench_sparse_copy.py +++ b/tools/benchmark/bench_sparse_copy.py @@ -108,6 +108,18 @@ def _run_dispatch_fwd_bwd(inp: dict, fn) -> dict: return {"output": output.detach(), "grad_dense": inp["dense_input"].grad} +def _run_dispatch_fwd_fp32(inp: dict) -> dict: + dense_fp32 = inp["dense_input"].float().detach().requires_grad_(True) + return {"output": _dispatch_pytorch(dense_fp32, inp["sparse_map"])} + + +def _run_dispatch_fwd_bwd_fp32(inp: dict) -> dict: + dense_fp32 = inp["dense_input"].float().detach().requires_grad_(True) + output = _dispatch_pytorch(dense_fp32, inp["sparse_map"]) + output.backward(inp["backward_grad"].float()) + return {"output": output.detach(), "grad_dense": dense_fp32.grad} + + def _run_dispatch_fwd_triton(inp: dict) -> dict: return {"output": copy_dense_to_sparse_autograd(inp["dense_input"], inp["sparse_map"])} @@ -125,11 +137,16 @@ def _dispatch_postprocess(out: dict[str, torch.Tensor], inp: dict) -> dict[str, def _dispatch_variants() -> list[Variant]: variants = [ + Variant( + name="fp32_reference", + fwd=_run_dispatch_fwd_fp32, + fwd_bwd=_run_dispatch_fwd_bwd_fp32, + is_reference=True, + ), Variant( name="pytorch_eager", fwd=lambda inp: _run_dispatch_fwd(inp, _dispatch_pytorch), fwd_bwd=lambda inp: _run_dispatch_fwd_bwd(inp, _dispatch_pytorch), - is_reference=True, ), Variant( name="pytorch_compiled", @@ -183,6 +200,24 @@ def _run_combine_fwd_bwd(inp: dict, fn) -> dict: } +def _run_combine_fwd_fp32(inp: dict) -> dict: + sparse_fp32 = inp["sparse_input"].float().detach().requires_grad_(True) + scores_fp32 = inp["scores"].float().detach().requires_grad_(True) + return {"output": _combine_pytorch(sparse_fp32, scores_fp32, inp["sparse_map"])} + + +def _run_combine_fwd_bwd_fp32(inp: dict) -> dict: + sparse_fp32 = inp["sparse_input"].float().detach().requires_grad_(True) + scores_fp32 = inp["scores"].float().detach().requires_grad_(True) + output = _combine_pytorch(sparse_fp32, scores_fp32, inp["sparse_map"]) + output.backward(inp["backward_grad"].float()) + return { + "output": output.detach(), + "grad_sparse": sparse_fp32.grad, + "grad_scores": scores_fp32.grad, + } + + def _run_combine_fwd_triton(inp: dict) -> dict: return {"output": copy_sparse_to_dense_autograd(inp["sparse_input"], inp["scores"], inp["sparse_map"])} @@ -205,11 +240,16 @@ def _combine_postprocess(out: dict[str, torch.Tensor], inp: dict) -> dict[str, t def _combine_variants() -> list[Variant]: variants = [ + Variant( + name="fp32_reference", + fwd=_run_combine_fwd_fp32, + fwd_bwd=_run_combine_fwd_bwd_fp32, + is_reference=True, + ), Variant( name="pytorch_eager", fwd=lambda inp: _run_combine_fwd(inp, _combine_pytorch), fwd_bwd=lambda inp: _run_combine_fwd_bwd(inp, _combine_pytorch), - is_reference=True, ), Variant( name="pytorch_compiled", diff --git a/tools/benchmark/bench_sparse_linear.py b/tools/benchmark/bench_sparse_linear.py index 81af3f60b..2abe9d97d 100644 --- a/tools/benchmark/bench_sparse_linear.py +++ b/tools/benchmark/bench_sparse_linear.py @@ -134,6 +134,20 @@ def _run_output_sparse_fwd_bwd(inp: dict, fn) -> dict: return {"output": output.detach(), "grad_lhs": inp["lhs"].grad, "grad_rhs": inp["rhs"].grad} +def _run_output_sparse_fwd_fp32(inp: dict) -> dict: + lhs_fp32 = inp["lhs"].float().detach().requires_grad_(True) + rhs_fp32 = inp["rhs"].float().detach().requires_grad_(True) + return {"output": _output_sparse_loop(lhs_fp32, rhs_fp32, inp["sparse_map"])} + + +def _run_output_sparse_fwd_bwd_fp32(inp: dict) -> dict: + lhs_fp32 = inp["lhs"].float().detach().requires_grad_(True) + rhs_fp32 = inp["rhs"].float().detach().requires_grad_(True) + output = _output_sparse_loop(lhs_fp32, rhs_fp32, inp["sparse_map"]) + output.backward(inp["backward_grad"].float()) + return {"output": output.detach(), "grad_lhs": lhs_fp32.grad, "grad_rhs": rhs_fp32.grad} + + def _run_output_sparse_fwd_triton(inp: dict) -> dict: return {"output": OutputSparseLinear.apply(inp["lhs"], inp["rhs"], inp["sparse_map"])} @@ -146,11 +160,16 @@ def _run_output_sparse_fwd_bwd_triton(inp: dict) -> dict: def _output_sparse_variants() -> list[Variant]: variants = [ + Variant( + name="fp32_reference", + fwd=_run_output_sparse_fwd_fp32, + fwd_bwd=_run_output_sparse_fwd_bwd_fp32, + is_reference=True, + ), Variant( name="pytorch_loop", fwd=lambda inp: _run_output_sparse_fwd(inp, _output_sparse_loop), fwd_bwd=lambda inp: _run_output_sparse_fwd_bwd(inp, _output_sparse_loop), - is_reference=True, ), Variant( name="pytorch_compiled", @@ -197,6 +216,20 @@ def _run_input_inner_sparse_fwd_bwd(inp: dict, fn) -> dict: return {"output": output.detach(), "grad_lhs": inp["lhs"].grad, "grad_rhs": inp["rhs"].grad} +def _run_input_inner_sparse_fwd_fp32(inp: dict) -> dict: + lhs_fp32 = inp["lhs"].float().detach().requires_grad_(True) + rhs_fp32 = inp["rhs"].float().detach().requires_grad_(True) + return {"output": _input_inner_sparse_loop(lhs_fp32, rhs_fp32, inp["sparse_map"])} + + +def _run_input_inner_sparse_fwd_bwd_fp32(inp: dict) -> dict: + lhs_fp32 = inp["lhs"].float().detach().requires_grad_(True) + rhs_fp32 = inp["rhs"].float().detach().requires_grad_(True) + output = _input_inner_sparse_loop(lhs_fp32, rhs_fp32, inp["sparse_map"]) + output.backward(inp["backward_grad"].float()) + return {"output": output.detach(), "grad_lhs": lhs_fp32.grad, "grad_rhs": rhs_fp32.grad} + + def _run_input_inner_sparse_fwd_triton(inp: dict) -> dict: return {"output": InputSparseLinear.apply(inp["lhs"], inp["rhs"], inp["sparse_map"])} @@ -209,11 +242,16 @@ def _run_input_inner_sparse_fwd_bwd_triton(inp: dict) -> dict: def _input_inner_sparse_variants() -> list[Variant]: variants = [ + Variant( + name="fp32_reference", + fwd=_run_input_inner_sparse_fwd_fp32, + fwd_bwd=_run_input_inner_sparse_fwd_bwd_fp32, + is_reference=True, + ), Variant( name="pytorch_loop", fwd=lambda inp: _run_input_inner_sparse_fwd(inp, _input_inner_sparse_loop), fwd_bwd=lambda inp: _run_input_inner_sparse_fwd_bwd(inp, _input_inner_sparse_loop), - is_reference=True, ), Variant( name="pytorch_compiled", From c1b3cc387304c70d1522a38bb8a4941e1468e40d Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Wed, 29 Apr 2026 20:57:33 -0400 Subject: [PATCH 19/41] Fix sparse warmup duplication; expose Apex availability flags publicly MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Issue 3 — sparse_linear warmup called once per make_inputs invocation: make_inputs is called many times per case (per variant × per fwd/fwd_bwd/ memory pass). The Triton autotuning warmup only needs to fire once per shape. Add module-level sets _output_sparse_warmed_up and _input_inner_sparse_warmed_up; skip the warmup on subsequent calls with the same (tokens, top_k, num_experts, hidden, ffn_per_expert, dtype) key. Issue 4 — private import of _fused/_fast_normalization_available: bench_normalization.py imported names with a leading underscore from fast_llm.layers.common.normalization.normalization. Drop the underscores from the source (they are Apex availability flags, already used widely inside the same module) and update the benchmark import to match. Co-Authored-By: Claude Sonnet 4.6 --- .../common/normalization/normalization.py | 20 +++++++++---------- tools/benchmark/bench_normalization.py | 10 +++++----- tools/benchmark/bench_sparse_linear.py | 16 +++++++++++---- 3 files changed, 27 insertions(+), 19 deletions(-) diff --git a/fast_llm/layers/common/normalization/normalization.py b/fast_llm/layers/common/normalization/normalization.py index 2858b9370..4f92223d2 100644 --- a/fast_llm/layers/common/normalization/normalization.py +++ b/fast_llm/layers/common/normalization/normalization.py @@ -22,16 +22,16 @@ try: import fused_layer_norm_cuda # noqa - _fused_normalization_available = torch.cuda.is_available() + fused_normalization_available = torch.cuda.is_available() except ImportError: - _fused_normalization_available = False + fused_normalization_available = False try: import fast_layer_norm # noqa - _fast_normalization_available = torch.cuda.is_available() + fast_normalization_available = torch.cuda.is_available() except ImportError: - _fast_normalization_available = False + fast_normalization_available = False try: @@ -79,7 +79,7 @@ class FastLayerNorm(torch.autograd.Function): def forward( ctx, input_: torch.Tensor, normalized_shape: torch.Size, weight: torch.Tensor, bias: torch.Tensor, eps: float ) -> torch.Tensor: # noqa - assert _fast_normalization_available + assert fast_normalization_available Assert.incl(normalized_shape.numel(), _PERSIST_LN_SIZES) output, _, inv_var = fast_layer_norm.ln_fwd(input_, weight, bias, eps) ctx.save_for_backward(output, weight, bias, inv_var) @@ -110,7 +110,7 @@ class FusedLayerNorm(torch.autograd.Function): def forward( ctx, input_: torch.Tensor, normalized_shape: torch.Size, weight: torch.Tensor, bias: torch.Tensor, eps: float ) -> torch.Tensor: # noqa - assert _fused_normalization_available + assert fused_normalization_available ctx.eps = eps ctx.normalized_shape = normalized_shape output, _, inv_var = fused_layer_norm_cuda.forward_affine(input_, normalized_shape, weight, bias, eps) @@ -136,7 +136,7 @@ class FusedRMSNorm(torch.autograd.Function): def forward( ctx, input_: torch.Tensor, normalized_shape: torch.Size, weight: torch.Tensor, eps: float ) -> torch.Tensor: # noqa - assert _fused_normalization_available + assert fused_normalization_available ctx.eps = eps ctx.normalized_shape = normalized_shape output, inv_var = fused_layer_norm_cuda.rms_forward_affine(input_, normalized_shape, weight, eps) @@ -185,7 +185,7 @@ def __init__(self, config: ConfigType, hidden_dim: TensorDim, lr_scale: float | implementation = self._config.implementation if implementation == NormalizationImplementation.auto: if ( - _fast_normalization_available + fast_normalization_available and hidden_dim.size in _PERSIST_LN_SIZES and not self._config.zero_centered ): @@ -193,7 +193,7 @@ def __init__(self, config: ConfigType, hidden_dim: TensorDim, lr_scale: float | elif TritonConfig.enabled(torch.device("cuda")) or self._config.zero_centered: log_main_rank("Fast layer norm unavailable, using backup triton implementation.") implementation = NormalizationImplementation.triton - elif _fused_normalization_available: + elif fused_normalization_available: log_main_rank("Fast layer norm unavailable, using backup fused implementation.") implementation = NormalizationImplementation.fused else: @@ -261,7 +261,7 @@ def __init__(self, config: ConfigType, hidden_dim: TensorDim, lr_scale: float | if implementation == NormalizationImplementation.auto: if TritonConfig.enabled(torch.device("cuda")) or self._config.zero_centered: implementation = NormalizationImplementation.triton - elif _fused_normalization_available: + elif fused_normalization_available: log_main_rank("Triton RMS norm unavailable, using fused implementation.") implementation = NormalizationImplementation.fused else: diff --git a/tools/benchmark/bench_normalization.py b/tools/benchmark/bench_normalization.py index 066c53676..0600cbb8e 100644 --- a/tools/benchmark/bench_normalization.py +++ b/tools/benchmark/bench_normalization.py @@ -23,8 +23,8 @@ FastLayerNorm, FusedLayerNorm, FusedRMSNorm, - _fast_normalization_available, - _fused_normalization_available, + fast_normalization_available, + fused_normalization_available, ) from tools.benchmark.runner import Case, Variant, run_benchmark from tools.benchmark.utils import case_name, device @@ -197,7 +197,7 @@ def _layer_norm_variants() -> list[Variant]: fwd_bwd=lambda inp: _run_layer_fwd_bwd(inp, _layer_compiled_max), ), ] - if _fused_normalization_available: + if fused_normalization_available: variants.append( Variant( name="apex_fused", @@ -205,7 +205,7 @@ def _layer_norm_variants() -> list[Variant]: fwd_bwd=lambda inp: _run_layer_fwd_bwd(inp, _layer_norm_apex_fused), ) ) - if _fast_normalization_available: + if fast_normalization_available: # apex_fast only supports widths in _PERSIST_LN_SIZES; all shapes in _SHAPES qualify. variants.append( Variant( @@ -249,7 +249,7 @@ def _rms_norm_variants() -> list[Variant]: fwd_bwd=lambda inp: _run_rms_fwd_bwd(inp, _rms_compiled_max), ), ] - if _fused_normalization_available: + if fused_normalization_available: variants.append( Variant( name="apex_fused", diff --git a/tools/benchmark/bench_sparse_linear.py b/tools/benchmark/bench_sparse_linear.py index 2abe9d97d..01e0c4997 100644 --- a/tools/benchmark/bench_sparse_linear.py +++ b/tools/benchmark/bench_sparse_linear.py @@ -39,6 +39,12 @@ ] _DEFAULT_DTYPES = (torch.bfloat16,) +# Triton autotuning warmup only needs to run once per shape. make_inputs is +# called multiple times per case (per variant, per fwd/fwd_bwd/memory pass), +# so cache which shapes have already been warmed up. +_output_sparse_warmed_up: set[tuple] = set() +_input_inner_sparse_warmed_up: set[tuple] = set() + def _make_sparse_map(tokens: int, top_k: int, num_experts: int) -> SparseMap: top_experts = torch.randint(0, num_experts, (tokens, top_k), device=device()) @@ -63,13 +69,14 @@ def _make_output_sparse_inputs( backward_grad = _zero_padded_rows( torch.ones(sparse_map.num_rows, ffn_per_expert, dtype=dtype, device=device()), sparse_map ) - # Warm up Triton autotuning so the timed runs aren't dominated by JIT compilation. - if TritonConfig.enabled(): + _warmup_key = (tokens, top_k, num_experts, hidden, ffn_per_expert, dtype) + if TritonConfig.enabled() and _warmup_key not in _output_sparse_warmed_up: _w_lhs = lhs_data.detach().requires_grad_(True) _w_rhs = rhs_data.detach().requires_grad_(True) _w_out = OutputSparseLinear.apply(_w_lhs, _w_rhs, sparse_map) _w_out.backward(backward_grad) del _w_lhs, _w_rhs, _w_out + _output_sparse_warmed_up.add(_warmup_key) return { "lhs": lhs_data.requires_grad_(True), "rhs": rhs_data.requires_grad_(True), @@ -90,13 +97,14 @@ def _make_input_inner_sparse_inputs( backward_grad = _zero_padded_rows( torch.ones(sparse_map.num_rows, hidden, dtype=dtype, device=device()), sparse_map ) - # Warm up Triton autotuning so the timed runs aren't dominated by JIT compilation. - if TritonConfig.enabled(): + _warmup_key = (tokens, top_k, num_experts, hidden, ffn_per_expert, dtype) + if TritonConfig.enabled() and _warmup_key not in _input_inner_sparse_warmed_up: _w_lhs = lhs_data.detach().requires_grad_(True) _w_rhs = rhs_data.detach().requires_grad_(True) _w_out = InputSparseLinear.apply(_w_lhs, _w_rhs, sparse_map) _w_out.backward(backward_grad) del _w_lhs, _w_rhs, _w_out + _input_inner_sparse_warmed_up.add(_warmup_key) return { "lhs": lhs_data.requires_grad_(True), "rhs": rhs_data.requires_grad_(True), From fdbf78bd3ac0b66cc359e51bdf2b707e43498ded Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Wed, 29 Apr 2026 20:59:37 -0400 Subject: [PATCH 20/41] Fix E741 ambiguous variable name; document dead GRPO masks runner.py: rename l -> unit_label / unit_scale in _unit_column generator expression to avoid the E741 'ambiguous variable name' lint warning. bench_grpo_loss.py: add a comment explaining that the labels>=0 masks and clamp(min=0) mirror the production implementation's ignore_index=-100 handling; in this benchmark labels are always non-negative so the guards are unreachable. Co-Authored-By: Claude Sonnet 4.6 --- tools/benchmark/bench_grpo_loss.py | 2 ++ tools/benchmark/runner.py | 4 +++- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/tools/benchmark/bench_grpo_loss.py b/tools/benchmark/bench_grpo_loss.py index 4ca0f3673..cb02a1876 100644 --- a/tools/benchmark/bench_grpo_loss.py +++ b/tools/benchmark/bench_grpo_loss.py @@ -46,6 +46,8 @@ def _make_grpo_inputs(tokens: int, vocab: int, dtype: torch.dtype) -> dict: def _grpo_eager(logits: torch.Tensor, labels: torch.Tensor, advantages: torch.Tensor, old_log_probs: torch.Tensor): log_probs = logits.float().log_softmax(-1) + # clamp + labels>=0 guards mirror production code that handles ignore_index=-100; + # labels here are always non-negative (randint), so the masks are dead in this benchmark. new_log_probs = log_probs.gather(-1, labels.clamp(min=0).unsqueeze(-1)).squeeze(-1) new_log_probs = torch.where(labels >= 0, new_log_probs, torch.zeros_like(new_log_probs)) ratio = (new_log_probs - old_log_probs).exp() diff --git a/tools/benchmark/runner.py b/tools/benchmark/runner.py index 36a6f3eb6..e166dd04d 100644 --- a/tools/benchmark/runner.py +++ b/tools/benchmark/runner.py @@ -443,7 +443,9 @@ def _unit_column( else: # All values are zero / None. Fall back to the canonical unit (scale=1.0) # so e.g. memory defaults to MiB rather than the middle of the table. - label, scale = next(((l, s) for (l, s) in units if s == 1.0), units[0]) + label, scale = next( + ((unit_label, unit_scale) for (unit_label, unit_scale) in units if unit_scale == 1.0), units[0] + ) scaled = [v * scale if v is not None else None for v in canonical_values] header = f"{prefix} {label}" if prefix else label return header, _format_aligned(scaled) From 2505fb530c81ac3cf277a30737b58024dbb3ae58 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Wed, 29 Apr 2026 21:18:02 -0400 Subject: [PATCH 21/41] Rename abbreviations and use functools.partial in benchmark suite MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - inp → inputs, elem → element_size, _bytes_per_elem → _bytes_per_element in all bench_*.py files - n_reps → num_reps, n_warmup → num_warmup, sep → separator, l/s → unit_label/unit_scale in runner.py - Replace lambda default-capture patterns (t=tokens, v=vocab, d=dtype…) with functools.partial in all eight bench_*.py files Co-Authored-By: Claude Sonnet 4.6 --- tools/benchmark/bench_entropy_loss.py | 184 ++++++++++++------------ tools/benchmark/bench_grpo_loss.py | 72 +++++----- tools/benchmark/bench_mlp_activation.py | 58 ++++---- tools/benchmark/bench_normalization.py | 120 ++++++++-------- tools/benchmark/bench_pointwise.py | 14 +- tools/benchmark/bench_rotary.py | 20 +-- tools/benchmark/bench_sparse_copy.py | 136 +++++++++--------- tools/benchmark/bench_sparse_linear.py | 120 ++++++++-------- tools/benchmark/runner.py | 26 ++-- 9 files changed, 381 insertions(+), 369 deletions(-) diff --git a/tools/benchmark/bench_entropy_loss.py b/tools/benchmark/bench_entropy_loss.py index 2eec74cb2..b5a479018 100644 --- a/tools/benchmark/bench_entropy_loss.py +++ b/tools/benchmark/bench_entropy_loss.py @@ -15,6 +15,8 @@ Shapes fix tokens=4096, sweep vocab size from Llama-2 (32K) to Llama-3 (128K). """ +from functools import partial + import torch import torch.nn.functional as F @@ -62,40 +64,42 @@ def _ce_labels_eager(logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor _ce_labels_compiled_max = torch.compile(_ce_labels_eager, mode="max-autotune-no-cudagraphs", dynamic=False) -def _run_ce_labels_fwd(inp: dict, fn) -> dict: - return {"loss": fn(inp["logits"], inp["labels"])} +def _run_ce_labels_fwd(inputs: dict, fn) -> dict: + return {"loss": fn(inputs["logits"], inputs["labels"])} -def _run_ce_labels_fwd_fp32(inp: dict) -> dict: - logits_fp32 = inp["logits"].float().detach().requires_grad_(True) - return {"loss": _ce_labels_eager(logits_fp32, inp["labels"])} +def _run_ce_labels_fwd_fp32(inputs: dict) -> dict: + logits_fp32 = inputs["logits"].float().detach().requires_grad_(True) + return {"loss": _ce_labels_eager(logits_fp32, inputs["labels"])} -def _reset_logits_grad(inp: dict) -> None: - inp["logits"].grad = None +def _reset_logits_grad(inputs: dict) -> None: + inputs["logits"].grad = None -def _run_ce_labels_fwd_bwd(inp: dict, fn) -> dict: - loss = fn(inp["logits"], inp["labels"]) +def _run_ce_labels_fwd_bwd(inputs: dict, fn) -> dict: + loss = fn(inputs["logits"], inputs["labels"]) loss.backward() - return {"loss": loss.detach(), "grad_logits": inp["logits"].grad} + return {"loss": loss.detach(), "grad_logits": inputs["logits"].grad} -def _run_ce_labels_fwd_bwd_fp32(inp: dict) -> dict: - logits_fp32 = inp["logits"].float().detach().requires_grad_(True) - loss = _ce_labels_eager(logits_fp32, inp["labels"]) +def _run_ce_labels_fwd_bwd_fp32(inputs: dict) -> dict: + logits_fp32 = inputs["logits"].float().detach().requires_grad_(True) + loss = _ce_labels_eager(logits_fp32, inputs["labels"]) loss.backward() return {"loss": loss.detach(), "grad_logits": logits_fp32.grad} -def _run_ce_labels_fwd_triton(inp: dict) -> dict: - loss, _ = triton_entropy_loss_forward_backward(inp["logits"], inp["labels"], loss_mask=None, grad_output=None) +def _run_ce_labels_fwd_triton(inputs: dict) -> dict: + loss, _ = triton_entropy_loss_forward_backward( + inputs["logits"], inputs["labels"], loss_mask=None, grad_output=None + ) return {"loss": loss} -def _run_ce_labels_fwd_bwd_triton(inp: dict) -> dict: +def _run_ce_labels_fwd_bwd_triton(inputs: dict) -> dict: loss, grad_logits = triton_entropy_loss_forward_backward( - inp["logits"], inp["labels"], loss_mask=None, grad_output=1.0 + inputs["logits"], inputs["labels"], loss_mask=None, grad_output=1.0 ) return {"loss": loss, "grad_logits": grad_logits} @@ -110,20 +114,20 @@ def _ce_labels_variants() -> list[Variant]: ), Variant( name="pytorch_eager", - fwd=lambda inp: _run_ce_labels_fwd(inp, _ce_labels_eager), - fwd_bwd=lambda inp: _run_ce_labels_fwd_bwd(inp, _ce_labels_eager), + fwd=lambda inputs: _run_ce_labels_fwd(inputs, _ce_labels_eager), + fwd_bwd=lambda inputs: _run_ce_labels_fwd_bwd(inputs, _ce_labels_eager), reset_inputs=_reset_logits_grad, ), Variant( name="pytorch_compiled", - fwd=lambda inp: _run_ce_labels_fwd(inp, _ce_labels_compiled_default), - fwd_bwd=lambda inp: _run_ce_labels_fwd_bwd(inp, _ce_labels_compiled_default), + fwd=lambda inputs: _run_ce_labels_fwd(inputs, _ce_labels_compiled_default), + fwd_bwd=lambda inputs: _run_ce_labels_fwd_bwd(inputs, _ce_labels_compiled_default), reset_inputs=_reset_logits_grad, ), Variant( name="pytorch_compiled_max", - fwd=lambda inp: _run_ce_labels_fwd(inp, _ce_labels_compiled_max), - fwd_bwd=lambda inp: _run_ce_labels_fwd_bwd(inp, _ce_labels_compiled_max), + fwd=lambda inputs: _run_ce_labels_fwd(inputs, _ce_labels_compiled_max), + fwd_bwd=lambda inputs: _run_ce_labels_fwd_bwd(inputs, _ce_labels_compiled_max), reset_inputs=_reset_logits_grad, ), ] @@ -150,32 +154,32 @@ def _ce_dist_eager(logits: torch.Tensor, target_logits: torch.Tensor) -> torch.T _ce_dist_compiled_max = torch.compile(_ce_dist_eager, mode="max-autotune-no-cudagraphs", dynamic=False) -def _run_dist_fwd(inp: dict, fn) -> dict: - return {"loss": fn(inp["logits"], inp["target_logits"])} +def _run_dist_fwd(inputs: dict, fn) -> dict: + return {"loss": fn(inputs["logits"], inputs["target_logits"])} -def _run_ce_dist_fwd_fp32(inp: dict) -> dict: - logits_fp32 = inp["logits"].float().detach().requires_grad_(True) - return {"loss": _ce_dist_eager(logits_fp32, inp["target_logits"].float())} +def _run_ce_dist_fwd_fp32(inputs: dict) -> dict: + logits_fp32 = inputs["logits"].float().detach().requires_grad_(True) + return {"loss": _ce_dist_eager(logits_fp32, inputs["target_logits"].float())} -def _run_dist_fwd_bwd(inp: dict, fn) -> dict: - loss = fn(inp["logits"], inp["target_logits"]) +def _run_dist_fwd_bwd(inputs: dict, fn) -> dict: + loss = fn(inputs["logits"], inputs["target_logits"]) loss.backward() - return {"loss": loss.detach(), "grad_logits": inp["logits"].grad} + return {"loss": loss.detach(), "grad_logits": inputs["logits"].grad} -def _run_ce_dist_fwd_bwd_fp32(inp: dict) -> dict: - logits_fp32 = inp["logits"].float().detach().requires_grad_(True) - loss = _ce_dist_eager(logits_fp32, inp["target_logits"].float()) +def _run_ce_dist_fwd_bwd_fp32(inputs: dict) -> dict: + logits_fp32 = inputs["logits"].float().detach().requires_grad_(True) + loss = _ce_dist_eager(logits_fp32, inputs["target_logits"].float()) loss.backward() return {"loss": loss.detach(), "grad_logits": logits_fp32.grad} -def _run_ce_dist_fwd_triton(inp: dict) -> dict: +def _run_ce_dist_fwd_triton(inputs: dict) -> dict: loss, _ = triton_entropy_loss_forward_backward( - inp["logits"], - inp["target_logits"], + inputs["logits"], + inputs["target_logits"], loss_mask=None, grad_output=None, target_format=TargetFormat.logits, @@ -184,10 +188,10 @@ def _run_ce_dist_fwd_triton(inp: dict) -> dict: return {"loss": loss} -def _run_ce_dist_fwd_bwd_triton(inp: dict) -> dict: +def _run_ce_dist_fwd_bwd_triton(inputs: dict) -> dict: loss, grad_logits = triton_entropy_loss_forward_backward( - inp["logits"], - inp["target_logits"], + inputs["logits"], + inputs["target_logits"], loss_mask=None, grad_output=1.0, target_format=TargetFormat.logits, @@ -206,20 +210,20 @@ def _ce_dist_variants() -> list[Variant]: ), Variant( name="pytorch_eager", - fwd=lambda inp: _run_dist_fwd(inp, _ce_dist_eager), - fwd_bwd=lambda inp: _run_dist_fwd_bwd(inp, _ce_dist_eager), + fwd=lambda inputs: _run_dist_fwd(inputs, _ce_dist_eager), + fwd_bwd=lambda inputs: _run_dist_fwd_bwd(inputs, _ce_dist_eager), reset_inputs=_reset_logits_grad, ), Variant( name="pytorch_compiled", - fwd=lambda inp: _run_dist_fwd(inp, _ce_dist_compiled_default), - fwd_bwd=lambda inp: _run_dist_fwd_bwd(inp, _ce_dist_compiled_default), + fwd=lambda inputs: _run_dist_fwd(inputs, _ce_dist_compiled_default), + fwd_bwd=lambda inputs: _run_dist_fwd_bwd(inputs, _ce_dist_compiled_default), reset_inputs=_reset_logits_grad, ), Variant( name="pytorch_compiled_max", - fwd=lambda inp: _run_dist_fwd(inp, _ce_dist_compiled_max), - fwd_bwd=lambda inp: _run_dist_fwd_bwd(inp, _ce_dist_compiled_max), + fwd=lambda inputs: _run_dist_fwd(inputs, _ce_dist_compiled_max), + fwd_bwd=lambda inputs: _run_dist_fwd_bwd(inputs, _ce_dist_compiled_max), reset_inputs=_reset_logits_grad, ), ] @@ -250,22 +254,22 @@ def _reverse_kl_eager(logits: torch.Tensor, target_logits: torch.Tensor) -> torc _reverse_kl_compiled_max = torch.compile(_reverse_kl_eager, mode="max-autotune-no-cudagraphs", dynamic=False) -def _run_rkl_fwd_fp32(inp: dict) -> dict: - logits_fp32 = inp["logits"].float().detach().requires_grad_(True) - return {"loss": _reverse_kl_eager(logits_fp32, inp["target_logits"].float())} +def _run_rkl_fwd_fp32(inputs: dict) -> dict: + logits_fp32 = inputs["logits"].float().detach().requires_grad_(True) + return {"loss": _reverse_kl_eager(logits_fp32, inputs["target_logits"].float())} -def _run_rkl_fwd_bwd_fp32(inp: dict) -> dict: - logits_fp32 = inp["logits"].float().detach().requires_grad_(True) - loss = _reverse_kl_eager(logits_fp32, inp["target_logits"].float()) +def _run_rkl_fwd_bwd_fp32(inputs: dict) -> dict: + logits_fp32 = inputs["logits"].float().detach().requires_grad_(True) + loss = _reverse_kl_eager(logits_fp32, inputs["target_logits"].float()) loss.backward() return {"loss": loss.detach(), "grad_logits": logits_fp32.grad} -def _run_rkl_fwd_triton(inp: dict) -> dict: +def _run_rkl_fwd_triton(inputs: dict) -> dict: loss, _ = triton_entropy_loss_forward_backward( - inp["logits"], - inp["target_logits"], + inputs["logits"], + inputs["target_logits"], loss_mask=None, grad_output=None, target_format=TargetFormat.logits, @@ -274,10 +278,10 @@ def _run_rkl_fwd_triton(inp: dict) -> dict: return {"loss": loss} -def _run_rkl_fwd_bwd_triton(inp: dict) -> dict: +def _run_rkl_fwd_bwd_triton(inputs: dict) -> dict: loss, grad_logits = triton_entropy_loss_forward_backward( - inp["logits"], - inp["target_logits"], + inputs["logits"], + inputs["target_logits"], loss_mask=None, grad_output=1.0, target_format=TargetFormat.logits, @@ -296,20 +300,20 @@ def _reverse_kl_variants() -> list[Variant]: ), Variant( name="pytorch_eager", - fwd=lambda inp: _run_dist_fwd(inp, _reverse_kl_eager), - fwd_bwd=lambda inp: _run_dist_fwd_bwd(inp, _reverse_kl_eager), + fwd=lambda inputs: _run_dist_fwd(inputs, _reverse_kl_eager), + fwd_bwd=lambda inputs: _run_dist_fwd_bwd(inputs, _reverse_kl_eager), reset_inputs=_reset_logits_grad, ), Variant( name="pytorch_compiled", - fwd=lambda inp: _run_dist_fwd(inp, _reverse_kl_compiled_default), - fwd_bwd=lambda inp: _run_dist_fwd_bwd(inp, _reverse_kl_compiled_default), + fwd=lambda inputs: _run_dist_fwd(inputs, _reverse_kl_compiled_default), + fwd_bwd=lambda inputs: _run_dist_fwd_bwd(inputs, _reverse_kl_compiled_default), reset_inputs=_reset_logits_grad, ), Variant( name="pytorch_compiled_max", - fwd=lambda inp: _run_dist_fwd(inp, _reverse_kl_compiled_max), - fwd_bwd=lambda inp: _run_dist_fwd_bwd(inp, _reverse_kl_compiled_max), + fwd=lambda inputs: _run_dist_fwd(inputs, _reverse_kl_compiled_max), + fwd_bwd=lambda inputs: _run_dist_fwd_bwd(inputs, _reverse_kl_compiled_max), reset_inputs=_reset_logits_grad, ), ] @@ -336,35 +340,35 @@ def _z_loss_eager(logits: torch.Tensor) -> torch.Tensor: _z_loss_compiled_max = torch.compile(_z_loss_eager, mode="max-autotune-no-cudagraphs", dynamic=False) -def _run_zl_fwd(inp: dict, fn) -> dict: - return {"loss": fn(inp["logits"])} +def _run_zl_fwd(inputs: dict, fn) -> dict: + return {"loss": fn(inputs["logits"])} -def _run_zl_fwd_fp32(inp: dict) -> dict: - logits_fp32 = inp["logits"].float().detach().requires_grad_(True) +def _run_zl_fwd_fp32(inputs: dict) -> dict: + logits_fp32 = inputs["logits"].float().detach().requires_grad_(True) return {"loss": _z_loss_eager(logits_fp32)} -def _run_zl_fwd_bwd(inp: dict, fn) -> dict: - loss = fn(inp["logits"]) +def _run_zl_fwd_bwd(inputs: dict, fn) -> dict: + loss = fn(inputs["logits"]) loss.backward() - return {"loss": loss.detach(), "grad_logits": inp["logits"].grad} + return {"loss": loss.detach(), "grad_logits": inputs["logits"].grad} -def _run_zl_fwd_bwd_fp32(inp: dict) -> dict: - logits_fp32 = inp["logits"].float().detach().requires_grad_(True) +def _run_zl_fwd_bwd_fp32(inputs: dict) -> dict: + logits_fp32 = inputs["logits"].float().detach().requires_grad_(True) loss = _z_loss_eager(logits_fp32) loss.backward() return {"loss": loss.detach(), "grad_logits": logits_fp32.grad} -def _run_zl_fwd_triton(inp: dict) -> dict: - loss, _ = triton_z_loss_forward_backward(inp["logits"], loss_mask=None, grad_output=None) +def _run_zl_fwd_triton(inputs: dict) -> dict: + loss, _ = triton_z_loss_forward_backward(inputs["logits"], loss_mask=None, grad_output=None) return {"loss": loss} -def _run_zl_fwd_bwd_triton(inp: dict) -> dict: - loss, grad_logits = triton_z_loss_forward_backward(inp["logits"], loss_mask=None, grad_output=1.0) +def _run_zl_fwd_bwd_triton(inputs: dict) -> dict: + loss, grad_logits = triton_z_loss_forward_backward(inputs["logits"], loss_mask=None, grad_output=1.0) return {"loss": loss, "grad_logits": grad_logits} @@ -373,20 +377,20 @@ def _z_loss_variants() -> list[Variant]: Variant(name="fp32_reference", fwd=_run_zl_fwd_fp32, fwd_bwd=_run_zl_fwd_bwd_fp32, is_reference=True), Variant( name="pytorch_eager", - fwd=lambda inp: _run_zl_fwd(inp, _z_loss_eager), - fwd_bwd=lambda inp: _run_zl_fwd_bwd(inp, _z_loss_eager), + fwd=lambda inputs: _run_zl_fwd(inputs, _z_loss_eager), + fwd_bwd=lambda inputs: _run_zl_fwd_bwd(inputs, _z_loss_eager), reset_inputs=_reset_logits_grad, ), Variant( name="pytorch_compiled", - fwd=lambda inp: _run_zl_fwd(inp, _z_loss_compiled_default), - fwd_bwd=lambda inp: _run_zl_fwd_bwd(inp, _z_loss_compiled_default), + fwd=lambda inputs: _run_zl_fwd(inputs, _z_loss_compiled_default), + fwd_bwd=lambda inputs: _run_zl_fwd_bwd(inputs, _z_loss_compiled_default), reset_inputs=_reset_logits_grad, ), Variant( name="pytorch_compiled_max", - fwd=lambda inp: _run_zl_fwd(inp, _z_loss_compiled_max), - fwd_bwd=lambda inp: _run_zl_fwd_bwd(inp, _z_loss_compiled_max), + fwd=lambda inputs: _run_zl_fwd(inputs, _z_loss_compiled_max), + fwd_bwd=lambda inputs: _run_zl_fwd_bwd(inputs, _z_loss_compiled_max), reset_inputs=_reset_logits_grad, ), ] @@ -398,20 +402,20 @@ def _z_loss_variants() -> list[Variant]: # --------------------------------------------------------------------------- cases -def _bytes_per_elem(dtype: torch.dtype) -> int: +def _bytes_per_element(dtype: torch.dtype) -> int: return torch.tensor([], dtype=dtype).element_size() def _label_loss_bytes(tokens: int, vocab: int, dtype: torch.dtype) -> int: """fwd+bwd: read logits, read labels (int32), write grad_logits.""" - elem = _bytes_per_elem(dtype) - return 2 * tokens * vocab * elem + tokens * 4 + element_size = _bytes_per_element(dtype) + return 2 * tokens * vocab * element_size + tokens * 4 def _dist_loss_bytes(tokens: int, vocab: int, dtype: torch.dtype) -> int: """fwd+bwd: read logits, read target_logits, write grad_logits.""" - elem = _bytes_per_elem(dtype) - return 3 * tokens * vocab * elem + element_size = _bytes_per_element(dtype) + return 3 * tokens * vocab * element_size def _entropy_loss_flops(tokens: int, vocab: int) -> int: @@ -423,7 +427,7 @@ def _label_cases(kernel_name: str, dtypes: tuple[torch.dtype, ...]) -> list[Case return [ Case( name=case_name(kernel_name, (tokens, vocab), dtype), - make_inputs=(lambda t=tokens, v=vocab, d=dtype: _make_label_inputs(t, v, d)), + make_inputs=partial(_make_label_inputs, tokens, vocab, dtype), expected_bytes=_label_loss_bytes(tokens, vocab, dtype), expected_flops=_entropy_loss_flops(tokens, vocab), compute_dtype=dtype, @@ -437,7 +441,7 @@ def _dist_cases(kernel_name: str, dtypes: tuple[torch.dtype, ...]) -> list[Case] return [ Case( name=case_name(kernel_name, (tokens, vocab), dtype), - make_inputs=(lambda t=tokens, v=vocab, d=dtype: _make_distribution_inputs(t, v, d)), + make_inputs=partial(_make_distribution_inputs, tokens, vocab, dtype), expected_bytes=_dist_loss_bytes(tokens, vocab, dtype), expected_flops=_entropy_loss_flops(tokens, vocab), compute_dtype=dtype, diff --git a/tools/benchmark/bench_grpo_loss.py b/tools/benchmark/bench_grpo_loss.py index cb02a1876..4dd5cdd5a 100644 --- a/tools/benchmark/bench_grpo_loss.py +++ b/tools/benchmark/bench_grpo_loss.py @@ -18,6 +18,8 @@ Shapes match bench_entropy_loss: tokens=4096, vocab swept over 32K/64K/128K. """ +from functools import partial + import torch from fast_llm.functional.config import TritonConfig @@ -64,44 +66,44 @@ def _grpo_eager(logits: torch.Tensor, labels: torch.Tensor, advantages: torch.Te _grpo_compiled_max = torch.compile(_grpo_eager, mode="max-autotune-no-cudagraphs", dynamic=False) -def _run_fwd(inp: dict, fn) -> dict: - return {"loss": fn(inp["logits"], inp["labels"], inp["advantages"], inp["old_log_probs"])} +def _run_fwd(inputs: dict, fn) -> dict: + return {"loss": fn(inputs["logits"], inputs["labels"], inputs["advantages"], inputs["old_log_probs"])} -def _run_fwd_fp32(inp: dict) -> dict: +def _run_fwd_fp32(inputs: dict) -> dict: return { "loss": _grpo_eager( - inp["logits"].float().detach().requires_grad_(), - inp["labels"], - inp["advantages"], - inp["old_log_probs"], + inputs["logits"].float().detach().requires_grad_(), + inputs["labels"], + inputs["advantages"], + inputs["old_log_probs"], ) } -def _reset_logits_grad(inp: dict) -> None: - inp["logits"].grad = None +def _reset_logits_grad(inputs: dict) -> None: + inputs["logits"].grad = None -def _run_fwd_bwd(inp: dict, fn) -> dict: - loss = fn(inp["logits"], inp["labels"], inp["advantages"], inp["old_log_probs"]) +def _run_fwd_bwd(inputs: dict, fn) -> dict: + loss = fn(inputs["logits"], inputs["labels"], inputs["advantages"], inputs["old_log_probs"]) loss.backward() - return {"loss": loss.detach(), "grad_logits": inp["logits"].grad} + return {"loss": loss.detach(), "grad_logits": inputs["logits"].grad} -def _run_fwd_bwd_fp32(inp: dict) -> dict: - logits_fp32 = inp["logits"].float().detach().requires_grad_() - loss = _grpo_eager(logits_fp32, inp["labels"], inp["advantages"], inp["old_log_probs"]) +def _run_fwd_bwd_fp32(inputs: dict) -> dict: + logits_fp32 = inputs["logits"].float().detach().requires_grad_() + loss = _grpo_eager(logits_fp32, inputs["labels"], inputs["advantages"], inputs["old_log_probs"]) loss.backward() return {"loss": loss.detach(), "grad_logits": logits_fp32.grad} -def _run_fwd_triton(inp: dict) -> dict: +def _run_fwd_triton(inputs: dict) -> dict: loss, _, _ = triton_grpo_loss_forward_backward( - inp["logits"], - inp["labels"], - inp["advantages"], - inp["old_log_probs"], + inputs["logits"], + inputs["labels"], + inputs["advantages"], + inputs["old_log_probs"], grad_output=None, epsilon_low=_EPSILON_LOW, epsilon_high=_EPSILON_HIGH, @@ -109,12 +111,12 @@ def _run_fwd_triton(inp: dict) -> dict: return {"loss": loss} -def _run_fwd_bwd_triton(inp: dict) -> dict: +def _run_fwd_bwd_triton(inputs: dict) -> dict: loss, grad_logits, _ = triton_grpo_loss_forward_backward( - inp["logits"], - inp["labels"], - inp["advantages"], - inp["old_log_probs"], + inputs["logits"], + inputs["labels"], + inputs["advantages"], + inputs["old_log_probs"], grad_output=1.0, epsilon_low=_EPSILON_LOW, epsilon_high=_EPSILON_HIGH, @@ -132,20 +134,20 @@ def _grpo_variants() -> list[Variant]: ), Variant( name="pytorch_eager", - fwd=lambda inp: _run_fwd(inp, _grpo_eager), - fwd_bwd=lambda inp: _run_fwd_bwd(inp, _grpo_eager), + fwd=lambda inputs: _run_fwd(inputs, _grpo_eager), + fwd_bwd=lambda inputs: _run_fwd_bwd(inputs, _grpo_eager), reset_inputs=_reset_logits_grad, ), Variant( name="pytorch_compiled", - fwd=lambda inp: _run_fwd(inp, _grpo_compiled_default), - fwd_bwd=lambda inp: _run_fwd_bwd(inp, _grpo_compiled_default), + fwd=lambda inputs: _run_fwd(inputs, _grpo_compiled_default), + fwd_bwd=lambda inputs: _run_fwd_bwd(inputs, _grpo_compiled_default), reset_inputs=_reset_logits_grad, ), Variant( name="pytorch_compiled_max", - fwd=lambda inp: _run_fwd(inp, _grpo_compiled_max), - fwd_bwd=lambda inp: _run_fwd_bwd(inp, _grpo_compiled_max), + fwd=lambda inputs: _run_fwd(inputs, _grpo_compiled_max), + fwd_bwd=lambda inputs: _run_fwd_bwd(inputs, _grpo_compiled_max), reset_inputs=_reset_logits_grad, ), ] @@ -160,14 +162,14 @@ def _grpo_variants() -> list[Variant]: return variants -def _bytes_per_elem(dtype: torch.dtype) -> int: +def _bytes_per_element(dtype: torch.dtype) -> int: return torch.tensor([], dtype=dtype).element_size() def _grpo_bytes(tokens: int, vocab: int, dtype: torch.dtype) -> int: - elem = _bytes_per_elem(dtype) + element_size = _bytes_per_element(dtype) # fwd: read logits + bwd: read logits + write grad_logits - logit_traffic = 3 * tokens * vocab * elem + logit_traffic = 3 * tokens * vocab * element_size # labels (int64), advantages (fp32), old_log_probs (fp32) scalar_traffic = tokens * (8 + 4 + 4) return logit_traffic + scalar_traffic @@ -182,7 +184,7 @@ def _grpo_cases(dtypes: tuple[torch.dtype, ...]) -> list[Case]: return [ Case( name=case_name("grpo_loss", (tokens, vocab), dtype), - make_inputs=lambda t=tokens, v=vocab, d=dtype: _make_grpo_inputs(t, v, d), + make_inputs=partial(_make_grpo_inputs, tokens, vocab, dtype), expected_bytes=_grpo_bytes(tokens, vocab, dtype), expected_flops=_grpo_flops(tokens, vocab), compute_dtype=dtype, diff --git a/tools/benchmark/bench_mlp_activation.py b/tools/benchmark/bench_mlp_activation.py index 1e2df9831..8d20423a3 100644 --- a/tools/benchmark/bench_mlp_activation.py +++ b/tools/benchmark/bench_mlp_activation.py @@ -15,6 +15,8 @@ Shapes fix tokens=8192 and sweep ffn_dim across typical MLP widths. """ +from functools import partial + import torch from fast_llm.functional.config import ActivationType, TritonConfig @@ -60,39 +62,39 @@ def _pytorch_fwd(input_: torch.Tensor, gated: bool, activation_type: ActivationT _pytorch_compiled_max = torch.compile(_pytorch_fwd, mode="max-autotune-no-cudagraphs", dynamic=False) -def _run_fwd(inp: dict, fn) -> dict: - return {"output": fn(inp["input_"], inp["gated"], inp["activation_type"])} +def _run_fwd(inputs: dict, fn) -> dict: + return {"output": fn(inputs["input_"], inputs["gated"], inputs["activation_type"])} -def _run_fwd_fp32(inp: dict) -> dict: - return {"output": _pytorch_fwd(inp["input_"].float(), inp["gated"], inp["activation_type"])} +def _run_fwd_fp32(inputs: dict) -> dict: + return {"output": _pytorch_fwd(inputs["input_"].float(), inputs["gated"], inputs["activation_type"])} -def _run_fwd_triton(inp: dict) -> dict: - output, _ = triton_mlp_activation_forward(inp["input_"], inp["gated"], inp["activation_type"]) +def _run_fwd_triton(inputs: dict) -> dict: + output, _ = triton_mlp_activation_forward(inputs["input_"], inputs["gated"], inputs["activation_type"]) return {"output": output} # --------------------------------------------------------------------------- fwd+bwd wrappers -def _run_fwd_bwd(inp: dict, fn) -> dict: - output = fn(inp["input_"], inp["gated"], inp["activation_type"]) - output.backward(inp["grad_output"]) - return {"output": output.detach(), "grad_input": inp["input_"].grad} +def _run_fwd_bwd(inputs: dict, fn) -> dict: + output = fn(inputs["input_"], inputs["gated"], inputs["activation_type"]) + output.backward(inputs["grad_output"]) + return {"output": output.detach(), "grad_input": inputs["input_"].grad} -def _run_fwd_bwd_fp32(inp: dict) -> dict: - input_fp32 = inp["input_"].float().detach().requires_grad_(True) - output = _pytorch_fwd(input_fp32, inp["gated"], inp["activation_type"]) - output.backward(inp["grad_output"].float()) +def _run_fwd_bwd_fp32(inputs: dict) -> dict: + input_fp32 = inputs["input_"].float().detach().requires_grad_(True) + output = _pytorch_fwd(input_fp32, inputs["gated"], inputs["activation_type"]) + output.backward(inputs["grad_output"].float()) return {"output": output.detach(), "grad_input": input_fp32.grad} -def _run_fwd_bwd_triton(inp: dict) -> dict: - output = triton_mlp_activation_autograd(inp["input_"], inp["gated"], inp["activation_type"]) - output.backward(inp["grad_output"]) - return {"output": output.detach(), "grad_input": inp["input_"].grad} +def _run_fwd_bwd_triton(inputs: dict) -> dict: + output = triton_mlp_activation_autograd(inputs["input_"], inputs["gated"], inputs["activation_type"]) + output.backward(inputs["grad_output"]) + return {"output": output.detach(), "grad_input": inputs["input_"].grad} # --------------------------------------------------------------------------- variants @@ -108,18 +110,18 @@ def _mlp_activation_variants() -> list[Variant]: ), Variant( name="pytorch_eager", - fwd=lambda inp: _run_fwd(inp, _pytorch_fwd), - fwd_bwd=lambda inp: _run_fwd_bwd(inp, _pytorch_fwd), + fwd=lambda inputs: _run_fwd(inputs, _pytorch_fwd), + fwd_bwd=lambda inputs: _run_fwd_bwd(inputs, _pytorch_fwd), ), Variant( name="pytorch_compiled", - fwd=lambda inp: _run_fwd(inp, _pytorch_compiled_default), - fwd_bwd=lambda inp: _run_fwd_bwd(inp, _pytorch_compiled_default), + fwd=lambda inputs: _run_fwd(inputs, _pytorch_compiled_default), + fwd_bwd=lambda inputs: _run_fwd_bwd(inputs, _pytorch_compiled_default), ), Variant( name="pytorch_compiled_max", - fwd=lambda inp: _run_fwd(inp, _pytorch_compiled_max), - fwd_bwd=lambda inp: _run_fwd_bwd(inp, _pytorch_compiled_max), + fwd=lambda inputs: _run_fwd(inputs, _pytorch_compiled_max), + fwd_bwd=lambda inputs: _run_fwd_bwd(inputs, _pytorch_compiled_max), ), ] if TritonConfig.enabled(): @@ -136,7 +138,7 @@ def _mlp_activation_variants() -> list[Variant]: # --------------------------------------------------------------------------- cases -def _bytes_per_elem(dtype: torch.dtype) -> int: +def _bytes_per_element(dtype: torch.dtype) -> int: return torch.tensor([], dtype=dtype).element_size() @@ -144,11 +146,11 @@ def _mlp_activation_bytes(tokens: int, ffn_dim: int, dtype: torch.dtype) -> int: """fwd: read input (2*ffn_dim) + write output (ffn_dim). bwd: read grad_output (ffn_dim) + read input (2*ffn_dim) + write grad_input (2*ffn_dim). Total: 8 × tokens × ffn_dim × elem_size.""" - return 8 * tokens * ffn_dim * _bytes_per_elem(dtype) + return 8 * tokens * ffn_dim * _bytes_per_element(dtype) def _mlp_activation_flops(tokens: int, ffn_dim: int) -> int: - # gated silu: fwd ≈ 6 FLOPs/elem, bwd ≈ 8 FLOPs/elem, total ≈ 14 per output element. + # gated silu: fwd ≈ 6 FLOPs/element_size, bwd ≈ 8 FLOPs/element_size, total ≈ 14 per output element. return 14 * tokens * ffn_dim @@ -156,7 +158,7 @@ def _mlp_activation_cases(dtypes: tuple[torch.dtype, ...]) -> list[Case]: return [ Case( name=case_name("mlp_activation", (tokens, ffn_dim), dtype), - make_inputs=(lambda t=tokens, f=ffn_dim, d=dtype: _make_mlp_inputs(t, f, d)), + make_inputs=partial(_make_mlp_inputs, tokens, ffn_dim, dtype), expected_bytes=_mlp_activation_bytes(tokens, ffn_dim, dtype), expected_flops=_mlp_activation_flops(tokens, ffn_dim), compute_dtype=dtype, diff --git a/tools/benchmark/bench_normalization.py b/tools/benchmark/bench_normalization.py index 0600cbb8e..3a3c02044 100644 --- a/tools/benchmark/bench_normalization.py +++ b/tools/benchmark/bench_normalization.py @@ -15,6 +15,8 @@ - fast_llm_triton: triton_normalization_autograd """ +from functools import partial + import torch from fast_llm.functional.config import TritonConfig @@ -80,20 +82,20 @@ def _make_rms_norm_inputs(rows: int, cols: int, dtype: torch.dtype) -> dict: } -def _layer_norm_inputs_fp32(inp: dict) -> dict: +def _layer_norm_inputs_fp32(inputs: dict) -> dict: return { - "input_": _to_fp32_input(inp["input_"]), - "weight": _to_fp32_param(inp["weight"]), - "bias": _to_fp32_param(inp["bias"]), - "grad_output": inp["grad_output"].float(), + "input_": _to_fp32_input(inputs["input_"]), + "weight": _to_fp32_param(inputs["weight"]), + "bias": _to_fp32_param(inputs["bias"]), + "grad_output": inputs["grad_output"].float(), } -def _rms_norm_inputs_fp32(inp: dict) -> dict: +def _rms_norm_inputs_fp32(inputs: dict) -> dict: return { - "input_": _to_fp32_input(inp["input_"]), - "weight": _to_fp32_param(inp["weight"]), - "grad_output": inp["grad_output"].float(), + "input_": _to_fp32_input(inputs["input_"]), + "weight": _to_fp32_param(inputs["weight"]), + "grad_output": inputs["grad_output"].float(), } @@ -143,30 +145,30 @@ def _param_grad(param: torch.Tensor) -> torch.Tensor: return param.grad if param.grad is not None else param.grad_buffer -def _run_layer_fwd(inp: dict, fn) -> dict: - return {"output": fn(inp["input_"], inp["weight"], inp["bias"])} +def _run_layer_fwd(inputs: dict, fn) -> dict: + return {"output": fn(inputs["input_"], inputs["weight"], inputs["bias"])} -def _run_layer_fwd_bwd(inp: dict, fn) -> dict: - output = fn(inp["input_"], inp["weight"], inp["bias"]) - output.backward(inp["grad_output"]) +def _run_layer_fwd_bwd(inputs: dict, fn) -> dict: + output = fn(inputs["input_"], inputs["weight"], inputs["bias"]) + output.backward(inputs["grad_output"]) return { - "grad_input": inp["input_"].grad, - "grad_weight": _param_grad(inp["weight"]), - "grad_bias": _param_grad(inp["bias"]), + "grad_input": inputs["input_"].grad, + "grad_weight": _param_grad(inputs["weight"]), + "grad_bias": _param_grad(inputs["bias"]), } -def _run_rms_fwd(inp: dict, fn) -> dict: - return {"output": fn(inp["input_"], inp["weight"])} +def _run_rms_fwd(inputs: dict, fn) -> dict: + return {"output": fn(inputs["input_"], inputs["weight"])} -def _run_rms_fwd_bwd(inp: dict, fn) -> dict: - output = fn(inp["input_"], inp["weight"]) - output.backward(inp["grad_output"]) +def _run_rms_fwd_bwd(inputs: dict, fn) -> dict: + output = fn(inputs["input_"], inputs["weight"]) + output.backward(inputs["grad_output"]) return { - "grad_input": inp["input_"].grad, - "grad_weight": _param_grad(inp["weight"]), + "grad_input": inputs["input_"].grad, + "grad_weight": _param_grad(inputs["weight"]), } @@ -177,32 +179,32 @@ def _layer_norm_variants() -> list[Variant]: variants = [ Variant( name="fp32_reference", - fwd=lambda inp: _run_layer_fwd(_layer_norm_inputs_fp32(inp), _layer_norm_eager), - fwd_bwd=lambda inp: _run_layer_fwd_bwd(_layer_norm_inputs_fp32(inp), _layer_norm_eager), + fwd=lambda inputs: _run_layer_fwd(_layer_norm_inputs_fp32(inputs), _layer_norm_eager), + fwd_bwd=lambda inputs: _run_layer_fwd_bwd(_layer_norm_inputs_fp32(inputs), _layer_norm_eager), is_reference=True, ), Variant( name="pytorch_eager", - fwd=lambda inp: _run_layer_fwd(inp, _layer_norm_eager), - fwd_bwd=lambda inp: _run_layer_fwd_bwd(inp, _layer_norm_eager), + fwd=lambda inputs: _run_layer_fwd(inputs, _layer_norm_eager), + fwd_bwd=lambda inputs: _run_layer_fwd_bwd(inputs, _layer_norm_eager), ), Variant( name="pytorch_compiled", - fwd=lambda inp: _run_layer_fwd(inp, _layer_compiled_default), - fwd_bwd=lambda inp: _run_layer_fwd_bwd(inp, _layer_compiled_default), + fwd=lambda inputs: _run_layer_fwd(inputs, _layer_compiled_default), + fwd_bwd=lambda inputs: _run_layer_fwd_bwd(inputs, _layer_compiled_default), ), Variant( name="pytorch_compiled_max", - fwd=lambda inp: _run_layer_fwd(inp, _layer_compiled_max), - fwd_bwd=lambda inp: _run_layer_fwd_bwd(inp, _layer_compiled_max), + fwd=lambda inputs: _run_layer_fwd(inputs, _layer_compiled_max), + fwd_bwd=lambda inputs: _run_layer_fwd_bwd(inputs, _layer_compiled_max), ), ] if fused_normalization_available: variants.append( Variant( name="apex_fused", - fwd=lambda inp: _run_layer_fwd(inp, _layer_norm_apex_fused), - fwd_bwd=lambda inp: _run_layer_fwd_bwd(inp, _layer_norm_apex_fused), + fwd=lambda inputs: _run_layer_fwd(inputs, _layer_norm_apex_fused), + fwd_bwd=lambda inputs: _run_layer_fwd_bwd(inputs, _layer_norm_apex_fused), ) ) if fast_normalization_available: @@ -210,16 +212,16 @@ def _layer_norm_variants() -> list[Variant]: variants.append( Variant( name="apex_fast", - fwd=lambda inp: _run_layer_fwd(inp, _layer_norm_apex_fast), - fwd_bwd=lambda inp: _run_layer_fwd_bwd(inp, _layer_norm_apex_fast), + fwd=lambda inputs: _run_layer_fwd(inputs, _layer_norm_apex_fast), + fwd_bwd=lambda inputs: _run_layer_fwd_bwd(inputs, _layer_norm_apex_fast), ) ) if TritonConfig.enabled(): variants.append( Variant( name="fast_llm_triton", - fwd=lambda inp: _run_layer_fwd(inp, _layer_norm_triton), - fwd_bwd=lambda inp: _run_layer_fwd_bwd(inp, _layer_norm_triton), + fwd=lambda inputs: _run_layer_fwd(inputs, _layer_norm_triton), + fwd_bwd=lambda inputs: _run_layer_fwd_bwd(inputs, _layer_norm_triton), ) ) return variants @@ -229,40 +231,40 @@ def _rms_norm_variants() -> list[Variant]: variants = [ Variant( name="fp32_reference", - fwd=lambda inp: _run_rms_fwd(_rms_norm_inputs_fp32(inp), _rms_norm_eager), - fwd_bwd=lambda inp: _run_rms_fwd_bwd(_rms_norm_inputs_fp32(inp), _rms_norm_eager), + fwd=lambda inputs: _run_rms_fwd(_rms_norm_inputs_fp32(inputs), _rms_norm_eager), + fwd_bwd=lambda inputs: _run_rms_fwd_bwd(_rms_norm_inputs_fp32(inputs), _rms_norm_eager), is_reference=True, ), Variant( name="pytorch_eager", - fwd=lambda inp: _run_rms_fwd(inp, _rms_norm_eager), - fwd_bwd=lambda inp: _run_rms_fwd_bwd(inp, _rms_norm_eager), + fwd=lambda inputs: _run_rms_fwd(inputs, _rms_norm_eager), + fwd_bwd=lambda inputs: _run_rms_fwd_bwd(inputs, _rms_norm_eager), ), Variant( name="pytorch_compiled", - fwd=lambda inp: _run_rms_fwd(inp, _rms_compiled_default), - fwd_bwd=lambda inp: _run_rms_fwd_bwd(inp, _rms_compiled_default), + fwd=lambda inputs: _run_rms_fwd(inputs, _rms_compiled_default), + fwd_bwd=lambda inputs: _run_rms_fwd_bwd(inputs, _rms_compiled_default), ), Variant( name="pytorch_compiled_max", - fwd=lambda inp: _run_rms_fwd(inp, _rms_compiled_max), - fwd_bwd=lambda inp: _run_rms_fwd_bwd(inp, _rms_compiled_max), + fwd=lambda inputs: _run_rms_fwd(inputs, _rms_compiled_max), + fwd_bwd=lambda inputs: _run_rms_fwd_bwd(inputs, _rms_compiled_max), ), ] if fused_normalization_available: variants.append( Variant( name="apex_fused", - fwd=lambda inp: _run_rms_fwd(inp, _rms_norm_apex_fused), - fwd_bwd=lambda inp: _run_rms_fwd_bwd(inp, _rms_norm_apex_fused), + fwd=lambda inputs: _run_rms_fwd(inputs, _rms_norm_apex_fused), + fwd_bwd=lambda inputs: _run_rms_fwd_bwd(inputs, _rms_norm_apex_fused), ) ) if TritonConfig.enabled(): variants.append( Variant( name="fast_llm_triton", - fwd=lambda inp: _run_rms_fwd(inp, _rms_norm_triton), - fwd_bwd=lambda inp: _run_rms_fwd_bwd(inp, _rms_norm_triton), + fwd=lambda inputs: _run_rms_fwd(inputs, _rms_norm_triton), + fwd_bwd=lambda inputs: _run_rms_fwd_bwd(inputs, _rms_norm_triton), ) ) return variants @@ -271,7 +273,7 @@ def _rms_norm_variants() -> list[Variant]: # --------------------------------------------------------------------------- cases -def _bytes_per_elem(dtype: torch.dtype) -> int: +def _bytes_per_element(dtype: torch.dtype) -> int: return torch.tensor([], dtype=dtype).element_size() @@ -280,16 +282,16 @@ def _layer_norm_bytes(rows: int, cols: int, dtype: torch.dtype) -> int: fwd reads input + weight + bias and writes output (also stores inv_var). bwd reads grad_output, output, weight, bias, inv_var; writes grad_input, grad_weight, grad_bias. Activation tensors dominate.""" - elem = _bytes_per_elem(dtype) - activations = 4 * rows * cols * elem # fwd in/out + bwd grad_in/out - parameters = 6 * cols * elem # weight, bias × (read + grad write) twice + element_size = _bytes_per_element(dtype) + activations = 4 * rows * cols * element_size # fwd in/out + bwd grad_in/out + parameters = 6 * cols * element_size # weight, bias × (read + grad write) twice return activations + parameters def _rms_norm_bytes(rows: int, cols: int, dtype: torch.dtype) -> int: - elem = _bytes_per_elem(dtype) - activations = 4 * rows * cols * elem - parameters = 3 * cols * elem + element_size = _bytes_per_element(dtype) + activations = 4 * rows * cols * element_size + parameters = 3 * cols * element_size return activations + parameters @@ -309,7 +311,7 @@ def _layer_norm_cases(dtypes: tuple[torch.dtype, ...]) -> list[Case]: return [ Case( name=case_name("layer_norm", shape, dtype), - make_inputs=(lambda s=shape, d=dtype: _make_layer_norm_inputs(s[0], s[1], d)), + make_inputs=partial(_make_layer_norm_inputs, shape[0], shape[1], dtype), expected_bytes=_layer_norm_bytes(shape[0], shape[1], dtype), expected_flops=_layer_norm_flops(shape[0], shape[1]), compute_dtype=dtype, @@ -323,7 +325,7 @@ def _rms_norm_cases(dtypes: tuple[torch.dtype, ...]) -> list[Case]: return [ Case( name=case_name("rms_norm", shape, dtype), - make_inputs=(lambda s=shape, d=dtype: _make_rms_norm_inputs(s[0], s[1], d)), + make_inputs=partial(_make_rms_norm_inputs, shape[0], shape[1], dtype), expected_bytes=_rms_norm_bytes(shape[0], shape[1], dtype), expected_flops=_rms_norm_flops(shape[0], shape[1]), compute_dtype=dtype, diff --git a/tools/benchmark/bench_pointwise.py b/tools/benchmark/bench_pointwise.py index b2ea1f00e..a9a059114 100644 --- a/tools/benchmark/bench_pointwise.py +++ b/tools/benchmark/bench_pointwise.py @@ -7,6 +7,8 @@ documented as being ~2x faster than the PyTorch equivalent on A100. """ +from functools import partial + import torch from fast_llm.functional.triton.pointwise import triton_add, triton_copy, triton_fill @@ -42,7 +44,7 @@ def _copy_cases(dtypes: tuple[torch.dtype, ...]) -> list[Case]: return [ Case( name=case_name("copy", (numel,), dtype), - make_inputs=(lambda n=numel, d=dtype: _make_copy_inputs(n, d)), + make_inputs=partial(_make_copy_inputs, numel, dtype), # Read input + write output. expected_bytes=2 * numel * torch.tensor([], dtype=dtype).element_size(), ) @@ -54,7 +56,7 @@ def _copy_cases(dtypes: tuple[torch.dtype, ...]) -> list[Case]: _COPY_VARIANTS = standard_fwd_variants( eager_fn=_copy_eager, triton_fn=triton_copy, - unpack=lambda inp: (inp["input_"], inp["out"]), + unpack=lambda inputs: (inputs["input_"], inputs["out"]), ) @@ -73,7 +75,7 @@ def _fill_cases(dtypes: tuple[torch.dtype, ...]) -> list[Case]: return [ Case( name=case_name("fill", (numel,), dtype), - make_inputs=(lambda n=numel, d=dtype: _make_fill_inputs(n, d)), + make_inputs=partial(_make_fill_inputs, numel, dtype), # Write only. expected_bytes=numel * torch.tensor([], dtype=dtype).element_size(), ) @@ -85,7 +87,7 @@ def _fill_cases(dtypes: tuple[torch.dtype, ...]) -> list[Case]: _FILL_VARIANTS = standard_fwd_variants( eager_fn=_fill_eager, triton_fn=triton_fill, - unpack=lambda inp: (inp["input_"], inp["value"]), + unpack=lambda inputs: (inputs["input_"], inputs["value"]), ) @@ -108,7 +110,7 @@ def _add_cases(dtypes: tuple[torch.dtype, ...]) -> list[Case]: return [ Case( name=case_name("add", (numel,), dtype), - make_inputs=(lambda n=numel, d=dtype: _make_add_inputs(n, d)), + make_inputs=partial(_make_add_inputs, numel, dtype), # Read 2 inputs + write 1 output. expected_bytes=3 * numel * torch.tensor([], dtype=dtype).element_size(), # One fp add per element. @@ -123,7 +125,7 @@ def _add_cases(dtypes: tuple[torch.dtype, ...]) -> list[Case]: _ADD_VARIANTS = standard_fwd_variants( eager_fn=_add_eager, triton_fn=triton_add, - unpack=lambda inp: (inp["input_"], inp["other"], inp["out"]), + unpack=lambda inputs: (inputs["input_"], inputs["other"], inputs["out"]), ) diff --git a/tools/benchmark/bench_rotary.py b/tools/benchmark/bench_rotary.py index 232d45fde..9d75e6815 100644 --- a/tools/benchmark/bench_rotary.py +++ b/tools/benchmark/bench_rotary.py @@ -12,6 +12,8 @@ - 8 heads × 128 → GQA key-value heads (Llama 3) """ +from functools import partial + import torch from fast_llm.functional.config import TritonConfig @@ -55,9 +57,9 @@ def _rotary_eager(input_: torch.Tensor, frequencies: torch.Tensor) -> torch.Tens def _rotary_bytes(tokens: int, num_heads: int, head_size: int, dtype: torch.dtype) -> int: - elem = torch.tensor([], dtype=dtype).element_size() + element_size = torch.tensor([], dtype=dtype).element_size() # Read + write input tensor; frequencies are float32. - return 2 * tokens * num_heads * head_size * elem + tokens * head_size * 4 + return 2 * tokens * num_heads * head_size * element_size + tokens * head_size * 4 def _rotary_flops(tokens: int, num_heads: int, head_size: int) -> int: @@ -69,28 +71,28 @@ def _rotary_variants() -> list[Variant]: variants = [ Variant( name="fp32_reference", - fwd=lambda inp: {"output": _rotary_eager(inp["input_"].float(), inp["frequencies"])}, + fwd=lambda inputs: {"output": _rotary_eager(inputs["input_"].float(), inputs["frequencies"])}, is_reference=True, ), Variant( name="pytorch_eager", - fwd=lambda inp: {"output": _rotary_eager(inp["input_"], inp["frequencies"])}, + fwd=lambda inputs: {"output": _rotary_eager(inputs["input_"], inputs["frequencies"])}, ), Variant( name="pytorch_compiled", - fwd=lambda inp: {"output": _rotary_compiled_default(inp["input_"], inp["frequencies"])}, + fwd=lambda inputs: {"output": _rotary_compiled_default(inputs["input_"], inputs["frequencies"])}, ), Variant( name="pytorch_compiled_max", - fwd=lambda inp: {"output": _rotary_compiled_max(inp["input_"], inp["frequencies"])}, + fwd=lambda inputs: {"output": _rotary_compiled_max(inputs["input_"], inputs["frequencies"])}, ), ] if TritonConfig.enabled(): variants.append( Variant( name="fast_llm_triton", - fwd=lambda inp: {"output": triton_rotary_(inp["work"], inp["frequencies"])}, - reset_inputs=lambda inp: inp["work"].copy_(inp["input_"]), + fwd=lambda inputs: {"output": triton_rotary_(inputs["work"], inputs["frequencies"])}, + reset_inputs=lambda inputs: inputs["work"].copy_(inputs["input_"]), ) ) return variants @@ -100,7 +102,7 @@ def _rotary_cases(dtypes: tuple[torch.dtype, ...]) -> list[Case]: return [ Case( name=case_name("rotary", (tokens, num_heads, head_size), dtype), - make_inputs=(lambda t=tokens, h=num_heads, s=head_size, d=dtype: _make_rotary_inputs(t, h, s, d)), + make_inputs=partial(_make_rotary_inputs, tokens, num_heads, head_size, dtype), expected_bytes=_rotary_bytes(tokens, num_heads, head_size, dtype), expected_flops=_rotary_flops(tokens, num_heads, head_size), compute_dtype=dtype, diff --git a/tools/benchmark/bench_sparse_copy.py b/tools/benchmark/bench_sparse_copy.py index 77768faf5..ddd58a9d5 100644 --- a/tools/benchmark/bench_sparse_copy.py +++ b/tools/benchmark/bench_sparse_copy.py @@ -20,6 +20,8 @@ The SparseMap is pre-computed once per case (routing structure, not data). """ +from functools import partial + import torch from fast_llm.functional.config import TritonConfig @@ -98,40 +100,40 @@ def _dispatch_pytorch(dense_input: torch.Tensor, sparse_map: SparseMap) -> torch _dispatch_compiled_max = torch.compile(_dispatch_pytorch, mode="max-autotune-no-cudagraphs", dynamic=False) -def _run_dispatch_fwd(inp: dict, fn) -> dict: - return {"output": fn(inp["dense_input"], inp["sparse_map"])} +def _run_dispatch_fwd(inputs: dict, fn) -> dict: + return {"output": fn(inputs["dense_input"], inputs["sparse_map"])} -def _run_dispatch_fwd_bwd(inp: dict, fn) -> dict: - output = fn(inp["dense_input"], inp["sparse_map"]) - output.backward(inp["backward_grad"]) - return {"output": output.detach(), "grad_dense": inp["dense_input"].grad} +def _run_dispatch_fwd_bwd(inputs: dict, fn) -> dict: + output = fn(inputs["dense_input"], inputs["sparse_map"]) + output.backward(inputs["backward_grad"]) + return {"output": output.detach(), "grad_dense": inputs["dense_input"].grad} -def _run_dispatch_fwd_fp32(inp: dict) -> dict: - dense_fp32 = inp["dense_input"].float().detach().requires_grad_(True) - return {"output": _dispatch_pytorch(dense_fp32, inp["sparse_map"])} +def _run_dispatch_fwd_fp32(inputs: dict) -> dict: + dense_fp32 = inputs["dense_input"].float().detach().requires_grad_(True) + return {"output": _dispatch_pytorch(dense_fp32, inputs["sparse_map"])} -def _run_dispatch_fwd_bwd_fp32(inp: dict) -> dict: - dense_fp32 = inp["dense_input"].float().detach().requires_grad_(True) - output = _dispatch_pytorch(dense_fp32, inp["sparse_map"]) - output.backward(inp["backward_grad"].float()) +def _run_dispatch_fwd_bwd_fp32(inputs: dict) -> dict: + dense_fp32 = inputs["dense_input"].float().detach().requires_grad_(True) + output = _dispatch_pytorch(dense_fp32, inputs["sparse_map"]) + output.backward(inputs["backward_grad"].float()) return {"output": output.detach(), "grad_dense": dense_fp32.grad} -def _run_dispatch_fwd_triton(inp: dict) -> dict: - return {"output": copy_dense_to_sparse_autograd(inp["dense_input"], inp["sparse_map"])} +def _run_dispatch_fwd_triton(inputs: dict) -> dict: + return {"output": copy_dense_to_sparse_autograd(inputs["dense_input"], inputs["sparse_map"])} -def _run_dispatch_fwd_bwd_triton(inp: dict) -> dict: - output = copy_dense_to_sparse_autograd(inp["dense_input"], inp["sparse_map"]) - output.backward(inp["backward_grad"]) - return {"output": output.detach(), "grad_dense": inp["dense_input"].grad} +def _run_dispatch_fwd_bwd_triton(inputs: dict) -> dict: + output = copy_dense_to_sparse_autograd(inputs["dense_input"], inputs["sparse_map"]) + output.backward(inputs["backward_grad"]) + return {"output": output.detach(), "grad_dense": inputs["dense_input"].grad} -def _dispatch_postprocess(out: dict[str, torch.Tensor], inp: dict) -> dict[str, torch.Tensor]: - out["output"].masked_fill_(inp["phantom_mask"], 0) +def _dispatch_postprocess(out: dict[str, torch.Tensor], inputs: dict) -> dict[str, torch.Tensor]: + out["output"].masked_fill_(inputs["phantom_mask"], 0) return out @@ -145,18 +147,18 @@ def _dispatch_variants() -> list[Variant]: ), Variant( name="pytorch_eager", - fwd=lambda inp: _run_dispatch_fwd(inp, _dispatch_pytorch), - fwd_bwd=lambda inp: _run_dispatch_fwd_bwd(inp, _dispatch_pytorch), + fwd=lambda inputs: _run_dispatch_fwd(inputs, _dispatch_pytorch), + fwd_bwd=lambda inputs: _run_dispatch_fwd_bwd(inputs, _dispatch_pytorch), ), Variant( name="pytorch_compiled", - fwd=lambda inp: _run_dispatch_fwd(inp, _dispatch_compiled_default), - fwd_bwd=lambda inp: _run_dispatch_fwd_bwd(inp, _dispatch_compiled_default), + fwd=lambda inputs: _run_dispatch_fwd(inputs, _dispatch_compiled_default), + fwd_bwd=lambda inputs: _run_dispatch_fwd_bwd(inputs, _dispatch_compiled_default), ), Variant( name="pytorch_compiled_max", - fwd=lambda inp: _run_dispatch_fwd(inp, _dispatch_compiled_max), - fwd_bwd=lambda inp: _run_dispatch_fwd_bwd(inp, _dispatch_compiled_max), + fwd=lambda inputs: _run_dispatch_fwd(inputs, _dispatch_compiled_max), + fwd_bwd=lambda inputs: _run_dispatch_fwd_bwd(inputs, _dispatch_compiled_max), ), ] if TritonConfig.enabled(): @@ -186,31 +188,31 @@ def _combine_pytorch(sparse_input: torch.Tensor, scores: torch.Tensor, sparse_ma _combine_compiled_max = torch.compile(_combine_pytorch, mode="max-autotune-no-cudagraphs", dynamic=False) -def _run_combine_fwd(inp: dict, fn) -> dict: - return {"output": fn(inp["sparse_input"], inp["scores"], inp["sparse_map"])} +def _run_combine_fwd(inputs: dict, fn) -> dict: + return {"output": fn(inputs["sparse_input"], inputs["scores"], inputs["sparse_map"])} -def _run_combine_fwd_bwd(inp: dict, fn) -> dict: - output = fn(inp["sparse_input"], inp["scores"], inp["sparse_map"]) - output.backward(inp["backward_grad"]) +def _run_combine_fwd_bwd(inputs: dict, fn) -> dict: + output = fn(inputs["sparse_input"], inputs["scores"], inputs["sparse_map"]) + output.backward(inputs["backward_grad"]) return { "output": output.detach(), - "grad_sparse": inp["sparse_input"].grad, - "grad_scores": inp["scores"].grad, + "grad_sparse": inputs["sparse_input"].grad, + "grad_scores": inputs["scores"].grad, } -def _run_combine_fwd_fp32(inp: dict) -> dict: - sparse_fp32 = inp["sparse_input"].float().detach().requires_grad_(True) - scores_fp32 = inp["scores"].float().detach().requires_grad_(True) - return {"output": _combine_pytorch(sparse_fp32, scores_fp32, inp["sparse_map"])} +def _run_combine_fwd_fp32(inputs: dict) -> dict: + sparse_fp32 = inputs["sparse_input"].float().detach().requires_grad_(True) + scores_fp32 = inputs["scores"].float().detach().requires_grad_(True) + return {"output": _combine_pytorch(sparse_fp32, scores_fp32, inputs["sparse_map"])} -def _run_combine_fwd_bwd_fp32(inp: dict) -> dict: - sparse_fp32 = inp["sparse_input"].float().detach().requires_grad_(True) - scores_fp32 = inp["scores"].float().detach().requires_grad_(True) - output = _combine_pytorch(sparse_fp32, scores_fp32, inp["sparse_map"]) - output.backward(inp["backward_grad"].float()) +def _run_combine_fwd_bwd_fp32(inputs: dict) -> dict: + sparse_fp32 = inputs["sparse_input"].float().detach().requires_grad_(True) + scores_fp32 = inputs["scores"].float().detach().requires_grad_(True) + output = _combine_pytorch(sparse_fp32, scores_fp32, inputs["sparse_map"]) + output.backward(inputs["backward_grad"].float()) return { "output": output.detach(), "grad_sparse": sparse_fp32.grad, @@ -218,23 +220,23 @@ def _run_combine_fwd_bwd_fp32(inp: dict) -> dict: } -def _run_combine_fwd_triton(inp: dict) -> dict: - return {"output": copy_sparse_to_dense_autograd(inp["sparse_input"], inp["scores"], inp["sparse_map"])} +def _run_combine_fwd_triton(inputs: dict) -> dict: + return {"output": copy_sparse_to_dense_autograd(inputs["sparse_input"], inputs["scores"], inputs["sparse_map"])} -def _run_combine_fwd_bwd_triton(inp: dict) -> dict: - output = copy_sparse_to_dense_autograd(inp["sparse_input"], inp["scores"], inp["sparse_map"]) - output.backward(inp["backward_grad"]) +def _run_combine_fwd_bwd_triton(inputs: dict) -> dict: + output = copy_sparse_to_dense_autograd(inputs["sparse_input"], inputs["scores"], inputs["sparse_map"]) + output.backward(inputs["backward_grad"]) return { "output": output.detach(), - "grad_sparse": inp["sparse_input"].grad, - "grad_scores": inp["scores"].grad, + "grad_sparse": inputs["sparse_input"].grad, + "grad_scores": inputs["scores"].grad, } -def _combine_postprocess(out: dict[str, torch.Tensor], inp: dict) -> dict[str, torch.Tensor]: +def _combine_postprocess(out: dict[str, torch.Tensor], inputs: dict) -> dict[str, torch.Tensor]: if "grad_sparse" in out: - out["grad_sparse"].masked_fill_(inp["phantom_mask"], 0) + out["grad_sparse"].masked_fill_(inputs["phantom_mask"], 0) return out @@ -248,18 +250,18 @@ def _combine_variants() -> list[Variant]: ), Variant( name="pytorch_eager", - fwd=lambda inp: _run_combine_fwd(inp, _combine_pytorch), - fwd_bwd=lambda inp: _run_combine_fwd_bwd(inp, _combine_pytorch), + fwd=lambda inputs: _run_combine_fwd(inputs, _combine_pytorch), + fwd_bwd=lambda inputs: _run_combine_fwd_bwd(inputs, _combine_pytorch), ), Variant( name="pytorch_compiled", - fwd=lambda inp: _run_combine_fwd(inp, _combine_compiled_default), - fwd_bwd=lambda inp: _run_combine_fwd_bwd(inp, _combine_compiled_default), + fwd=lambda inputs: _run_combine_fwd(inputs, _combine_compiled_default), + fwd_bwd=lambda inputs: _run_combine_fwd_bwd(inputs, _combine_compiled_default), ), Variant( name="pytorch_compiled_max", - fwd=lambda inp: _run_combine_fwd(inp, _combine_compiled_max), - fwd_bwd=lambda inp: _run_combine_fwd_bwd(inp, _combine_compiled_max), + fwd=lambda inputs: _run_combine_fwd(inputs, _combine_compiled_max), + fwd_bwd=lambda inputs: _run_combine_fwd_bwd(inputs, _combine_compiled_max), ), ] if TritonConfig.enabled(): @@ -277,32 +279,30 @@ def _combine_variants() -> list[Variant]: # --------------------------------------------------------------------------- cases / bytes -def _bytes_per_elem(dtype: torch.dtype) -> int: +def _bytes_per_element(dtype: torch.dtype) -> int: return torch.tensor([], dtype=dtype).element_size() def _dispatch_bytes(tokens: int, top_k: int, hidden: int, dtype: torch.dtype) -> int: - elem = _bytes_per_elem(dtype) + element_size = _bytes_per_element(dtype) # fwd: read dense (tokens×h) + write sparse (top_k×tokens×h) # bwd: read sparse grad + write dense grad → same traffic reversed - return 2 * (1 + top_k) * tokens * hidden * elem + return 2 * (1 + top_k) * tokens * hidden * element_size def _combine_bytes(tokens: int, top_k: int, hidden: int, dtype: torch.dtype) -> int: - elem = _bytes_per_elem(dtype) + element_size = _bytes_per_element(dtype) sparse_rows = top_k * tokens # fwd: read sparse (sparse×h) + read scores (tokens×top_k) + write dense (tokens×h) # bwd: read dense grad + read scores + write sparse grad + write score grad - return 2 * (sparse_rows + tokens) * hidden * elem + 4 * tokens * top_k * elem + return 2 * (sparse_rows + tokens) * hidden * element_size + 4 * tokens * top_k * element_size def _dispatch_cases(dtypes: tuple[torch.dtype, ...]) -> list[Case]: return [ Case( name=case_name("dispatch", (tokens, top_k, num_experts, hidden), dtype), - make_inputs=lambda t=tokens, k=top_k, n=num_experts, h=hidden, d=dtype: _make_dispatch_inputs( - t, k, n, h, d - ), + make_inputs=partial(_make_dispatch_inputs, tokens, top_k, num_experts, hidden, dtype), expected_bytes=_dispatch_bytes(tokens, top_k, hidden, dtype), expected_flops=0, compute_dtype=dtype, @@ -316,9 +316,7 @@ def _combine_cases(dtypes: tuple[torch.dtype, ...]) -> list[Case]: return [ Case( name=case_name("combine", (tokens, top_k, num_experts, hidden), dtype), - make_inputs=lambda t=tokens, k=top_k, n=num_experts, h=hidden, d=dtype: _make_combine_inputs( - t, k, n, h, d - ), + make_inputs=partial(_make_combine_inputs, tokens, top_k, num_experts, hidden, dtype), expected_bytes=_combine_bytes(tokens, top_k, hidden, dtype), expected_flops=0, compute_dtype=dtype, diff --git a/tools/benchmark/bench_sparse_linear.py b/tools/benchmark/bench_sparse_linear.py index 01e0c4997..42d8e0fa5 100644 --- a/tools/benchmark/bench_sparse_linear.py +++ b/tools/benchmark/bench_sparse_linear.py @@ -23,6 +23,8 @@ Shapes: (tokens, top_k, num_experts, hidden, ffn_per_expert) matching MoE FFN configs. """ +from functools import partial + import torch from fast_llm.functional.config import TritonConfig @@ -132,38 +134,38 @@ def _output_sparse_loop(lhs: torch.Tensor, rhs: torch.Tensor, sparse_map: Sparse _output_sparse_compiled = torch.compile(_output_sparse_loop, mode="default", dynamic=False) -def _run_output_sparse_fwd(inp: dict, fn) -> dict: - return {"output": fn(inp["lhs"], inp["rhs"], inp["sparse_map"])} +def _run_output_sparse_fwd(inputs: dict, fn) -> dict: + return {"output": fn(inputs["lhs"], inputs["rhs"], inputs["sparse_map"])} -def _run_output_sparse_fwd_bwd(inp: dict, fn) -> dict: - output = fn(inp["lhs"], inp["rhs"], inp["sparse_map"]) - output.backward(inp["backward_grad"]) - return {"output": output.detach(), "grad_lhs": inp["lhs"].grad, "grad_rhs": inp["rhs"].grad} +def _run_output_sparse_fwd_bwd(inputs: dict, fn) -> dict: + output = fn(inputs["lhs"], inputs["rhs"], inputs["sparse_map"]) + output.backward(inputs["backward_grad"]) + return {"output": output.detach(), "grad_lhs": inputs["lhs"].grad, "grad_rhs": inputs["rhs"].grad} -def _run_output_sparse_fwd_fp32(inp: dict) -> dict: - lhs_fp32 = inp["lhs"].float().detach().requires_grad_(True) - rhs_fp32 = inp["rhs"].float().detach().requires_grad_(True) - return {"output": _output_sparse_loop(lhs_fp32, rhs_fp32, inp["sparse_map"])} +def _run_output_sparse_fwd_fp32(inputs: dict) -> dict: + lhs_fp32 = inputs["lhs"].float().detach().requires_grad_(True) + rhs_fp32 = inputs["rhs"].float().detach().requires_grad_(True) + return {"output": _output_sparse_loop(lhs_fp32, rhs_fp32, inputs["sparse_map"])} -def _run_output_sparse_fwd_bwd_fp32(inp: dict) -> dict: - lhs_fp32 = inp["lhs"].float().detach().requires_grad_(True) - rhs_fp32 = inp["rhs"].float().detach().requires_grad_(True) - output = _output_sparse_loop(lhs_fp32, rhs_fp32, inp["sparse_map"]) - output.backward(inp["backward_grad"].float()) +def _run_output_sparse_fwd_bwd_fp32(inputs: dict) -> dict: + lhs_fp32 = inputs["lhs"].float().detach().requires_grad_(True) + rhs_fp32 = inputs["rhs"].float().detach().requires_grad_(True) + output = _output_sparse_loop(lhs_fp32, rhs_fp32, inputs["sparse_map"]) + output.backward(inputs["backward_grad"].float()) return {"output": output.detach(), "grad_lhs": lhs_fp32.grad, "grad_rhs": rhs_fp32.grad} -def _run_output_sparse_fwd_triton(inp: dict) -> dict: - return {"output": OutputSparseLinear.apply(inp["lhs"], inp["rhs"], inp["sparse_map"])} +def _run_output_sparse_fwd_triton(inputs: dict) -> dict: + return {"output": OutputSparseLinear.apply(inputs["lhs"], inputs["rhs"], inputs["sparse_map"])} -def _run_output_sparse_fwd_bwd_triton(inp: dict) -> dict: - output = OutputSparseLinear.apply(inp["lhs"], inp["rhs"], inp["sparse_map"]) - output.backward(inp["backward_grad"]) - return {"output": output.detach(), "grad_lhs": inp["lhs"].grad, "grad_rhs": inp["rhs"].grad} +def _run_output_sparse_fwd_bwd_triton(inputs: dict) -> dict: + output = OutputSparseLinear.apply(inputs["lhs"], inputs["rhs"], inputs["sparse_map"]) + output.backward(inputs["backward_grad"]) + return {"output": output.detach(), "grad_lhs": inputs["lhs"].grad, "grad_rhs": inputs["rhs"].grad} def _output_sparse_variants() -> list[Variant]: @@ -176,13 +178,13 @@ def _output_sparse_variants() -> list[Variant]: ), Variant( name="pytorch_loop", - fwd=lambda inp: _run_output_sparse_fwd(inp, _output_sparse_loop), - fwd_bwd=lambda inp: _run_output_sparse_fwd_bwd(inp, _output_sparse_loop), + fwd=lambda inputs: _run_output_sparse_fwd(inputs, _output_sparse_loop), + fwd_bwd=lambda inputs: _run_output_sparse_fwd_bwd(inputs, _output_sparse_loop), ), Variant( name="pytorch_compiled", - fwd=lambda inp: _run_output_sparse_fwd(inp, _output_sparse_compiled), - fwd_bwd=lambda inp: _run_output_sparse_fwd_bwd(inp, _output_sparse_compiled), + fwd=lambda inputs: _run_output_sparse_fwd(inputs, _output_sparse_compiled), + fwd_bwd=lambda inputs: _run_output_sparse_fwd_bwd(inputs, _output_sparse_compiled), ), ] if TritonConfig.enabled(): @@ -214,38 +216,38 @@ def _input_inner_sparse_loop(lhs: torch.Tensor, rhs: torch.Tensor, sparse_map: S _input_inner_sparse_compiled = torch.compile(_input_inner_sparse_loop, mode="default", dynamic=False) -def _run_input_inner_sparse_fwd(inp: dict, fn) -> dict: - return {"output": fn(inp["lhs"], inp["rhs"], inp["sparse_map"])} +def _run_input_inner_sparse_fwd(inputs: dict, fn) -> dict: + return {"output": fn(inputs["lhs"], inputs["rhs"], inputs["sparse_map"])} -def _run_input_inner_sparse_fwd_bwd(inp: dict, fn) -> dict: - output = fn(inp["lhs"], inp["rhs"], inp["sparse_map"]) - output.backward(inp["backward_grad"]) - return {"output": output.detach(), "grad_lhs": inp["lhs"].grad, "grad_rhs": inp["rhs"].grad} +def _run_input_inner_sparse_fwd_bwd(inputs: dict, fn) -> dict: + output = fn(inputs["lhs"], inputs["rhs"], inputs["sparse_map"]) + output.backward(inputs["backward_grad"]) + return {"output": output.detach(), "grad_lhs": inputs["lhs"].grad, "grad_rhs": inputs["rhs"].grad} -def _run_input_inner_sparse_fwd_fp32(inp: dict) -> dict: - lhs_fp32 = inp["lhs"].float().detach().requires_grad_(True) - rhs_fp32 = inp["rhs"].float().detach().requires_grad_(True) - return {"output": _input_inner_sparse_loop(lhs_fp32, rhs_fp32, inp["sparse_map"])} +def _run_input_inner_sparse_fwd_fp32(inputs: dict) -> dict: + lhs_fp32 = inputs["lhs"].float().detach().requires_grad_(True) + rhs_fp32 = inputs["rhs"].float().detach().requires_grad_(True) + return {"output": _input_inner_sparse_loop(lhs_fp32, rhs_fp32, inputs["sparse_map"])} -def _run_input_inner_sparse_fwd_bwd_fp32(inp: dict) -> dict: - lhs_fp32 = inp["lhs"].float().detach().requires_grad_(True) - rhs_fp32 = inp["rhs"].float().detach().requires_grad_(True) - output = _input_inner_sparse_loop(lhs_fp32, rhs_fp32, inp["sparse_map"]) - output.backward(inp["backward_grad"].float()) +def _run_input_inner_sparse_fwd_bwd_fp32(inputs: dict) -> dict: + lhs_fp32 = inputs["lhs"].float().detach().requires_grad_(True) + rhs_fp32 = inputs["rhs"].float().detach().requires_grad_(True) + output = _input_inner_sparse_loop(lhs_fp32, rhs_fp32, inputs["sparse_map"]) + output.backward(inputs["backward_grad"].float()) return {"output": output.detach(), "grad_lhs": lhs_fp32.grad, "grad_rhs": rhs_fp32.grad} -def _run_input_inner_sparse_fwd_triton(inp: dict) -> dict: - return {"output": InputSparseLinear.apply(inp["lhs"], inp["rhs"], inp["sparse_map"])} +def _run_input_inner_sparse_fwd_triton(inputs: dict) -> dict: + return {"output": InputSparseLinear.apply(inputs["lhs"], inputs["rhs"], inputs["sparse_map"])} -def _run_input_inner_sparse_fwd_bwd_triton(inp: dict) -> dict: - output = InputSparseLinear.apply(inp["lhs"], inp["rhs"], inp["sparse_map"]) - output.backward(inp["backward_grad"]) - return {"output": output.detach(), "grad_lhs": inp["lhs"].grad, "grad_rhs": inp["rhs"].grad} +def _run_input_inner_sparse_fwd_bwd_triton(inputs: dict) -> dict: + output = InputSparseLinear.apply(inputs["lhs"], inputs["rhs"], inputs["sparse_map"]) + output.backward(inputs["backward_grad"]) + return {"output": output.detach(), "grad_lhs": inputs["lhs"].grad, "grad_rhs": inputs["rhs"].grad} def _input_inner_sparse_variants() -> list[Variant]: @@ -258,13 +260,13 @@ def _input_inner_sparse_variants() -> list[Variant]: ), Variant( name="pytorch_loop", - fwd=lambda inp: _run_input_inner_sparse_fwd(inp, _input_inner_sparse_loop), - fwd_bwd=lambda inp: _run_input_inner_sparse_fwd_bwd(inp, _input_inner_sparse_loop), + fwd=lambda inputs: _run_input_inner_sparse_fwd(inputs, _input_inner_sparse_loop), + fwd_bwd=lambda inputs: _run_input_inner_sparse_fwd_bwd(inputs, _input_inner_sparse_loop), ), Variant( name="pytorch_compiled", - fwd=lambda inp: _run_input_inner_sparse_fwd(inp, _input_inner_sparse_compiled), - fwd_bwd=lambda inp: _run_input_inner_sparse_fwd_bwd(inp, _input_inner_sparse_compiled), + fwd=lambda inputs: _run_input_inner_sparse_fwd(inputs, _input_inner_sparse_compiled), + fwd_bwd=lambda inputs: _run_input_inner_sparse_fwd_bwd(inputs, _input_inner_sparse_compiled), ), ] if TritonConfig.enabled(): @@ -281,20 +283,20 @@ def _input_inner_sparse_variants() -> list[Variant]: # --------------------------------------------------------------------------- cases / bytes / flops -def _bytes_per_elem(dtype: torch.dtype) -> int: +def _bytes_per_element(dtype: torch.dtype) -> int: return torch.tensor([], dtype=dtype).element_size() def _sparse_linear_bytes( sparse_tokens: int, hidden: int, ffn_per_expert: int, num_experts: int, dtype: torch.dtype ) -> int: - elem = _bytes_per_elem(dtype) + element_size = _bytes_per_element(dtype) # fwd: read lhs + read rhs_full + write output # bwd: read grad_output + read rhs_full + write grad_lhs + read lhs + read grad_output + write grad_rhs # Simplification: 3× lhs traffic + 3× rhs traffic + 2× output traffic - lhs_bytes = sparse_tokens * hidden * elem - rhs_bytes = hidden * ffn_per_expert * num_experts * elem - out_bytes = sparse_tokens * ffn_per_expert * elem + lhs_bytes = sparse_tokens * hidden * element_size + rhs_bytes = hidden * ffn_per_expert * num_experts * element_size + out_bytes = sparse_tokens * ffn_per_expert * element_size return 3 * lhs_bytes + 3 * rhs_bytes + 2 * out_bytes @@ -307,9 +309,7 @@ def _output_sparse_cases(dtypes: tuple[torch.dtype, ...]) -> list[Case]: return [ Case( name=case_name("output_sparse", (tokens, top_k, num_experts, hidden, ffn_per_expert), dtype), - make_inputs=lambda t=tokens, k=top_k, n=num_experts, h=hidden, f=ffn_per_expert, d=dtype: ( - _make_output_sparse_inputs(t, k, n, h, f, d) - ), + make_inputs=partial(_make_output_sparse_inputs, tokens, top_k, num_experts, hidden, ffn_per_expert, dtype), expected_bytes=_sparse_linear_bytes(tokens * top_k, hidden, ffn_per_expert, num_experts, dtype), expected_flops=_sparse_linear_flops(tokens * top_k, hidden, ffn_per_expert), compute_dtype=dtype, @@ -323,8 +323,8 @@ def _input_inner_sparse_cases(dtypes: tuple[torch.dtype, ...]) -> list[Case]: return [ Case( name=case_name("input_inner_sparse", (tokens, top_k, num_experts, hidden, ffn_per_expert), dtype), - make_inputs=lambda t=tokens, k=top_k, n=num_experts, h=hidden, f=ffn_per_expert, d=dtype: ( - _make_input_inner_sparse_inputs(t, k, n, h, f, d) + make_inputs=partial( + _make_input_inner_sparse_inputs, tokens, top_k, num_experts, hidden, ffn_per_expert, dtype ), expected_bytes=_sparse_linear_bytes(tokens * top_k, ffn_per_expert, hidden, num_experts, dtype), expected_flops=_sparse_linear_flops(tokens * top_k, ffn_per_expert, hidden), diff --git a/tools/benchmark/runner.py b/tools/benchmark/runner.py index e166dd04d..e9547d968 100644 --- a/tools/benchmark/runner.py +++ b/tools/benchmark/runner.py @@ -104,7 +104,7 @@ class TimingStats: min_ms: float max_ms: float std_ms: float - n_reps: int + num_reps: int @dataclasses.dataclass @@ -175,8 +175,8 @@ def bench_fn( one_rep_ms = warmup_start.elapsed_time(warmup_end) # Additional warmup to stabilize (covers autotune misses on first call) - n_warmup = max(1, int(warmup_ms / max(one_rep_ms, 0.01))) - for _ in range(n_warmup): + num_warmup = max(1, int(warmup_ms / max(one_rep_ms, 0.01))) + for _ in range(num_warmup): fn() torch.cuda.synchronize() @@ -189,11 +189,11 @@ def bench_fn( torch.cuda.synchronize() one_rep_ms = max(post_start.elapsed_time(post_end), 0.001) - n_reps = max(min_reps, min(max_reps, int(rep_ms / one_rep_ms))) + num_reps = max(min_reps, min(max_reps, int(rep_ms / one_rep_ms))) - start_events = [torch.cuda.Event(enable_timing=True) for _ in range(n_reps)] - end_events = [torch.cuda.Event(enable_timing=True) for _ in range(n_reps)] - for i in range(n_reps): + start_events = [torch.cuda.Event(enable_timing=True) for _ in range(num_reps)] + end_events = [torch.cuda.Event(enable_timing=True) for _ in range(num_reps)] + for i in range(num_reps): if reset is not None: reset() flush() @@ -202,14 +202,14 @@ def bench_fn( end_events[i].record() torch.cuda.synchronize() - times = [start_events[i].elapsed_time(end_events[i]) for i in range(n_reps)] + times = [start_events[i].elapsed_time(end_events[i]) for i in range(num_reps)] return TimingStats( median_ms=statistics.median(times), mean_ms=statistics.fmean(times), min_ms=min(times), max_ms=max(times), std_ms=statistics.pstdev(times) if len(times) > 1 else 0.0, - n_reps=n_reps, + num_reps=num_reps, ) @@ -590,19 +590,19 @@ def _time_for_throughput(r: VariantResult) -> float | None: columns.pop() widths = [max(len(header), *(len(v) for v in values)) for header, values in columns] - sep = " " + separator = " " # First column (case name + variant names) is text — left-justify. All other # columns are numeric — right-justify so decimal points line up across rows. def _justify(text: str, width: int, column_index: int) -> str: return text.ljust(width) if column_index == 0 else text.rjust(width) - header_line = sep.join(_justify(h, w, i) for i, ((h, _), w) in enumerate(zip(columns, widths))) - divider = sep.join("-" * w for w in widths) + header_line = separator.join(_justify(h, w, i) for i, ((h, _), w) in enumerate(zip(columns, widths))) + divider = separator.join("-" * w for w in widths) body_lines = [] for row in range(len(results)): body_lines.append( - sep.join(_justify(values[row], w, i) for i, ((_, values), w) in enumerate(zip(columns, widths))) + separator.join(_justify(values[row], w, i) for i, ((_, values), w) in enumerate(zip(columns, widths))) ) return "\n".join([header_line, divider, *body_lines]) From 031980cc6ba2508bd8d5e27e42c13b778a828484 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Wed, 29 Apr 2026 21:24:48 -0400 Subject: [PATCH 22/41] Convert inspect_rotary_compile.py to pathlib Replace os.path / os.makedirs / os.walk with pathlib.Path per project convention. Co-Authored-By: Claude Sonnet 4.6 --- tools/inspect_rotary_compile.py | 27 ++++++++++++--------------- 1 file changed, 12 insertions(+), 15 deletions(-) diff --git a/tools/inspect_rotary_compile.py b/tools/inspect_rotary_compile.py index 0493e6cbf..eac1624a2 100644 --- a/tools/inspect_rotary_compile.py +++ b/tools/inspect_rotary_compile.py @@ -9,14 +9,15 @@ """ import os +from pathlib import Path import torch import torch._inductor.config as inductor_config # Route torch.compile output to a known directory. -_OUT = "/tmp/torchinductor_rotary_inspect" -os.makedirs(_OUT, exist_ok=True) -os.environ["TORCHINDUCTOR_CACHE_DIR"] = _OUT +_OUT = Path("/tmp/torchinductor_rotary_inspect") +_OUT.mkdir(parents=True, exist_ok=True) +os.environ["TORCHINDUCTOR_CACHE_DIR"] = str(_OUT) inductor_config.debug = True # writes generated Triton .py files alongside the cache @@ -49,15 +50,11 @@ def _rotary_eager(input_: torch.Tensor, frequencies: torch.Tensor) -> torch.Tens print(f"\nInductor cache / debug output dir: {_OUT}") # Find and print the generated Triton kernel files. -for root, dirs, files in os.walk(_OUT): - for fname in sorted(files): - if fname.endswith(".py"): - path = os.path.join(root, fname) - print(f"\n{'='*80}") - print(f"FILE: {path}") - print("=" * 80) - with open(path) as f: - lines = f.readlines() - print("".join(lines[:300])) - if len(lines) > 300: - print(f"... ({len(lines) - 300} more lines)") +for path in sorted(_OUT.rglob("*.py")): + print(f"\n{'='*80}") + print(f"FILE: {path}") + print("=" * 80) + lines = path.read_text().splitlines(keepends=True) + print("".join(lines[:300])) + if len(lines) > 300: + print(f"... ({len(lines) - 300} more lines)") From 71f1184f66160c96700f821590973f23d6c57ef3 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Wed, 29 Apr 2026 21:59:26 -0400 Subject: [PATCH 23/41] Add sparse linear autograd tests and benchmark smoke test MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit test_sparse_matmul.py: - Extend _SPARSE_TEST_DATAS with single-expert and fully-packed (no-padding) cases; all four existing kernel tests now run against these too - Add _output_sparse_linear_ref / _input_sparse_linear_ref Python-loop references (PyTorch-differentiable) - Add test_output_sparse_linear_autograd / test_input_sparse_linear_autograd comparing fwd output and both gradients against the loop reference — would have caught the swapped lhs/grad_out bug in the autograd backward tests/tools/test_benchmark_smoke.py: - test_run_benchmark_wiring: exercises Case/Variant/run_benchmark end-to-end with a trivial relu kernel; always runs without a GPU - test_bench_pointwise_smoke: monkeypatches _SIZES_NUMEL=[1024] and calls the real bench_pointwise pipeline, asserting fp32_reference and pytorch_eager variants succeed Co-Authored-By: Claude Sonnet 4.6 --- tests/functional/test_sparse_matmul.py | 89 +++++++++++++++++++++++++- tests/tools/test_benchmark_smoke.py | 51 +++++++++++++++ 2 files changed, 139 insertions(+), 1 deletion(-) create mode 100644 tests/tools/test_benchmark_smoke.py diff --git a/tests/functional/test_sparse_matmul.py b/tests/functional/test_sparse_matmul.py index 0ebf9c5a5..f48a216e9 100644 --- a/tests/functional/test_sparse_matmul.py +++ b/tests/functional/test_sparse_matmul.py @@ -6,6 +6,8 @@ from fast_llm.functional.triton.sparse_copy import SparseMap from fast_llm.functional.triton.sparse_linear import ( + InputSparseLinear, + OutputSparseLinear, dense_matmul, input_inner_sparse_matmul, input_row_sparse_matmul, @@ -74,7 +76,21 @@ def normal(self, dim_0: int, dim_1: int, device: torch.device) -> torch.Tensor: dense_dim=256, sparse_dim=512, expert_ends=(128, 256, 256, 384), - tokens_per_expert=(52, 125, 0, 97), + tokens_per_expert=(52, 125, 0, 97), # expert 2 has zero real tokens + ), + # Single expert — the simplest non-trivial case; also exercises the no-padding path. + _SparseTestData( + dense_dim=512, + sparse_dim=256, + expert_ends=(256,), + tokens_per_expert=(200,), + ), + # Four experts, fully packed (no padding rows) — exercises the pad_begin == expert_end path. + _SparseTestData( + dense_dim=384, + sparse_dim=128, + expert_ends=(64, 128, 192, 256), + tokens_per_expert=(64, 64, 64, 64), ), ) @@ -151,3 +167,74 @@ def test_input_row_sparse_matmul(sparse_test_data, testing_device): ) Assert.rms_close(output, output_ref, 1e-3) + + +# --------------------------------------------------------------------------- autograd wrappers + + +def _output_sparse_linear_ref(lhs: torch.Tensor, rhs: torch.Tensor, data: _SparseTestData) -> torch.Tensor: + """OutputSparseLinear forward via Python loop (PyTorch-differentiable).""" + ffn = rhs.shape[1] // data.num_experts + out = lhs.new_zeros(data.token_dim, ffn) + for i in range(data.num_experts): + begin, end = data.expert_begins[i], data.expert_pad_begins[i] + if end > begin: + out[begin:end] = lhs[begin:end] @ rhs[:, i * ffn : (i + 1) * ffn] + return out + + +def _input_sparse_linear_ref(lhs: torch.Tensor, rhs: torch.Tensor, data: _SparseTestData) -> torch.Tensor: + """InputSparseLinear forward via Python loop (PyTorch-differentiable).""" + ffn = rhs.shape[0] // data.num_experts + out = lhs.new_zeros(data.token_dim, rhs.shape[1]) + for i in range(data.num_experts): + begin, end = data.expert_begins[i], data.expert_pad_begins[i] + if end > begin: + out[begin:end] = lhs[begin:end] @ rhs[i * ffn : (i + 1) * ffn] + return out + + +@requires_triton +@pytest.mark.slow +@pytest.mark.parametrize("sparse_test_data", _SPARSE_TEST_DATAS) +def test_output_sparse_linear_autograd(sparse_test_data, testing_device): + lhs = sparse_test_data.normal(sparse_test_data.token_dim, sparse_test_data.dense_dim, testing_device) + rhs = sparse_test_data.normal(sparse_test_data.dense_dim, sparse_test_data.sparse_dim_expanded, testing_device) + grad_output = sparse_test_data.normal(sparse_test_data.token_dim, sparse_test_data.sparse_dim, testing_device) + + lhs_ref = lhs.detach().requires_grad_(True) + rhs_ref = rhs.detach().requires_grad_(True) + out_ref = _output_sparse_linear_ref(lhs_ref, rhs_ref, sparse_test_data) + out_ref.backward(grad_output) + + lhs_t = lhs.detach().requires_grad_(True) + rhs_t = rhs.detach().requires_grad_(True) + out_t = OutputSparseLinear.apply(lhs_t, rhs_t, sparse_test_data.get_sparse_map(testing_device)) + out_t.backward(grad_output) + + Assert.rms_close(out_t, out_ref, 1e-3) + Assert.rms_close(lhs_t.grad, lhs_ref.grad, 1e-3) + Assert.rms_close(rhs_t.grad, rhs_ref.grad, 1e-3) + + +@requires_triton +@pytest.mark.slow +@pytest.mark.parametrize("sparse_test_data", _SPARSE_TEST_DATAS) +def test_input_sparse_linear_autograd(sparse_test_data, testing_device): + lhs = sparse_test_data.normal(sparse_test_data.token_dim, sparse_test_data.sparse_dim, testing_device) + rhs = sparse_test_data.normal(sparse_test_data.sparse_dim_expanded, sparse_test_data.dense_dim, testing_device) + grad_output = sparse_test_data.normal(sparse_test_data.token_dim, sparse_test_data.dense_dim, testing_device) + + lhs_ref = lhs.detach().requires_grad_(True) + rhs_ref = rhs.detach().requires_grad_(True) + out_ref = _input_sparse_linear_ref(lhs_ref, rhs_ref, sparse_test_data) + out_ref.backward(grad_output) + + lhs_t = lhs.detach().requires_grad_(True) + rhs_t = rhs.detach().requires_grad_(True) + out_t = InputSparseLinear.apply(lhs_t, rhs_t, sparse_test_data.get_sparse_map(testing_device)) + out_t.backward(grad_output) + + Assert.rms_close(out_t, out_ref, 1e-3) + Assert.rms_close(lhs_t.grad, lhs_ref.grad, 1e-3) + Assert.rms_close(rhs_t.grad, rhs_ref.grad, 1e-3) diff --git a/tests/tools/test_benchmark_smoke.py b/tests/tools/test_benchmark_smoke.py new file mode 100644 index 000000000..5d14280b9 --- /dev/null +++ b/tests/tools/test_benchmark_smoke.py @@ -0,0 +1,51 @@ +""" +Smoke tests for the benchmark runner and bench_pointwise wiring. + +Both tests run without a GPU: bench_fn falls back to wall-clock timing (one +warmup call + one timed call) when CUDA is unavailable, and device() returns +"cpu". The compiled variants may fail on CPU — that is expected and does not +cause the test to fail; the runner records the error per-variant rather than +raising. +""" + +import torch + +import tools.benchmark.bench_pointwise as bench_pointwise +from tools.benchmark.runner import Case, Variant, run_benchmark + + +def test_run_benchmark_wiring(): + """Core runner machinery (Case, Variant, correctness comparison, table printing) works end-to-end.""" + + def _relu_fp32(inputs: dict) -> dict: + return {"out": torch.relu(inputs["x"].float())} + + def _relu(inputs: dict) -> dict: + return {"out": torch.relu(inputs["x"])} + + cases = [Case(name="relu_256", make_inputs=lambda: {"x": torch.randn(256)})] + variants = [ + Variant(name="fp32_reference", fwd=_relu_fp32, is_reference=True), + Variant(name="eager", fwd=_relu), + ] + results = run_benchmark("smoke: relu", cases, variants) + assert len(results) == 1 + _case, variant_results = results[0] + assert all(r.error is None for r in variant_results), [r.error for r in variant_results] + + +def test_bench_pointwise_smoke(monkeypatch): + """bench_pointwise case/variant wiring is intact end-to-end with tiny inputs.""" + monkeypatch.setattr(bench_pointwise, "_SIZES_NUMEL", [1024]) + + for make_cases, variants, label in [ + (bench_pointwise._copy_cases, bench_pointwise._COPY_VARIANTS, "copy"), + (bench_pointwise._fill_cases, bench_pointwise._FILL_VARIANTS, "fill"), + (bench_pointwise._add_cases, bench_pointwise._ADD_VARIANTS, "add"), + ]: + results = run_benchmark(f"smoke: {label}", make_cases((torch.float32,)), variants) + assert len(results) == 1, f"{label}: expected 1 case, got {len(results)}" + _case, variant_results = results[0] + for r in variant_results: + if r.variant_name in ("fp32_reference", "pytorch_eager"): + assert r.error is None, f"{label}/{r.variant_name}: {r.error}" From a6c6e51adfc666b77527ab4757f06ec79398325f Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Wed, 29 Apr 2026 22:02:57 -0400 Subject: [PATCH 24/41] Remove benchmark smoke test The runner is a dev tool; its correctness is covered by the sparse matmul autograd tests that exercise the same kernel path in CI. Co-Authored-By: Claude Sonnet 4.6 --- tests/tools/test_benchmark_smoke.py | 51 ----------------------------- 1 file changed, 51 deletions(-) delete mode 100644 tests/tools/test_benchmark_smoke.py diff --git a/tests/tools/test_benchmark_smoke.py b/tests/tools/test_benchmark_smoke.py deleted file mode 100644 index 5d14280b9..000000000 --- a/tests/tools/test_benchmark_smoke.py +++ /dev/null @@ -1,51 +0,0 @@ -""" -Smoke tests for the benchmark runner and bench_pointwise wiring. - -Both tests run without a GPU: bench_fn falls back to wall-clock timing (one -warmup call + one timed call) when CUDA is unavailable, and device() returns -"cpu". The compiled variants may fail on CPU — that is expected and does not -cause the test to fail; the runner records the error per-variant rather than -raising. -""" - -import torch - -import tools.benchmark.bench_pointwise as bench_pointwise -from tools.benchmark.runner import Case, Variant, run_benchmark - - -def test_run_benchmark_wiring(): - """Core runner machinery (Case, Variant, correctness comparison, table printing) works end-to-end.""" - - def _relu_fp32(inputs: dict) -> dict: - return {"out": torch.relu(inputs["x"].float())} - - def _relu(inputs: dict) -> dict: - return {"out": torch.relu(inputs["x"])} - - cases = [Case(name="relu_256", make_inputs=lambda: {"x": torch.randn(256)})] - variants = [ - Variant(name="fp32_reference", fwd=_relu_fp32, is_reference=True), - Variant(name="eager", fwd=_relu), - ] - results = run_benchmark("smoke: relu", cases, variants) - assert len(results) == 1 - _case, variant_results = results[0] - assert all(r.error is None for r in variant_results), [r.error for r in variant_results] - - -def test_bench_pointwise_smoke(monkeypatch): - """bench_pointwise case/variant wiring is intact end-to-end with tiny inputs.""" - monkeypatch.setattr(bench_pointwise, "_SIZES_NUMEL", [1024]) - - for make_cases, variants, label in [ - (bench_pointwise._copy_cases, bench_pointwise._COPY_VARIANTS, "copy"), - (bench_pointwise._fill_cases, bench_pointwise._FILL_VARIANTS, "fill"), - (bench_pointwise._add_cases, bench_pointwise._ADD_VARIANTS, "add"), - ]: - results = run_benchmark(f"smoke: {label}", make_cases((torch.float32,)), variants) - assert len(results) == 1, f"{label}: expected 1 case, got {len(results)}" - _case, variant_results = results[0] - for r in variant_results: - if r.variant_name in ("fp32_reference", "pytorch_eager"): - assert r.error is None, f"{label}/{r.variant_name}: {r.error}" From e4eaab7c6352bb89012a0aeaccc184a557a16892 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Wed, 29 Apr 2026 22:06:01 -0400 Subject: [PATCH 25/41] Fix two nits from rename pass MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - bench_mlp_activation.py: FLOPs/element_size → FLOPs/element in comment (rename script incorrectly replaced elem in a comment where it meant element, not element_size) - runner.py: Variant.reset_inputs return type None → Any (copy_() returns the destination tensor; runner discards it, so functionally fine) Co-Authored-By: Claude Sonnet 4.6 --- tools/benchmark/bench_mlp_activation.py | 2 +- tools/benchmark/runner.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tools/benchmark/bench_mlp_activation.py b/tools/benchmark/bench_mlp_activation.py index 8d20423a3..b6ceceb13 100644 --- a/tools/benchmark/bench_mlp_activation.py +++ b/tools/benchmark/bench_mlp_activation.py @@ -150,7 +150,7 @@ def _mlp_activation_bytes(tokens: int, ffn_dim: int, dtype: torch.dtype) -> int: def _mlp_activation_flops(tokens: int, ffn_dim: int) -> int: - # gated silu: fwd ≈ 6 FLOPs/element_size, bwd ≈ 8 FLOPs/element_size, total ≈ 14 per output element. + # gated silu: fwd ≈ 6 FLOPs/element, bwd ≈ 8 FLOPs/element, total ≈ 14 per output element. return 14 * tokens * ffn_dim diff --git a/tools/benchmark/runner.py b/tools/benchmark/runner.py index e9547d968..35daacfc8 100644 --- a/tools/benchmark/runner.py +++ b/tools/benchmark/runner.py @@ -71,7 +71,7 @@ class Variant: # Called between timing reps (outside the timed region) to restore any # input tensors the variant mutates in-place. Use this instead of cloning # inside the timed callable so the mutation cost is not measured. - reset_inputs: Callable[[Inputs], None] | None = None + reset_inputs: Callable[[Inputs], Any] | None = None @dataclasses.dataclass From 2fbe934e05aec4188bf1533ef32ddf159fc5e601 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Wed, 29 Apr 2026 22:07:05 -0400 Subject: [PATCH 26/41] Fix bench_fn: call reset() during warmup and calibration phases Warmup calls and the post-warmup one_rep_ms calibration were not calling reset(), so variants with reset_inputs (in-place rotary, PyTorch fwd_bwd entropy/grpo) ran their warmup reps in a dirty input state. The timed reps were already correct. This biased the num_reps estimate slightly low. Co-Authored-By: Claude Sonnet 4.6 --- tools/benchmark/runner.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/tools/benchmark/runner.py b/tools/benchmark/runner.py index 35daacfc8..f9f101d21 100644 --- a/tools/benchmark/runner.py +++ b/tools/benchmark/runner.py @@ -156,7 +156,11 @@ def bench_fn( """ if not torch.cuda.is_available(): # CPU / Triton interpret: single timed run with wall clock. + if reset is not None: + reset() fn() # warmup + if reset is not None: + reset() start = time.perf_counter() fn() elapsed_ms = (time.perf_counter() - start) * 1000 @@ -168,6 +172,8 @@ def bench_fn( torch.cuda.synchronize() warmup_start = torch.cuda.Event(enable_timing=True) warmup_end = torch.cuda.Event(enable_timing=True) + if reset is not None: + reset() warmup_start.record() fn() warmup_end.record() @@ -177,12 +183,16 @@ def bench_fn( # Additional warmup to stabilize (covers autotune misses on first call) num_warmup = max(1, int(warmup_ms / max(one_rep_ms, 0.01))) for _ in range(num_warmup): + if reset is not None: + reset() fn() torch.cuda.synchronize() # Re-estimate after warmup (autotune usually settles to a faster config). post_start = torch.cuda.Event(enable_timing=True) post_end = torch.cuda.Event(enable_timing=True) + if reset is not None: + reset() post_start.record() fn() post_end.record() From 8441615efff54195b618371498cb9ce6285eecf1 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Wed, 29 Apr 2026 23:28:49 -0400 Subject: [PATCH 27/41] Add shapes parameter to bench run() functions and smoke tests All eight benchmark modules now accept a `shapes` parameter on their `run()` entry point (and the internal `_*_cases()` helpers), so callers can supply a custom list of input sizes without monkey-patching module globals. `tests/tools/test_triton_benchmark.py` uses this to run every module with one tiny shape and float32 dtype, keeping the full runner code path exercised without the long compile/autotune time of production shapes. The `_disable_dynamo` fixture suppresses torch.compile cold-start (~20 s per variant on CPU). The sparse_linear shape uses hidden=ffn_per_expert=256 to satisfy the Triton kernel's block-size divisibility assertions. Co-Authored-By: Claude Sonnet 4.6 --- tests/tools/test_triton_benchmark.py | 49 +++++++++++++++++++++++++ tools/benchmark/bench_entropy_loss.py | 28 +++++++++----- tools/benchmark/bench_grpo_loss.py | 13 +++++-- tools/benchmark/bench_mlp_activation.py | 16 ++++++-- tools/benchmark/bench_normalization.py | 22 +++++++---- tools/benchmark/bench_pointwise.py | 27 +++++++++----- tools/benchmark/bench_rotary.py | 13 +++++-- tools/benchmark/bench_sparse_copy.py | 24 ++++++++---- tools/benchmark/bench_sparse_linear.py | 24 ++++++++---- 9 files changed, 164 insertions(+), 52 deletions(-) create mode 100644 tests/tools/test_triton_benchmark.py diff --git a/tests/tools/test_triton_benchmark.py b/tests/tools/test_triton_benchmark.py new file mode 100644 index 000000000..4ec1f9b33 --- /dev/null +++ b/tests/tools/test_triton_benchmark.py @@ -0,0 +1,49 @@ +""" +Smoke tests for all benchmark modules. + +Each test runs a single benchmark module with one tiny shape and float32 dtype +so the full runner code path is exercised quickly on CPU. torch.compile is +disabled via a fixture to avoid the JIT cold-start (~20 s on CPU per variant). +""" + +import pytest +import torch + +from tools.benchmark import ( + bench_entropy_loss, + bench_grpo_loss, + bench_mlp_activation, + bench_normalization, + bench_pointwise, + bench_rotary, + bench_sparse_copy, + bench_sparse_linear, +) + +_DTYPES = (torch.float32,) + +_PARAMS = [ + pytest.param(bench_entropy_loss, {"shapes": [(64, 256)]}, id="entropy_loss"), + pytest.param(bench_grpo_loss, {"shapes": [(64, 256)]}, id="grpo_loss"), + pytest.param(bench_mlp_activation, {"shapes": [(64, 128)]}, id="mlp_activation"), + pytest.param(bench_normalization, {"shapes": [(64, 128)]}, id="normalization"), + pytest.param(bench_pointwise, {"shapes": [1024]}, id="pointwise"), + pytest.param(bench_rotary, {"shapes": [(64, 4, 64)]}, id="rotary"), + pytest.param(bench_sparse_copy, {"shapes": [(64, 2, 4, 128)]}, id="sparse_copy"), + pytest.param(bench_sparse_linear, {"shapes": [(64, 2, 4, 256, 256)]}, id="sparse_linear"), +] + + +@pytest.fixture(autouse=True) +def _disable_dynamo(): + import torch._dynamo + + orig = torch._dynamo.config.disable + torch._dynamo.config.disable = True + yield + torch._dynamo.config.disable = orig + + +@pytest.mark.parametrize("module,kwargs", _PARAMS) +def test_triton_benchmark(module, kwargs): + module.run(dtypes=_DTYPES, **kwargs) diff --git a/tools/benchmark/bench_entropy_loss.py b/tools/benchmark/bench_entropy_loss.py index b5a479018..da686b8c9 100644 --- a/tools/benchmark/bench_entropy_loss.py +++ b/tools/benchmark/bench_entropy_loss.py @@ -423,7 +423,10 @@ def _entropy_loss_flops(tokens: int, vocab: int) -> int: return 4 * tokens * vocab -def _label_cases(kernel_name: str, dtypes: tuple[torch.dtype, ...]) -> list[Case]: +def _label_cases( + kernel_name: str, dtypes: tuple[torch.dtype, ...], shapes: list[tuple[int, int]] | None = None +) -> list[Case]: + shapes = shapes if shapes is not None else _SHAPES return [ Case( name=case_name(kernel_name, (tokens, vocab), dtype), @@ -433,11 +436,14 @@ def _label_cases(kernel_name: str, dtypes: tuple[torch.dtype, ...]) -> list[Case compute_dtype=dtype, ) for dtype in dtypes - for tokens, vocab in _SHAPES + for tokens, vocab in shapes ] -def _dist_cases(kernel_name: str, dtypes: tuple[torch.dtype, ...]) -> list[Case]: +def _dist_cases( + kernel_name: str, dtypes: tuple[torch.dtype, ...], shapes: list[tuple[int, int]] | None = None +) -> list[Case]: + shapes = shapes if shapes is not None else _SHAPES return [ Case( name=case_name(kernel_name, (tokens, vocab), dtype), @@ -447,34 +453,38 @@ def _dist_cases(kernel_name: str, dtypes: tuple[torch.dtype, ...]) -> list[Case] compute_dtype=dtype, ) for dtype in dtypes - for tokens, vocab in _SHAPES + for tokens, vocab in shapes ] # --------------------------------------------------------------------------- entry point -def run(verbose: bool = False, dtypes: tuple[torch.dtype, ...] | None = None) -> None: +def run( + verbose: bool = False, + dtypes: tuple[torch.dtype, ...] | None = None, + shapes: list[tuple[int, int]] | None = None, +) -> None: dtypes = tuple(dtypes) if dtypes else _DEFAULT_DTYPES run_benchmark( "entropy_loss: cross_entropy (labels)", - _label_cases("cross_entropy_labels", dtypes), + _label_cases("cross_entropy_labels", dtypes, shapes), _ce_labels_variants(), verbose=verbose, ) run_benchmark( "entropy_loss: cross_entropy (logits)", - _dist_cases("cross_entropy_logits", dtypes), + _dist_cases("cross_entropy_logits", dtypes, shapes), _ce_dist_variants(), verbose=verbose, ) run_benchmark( "entropy_loss: reverse_kl (logits)", - _dist_cases("reverse_kl_logits", dtypes), + _dist_cases("reverse_kl_logits", dtypes, shapes), _reverse_kl_variants(), verbose=verbose, ) - run_benchmark("entropy_loss: z_loss", _label_cases("z_loss", dtypes), _z_loss_variants(), verbose=verbose) + run_benchmark("entropy_loss: z_loss", _label_cases("z_loss", dtypes, shapes), _z_loss_variants(), verbose=verbose) if __name__ == "__main__": diff --git a/tools/benchmark/bench_grpo_loss.py b/tools/benchmark/bench_grpo_loss.py index 4dd5cdd5a..cc5ec46f1 100644 --- a/tools/benchmark/bench_grpo_loss.py +++ b/tools/benchmark/bench_grpo_loss.py @@ -180,7 +180,8 @@ def _grpo_flops(tokens: int, vocab: int) -> int: return 14 * tokens * vocab -def _grpo_cases(dtypes: tuple[torch.dtype, ...]) -> list[Case]: +def _grpo_cases(dtypes: tuple[torch.dtype, ...], shapes: list[tuple[int, int]] | None = None) -> list[Case]: + shapes = shapes if shapes is not None else _SHAPES return [ Case( name=case_name("grpo_loss", (tokens, vocab), dtype), @@ -190,13 +191,17 @@ def _grpo_cases(dtypes: tuple[torch.dtype, ...]) -> list[Case]: compute_dtype=dtype, ) for dtype in dtypes - for tokens, vocab in _SHAPES + for tokens, vocab in shapes ] -def run(verbose: bool = False, dtypes: tuple[torch.dtype, ...] | None = None) -> None: +def run( + verbose: bool = False, + dtypes: tuple[torch.dtype, ...] | None = None, + shapes: list[tuple[int, int]] | None = None, +) -> None: dtypes = tuple(dtypes) if dtypes else _DEFAULT_DTYPES - run_benchmark("grpo_loss", _grpo_cases(dtypes), _grpo_variants(), verbose=verbose) + run_benchmark("grpo_loss", _grpo_cases(dtypes, shapes), _grpo_variants(), verbose=verbose) if __name__ == "__main__": diff --git a/tools/benchmark/bench_mlp_activation.py b/tools/benchmark/bench_mlp_activation.py index b6ceceb13..df92c3fa9 100644 --- a/tools/benchmark/bench_mlp_activation.py +++ b/tools/benchmark/bench_mlp_activation.py @@ -154,7 +154,8 @@ def _mlp_activation_flops(tokens: int, ffn_dim: int) -> int: return 14 * tokens * ffn_dim -def _mlp_activation_cases(dtypes: tuple[torch.dtype, ...]) -> list[Case]: +def _mlp_activation_cases(dtypes: tuple[torch.dtype, ...], shapes: list[tuple[int, int]] | None = None) -> list[Case]: + shapes = shapes if shapes is not None else _SHAPES return [ Case( name=case_name("mlp_activation", (tokens, ffn_dim), dtype), @@ -164,17 +165,24 @@ def _mlp_activation_cases(dtypes: tuple[torch.dtype, ...]) -> list[Case]: compute_dtype=dtype, ) for dtype in dtypes - for tokens, ffn_dim in _SHAPES + for tokens, ffn_dim in shapes ] # --------------------------------------------------------------------------- entry point -def run(verbose: bool = False, dtypes: tuple[torch.dtype, ...] | None = None) -> None: +def run( + verbose: bool = False, + dtypes: tuple[torch.dtype, ...] | None = None, + shapes: list[tuple[int, int]] | None = None, +) -> None: dtypes = tuple(dtypes) if dtypes else _DEFAULT_DTYPES run_benchmark( - "mlp_activation (gated silu)", _mlp_activation_cases(dtypes), _mlp_activation_variants(), verbose=verbose + "mlp_activation (gated silu)", + _mlp_activation_cases(dtypes, shapes), + _mlp_activation_variants(), + verbose=verbose, ) diff --git a/tools/benchmark/bench_normalization.py b/tools/benchmark/bench_normalization.py index 3a3c02044..48717202b 100644 --- a/tools/benchmark/bench_normalization.py +++ b/tools/benchmark/bench_normalization.py @@ -307,7 +307,8 @@ def _rms_norm_flops(rows: int, cols: int) -> int: return 15 * rows * cols -def _layer_norm_cases(dtypes: tuple[torch.dtype, ...]) -> list[Case]: +def _layer_norm_cases(dtypes: tuple[torch.dtype, ...], shapes: list[tuple[int, int]] | None = None) -> list[Case]: + shapes = shapes if shapes is not None else _SHAPES return [ Case( name=case_name("layer_norm", shape, dtype), @@ -317,11 +318,12 @@ def _layer_norm_cases(dtypes: tuple[torch.dtype, ...]) -> list[Case]: compute_dtype=dtype, ) for dtype in dtypes - for shape in _SHAPES + for shape in shapes ] -def _rms_norm_cases(dtypes: tuple[torch.dtype, ...]) -> list[Case]: +def _rms_norm_cases(dtypes: tuple[torch.dtype, ...], shapes: list[tuple[int, int]] | None = None) -> list[Case]: + shapes = shapes if shapes is not None else _SHAPES return [ Case( name=case_name("rms_norm", shape, dtype), @@ -331,17 +333,23 @@ def _rms_norm_cases(dtypes: tuple[torch.dtype, ...]) -> list[Case]: compute_dtype=dtype, ) for dtype in dtypes - for shape in _SHAPES + for shape in shapes ] # --------------------------------------------------------------------------- entry point -def run(verbose: bool = False, dtypes: tuple[torch.dtype, ...] | None = None) -> None: +def run( + verbose: bool = False, + dtypes: tuple[torch.dtype, ...] | None = None, + shapes: list[tuple[int, int]] | None = None, +) -> None: dtypes = tuple(dtypes) if dtypes else _DEFAULT_DTYPES - run_benchmark("normalization: layer_norm", _layer_norm_cases(dtypes), _layer_norm_variants(), verbose=verbose) - run_benchmark("normalization: rms_norm", _rms_norm_cases(dtypes), _rms_norm_variants(), verbose=verbose) + run_benchmark( + "normalization: layer_norm", _layer_norm_cases(dtypes, shapes), _layer_norm_variants(), verbose=verbose + ) + run_benchmark("normalization: rms_norm", _rms_norm_cases(dtypes, shapes), _rms_norm_variants(), verbose=verbose) if __name__ == "__main__": diff --git a/tools/benchmark/bench_pointwise.py b/tools/benchmark/bench_pointwise.py index a9a059114..2648d6e80 100644 --- a/tools/benchmark/bench_pointwise.py +++ b/tools/benchmark/bench_pointwise.py @@ -40,7 +40,8 @@ def _make_copy_inputs(numel: int, dtype: torch.dtype) -> dict: return {"input_": input_, "out": out} -def _copy_cases(dtypes: tuple[torch.dtype, ...]) -> list[Case]: +def _copy_cases(dtypes: tuple[torch.dtype, ...], shapes: list[int] | None = None) -> list[Case]: + sizes = shapes if shapes is not None else _SIZES_NUMEL return [ Case( name=case_name("copy", (numel,), dtype), @@ -49,7 +50,7 @@ def _copy_cases(dtypes: tuple[torch.dtype, ...]) -> list[Case]: expected_bytes=2 * numel * torch.tensor([], dtype=dtype).element_size(), ) for dtype in dtypes - for numel in _SIZES_NUMEL + for numel in sizes ] @@ -71,7 +72,8 @@ def _make_fill_inputs(numel: int, dtype: torch.dtype) -> dict: return {"input_": torch.empty(numel, dtype=dtype, device=device()), "value": 1.5} -def _fill_cases(dtypes: tuple[torch.dtype, ...]) -> list[Case]: +def _fill_cases(dtypes: tuple[torch.dtype, ...], shapes: list[int] | None = None) -> list[Case]: + sizes = shapes if shapes is not None else _SIZES_NUMEL return [ Case( name=case_name("fill", (numel,), dtype), @@ -80,7 +82,7 @@ def _fill_cases(dtypes: tuple[torch.dtype, ...]) -> list[Case]: expected_bytes=numel * torch.tensor([], dtype=dtype).element_size(), ) for dtype in dtypes - for numel in _SIZES_NUMEL + for numel in sizes ] @@ -106,7 +108,8 @@ def _make_add_inputs(numel: int, dtype: torch.dtype) -> dict: } -def _add_cases(dtypes: tuple[torch.dtype, ...]) -> list[Case]: +def _add_cases(dtypes: tuple[torch.dtype, ...], shapes: list[int] | None = None) -> list[Case]: + sizes = shapes if shapes is not None else _SIZES_NUMEL return [ Case( name=case_name("add", (numel,), dtype), @@ -118,7 +121,7 @@ def _add_cases(dtypes: tuple[torch.dtype, ...]) -> list[Case]: compute_dtype=dtype, ) for dtype in dtypes - for numel in _SIZES_NUMEL + for numel in sizes ] @@ -132,11 +135,15 @@ def _add_cases(dtypes: tuple[torch.dtype, ...]) -> list[Case]: # --------------------------------------------------------------------------- entry point -def run(verbose: bool = False, dtypes: tuple[torch.dtype, ...] | None = None) -> None: +def run( + verbose: bool = False, + dtypes: tuple[torch.dtype, ...] | None = None, + shapes: list[int] | None = None, +) -> None: dtypes = tuple(dtypes) if dtypes else _DEFAULT_DTYPES - run_benchmark("pointwise: copy", _copy_cases(dtypes), _COPY_VARIANTS, verbose=verbose) - run_benchmark("pointwise: fill", _fill_cases(dtypes), _FILL_VARIANTS, verbose=verbose) - run_benchmark("pointwise: add", _add_cases(dtypes), _ADD_VARIANTS, verbose=verbose) + run_benchmark("pointwise: copy", _copy_cases(dtypes, shapes), _COPY_VARIANTS, verbose=verbose) + run_benchmark("pointwise: fill", _fill_cases(dtypes, shapes), _FILL_VARIANTS, verbose=verbose) + run_benchmark("pointwise: add", _add_cases(dtypes, shapes), _ADD_VARIANTS, verbose=verbose) if __name__ == "__main__": diff --git a/tools/benchmark/bench_rotary.py b/tools/benchmark/bench_rotary.py index 9d75e6815..5592bdc1a 100644 --- a/tools/benchmark/bench_rotary.py +++ b/tools/benchmark/bench_rotary.py @@ -98,7 +98,8 @@ def _rotary_variants() -> list[Variant]: return variants -def _rotary_cases(dtypes: tuple[torch.dtype, ...]) -> list[Case]: +def _rotary_cases(dtypes: tuple[torch.dtype, ...], shapes: list[tuple[int, int, int]] | None = None) -> list[Case]: + shapes = shapes if shapes is not None else _SHAPES return [ Case( name=case_name("rotary", (tokens, num_heads, head_size), dtype), @@ -108,13 +109,17 @@ def _rotary_cases(dtypes: tuple[torch.dtype, ...]) -> list[Case]: compute_dtype=dtype, ) for dtype in dtypes - for tokens, num_heads, head_size in _SHAPES + for tokens, num_heads, head_size in shapes ] -def run(verbose: bool = False, dtypes: tuple[torch.dtype, ...] | None = None) -> None: +def run( + verbose: bool = False, + dtypes: tuple[torch.dtype, ...] | None = None, + shapes: list[tuple[int, int, int]] | None = None, +) -> None: dtypes = tuple(dtypes) if dtypes else _DEFAULT_DTYPES - run_benchmark("rotary", _rotary_cases(dtypes), _rotary_variants(), verbose=verbose) + run_benchmark("rotary", _rotary_cases(dtypes, shapes), _rotary_variants(), verbose=verbose) if __name__ == "__main__": diff --git a/tools/benchmark/bench_sparse_copy.py b/tools/benchmark/bench_sparse_copy.py index ddd58a9d5..a2997bf7d 100644 --- a/tools/benchmark/bench_sparse_copy.py +++ b/tools/benchmark/bench_sparse_copy.py @@ -298,7 +298,10 @@ def _combine_bytes(tokens: int, top_k: int, hidden: int, dtype: torch.dtype) -> return 2 * (sparse_rows + tokens) * hidden * element_size + 4 * tokens * top_k * element_size -def _dispatch_cases(dtypes: tuple[torch.dtype, ...]) -> list[Case]: +def _dispatch_cases( + dtypes: tuple[torch.dtype, ...], shapes: list[tuple[int, int, int, int]] | None = None +) -> list[Case]: + shapes = shapes if shapes is not None else _SHAPES return [ Case( name=case_name("dispatch", (tokens, top_k, num_experts, hidden), dtype), @@ -308,11 +311,14 @@ def _dispatch_cases(dtypes: tuple[torch.dtype, ...]) -> list[Case]: compute_dtype=dtype, ) for dtype in dtypes - for tokens, top_k, num_experts, hidden in _SHAPES + for tokens, top_k, num_experts, hidden in shapes ] -def _combine_cases(dtypes: tuple[torch.dtype, ...]) -> list[Case]: +def _combine_cases( + dtypes: tuple[torch.dtype, ...], shapes: list[tuple[int, int, int, int]] | None = None +) -> list[Case]: + shapes = shapes if shapes is not None else _SHAPES return [ Case( name=case_name("combine", (tokens, top_k, num_experts, hidden), dtype), @@ -322,17 +328,21 @@ def _combine_cases(dtypes: tuple[torch.dtype, ...]) -> list[Case]: compute_dtype=dtype, ) for dtype in dtypes - for tokens, top_k, num_experts, hidden in _SHAPES + for tokens, top_k, num_experts, hidden in shapes ] # --------------------------------------------------------------------------- entry point -def run(verbose: bool = False, dtypes: tuple[torch.dtype, ...] | None = None) -> None: +def run( + verbose: bool = False, + dtypes: tuple[torch.dtype, ...] | None = None, + shapes: list[tuple[int, int, int, int]] | None = None, +) -> None: dtypes = tuple(dtypes) if dtypes else _DEFAULT_DTYPES - run_benchmark("sparse_copy: dispatch", _dispatch_cases(dtypes), _dispatch_variants(), verbose=verbose) - run_benchmark("sparse_copy: combine", _combine_cases(dtypes), _combine_variants(), verbose=verbose) + run_benchmark("sparse_copy: dispatch", _dispatch_cases(dtypes, shapes), _dispatch_variants(), verbose=verbose) + run_benchmark("sparse_copy: combine", _combine_cases(dtypes, shapes), _combine_variants(), verbose=verbose) if __name__ == "__main__": diff --git a/tools/benchmark/bench_sparse_linear.py b/tools/benchmark/bench_sparse_linear.py index 42d8e0fa5..e799944d7 100644 --- a/tools/benchmark/bench_sparse_linear.py +++ b/tools/benchmark/bench_sparse_linear.py @@ -305,7 +305,10 @@ def _sparse_linear_flops(sparse_tokens_unpadded: int, hidden: int, ffn_per_exper return 3 * 2 * sparse_tokens_unpadded * hidden * ffn_per_expert -def _output_sparse_cases(dtypes: tuple[torch.dtype, ...]) -> list[Case]: +def _output_sparse_cases( + dtypes: tuple[torch.dtype, ...], shapes: list[tuple[int, int, int, int, int]] | None = None +) -> list[Case]: + shapes = shapes if shapes is not None else _SHAPES return [ Case( name=case_name("output_sparse", (tokens, top_k, num_experts, hidden, ffn_per_expert), dtype), @@ -315,11 +318,14 @@ def _output_sparse_cases(dtypes: tuple[torch.dtype, ...]) -> list[Case]: compute_dtype=dtype, ) for dtype in dtypes - for tokens, top_k, num_experts, hidden, ffn_per_expert in _SHAPES + for tokens, top_k, num_experts, hidden, ffn_per_expert in shapes ] -def _input_inner_sparse_cases(dtypes: tuple[torch.dtype, ...]) -> list[Case]: +def _input_inner_sparse_cases( + dtypes: tuple[torch.dtype, ...], shapes: list[tuple[int, int, int, int, int]] | None = None +) -> list[Case]: + shapes = shapes if shapes is not None else _SHAPES return [ Case( name=case_name("input_inner_sparse", (tokens, top_k, num_experts, hidden, ffn_per_expert), dtype), @@ -331,24 +337,28 @@ def _input_inner_sparse_cases(dtypes: tuple[torch.dtype, ...]) -> list[Case]: compute_dtype=dtype, ) for dtype in dtypes - for tokens, top_k, num_experts, hidden, ffn_per_expert in _SHAPES + for tokens, top_k, num_experts, hidden, ffn_per_expert in shapes ] # --------------------------------------------------------------------------- entry point -def run(verbose: bool = False, dtypes: tuple[torch.dtype, ...] | None = None) -> None: +def run( + verbose: bool = False, + dtypes: tuple[torch.dtype, ...] | None = None, + shapes: list[tuple[int, int, int, int, int]] | None = None, +) -> None: dtypes = tuple(dtypes) if dtypes else _DEFAULT_DTYPES run_benchmark( "sparse_linear: output_sparse (layer 1 / up-proj)", - _output_sparse_cases(dtypes), + _output_sparse_cases(dtypes, shapes), _output_sparse_variants(), verbose=verbose, ) run_benchmark( "sparse_linear: input_inner_sparse (layer 2 / down-proj)", - _input_inner_sparse_cases(dtypes), + _input_inner_sparse_cases(dtypes, shapes), _input_inner_sparse_variants(), verbose=verbose, ) From 80631753f59b052b22ece32d18a03c45dc032c09 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Wed, 29 Apr 2026 23:55:16 -0400 Subject: [PATCH 28/41] Refactor benchmark tests: per-kernel parametrization, single timed rep Add benchmarks() to all 8 bench modules (returning a list of (name, cases, variants) tuples) so each sub-benchmark is addressable independently. Thread min_reps through runner.py's bench_fn and run_benchmark so callers can cap rep count. Rewrite tests/tools/test_triton_benchmark.py to build one pytest parameter per kernel from benchmarks() calls (16 tests instead of 8 file-level tests) and run each with warmup_ms=0, rep_ms=0, min_reps=1, reducing the test suite from ~80s to ~45s while covering every kernel. Co-Authored-By: Claude Sonnet 4.6 --- tests/tools/test_triton_benchmark.py | 44 ++++++++++++++-------- tools/benchmark/bench_entropy_loss.py | 50 +++++++++++++++---------- tools/benchmark/bench_grpo_loss.py | 15 +++++++- tools/benchmark/bench_mlp_activation.py | 20 ++++++---- tools/benchmark/bench_normalization.py | 21 ++++++++--- tools/benchmark/bench_pointwise.py | 21 +++++++++-- tools/benchmark/bench_rotary.py | 15 +++++++- tools/benchmark/bench_sparse_copy.py | 19 ++++++++-- tools/benchmark/bench_sparse_linear.py | 37 +++++++++++------- tools/benchmark/runner.py | 10 +++-- 10 files changed, 177 insertions(+), 75 deletions(-) diff --git a/tests/tools/test_triton_benchmark.py b/tests/tools/test_triton_benchmark.py index 4ec1f9b33..2f5f2f772 100644 --- a/tests/tools/test_triton_benchmark.py +++ b/tests/tools/test_triton_benchmark.py @@ -1,9 +1,10 @@ """ Smoke tests for all benchmark modules. -Each test runs a single benchmark module with one tiny shape and float32 dtype -so the full runner code path is exercised quickly on CPU. torch.compile is -disabled via a fixture to avoid the JIT cold-start (~20 s on CPU per variant). +One test per sub-benchmark (kernel): inputs are tiny so the runner code path is +exercised quickly on CPU. torch.compile is disabled via a fixture to avoid the +JIT cold-start (~20 s on CPU per variant). warmup_ms=0 and min_reps=1 cap the +rep count to one timed call per variant so the suite stays fast. """ import pytest @@ -19,19 +20,30 @@ bench_sparse_copy, bench_sparse_linear, ) +from tools.benchmark.runner import run_benchmark _DTYPES = (torch.float32,) -_PARAMS = [ - pytest.param(bench_entropy_loss, {"shapes": [(64, 256)]}, id="entropy_loss"), - pytest.param(bench_grpo_loss, {"shapes": [(64, 256)]}, id="grpo_loss"), - pytest.param(bench_mlp_activation, {"shapes": [(64, 128)]}, id="mlp_activation"), - pytest.param(bench_normalization, {"shapes": [(64, 128)]}, id="normalization"), - pytest.param(bench_pointwise, {"shapes": [1024]}, id="pointwise"), - pytest.param(bench_rotary, {"shapes": [(64, 4, 64)]}, id="rotary"), - pytest.param(bench_sparse_copy, {"shapes": [(64, 2, 4, 128)]}, id="sparse_copy"), - pytest.param(bench_sparse_linear, {"shapes": [(64, 2, 4, 256, 256)]}, id="sparse_linear"), -] + +def _build_params() -> list: + modules_and_shapes = [ + (bench_entropy_loss, {"shapes": [(64, 256)]}), + (bench_grpo_loss, {"shapes": [(64, 256)]}), + (bench_mlp_activation, {"shapes": [(64, 128)]}), + (bench_normalization, {"shapes": [(64, 128)]}), + (bench_pointwise, {"shapes": [1024]}), + (bench_rotary, {"shapes": [(64, 4, 64)]}), + (bench_sparse_copy, {"shapes": [(64, 2, 4, 128)]}), + (bench_sparse_linear, {"shapes": [(64, 2, 4, 256, 256)]}), + ] + params = [] + for module, kwargs in modules_and_shapes: + for name, cases, variants in module.benchmarks(dtypes=_DTYPES, **kwargs): + params.append(pytest.param(name, cases, variants, id=name)) + return params + + +_PARAMS = _build_params() @pytest.fixture(autouse=True) @@ -44,6 +56,6 @@ def _disable_dynamo(): torch._dynamo.config.disable = orig -@pytest.mark.parametrize("module,kwargs", _PARAMS) -def test_triton_benchmark(module, kwargs): - module.run(dtypes=_DTYPES, **kwargs) +@pytest.mark.parametrize("name,cases,variants", _PARAMS) +def test_triton_benchmark(name, cases, variants): + run_benchmark(name, cases, variants, warmup_ms=0, rep_ms=0, min_reps=1) diff --git a/tools/benchmark/bench_entropy_loss.py b/tools/benchmark/bench_entropy_loss.py index da686b8c9..2e936bd04 100644 --- a/tools/benchmark/bench_entropy_loss.py +++ b/tools/benchmark/bench_entropy_loss.py @@ -460,31 +460,41 @@ def _dist_cases( # --------------------------------------------------------------------------- entry point +def benchmarks( + dtypes: tuple[torch.dtype, ...] | None = None, + shapes: list[tuple[int, int]] | None = None, +) -> list[tuple[str, list, list]]: + dtypes = tuple(dtypes) if dtypes else _DEFAULT_DTYPES + return [ + ( + "entropy_loss: cross_entropy (labels)", + _label_cases("cross_entropy_labels", dtypes, shapes), + _ce_labels_variants(), + ), + ( + "entropy_loss: cross_entropy (logits)", + _dist_cases("cross_entropy_logits", dtypes, shapes), + _ce_dist_variants(), + ), + ( + "entropy_loss: reverse_kl (logits)", + _dist_cases("reverse_kl_logits", dtypes, shapes), + _reverse_kl_variants(), + ), + ("entropy_loss: z_loss", _label_cases("z_loss", dtypes, shapes), _z_loss_variants()), + ] + + def run( verbose: bool = False, dtypes: tuple[torch.dtype, ...] | None = None, shapes: list[tuple[int, int]] | None = None, + warmup_ms: float = 25.0, + rep_ms: float = 100.0, + min_reps: int = 5, ) -> None: - dtypes = tuple(dtypes) if dtypes else _DEFAULT_DTYPES - run_benchmark( - "entropy_loss: cross_entropy (labels)", - _label_cases("cross_entropy_labels", dtypes, shapes), - _ce_labels_variants(), - verbose=verbose, - ) - run_benchmark( - "entropy_loss: cross_entropy (logits)", - _dist_cases("cross_entropy_logits", dtypes, shapes), - _ce_dist_variants(), - verbose=verbose, - ) - run_benchmark( - "entropy_loss: reverse_kl (logits)", - _dist_cases("reverse_kl_logits", dtypes, shapes), - _reverse_kl_variants(), - verbose=verbose, - ) - run_benchmark("entropy_loss: z_loss", _label_cases("z_loss", dtypes, shapes), _z_loss_variants(), verbose=verbose) + for name, cases, variants in benchmarks(dtypes, shapes): + run_benchmark(name, cases, variants, verbose=verbose, warmup_ms=warmup_ms, rep_ms=rep_ms, min_reps=min_reps) if __name__ == "__main__": diff --git a/tools/benchmark/bench_grpo_loss.py b/tools/benchmark/bench_grpo_loss.py index cc5ec46f1..969a9217f 100644 --- a/tools/benchmark/bench_grpo_loss.py +++ b/tools/benchmark/bench_grpo_loss.py @@ -195,13 +195,24 @@ def _grpo_cases(dtypes: tuple[torch.dtype, ...], shapes: list[tuple[int, int]] | ] +def benchmarks( + dtypes: tuple[torch.dtype, ...] | None = None, + shapes: list[tuple[int, int]] | None = None, +) -> list[tuple[str, list, list]]: + dtypes = tuple(dtypes) if dtypes else _DEFAULT_DTYPES + return [("grpo_loss", _grpo_cases(dtypes, shapes), _grpo_variants())] + + def run( verbose: bool = False, dtypes: tuple[torch.dtype, ...] | None = None, shapes: list[tuple[int, int]] | None = None, + warmup_ms: float = 25.0, + rep_ms: float = 100.0, + min_reps: int = 5, ) -> None: - dtypes = tuple(dtypes) if dtypes else _DEFAULT_DTYPES - run_benchmark("grpo_loss", _grpo_cases(dtypes, shapes), _grpo_variants(), verbose=verbose) + for name, cases, variants in benchmarks(dtypes, shapes): + run_benchmark(name, cases, variants, verbose=verbose, warmup_ms=warmup_ms, rep_ms=rep_ms, min_reps=min_reps) if __name__ == "__main__": diff --git a/tools/benchmark/bench_mlp_activation.py b/tools/benchmark/bench_mlp_activation.py index df92c3fa9..ed96b6476 100644 --- a/tools/benchmark/bench_mlp_activation.py +++ b/tools/benchmark/bench_mlp_activation.py @@ -172,18 +172,24 @@ def _mlp_activation_cases(dtypes: tuple[torch.dtype, ...], shapes: list[tuple[in # --------------------------------------------------------------------------- entry point +def benchmarks( + dtypes: tuple[torch.dtype, ...] | None = None, + shapes: list[tuple[int, int]] | None = None, +) -> list[tuple[str, list, list]]: + dtypes = tuple(dtypes) if dtypes else _DEFAULT_DTYPES + return [("mlp_activation (gated silu)", _mlp_activation_cases(dtypes, shapes), _mlp_activation_variants())] + + def run( verbose: bool = False, dtypes: tuple[torch.dtype, ...] | None = None, shapes: list[tuple[int, int]] | None = None, + warmup_ms: float = 25.0, + rep_ms: float = 100.0, + min_reps: int = 5, ) -> None: - dtypes = tuple(dtypes) if dtypes else _DEFAULT_DTYPES - run_benchmark( - "mlp_activation (gated silu)", - _mlp_activation_cases(dtypes, shapes), - _mlp_activation_variants(), - verbose=verbose, - ) + for name, cases, variants in benchmarks(dtypes, shapes): + run_benchmark(name, cases, variants, verbose=verbose, warmup_ms=warmup_ms, rep_ms=rep_ms, min_reps=min_reps) if __name__ == "__main__": diff --git a/tools/benchmark/bench_normalization.py b/tools/benchmark/bench_normalization.py index 48717202b..cced806a0 100644 --- a/tools/benchmark/bench_normalization.py +++ b/tools/benchmark/bench_normalization.py @@ -340,16 +340,27 @@ def _rms_norm_cases(dtypes: tuple[torch.dtype, ...], shapes: list[tuple[int, int # --------------------------------------------------------------------------- entry point +def benchmarks( + dtypes: tuple[torch.dtype, ...] | None = None, + shapes: list[tuple[int, int]] | None = None, +) -> list[tuple[str, list, list]]: + dtypes = tuple(dtypes) if dtypes else _DEFAULT_DTYPES + return [ + ("normalization: layer_norm", _layer_norm_cases(dtypes, shapes), _layer_norm_variants()), + ("normalization: rms_norm", _rms_norm_cases(dtypes, shapes), _rms_norm_variants()), + ] + + def run( verbose: bool = False, dtypes: tuple[torch.dtype, ...] | None = None, shapes: list[tuple[int, int]] | None = None, + warmup_ms: float = 25.0, + rep_ms: float = 100.0, + min_reps: int = 5, ) -> None: - dtypes = tuple(dtypes) if dtypes else _DEFAULT_DTYPES - run_benchmark( - "normalization: layer_norm", _layer_norm_cases(dtypes, shapes), _layer_norm_variants(), verbose=verbose - ) - run_benchmark("normalization: rms_norm", _rms_norm_cases(dtypes, shapes), _rms_norm_variants(), verbose=verbose) + for name, cases, variants in benchmarks(dtypes, shapes): + run_benchmark(name, cases, variants, verbose=verbose, warmup_ms=warmup_ms, rep_ms=rep_ms, min_reps=min_reps) if __name__ == "__main__": diff --git a/tools/benchmark/bench_pointwise.py b/tools/benchmark/bench_pointwise.py index 2648d6e80..e00702617 100644 --- a/tools/benchmark/bench_pointwise.py +++ b/tools/benchmark/bench_pointwise.py @@ -135,15 +135,28 @@ def _add_cases(dtypes: tuple[torch.dtype, ...], shapes: list[int] | None = None) # --------------------------------------------------------------------------- entry point +def benchmarks( + dtypes: tuple[torch.dtype, ...] | None = None, + shapes: list[int] | None = None, +) -> list[tuple[str, list, list]]: + dtypes = tuple(dtypes) if dtypes else _DEFAULT_DTYPES + return [ + ("pointwise: copy", _copy_cases(dtypes, shapes), _COPY_VARIANTS), + ("pointwise: fill", _fill_cases(dtypes, shapes), _FILL_VARIANTS), + ("pointwise: add", _add_cases(dtypes, shapes), _ADD_VARIANTS), + ] + + def run( verbose: bool = False, dtypes: tuple[torch.dtype, ...] | None = None, shapes: list[int] | None = None, + warmup_ms: float = 25.0, + rep_ms: float = 100.0, + min_reps: int = 5, ) -> None: - dtypes = tuple(dtypes) if dtypes else _DEFAULT_DTYPES - run_benchmark("pointwise: copy", _copy_cases(dtypes, shapes), _COPY_VARIANTS, verbose=verbose) - run_benchmark("pointwise: fill", _fill_cases(dtypes, shapes), _FILL_VARIANTS, verbose=verbose) - run_benchmark("pointwise: add", _add_cases(dtypes, shapes), _ADD_VARIANTS, verbose=verbose) + for name, cases, variants in benchmarks(dtypes, shapes): + run_benchmark(name, cases, variants, verbose=verbose, warmup_ms=warmup_ms, rep_ms=rep_ms, min_reps=min_reps) if __name__ == "__main__": diff --git a/tools/benchmark/bench_rotary.py b/tools/benchmark/bench_rotary.py index 5592bdc1a..fb798c234 100644 --- a/tools/benchmark/bench_rotary.py +++ b/tools/benchmark/bench_rotary.py @@ -113,13 +113,24 @@ def _rotary_cases(dtypes: tuple[torch.dtype, ...], shapes: list[tuple[int, int, ] +def benchmarks( + dtypes: tuple[torch.dtype, ...] | None = None, + shapes: list[tuple[int, int, int]] | None = None, +) -> list[tuple[str, list, list]]: + dtypes = tuple(dtypes) if dtypes else _DEFAULT_DTYPES + return [("rotary", _rotary_cases(dtypes, shapes), _rotary_variants())] + + def run( verbose: bool = False, dtypes: tuple[torch.dtype, ...] | None = None, shapes: list[tuple[int, int, int]] | None = None, + warmup_ms: float = 25.0, + rep_ms: float = 100.0, + min_reps: int = 5, ) -> None: - dtypes = tuple(dtypes) if dtypes else _DEFAULT_DTYPES - run_benchmark("rotary", _rotary_cases(dtypes, shapes), _rotary_variants(), verbose=verbose) + for name, cases, variants in benchmarks(dtypes, shapes): + run_benchmark(name, cases, variants, verbose=verbose, warmup_ms=warmup_ms, rep_ms=rep_ms, min_reps=min_reps) if __name__ == "__main__": diff --git a/tools/benchmark/bench_sparse_copy.py b/tools/benchmark/bench_sparse_copy.py index a2997bf7d..2d4464dd9 100644 --- a/tools/benchmark/bench_sparse_copy.py +++ b/tools/benchmark/bench_sparse_copy.py @@ -335,14 +335,27 @@ def _combine_cases( # --------------------------------------------------------------------------- entry point +def benchmarks( + dtypes: tuple[torch.dtype, ...] | None = None, + shapes: list[tuple[int, int, int, int]] | None = None, +) -> list[tuple[str, list, list]]: + dtypes = tuple(dtypes) if dtypes else _DEFAULT_DTYPES + return [ + ("sparse_copy: dispatch", _dispatch_cases(dtypes, shapes), _dispatch_variants()), + ("sparse_copy: combine", _combine_cases(dtypes, shapes), _combine_variants()), + ] + + def run( verbose: bool = False, dtypes: tuple[torch.dtype, ...] | None = None, shapes: list[tuple[int, int, int, int]] | None = None, + warmup_ms: float = 25.0, + rep_ms: float = 100.0, + min_reps: int = 5, ) -> None: - dtypes = tuple(dtypes) if dtypes else _DEFAULT_DTYPES - run_benchmark("sparse_copy: dispatch", _dispatch_cases(dtypes, shapes), _dispatch_variants(), verbose=verbose) - run_benchmark("sparse_copy: combine", _combine_cases(dtypes, shapes), _combine_variants(), verbose=verbose) + for name, cases, variants in benchmarks(dtypes, shapes): + run_benchmark(name, cases, variants, verbose=verbose, warmup_ms=warmup_ms, rep_ms=rep_ms, min_reps=min_reps) if __name__ == "__main__": diff --git a/tools/benchmark/bench_sparse_linear.py b/tools/benchmark/bench_sparse_linear.py index e799944d7..07faf16c3 100644 --- a/tools/benchmark/bench_sparse_linear.py +++ b/tools/benchmark/bench_sparse_linear.py @@ -344,24 +344,35 @@ def _input_inner_sparse_cases( # --------------------------------------------------------------------------- entry point +def benchmarks( + dtypes: tuple[torch.dtype, ...] | None = None, + shapes: list[tuple[int, int, int, int, int]] | None = None, +) -> list[tuple[str, list, list]]: + dtypes = tuple(dtypes) if dtypes else _DEFAULT_DTYPES + return [ + ( + "sparse_linear: output_sparse (layer 1 / up-proj)", + _output_sparse_cases(dtypes, shapes), + _output_sparse_variants(), + ), + ( + "sparse_linear: input_inner_sparse (layer 2 / down-proj)", + _input_inner_sparse_cases(dtypes, shapes), + _input_inner_sparse_variants(), + ), + ] + + def run( verbose: bool = False, dtypes: tuple[torch.dtype, ...] | None = None, shapes: list[tuple[int, int, int, int, int]] | None = None, + warmup_ms: float = 25.0, + rep_ms: float = 100.0, + min_reps: int = 5, ) -> None: - dtypes = tuple(dtypes) if dtypes else _DEFAULT_DTYPES - run_benchmark( - "sparse_linear: output_sparse (layer 1 / up-proj)", - _output_sparse_cases(dtypes, shapes), - _output_sparse_variants(), - verbose=verbose, - ) - run_benchmark( - "sparse_linear: input_inner_sparse (layer 2 / down-proj)", - _input_inner_sparse_cases(dtypes, shapes), - _input_inner_sparse_variants(), - verbose=verbose, - ) + for name, cases, variants in benchmarks(dtypes, shapes): + run_benchmark(name, cases, variants, verbose=verbose, warmup_ms=warmup_ms, rep_ms=rep_ms, min_reps=min_reps) if __name__ == "__main__": diff --git a/tools/benchmark/runner.py b/tools/benchmark/runner.py index f9f101d21..e3c649cbd 100644 --- a/tools/benchmark/runner.py +++ b/tools/benchmark/runner.py @@ -280,6 +280,7 @@ def _run_one_variant( reference_outputs: dict[str, dict[str, torch.Tensor]] | None, warmup_ms: float, rep_ms: float, + min_reps: int = 5, ) -> VariantResult: result = VariantResult(variant_name=variant.name) try: @@ -313,7 +314,9 @@ def _fwd_once() -> Any: # Timing: reuse the same input tensors, fn closes over them. _reset_fwd = (lambda: variant.reset_inputs(inputs)) if variant.reset_inputs else None - result.fwd_timing = bench_fn(_guarded_fwd, reset=_reset_fwd, warmup_ms=warmup_ms, rep_ms=rep_ms) + result.fwd_timing = bench_fn( + _guarded_fwd, reset=_reset_fwd, warmup_ms=warmup_ms, rep_ms=rep_ms, min_reps=min_reps + ) del inputs # fwd+bwd mode @@ -353,7 +356,7 @@ def _fwd_bwd_once() -> Any: # Timing. _reset_fwd_bwd = (lambda: variant.reset_inputs(inputs)) if variant.reset_inputs else None result.fwd_bwd_timing = bench_fn( - _guarded_fwd_bwd, reset=_reset_fwd_bwd, warmup_ms=warmup_ms, rep_ms=rep_ms + _guarded_fwd_bwd, reset=_reset_fwd_bwd, warmup_ms=warmup_ms, rep_ms=rep_ms, min_reps=min_reps ) del inputs elif variant.fwd is not None and result.memory is None: @@ -627,6 +630,7 @@ def run_benchmark( *, warmup_ms: float = 25.0, rep_ms: float = 100.0, + min_reps: int = 5, verbose: bool = False, print_fn: Callable[[str], None] = print, ) -> list[tuple[Case, list[VariantResult]]]: @@ -658,7 +662,7 @@ def run_benchmark( has_fwd_bwd = False rms_keys_seen: list[str] = [] for variant in variants: - r = _run_one_variant(variant, case, ref_outputs, warmup_ms=warmup_ms, rep_ms=rep_ms) + r = _run_one_variant(variant, case, ref_outputs, warmup_ms=warmup_ms, rep_ms=rep_ms, min_reps=min_reps) results.append(r) has_fwd = has_fwd or r.fwd_timing is not None has_fwd_bwd = has_fwd_bwd or r.fwd_bwd_timing is not None From df1730f176a046e0000b03c897e0a9573d478f8c Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 30 Apr 2026 00:04:55 -0400 Subject: [PATCH 29/41] Skip compiled variants in benchmark smoke tests pytorch_compiled and pytorch_compiled_max aren't needed for correctness checking and slow down the suite without covering new code paths. Co-Authored-By: Claude Sonnet 4.6 --- tests/tools/test_triton_benchmark.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/tools/test_triton_benchmark.py b/tests/tools/test_triton_benchmark.py index 2f5f2f772..fbaab7fb6 100644 --- a/tests/tools/test_triton_benchmark.py +++ b/tests/tools/test_triton_benchmark.py @@ -56,6 +56,10 @@ def _disable_dynamo(): torch._dynamo.config.disable = orig +_SKIP_VARIANTS = {"pytorch_compiled", "pytorch_compiled_max"} + + @pytest.mark.parametrize("name,cases,variants", _PARAMS) def test_triton_benchmark(name, cases, variants): + variants = [v for v in variants if v.name not in _SKIP_VARIANTS] run_benchmark(name, cases, variants, warmup_ms=0, rep_ms=0, min_reps=1) From 6ccfca3f781a7310b7eb85c86d81984efe1b5d2b Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 30 Apr 2026 01:12:30 -0400 Subject: [PATCH 30/41] Skip sparse benchmarks in Triton interpreter, revert broken histogram patch tl.histogram is broken in the Triton interpreter; skip sparse_copy and sparse_linear in interpreter mode rather than patching around multiple cascading bugs. The np.histogram monkeypatch is reverted as it is no longer needed. Co-Authored-By: Claude Sonnet 4.6 --- tests/tools/test_triton_benchmark.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/tests/tools/test_triton_benchmark.py b/tests/tools/test_triton_benchmark.py index fbaab7fb6..a0505d3b4 100644 --- a/tests/tools/test_triton_benchmark.py +++ b/tests/tools/test_triton_benchmark.py @@ -10,6 +10,7 @@ import pytest import torch +from fast_llm.functional.triton import triton_interpret from tools.benchmark import ( bench_entropy_loss, bench_grpo_loss, @@ -24,6 +25,15 @@ _DTYPES = (torch.float32,) +# sparse_copy and sparse_linear use tl.histogram, which has unfixed bugs in the +# Triton interpreter. Skip them in interpreter mode; they're covered on GPU. +_INTERPRETER_SKIP = { + "sparse_copy: dispatch", + "sparse_copy: combine", + "sparse_linear: output_sparse (layer 1 / up-proj)", + "sparse_linear: input_inner_sparse (layer 2 / down-proj)", +} + def _build_params() -> list: modules_and_shapes = [ @@ -61,5 +71,7 @@ def _disable_dynamo(): @pytest.mark.parametrize("name,cases,variants", _PARAMS) def test_triton_benchmark(name, cases, variants): + if triton_interpret and name in _INTERPRETER_SKIP: + pytest.skip("tl.histogram is broken in the Triton interpreter") variants = [v for v in variants if v.name not in _SKIP_VARIANTS] run_benchmark(name, cases, variants, warmup_ms=0, rep_ms=0, min_reps=1) From 7a668e82628b3923d77078d9a0c93b6a634f9484 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 1 May 2026 03:45:30 -0400 Subject: [PATCH 31/41] Speed up benchmark smoke tests from 27s to ~4s Three changes: 1. Remove gc.collect() from the runner before each variant's timed run. Python's reference counting handles cleanup immediately; gc.collect() is only needed for cyclic references, which are unlikely to hold GPU tensors. Each call took ~80ms, adding ~240ms per test for no benefit. 2. Monkeypatch the test to avoid kernel compilation and CUDA sync overhead: - Replace fast_llm_triton variants with fp32 reference (no Triton JIT). - Disable TritonConfig.enabled so sparse_linear make_inputs warmup doesn't trigger kernel compilation. - Set _cudagraph_mark_step_begin=None and synchronize=no-op to eliminate C-level CUDA syncs (~400ms/call) that _guarded() inserts before every fn() invocation. Co-Authored-By: Claude Sonnet 4.6 --- tests/tools/test_triton_benchmark.py | 30 ++++++++++++++++++++++++---- tools/benchmark/runner.py | 5 ----- 2 files changed, 26 insertions(+), 9 deletions(-) diff --git a/tests/tools/test_triton_benchmark.py b/tests/tools/test_triton_benchmark.py index a0505d3b4..4e8f4efb6 100644 --- a/tests/tools/test_triton_benchmark.py +++ b/tests/tools/test_triton_benchmark.py @@ -2,14 +2,24 @@ Smoke tests for all benchmark modules. One test per sub-benchmark (kernel): inputs are tiny so the runner code path is -exercised quickly on CPU. torch.compile is disabled via a fixture to avoid the -JIT cold-start (~20 s on CPU per variant). warmup_ms=0 and min_reps=1 cap the -rep count to one timed call per variant so the suite stays fast. +exercised quickly without requiring a full benchmark run. + +Patches applied to keep each test under ~100 ms: +- torch.compile disabled (avoids JIT cold-start). +- fast_llm_triton variants replaced with fp32 reference (no Triton compilation; + kernel correctness is covered by the main test suite). +- TritonConfig.enabled → False (prevents make_inputs warmup in sparse_linear). +- _cudagraph_mark_step_begin → None and synchronize → no-op (both cause C-level + CUDA syncs per fn() call that dominate the wall time without this). """ +import dataclasses + import pytest import torch +import tools.benchmark.runner as _bench_runner +from fast_llm.functional.config import TritonConfig from fast_llm.functional.triton import triton_interpret from tools.benchmark import ( bench_entropy_loss, @@ -70,8 +80,20 @@ def _disable_dynamo(): @pytest.mark.parametrize("name,cases,variants", _PARAMS) -def test_triton_benchmark(name, cases, variants): +def test_triton_benchmark(name, cases, variants, monkeypatch): if triton_interpret and name in _INTERPRETER_SKIP: pytest.skip("tl.histogram is broken in the Triton interpreter") + + monkeypatch.setattr(TritonConfig, "enabled", lambda *a, **kw: False) + monkeypatch.setattr(_bench_runner, "_cudagraph_mark_step_begin", None) + monkeypatch.setattr(torch.cuda, "synchronize", lambda: None) + variants = [v for v in variants if v.name not in _SKIP_VARIANTS] + # Replace fast_llm_triton fwd/fwd_bwd with the fp32 reference so no Triton + # kernels are compiled. The runner still exercises the full variant code path. + ref = next(v for v in variants if v.is_reference) + variants = [ + dataclasses.replace(v, fwd=ref.fwd, fwd_bwd=ref.fwd_bwd) if v.name == "fast_llm_triton" else v + for v in variants + ] run_benchmark(name, cases, variants, warmup_ms=0, rep_ms=0, min_reps=1) diff --git a/tools/benchmark/runner.py b/tools/benchmark/runner.py index e3c649cbd..a5f823563 100644 --- a/tools/benchmark/runner.py +++ b/tools/benchmark/runner.py @@ -10,7 +10,6 @@ """ import dataclasses -import gc import math import statistics import time @@ -288,7 +287,6 @@ def _run_one_variant( # fwd mode if variant.fwd is not None: inputs = _seeded_inputs(case) - gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() @@ -322,7 +320,6 @@ def _fwd_once() -> Any: # fwd+bwd mode if variant.fwd_bwd is not None: inputs = _seeded_inputs(case) - gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() @@ -347,7 +344,6 @@ def _fwd_bwd_once() -> Any: # Memory measurement: one fresh call on fresh inputs. fresh_inputs = _seeded_inputs(case) - gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() result.memory = measure_memory(_guarded(lambda: variant.fwd_bwd(fresh_inputs))) @@ -362,7 +358,6 @@ def _fwd_bwd_once() -> Any: elif variant.fwd is not None and result.memory is None: # No backward — measure fwd-mode memory. fresh_inputs = _seeded_inputs(case) - gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() result.memory = measure_memory(_guarded(lambda: variant.fwd(fresh_inputs))) From d7d96d842fcba5a8f91c8059d020b3b475bc97f1 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 1 May 2026 05:47:05 -0400 Subject: [PATCH 32/41] Mask padded entries in comparison instead of zeroing kernel inputs The previous approach forced kernel and reference into bit-equivalent regions by (a) writing zeros to phantom blocks in the kernel and (b) zeroing padded rows of lhs / backward_grad in the benchmark setup. Both were extra work to mask values that downstream consumers ignore. Now the kernel writes garbage to those positions (early-returns on phantom blocks; matmuls random padded inputs in within-expert padding) and the comparison side masks them out before rel_rms. - sparse_linear.py: revert the phantom-block "write zeros" change, back to early-return. Keep the bf16-accumulation and grad_rhs arg-swap fixes from the same commit. - bench_sparse_linear.py: drop _zero_padded_rows, generate plain randn / ones inputs, add an output_postprocess that masks both within-expert padding and phantom rows past expert_ends[-1] on the candidate. All six (op, shape) configurations are back at bf16 noise floor on H100. - test_sparse_matmul.py: apply the same masking to the autograd test, which was silently failing on every padded data case. Collapse the two near-duplicate autograd tests into one parametrized over (autograd_class, expert_axis) and unify the two reference helpers. Adjust sparse_test_data3 expert sizes to multiples of 128 so blocks don't straddle expert boundaries (kernel TODO: "Assumes sparse_index is constant within a block"). - test_triton_benchmark.py: fold the env patches into a single monkeypatch-based autouse fixture. Co-Authored-By: Claude Opus 4.7 (1M context) --- fast_llm/functional/triton/sparse_linear.py | 32 +++----- tests/functional/test_sparse_matmul.py | 91 ++++++++++----------- tests/tools/test_triton_benchmark.py | 23 ++---- tools/benchmark/bench_sparse_linear.py | 47 +++++++---- 4 files changed, 92 insertions(+), 101 deletions(-) diff --git a/fast_llm/functional/triton/sparse_linear.py b/fast_llm/functional/triton/sparse_linear.py index 601ae0fa5..14b15b319 100644 --- a/fast_llm/functional/triton/sparse_linear.py +++ b/fast_llm/functional/triton/sparse_linear.py @@ -231,23 +231,18 @@ def output_sparse_matmul_kernel( sparse_range = tl_arange(0, padded_sparse_dim) expert_ends = tl.load(expert_ends_ptr + sparse_range, mask=sparse_range < sparse_dim, other=row_dim) sparse_index = tl.sum((expert_ends <= row_offset).to(tl.int64)) # noqa - - # Pointers - row_range = tl_arange(0, block_size_row)[:, None] - col_range = tl_arange(0, block_size_col)[None, :] - out_ptr += (row_offset + row_range) * out_stride_row + (col_sparse_offset + col_range) * out_stride_col - if sparse_index == sparse_dim: - # Phantom block: row_offset is past the last expert. Write zeros so the - # output is fully defined regardless of the caller's allocation. - if not accumulate: - tl.store(out_ptr, tl.zeros((block_size_row, block_size_col), dtype=out_ptr.dtype.element_ty)) + # Phantom block past the last expert; the caller is expected to ignore these rows. return col_dense_offset = col_sparse_offset + sparse_index * col_sparse_dim + # Pointers + row_range = tl_arange(0, block_size_row)[:, None] + col_range = tl_arange(0, block_size_col)[None, :] inner_range = tl_arange(0, block_size_inner) lhs_ptr += (row_offset + row_range) * lhs_stride_row + inner_range[None, :] * lhs_stride_inner rhs_ptr += inner_range[:, None] * rhs_stride_inner + (col_dense_offset + col_range) * rhs_stride_col + out_ptr += (row_offset + row_range) * out_stride_row + (col_sparse_offset + col_range) * out_stride_col # Matrix multiplication out = tl.dot(tl.load(lhs_ptr), tl.load(rhs_ptr), out_dtype=tl.float32) @@ -356,30 +351,25 @@ def input_inner_sparse_matmul_kernel( # Grid offsets row_offset = pid_row * block_size_row - col_offset = pid_col * block_size_col sparse_range = tl_arange(0, padded_sparse_dim) expert_ends = tl.load(expert_ends_ptr + sparse_range, mask=sparse_range < sparse_dim, other=row_dim) sparse_index = tl.sum((expert_ends <= row_offset).to(tl.int64)) # noqa - - # Pointers - row_range = tl_arange(0, block_size_row)[:, None] - col_range = tl_arange(0, block_size_col)[None, :] - out_ptr += (row_offset + row_range) * out_stride_row + (col_offset + col_range) * out_stride_col - if sparse_index == sparse_dim: - # Phantom block: row_offset is past the last expert. Write zeros so the - # output is fully defined regardless of the caller's allocation. - if not accumulate: - tl.store(out_ptr, tl.zeros((block_size_row, block_size_col), dtype=out_ptr.dtype.element_ty)) + # Phantom block past the last expert; the caller is expected to ignore these rows. return inner_dense_offset = sparse_index * inner_sparse_dim + col_offset = pid_col * block_size_col + # Pointers + row_range = tl_arange(0, block_size_row)[:, None] + col_range = tl_arange(0, block_size_col)[None, :] inner_range = tl_arange(0, block_size_inner) lhs_ptr += (row_offset + row_range) * lhs_stride_row + inner_range[None, :] * lhs_stride_inner rhs_ptr += (inner_dense_offset + inner_range[:, None]) * rhs_stride_inner + ( col_offset + col_range ) * rhs_stride_col + out_ptr += (row_offset + row_range) * out_stride_row + (col_offset + col_range) * out_stride_col # Matrix multiplication out = tl.dot(tl.load(lhs_ptr), tl.load(rhs_ptr), out_dtype=tl.float32) diff --git a/tests/functional/test_sparse_matmul.py b/tests/functional/test_sparse_matmul.py index f48a216e9..7269e893f 100644 --- a/tests/functional/test_sparse_matmul.py +++ b/tests/functional/test_sparse_matmul.py @@ -86,11 +86,14 @@ def normal(self, dim_0: int, dim_1: int, device: torch.device) -> torch.Tensor: tokens_per_expert=(200,), ), # Four experts, fully packed (no padding rows) — exercises the pad_begin == expert_end path. + # Expert sizes must be multiples of the largest autotune block_size_row (128); otherwise + # blocks straddle expert boundaries and the kernel's "sparse_index constant within a block" + # assumption breaks. _SparseTestData( dense_dim=384, sparse_dim=128, - expert_ends=(64, 128, 192, 256), - tokens_per_expert=(64, 64, 64, 64), + expert_ends=(128, 256, 384, 512), + tokens_per_expert=(128, 128, 128, 128), ), ) @@ -172,69 +175,59 @@ def test_input_row_sparse_matmul(sparse_test_data, testing_device): # --------------------------------------------------------------------------- autograd wrappers -def _output_sparse_linear_ref(lhs: torch.Tensor, rhs: torch.Tensor, data: _SparseTestData) -> torch.Tensor: - """OutputSparseLinear forward via Python loop (PyTorch-differentiable).""" - ffn = rhs.shape[1] // data.num_experts - out = lhs.new_zeros(data.token_dim, ffn) - for i in range(data.num_experts): - begin, end = data.expert_begins[i], data.expert_pad_begins[i] - if end > begin: - out[begin:end] = lhs[begin:end] @ rhs[:, i * ffn : (i + 1) * ffn] +def _sparse_linear_ref(lhs: torch.Tensor, rhs: torch.Tensor, data: _SparseTestData, expert_axis: int) -> torch.Tensor: + """Per-expert matmul reference; rows past `expert_pad_begins` are zero in the output.""" + rhs_per_expert = rhs.chunk(data.num_experts, dim=expert_axis) + out = lhs.new_zeros(data.token_dim, rhs_per_expert[0].shape[1]) + for i, (begin, end) in enumerate(zip(data.expert_begins, data.expert_pad_begins, strict=True)): + out[begin:end] = lhs[begin:end] @ rhs_per_expert[i] return out -def _input_sparse_linear_ref(lhs: torch.Tensor, rhs: torch.Tensor, data: _SparseTestData) -> torch.Tensor: - """InputSparseLinear forward via Python loop (PyTorch-differentiable).""" - ffn = rhs.shape[0] // data.num_experts - out = lhs.new_zeros(data.token_dim, rhs.shape[1]) - for i in range(data.num_experts): - begin, end = data.expert_begins[i], data.expert_pad_begins[i] +def _zero_padded_rows(tensor: torch.Tensor, data: _SparseTestData) -> torch.Tensor: + # The autograd kernels treat padded tokens as regular ones; forward output and grad_lhs + # contain matmul-of-random garbage in [pad_begin, expert_end). Zero those rows so the + # comparison vs the reference (which produces zeros there) reflects only real-row error. + masked = tensor.clone() + for begin, end in zip(data.expert_pad_begins, data.expert_ends, strict=True): if end > begin: - out[begin:end] = lhs[begin:end] @ rhs[i * ffn : (i + 1) * ffn] - return out - - -@requires_triton -@pytest.mark.slow -@pytest.mark.parametrize("sparse_test_data", _SPARSE_TEST_DATAS) -def test_output_sparse_linear_autograd(sparse_test_data, testing_device): - lhs = sparse_test_data.normal(sparse_test_data.token_dim, sparse_test_data.dense_dim, testing_device) - rhs = sparse_test_data.normal(sparse_test_data.dense_dim, sparse_test_data.sparse_dim_expanded, testing_device) - grad_output = sparse_test_data.normal(sparse_test_data.token_dim, sparse_test_data.sparse_dim, testing_device) - - lhs_ref = lhs.detach().requires_grad_(True) - rhs_ref = rhs.detach().requires_grad_(True) - out_ref = _output_sparse_linear_ref(lhs_ref, rhs_ref, sparse_test_data) - out_ref.backward(grad_output) - - lhs_t = lhs.detach().requires_grad_(True) - rhs_t = rhs.detach().requires_grad_(True) - out_t = OutputSparseLinear.apply(lhs_t, rhs_t, sparse_test_data.get_sparse_map(testing_device)) - out_t.backward(grad_output) - - Assert.rms_close(out_t, out_ref, 1e-3) - Assert.rms_close(lhs_t.grad, lhs_ref.grad, 1e-3) - Assert.rms_close(rhs_t.grad, rhs_ref.grad, 1e-3) + masked[begin:end] = 0 + return masked @requires_triton @pytest.mark.slow @pytest.mark.parametrize("sparse_test_data", _SPARSE_TEST_DATAS) -def test_input_sparse_linear_autograd(sparse_test_data, testing_device): - lhs = sparse_test_data.normal(sparse_test_data.token_dim, sparse_test_data.sparse_dim, testing_device) - rhs = sparse_test_data.normal(sparse_test_data.sparse_dim_expanded, sparse_test_data.dense_dim, testing_device) - grad_output = sparse_test_data.normal(sparse_test_data.token_dim, sparse_test_data.dense_dim, testing_device) +@pytest.mark.parametrize( + "autograd_class,expert_axis", + [(OutputSparseLinear, 1), (InputSparseLinear, 0)], + ids=["output_sparse", "input_sparse"], +) +def test_sparse_linear_autograd(sparse_test_data, testing_device, autograd_class, expert_axis): + # `expert_axis` is the rhs axis split per expert. Matmul contracts rhs's axis 0, so + # OutputSparseLinear (expert_axis=1) splits the output dim, InputSparseLinear (expert_axis=0) + # splits the contracting dim. + if expert_axis == 1: + lhs_features, out_features = sparse_test_data.dense_dim, sparse_test_data.sparse_dim + rhs_shape = (sparse_test_data.dense_dim, sparse_test_data.sparse_dim_expanded) + else: + lhs_features, out_features = sparse_test_data.sparse_dim, sparse_test_data.dense_dim + rhs_shape = (sparse_test_data.sparse_dim_expanded, sparse_test_data.dense_dim) + + lhs = sparse_test_data.normal(sparse_test_data.token_dim, lhs_features, testing_device) + rhs = sparse_test_data.normal(*rhs_shape, testing_device) + grad_output = sparse_test_data.normal(sparse_test_data.token_dim, out_features, testing_device) lhs_ref = lhs.detach().requires_grad_(True) rhs_ref = rhs.detach().requires_grad_(True) - out_ref = _input_sparse_linear_ref(lhs_ref, rhs_ref, sparse_test_data) + out_ref = _sparse_linear_ref(lhs_ref, rhs_ref, sparse_test_data, expert_axis) out_ref.backward(grad_output) lhs_t = lhs.detach().requires_grad_(True) rhs_t = rhs.detach().requires_grad_(True) - out_t = InputSparseLinear.apply(lhs_t, rhs_t, sparse_test_data.get_sparse_map(testing_device)) + out_t = autograd_class.apply(lhs_t, rhs_t, sparse_test_data.get_sparse_map(testing_device)) out_t.backward(grad_output) - Assert.rms_close(out_t, out_ref, 1e-3) - Assert.rms_close(lhs_t.grad, lhs_ref.grad, 1e-3) + Assert.rms_close(_zero_padded_rows(out_t, sparse_test_data), out_ref, 1e-3) + Assert.rms_close(_zero_padded_rows(lhs_t.grad, sparse_test_data), lhs_ref.grad, 1e-3) Assert.rms_close(rhs_t.grad, rhs_ref.grad, 1e-3) diff --git a/tests/tools/test_triton_benchmark.py b/tests/tools/test_triton_benchmark.py index 4e8f4efb6..cd2dd3db3 100644 --- a/tests/tools/test_triton_benchmark.py +++ b/tests/tools/test_triton_benchmark.py @@ -44,6 +44,8 @@ "sparse_linear: input_inner_sparse (layer 2 / down-proj)", } +_SKIP_VARIANTS = {"pytorch_compiled", "pytorch_compiled_max"} + def _build_params() -> list: modules_and_shapes = [ @@ -67,30 +69,23 @@ def _build_params() -> list: @pytest.fixture(autouse=True) -def _disable_dynamo(): +def _patch_benchmark_env(monkeypatch): import torch._dynamo - orig = torch._dynamo.config.disable - torch._dynamo.config.disable = True - yield - torch._dynamo.config.disable = orig - - -_SKIP_VARIANTS = {"pytorch_compiled", "pytorch_compiled_max"} + monkeypatch.setattr(torch._dynamo.config, "disable", True) + monkeypatch.setattr(TritonConfig, "enabled", lambda *a, **kw: False) + monkeypatch.setattr(_bench_runner, "_cudagraph_mark_step_begin", None) + monkeypatch.setattr(torch.cuda, "synchronize", lambda: None) @pytest.mark.parametrize("name,cases,variants", _PARAMS) -def test_triton_benchmark(name, cases, variants, monkeypatch): +def test_triton_benchmark(name, cases, variants): if triton_interpret and name in _INTERPRETER_SKIP: pytest.skip("tl.histogram is broken in the Triton interpreter") - monkeypatch.setattr(TritonConfig, "enabled", lambda *a, **kw: False) - monkeypatch.setattr(_bench_runner, "_cudagraph_mark_step_begin", None) - monkeypatch.setattr(torch.cuda, "synchronize", lambda: None) - - variants = [v for v in variants if v.name not in _SKIP_VARIANTS] # Replace fast_llm_triton fwd/fwd_bwd with the fp32 reference so no Triton # kernels are compiled. The runner still exercises the full variant code path. + variants = [v for v in variants if v.name not in _SKIP_VARIANTS] ref = next(v for v in variants if v.is_reference) variants = [ dataclasses.replace(v, fwd=ref.fwd, fwd_bwd=ref.fwd_bwd) if v.name == "fast_llm_triton" else v diff --git a/tools/benchmark/bench_sparse_linear.py b/tools/benchmark/bench_sparse_linear.py index 07faf16c3..a5741bed5 100644 --- a/tools/benchmark/bench_sparse_linear.py +++ b/tools/benchmark/bench_sparse_linear.py @@ -53,24 +53,39 @@ def _make_sparse_map(tokens: int, top_k: int, num_experts: int) -> SparseMap: return get_sparse_map(top_experts, num_experts) -def _zero_padded_rows(tensor: torch.Tensor, sparse_map: SparseMap) -> torch.Tensor: - for e in range(sparse_map.num_experts): - pad_start = int(sparse_map.expert_pad_begins[e]) - pad_end = int(sparse_map.expert_ends[e]) - if pad_end > pad_start: - tensor[pad_start:pad_end] = 0 - return tensor +def _mask_padded_rows(cand: dict[str, torch.Tensor], inputs: dict) -> dict[str, torch.Tensor]: + # Two regions in the kernel's forward output and grad_lhs are by-design garbage that + # downstream consumers ignore: per-expert padding [pad_begin, expert_end) (where the + # kernel does a matmul on random padding inputs) and phantom rows [expert_ends[-1], + # num_rows) past the last expert (where the kernel early-returns and leaves the output + # buffer uninitialized). The loop reference produces zeros in both regions, so without + # masking those mismatches would dominate rel_rms. grad_rhs already excludes padded + # contributions in both the kernel and reference, so it needs no masking. + sparse_map = inputs["sparse_map"] + pad_begins = sparse_map.expert_pad_begins.tolist() + pad_ends = sparse_map.expert_ends.tolist() + last_expert_end = pad_ends[-1] + masked = dict(cand) + for key in ("output", "grad_lhs"): + if key not in masked: + continue + clone = masked[key].clone() + for begin, end in zip(pad_begins, pad_ends, strict=True): + if end > begin: + clone[begin:end] = 0 + if clone.shape[0] > last_expert_end: + clone[last_expert_end:] = 0 + masked[key] = clone + return masked def _make_output_sparse_inputs( tokens: int, top_k: int, num_experts: int, hidden: int, ffn_per_expert: int, dtype: torch.dtype ) -> dict: sparse_map = _make_sparse_map(tokens, top_k, num_experts) - lhs_data = _zero_padded_rows(torch.randn(sparse_map.num_rows, hidden, dtype=dtype, device=device()), sparse_map) + lhs_data = torch.randn(sparse_map.num_rows, hidden, dtype=dtype, device=device()) rhs_data = torch.randn(hidden, ffn_per_expert * num_experts, dtype=dtype, device=device()) - backward_grad = _zero_padded_rows( - torch.ones(sparse_map.num_rows, ffn_per_expert, dtype=dtype, device=device()), sparse_map - ) + backward_grad = torch.ones(sparse_map.num_rows, ffn_per_expert, dtype=dtype, device=device()) _warmup_key = (tokens, top_k, num_experts, hidden, ffn_per_expert, dtype) if TritonConfig.enabled() and _warmup_key not in _output_sparse_warmed_up: _w_lhs = lhs_data.detach().requires_grad_(True) @@ -92,13 +107,9 @@ def _make_input_inner_sparse_inputs( tokens: int, top_k: int, num_experts: int, hidden: int, ffn_per_expert: int, dtype: torch.dtype ) -> dict: sparse_map = _make_sparse_map(tokens, top_k, num_experts) - lhs_data = _zero_padded_rows( - torch.randn(sparse_map.num_rows, ffn_per_expert, dtype=dtype, device=device()), sparse_map - ) + lhs_data = torch.randn(sparse_map.num_rows, ffn_per_expert, dtype=dtype, device=device()) rhs_data = torch.randn(ffn_per_expert * num_experts, hidden, dtype=dtype, device=device()) - backward_grad = _zero_padded_rows( - torch.ones(sparse_map.num_rows, hidden, dtype=dtype, device=device()), sparse_map - ) + backward_grad = torch.ones(sparse_map.num_rows, hidden, dtype=dtype, device=device()) _warmup_key = (tokens, top_k, num_experts, hidden, ffn_per_expert, dtype) if TritonConfig.enabled() and _warmup_key not in _input_inner_sparse_warmed_up: _w_lhs = lhs_data.detach().requires_grad_(True) @@ -193,6 +204,7 @@ def _output_sparse_variants() -> list[Variant]: name="fast_llm_triton", fwd=_run_output_sparse_fwd_triton, fwd_bwd=_run_output_sparse_fwd_bwd_triton, + output_postprocess=_mask_padded_rows, ) ) return variants @@ -275,6 +287,7 @@ def _input_inner_sparse_variants() -> list[Variant]: name="fast_llm_triton", fwd=_run_input_inner_sparse_fwd_triton, fwd_bwd=_run_input_inner_sparse_fwd_bwd_triton, + output_postprocess=_mask_padded_rows, ) ) return variants From 5da9509a0c122fc46ec0a86324a27a93a97d0540 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 1 May 2026 06:20:50 -0400 Subject: [PATCH 33/41] Factor common variant/case scaffolding into utils helpers MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Each bench_*.py was repeating the same ~80-line scaffold per kernel: six _run_* wrappers (fwd / fwd_bwd / fp32 / triton variants), a five-Variant list, a _*_cases case-builder, and identical run()/__main__ boilerplate. New helpers in utils.py — make_cases, bench_main, and standard_fwd_bwd_pytorch_variants — let each bench file specify just the kernel signature (input_keys, grad_input_keys, grad_output_key, output_key) and an eager function. Triton variants stay explicit since their signatures vary too much for a uniform helper. Net: ~760 lines removed across the suite. Co-Authored-By: Claude Opus 4.7 (1M context) --- tools/benchmark/bench_entropy_loss.py | 466 +++++------------------- tools/benchmark/bench_grpo_loss.py | 125 ++----- tools/benchmark/bench_mlp_activation.py | 145 ++------ tools/benchmark/bench_normalization.py | 259 +++---------- tools/benchmark/bench_pointwise.py | 95 ++--- tools/benchmark/bench_rotary.py | 44 +-- tools/benchmark/bench_sparse_copy.py | 252 +++---------- tools/benchmark/bench_sparse_linear.py | 269 ++++---------- tools/benchmark/utils.py | 193 +++++++++- 9 files changed, 547 insertions(+), 1301 deletions(-) diff --git a/tools/benchmark/bench_entropy_loss.py b/tools/benchmark/bench_entropy_loss.py index 2e936bd04..8d22d6386 100644 --- a/tools/benchmark/bench_entropy_loss.py +++ b/tools/benchmark/bench_entropy_loss.py @@ -15,16 +15,14 @@ Shapes fix tokens=4096, sweep vocab size from Llama-2 (32K) to Llama-3 (128K). """ -from functools import partial - import torch import torch.nn.functional as F from fast_llm.functional.config import EntropyLossType, TargetFormat, TritonConfig from fast_llm.functional.triton.entropy_loss import triton_entropy_loss_forward_backward from fast_llm.functional.triton.z_loss import triton_z_loss_forward_backward -from tools.benchmark.runner import Case, Variant, run_benchmark -from tools.benchmark.utils import case_name, device +from tools.benchmark.runner import Variant +from tools.benchmark.utils import bench_main, device, make_cases, standard_fwd_bwd_pytorch_variants # (tokens, vocab_size) _SHAPES = [ @@ -53,96 +51,15 @@ def _make_distribution_inputs(tokens: int, vocab: int, dtype: torch.dtype) -> di } -# --------------------------------------------------------------------------- cross_entropy (labels) - - -def _ce_labels_eager(logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: - return F.cross_entropy(logits, labels) - - -_ce_labels_compiled_default = torch.compile(_ce_labels_eager, mode="default", dynamic=False) -_ce_labels_compiled_max = torch.compile(_ce_labels_eager, mode="max-autotune-no-cudagraphs", dynamic=False) - - -def _run_ce_labels_fwd(inputs: dict, fn) -> dict: - return {"loss": fn(inputs["logits"], inputs["labels"])} - - -def _run_ce_labels_fwd_fp32(inputs: dict) -> dict: - logits_fp32 = inputs["logits"].float().detach().requires_grad_(True) - return {"loss": _ce_labels_eager(logits_fp32, inputs["labels"])} - - def _reset_logits_grad(inputs: dict) -> None: inputs["logits"].grad = None -def _run_ce_labels_fwd_bwd(inputs: dict, fn) -> dict: - loss = fn(inputs["logits"], inputs["labels"]) - loss.backward() - return {"loss": loss.detach(), "grad_logits": inputs["logits"].grad} - - -def _run_ce_labels_fwd_bwd_fp32(inputs: dict) -> dict: - logits_fp32 = inputs["logits"].float().detach().requires_grad_(True) - loss = _ce_labels_eager(logits_fp32, inputs["labels"]) - loss.backward() - return {"loss": loss.detach(), "grad_logits": logits_fp32.grad} - - -def _run_ce_labels_fwd_triton(inputs: dict) -> dict: - loss, _ = triton_entropy_loss_forward_backward( - inputs["logits"], inputs["labels"], loss_mask=None, grad_output=None - ) - return {"loss": loss} - - -def _run_ce_labels_fwd_bwd_triton(inputs: dict) -> dict: - loss, grad_logits = triton_entropy_loss_forward_backward( - inputs["logits"], inputs["labels"], loss_mask=None, grad_output=1.0 - ) - return {"loss": loss, "grad_logits": grad_logits} - - -def _ce_labels_variants() -> list[Variant]: - variants = [ - Variant( - name="fp32_reference", - fwd=_run_ce_labels_fwd_fp32, - fwd_bwd=_run_ce_labels_fwd_bwd_fp32, - is_reference=True, - ), - Variant( - name="pytorch_eager", - fwd=lambda inputs: _run_ce_labels_fwd(inputs, _ce_labels_eager), - fwd_bwd=lambda inputs: _run_ce_labels_fwd_bwd(inputs, _ce_labels_eager), - reset_inputs=_reset_logits_grad, - ), - Variant( - name="pytorch_compiled", - fwd=lambda inputs: _run_ce_labels_fwd(inputs, _ce_labels_compiled_default), - fwd_bwd=lambda inputs: _run_ce_labels_fwd_bwd(inputs, _ce_labels_compiled_default), - reset_inputs=_reset_logits_grad, - ), - Variant( - name="pytorch_compiled_max", - fwd=lambda inputs: _run_ce_labels_fwd(inputs, _ce_labels_compiled_max), - fwd_bwd=lambda inputs: _run_ce_labels_fwd_bwd(inputs, _ce_labels_compiled_max), - reset_inputs=_reset_logits_grad, - ), - ] - if TritonConfig.enabled(): - variants.append( - Variant( - name="fast_llm_triton", - fwd=_run_ce_labels_fwd_triton, - fwd_bwd=_run_ce_labels_fwd_bwd_triton, - ) - ) - return variants +# --------------------------------------------------------------------------- eager kernels -# --------------------------------------------------------------------------- cross_entropy (logits / distribution) +def _ce_labels_eager(logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: + return F.cross_entropy(logits, labels) def _ce_dist_eager(logits: torch.Tensor, target_logits: torch.Tensor) -> torch.Tensor: @@ -150,185 +67,9 @@ def _ce_dist_eager(logits: torch.Tensor, target_logits: torch.Tensor) -> torch.T return F.cross_entropy(logits, target_logits.softmax(dim=-1)) -_ce_dist_compiled_default = torch.compile(_ce_dist_eager, mode="default", dynamic=False) -_ce_dist_compiled_max = torch.compile(_ce_dist_eager, mode="max-autotune-no-cudagraphs", dynamic=False) - - -def _run_dist_fwd(inputs: dict, fn) -> dict: - return {"loss": fn(inputs["logits"], inputs["target_logits"])} - - -def _run_ce_dist_fwd_fp32(inputs: dict) -> dict: - logits_fp32 = inputs["logits"].float().detach().requires_grad_(True) - return {"loss": _ce_dist_eager(logits_fp32, inputs["target_logits"].float())} - - -def _run_dist_fwd_bwd(inputs: dict, fn) -> dict: - loss = fn(inputs["logits"], inputs["target_logits"]) - loss.backward() - return {"loss": loss.detach(), "grad_logits": inputs["logits"].grad} - - -def _run_ce_dist_fwd_bwd_fp32(inputs: dict) -> dict: - logits_fp32 = inputs["logits"].float().detach().requires_grad_(True) - loss = _ce_dist_eager(logits_fp32, inputs["target_logits"].float()) - loss.backward() - return {"loss": loss.detach(), "grad_logits": logits_fp32.grad} - - -def _run_ce_dist_fwd_triton(inputs: dict) -> dict: - loss, _ = triton_entropy_loss_forward_backward( - inputs["logits"], - inputs["target_logits"], - loss_mask=None, - grad_output=None, - target_format=TargetFormat.logits, - entropy_loss_type=EntropyLossType.cross_entropy, - ) - return {"loss": loss} - - -def _run_ce_dist_fwd_bwd_triton(inputs: dict) -> dict: - loss, grad_logits = triton_entropy_loss_forward_backward( - inputs["logits"], - inputs["target_logits"], - loss_mask=None, - grad_output=1.0, - target_format=TargetFormat.logits, - entropy_loss_type=EntropyLossType.cross_entropy, - ) - return {"loss": loss, "grad_logits": grad_logits} - - -def _ce_dist_variants() -> list[Variant]: - variants = [ - Variant( - name="fp32_reference", - fwd=_run_ce_dist_fwd_fp32, - fwd_bwd=_run_ce_dist_fwd_bwd_fp32, - is_reference=True, - ), - Variant( - name="pytorch_eager", - fwd=lambda inputs: _run_dist_fwd(inputs, _ce_dist_eager), - fwd_bwd=lambda inputs: _run_dist_fwd_bwd(inputs, _ce_dist_eager), - reset_inputs=_reset_logits_grad, - ), - Variant( - name="pytorch_compiled", - fwd=lambda inputs: _run_dist_fwd(inputs, _ce_dist_compiled_default), - fwd_bwd=lambda inputs: _run_dist_fwd_bwd(inputs, _ce_dist_compiled_default), - reset_inputs=_reset_logits_grad, - ), - Variant( - name="pytorch_compiled_max", - fwd=lambda inputs: _run_dist_fwd(inputs, _ce_dist_compiled_max), - fwd_bwd=lambda inputs: _run_dist_fwd_bwd(inputs, _ce_dist_compiled_max), - reset_inputs=_reset_logits_grad, - ), - ] - if TritonConfig.enabled(): - variants.append( - Variant( - name="fast_llm_triton", - fwd=_run_ce_dist_fwd_triton, - fwd_bwd=_run_ce_dist_fwd_bwd_triton, - ) - ) - return variants - - -# --------------------------------------------------------------------------- reverse_kl (logits / distribution) - - def _reverse_kl_eager(logits: torch.Tensor, target_logits: torch.Tensor) -> torch.Tensor: """KL(q||p) where q = softmax(logits), p = softmax(target_logits).""" - return F.kl_div( - target_logits.log_softmax(dim=-1), - logits.softmax(dim=-1), - reduction="batchmean", - ) - - -_reverse_kl_compiled_default = torch.compile(_reverse_kl_eager, mode="default", dynamic=False) -_reverse_kl_compiled_max = torch.compile(_reverse_kl_eager, mode="max-autotune-no-cudagraphs", dynamic=False) - - -def _run_rkl_fwd_fp32(inputs: dict) -> dict: - logits_fp32 = inputs["logits"].float().detach().requires_grad_(True) - return {"loss": _reverse_kl_eager(logits_fp32, inputs["target_logits"].float())} - - -def _run_rkl_fwd_bwd_fp32(inputs: dict) -> dict: - logits_fp32 = inputs["logits"].float().detach().requires_grad_(True) - loss = _reverse_kl_eager(logits_fp32, inputs["target_logits"].float()) - loss.backward() - return {"loss": loss.detach(), "grad_logits": logits_fp32.grad} - - -def _run_rkl_fwd_triton(inputs: dict) -> dict: - loss, _ = triton_entropy_loss_forward_backward( - inputs["logits"], - inputs["target_logits"], - loss_mask=None, - grad_output=None, - target_format=TargetFormat.logits, - entropy_loss_type=EntropyLossType.reverse_kl, - ) - return {"loss": loss} - - -def _run_rkl_fwd_bwd_triton(inputs: dict) -> dict: - loss, grad_logits = triton_entropy_loss_forward_backward( - inputs["logits"], - inputs["target_logits"], - loss_mask=None, - grad_output=1.0, - target_format=TargetFormat.logits, - entropy_loss_type=EntropyLossType.reverse_kl, - ) - return {"loss": loss, "grad_logits": grad_logits} - - -def _reverse_kl_variants() -> list[Variant]: - variants = [ - Variant( - name="fp32_reference", - fwd=_run_rkl_fwd_fp32, - fwd_bwd=_run_rkl_fwd_bwd_fp32, - is_reference=True, - ), - Variant( - name="pytorch_eager", - fwd=lambda inputs: _run_dist_fwd(inputs, _reverse_kl_eager), - fwd_bwd=lambda inputs: _run_dist_fwd_bwd(inputs, _reverse_kl_eager), - reset_inputs=_reset_logits_grad, - ), - Variant( - name="pytorch_compiled", - fwd=lambda inputs: _run_dist_fwd(inputs, _reverse_kl_compiled_default), - fwd_bwd=lambda inputs: _run_dist_fwd_bwd(inputs, _reverse_kl_compiled_default), - reset_inputs=_reset_logits_grad, - ), - Variant( - name="pytorch_compiled_max", - fwd=lambda inputs: _run_dist_fwd(inputs, _reverse_kl_compiled_max), - fwd_bwd=lambda inputs: _run_dist_fwd_bwd(inputs, _reverse_kl_compiled_max), - reset_inputs=_reset_logits_grad, - ), - ] - if TritonConfig.enabled(): - variants.append( - Variant( - name="fast_llm_triton", - fwd=_run_rkl_fwd_triton, - fwd_bwd=_run_rkl_fwd_bwd_triton, - ) - ) - return variants - - -# --------------------------------------------------------------------------- z_loss + return F.kl_div(target_logits.log_softmax(dim=-1), logits.softmax(dim=-1), reduction="batchmean") def _z_loss_eager(logits: torch.Tensor) -> torch.Tensor: @@ -336,86 +77,70 @@ def _z_loss_eager(logits: torch.Tensor) -> torch.Tensor: return (log_z * log_z).mean() -_z_loss_compiled_default = torch.compile(_z_loss_eager, mode="default", dynamic=False) -_z_loss_compiled_max = torch.compile(_z_loss_eager, mode="max-autotune-no-cudagraphs", dynamic=False) - - -def _run_zl_fwd(inputs: dict, fn) -> dict: - return {"loss": fn(inputs["logits"])} - - -def _run_zl_fwd_fp32(inputs: dict) -> dict: - logits_fp32 = inputs["logits"].float().detach().requires_grad_(True) - return {"loss": _z_loss_eager(logits_fp32)} - +# --------------------------------------------------------------------------- variant assembly -def _run_zl_fwd_bwd(inputs: dict, fn) -> dict: - loss = fn(inputs["logits"]) - loss.backward() - return {"loss": loss.detach(), "grad_logits": inputs["logits"].grad} - - -def _run_zl_fwd_bwd_fp32(inputs: dict) -> dict: - logits_fp32 = inputs["logits"].float().detach().requires_grad_(True) - loss = _z_loss_eager(logits_fp32) - loss.backward() - return {"loss": loss.detach(), "grad_logits": logits_fp32.grad} +def _entropy_variants(eager_function, input_keys, triton_kwargs=None) -> list[Variant]: + variants = standard_fwd_bwd_pytorch_variants( + eager_function, + input_keys=input_keys, + grad_input_keys=("logits",), + output_key="loss", + reset_inputs=_reset_logits_grad, + ) + if TritonConfig.enabled(): + target_key = input_keys[1] + kwargs = triton_kwargs or {} -def _run_zl_fwd_triton(inputs: dict) -> dict: - loss, _ = triton_z_loss_forward_backward(inputs["logits"], loss_mask=None, grad_output=None) - return {"loss": loss} + def triton_fwd(inputs: dict) -> dict: + loss, _ = triton_entropy_loss_forward_backward( + inputs["logits"], inputs[target_key], loss_mask=None, grad_output=None, **kwargs + ) + return {"loss": loss} + def triton_fwd_bwd(inputs: dict) -> dict: + loss, grad_logits = triton_entropy_loss_forward_backward( + inputs["logits"], inputs[target_key], loss_mask=None, grad_output=1.0, **kwargs + ) + return {"loss": loss, "grad_logits": grad_logits} -def _run_zl_fwd_bwd_triton(inputs: dict) -> dict: - loss, grad_logits = triton_z_loss_forward_backward(inputs["logits"], loss_mask=None, grad_output=1.0) - return {"loss": loss, "grad_logits": grad_logits} + variants.append(Variant(name="fast_llm_triton", fwd=triton_fwd, fwd_bwd=triton_fwd_bwd)) + return variants def _z_loss_variants() -> list[Variant]: - variants = [ - Variant(name="fp32_reference", fwd=_run_zl_fwd_fp32, fwd_bwd=_run_zl_fwd_bwd_fp32, is_reference=True), - Variant( - name="pytorch_eager", - fwd=lambda inputs: _run_zl_fwd(inputs, _z_loss_eager), - fwd_bwd=lambda inputs: _run_zl_fwd_bwd(inputs, _z_loss_eager), - reset_inputs=_reset_logits_grad, - ), - Variant( - name="pytorch_compiled", - fwd=lambda inputs: _run_zl_fwd(inputs, _z_loss_compiled_default), - fwd_bwd=lambda inputs: _run_zl_fwd_bwd(inputs, _z_loss_compiled_default), - reset_inputs=_reset_logits_grad, - ), - Variant( - name="pytorch_compiled_max", - fwd=lambda inputs: _run_zl_fwd(inputs, _z_loss_compiled_max), - fwd_bwd=lambda inputs: _run_zl_fwd_bwd(inputs, _z_loss_compiled_max), - reset_inputs=_reset_logits_grad, - ), - ] + variants = standard_fwd_bwd_pytorch_variants( + _z_loss_eager, + input_keys=("logits",), + grad_input_keys=("logits",), + output_key="loss", + reset_inputs=_reset_logits_grad, + ) if TritonConfig.enabled(): - variants.append(Variant(name="fast_llm_triton", fwd=_run_zl_fwd_triton, fwd_bwd=_run_zl_fwd_bwd_triton)) - return variants + def triton_fwd(inputs: dict) -> dict: + loss, _ = triton_z_loss_forward_backward(inputs["logits"], loss_mask=None, grad_output=None) + return {"loss": loss} + + def triton_fwd_bwd(inputs: dict) -> dict: + loss, grad_logits = triton_z_loss_forward_backward(inputs["logits"], loss_mask=None, grad_output=1.0) + return {"loss": loss, "grad_logits": grad_logits} -# --------------------------------------------------------------------------- cases + variants.append(Variant(name="fast_llm_triton", fwd=triton_fwd, fwd_bwd=triton_fwd_bwd)) + return variants -def _bytes_per_element(dtype: torch.dtype) -> int: - return torch.tensor([], dtype=dtype).element_size() +# --------------------------------------------------------------------------- bytes / flops def _label_loss_bytes(tokens: int, vocab: int, dtype: torch.dtype) -> int: - """fwd+bwd: read logits, read labels (int32), write grad_logits.""" - element_size = _bytes_per_element(dtype) - return 2 * tokens * vocab * element_size + tokens * 4 + # fwd+bwd: read logits, read labels (int32), write grad_logits. + return 2 * tokens * vocab * dtype.itemsize + tokens * 4 def _dist_loss_bytes(tokens: int, vocab: int, dtype: torch.dtype) -> int: - """fwd+bwd: read logits, read target_logits, write grad_logits.""" - element_size = _bytes_per_element(dtype) - return 3 * tokens * vocab * element_size + # fwd+bwd: read logits, read target_logits, write grad_logits. + return 3 * tokens * vocab * dtype.itemsize def _entropy_loss_flops(tokens: int, vocab: int) -> int: @@ -423,40 +148,6 @@ def _entropy_loss_flops(tokens: int, vocab: int) -> int: return 4 * tokens * vocab -def _label_cases( - kernel_name: str, dtypes: tuple[torch.dtype, ...], shapes: list[tuple[int, int]] | None = None -) -> list[Case]: - shapes = shapes if shapes is not None else _SHAPES - return [ - Case( - name=case_name(kernel_name, (tokens, vocab), dtype), - make_inputs=partial(_make_label_inputs, tokens, vocab, dtype), - expected_bytes=_label_loss_bytes(tokens, vocab, dtype), - expected_flops=_entropy_loss_flops(tokens, vocab), - compute_dtype=dtype, - ) - for dtype in dtypes - for tokens, vocab in shapes - ] - - -def _dist_cases( - kernel_name: str, dtypes: tuple[torch.dtype, ...], shapes: list[tuple[int, int]] | None = None -) -> list[Case]: - shapes = shapes if shapes is not None else _SHAPES - return [ - Case( - name=case_name(kernel_name, (tokens, vocab), dtype), - make_inputs=partial(_make_distribution_inputs, tokens, vocab, dtype), - expected_bytes=_dist_loss_bytes(tokens, vocab, dtype), - expected_flops=_entropy_loss_flops(tokens, vocab), - compute_dtype=dtype, - ) - for dtype in dtypes - for tokens, vocab in shapes - ] - - # --------------------------------------------------------------------------- entry point @@ -465,36 +156,57 @@ def benchmarks( shapes: list[tuple[int, int]] | None = None, ) -> list[tuple[str, list, list]]: dtypes = tuple(dtypes) if dtypes else _DEFAULT_DTYPES + shapes = shapes if shapes is not None else _SHAPES return [ ( "entropy_loss: cross_entropy (labels)", - _label_cases("cross_entropy_labels", dtypes, shapes), - _ce_labels_variants(), + make_cases( + "cross_entropy_labels", dtypes, shapes, _make_label_inputs, _label_loss_bytes, _entropy_loss_flops + ), + _entropy_variants(_ce_labels_eager, input_keys=("logits", "labels")), ), ( "entropy_loss: cross_entropy (logits)", - _dist_cases("cross_entropy_logits", dtypes, shapes), - _ce_dist_variants(), + make_cases( + "cross_entropy_logits", + dtypes, + shapes, + _make_distribution_inputs, + _dist_loss_bytes, + _entropy_loss_flops, + ), + _entropy_variants( + _ce_dist_eager, + input_keys=("logits", "target_logits"), + triton_kwargs={ + "target_format": TargetFormat.logits, + "entropy_loss_type": EntropyLossType.cross_entropy, + }, + ), ), ( "entropy_loss: reverse_kl (logits)", - _dist_cases("reverse_kl_logits", dtypes, shapes), - _reverse_kl_variants(), + make_cases( + "reverse_kl_logits", dtypes, shapes, _make_distribution_inputs, _dist_loss_bytes, _entropy_loss_flops + ), + _entropy_variants( + _reverse_kl_eager, + input_keys=("logits", "target_logits"), + triton_kwargs={ + "target_format": TargetFormat.logits, + "entropy_loss_type": EntropyLossType.reverse_kl, + }, + ), + ), + ( + "entropy_loss: z_loss", + make_cases("z_loss", dtypes, shapes, _make_label_inputs, _label_loss_bytes, _entropy_loss_flops), + _z_loss_variants(), ), - ("entropy_loss: z_loss", _label_cases("z_loss", dtypes, shapes), _z_loss_variants()), ] -def run( - verbose: bool = False, - dtypes: tuple[torch.dtype, ...] | None = None, - shapes: list[tuple[int, int]] | None = None, - warmup_ms: float = 25.0, - rep_ms: float = 100.0, - min_reps: int = 5, -) -> None: - for name, cases, variants in benchmarks(dtypes, shapes): - run_benchmark(name, cases, variants, verbose=verbose, warmup_ms=warmup_ms, rep_ms=rep_ms, min_reps=min_reps) +run = bench_main(benchmarks) if __name__ == "__main__": diff --git a/tools/benchmark/bench_grpo_loss.py b/tools/benchmark/bench_grpo_loss.py index 969a9217f..7725e1ab2 100644 --- a/tools/benchmark/bench_grpo_loss.py +++ b/tools/benchmark/bench_grpo_loss.py @@ -18,14 +18,12 @@ Shapes match bench_entropy_loss: tokens=4096, vocab swept over 32K/64K/128K. """ -from functools import partial - import torch from fast_llm.functional.config import TritonConfig from fast_llm.functional.triton.grpo_loss import triton_grpo_loss_forward_backward -from tools.benchmark.runner import Case, Variant, run_benchmark -from tools.benchmark.utils import case_name, device +from tools.benchmark.runner import Variant +from tools.benchmark.utils import bench_main, device, make_cases, standard_fwd_bwd_pytorch_variants _SHAPES = [ (4096, 32768), @@ -62,43 +60,11 @@ def _grpo_eager(logits: torch.Tensor, labels: torch.Tensor, advantages: torch.Te return per_token_loss.mean() -_grpo_compiled_default = torch.compile(_grpo_eager, mode="default", dynamic=False) -_grpo_compiled_max = torch.compile(_grpo_eager, mode="max-autotune-no-cudagraphs", dynamic=False) - - -def _run_fwd(inputs: dict, fn) -> dict: - return {"loss": fn(inputs["logits"], inputs["labels"], inputs["advantages"], inputs["old_log_probs"])} - - -def _run_fwd_fp32(inputs: dict) -> dict: - return { - "loss": _grpo_eager( - inputs["logits"].float().detach().requires_grad_(), - inputs["labels"], - inputs["advantages"], - inputs["old_log_probs"], - ) - } - - def _reset_logits_grad(inputs: dict) -> None: inputs["logits"].grad = None -def _run_fwd_bwd(inputs: dict, fn) -> dict: - loss = fn(inputs["logits"], inputs["labels"], inputs["advantages"], inputs["old_log_probs"]) - loss.backward() - return {"loss": loss.detach(), "grad_logits": inputs["logits"].grad} - - -def _run_fwd_bwd_fp32(inputs: dict) -> dict: - logits_fp32 = inputs["logits"].float().detach().requires_grad_() - loss = _grpo_eager(logits_fp32, inputs["labels"], inputs["advantages"], inputs["old_log_probs"]) - loss.backward() - return {"loss": loss.detach(), "grad_logits": logits_fp32.grad} - - -def _run_fwd_triton(inputs: dict) -> dict: +def _triton_fwd(inputs: dict) -> dict: loss, _, _ = triton_grpo_loss_forward_backward( inputs["logits"], inputs["labels"], @@ -111,7 +77,7 @@ def _run_fwd_triton(inputs: dict) -> dict: return {"loss": loss} -def _run_fwd_bwd_triton(inputs: dict) -> dict: +def _triton_fwd_bwd(inputs: dict) -> dict: loss, grad_logits, _ = triton_grpo_loss_forward_backward( inputs["logits"], inputs["labels"], @@ -125,51 +91,21 @@ def _run_fwd_bwd_triton(inputs: dict) -> dict: def _grpo_variants() -> list[Variant]: - variants = [ - Variant( - name="fp32_reference", - fwd=_run_fwd_fp32, - fwd_bwd=_run_fwd_bwd_fp32, - is_reference=True, - ), - Variant( - name="pytorch_eager", - fwd=lambda inputs: _run_fwd(inputs, _grpo_eager), - fwd_bwd=lambda inputs: _run_fwd_bwd(inputs, _grpo_eager), - reset_inputs=_reset_logits_grad, - ), - Variant( - name="pytorch_compiled", - fwd=lambda inputs: _run_fwd(inputs, _grpo_compiled_default), - fwd_bwd=lambda inputs: _run_fwd_bwd(inputs, _grpo_compiled_default), - reset_inputs=_reset_logits_grad, - ), - Variant( - name="pytorch_compiled_max", - fwd=lambda inputs: _run_fwd(inputs, _grpo_compiled_max), - fwd_bwd=lambda inputs: _run_fwd_bwd(inputs, _grpo_compiled_max), - reset_inputs=_reset_logits_grad, - ), - ] + variants = standard_fwd_bwd_pytorch_variants( + _grpo_eager, + input_keys=("logits", "labels", "advantages", "old_log_probs"), + grad_input_keys=("logits",), + output_key="loss", + reset_inputs=_reset_logits_grad, + ) if TritonConfig.enabled(): - variants.append( - Variant( - name="fast_llm_triton", - fwd=_run_fwd_triton, - fwd_bwd=_run_fwd_bwd_triton, - ) - ) + variants.append(Variant(name="fast_llm_triton", fwd=_triton_fwd, fwd_bwd=_triton_fwd_bwd)) return variants -def _bytes_per_element(dtype: torch.dtype) -> int: - return torch.tensor([], dtype=dtype).element_size() - - def _grpo_bytes(tokens: int, vocab: int, dtype: torch.dtype) -> int: - element_size = _bytes_per_element(dtype) # fwd: read logits + bwd: read logits + write grad_logits - logit_traffic = 3 * tokens * vocab * element_size + logit_traffic = 3 * tokens * vocab * dtype.itemsize # labels (int64), advantages (fp32), old_log_probs (fp32) scalar_traffic = tokens * (8 + 4 + 4) return logit_traffic + scalar_traffic @@ -180,39 +116,22 @@ def _grpo_flops(tokens: int, vocab: int) -> int: return 14 * tokens * vocab -def _grpo_cases(dtypes: tuple[torch.dtype, ...], shapes: list[tuple[int, int]] | None = None) -> list[Case]: - shapes = shapes if shapes is not None else _SHAPES - return [ - Case( - name=case_name("grpo_loss", (tokens, vocab), dtype), - make_inputs=partial(_make_grpo_inputs, tokens, vocab, dtype), - expected_bytes=_grpo_bytes(tokens, vocab, dtype), - expected_flops=_grpo_flops(tokens, vocab), - compute_dtype=dtype, - ) - for dtype in dtypes - for tokens, vocab in shapes - ] - - def benchmarks( dtypes: tuple[torch.dtype, ...] | None = None, shapes: list[tuple[int, int]] | None = None, ) -> list[tuple[str, list, list]]: dtypes = tuple(dtypes) if dtypes else _DEFAULT_DTYPES - return [("grpo_loss", _grpo_cases(dtypes, shapes), _grpo_variants())] + shapes = shapes if shapes is not None else _SHAPES + return [ + ( + "grpo_loss", + make_cases("grpo_loss", dtypes, shapes, _make_grpo_inputs, _grpo_bytes, _grpo_flops), + _grpo_variants(), + ) + ] -def run( - verbose: bool = False, - dtypes: tuple[torch.dtype, ...] | None = None, - shapes: list[tuple[int, int]] | None = None, - warmup_ms: float = 25.0, - rep_ms: float = 100.0, - min_reps: int = 5, -) -> None: - for name, cases, variants in benchmarks(dtypes, shapes): - run_benchmark(name, cases, variants, verbose=verbose, warmup_ms=warmup_ms, rep_ms=rep_ms, min_reps=min_reps) +run = bench_main(benchmarks) if __name__ == "__main__": diff --git a/tools/benchmark/bench_mlp_activation.py b/tools/benchmark/bench_mlp_activation.py index ed96b6476..411dd3355 100644 --- a/tools/benchmark/bench_mlp_activation.py +++ b/tools/benchmark/bench_mlp_activation.py @@ -15,8 +15,6 @@ Shapes fix tokens=8192 and sweep ffn_dim across typical MLP widths. """ -from functools import partial - import torch from fast_llm.functional.config import ActivationType, TritonConfig @@ -25,8 +23,8 @@ triton_mlp_activation_autograd, triton_mlp_activation_forward, ) -from tools.benchmark.runner import Case, Variant, run_benchmark -from tools.benchmark.utils import case_name, device +from tools.benchmark.runner import Variant +from tools.benchmark.utils import bench_main, device, make_cases, standard_fwd_bwd_pytorch_variants # (tokens, ffn_dim) — input tensor has shape (tokens, 2*ffn_dim) for gated. _SHAPES = [ @@ -39,114 +37,43 @@ _DEFAULT_DTYPES = (torch.bfloat16,) -# --------------------------------------------------------------------------- inputs - - def _make_mlp_inputs(tokens: int, ffn_dim: int, dtype: torch.dtype) -> dict: return { - "input_": torch.randn(tokens, 2 * ffn_dim, dtype=dtype, device=device(), requires_grad=True), + "input": torch.randn(tokens, 2 * ffn_dim, dtype=dtype, device=device(), requires_grad=True), "grad_output": torch.randn(tokens, ffn_dim, dtype=dtype, device=device()), "gated": True, "activation_type": _ACTIVATION, } -# --------------------------------------------------------------------------- forward wrappers - - -def _pytorch_fwd(input_: torch.Tensor, gated: bool, activation_type: ActivationType) -> torch.Tensor: - return torch_mlp_activation(input_, gated, activation_type) - - -_pytorch_compiled_default = torch.compile(_pytorch_fwd, mode="default", dynamic=False) -_pytorch_compiled_max = torch.compile(_pytorch_fwd, mode="max-autotune-no-cudagraphs", dynamic=False) - - -def _run_fwd(inputs: dict, fn) -> dict: - return {"output": fn(inputs["input_"], inputs["gated"], inputs["activation_type"])} - - -def _run_fwd_fp32(inputs: dict) -> dict: - return {"output": _pytorch_fwd(inputs["input_"].float(), inputs["gated"], inputs["activation_type"])} - - -def _run_fwd_triton(inputs: dict) -> dict: - output, _ = triton_mlp_activation_forward(inputs["input_"], inputs["gated"], inputs["activation_type"]) +def _triton_fwd(inputs: dict) -> dict: + output, _ = triton_mlp_activation_forward(inputs["input"], inputs["gated"], inputs["activation_type"]) return {"output": output} -# --------------------------------------------------------------------------- fwd+bwd wrappers - - -def _run_fwd_bwd(inputs: dict, fn) -> dict: - output = fn(inputs["input_"], inputs["gated"], inputs["activation_type"]) - output.backward(inputs["grad_output"]) - return {"output": output.detach(), "grad_input": inputs["input_"].grad} - - -def _run_fwd_bwd_fp32(inputs: dict) -> dict: - input_fp32 = inputs["input_"].float().detach().requires_grad_(True) - output = _pytorch_fwd(input_fp32, inputs["gated"], inputs["activation_type"]) - output.backward(inputs["grad_output"].float()) - return {"output": output.detach(), "grad_input": input_fp32.grad} - - -def _run_fwd_bwd_triton(inputs: dict) -> dict: - output = triton_mlp_activation_autograd(inputs["input_"], inputs["gated"], inputs["activation_type"]) +def _triton_fwd_bwd(inputs: dict) -> dict: + output = triton_mlp_activation_autograd(inputs["input"], inputs["gated"], inputs["activation_type"]) output.backward(inputs["grad_output"]) - return {"output": output.detach(), "grad_input": inputs["input_"].grad} - - -# --------------------------------------------------------------------------- variants + return {"output": output.detach(), "grad_input": inputs["input"].grad} def _mlp_activation_variants() -> list[Variant]: - variants = [ - Variant( - name="fp32_reference", - fwd=_run_fwd_fp32, - fwd_bwd=_run_fwd_bwd_fp32, - is_reference=True, - ), - Variant( - name="pytorch_eager", - fwd=lambda inputs: _run_fwd(inputs, _pytorch_fwd), - fwd_bwd=lambda inputs: _run_fwd_bwd(inputs, _pytorch_fwd), - ), - Variant( - name="pytorch_compiled", - fwd=lambda inputs: _run_fwd(inputs, _pytorch_compiled_default), - fwd_bwd=lambda inputs: _run_fwd_bwd(inputs, _pytorch_compiled_default), - ), - Variant( - name="pytorch_compiled_max", - fwd=lambda inputs: _run_fwd(inputs, _pytorch_compiled_max), - fwd_bwd=lambda inputs: _run_fwd_bwd(inputs, _pytorch_compiled_max), - ), - ] + variants = standard_fwd_bwd_pytorch_variants( + torch_mlp_activation, + input_keys=("input", "gated", "activation_type"), + grad_input_keys=("input",), + grad_output_key="grad_output", + ) if TritonConfig.enabled(): - variants.append( - Variant( - name="fast_llm_triton", - fwd=_run_fwd_triton, - fwd_bwd=_run_fwd_bwd_triton, - ) - ) + variants.append(Variant(name="fast_llm_triton", fwd=_triton_fwd, fwd_bwd=_triton_fwd_bwd)) return variants -# --------------------------------------------------------------------------- cases - - -def _bytes_per_element(dtype: torch.dtype) -> int: - return torch.tensor([], dtype=dtype).element_size() - - def _mlp_activation_bytes(tokens: int, ffn_dim: int, dtype: torch.dtype) -> int: """fwd: read input (2*ffn_dim) + write output (ffn_dim). bwd: read grad_output (ffn_dim) + read input (2*ffn_dim) + write grad_input (2*ffn_dim). Total: 8 × tokens × ffn_dim × elem_size.""" - return 8 * tokens * ffn_dim * _bytes_per_element(dtype) + return 8 * tokens * ffn_dim * dtype.itemsize def _mlp_activation_flops(tokens: int, ffn_dim: int) -> int: @@ -154,42 +81,24 @@ def _mlp_activation_flops(tokens: int, ffn_dim: int) -> int: return 14 * tokens * ffn_dim -def _mlp_activation_cases(dtypes: tuple[torch.dtype, ...], shapes: list[tuple[int, int]] | None = None) -> list[Case]: - shapes = shapes if shapes is not None else _SHAPES - return [ - Case( - name=case_name("mlp_activation", (tokens, ffn_dim), dtype), - make_inputs=partial(_make_mlp_inputs, tokens, ffn_dim, dtype), - expected_bytes=_mlp_activation_bytes(tokens, ffn_dim, dtype), - expected_flops=_mlp_activation_flops(tokens, ffn_dim), - compute_dtype=dtype, - ) - for dtype in dtypes - for tokens, ffn_dim in shapes - ] - - -# --------------------------------------------------------------------------- entry point - - def benchmarks( dtypes: tuple[torch.dtype, ...] | None = None, shapes: list[tuple[int, int]] | None = None, ) -> list[tuple[str, list, list]]: dtypes = tuple(dtypes) if dtypes else _DEFAULT_DTYPES - return [("mlp_activation (gated silu)", _mlp_activation_cases(dtypes, shapes), _mlp_activation_variants())] + shapes = shapes if shapes is not None else _SHAPES + return [ + ( + "mlp_activation (gated silu)", + make_cases( + "mlp_activation", dtypes, shapes, _make_mlp_inputs, _mlp_activation_bytes, _mlp_activation_flops + ), + _mlp_activation_variants(), + ) + ] -def run( - verbose: bool = False, - dtypes: tuple[torch.dtype, ...] | None = None, - shapes: list[tuple[int, int]] | None = None, - warmup_ms: float = 25.0, - rep_ms: float = 100.0, - min_reps: int = 5, -) -> None: - for name, cases, variants in benchmarks(dtypes, shapes): - run_benchmark(name, cases, variants, verbose=verbose, warmup_ms=warmup_ms, rep_ms=rep_ms, min_reps=min_reps) +run = bench_main(benchmarks) if __name__ == "__main__": diff --git a/tools/benchmark/bench_normalization.py b/tools/benchmark/bench_normalization.py index cced806a0..a7b904ce5 100644 --- a/tools/benchmark/bench_normalization.py +++ b/tools/benchmark/bench_normalization.py @@ -15,8 +15,6 @@ - fast_llm_triton: triton_normalization_autograd """ -from functools import partial - import torch from fast_llm.functional.config import TritonConfig @@ -28,8 +26,8 @@ fast_normalization_available, fused_normalization_available, ) -from tools.benchmark.runner import Case, Variant, run_benchmark -from tools.benchmark.utils import case_name, device +from tools.benchmark.runner import Variant +from tools.benchmark.utils import bench_main, device, make_cases, standard_fwd_bwd_pytorch_variants # Activation shape (batch*seq, hidden). Numel fixed at 32M to mimic a constant # training memory budget across model widths; hidden swept from 1K to 16K covers @@ -45,9 +43,6 @@ _EPS = 1e-5 -# --------------------------------------------------------------------------- input setup - - def _setup_param(tensor: torch.Tensor) -> torch.Tensor: """Triton's normalization backward writes weight/bias gradients to a `grad_buffer` attribute (Fast-LLM convention) instead of autograd's `.grad`. @@ -57,17 +52,9 @@ def _setup_param(tensor: torch.Tensor) -> torch.Tensor: return tensor -def _to_fp32_input(tensor: torch.Tensor) -> torch.Tensor: - return tensor.float().detach().requires_grad_() - - -def _to_fp32_param(tensor: torch.Tensor) -> torch.Tensor: - return _setup_param(tensor.float().detach().requires_grad_()) - - def _make_layer_norm_inputs(rows: int, cols: int, dtype: torch.dtype) -> dict: return { - "input_": torch.randn(rows, cols, dtype=dtype, device=device(), requires_grad=True), + "input": torch.randn(rows, cols, dtype=dtype, device=device(), requires_grad=True), "weight": _setup_param(torch.randn(cols, dtype=dtype, device=device(), requires_grad=True)), "bias": _setup_param(torch.zeros(cols, dtype=dtype, device=device(), requires_grad=True)), "grad_output": torch.randn(rows, cols, dtype=dtype, device=device()), @@ -76,32 +63,12 @@ def _make_layer_norm_inputs(rows: int, cols: int, dtype: torch.dtype) -> dict: def _make_rms_norm_inputs(rows: int, cols: int, dtype: torch.dtype) -> dict: return { - "input_": torch.randn(rows, cols, dtype=dtype, device=device(), requires_grad=True), + "input": torch.randn(rows, cols, dtype=dtype, device=device(), requires_grad=True), "weight": _setup_param(torch.randn(cols, dtype=dtype, device=device(), requires_grad=True)), "grad_output": torch.randn(rows, cols, dtype=dtype, device=device()), } -def _layer_norm_inputs_fp32(inputs: dict) -> dict: - return { - "input_": _to_fp32_input(inputs["input_"]), - "weight": _to_fp32_param(inputs["weight"]), - "bias": _to_fp32_param(inputs["bias"]), - "grad_output": inputs["grad_output"].float(), - } - - -def _rms_norm_inputs_fp32(inputs: dict) -> dict: - return { - "input_": _to_fp32_input(inputs["input_"]), - "weight": _to_fp32_param(inputs["weight"]), - "grad_output": inputs["grad_output"].float(), - } - - -# --------------------------------------------------------------------------- forward functions - - def _layer_norm_eager(input_, weight, bias): return torch.layer_norm(input_, weight.shape, weight, bias, _EPS) @@ -110,14 +77,6 @@ def _rms_norm_eager(input_, weight): return torch.rms_norm(input_, weight.shape, weight, _EPS) -def _layer_norm_triton(input_, weight, bias): - return triton_normalization_autograd(input_, weight, bias, _EPS, True, False) - - -def _rms_norm_triton(input_, weight): - return triton_normalization_autograd(input_, weight, None, _EPS, True, False) - - def _layer_norm_apex_fused(input_, weight, bias): return FusedLayerNorm.apply(input_, weight.shape, weight, bias, _EPS) @@ -130,168 +89,93 @@ def _rms_norm_apex_fused(input_, weight): return FusedRMSNorm.apply(input_, weight.shape, weight, _EPS) -_layer_compiled_default = torch.compile(_layer_norm_eager, mode="default", dynamic=False) -_layer_compiled_max = torch.compile(_layer_norm_eager, mode="max-autotune-no-cudagraphs", dynamic=False) -_rms_compiled_default = torch.compile(_rms_norm_eager, mode="default", dynamic=False) -_rms_compiled_max = torch.compile(_rms_norm_eager, mode="max-autotune-no-cudagraphs", dynamic=False) - - -# --------------------------------------------------------------------------- variant wrappers - - def _param_grad(param: torch.Tensor) -> torch.Tensor: """Pull the parameter gradient from wherever the kernel wrote it. Triton writes to `grad_buffer`; autograd writes to `.grad`.""" return param.grad if param.grad is not None else param.grad_buffer -def _run_layer_fwd(inputs: dict, fn) -> dict: - return {"output": fn(inputs["input_"], inputs["weight"], inputs["bias"])} +def _layer_norm_triton_fwd(inputs: dict) -> dict: + return { + "output": triton_normalization_autograd(inputs["input"], inputs["weight"], inputs["bias"], _EPS, True, False) + } -def _run_layer_fwd_bwd(inputs: dict, fn) -> dict: - output = fn(inputs["input_"], inputs["weight"], inputs["bias"]) +def _layer_norm_triton_fwd_bwd(inputs: dict) -> dict: + output = triton_normalization_autograd(inputs["input"], inputs["weight"], inputs["bias"], _EPS, True, False) output.backward(inputs["grad_output"]) return { - "grad_input": inputs["input_"].grad, + "output": output.detach(), + "grad_input": inputs["input"].grad, "grad_weight": _param_grad(inputs["weight"]), "grad_bias": _param_grad(inputs["bias"]), } -def _run_rms_fwd(inputs: dict, fn) -> dict: - return {"output": fn(inputs["input_"], inputs["weight"])} +def _rms_norm_triton_fwd(inputs: dict) -> dict: + return {"output": triton_normalization_autograd(inputs["input"], inputs["weight"], None, _EPS, True, False)} -def _run_rms_fwd_bwd(inputs: dict, fn) -> dict: - output = fn(inputs["input_"], inputs["weight"]) +def _rms_norm_triton_fwd_bwd(inputs: dict) -> dict: + output = triton_normalization_autograd(inputs["input"], inputs["weight"], None, _EPS, True, False) output.backward(inputs["grad_output"]) return { - "grad_input": inputs["input_"].grad, + "output": output.detach(), + "grad_input": inputs["input"].grad, "grad_weight": _param_grad(inputs["weight"]), } -# --------------------------------------------------------------------------- variants - - def _layer_norm_variants() -> list[Variant]: - variants = [ - Variant( - name="fp32_reference", - fwd=lambda inputs: _run_layer_fwd(_layer_norm_inputs_fp32(inputs), _layer_norm_eager), - fwd_bwd=lambda inputs: _run_layer_fwd_bwd(_layer_norm_inputs_fp32(inputs), _layer_norm_eager), - is_reference=True, - ), - Variant( - name="pytorch_eager", - fwd=lambda inputs: _run_layer_fwd(inputs, _layer_norm_eager), - fwd_bwd=lambda inputs: _run_layer_fwd_bwd(inputs, _layer_norm_eager), - ), - Variant( - name="pytorch_compiled", - fwd=lambda inputs: _run_layer_fwd(inputs, _layer_compiled_default), - fwd_bwd=lambda inputs: _run_layer_fwd_bwd(inputs, _layer_compiled_default), - ), - Variant( - name="pytorch_compiled_max", - fwd=lambda inputs: _run_layer_fwd(inputs, _layer_compiled_max), - fwd_bwd=lambda inputs: _run_layer_fwd_bwd(inputs, _layer_compiled_max), - ), - ] + extras: dict = {} if fused_normalization_available: - variants.append( - Variant( - name="apex_fused", - fwd=lambda inputs: _run_layer_fwd(inputs, _layer_norm_apex_fused), - fwd_bwd=lambda inputs: _run_layer_fwd_bwd(inputs, _layer_norm_apex_fused), - ) - ) + extras["apex_fused"] = _layer_norm_apex_fused if fast_normalization_available: # apex_fast only supports widths in _PERSIST_LN_SIZES; all shapes in _SHAPES qualify. - variants.append( - Variant( - name="apex_fast", - fwd=lambda inputs: _run_layer_fwd(inputs, _layer_norm_apex_fast), - fwd_bwd=lambda inputs: _run_layer_fwd_bwd(inputs, _layer_norm_apex_fast), - ) - ) + extras["apex_fast"] = _layer_norm_apex_fast + variants = standard_fwd_bwd_pytorch_variants( + _layer_norm_eager, + input_keys=("input", "weight", "bias"), + grad_input_keys=("input", "weight", "bias"), + grad_output_key="grad_output", + extra_functions=extras, + ) if TritonConfig.enabled(): variants.append( - Variant( - name="fast_llm_triton", - fwd=lambda inputs: _run_layer_fwd(inputs, _layer_norm_triton), - fwd_bwd=lambda inputs: _run_layer_fwd_bwd(inputs, _layer_norm_triton), - ) + Variant(name="fast_llm_triton", fwd=_layer_norm_triton_fwd, fwd_bwd=_layer_norm_triton_fwd_bwd) ) return variants def _rms_norm_variants() -> list[Variant]: - variants = [ - Variant( - name="fp32_reference", - fwd=lambda inputs: _run_rms_fwd(_rms_norm_inputs_fp32(inputs), _rms_norm_eager), - fwd_bwd=lambda inputs: _run_rms_fwd_bwd(_rms_norm_inputs_fp32(inputs), _rms_norm_eager), - is_reference=True, - ), - Variant( - name="pytorch_eager", - fwd=lambda inputs: _run_rms_fwd(inputs, _rms_norm_eager), - fwd_bwd=lambda inputs: _run_rms_fwd_bwd(inputs, _rms_norm_eager), - ), - Variant( - name="pytorch_compiled", - fwd=lambda inputs: _run_rms_fwd(inputs, _rms_compiled_default), - fwd_bwd=lambda inputs: _run_rms_fwd_bwd(inputs, _rms_compiled_default), - ), - Variant( - name="pytorch_compiled_max", - fwd=lambda inputs: _run_rms_fwd(inputs, _rms_compiled_max), - fwd_bwd=lambda inputs: _run_rms_fwd_bwd(inputs, _rms_compiled_max), - ), - ] + extras: dict = {} if fused_normalization_available: - variants.append( - Variant( - name="apex_fused", - fwd=lambda inputs: _run_rms_fwd(inputs, _rms_norm_apex_fused), - fwd_bwd=lambda inputs: _run_rms_fwd_bwd(inputs, _rms_norm_apex_fused), - ) - ) + extras["apex_fused"] = _rms_norm_apex_fused + variants = standard_fwd_bwd_pytorch_variants( + _rms_norm_eager, + input_keys=("input", "weight"), + grad_input_keys=("input", "weight"), + grad_output_key="grad_output", + extra_functions=extras, + ) if TritonConfig.enabled(): - variants.append( - Variant( - name="fast_llm_triton", - fwd=lambda inputs: _run_rms_fwd(inputs, _rms_norm_triton), - fwd_bwd=lambda inputs: _run_rms_fwd_bwd(inputs, _rms_norm_triton), - ) - ) + variants.append(Variant(name="fast_llm_triton", fwd=_rms_norm_triton_fwd, fwd_bwd=_rms_norm_triton_fwd_bwd)) return variants -# --------------------------------------------------------------------------- cases - - -def _bytes_per_element(dtype: torch.dtype) -> int: - return torch.tensor([], dtype=dtype).element_size() - - def _layer_norm_bytes(rows: int, cols: int, dtype: torch.dtype) -> int: """Approximate fwd+bwd memory traffic for LayerNorm. fwd reads input + weight + bias and writes output (also stores inv_var). bwd reads grad_output, output, weight, bias, inv_var; writes grad_input, grad_weight, grad_bias. Activation tensors dominate.""" - element_size = _bytes_per_element(dtype) - activations = 4 * rows * cols * element_size # fwd in/out + bwd grad_in/out - parameters = 6 * cols * element_size # weight, bias × (read + grad write) twice + activations = 4 * rows * cols * dtype.itemsize # fwd in/out + bwd grad_in/out + parameters = 6 * cols * dtype.itemsize # weight, bias × (read + grad write) twice return activations + parameters def _rms_norm_bytes(rows: int, cols: int, dtype: torch.dtype) -> int: - element_size = _bytes_per_element(dtype) - activations = 4 * rows * cols * element_size - parameters = 3 * cols * element_size + activations = 4 * rows * cols * dtype.itemsize + parameters = 3 * cols * dtype.itemsize return activations + parameters @@ -307,60 +191,27 @@ def _rms_norm_flops(rows: int, cols: int) -> int: return 15 * rows * cols -def _layer_norm_cases(dtypes: tuple[torch.dtype, ...], shapes: list[tuple[int, int]] | None = None) -> list[Case]: - shapes = shapes if shapes is not None else _SHAPES - return [ - Case( - name=case_name("layer_norm", shape, dtype), - make_inputs=partial(_make_layer_norm_inputs, shape[0], shape[1], dtype), - expected_bytes=_layer_norm_bytes(shape[0], shape[1], dtype), - expected_flops=_layer_norm_flops(shape[0], shape[1]), - compute_dtype=dtype, - ) - for dtype in dtypes - for shape in shapes - ] - - -def _rms_norm_cases(dtypes: tuple[torch.dtype, ...], shapes: list[tuple[int, int]] | None = None) -> list[Case]: - shapes = shapes if shapes is not None else _SHAPES - return [ - Case( - name=case_name("rms_norm", shape, dtype), - make_inputs=partial(_make_rms_norm_inputs, shape[0], shape[1], dtype), - expected_bytes=_rms_norm_bytes(shape[0], shape[1], dtype), - expected_flops=_rms_norm_flops(shape[0], shape[1]), - compute_dtype=dtype, - ) - for dtype in dtypes - for shape in shapes - ] - - -# --------------------------------------------------------------------------- entry point - - def benchmarks( dtypes: tuple[torch.dtype, ...] | None = None, shapes: list[tuple[int, int]] | None = None, ) -> list[tuple[str, list, list]]: dtypes = tuple(dtypes) if dtypes else _DEFAULT_DTYPES + shapes = shapes if shapes is not None else _SHAPES return [ - ("normalization: layer_norm", _layer_norm_cases(dtypes, shapes), _layer_norm_variants()), - ("normalization: rms_norm", _rms_norm_cases(dtypes, shapes), _rms_norm_variants()), + ( + "normalization: layer_norm", + make_cases("layer_norm", dtypes, shapes, _make_layer_norm_inputs, _layer_norm_bytes, _layer_norm_flops), + _layer_norm_variants(), + ), + ( + "normalization: rms_norm", + make_cases("rms_norm", dtypes, shapes, _make_rms_norm_inputs, _rms_norm_bytes, _rms_norm_flops), + _rms_norm_variants(), + ), ] -def run( - verbose: bool = False, - dtypes: tuple[torch.dtype, ...] | None = None, - shapes: list[tuple[int, int]] | None = None, - warmup_ms: float = 25.0, - rep_ms: float = 100.0, - min_reps: int = 5, -) -> None: - for name, cases, variants in benchmarks(dtypes, shapes): - run_benchmark(name, cases, variants, verbose=verbose, warmup_ms=warmup_ms, rep_ms=rep_ms, min_reps=min_reps) +run = bench_main(benchmarks) if __name__ == "__main__": diff --git a/tools/benchmark/bench_pointwise.py b/tools/benchmark/bench_pointwise.py index e00702617..0e62a2b2b 100644 --- a/tools/benchmark/bench_pointwise.py +++ b/tools/benchmark/bench_pointwise.py @@ -7,13 +7,10 @@ documented as being ~2x faster than the PyTorch equivalent on A100. """ -from functools import partial - import torch from fast_llm.functional.triton.pointwise import triton_add, triton_copy, triton_fill -from tools.benchmark.runner import Case, run_benchmark -from tools.benchmark.utils import case_name, device, standard_fwd_variants +from tools.benchmark.utils import bench_main, device, make_cases, standard_fwd_variants # Sizes span from L2-resident to comfortably HBM-bound, in 4× steps so the # regime transitions (L2 → HBM, mid-HBM → saturated-HBM) are visible. @@ -36,27 +33,17 @@ def _copy_eager(input_: torch.Tensor, out: torch.Tensor) -> torch.Tensor: def _make_copy_inputs(numel: int, dtype: torch.dtype) -> dict: input_ = torch.randn(numel, dtype=dtype, device=device()) - out = torch.empty_like(input_) - return {"input_": input_, "out": out} + return {"input_": input_, "out": torch.empty_like(input_)} -def _copy_cases(dtypes: tuple[torch.dtype, ...], shapes: list[int] | None = None) -> list[Case]: - sizes = shapes if shapes is not None else _SIZES_NUMEL - return [ - Case( - name=case_name("copy", (numel,), dtype), - make_inputs=partial(_make_copy_inputs, numel, dtype), - # Read input + write output. - expected_bytes=2 * numel * torch.tensor([], dtype=dtype).element_size(), - ) - for dtype in dtypes - for numel in sizes - ] +def _copy_bytes(numel: int, dtype: torch.dtype) -> int: + # Read input + write output. + return 2 * numel * dtype.itemsize _COPY_VARIANTS = standard_fwd_variants( - eager_fn=_copy_eager, - triton_fn=triton_copy, + eager_function=_copy_eager, + triton_function=triton_copy, unpack=lambda inputs: (inputs["input_"], inputs["out"]), ) @@ -72,23 +59,14 @@ def _make_fill_inputs(numel: int, dtype: torch.dtype) -> dict: return {"input_": torch.empty(numel, dtype=dtype, device=device()), "value": 1.5} -def _fill_cases(dtypes: tuple[torch.dtype, ...], shapes: list[int] | None = None) -> list[Case]: - sizes = shapes if shapes is not None else _SIZES_NUMEL - return [ - Case( - name=case_name("fill", (numel,), dtype), - make_inputs=partial(_make_fill_inputs, numel, dtype), - # Write only. - expected_bytes=numel * torch.tensor([], dtype=dtype).element_size(), - ) - for dtype in dtypes - for numel in sizes - ] +def _fill_bytes(numel: int, dtype: torch.dtype) -> int: + # Write only. + return numel * dtype.itemsize _FILL_VARIANTS = standard_fwd_variants( - eager_fn=_fill_eager, - triton_fn=triton_fill, + eager_function=_fill_eager, + triton_function=triton_fill, unpack=lambda inputs: (inputs["input_"], inputs["value"]), ) @@ -108,26 +86,19 @@ def _make_add_inputs(numel: int, dtype: torch.dtype) -> dict: } -def _add_cases(dtypes: tuple[torch.dtype, ...], shapes: list[int] | None = None) -> list[Case]: - sizes = shapes if shapes is not None else _SIZES_NUMEL - return [ - Case( - name=case_name("add", (numel,), dtype), - make_inputs=partial(_make_add_inputs, numel, dtype), - # Read 2 inputs + write 1 output. - expected_bytes=3 * numel * torch.tensor([], dtype=dtype).element_size(), - # One fp add per element. - expected_flops=numel, - compute_dtype=dtype, - ) - for dtype in dtypes - for numel in sizes - ] +def _add_bytes(numel: int, dtype: torch.dtype) -> int: + # Read 2 inputs + write 1 output. + return 3 * numel * dtype.itemsize + + +def _add_flops(numel: int) -> int: + # One fp add per element. + return numel _ADD_VARIANTS = standard_fwd_variants( - eager_fn=_add_eager, - triton_fn=triton_add, + eager_function=_add_eager, + triton_function=triton_add, unpack=lambda inputs: (inputs["input_"], inputs["other"], inputs["out"]), ) @@ -140,23 +111,19 @@ def benchmarks( shapes: list[int] | None = None, ) -> list[tuple[str, list, list]]: dtypes = tuple(dtypes) if dtypes else _DEFAULT_DTYPES + shapes = shapes if shapes is not None else _SIZES_NUMEL return [ - ("pointwise: copy", _copy_cases(dtypes, shapes), _COPY_VARIANTS), - ("pointwise: fill", _fill_cases(dtypes, shapes), _FILL_VARIANTS), - ("pointwise: add", _add_cases(dtypes, shapes), _ADD_VARIANTS), + ("pointwise: copy", make_cases("copy", dtypes, shapes, _make_copy_inputs, _copy_bytes), _COPY_VARIANTS), + ("pointwise: fill", make_cases("fill", dtypes, shapes, _make_fill_inputs, _fill_bytes), _FILL_VARIANTS), + ( + "pointwise: add", + make_cases("add", dtypes, shapes, _make_add_inputs, _add_bytes, _add_flops), + _ADD_VARIANTS, + ), ] -def run( - verbose: bool = False, - dtypes: tuple[torch.dtype, ...] | None = None, - shapes: list[int] | None = None, - warmup_ms: float = 25.0, - rep_ms: float = 100.0, - min_reps: int = 5, -) -> None: - for name, cases, variants in benchmarks(dtypes, shapes): - run_benchmark(name, cases, variants, verbose=verbose, warmup_ms=warmup_ms, rep_ms=rep_ms, min_reps=min_reps) +run = bench_main(benchmarks) if __name__ == "__main__": diff --git a/tools/benchmark/bench_rotary.py b/tools/benchmark/bench_rotary.py index fb798c234..8d4616097 100644 --- a/tools/benchmark/bench_rotary.py +++ b/tools/benchmark/bench_rotary.py @@ -12,14 +12,12 @@ - 8 heads × 128 → GQA key-value heads (Llama 3) """ -from functools import partial - import torch from fast_llm.functional.config import TritonConfig from fast_llm.functional.triton.rotary import triton_rotary_ -from tools.benchmark.runner import Case, Variant, run_benchmark -from tools.benchmark.utils import case_name, device +from tools.benchmark.runner import Variant +from tools.benchmark.utils import bench_main, device, make_cases # (tokens, num_heads, head_size) — tokens = batch * seq_len _SHAPES = [ @@ -57,9 +55,8 @@ def _rotary_eager(input_: torch.Tensor, frequencies: torch.Tensor) -> torch.Tens def _rotary_bytes(tokens: int, num_heads: int, head_size: int, dtype: torch.dtype) -> int: - element_size = torch.tensor([], dtype=dtype).element_size() # Read + write input tensor; frequencies are float32. - return 2 * tokens * num_heads * head_size * element_size + tokens * head_size * 4 + return 2 * tokens * num_heads * head_size * dtype.itemsize + tokens * head_size * 4 def _rotary_flops(tokens: int, num_heads: int, head_size: int) -> int: @@ -98,39 +95,22 @@ def _rotary_variants() -> list[Variant]: return variants -def _rotary_cases(dtypes: tuple[torch.dtype, ...], shapes: list[tuple[int, int, int]] | None = None) -> list[Case]: - shapes = shapes if shapes is not None else _SHAPES - return [ - Case( - name=case_name("rotary", (tokens, num_heads, head_size), dtype), - make_inputs=partial(_make_rotary_inputs, tokens, num_heads, head_size, dtype), - expected_bytes=_rotary_bytes(tokens, num_heads, head_size, dtype), - expected_flops=_rotary_flops(tokens, num_heads, head_size), - compute_dtype=dtype, - ) - for dtype in dtypes - for tokens, num_heads, head_size in shapes - ] - - def benchmarks( dtypes: tuple[torch.dtype, ...] | None = None, shapes: list[tuple[int, int, int]] | None = None, ) -> list[tuple[str, list, list]]: dtypes = tuple(dtypes) if dtypes else _DEFAULT_DTYPES - return [("rotary", _rotary_cases(dtypes, shapes), _rotary_variants())] + shapes = shapes if shapes is not None else _SHAPES + return [ + ( + "rotary", + make_cases("rotary", dtypes, shapes, _make_rotary_inputs, _rotary_bytes, _rotary_flops), + _rotary_variants(), + ) + ] -def run( - verbose: bool = False, - dtypes: tuple[torch.dtype, ...] | None = None, - shapes: list[tuple[int, int, int]] | None = None, - warmup_ms: float = 25.0, - rep_ms: float = 100.0, - min_reps: int = 5, -) -> None: - for name, cases, variants in benchmarks(dtypes, shapes): - run_benchmark(name, cases, variants, verbose=verbose, warmup_ms=warmup_ms, rep_ms=rep_ms, min_reps=min_reps) +run = bench_main(benchmarks) if __name__ == "__main__": diff --git a/tools/benchmark/bench_sparse_copy.py b/tools/benchmark/bench_sparse_copy.py index 2d4464dd9..c991aa792 100644 --- a/tools/benchmark/bench_sparse_copy.py +++ b/tools/benchmark/bench_sparse_copy.py @@ -20,8 +20,6 @@ The SparseMap is pre-computed once per case (routing structure, not data). """ -from functools import partial - import torch from fast_llm.functional.config import TritonConfig @@ -31,8 +29,8 @@ copy_sparse_to_dense_autograd, get_sparse_map, ) -from tools.benchmark.runner import Case, Variant, run_benchmark -from tools.benchmark.utils import case_name, device +from tools.benchmark.runner import Variant +from tools.benchmark.utils import bench_main, device, make_cases, standard_fwd_bwd_pytorch_variants # (tokens, top_k, num_experts, hidden_size) _SHAPES = [ @@ -53,9 +51,9 @@ def _make_phantom_mask(sparse_map: SparseMap) -> torch.Tensor: # and the static tail beyond expert_ends[-1]). Precomputed once per case and # used with masked_fill_ in output_postprocess — never inside the timed path. mask = torch.zeros(sparse_map.num_rows, 1, dtype=torch.bool, device=device()) - for e in range(sparse_map.num_experts): - pad_begin = int(sparse_map.expert_pad_begins[e]) - pad_end = int(sparse_map.expert_ends[e]) + for expert in range(sparse_map.num_experts): + pad_begin = int(sparse_map.expert_pad_begins[expert]) + pad_end = int(sparse_map.expert_ends[expert]) if pad_end > pad_begin: mask[pad_begin:pad_end] = True tail_begin = int(sparse_map.expert_ends[-1]) @@ -67,7 +65,7 @@ def _make_phantom_mask(sparse_map: SparseMap) -> torch.Tensor: def _make_dispatch_inputs(tokens: int, top_k: int, num_experts: int, hidden: int, dtype: torch.dtype) -> dict: sparse_map = _make_sparse_map(tokens, top_k, num_experts) return { - "dense_input": torch.randn(tokens, hidden, dtype=dtype, device=device(), requires_grad=True), + "dense": torch.randn(tokens, hidden, dtype=dtype, device=device(), requires_grad=True), "sparse_map": sparse_map, "phantom_mask": _make_phantom_mask(sparse_map), "backward_grad": torch.ones(sparse_map.num_rows, hidden, dtype=dtype, device=device()), @@ -77,7 +75,7 @@ def _make_dispatch_inputs(tokens: int, top_k: int, num_experts: int, hidden: int def _make_combine_inputs(tokens: int, top_k: int, num_experts: int, hidden: int, dtype: torch.dtype) -> dict: sparse_map = _make_sparse_map(tokens, top_k, num_experts) return { - "sparse_input": torch.randn(sparse_map.num_rows, hidden, dtype=dtype, device=device(), requires_grad=True), + "sparse": torch.randn(sparse_map.num_rows, hidden, dtype=dtype, device=device(), requires_grad=True), "scores": torch.softmax(torch.randn(tokens, top_k, dtype=dtype, device=device()), dim=-1).requires_grad_(True), "sparse_map": sparse_map, "phantom_mask": _make_phantom_mask(sparse_map), @@ -96,77 +94,34 @@ def _dispatch_pytorch(dense_input: torch.Tensor, sparse_map: SparseMap) -> torch return out -_dispatch_compiled_default = torch.compile(_dispatch_pytorch, mode="default", dynamic=False) -_dispatch_compiled_max = torch.compile(_dispatch_pytorch, mode="max-autotune-no-cudagraphs", dynamic=False) - - -def _run_dispatch_fwd(inputs: dict, fn) -> dict: - return {"output": fn(inputs["dense_input"], inputs["sparse_map"])} - - -def _run_dispatch_fwd_bwd(inputs: dict, fn) -> dict: - output = fn(inputs["dense_input"], inputs["sparse_map"]) - output.backward(inputs["backward_grad"]) - return {"output": output.detach(), "grad_dense": inputs["dense_input"].grad} - - -def _run_dispatch_fwd_fp32(inputs: dict) -> dict: - dense_fp32 = inputs["dense_input"].float().detach().requires_grad_(True) - return {"output": _dispatch_pytorch(dense_fp32, inputs["sparse_map"])} +def _dispatch_triton_fwd(inputs: dict) -> dict: + return {"output": copy_dense_to_sparse_autograd(inputs["dense"], inputs["sparse_map"])} -def _run_dispatch_fwd_bwd_fp32(inputs: dict) -> dict: - dense_fp32 = inputs["dense_input"].float().detach().requires_grad_(True) - output = _dispatch_pytorch(dense_fp32, inputs["sparse_map"]) - output.backward(inputs["backward_grad"].float()) - return {"output": output.detach(), "grad_dense": dense_fp32.grad} - - -def _run_dispatch_fwd_triton(inputs: dict) -> dict: - return {"output": copy_dense_to_sparse_autograd(inputs["dense_input"], inputs["sparse_map"])} - - -def _run_dispatch_fwd_bwd_triton(inputs: dict) -> dict: - output = copy_dense_to_sparse_autograd(inputs["dense_input"], inputs["sparse_map"]) +def _dispatch_triton_fwd_bwd(inputs: dict) -> dict: + output = copy_dense_to_sparse_autograd(inputs["dense"], inputs["sparse_map"]) output.backward(inputs["backward_grad"]) - return {"output": output.detach(), "grad_dense": inputs["dense_input"].grad} + return {"output": output.detach(), "grad_dense": inputs["dense"].grad} -def _dispatch_postprocess(out: dict[str, torch.Tensor], inputs: dict) -> dict[str, torch.Tensor]: - out["output"].masked_fill_(inputs["phantom_mask"], 0) - return out +def _dispatch_postprocess(output: dict[str, torch.Tensor], inputs: dict) -> dict[str, torch.Tensor]: + output["output"].masked_fill_(inputs["phantom_mask"], 0) + return output def _dispatch_variants() -> list[Variant]: - variants = [ - Variant( - name="fp32_reference", - fwd=_run_dispatch_fwd_fp32, - fwd_bwd=_run_dispatch_fwd_bwd_fp32, - is_reference=True, - ), - Variant( - name="pytorch_eager", - fwd=lambda inputs: _run_dispatch_fwd(inputs, _dispatch_pytorch), - fwd_bwd=lambda inputs: _run_dispatch_fwd_bwd(inputs, _dispatch_pytorch), - ), - Variant( - name="pytorch_compiled", - fwd=lambda inputs: _run_dispatch_fwd(inputs, _dispatch_compiled_default), - fwd_bwd=lambda inputs: _run_dispatch_fwd_bwd(inputs, _dispatch_compiled_default), - ), - Variant( - name="pytorch_compiled_max", - fwd=lambda inputs: _run_dispatch_fwd(inputs, _dispatch_compiled_max), - fwd_bwd=lambda inputs: _run_dispatch_fwd_bwd(inputs, _dispatch_compiled_max), - ), - ] + variants = standard_fwd_bwd_pytorch_variants( + _dispatch_pytorch, + input_keys=("dense", "sparse_map"), + grad_input_keys=("dense",), + grad_output_key="backward_grad", + ) if TritonConfig.enabled(): variants.append( Variant( name="fast_llm_triton", - fwd=_run_dispatch_fwd_triton, - fwd_bwd=_run_dispatch_fwd_bwd_triton, + fwd=_dispatch_triton_fwd, + fwd_bwd=_dispatch_triton_fwd_bwd, output_postprocess=_dispatch_postprocess, ) ) @@ -184,152 +139,59 @@ def _combine_pytorch(sparse_input: torch.Tensor, scores: torch.Tensor, sparse_ma return out -_combine_compiled_default = torch.compile(_combine_pytorch, mode="default", dynamic=False) -_combine_compiled_max = torch.compile(_combine_pytorch, mode="max-autotune-no-cudagraphs", dynamic=False) - - -def _run_combine_fwd(inputs: dict, fn) -> dict: - return {"output": fn(inputs["sparse_input"], inputs["scores"], inputs["sparse_map"])} +def _combine_triton_fwd(inputs: dict) -> dict: + return {"output": copy_sparse_to_dense_autograd(inputs["sparse"], inputs["scores"], inputs["sparse_map"])} -def _run_combine_fwd_bwd(inputs: dict, fn) -> dict: - output = fn(inputs["sparse_input"], inputs["scores"], inputs["sparse_map"]) +def _combine_triton_fwd_bwd(inputs: dict) -> dict: + output = copy_sparse_to_dense_autograd(inputs["sparse"], inputs["scores"], inputs["sparse_map"]) output.backward(inputs["backward_grad"]) return { "output": output.detach(), - "grad_sparse": inputs["sparse_input"].grad, + "grad_sparse": inputs["sparse"].grad, "grad_scores": inputs["scores"].grad, } -def _run_combine_fwd_fp32(inputs: dict) -> dict: - sparse_fp32 = inputs["sparse_input"].float().detach().requires_grad_(True) - scores_fp32 = inputs["scores"].float().detach().requires_grad_(True) - return {"output": _combine_pytorch(sparse_fp32, scores_fp32, inputs["sparse_map"])} - - -def _run_combine_fwd_bwd_fp32(inputs: dict) -> dict: - sparse_fp32 = inputs["sparse_input"].float().detach().requires_grad_(True) - scores_fp32 = inputs["scores"].float().detach().requires_grad_(True) - output = _combine_pytorch(sparse_fp32, scores_fp32, inputs["sparse_map"]) - output.backward(inputs["backward_grad"].float()) - return { - "output": output.detach(), - "grad_sparse": sparse_fp32.grad, - "grad_scores": scores_fp32.grad, - } - - -def _run_combine_fwd_triton(inputs: dict) -> dict: - return {"output": copy_sparse_to_dense_autograd(inputs["sparse_input"], inputs["scores"], inputs["sparse_map"])} - - -def _run_combine_fwd_bwd_triton(inputs: dict) -> dict: - output = copy_sparse_to_dense_autograd(inputs["sparse_input"], inputs["scores"], inputs["sparse_map"]) - output.backward(inputs["backward_grad"]) - return { - "output": output.detach(), - "grad_sparse": inputs["sparse_input"].grad, - "grad_scores": inputs["scores"].grad, - } - - -def _combine_postprocess(out: dict[str, torch.Tensor], inputs: dict) -> dict[str, torch.Tensor]: - if "grad_sparse" in out: - out["grad_sparse"].masked_fill_(inputs["phantom_mask"], 0) - return out +def _combine_postprocess(output: dict[str, torch.Tensor], inputs: dict) -> dict[str, torch.Tensor]: + if "grad_sparse" in output: + output["grad_sparse"].masked_fill_(inputs["phantom_mask"], 0) + return output def _combine_variants() -> list[Variant]: - variants = [ - Variant( - name="fp32_reference", - fwd=_run_combine_fwd_fp32, - fwd_bwd=_run_combine_fwd_bwd_fp32, - is_reference=True, - ), - Variant( - name="pytorch_eager", - fwd=lambda inputs: _run_combine_fwd(inputs, _combine_pytorch), - fwd_bwd=lambda inputs: _run_combine_fwd_bwd(inputs, _combine_pytorch), - ), - Variant( - name="pytorch_compiled", - fwd=lambda inputs: _run_combine_fwd(inputs, _combine_compiled_default), - fwd_bwd=lambda inputs: _run_combine_fwd_bwd(inputs, _combine_compiled_default), - ), - Variant( - name="pytorch_compiled_max", - fwd=lambda inputs: _run_combine_fwd(inputs, _combine_compiled_max), - fwd_bwd=lambda inputs: _run_combine_fwd_bwd(inputs, _combine_compiled_max), - ), - ] + variants = standard_fwd_bwd_pytorch_variants( + _combine_pytorch, + input_keys=("sparse", "scores", "sparse_map"), + grad_input_keys=("sparse", "scores"), + grad_output_key="backward_grad", + ) if TritonConfig.enabled(): variants.append( Variant( name="fast_llm_triton", - fwd=_run_combine_fwd_triton, - fwd_bwd=_run_combine_fwd_bwd_triton, + fwd=_combine_triton_fwd, + fwd_bwd=_combine_triton_fwd_bwd, output_postprocess=_combine_postprocess, ) ) return variants -# --------------------------------------------------------------------------- cases / bytes - +# --------------------------------------------------------------------------- bytes -def _bytes_per_element(dtype: torch.dtype) -> int: - return torch.tensor([], dtype=dtype).element_size() - -def _dispatch_bytes(tokens: int, top_k: int, hidden: int, dtype: torch.dtype) -> int: - element_size = _bytes_per_element(dtype) +def _dispatch_bytes(tokens: int, top_k: int, num_experts: int, hidden: int, dtype: torch.dtype) -> int: # fwd: read dense (tokens×h) + write sparse (top_k×tokens×h) # bwd: read sparse grad + write dense grad → same traffic reversed - return 2 * (1 + top_k) * tokens * hidden * element_size + return 2 * (1 + top_k) * tokens * hidden * dtype.itemsize -def _combine_bytes(tokens: int, top_k: int, hidden: int, dtype: torch.dtype) -> int: - element_size = _bytes_per_element(dtype) +def _combine_bytes(tokens: int, top_k: int, num_experts: int, hidden: int, dtype: torch.dtype) -> int: sparse_rows = top_k * tokens # fwd: read sparse (sparse×h) + read scores (tokens×top_k) + write dense (tokens×h) # bwd: read dense grad + read scores + write sparse grad + write score grad - return 2 * (sparse_rows + tokens) * hidden * element_size + 4 * tokens * top_k * element_size - - -def _dispatch_cases( - dtypes: tuple[torch.dtype, ...], shapes: list[tuple[int, int, int, int]] | None = None -) -> list[Case]: - shapes = shapes if shapes is not None else _SHAPES - return [ - Case( - name=case_name("dispatch", (tokens, top_k, num_experts, hidden), dtype), - make_inputs=partial(_make_dispatch_inputs, tokens, top_k, num_experts, hidden, dtype), - expected_bytes=_dispatch_bytes(tokens, top_k, hidden, dtype), - expected_flops=0, - compute_dtype=dtype, - ) - for dtype in dtypes - for tokens, top_k, num_experts, hidden in shapes - ] - - -def _combine_cases( - dtypes: tuple[torch.dtype, ...], shapes: list[tuple[int, int, int, int]] | None = None -) -> list[Case]: - shapes = shapes if shapes is not None else _SHAPES - return [ - Case( - name=case_name("combine", (tokens, top_k, num_experts, hidden), dtype), - make_inputs=partial(_make_combine_inputs, tokens, top_k, num_experts, hidden, dtype), - expected_bytes=_combine_bytes(tokens, top_k, hidden, dtype), - expected_flops=0, - compute_dtype=dtype, - ) - for dtype in dtypes - for tokens, top_k, num_experts, hidden in shapes - ] + return 2 * (sparse_rows + tokens) * hidden * dtype.itemsize + 4 * tokens * top_k * dtype.itemsize # --------------------------------------------------------------------------- entry point @@ -340,22 +202,22 @@ def benchmarks( shapes: list[tuple[int, int, int, int]] | None = None, ) -> list[tuple[str, list, list]]: dtypes = tuple(dtypes) if dtypes else _DEFAULT_DTYPES + shapes = shapes if shapes is not None else _SHAPES return [ - ("sparse_copy: dispatch", _dispatch_cases(dtypes, shapes), _dispatch_variants()), - ("sparse_copy: combine", _combine_cases(dtypes, shapes), _combine_variants()), + ( + "sparse_copy: dispatch", + make_cases("dispatch", dtypes, shapes, _make_dispatch_inputs, _dispatch_bytes), + _dispatch_variants(), + ), + ( + "sparse_copy: combine", + make_cases("combine", dtypes, shapes, _make_combine_inputs, _combine_bytes), + _combine_variants(), + ), ] -def run( - verbose: bool = False, - dtypes: tuple[torch.dtype, ...] | None = None, - shapes: list[tuple[int, int, int, int]] | None = None, - warmup_ms: float = 25.0, - rep_ms: float = 100.0, - min_reps: int = 5, -) -> None: - for name, cases, variants in benchmarks(dtypes, shapes): - run_benchmark(name, cases, variants, verbose=verbose, warmup_ms=warmup_ms, rep_ms=rep_ms, min_reps=min_reps) +run = bench_main(benchmarks) if __name__ == "__main__": diff --git a/tools/benchmark/bench_sparse_linear.py b/tools/benchmark/bench_sparse_linear.py index a5741bed5..5ee5833c8 100644 --- a/tools/benchmark/bench_sparse_linear.py +++ b/tools/benchmark/bench_sparse_linear.py @@ -23,15 +23,13 @@ Shapes: (tokens, top_k, num_experts, hidden, ffn_per_expert) matching MoE FFN configs. """ -from functools import partial - import torch from fast_llm.functional.config import TritonConfig from fast_llm.functional.triton.sparse_copy import SparseMap, get_sparse_map from fast_llm.functional.triton.sparse_linear import InputSparseLinear, OutputSparseLinear -from tools.benchmark.runner import Case, Variant, run_benchmark -from tools.benchmark.utils import case_name, device +from tools.benchmark.runner import Variant +from tools.benchmark.utils import bench_main, device, make_cases, standard_fwd_bwd_pytorch_variants # (tokens, top_k, num_experts, hidden, ffn_per_expert) _SHAPES = [ @@ -53,7 +51,7 @@ def _make_sparse_map(tokens: int, top_k: int, num_experts: int) -> SparseMap: return get_sparse_map(top_experts, num_experts) -def _mask_padded_rows(cand: dict[str, torch.Tensor], inputs: dict) -> dict[str, torch.Tensor]: +def _mask_padded_rows(candidate: dict[str, torch.Tensor], inputs: dict) -> dict[str, torch.Tensor]: # Two regions in the kernel's forward output and grad_lhs are by-design garbage that # downstream consumers ignore: per-expert padding [pad_begin, expert_end) (where the # kernel does a matmul on random padding inputs) and phantom rows [expert_ends[-1], @@ -65,7 +63,7 @@ def _mask_padded_rows(cand: dict[str, torch.Tensor], inputs: dict) -> dict[str, pad_begins = sparse_map.expert_pad_begins.tolist() pad_ends = sparse_map.expert_ends.tolist() last_expert_end = pad_ends[-1] - masked = dict(cand) + masked = dict(candidate) for key in ("output", "grad_lhs"): if key not in masked: continue @@ -86,19 +84,18 @@ def _make_output_sparse_inputs( lhs_data = torch.randn(sparse_map.num_rows, hidden, dtype=dtype, device=device()) rhs_data = torch.randn(hidden, ffn_per_expert * num_experts, dtype=dtype, device=device()) backward_grad = torch.ones(sparse_map.num_rows, ffn_per_expert, dtype=dtype, device=device()) - _warmup_key = (tokens, top_k, num_experts, hidden, ffn_per_expert, dtype) - if TritonConfig.enabled() and _warmup_key not in _output_sparse_warmed_up: - _w_lhs = lhs_data.detach().requires_grad_(True) - _w_rhs = rhs_data.detach().requires_grad_(True) - _w_out = OutputSparseLinear.apply(_w_lhs, _w_rhs, sparse_map) - _w_out.backward(backward_grad) - del _w_lhs, _w_rhs, _w_out - _output_sparse_warmed_up.add(_warmup_key) + warmup_key = (tokens, top_k, num_experts, hidden, ffn_per_expert, dtype) + if TritonConfig.enabled() and warmup_key not in _output_sparse_warmed_up: + warmup_lhs = lhs_data.detach().requires_grad_(True) + warmup_rhs = rhs_data.detach().requires_grad_(True) + warmup_out = OutputSparseLinear.apply(warmup_lhs, warmup_rhs, sparse_map) + warmup_out.backward(backward_grad) + del warmup_lhs, warmup_rhs, warmup_out + _output_sparse_warmed_up.add(warmup_key) return { "lhs": lhs_data.requires_grad_(True), "rhs": rhs_data.requires_grad_(True), "sparse_map": sparse_map, - "ffn_per_expert": ffn_per_expert, "backward_grad": backward_grad, } @@ -110,248 +107,133 @@ def _make_input_inner_sparse_inputs( lhs_data = torch.randn(sparse_map.num_rows, ffn_per_expert, dtype=dtype, device=device()) rhs_data = torch.randn(ffn_per_expert * num_experts, hidden, dtype=dtype, device=device()) backward_grad = torch.ones(sparse_map.num_rows, hidden, dtype=dtype, device=device()) - _warmup_key = (tokens, top_k, num_experts, hidden, ffn_per_expert, dtype) - if TritonConfig.enabled() and _warmup_key not in _input_inner_sparse_warmed_up: - _w_lhs = lhs_data.detach().requires_grad_(True) - _w_rhs = rhs_data.detach().requires_grad_(True) - _w_out = InputSparseLinear.apply(_w_lhs, _w_rhs, sparse_map) - _w_out.backward(backward_grad) - del _w_lhs, _w_rhs, _w_out - _input_inner_sparse_warmed_up.add(_warmup_key) + warmup_key = (tokens, top_k, num_experts, hidden, ffn_per_expert, dtype) + if TritonConfig.enabled() and warmup_key not in _input_inner_sparse_warmed_up: + warmup_lhs = lhs_data.detach().requires_grad_(True) + warmup_rhs = rhs_data.detach().requires_grad_(True) + warmup_out = InputSparseLinear.apply(warmup_lhs, warmup_rhs, sparse_map) + warmup_out.backward(backward_grad) + del warmup_lhs, warmup_rhs, warmup_out + _input_inner_sparse_warmed_up.add(warmup_key) return { "lhs": lhs_data.requires_grad_(True), "rhs": rhs_data.requires_grad_(True), "sparse_map": sparse_map, - "ffn_per_expert": ffn_per_expert, "backward_grad": backward_grad, } -# --------------------------------------------------------------------------- output_sparse references +# --------------------------------------------------------------------------- output_sparse def _output_sparse_loop(lhs: torch.Tensor, rhs: torch.Tensor, sparse_map: SparseMap) -> torch.Tensor: ffn_per_expert = rhs.shape[1] // sparse_map.num_experts out = lhs.new_zeros(sparse_map.num_rows, ffn_per_expert) - for e in range(sparse_map.num_experts): - row_begin = int(sparse_map.expert_ends[e - 1]) if e > 0 else 0 - row_end = int(sparse_map.expert_pad_begins[e]) + for expert in range(sparse_map.num_experts): + row_begin = int(sparse_map.expert_ends[expert - 1]) if expert > 0 else 0 + row_end = int(sparse_map.expert_pad_begins[expert]) if row_end > row_begin: - col_begin = e * ffn_per_expert + col_begin = expert * ffn_per_expert out[row_begin:row_end] = lhs[row_begin:row_end] @ rhs[:, col_begin : col_begin + ffn_per_expert] return out -_output_sparse_compiled = torch.compile(_output_sparse_loop, mode="default", dynamic=False) - - -def _run_output_sparse_fwd(inputs: dict, fn) -> dict: - return {"output": fn(inputs["lhs"], inputs["rhs"], inputs["sparse_map"])} - - -def _run_output_sparse_fwd_bwd(inputs: dict, fn) -> dict: - output = fn(inputs["lhs"], inputs["rhs"], inputs["sparse_map"]) - output.backward(inputs["backward_grad"]) - return {"output": output.detach(), "grad_lhs": inputs["lhs"].grad, "grad_rhs": inputs["rhs"].grad} - - -def _run_output_sparse_fwd_fp32(inputs: dict) -> dict: - lhs_fp32 = inputs["lhs"].float().detach().requires_grad_(True) - rhs_fp32 = inputs["rhs"].float().detach().requires_grad_(True) - return {"output": _output_sparse_loop(lhs_fp32, rhs_fp32, inputs["sparse_map"])} - - -def _run_output_sparse_fwd_bwd_fp32(inputs: dict) -> dict: - lhs_fp32 = inputs["lhs"].float().detach().requires_grad_(True) - rhs_fp32 = inputs["rhs"].float().detach().requires_grad_(True) - output = _output_sparse_loop(lhs_fp32, rhs_fp32, inputs["sparse_map"]) - output.backward(inputs["backward_grad"].float()) - return {"output": output.detach(), "grad_lhs": lhs_fp32.grad, "grad_rhs": rhs_fp32.grad} - - -def _run_output_sparse_fwd_triton(inputs: dict) -> dict: +def _output_sparse_triton_fwd(inputs: dict) -> dict: return {"output": OutputSparseLinear.apply(inputs["lhs"], inputs["rhs"], inputs["sparse_map"])} -def _run_output_sparse_fwd_bwd_triton(inputs: dict) -> dict: +def _output_sparse_triton_fwd_bwd(inputs: dict) -> dict: output = OutputSparseLinear.apply(inputs["lhs"], inputs["rhs"], inputs["sparse_map"]) output.backward(inputs["backward_grad"]) return {"output": output.detach(), "grad_lhs": inputs["lhs"].grad, "grad_rhs": inputs["rhs"].grad} def _output_sparse_variants() -> list[Variant]: - variants = [ - Variant( - name="fp32_reference", - fwd=_run_output_sparse_fwd_fp32, - fwd_bwd=_run_output_sparse_fwd_bwd_fp32, - is_reference=True, - ), - Variant( - name="pytorch_loop", - fwd=lambda inputs: _run_output_sparse_fwd(inputs, _output_sparse_loop), - fwd_bwd=lambda inputs: _run_output_sparse_fwd_bwd(inputs, _output_sparse_loop), - ), - Variant( - name="pytorch_compiled", - fwd=lambda inputs: _run_output_sparse_fwd(inputs, _output_sparse_compiled), - fwd_bwd=lambda inputs: _run_output_sparse_fwd_bwd(inputs, _output_sparse_compiled), - ), - ] + variants = standard_fwd_bwd_pytorch_variants( + _output_sparse_loop, + input_keys=("lhs", "rhs", "sparse_map"), + grad_input_keys=("lhs", "rhs"), + grad_output_key="backward_grad", + eager_name="pytorch_loop", + enable_max_autotune=False, + ) if TritonConfig.enabled(): variants.append( Variant( name="fast_llm_triton", - fwd=_run_output_sparse_fwd_triton, - fwd_bwd=_run_output_sparse_fwd_bwd_triton, + fwd=_output_sparse_triton_fwd, + fwd_bwd=_output_sparse_triton_fwd_bwd, output_postprocess=_mask_padded_rows, ) ) return variants -# --------------------------------------------------------------------------- input_inner_sparse references +# --------------------------------------------------------------------------- input_inner_sparse def _input_inner_sparse_loop(lhs: torch.Tensor, rhs: torch.Tensor, sparse_map: SparseMap) -> torch.Tensor: ffn_per_expert = rhs.shape[0] // sparse_map.num_experts out = lhs.new_zeros(sparse_map.num_rows, rhs.shape[1]) - for e in range(sparse_map.num_experts): - row_begin = int(sparse_map.expert_ends[e - 1]) if e > 0 else 0 - row_end = int(sparse_map.expert_pad_begins[e]) + for expert in range(sparse_map.num_experts): + row_begin = int(sparse_map.expert_ends[expert - 1]) if expert > 0 else 0 + row_end = int(sparse_map.expert_pad_begins[expert]) if row_end > row_begin: - inner_begin = e * ffn_per_expert + inner_begin = expert * ffn_per_expert out[row_begin:row_end] = lhs[row_begin:row_end] @ rhs[inner_begin : inner_begin + ffn_per_expert] return out -_input_inner_sparse_compiled = torch.compile(_input_inner_sparse_loop, mode="default", dynamic=False) - - -def _run_input_inner_sparse_fwd(inputs: dict, fn) -> dict: - return {"output": fn(inputs["lhs"], inputs["rhs"], inputs["sparse_map"])} - - -def _run_input_inner_sparse_fwd_bwd(inputs: dict, fn) -> dict: - output = fn(inputs["lhs"], inputs["rhs"], inputs["sparse_map"]) - output.backward(inputs["backward_grad"]) - return {"output": output.detach(), "grad_lhs": inputs["lhs"].grad, "grad_rhs": inputs["rhs"].grad} - - -def _run_input_inner_sparse_fwd_fp32(inputs: dict) -> dict: - lhs_fp32 = inputs["lhs"].float().detach().requires_grad_(True) - rhs_fp32 = inputs["rhs"].float().detach().requires_grad_(True) - return {"output": _input_inner_sparse_loop(lhs_fp32, rhs_fp32, inputs["sparse_map"])} - - -def _run_input_inner_sparse_fwd_bwd_fp32(inputs: dict) -> dict: - lhs_fp32 = inputs["lhs"].float().detach().requires_grad_(True) - rhs_fp32 = inputs["rhs"].float().detach().requires_grad_(True) - output = _input_inner_sparse_loop(lhs_fp32, rhs_fp32, inputs["sparse_map"]) - output.backward(inputs["backward_grad"].float()) - return {"output": output.detach(), "grad_lhs": lhs_fp32.grad, "grad_rhs": rhs_fp32.grad} - - -def _run_input_inner_sparse_fwd_triton(inputs: dict) -> dict: +def _input_inner_sparse_triton_fwd(inputs: dict) -> dict: return {"output": InputSparseLinear.apply(inputs["lhs"], inputs["rhs"], inputs["sparse_map"])} -def _run_input_inner_sparse_fwd_bwd_triton(inputs: dict) -> dict: +def _input_inner_sparse_triton_fwd_bwd(inputs: dict) -> dict: output = InputSparseLinear.apply(inputs["lhs"], inputs["rhs"], inputs["sparse_map"]) output.backward(inputs["backward_grad"]) return {"output": output.detach(), "grad_lhs": inputs["lhs"].grad, "grad_rhs": inputs["rhs"].grad} def _input_inner_sparse_variants() -> list[Variant]: - variants = [ - Variant( - name="fp32_reference", - fwd=_run_input_inner_sparse_fwd_fp32, - fwd_bwd=_run_input_inner_sparse_fwd_bwd_fp32, - is_reference=True, - ), - Variant( - name="pytorch_loop", - fwd=lambda inputs: _run_input_inner_sparse_fwd(inputs, _input_inner_sparse_loop), - fwd_bwd=lambda inputs: _run_input_inner_sparse_fwd_bwd(inputs, _input_inner_sparse_loop), - ), - Variant( - name="pytorch_compiled", - fwd=lambda inputs: _run_input_inner_sparse_fwd(inputs, _input_inner_sparse_compiled), - fwd_bwd=lambda inputs: _run_input_inner_sparse_fwd_bwd(inputs, _input_inner_sparse_compiled), - ), - ] + variants = standard_fwd_bwd_pytorch_variants( + _input_inner_sparse_loop, + input_keys=("lhs", "rhs", "sparse_map"), + grad_input_keys=("lhs", "rhs"), + grad_output_key="backward_grad", + eager_name="pytorch_loop", + enable_max_autotune=False, + ) if TritonConfig.enabled(): variants.append( Variant( name="fast_llm_triton", - fwd=_run_input_inner_sparse_fwd_triton, - fwd_bwd=_run_input_inner_sparse_fwd_bwd_triton, + fwd=_input_inner_sparse_triton_fwd, + fwd_bwd=_input_inner_sparse_triton_fwd_bwd, output_postprocess=_mask_padded_rows, ) ) return variants -# --------------------------------------------------------------------------- cases / bytes / flops - - -def _bytes_per_element(dtype: torch.dtype) -> int: - return torch.tensor([], dtype=dtype).element_size() +# --------------------------------------------------------------------------- bytes / flops def _sparse_linear_bytes( - sparse_tokens: int, hidden: int, ffn_per_expert: int, num_experts: int, dtype: torch.dtype + tokens: int, top_k: int, num_experts: int, hidden: int, ffn_per_expert: int, dtype: torch.dtype ) -> int: - element_size = _bytes_per_element(dtype) # fwd: read lhs + read rhs_full + write output # bwd: read grad_output + read rhs_full + write grad_lhs + read lhs + read grad_output + write grad_rhs # Simplification: 3× lhs traffic + 3× rhs traffic + 2× output traffic - lhs_bytes = sparse_tokens * hidden * element_size - rhs_bytes = hidden * ffn_per_expert * num_experts * element_size - out_bytes = sparse_tokens * ffn_per_expert * element_size - return 3 * lhs_bytes + 3 * rhs_bytes + 2 * out_bytes + sparse_tokens = tokens * top_k + lhs_bytes = sparse_tokens * hidden * dtype.itemsize + rhs_bytes = hidden * ffn_per_expert * num_experts * dtype.itemsize + output_bytes = sparse_tokens * ffn_per_expert * dtype.itemsize + return 3 * lhs_bytes + 3 * rhs_bytes + 2 * output_bytes -def _sparse_linear_flops(sparse_tokens_unpadded: int, hidden: int, ffn_per_expert: int) -> int: +def _sparse_linear_flops(tokens: int, top_k: int, num_experts: int, hidden: int, ffn_per_expert: int) -> int: # fwd + bwd ≈ 3 matmuls (fwd: lhs@rhs, bwd_lhs: grad@rhs.T, bwd_rhs: lhs.T@grad) - return 3 * 2 * sparse_tokens_unpadded * hidden * ffn_per_expert - - -def _output_sparse_cases( - dtypes: tuple[torch.dtype, ...], shapes: list[tuple[int, int, int, int, int]] | None = None -) -> list[Case]: - shapes = shapes if shapes is not None else _SHAPES - return [ - Case( - name=case_name("output_sparse", (tokens, top_k, num_experts, hidden, ffn_per_expert), dtype), - make_inputs=partial(_make_output_sparse_inputs, tokens, top_k, num_experts, hidden, ffn_per_expert, dtype), - expected_bytes=_sparse_linear_bytes(tokens * top_k, hidden, ffn_per_expert, num_experts, dtype), - expected_flops=_sparse_linear_flops(tokens * top_k, hidden, ffn_per_expert), - compute_dtype=dtype, - ) - for dtype in dtypes - for tokens, top_k, num_experts, hidden, ffn_per_expert in shapes - ] - - -def _input_inner_sparse_cases( - dtypes: tuple[torch.dtype, ...], shapes: list[tuple[int, int, int, int, int]] | None = None -) -> list[Case]: - shapes = shapes if shapes is not None else _SHAPES - return [ - Case( - name=case_name("input_inner_sparse", (tokens, top_k, num_experts, hidden, ffn_per_expert), dtype), - make_inputs=partial( - _make_input_inner_sparse_inputs, tokens, top_k, num_experts, hidden, ffn_per_expert, dtype - ), - expected_bytes=_sparse_linear_bytes(tokens * top_k, ffn_per_expert, hidden, num_experts, dtype), - expected_flops=_sparse_linear_flops(tokens * top_k, ffn_per_expert, hidden), - compute_dtype=dtype, - ) - for dtype in dtypes - for tokens, top_k, num_experts, hidden, ffn_per_expert in shapes - ] + return 3 * 2 * tokens * top_k * hidden * ffn_per_expert # --------------------------------------------------------------------------- entry point @@ -362,30 +244,31 @@ def benchmarks( shapes: list[tuple[int, int, int, int, int]] | None = None, ) -> list[tuple[str, list, list]]: dtypes = tuple(dtypes) if dtypes else _DEFAULT_DTYPES + shapes = shapes if shapes is not None else _SHAPES return [ ( "sparse_linear: output_sparse (layer 1 / up-proj)", - _output_sparse_cases(dtypes, shapes), + make_cases( + "output_sparse", dtypes, shapes, _make_output_sparse_inputs, _sparse_linear_bytes, _sparse_linear_flops + ), _output_sparse_variants(), ), ( "sparse_linear: input_inner_sparse (layer 2 / down-proj)", - _input_inner_sparse_cases(dtypes, shapes), + make_cases( + "input_inner_sparse", + dtypes, + shapes, + _make_input_inner_sparse_inputs, + _sparse_linear_bytes, + _sparse_linear_flops, + ), _input_inner_sparse_variants(), ), ] -def run( - verbose: bool = False, - dtypes: tuple[torch.dtype, ...] | None = None, - shapes: list[tuple[int, int, int, int, int]] | None = None, - warmup_ms: float = 25.0, - rep_ms: float = 100.0, - min_reps: int = 5, -) -> None: - for name, cases, variants in benchmarks(dtypes, shapes): - run_benchmark(name, cases, variants, verbose=verbose, warmup_ms=warmup_ms, rep_ms=rep_ms, min_reps=min_reps) +run = bench_main(benchmarks) if __name__ == "__main__": diff --git a/tools/benchmark/utils.py b/tools/benchmark/utils.py index 1a5ab417c..2284c2497 100644 --- a/tools/benchmark/utils.py +++ b/tools/benchmark/utils.py @@ -5,12 +5,13 @@ """ from collections.abc import Callable +from functools import partial import torch from fast_llm.engine.config_utils.data_type import DataType from fast_llm.functional.config import TritonConfig -from tools.benchmark.runner import Inputs, Variant +from tools.benchmark.runner import Case, Inputs, Variant, run_benchmark # --------------------------------------------------------------------------- formatting @@ -40,12 +41,68 @@ def device() -> str: return "cuda" if torch.cuda.is_available() else "cpu" +# --------------------------------------------------------------------------- cases + + +def make_cases( + kernel_name: str, + dtypes: tuple[torch.dtype, ...], + shapes: list, + make_inputs: Callable, + bytes_fn: Callable | None = None, + flops_fn: Callable | None = None, +) -> list[Case]: + """Build the standard `Case` list as the cross-product of `dtypes × shapes`. + + Each `shape` may be a tuple or a scalar; tuples are unpacked positionally + into `make_inputs(*shape, dtype)`, `bytes_fn(*shape, dtype)`, and `flops_fn(*shape)`. + """ + cases = [] + for dtype in dtypes: + for shape in shapes: + shape_tuple = shape if isinstance(shape, tuple) else (shape,) + cases.append( + Case( + name=case_name(kernel_name, shape_tuple, dtype), + make_inputs=partial(make_inputs, *shape_tuple, dtype), + expected_bytes=bytes_fn(*shape_tuple, dtype) if bytes_fn else None, + expected_flops=flops_fn(*shape_tuple) if flops_fn else None, + compute_dtype=dtype, + ) + ) + return cases + + +# --------------------------------------------------------------------------- run/main + + +def bench_main(benchmarks_fn: Callable) -> Callable: + """Build the standard `run()` callable that loops `benchmarks_fn(dtypes, shapes)` + through `run_benchmark`. Each `bench_*.py` exports `run = bench_main(benchmarks)` + so the package CLI in `__main__.py` can dispatch to it.""" + + def run( + verbose: bool = False, + dtypes: tuple[torch.dtype, ...] | None = None, + shapes: list | None = None, + warmup_ms: float = 25.0, + rep_ms: float = 100.0, + min_reps: int = 5, + ) -> None: + for name, cases, variants in benchmarks_fn(dtypes, shapes): + run_benchmark( + name, cases, variants, verbose=verbose, warmup_ms=warmup_ms, rep_ms=rep_ms, min_reps=min_reps + ) + + return run + + # --------------------------------------------------------------------------- variant builders def standard_fwd_variants( - eager_fn: Callable, - triton_fn: Callable | None, + eager_function: Callable, + triton_function: Callable | None, unpack: Callable[[Inputs], tuple], ) -> list[Variant]: """Build the canonical 5-variant set for a forward-only kernel. @@ -53,33 +110,139 @@ def standard_fwd_variants( Generates: fp32_reference, pytorch_eager, pytorch_compiled, pytorch_compiled_max, and (if `TritonConfig.enabled()`) fast_llm_triton. - `eager_fn` is the plain PyTorch implementation taking positional tensor args. - `triton_fn` is the Fast-LLM Triton wrapper; pass `None` if the kernel has no + `eager_function` is the plain PyTorch implementation taking positional tensor args. + `triton_function` is the Fast-LLM Triton wrapper; pass `None` if the kernel has no Triton variant. Both are invoked with `unpack(inputs)` unpacked positionally; - `triton_fn` is called with an extra `use_triton=True` kwarg. + `triton_function` is called with an extra `use_triton=True` kwarg. The fp32 reference upcasts every floating-point tensor in the unpacked arguments to fp32 (non-tensor / non-float arguments are passed through). """ - def _fp32_unpack(inputs: Inputs) -> tuple: + def fp32_unpack(inputs: Inputs) -> tuple: return tuple( arg.float() if isinstance(arg, torch.Tensor) and arg.is_floating_point() else arg for arg in unpack(inputs) ) - compiled_default = torch.compile(eager_fn, mode="default", dynamic=False) - compiled_max = torch.compile(eager_fn, mode="max-autotune-no-cudagraphs", dynamic=False) + compiled_default = torch.compile(eager_function, mode="default", dynamic=False) + compiled_max = torch.compile(eager_function, mode="max-autotune-no-cudagraphs", dynamic=False) + + variants = [ + Variant( + name="fp32_reference", + fwd=lambda inputs: eager_function(*fp32_unpack(inputs)), + is_reference=True, + ), + Variant(name="pytorch_eager", fwd=lambda inputs: eager_function(*unpack(inputs))), + Variant(name="pytorch_compiled", fwd=lambda inputs: compiled_default(*unpack(inputs))), + Variant(name="pytorch_compiled_max", fwd=lambda inputs: compiled_max(*unpack(inputs))), + ] + if triton_function is not None and TritonConfig.enabled(): + variants.append( + Variant(name="fast_llm_triton", fwd=lambda inputs: triton_function(*unpack(inputs), use_triton=True)) + ) + return variants + + +def _run_pytorch_fwd( + inputs: Inputs, + function: Callable, + input_keys: tuple[str, ...], + output_key: str, +) -> dict: + return {output_key: function(*(inputs[key] for key in input_keys))} + + +def _run_pytorch_fwd_bwd( + inputs: Inputs, + function: Callable, + input_keys: tuple[str, ...], + grad_input_keys: tuple[str, ...], + grad_output_key: str | None, + output_key: str, +) -> dict: + output = function(*(inputs[key] for key in input_keys)) + if grad_output_key is None: + output.backward() + else: + output.backward(inputs[grad_output_key]) + result = {output_key: output.detach()} + for key in grad_input_keys: + result[f"grad_{key}"] = inputs[key].grad + return result + + +def _to_fp32_inputs(inputs: Inputs, grad_input_keys: tuple[str, ...]) -> Inputs: + result = dict(inputs) + for key, value in inputs.items(): + if isinstance(value, torch.Tensor) and value.is_floating_point(): + float_value = value.float().detach() + result[key] = float_value.requires_grad_(True) if key in grad_input_keys else float_value + return result + + +def standard_fwd_bwd_pytorch_variants( + eager_function: Callable, + input_keys: tuple[str, ...], + grad_input_keys: tuple[str, ...], + *, + grad_output_key: str | None = None, + output_key: str = "output", + reset_inputs: Callable[[Inputs], None] | None = None, + extra_functions: dict[str, Callable] | None = None, + eager_name: str = "pytorch_eager", + enable_max_autotune: bool = True, +) -> list[Variant]: + """Build the canonical pytorch variant chunk for a forward-backward kernel. + + Generates: fp32_reference, , pytorch_compiled, [pytorch_compiled_max,] + plus any callables in `extra_functions` (e.g. apex implementations) appended + at the end with their dict-key as the variant name. + + `eager_function(*[inputs[key] for key in input_keys])` computes the forward output. + `grad_input_keys` lists input dict keys whose `.grad` is collected and returned + as `grad_` in the output dict. `grad_output_key` is the input dict key for + `output.backward(grad_output)`; pass `None` for scalar-loss kernels (uses bare + `output.backward()`). `output_key` is the output dict key for the forward result. + + The fp32 reference upcasts every floating-point tensor in the input dict to + fp32, re-attaching `requires_grad=True` for `grad_input_keys`. Non-float and + non-tensor entries (e.g. ints, enums, SparseMap) are passed through. + """ + fwd_kwargs = {"input_keys": input_keys, "output_key": output_key} + fwd_bwd_kwargs = { + "input_keys": input_keys, + "grad_input_keys": grad_input_keys, + "grad_output_key": grad_output_key, + "output_key": output_key, + } + + def variant(name: str, function: Callable) -> Variant: + return Variant( + name=name, + fwd=partial(_run_pytorch_fwd, function=function, **fwd_kwargs), + fwd_bwd=partial(_run_pytorch_fwd_bwd, function=function, **fwd_bwd_kwargs), + reset_inputs=reset_inputs, + ) + compiled_default = torch.compile(eager_function, mode="default", dynamic=False) variants = [ Variant( name="fp32_reference", - fwd=lambda inp: eager_fn(*_fp32_unpack(inp)), + fwd=lambda inputs: _run_pytorch_fwd( + _to_fp32_inputs(inputs, grad_input_keys), eager_function, **fwd_kwargs + ), + fwd_bwd=lambda inputs: _run_pytorch_fwd_bwd( + _to_fp32_inputs(inputs, grad_input_keys), eager_function, **fwd_bwd_kwargs + ), is_reference=True, ), - Variant(name="pytorch_eager", fwd=lambda inp: eager_fn(*unpack(inp))), - Variant(name="pytorch_compiled", fwd=lambda inp: compiled_default(*unpack(inp))), - Variant(name="pytorch_compiled_max", fwd=lambda inp: compiled_max(*unpack(inp))), + variant(eager_name, eager_function), + variant("pytorch_compiled", compiled_default), ] - if triton_fn is not None and TritonConfig.enabled(): - variants.append(Variant(name="fast_llm_triton", fwd=lambda inp: triton_fn(*unpack(inp), use_triton=True))) + if enable_max_autotune: + compiled_max = torch.compile(eager_function, mode="max-autotune-no-cudagraphs", dynamic=False) + variants.append(variant("pytorch_compiled_max", compiled_max)) + for name, function in (extra_functions or {}).items(): + variants.append(variant(name, function)) return variants From 16eecb61e63a0acca69c2564eed66cfb8adee057 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 1 May 2026 07:00:51 -0400 Subject: [PATCH 34/41] Trim per-file boilerplate from benchmark suite Push optional triton_fwd/triton_fwd_bwd/triton_output_postprocess into standard_fwd_bwd_pytorch_variants so kernels with simple triton variants no longer need a wrapper. Move DEFAULT_DTYPES into utils.bench_main and drop the per-file _DEFAULT_DTYPES + dtype-resolution line. Drop the `if __name__ == "__main__": run()` tail (only entrypoint is __main__.py). Trim verbose file-top docstrings. Co-Authored-By: Claude Opus 4.7 (1M context) --- tools/benchmark/bench_entropy_loss.py | 123 +++++++--------------- tools/benchmark/bench_grpo_loss.py | 57 +++-------- tools/benchmark/bench_mlp_activation.py | 59 +++-------- tools/benchmark/bench_normalization.py | 63 +++--------- tools/benchmark/bench_pointwise.py | 45 +------- tools/benchmark/bench_rotary.py | 27 +---- tools/benchmark/bench_sparse_copy.py | 113 +++++--------------- tools/benchmark/bench_sparse_linear.py | 131 ++++++------------------ tools/benchmark/utils.py | 89 ++++++---------- 9 files changed, 183 insertions(+), 524 deletions(-) diff --git a/tools/benchmark/bench_entropy_loss.py b/tools/benchmark/bench_entropy_loss.py index 8d22d6386..ba0de7a60 100644 --- a/tools/benchmark/bench_entropy_loss.py +++ b/tools/benchmark/bench_entropy_loss.py @@ -1,39 +1,21 @@ -""" -Benchmark entropy loss kernels. - -All Triton kernels fuse fwd+bwd into a single logits-tensor pass; `grad_output=1.0` -triggers gradient computation alongside the loss. - -Three main training cases benchmarked: - - cross_entropy + labels — standard LM training (integer targets) - cross_entropy + logits — distillation CE with soft targets, p=softmax(target_logits) - reverse_kl + logits — reverse KL divergence KL(q||p), p=softmax(target_logits) - -z_loss is also included (shared input structure with the labels case). - -Shapes fix tokens=4096, sweep vocab size from Llama-2 (32K) to Llama-3 (128K). -""" +"""Entropy loss kernels: cross_entropy (labels and logits target formats), +reverse_kl, and z_loss. All Triton kernels fuse fwd+bwd into a single +logits-tensor pass; `grad_output=1.0` triggers gradient computation.""" import torch import torch.nn.functional as F -from fast_llm.functional.config import EntropyLossType, TargetFormat, TritonConfig +from fast_llm.functional.config import EntropyLossType, TargetFormat from fast_llm.functional.triton.entropy_loss import triton_entropy_loss_forward_backward from fast_llm.functional.triton.z_loss import triton_z_loss_forward_backward -from tools.benchmark.runner import Variant from tools.benchmark.utils import bench_main, device, make_cases, standard_fwd_bwd_pytorch_variants # (tokens, vocab_size) _SHAPES = [ (4096, 32768), # 7B / Llama-2 vocab - (4096, 65536), # 64K vocab + (4096, 65536), (4096, 131072), # Llama-3 vocab ] -_DEFAULT_DTYPES = (torch.bfloat16,) - - -# --------------------------------------------------------------------------- inputs def _make_label_inputs(tokens: int, vocab: int, dtype: torch.dtype) -> dict: @@ -46,7 +28,6 @@ def _make_label_inputs(tokens: int, vocab: int, dtype: torch.dtype) -> dict: def _make_distribution_inputs(tokens: int, vocab: int, dtype: torch.dtype) -> dict: return { "logits": torch.randn(tokens, vocab, dtype=dtype, device=device(), requires_grad=True), - # target_logits: teacher logits; no gradient needed w.r.t. these. "target_logits": torch.randn(tokens, vocab, dtype=dtype, device=device()), } @@ -55,20 +36,15 @@ def _reset_logits_grad(inputs: dict) -> None: inputs["logits"].grad = None -# --------------------------------------------------------------------------- eager kernels - - def _ce_labels_eager(logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: return F.cross_entropy(logits, labels) def _ce_dist_eager(logits: torch.Tensor, target_logits: torch.Tensor) -> torch.Tensor: - """CE(p, q) where p = softmax(target_logits), q = softmax(logits).""" return F.cross_entropy(logits, target_logits.softmax(dim=-1)) def _reverse_kl_eager(logits: torch.Tensor, target_logits: torch.Tensor) -> torch.Tensor: - """KL(q||p) where q = softmax(logits), p = softmax(target_logits).""" return F.kl_div(target_logits.log_softmax(dim=-1), logits.softmax(dim=-1), reduction="batchmean") @@ -77,85 +53,60 @@ def _z_loss_eager(logits: torch.Tensor) -> torch.Tensor: return (log_z * log_z).mean() -# --------------------------------------------------------------------------- variant assembly +def _entropy_variants(eager_function, input_keys, triton_kwargs=None) -> list: + target_key = input_keys[1] + triton_kwargs = triton_kwargs or {} + + def triton_fwd(inputs: dict) -> dict: + loss, _ = triton_entropy_loss_forward_backward( + inputs["logits"], inputs[target_key], loss_mask=None, grad_output=None, **triton_kwargs + ) + return {"loss": loss} + def triton_fwd_bwd(inputs: dict) -> dict: + loss, grad_logits = triton_entropy_loss_forward_backward( + inputs["logits"], inputs[target_key], loss_mask=None, grad_output=1.0, **triton_kwargs + ) + return {"loss": loss, "grad_logits": grad_logits} -def _entropy_variants(eager_function, input_keys, triton_kwargs=None) -> list[Variant]: - variants = standard_fwd_bwd_pytorch_variants( + return standard_fwd_bwd_pytorch_variants( eager_function, input_keys=input_keys, grad_input_keys=("logits",), output_key="loss", reset_inputs=_reset_logits_grad, + triton_fwd=triton_fwd, + triton_fwd_bwd=triton_fwd_bwd, ) - if TritonConfig.enabled(): - target_key = input_keys[1] - kwargs = triton_kwargs or {} - - def triton_fwd(inputs: dict) -> dict: - loss, _ = triton_entropy_loss_forward_backward( - inputs["logits"], inputs[target_key], loss_mask=None, grad_output=None, **kwargs - ) - return {"loss": loss} - - def triton_fwd_bwd(inputs: dict) -> dict: - loss, grad_logits = triton_entropy_loss_forward_backward( - inputs["logits"], inputs[target_key], loss_mask=None, grad_output=1.0, **kwargs - ) - return {"loss": loss, "grad_logits": grad_logits} - - variants.append(Variant(name="fast_llm_triton", fwd=triton_fwd, fwd_bwd=triton_fwd_bwd)) - return variants - - -def _z_loss_variants() -> list[Variant]: - variants = standard_fwd_bwd_pytorch_variants( - _z_loss_eager, - input_keys=("logits",), - grad_input_keys=("logits",), - output_key="loss", - reset_inputs=_reset_logits_grad, - ) - if TritonConfig.enabled(): - - def triton_fwd(inputs: dict) -> dict: - loss, _ = triton_z_loss_forward_backward(inputs["logits"], loss_mask=None, grad_output=None) - return {"loss": loss} - def triton_fwd_bwd(inputs: dict) -> dict: - loss, grad_logits = triton_z_loss_forward_backward(inputs["logits"], loss_mask=None, grad_output=1.0) - return {"loss": loss, "grad_logits": grad_logits} - variants.append(Variant(name="fast_llm_triton", fwd=triton_fwd, fwd_bwd=triton_fwd_bwd)) - return variants +def _z_loss_triton_fwd(inputs: dict) -> dict: + loss, _ = triton_z_loss_forward_backward(inputs["logits"], loss_mask=None, grad_output=None) + return {"loss": loss} -# --------------------------------------------------------------------------- bytes / flops +def _z_loss_triton_fwd_bwd(inputs: dict) -> dict: + loss, grad_logits = triton_z_loss_forward_backward(inputs["logits"], loss_mask=None, grad_output=1.0) + return {"loss": loss, "grad_logits": grad_logits} def _label_loss_bytes(tokens: int, vocab: int, dtype: torch.dtype) -> int: - # fwd+bwd: read logits, read labels (int32), write grad_logits. return 2 * tokens * vocab * dtype.itemsize + tokens * 4 def _dist_loss_bytes(tokens: int, vocab: int, dtype: torch.dtype) -> int: - # fwd+bwd: read logits, read target_logits, write grad_logits. return 3 * tokens * vocab * dtype.itemsize def _entropy_loss_flops(tokens: int, vocab: int) -> int: - # fwd ≈ 3*vocab per token (max, sum_exp, CE); bwd ≈ vocab. Total ≈ 4*vocab. + # fwd ≈ 3*vocab per token, bwd ≈ vocab. return 4 * tokens * vocab -# --------------------------------------------------------------------------- entry point - - def benchmarks( - dtypes: tuple[torch.dtype, ...] | None = None, + dtypes: tuple[torch.dtype, ...], shapes: list[tuple[int, int]] | None = None, ) -> list[tuple[str, list, list]]: - dtypes = tuple(dtypes) if dtypes else _DEFAULT_DTYPES shapes = shapes if shapes is not None else _SHAPES return [ ( @@ -201,13 +152,17 @@ def benchmarks( ( "entropy_loss: z_loss", make_cases("z_loss", dtypes, shapes, _make_label_inputs, _label_loss_bytes, _entropy_loss_flops), - _z_loss_variants(), + standard_fwd_bwd_pytorch_variants( + _z_loss_eager, + input_keys=("logits",), + grad_input_keys=("logits",), + output_key="loss", + reset_inputs=_reset_logits_grad, + triton_fwd=_z_loss_triton_fwd, + triton_fwd_bwd=_z_loss_triton_fwd_bwd, + ), ), ] run = bench_main(benchmarks) - - -if __name__ == "__main__": - run() diff --git a/tools/benchmark/bench_grpo_loss.py b/tools/benchmark/bench_grpo_loss.py index 7725e1ab2..65c7376a9 100644 --- a/tools/benchmark/bench_grpo_loss.py +++ b/tools/benchmark/bench_grpo_loss.py @@ -1,28 +1,10 @@ -""" -Benchmark the fused GRPO loss kernel. - -GRPO (Group Relative Policy Optimization) loss computes a clipped importance-weighted -policy gradient per token: loss = -min(ratio * adv, clip(ratio, 1-eps, 1+eps) * adv), -where ratio = exp(log_prob_new - log_prob_old). - -The Triton kernel fuses softmax, log-prob extraction, ratio computation, clipping, and -the backward gradient into a single pass over logits — same structure as the cross_entropy -kernel. - -Comparisons: -- fp32_reference: PyTorch GRPO in fp32 -- pytorch_eager: PyTorch GRPO in compute dtype -- pytorch_compiled / pytorch_compiled_max: torch.compile of the above -- fast_llm_triton: triton_grpo_loss_forward_backward - -Shapes match bench_entropy_loss: tokens=4096, vocab swept over 32K/64K/128K. -""" +"""Fused GRPO (Group Relative Policy Optimization) loss kernel. The Triton +kernel fuses softmax, log-prob extraction, ratio + clip, and backward into a +single pass over logits.""" import torch -from fast_llm.functional.config import TritonConfig from fast_llm.functional.triton.grpo_loss import triton_grpo_loss_forward_backward -from tools.benchmark.runner import Variant from tools.benchmark.utils import bench_main, device, make_cases, standard_fwd_bwd_pytorch_variants _SHAPES = [ @@ -30,7 +12,6 @@ (4096, 65536), (4096, 131072), ] -_DEFAULT_DTYPES = (torch.bfloat16,) _EPSILON_LOW = 0.2 _EPSILON_HIGH = 0.2 @@ -90,19 +71,6 @@ def _triton_fwd_bwd(inputs: dict) -> dict: return {"loss": loss, "grad_logits": grad_logits} -def _grpo_variants() -> list[Variant]: - variants = standard_fwd_bwd_pytorch_variants( - _grpo_eager, - input_keys=("logits", "labels", "advantages", "old_log_probs"), - grad_input_keys=("logits",), - output_key="loss", - reset_inputs=_reset_logits_grad, - ) - if TritonConfig.enabled(): - variants.append(Variant(name="fast_llm_triton", fwd=_triton_fwd, fwd_bwd=_triton_fwd_bwd)) - return variants - - def _grpo_bytes(tokens: int, vocab: int, dtype: torch.dtype) -> int: # fwd: read logits + bwd: read logits + write grad_logits logit_traffic = 3 * tokens * vocab * dtype.itemsize @@ -112,27 +80,30 @@ def _grpo_bytes(tokens: int, vocab: int, dtype: torch.dtype) -> int: def _grpo_flops(tokens: int, vocab: int) -> int: - # Similar to cross_entropy labels: softmax (fwd) + grad (bwd) ≈ 14 FLOPs/element + # softmax (fwd) + grad (bwd) ≈ 14 FLOPs/element return 14 * tokens * vocab def benchmarks( - dtypes: tuple[torch.dtype, ...] | None = None, + dtypes: tuple[torch.dtype, ...], shapes: list[tuple[int, int]] | None = None, ) -> list[tuple[str, list, list]]: - dtypes = tuple(dtypes) if dtypes else _DEFAULT_DTYPES shapes = shapes if shapes is not None else _SHAPES return [ ( "grpo_loss", make_cases("grpo_loss", dtypes, shapes, _make_grpo_inputs, _grpo_bytes, _grpo_flops), - _grpo_variants(), + standard_fwd_bwd_pytorch_variants( + _grpo_eager, + input_keys=("logits", "labels", "advantages", "old_log_probs"), + grad_input_keys=("logits",), + output_key="loss", + reset_inputs=_reset_logits_grad, + triton_fwd=_triton_fwd, + triton_fwd_bwd=_triton_fwd_bwd, + ), ) ] run = bench_main(benchmarks) - - -if __name__ == "__main__": - run() diff --git a/tools/benchmark/bench_mlp_activation.py b/tools/benchmark/bench_mlp_activation.py index 411dd3355..52bac15d9 100644 --- a/tools/benchmark/bench_mlp_activation.py +++ b/tools/benchmark/bench_mlp_activation.py @@ -1,32 +1,17 @@ -""" -Benchmark the fused MLP activation kernel. - -The Triton kernel (`triton_mlp_activation_autograd`) fuses the element-wise -activation and (for gated models) the gated multiply into a single pass. For -gated SiLU the fwd input is (tokens, 2*ffn_dim) — [gate_proj, up_proj] -concatenated — and the output is (tokens, ffn_dim). - -Comparisons: -- fp32_reference: torch_mlp_activation in fp32 with autograd -- pytorch_eager: torch_mlp_activation in compute dtype -- pytorch_compiled / pytorch_compiled_max: torch.compile of the above -- fast_llm_triton: triton_mlp_activation_autograd - -Shapes fix tokens=8192 and sweep ffn_dim across typical MLP widths. -""" +"""Fused MLP activation kernel. For gated SiLU the fwd input is (tokens, 2*ffn_dim) +— [gate_proj, up_proj] concatenated — and the output is (tokens, ffn_dim).""" import torch -from fast_llm.functional.config import ActivationType, TritonConfig +from fast_llm.functional.config import ActivationType from fast_llm.functional.triton.mlp import ( torch_mlp_activation, triton_mlp_activation_autograd, triton_mlp_activation_forward, ) -from tools.benchmark.runner import Variant from tools.benchmark.utils import bench_main, device, make_cases, standard_fwd_bwd_pytorch_variants -# (tokens, ffn_dim) — input tensor has shape (tokens, 2*ffn_dim) for gated. +# (tokens, ffn_dim) — input has shape (tokens, 2*ffn_dim) for gated. _SHAPES = [ (8192, 4096), # 7B/13B models (8192, 8192), # large @@ -34,7 +19,6 @@ (4096, 28672), # MoE up-projection ] _ACTIVATION = ActivationType.silu # standard for Llama-style gated models -_DEFAULT_DTYPES = (torch.bfloat16,) def _make_mlp_inputs(tokens: int, ffn_dim: int, dtype: torch.dtype) -> dict: @@ -57,35 +41,21 @@ def _triton_fwd_bwd(inputs: dict) -> dict: return {"output": output.detach(), "grad_input": inputs["input"].grad} -def _mlp_activation_variants() -> list[Variant]: - variants = standard_fwd_bwd_pytorch_variants( - torch_mlp_activation, - input_keys=("input", "gated", "activation_type"), - grad_input_keys=("input",), - grad_output_key="grad_output", - ) - if TritonConfig.enabled(): - variants.append(Variant(name="fast_llm_triton", fwd=_triton_fwd, fwd_bwd=_triton_fwd_bwd)) - return variants - - def _mlp_activation_bytes(tokens: int, ffn_dim: int, dtype: torch.dtype) -> int: - """fwd: read input (2*ffn_dim) + write output (ffn_dim). - bwd: read grad_output (ffn_dim) + read input (2*ffn_dim) + write grad_input (2*ffn_dim). - Total: 8 × tokens × ffn_dim × elem_size.""" + # fwd: read input (2*ffn_dim) + write output (ffn_dim). + # bwd: read grad_output (ffn_dim) + read input (2*ffn_dim) + write grad_input (2*ffn_dim). return 8 * tokens * ffn_dim * dtype.itemsize def _mlp_activation_flops(tokens: int, ffn_dim: int) -> int: - # gated silu: fwd ≈ 6 FLOPs/element, bwd ≈ 8 FLOPs/element, total ≈ 14 per output element. + # gated silu: fwd ≈ 6 FLOPs/element, bwd ≈ 8 FLOPs/element. return 14 * tokens * ffn_dim def benchmarks( - dtypes: tuple[torch.dtype, ...] | None = None, + dtypes: tuple[torch.dtype, ...], shapes: list[tuple[int, int]] | None = None, ) -> list[tuple[str, list, list]]: - dtypes = tuple(dtypes) if dtypes else _DEFAULT_DTYPES shapes = shapes if shapes is not None else _SHAPES return [ ( @@ -93,13 +63,16 @@ def benchmarks( make_cases( "mlp_activation", dtypes, shapes, _make_mlp_inputs, _mlp_activation_bytes, _mlp_activation_flops ), - _mlp_activation_variants(), + standard_fwd_bwd_pytorch_variants( + torch_mlp_activation, + input_keys=("input", "gated", "activation_type"), + grad_input_keys=("input",), + grad_output_key="grad_output", + triton_fwd=_triton_fwd, + triton_fwd_bwd=_triton_fwd_bwd, + ), ) ] run = bench_main(benchmarks) - - -if __name__ == "__main__": - run() diff --git a/tools/benchmark/bench_normalization.py b/tools/benchmark/bench_normalization.py index a7b904ce5..fd31a3c3e 100644 --- a/tools/benchmark/bench_normalization.py +++ b/tools/benchmark/bench_normalization.py @@ -1,23 +1,9 @@ -""" -Benchmark normalization kernels: LayerNorm and RMSNorm. - -Both are fwd+bwd kernels. The Triton implementation in -`fast_llm/functional/triton/normalization.py` handles both flavors via the +"""LayerNorm and RMSNorm. The Triton implementation handles both via the `bias` argument (LayerNorm when given, RMSNorm when None) and writes parameter -gradients to Fast-LLM's `grad_buffer` attribute rather than autograd's `.grad`. - -Comparisons: -- fp32_reference: torch.{layer,rms}_norm in fp32 (eager) -- pytorch_eager: torch.{layer,rms}_norm in the case dtype -- pytorch_compiled / pytorch_compiled_max: torch.compile of the above -- apex_fused: Apex fused_layer_norm_cuda (all widths, layer+rms norm) -- apex_fast: Apex fast_layer_norm contrib (layer norm only, restricted widths) -- fast_llm_triton: triton_normalization_autograd -""" +gradients to Fast-LLM's `grad_buffer` attribute rather than autograd's `.grad`.""" import torch -from fast_llm.functional.config import TritonConfig from fast_llm.functional.triton.normalization import triton_normalization_autograd from fast_llm.layers.common.normalization.normalization import ( FastLayerNorm, @@ -29,9 +15,8 @@ from tools.benchmark.runner import Variant from tools.benchmark.utils import bench_main, device, make_cases, standard_fwd_bwd_pytorch_variants -# Activation shape (batch*seq, hidden). Numel fixed at 32M to mimic a constant -# training memory budget across model widths; hidden swept from 1K to 16K covers -# small models through Llama-405B / wide-MoE territory. +# (batch*seq, hidden). Numel fixed at 32M to mimic a constant training memory +# budget across model widths; hidden swept from 1K to 16K. _SHAPES = [ (32768, 1024), (16384, 2048), @@ -39,14 +24,12 @@ (4096, 8192), (2048, 16384), ] -_DEFAULT_DTYPES = (torch.bfloat16,) _EPS = 1e-5 def _setup_param(tensor: torch.Tensor) -> torch.Tensor: """Triton's normalization backward writes weight/bias gradients to a - `grad_buffer` attribute (Fast-LLM convention) instead of autograd's `.grad`. - Wire up the buffer + zero-flag the kernel expects.""" + `grad_buffer` attribute (Fast-LLM convention) instead of autograd's `.grad`.""" tensor.grad_buffer = torch.zeros_like(tensor) tensor.param_grad_is_zero = True return tensor @@ -90,8 +73,7 @@ def _rms_norm_apex_fused(input_, weight): def _param_grad(param: torch.Tensor) -> torch.Tensor: - """Pull the parameter gradient from wherever the kernel wrote it. - Triton writes to `grad_buffer`; autograd writes to `.grad`.""" + """Triton writes to `grad_buffer`; autograd writes to `.grad`.""" return param.grad if param.grad is not None else param.grad_buffer @@ -131,43 +113,34 @@ def _layer_norm_variants() -> list[Variant]: if fused_normalization_available: extras["apex_fused"] = _layer_norm_apex_fused if fast_normalization_available: - # apex_fast only supports widths in _PERSIST_LN_SIZES; all shapes in _SHAPES qualify. extras["apex_fast"] = _layer_norm_apex_fast - variants = standard_fwd_bwd_pytorch_variants( + return standard_fwd_bwd_pytorch_variants( _layer_norm_eager, input_keys=("input", "weight", "bias"), grad_input_keys=("input", "weight", "bias"), grad_output_key="grad_output", extra_functions=extras, + triton_fwd=_layer_norm_triton_fwd, + triton_fwd_bwd=_layer_norm_triton_fwd_bwd, ) - if TritonConfig.enabled(): - variants.append( - Variant(name="fast_llm_triton", fwd=_layer_norm_triton_fwd, fwd_bwd=_layer_norm_triton_fwd_bwd) - ) - return variants def _rms_norm_variants() -> list[Variant]: extras: dict = {} if fused_normalization_available: extras["apex_fused"] = _rms_norm_apex_fused - variants = standard_fwd_bwd_pytorch_variants( + return standard_fwd_bwd_pytorch_variants( _rms_norm_eager, input_keys=("input", "weight"), grad_input_keys=("input", "weight"), grad_output_key="grad_output", extra_functions=extras, + triton_fwd=_rms_norm_triton_fwd, + triton_fwd_bwd=_rms_norm_triton_fwd_bwd, ) - if TritonConfig.enabled(): - variants.append(Variant(name="fast_llm_triton", fwd=_rms_norm_triton_fwd, fwd_bwd=_rms_norm_triton_fwd_bwd)) - return variants def _layer_norm_bytes(rows: int, cols: int, dtype: torch.dtype) -> int: - """Approximate fwd+bwd memory traffic for LayerNorm. - fwd reads input + weight + bias and writes output (also stores inv_var). - bwd reads grad_output, output, weight, bias, inv_var; writes grad_input, - grad_weight, grad_bias. Activation tensors dominate.""" activations = 4 * rows * cols * dtype.itemsize # fwd in/out + bwd grad_in/out parameters = 6 * cols * dtype.itemsize # weight, bias × (read + grad write) twice return activations + parameters @@ -180,22 +153,18 @@ def _rms_norm_bytes(rows: int, cols: int, dtype: torch.dtype) -> int: def _layer_norm_flops(rows: int, cols: int) -> int: - """Approximate fwd+bwd FLOPs for LayerNorm. - fwd: mean (1), variance (2), normalize (2), scale+shift (2) ≈ 7 per element. - bwd: ~2x fwd.""" + # fwd ≈ 7 per element (mean, variance, normalize, scale+shift); bwd ≈ 2× fwd. return 21 * rows * cols def _rms_norm_flops(rows: int, cols: int) -> int: - """Same idea as LayerNorm but no mean subtraction or bias.""" return 15 * rows * cols def benchmarks( - dtypes: tuple[torch.dtype, ...] | None = None, + dtypes: tuple[torch.dtype, ...], shapes: list[tuple[int, int]] | None = None, ) -> list[tuple[str, list, list]]: - dtypes = tuple(dtypes) if dtypes else _DEFAULT_DTYPES shapes = shapes if shapes is not None else _SHAPES return [ ( @@ -212,7 +181,3 @@ def benchmarks( run = bench_main(benchmarks) - - -if __name__ == "__main__": - run() diff --git a/tools/benchmark/bench_pointwise.py b/tools/benchmark/bench_pointwise.py index 0e62a2b2b..3679597b0 100644 --- a/tools/benchmark/bench_pointwise.py +++ b/tools/benchmark/bench_pointwise.py @@ -1,19 +1,11 @@ -""" -Benchmark pointwise kernels: copy, fill, add. - -These kernels are pure bandwidth-bound: runtime is dominated by reading inputs -and writing outputs, so GB/s and %-of-peak-BW are the headline metrics. The -Triton kernels live in `fast_llm/functional/triton/pointwise.py` and are -documented as being ~2x faster than the PyTorch equivalent on A100. -""" +"""Bandwidth-bound pointwise kernels: copy, fill, add.""" import torch from fast_llm.functional.triton.pointwise import triton_add, triton_copy, triton_fill from tools.benchmark.utils import bench_main, device, make_cases, standard_fwd_variants -# Sizes span from L2-resident to comfortably HBM-bound, in 4× steps so the -# regime transitions (L2 → HBM, mid-HBM → saturated-HBM) are visible. +# 4× steps so L2 → HBM and saturated-HBM regimes are visible. _SIZES_NUMEL = [ 1 << 20, # 1M — 2 MiB bf16 (L2-resident on most GPUs) 1 << 22, # 4M — 8 MiB bf16 (L2 boundary) @@ -21,10 +13,6 @@ 1 << 26, # 64M — 128 MiB bf16 (HBM) 1 << 28, # 256M — 512 MiB bf16 (large HBM, near-saturated) ] -_DEFAULT_DTYPES = (torch.bfloat16,) - - -# --------------------------------------------------------------------------- copy def _copy_eager(input_: torch.Tensor, out: torch.Tensor) -> torch.Tensor: @@ -37,7 +25,6 @@ def _make_copy_inputs(numel: int, dtype: torch.dtype) -> dict: def _copy_bytes(numel: int, dtype: torch.dtype) -> int: - # Read input + write output. return 2 * numel * dtype.itemsize @@ -48,9 +35,6 @@ def _copy_bytes(numel: int, dtype: torch.dtype) -> int: ) -# --------------------------------------------------------------------------- fill - - def _fill_eager(input_: torch.Tensor, value: float) -> torch.Tensor: return input_.fill_(value) @@ -60,7 +44,6 @@ def _make_fill_inputs(numel: int, dtype: torch.dtype) -> dict: def _fill_bytes(numel: int, dtype: torch.dtype) -> int: - # Write only. return numel * dtype.itemsize @@ -71,9 +54,6 @@ def _fill_bytes(numel: int, dtype: torch.dtype) -> int: ) -# --------------------------------------------------------------------------- add - - def _add_eager(input_: torch.Tensor, other: torch.Tensor, out: torch.Tensor) -> torch.Tensor: return torch.add(input_, other, out=out) @@ -87,12 +67,10 @@ def _make_add_inputs(numel: int, dtype: torch.dtype) -> dict: def _add_bytes(numel: int, dtype: torch.dtype) -> int: - # Read 2 inputs + write 1 output. return 3 * numel * dtype.itemsize def _add_flops(numel: int) -> int: - # One fp add per element. return numel @@ -103,28 +81,13 @@ def _add_flops(numel: int) -> int: ) -# --------------------------------------------------------------------------- entry point - - -def benchmarks( - dtypes: tuple[torch.dtype, ...] | None = None, - shapes: list[int] | None = None, -) -> list[tuple[str, list, list]]: - dtypes = tuple(dtypes) if dtypes else _DEFAULT_DTYPES +def benchmarks(dtypes: tuple[torch.dtype, ...], shapes: list[int] | None = None) -> list[tuple[str, list, list]]: shapes = shapes if shapes is not None else _SIZES_NUMEL return [ ("pointwise: copy", make_cases("copy", dtypes, shapes, _make_copy_inputs, _copy_bytes), _COPY_VARIANTS), ("pointwise: fill", make_cases("fill", dtypes, shapes, _make_fill_inputs, _fill_bytes), _FILL_VARIANTS), - ( - "pointwise: add", - make_cases("add", dtypes, shapes, _make_add_inputs, _add_bytes, _add_flops), - _ADD_VARIANTS, - ), + ("pointwise: add", make_cases("add", dtypes, shapes, _make_add_inputs, _add_bytes, _add_flops), _ADD_VARIANTS), ] run = bench_main(benchmarks) - - -if __name__ == "__main__": - run() diff --git a/tools/benchmark/bench_rotary.py b/tools/benchmark/bench_rotary.py index 8d4616097..b91918f0b 100644 --- a/tools/benchmark/bench_rotary.py +++ b/tools/benchmark/bench_rotary.py @@ -1,16 +1,5 @@ -""" -Benchmark rotary position embeddings. - -The Triton kernel (`triton_rotary_`) operates in-place on (tokens, num_heads, -head_size) tensors, loading pre-computed (cos, sin) frequencies from -(tokens, 2*rotary_dim). The backward is an identical rotation call with -conjugated frequencies — same cost — so only fwd is benchmarked. - -Shapes sweep (tokens, num_heads, head_size) across typical attention configs: -- 32 heads × 128 → 7B/13B models -- 64 heads × 128 → 70B / MoE models -- 8 heads × 128 → GQA key-value heads (Llama 3) -""" +"""Rotary position embeddings. The Triton kernel is in-place; backward is an +identical rotation with conjugated frequencies, so only fwd is benchmarked.""" import torch @@ -26,7 +15,6 @@ (4096, 64, 128), # 70B / MoE, 4K context (4096, 8, 128), # GQA key-value heads, 4K context ] -_DEFAULT_DTYPES = (torch.bfloat16,) def _make_rotary_inputs(tokens: int, num_heads: int, head_size: int, dtype: torch.dtype) -> dict: @@ -40,9 +28,8 @@ def _make_rotary_inputs(tokens: int, num_heads: int, head_size: int, dtype: torc def _rotary_eager(input_: torch.Tensor, frequencies: torch.Tensor) -> torch.Tensor: - """Non-in-place full rotary (rotary_dim = head_size / 2).""" rotary_dim = frequencies.shape[-1] // 2 - freq_re = frequencies[:, :rotary_dim].unsqueeze(1) # (tokens, 1, rotary_dim) + freq_re = frequencies[:, :rotary_dim].unsqueeze(1) freq_im = frequencies[:, rotary_dim:].unsqueeze(1) x_re, x_im = input_.chunk(2, dim=-1) out_re = x_re * freq_re - x_im * freq_im @@ -55,7 +42,6 @@ def _rotary_eager(input_: torch.Tensor, frequencies: torch.Tensor) -> torch.Tens def _rotary_bytes(tokens: int, num_heads: int, head_size: int, dtype: torch.dtype) -> int: - # Read + write input tensor; frequencies are float32. return 2 * tokens * num_heads * head_size * dtype.itemsize + tokens * head_size * 4 @@ -96,10 +82,9 @@ def _rotary_variants() -> list[Variant]: def benchmarks( - dtypes: tuple[torch.dtype, ...] | None = None, + dtypes: tuple[torch.dtype, ...], shapes: list[tuple[int, int, int]] | None = None, ) -> list[tuple[str, list, list]]: - dtypes = tuple(dtypes) if dtypes else _DEFAULT_DTYPES shapes = shapes if shapes is not None else _SHAPES return [ ( @@ -111,7 +96,3 @@ def benchmarks( run = bench_main(benchmarks) - - -if __name__ == "__main__": - run() diff --git a/tools/benchmark/bench_sparse_copy.py b/tools/benchmark/bench_sparse_copy.py index c991aa792..5c9c54f37 100644 --- a/tools/benchmark/bench_sparse_copy.py +++ b/tools/benchmark/bench_sparse_copy.py @@ -1,35 +1,14 @@ -""" -Benchmark MoE token dispatch and combine (sparse copy) kernels. - -Two operations are benchmarked separately: - -dispatch (dense → sparse): - Each token is copied to top_k expert slots in the sparse buffer. - copy_dense_to_sparse_autograd handles fwd+bwd (bwd = sparse-to-dense, no scores). - -combine (sparse → dense): - Expert outputs are gathered and weighted by routing scores back to token space. - copy_sparse_to_dense_autograd handles fwd+bwd (bwd = dense-to-sparse + score grad). - -Comparisons: -- pytorch_eager: index-based scatter/gather in compute dtype -- pytorch_compiled / pytorch_compiled_max: torch.compile of the above -- fast_llm_triton: copy_dense_to_sparse_autograd / copy_sparse_to_dense_autograd - -Shapes: (tokens, top_k, num_experts, hidden_size) matching Mixtral-8x7B and fine-grained MoE. -The SparseMap is pre-computed once per case (routing structure, not data). -""" +"""MoE token dispatch (dense → sparse) and combine (sparse → dense, weighted by +routing scores) kernels.""" import torch -from fast_llm.functional.config import TritonConfig from fast_llm.functional.triton.sparse_copy import ( SparseMap, copy_dense_to_sparse_autograd, copy_sparse_to_dense_autograd, get_sparse_map, ) -from tools.benchmark.runner import Variant from tools.benchmark.utils import bench_main, device, make_cases, standard_fwd_bwd_pytorch_variants # (tokens, top_k, num_experts, hidden_size) @@ -38,7 +17,6 @@ (4096, 2, 64, 4096), # fine-grained MoE (4096, 2, 8, 8192), # wide hidden ] -_DEFAULT_DTYPES = (torch.bfloat16,) def _make_sparse_map(tokens: int, top_k: int, num_experts: int) -> SparseMap: @@ -47,9 +25,8 @@ def _make_sparse_map(tokens: int, top_k: int, num_experts: int) -> SparseMap: def _make_phantom_mask(sparse_map: SparseMap) -> torch.Tensor: - # Boolean mask shape (num_rows, 1): True for phantom rows (within-expert padding - # and the static tail beyond expert_ends[-1]). Precomputed once per case and - # used with masked_fill_ in output_postprocess — never inside the timed path. + # Boolean mask (num_rows, 1): True for phantom rows (within-expert padding + # and the static tail beyond expert_ends[-1]). Used in output_postprocess only. mask = torch.zeros(sparse_map.num_rows, 1, dtype=torch.bool, device=device()) for expert in range(sparse_map.num_experts): pad_begin = int(sparse_map.expert_pad_begins[expert]) @@ -83,9 +60,6 @@ def _make_combine_inputs(tokens: int, top_k: int, num_experts: int, hidden: int, } -# --------------------------------------------------------------------------- dispatch - - def _dispatch_pytorch(dense_input: torch.Tensor, sparse_map: SparseMap) -> torch.Tensor: out = dense_input.new_zeros(sparse_map.num_rows, dense_input.shape[1]) sparse_rows = sparse_map.sparse_rows.long() @@ -109,28 +83,6 @@ def _dispatch_postprocess(output: dict[str, torch.Tensor], inputs: dict) -> dict return output -def _dispatch_variants() -> list[Variant]: - variants = standard_fwd_bwd_pytorch_variants( - _dispatch_pytorch, - input_keys=("dense", "sparse_map"), - grad_input_keys=("dense",), - grad_output_key="backward_grad", - ) - if TritonConfig.enabled(): - variants.append( - Variant( - name="fast_llm_triton", - fwd=_dispatch_triton_fwd, - fwd_bwd=_dispatch_triton_fwd_bwd, - output_postprocess=_dispatch_postprocess, - ) - ) - return variants - - -# --------------------------------------------------------------------------- combine - - def _combine_pytorch(sparse_input: torch.Tensor, scores: torch.Tensor, sparse_map: SparseMap) -> torch.Tensor: out = sparse_input.new_zeros(sparse_map.num_rows_dense, sparse_input.shape[1]) sparse_rows = sparse_map.sparse_rows.long() @@ -159,66 +111,49 @@ def _combine_postprocess(output: dict[str, torch.Tensor], inputs: dict) -> dict[ return output -def _combine_variants() -> list[Variant]: - variants = standard_fwd_bwd_pytorch_variants( - _combine_pytorch, - input_keys=("sparse", "scores", "sparse_map"), - grad_input_keys=("sparse", "scores"), - grad_output_key="backward_grad", - ) - if TritonConfig.enabled(): - variants.append( - Variant( - name="fast_llm_triton", - fwd=_combine_triton_fwd, - fwd_bwd=_combine_triton_fwd_bwd, - output_postprocess=_combine_postprocess, - ) - ) - return variants - - -# --------------------------------------------------------------------------- bytes - - def _dispatch_bytes(tokens: int, top_k: int, num_experts: int, hidden: int, dtype: torch.dtype) -> int: - # fwd: read dense (tokens×h) + write sparse (top_k×tokens×h) - # bwd: read sparse grad + write dense grad → same traffic reversed + # fwd: read dense + write sparse; bwd: same traffic reversed. return 2 * (1 + top_k) * tokens * hidden * dtype.itemsize def _combine_bytes(tokens: int, top_k: int, num_experts: int, hidden: int, dtype: torch.dtype) -> int: sparse_rows = top_k * tokens - # fwd: read sparse (sparse×h) + read scores (tokens×top_k) + write dense (tokens×h) - # bwd: read dense grad + read scores + write sparse grad + write score grad return 2 * (sparse_rows + tokens) * hidden * dtype.itemsize + 4 * tokens * top_k * dtype.itemsize -# --------------------------------------------------------------------------- entry point - - def benchmarks( - dtypes: tuple[torch.dtype, ...] | None = None, + dtypes: tuple[torch.dtype, ...], shapes: list[tuple[int, int, int, int]] | None = None, ) -> list[tuple[str, list, list]]: - dtypes = tuple(dtypes) if dtypes else _DEFAULT_DTYPES shapes = shapes if shapes is not None else _SHAPES return [ ( "sparse_copy: dispatch", make_cases("dispatch", dtypes, shapes, _make_dispatch_inputs, _dispatch_bytes), - _dispatch_variants(), + standard_fwd_bwd_pytorch_variants( + _dispatch_pytorch, + input_keys=("dense", "sparse_map"), + grad_input_keys=("dense",), + grad_output_key="backward_grad", + triton_fwd=_dispatch_triton_fwd, + triton_fwd_bwd=_dispatch_triton_fwd_bwd, + triton_output_postprocess=_dispatch_postprocess, + ), ), ( "sparse_copy: combine", make_cases("combine", dtypes, shapes, _make_combine_inputs, _combine_bytes), - _combine_variants(), + standard_fwd_bwd_pytorch_variants( + _combine_pytorch, + input_keys=("sparse", "scores", "sparse_map"), + grad_input_keys=("sparse", "scores"), + grad_output_key="backward_grad", + triton_fwd=_combine_triton_fwd, + triton_fwd_bwd=_combine_triton_fwd_bwd, + triton_output_postprocess=_combine_postprocess, + ), ), ] run = bench_main(benchmarks) - - -if __name__ == "__main__": - run() diff --git a/tools/benchmark/bench_sparse_linear.py b/tools/benchmark/bench_sparse_linear.py index 5ee5833c8..2ffdc69f4 100644 --- a/tools/benchmark/bench_sparse_linear.py +++ b/tools/benchmark/bench_sparse_linear.py @@ -1,26 +1,10 @@ -""" -Benchmark MoE sparse grouped GEMM kernels. - -Two operations are benchmarked, corresponding to the two linear layers in a MoE FFN: +"""MoE sparse grouped GEMM kernels — the two linear layers in a MoE FFN. -output_sparse (layer 1 / up-proj): +output_sparse (layer 1 / up-proj): out[i, :] = lhs[i, :] @ rhs[:, expert(i)*ffn_per_expert : (expert(i)+1)*ffn_per_expert] - lhs: (sparse_tokens, hidden), rhs: (hidden, ffn_per_expert × num_experts) - Each token's output columns come from its assigned expert's slice of rhs. - OutputSparseLinear.apply handles fwd+bwd. -input_inner_sparse (layer 2 / down-proj): +input_inner_sparse (layer 2 / down-proj): out[i, :] = lhs[i, :] @ rhs[expert(i)*ffn_per_expert : (expert(i)+1)*ffn_per_expert, :] - lhs: (sparse_tokens, ffn_per_expert), rhs: (ffn_per_expert × num_experts, hidden) - Each token's inner dimension comes from its assigned expert's slice of rhs. - InputSparseLinear.apply handles fwd+bwd. - -Comparisons: -- pytorch_loop: loop over experts with torch.mm per expert (the obvious PyTorch approach) -- pytorch_compiled: torch.compile of the loop -- fast_llm_triton: OutputSparseLinear / InputSparseLinear autograd functions - -Shapes: (tokens, top_k, num_experts, hidden, ffn_per_expert) matching MoE FFN configs. """ import torch @@ -28,7 +12,6 @@ from fast_llm.functional.config import TritonConfig from fast_llm.functional.triton.sparse_copy import SparseMap, get_sparse_map from fast_llm.functional.triton.sparse_linear import InputSparseLinear, OutputSparseLinear -from tools.benchmark.runner import Variant from tools.benchmark.utils import bench_main, device, make_cases, standard_fwd_bwd_pytorch_variants # (tokens, top_k, num_experts, hidden, ffn_per_expert) @@ -37,11 +20,9 @@ (4096, 2, 64, 4096, 1792), # fine-grained MoE: 64 experts, same total capacity (4096, 2, 8, 8192, 28672), # large hidden / wide FFN ] -_DEFAULT_DTYPES = (torch.bfloat16,) -# Triton autotuning warmup only needs to run once per shape. make_inputs is -# called multiple times per case (per variant, per fwd/fwd_bwd/memory pass), -# so cache which shapes have already been warmed up. +# Triton autotuning warmup needs to run only once per shape; make_inputs is +# called many times per case (per variant, per fwd/fwd_bwd/memory pass). _output_sparse_warmed_up: set[tuple] = set() _input_inner_sparse_warmed_up: set[tuple] = set() @@ -52,13 +33,10 @@ def _make_sparse_map(tokens: int, top_k: int, num_experts: int) -> SparseMap: def _mask_padded_rows(candidate: dict[str, torch.Tensor], inputs: dict) -> dict[str, torch.Tensor]: - # Two regions in the kernel's forward output and grad_lhs are by-design garbage that - # downstream consumers ignore: per-expert padding [pad_begin, expert_end) (where the - # kernel does a matmul on random padding inputs) and phantom rows [expert_ends[-1], - # num_rows) past the last expert (where the kernel early-returns and leaves the output - # buffer uninitialized). The loop reference produces zeros in both regions, so without - # masking those mismatches would dominate rel_rms. grad_rhs already excludes padded - # contributions in both the kernel and reference, so it needs no masking. + # Two regions in the kernel's forward output and grad_lhs are by-design garbage: + # per-expert padding [pad_begin, expert_end) and phantom rows past expert_ends[-1]. + # The loop reference produces zeros there; mask the kernel output to match so + # rel_rms reflects only the real rows. grad_rhs already excludes padding. sparse_map = inputs["sparse_map"] pad_begins = sparse_map.expert_pad_begins.tolist() pad_ends = sparse_map.expert_ends.tolist() @@ -123,9 +101,6 @@ def _make_input_inner_sparse_inputs( } -# --------------------------------------------------------------------------- output_sparse - - def _output_sparse_loop(lhs: torch.Tensor, rhs: torch.Tensor, sparse_map: SparseMap) -> torch.Tensor: ffn_per_expert = rhs.shape[1] // sparse_map.num_experts out = lhs.new_zeros(sparse_map.num_rows, ffn_per_expert) @@ -148,30 +123,6 @@ def _output_sparse_triton_fwd_bwd(inputs: dict) -> dict: return {"output": output.detach(), "grad_lhs": inputs["lhs"].grad, "grad_rhs": inputs["rhs"].grad} -def _output_sparse_variants() -> list[Variant]: - variants = standard_fwd_bwd_pytorch_variants( - _output_sparse_loop, - input_keys=("lhs", "rhs", "sparse_map"), - grad_input_keys=("lhs", "rhs"), - grad_output_key="backward_grad", - eager_name="pytorch_loop", - enable_max_autotune=False, - ) - if TritonConfig.enabled(): - variants.append( - Variant( - name="fast_llm_triton", - fwd=_output_sparse_triton_fwd, - fwd_bwd=_output_sparse_triton_fwd_bwd, - output_postprocess=_mask_padded_rows, - ) - ) - return variants - - -# --------------------------------------------------------------------------- input_inner_sparse - - def _input_inner_sparse_loop(lhs: torch.Tensor, rhs: torch.Tensor, sparse_map: SparseMap) -> torch.Tensor: ffn_per_expert = rhs.shape[0] // sparse_map.num_experts out = lhs.new_zeros(sparse_map.num_rows, rhs.shape[1]) @@ -194,36 +145,10 @@ def _input_inner_sparse_triton_fwd_bwd(inputs: dict) -> dict: return {"output": output.detach(), "grad_lhs": inputs["lhs"].grad, "grad_rhs": inputs["rhs"].grad} -def _input_inner_sparse_variants() -> list[Variant]: - variants = standard_fwd_bwd_pytorch_variants( - _input_inner_sparse_loop, - input_keys=("lhs", "rhs", "sparse_map"), - grad_input_keys=("lhs", "rhs"), - grad_output_key="backward_grad", - eager_name="pytorch_loop", - enable_max_autotune=False, - ) - if TritonConfig.enabled(): - variants.append( - Variant( - name="fast_llm_triton", - fwd=_input_inner_sparse_triton_fwd, - fwd_bwd=_input_inner_sparse_triton_fwd_bwd, - output_postprocess=_mask_padded_rows, - ) - ) - return variants - - -# --------------------------------------------------------------------------- bytes / flops - - def _sparse_linear_bytes( tokens: int, top_k: int, num_experts: int, hidden: int, ffn_per_expert: int, dtype: torch.dtype ) -> int: - # fwd: read lhs + read rhs_full + write output - # bwd: read grad_output + read rhs_full + write grad_lhs + read lhs + read grad_output + write grad_rhs - # Simplification: 3× lhs traffic + 3× rhs traffic + 2× output traffic + # Approximation: 3× lhs + 3× rhs + 2× output traffic across fwd+bwd. sparse_tokens = tokens * top_k lhs_bytes = sparse_tokens * hidden * dtype.itemsize rhs_bytes = hidden * ffn_per_expert * num_experts * dtype.itemsize @@ -232,18 +157,14 @@ def _sparse_linear_bytes( def _sparse_linear_flops(tokens: int, top_k: int, num_experts: int, hidden: int, ffn_per_expert: int) -> int: - # fwd + bwd ≈ 3 matmuls (fwd: lhs@rhs, bwd_lhs: grad@rhs.T, bwd_rhs: lhs.T@grad) + # 3 matmuls (fwd: lhs@rhs, bwd_lhs: grad@rhs.T, bwd_rhs: lhs.T@grad). return 3 * 2 * tokens * top_k * hidden * ffn_per_expert -# --------------------------------------------------------------------------- entry point - - def benchmarks( - dtypes: tuple[torch.dtype, ...] | None = None, + dtypes: tuple[torch.dtype, ...], shapes: list[tuple[int, int, int, int, int]] | None = None, ) -> list[tuple[str, list, list]]: - dtypes = tuple(dtypes) if dtypes else _DEFAULT_DTYPES shapes = shapes if shapes is not None else _SHAPES return [ ( @@ -251,7 +172,17 @@ def benchmarks( make_cases( "output_sparse", dtypes, shapes, _make_output_sparse_inputs, _sparse_linear_bytes, _sparse_linear_flops ), - _output_sparse_variants(), + standard_fwd_bwd_pytorch_variants( + _output_sparse_loop, + input_keys=("lhs", "rhs", "sparse_map"), + grad_input_keys=("lhs", "rhs"), + grad_output_key="backward_grad", + eager_name="pytorch_loop", + enable_max_autotune=False, + triton_fwd=_output_sparse_triton_fwd, + triton_fwd_bwd=_output_sparse_triton_fwd_bwd, + triton_output_postprocess=_mask_padded_rows, + ), ), ( "sparse_linear: input_inner_sparse (layer 2 / down-proj)", @@ -263,13 +194,19 @@ def benchmarks( _sparse_linear_bytes, _sparse_linear_flops, ), - _input_inner_sparse_variants(), + standard_fwd_bwd_pytorch_variants( + _input_inner_sparse_loop, + input_keys=("lhs", "rhs", "sparse_map"), + grad_input_keys=("lhs", "rhs"), + grad_output_key="backward_grad", + eager_name="pytorch_loop", + enable_max_autotune=False, + triton_fwd=_input_inner_sparse_triton_fwd, + triton_fwd_bwd=_input_inner_sparse_triton_fwd_bwd, + triton_output_postprocess=_mask_padded_rows, + ), ), ] run = bench_main(benchmarks) - - -if __name__ == "__main__": - run() diff --git a/tools/benchmark/utils.py b/tools/benchmark/utils.py index 2284c2497..63495fcc6 100644 --- a/tools/benchmark/utils.py +++ b/tools/benchmark/utils.py @@ -1,8 +1,6 @@ -""" -Convenience helpers for writing kernel benchmark files. Reduces the boilerplate -of building cases and variants so each `bench_*.py` can stay focused on -kernel-specific logic (input construction, expected_bytes/flops, special variants). -""" +"""Helpers shared by the `bench_*.py` modules — case construction, variant +builders for the canonical fp32_reference + pytorch_eager + pytorch_compiled +chunk, and the `run = bench_main(benchmarks)` glue.""" from collections.abc import Callable from functools import partial @@ -11,13 +9,12 @@ from fast_llm.engine.config_utils.data_type import DataType from fast_llm.functional.config import TritonConfig -from tools.benchmark.runner import Case, Inputs, Variant, run_benchmark +from tools.benchmark.runner import Case, Inputs, Variant, VariantFn, run_benchmark -# --------------------------------------------------------------------------- formatting +DEFAULT_DTYPES: tuple[torch.dtype, ...] = (torch.bfloat16,) def format_size(n: int) -> str: - """Format an int with the largest binary prefix that divides it exactly: 1048576 → '1 Mi'.""" for unit, factor in (("Gi", 1 << 30), ("Mi", 1 << 20), ("Ki", 1 << 10)): if n >= factor and n % factor == 0: return f"{n // factor} {unit}" @@ -25,25 +22,18 @@ def format_size(n: int) -> str: def format_shape(shape: tuple[int, ...]) -> str: - """Format a shape tuple with human-readable sizes per dim: (16777216,) → '(16 Mi,)'.""" joined = ", ".join(format_size(n) for n in shape) return f"({joined},)" if len(shape) == 1 else f"({joined})" def case_name(kernel: str, shape: tuple[int, ...], dtype: torch.dtype) -> str: - """Build the standard case header: `[copy] (16 Mi,) bf16`.""" return f"[{kernel}] {format_shape(shape)} {DataType.from_torch(dtype).short}" def device() -> str: - """The device benchmarks should target. Falls back to CPU when CUDA is missing - so non-Triton variants can still run for local smoke testing.""" return "cuda" if torch.cuda.is_available() else "cpu" -# --------------------------------------------------------------------------- cases - - def make_cases( kernel_name: str, dtypes: tuple[torch.dtype, ...], @@ -52,11 +42,8 @@ def make_cases( bytes_fn: Callable | None = None, flops_fn: Callable | None = None, ) -> list[Case]: - """Build the standard `Case` list as the cross-product of `dtypes × shapes`. - - Each `shape` may be a tuple or a scalar; tuples are unpacked positionally - into `make_inputs(*shape, dtype)`, `bytes_fn(*shape, dtype)`, and `flops_fn(*shape)`. - """ + """Cross-product of `dtypes × shapes`. Each shape may be a tuple or scalar; + tuples are unpacked positionally into `make_inputs`, `bytes_fn`, `flops_fn`.""" cases = [] for dtype in dtypes: for shape in shapes: @@ -73,13 +60,8 @@ def make_cases( return cases -# --------------------------------------------------------------------------- run/main - - def bench_main(benchmarks_fn: Callable) -> Callable: - """Build the standard `run()` callable that loops `benchmarks_fn(dtypes, shapes)` - through `run_benchmark`. Each `bench_*.py` exports `run = bench_main(benchmarks)` - so the package CLI in `__main__.py` can dispatch to it.""" + """Build the `run()` callable each `bench_*.py` exports for `__main__.py` to dispatch to.""" def run( verbose: bool = False, @@ -89,7 +71,7 @@ def run( rep_ms: float = 100.0, min_reps: int = 5, ) -> None: - for name, cases, variants in benchmarks_fn(dtypes, shapes): + for name, cases, variants in benchmarks_fn(dtypes or DEFAULT_DTYPES, shapes): run_benchmark( name, cases, variants, verbose=verbose, warmup_ms=warmup_ms, rep_ms=rep_ms, min_reps=min_reps ) @@ -97,27 +79,13 @@ def run( return run -# --------------------------------------------------------------------------- variant builders - - def standard_fwd_variants( eager_function: Callable, triton_function: Callable | None, unpack: Callable[[Inputs], tuple], ) -> list[Variant]: - """Build the canonical 5-variant set for a forward-only kernel. - - Generates: fp32_reference, pytorch_eager, pytorch_compiled, pytorch_compiled_max, - and (if `TritonConfig.enabled()`) fast_llm_triton. - - `eager_function` is the plain PyTorch implementation taking positional tensor args. - `triton_function` is the Fast-LLM Triton wrapper; pass `None` if the kernel has no - Triton variant. Both are invoked with `unpack(inputs)` unpacked positionally; - `triton_function` is called with an extra `use_triton=True` kwarg. - - The fp32 reference upcasts every floating-point tensor in the unpacked - arguments to fp32 (non-tensor / non-float arguments are passed through). - """ + """Forward-only variant set: fp32_reference, pytorch_eager, pytorch_compiled, + pytorch_compiled_max, and (if `TritonConfig.enabled()`) fast_llm_triton.""" def fp32_unpack(inputs: Inputs) -> tuple: return tuple( @@ -190,24 +158,25 @@ def standard_fwd_bwd_pytorch_variants( output_key: str = "output", reset_inputs: Callable[[Inputs], None] | None = None, extra_functions: dict[str, Callable] | None = None, + triton_fwd: VariantFn | None = None, + triton_fwd_bwd: VariantFn | None = None, + triton_name: str = "fast_llm_triton", + triton_output_postprocess: Callable[[dict, Inputs], dict] | None = None, eager_name: str = "pytorch_eager", enable_max_autotune: bool = True, ) -> list[Variant]: - """Build the canonical pytorch variant chunk for a forward-backward kernel. + """Forward+backward variant set for kernels driven by a dict-style input. - Generates: fp32_reference, , pytorch_compiled, [pytorch_compiled_max,] - plus any callables in `extra_functions` (e.g. apex implementations) appended - at the end with their dict-key as the variant name. + Generates fp32_reference, , pytorch_compiled, [pytorch_compiled_max,] + plus any callables in `extra_functions` (variant name = dict key). When + `triton_fwd`/`triton_fwd_bwd` are given and `TritonConfig.enabled()`, a + `` variant is appended; the callables receive the raw inputs dict. `eager_function(*[inputs[key] for key in input_keys])` computes the forward output. - `grad_input_keys` lists input dict keys whose `.grad` is collected and returned - as `grad_` in the output dict. `grad_output_key` is the input dict key for - `output.backward(grad_output)`; pass `None` for scalar-loss kernels (uses bare - `output.backward()`). `output_key` is the output dict key for the forward result. - - The fp32 reference upcasts every floating-point tensor in the input dict to - fp32, re-attaching `requires_grad=True` for `grad_input_keys`. Non-float and - non-tensor entries (e.g. ints, enums, SparseMap) are passed through. + `grad_input_keys` lists input keys whose `.grad` is collected as `grad_`. + `grad_output_key=None` triggers a scalar-loss `output.backward()`. + The fp32 reference upcasts every floating-point tensor in the input dict, re-attaching + `requires_grad=True` on `grad_input_keys`. """ fwd_kwargs = {"input_keys": input_keys, "output_key": output_key} fwd_bwd_kwargs = { @@ -245,4 +214,14 @@ def variant(name: str, function: Callable) -> Variant: variants.append(variant("pytorch_compiled_max", compiled_max)) for name, function in (extra_functions or {}).items(): variants.append(variant(name, function)) + if (triton_fwd is not None or triton_fwd_bwd is not None) and TritonConfig.enabled(): + variants.append( + Variant( + name=triton_name, + fwd=triton_fwd, + fwd_bwd=triton_fwd_bwd, + reset_inputs=reset_inputs, + output_postprocess=triton_output_postprocess, + ) + ) return variants From 46d727758c47c49859306728dec55b16d419c87d Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 1 May 2026 07:32:36 -0400 Subject: [PATCH 35/41] Trim docstrings, inline normalization variant wrappers Drop module docstrings that restate the kernel name (pointwise, mlp_activation, sparse_copy, sparse_linear, grpo_loss, entropy_loss) and per-helper docstrings that don't carry WHY content. Keep WHY comments justifying magic numbers in bytes/flops formulas, the rotary "only fwd benchmarked" rationale, and the normalization grad_buffer convention. Inline _layer_norm_variants and _rms_norm_variants by computing the apex extras dicts at module level. Privatize format helpers in utils.py. Co-Authored-By: Claude Opus 4.7 (1M context) --- tools/benchmark/bench_entropy_loss.py | 34 +++----- tools/benchmark/bench_grpo_loss.py | 14 +--- tools/benchmark/bench_mlp_activation.py | 14 ++-- tools/benchmark/bench_normalization.py | 102 +++++++++++------------- tools/benchmark/bench_pointwise.py | 8 +- tools/benchmark/bench_rotary.py | 3 +- tools/benchmark/bench_sparse_copy.py | 11 +-- tools/benchmark/bench_sparse_linear.py | 9 --- tools/benchmark/utils.py | 44 ++++------ 9 files changed, 91 insertions(+), 148 deletions(-) diff --git a/tools/benchmark/bench_entropy_loss.py b/tools/benchmark/bench_entropy_loss.py index ba0de7a60..4bcb4095e 100644 --- a/tools/benchmark/bench_entropy_loss.py +++ b/tools/benchmark/bench_entropy_loss.py @@ -1,7 +1,3 @@ -"""Entropy loss kernels: cross_entropy (labels and logits target formats), -reverse_kl, and z_loss. All Triton kernels fuse fwd+bwd into a single -logits-tensor pass; `grad_output=1.0` triggers gradient computation.""" - import torch import torch.nn.functional as F @@ -12,7 +8,7 @@ # (tokens, vocab_size) _SHAPES = [ - (4096, 32768), # 7B / Llama-2 vocab + (4096, 32768), # Llama-2 vocab (4096, 65536), (4096, 131072), # Llama-3 vocab ] @@ -54,6 +50,7 @@ def _z_loss_eager(logits: torch.Tensor) -> torch.Tensor: def _entropy_variants(eager_function, input_keys, triton_kwargs=None) -> list: + """Variants for the 3 entropy_loss kernels that share `triton_entropy_loss_forward_backward`.""" target_key = input_keys[1] triton_kwargs = triton_kwargs or {} @@ -90,15 +87,17 @@ def _z_loss_triton_fwd_bwd(inputs: dict) -> dict: return {"loss": loss, "grad_logits": grad_logits} -def _label_loss_bytes(tokens: int, vocab: int, dtype: torch.dtype) -> int: +def _label_bytes(tokens: int, vocab: int, dtype: torch.dtype) -> int: + # 2× logits + small labels traffic. return 2 * tokens * vocab * dtype.itemsize + tokens * 4 -def _dist_loss_bytes(tokens: int, vocab: int, dtype: torch.dtype) -> int: +def _dist_bytes(tokens: int, vocab: int, dtype: torch.dtype) -> int: + # 2× logits + 1× target_logits. return 3 * tokens * vocab * dtype.itemsize -def _entropy_loss_flops(tokens: int, vocab: int) -> int: +def _flops(tokens: int, vocab: int) -> int: # fwd ≈ 3*vocab per token, bwd ≈ vocab. return 4 * tokens * vocab @@ -111,21 +110,12 @@ def benchmarks( return [ ( "entropy_loss: cross_entropy (labels)", - make_cases( - "cross_entropy_labels", dtypes, shapes, _make_label_inputs, _label_loss_bytes, _entropy_loss_flops - ), + make_cases("cross_entropy_labels", dtypes, shapes, _make_label_inputs, _label_bytes, _flops), _entropy_variants(_ce_labels_eager, input_keys=("logits", "labels")), ), ( "entropy_loss: cross_entropy (logits)", - make_cases( - "cross_entropy_logits", - dtypes, - shapes, - _make_distribution_inputs, - _dist_loss_bytes, - _entropy_loss_flops, - ), + make_cases("cross_entropy_logits", dtypes, shapes, _make_distribution_inputs, _dist_bytes, _flops), _entropy_variants( _ce_dist_eager, input_keys=("logits", "target_logits"), @@ -137,9 +127,7 @@ def benchmarks( ), ( "entropy_loss: reverse_kl (logits)", - make_cases( - "reverse_kl_logits", dtypes, shapes, _make_distribution_inputs, _dist_loss_bytes, _entropy_loss_flops - ), + make_cases("reverse_kl_logits", dtypes, shapes, _make_distribution_inputs, _dist_bytes, _flops), _entropy_variants( _reverse_kl_eager, input_keys=("logits", "target_logits"), @@ -151,7 +139,7 @@ def benchmarks( ), ( "entropy_loss: z_loss", - make_cases("z_loss", dtypes, shapes, _make_label_inputs, _label_loss_bytes, _entropy_loss_flops), + make_cases("z_loss", dtypes, shapes, _make_label_inputs, _label_bytes, _flops), standard_fwd_bwd_pytorch_variants( _z_loss_eager, input_keys=("logits",), diff --git a/tools/benchmark/bench_grpo_loss.py b/tools/benchmark/bench_grpo_loss.py index 65c7376a9..7334962d1 100644 --- a/tools/benchmark/bench_grpo_loss.py +++ b/tools/benchmark/bench_grpo_loss.py @@ -1,7 +1,3 @@ -"""Fused GRPO (Group Relative Policy Optimization) loss kernel. The Triton -kernel fuses softmax, log-prob extraction, ratio + clip, and backward into a -single pass over logits.""" - import torch from fast_llm.functional.triton.grpo_loss import triton_grpo_loss_forward_backward @@ -72,15 +68,13 @@ def _triton_fwd_bwd(inputs: dict) -> dict: def _grpo_bytes(tokens: int, vocab: int, dtype: torch.dtype) -> int: - # fwd: read logits + bwd: read logits + write grad_logits - logit_traffic = 3 * tokens * vocab * dtype.itemsize - # labels (int64), advantages (fp32), old_log_probs (fp32) - scalar_traffic = tokens * (8 + 4 + 4) - return logit_traffic + scalar_traffic + # 3× logits traffic (read fwd, read+write bwd) + per-token scalars: + # labels (int64 = 8B), advantages (fp32 = 4B), old_log_probs (fp32 = 4B). + return 3 * tokens * vocab * dtype.itemsize + tokens * 16 def _grpo_flops(tokens: int, vocab: int) -> int: - # softmax (fwd) + grad (bwd) ≈ 14 FLOPs/element + # softmax (fwd) + grad (bwd) ≈ 14 FLOPs/element. return 14 * tokens * vocab diff --git a/tools/benchmark/bench_mlp_activation.py b/tools/benchmark/bench_mlp_activation.py index 52bac15d9..1b94351ad 100644 --- a/tools/benchmark/bench_mlp_activation.py +++ b/tools/benchmark/bench_mlp_activation.py @@ -1,6 +1,3 @@ -"""Fused MLP activation kernel. For gated SiLU the fwd input is (tokens, 2*ffn_dim) -— [gate_proj, up_proj] concatenated — and the output is (tokens, ffn_dim).""" - import torch from fast_llm.functional.config import ActivationType @@ -13,12 +10,12 @@ # (tokens, ffn_dim) — input has shape (tokens, 2*ffn_dim) for gated. _SHAPES = [ - (8192, 4096), # 7B/13B models - (8192, 8192), # large - (8192, 14336), # 70B models + (8192, 4096), # 7B/13B + (8192, 8192), + (8192, 14336), # 70B (4096, 28672), # MoE up-projection ] -_ACTIVATION = ActivationType.silu # standard for Llama-style gated models +_ACTIVATION = ActivationType.silu def _make_mlp_inputs(tokens: int, ffn_dim: int, dtype: torch.dtype) -> dict: @@ -42,8 +39,7 @@ def _triton_fwd_bwd(inputs: dict) -> dict: def _mlp_activation_bytes(tokens: int, ffn_dim: int, dtype: torch.dtype) -> int: - # fwd: read input (2*ffn_dim) + write output (ffn_dim). - # bwd: read grad_output (ffn_dim) + read input (2*ffn_dim) + write grad_input (2*ffn_dim). + # fwd: 3*ffn_dim traffic; bwd: 5*ffn_dim. 8 elements/token total. return 8 * tokens * ffn_dim * dtype.itemsize diff --git a/tools/benchmark/bench_normalization.py b/tools/benchmark/bench_normalization.py index fd31a3c3e..a75d579ef 100644 --- a/tools/benchmark/bench_normalization.py +++ b/tools/benchmark/bench_normalization.py @@ -1,6 +1,5 @@ -"""LayerNorm and RMSNorm. The Triton implementation handles both via the -`bias` argument (LayerNorm when given, RMSNorm when None) and writes parameter -gradients to Fast-LLM's `grad_buffer` attribute rather than autograd's `.grad`.""" +"""LayerNorm and RMSNorm. The Triton kernel writes parameter gradients to a +`grad_buffer` attribute (Fast-LLM convention) instead of autograd's `.grad`.""" import torch @@ -12,7 +11,6 @@ fast_normalization_available, fused_normalization_available, ) -from tools.benchmark.runner import Variant from tools.benchmark.utils import bench_main, device, make_cases, standard_fwd_bwd_pytorch_variants # (batch*seq, hidden). Numel fixed at 32M to mimic a constant training memory @@ -28,8 +26,6 @@ def _setup_param(tensor: torch.Tensor) -> torch.Tensor: - """Triton's normalization backward writes weight/bias gradients to a - `grad_buffer` attribute (Fast-LLM convention) instead of autograd's `.grad`.""" tensor.grad_buffer = torch.zeros_like(tensor) tensor.param_grad_is_zero = True return tensor @@ -60,20 +56,7 @@ def _rms_norm_eager(input_, weight): return torch.rms_norm(input_, weight.shape, weight, _EPS) -def _layer_norm_apex_fused(input_, weight, bias): - return FusedLayerNorm.apply(input_, weight.shape, weight, bias, _EPS) - - -def _layer_norm_apex_fast(input_, weight, bias): - return FastLayerNorm.apply(input_, weight.shape, weight, bias, _EPS) - - -def _rms_norm_apex_fused(input_, weight): - return FusedRMSNorm.apply(input_, weight.shape, weight, _EPS) - - def _param_grad(param: torch.Tensor) -> torch.Tensor: - """Triton writes to `grad_buffer`; autograd writes to `.grad`.""" return param.grad if param.grad is not None else param.grad_buffer @@ -108,48 +91,36 @@ def _rms_norm_triton_fwd_bwd(inputs: dict) -> dict: } -def _layer_norm_variants() -> list[Variant]: - extras: dict = {} - if fused_normalization_available: - extras["apex_fused"] = _layer_norm_apex_fused - if fast_normalization_available: - extras["apex_fast"] = _layer_norm_apex_fast - return standard_fwd_bwd_pytorch_variants( - _layer_norm_eager, - input_keys=("input", "weight", "bias"), - grad_input_keys=("input", "weight", "bias"), - grad_output_key="grad_output", - extra_functions=extras, - triton_fwd=_layer_norm_triton_fwd, - triton_fwd_bwd=_layer_norm_triton_fwd_bwd, - ) - - -def _rms_norm_variants() -> list[Variant]: - extras: dict = {} - if fused_normalization_available: - extras["apex_fused"] = _rms_norm_apex_fused - return standard_fwd_bwd_pytorch_variants( - _rms_norm_eager, - input_keys=("input", "weight"), - grad_input_keys=("input", "weight"), - grad_output_key="grad_output", - extra_functions=extras, - triton_fwd=_rms_norm_triton_fwd, - triton_fwd_bwd=_rms_norm_triton_fwd_bwd, - ) +def _layer_norm_apex_fused(input_, weight, bias): + return FusedLayerNorm.apply(input_, weight.shape, weight, bias, _EPS) + + +def _layer_norm_apex_fast(input_, weight, bias): + return FastLayerNorm.apply(input_, weight.shape, weight, bias, _EPS) + + +def _rms_norm_apex_fused(input_, weight): + return FusedRMSNorm.apply(input_, weight.shape, weight, _EPS) + + +_LAYER_NORM_EXTRAS: dict = {} +if fused_normalization_available: + _LAYER_NORM_EXTRAS["apex_fused"] = _layer_norm_apex_fused +if fast_normalization_available: + _LAYER_NORM_EXTRAS["apex_fast"] = _layer_norm_apex_fast + +_RMS_NORM_EXTRAS: dict = {} +if fused_normalization_available: + _RMS_NORM_EXTRAS["apex_fused"] = _rms_norm_apex_fused def _layer_norm_bytes(rows: int, cols: int, dtype: torch.dtype) -> int: - activations = 4 * rows * cols * dtype.itemsize # fwd in/out + bwd grad_in/out - parameters = 6 * cols * dtype.itemsize # weight, bias × (read + grad write) twice - return activations + parameters + # fwd in/out + bwd grad_in/out (4× activations) + weight & bias × (read + grad write). + return 4 * rows * cols * dtype.itemsize + 6 * cols * dtype.itemsize def _rms_norm_bytes(rows: int, cols: int, dtype: torch.dtype) -> int: - activations = 4 * rows * cols * dtype.itemsize - parameters = 3 * cols * dtype.itemsize - return activations + parameters + return 4 * rows * cols * dtype.itemsize + 3 * cols * dtype.itemsize def _layer_norm_flops(rows: int, cols: int) -> int: @@ -158,6 +129,7 @@ def _layer_norm_flops(rows: int, cols: int) -> int: def _rms_norm_flops(rows: int, cols: int) -> int: + # No mean subtraction or bias. return 15 * rows * cols @@ -170,12 +142,28 @@ def benchmarks( ( "normalization: layer_norm", make_cases("layer_norm", dtypes, shapes, _make_layer_norm_inputs, _layer_norm_bytes, _layer_norm_flops), - _layer_norm_variants(), + standard_fwd_bwd_pytorch_variants( + _layer_norm_eager, + input_keys=("input", "weight", "bias"), + grad_input_keys=("input", "weight", "bias"), + grad_output_key="grad_output", + extra_functions=_LAYER_NORM_EXTRAS, + triton_fwd=_layer_norm_triton_fwd, + triton_fwd_bwd=_layer_norm_triton_fwd_bwd, + ), ), ( "normalization: rms_norm", make_cases("rms_norm", dtypes, shapes, _make_rms_norm_inputs, _rms_norm_bytes, _rms_norm_flops), - _rms_norm_variants(), + standard_fwd_bwd_pytorch_variants( + _rms_norm_eager, + input_keys=("input", "weight"), + grad_input_keys=("input", "weight"), + grad_output_key="grad_output", + extra_functions=_RMS_NORM_EXTRAS, + triton_fwd=_rms_norm_triton_fwd, + triton_fwd_bwd=_rms_norm_triton_fwd_bwd, + ), ), ] diff --git a/tools/benchmark/bench_pointwise.py b/tools/benchmark/bench_pointwise.py index 3679597b0..bc5c85468 100644 --- a/tools/benchmark/bench_pointwise.py +++ b/tools/benchmark/bench_pointwise.py @@ -1,5 +1,3 @@ -"""Bandwidth-bound pointwise kernels: copy, fill, add.""" - import torch from fast_llm.functional.triton.pointwise import triton_add, triton_copy, triton_fill @@ -86,7 +84,11 @@ def benchmarks(dtypes: tuple[torch.dtype, ...], shapes: list[int] | None = None) return [ ("pointwise: copy", make_cases("copy", dtypes, shapes, _make_copy_inputs, _copy_bytes), _COPY_VARIANTS), ("pointwise: fill", make_cases("fill", dtypes, shapes, _make_fill_inputs, _fill_bytes), _FILL_VARIANTS), - ("pointwise: add", make_cases("add", dtypes, shapes, _make_add_inputs, _add_bytes, _add_flops), _ADD_VARIANTS), + ( + "pointwise: add", + make_cases("add", dtypes, shapes, _make_add_inputs, _add_bytes, _add_flops), + _ADD_VARIANTS, + ), ] diff --git a/tools/benchmark/bench_rotary.py b/tools/benchmark/bench_rotary.py index b91918f0b..fe9f0ff52 100644 --- a/tools/benchmark/bench_rotary.py +++ b/tools/benchmark/bench_rotary.py @@ -22,7 +22,7 @@ def _make_rotary_inputs(tokens: int, num_heads: int, head_size: int, dtype: torc input_ = torch.randn(tokens, num_heads, head_size, dtype=dtype, device=device()) return { "input_": input_, - "work": input_.clone(), # pre-allocated work buffer for in-place variants + "work": input_.clone(), "frequencies": torch.randn(tokens, 2 * rotary_dim, dtype=torch.float32, device=device()), } @@ -42,6 +42,7 @@ def _rotary_eager(input_: torch.Tensor, frequencies: torch.Tensor) -> torch.Tens def _rotary_bytes(tokens: int, num_heads: int, head_size: int, dtype: torch.dtype) -> int: + # frequencies are float32, hence the extra 4 bytes per token×head_size. return 2 * tokens * num_heads * head_size * dtype.itemsize + tokens * head_size * 4 diff --git a/tools/benchmark/bench_sparse_copy.py b/tools/benchmark/bench_sparse_copy.py index 5c9c54f37..513d71ec8 100644 --- a/tools/benchmark/bench_sparse_copy.py +++ b/tools/benchmark/bench_sparse_copy.py @@ -1,6 +1,3 @@ -"""MoE token dispatch (dense → sparse) and combine (sparse → dense, weighted by -routing scores) kernels.""" - import torch from fast_llm.functional.triton.sparse_copy import ( @@ -25,8 +22,8 @@ def _make_sparse_map(tokens: int, top_k: int, num_experts: int) -> SparseMap: def _make_phantom_mask(sparse_map: SparseMap) -> torch.Tensor: - # Boolean mask (num_rows, 1): True for phantom rows (within-expert padding - # and the static tail beyond expert_ends[-1]). Used in output_postprocess only. + # True for within-expert padding rows and the static tail past expert_ends[-1]; + # used only in output_postprocess, never in the timed path. mask = torch.zeros(sparse_map.num_rows, 1, dtype=torch.bool, device=device()) for expert in range(sparse_map.num_experts): pad_begin = int(sparse_map.expert_pad_begins[expert]) @@ -117,8 +114,8 @@ def _dispatch_bytes(tokens: int, top_k: int, num_experts: int, hidden: int, dtyp def _combine_bytes(tokens: int, top_k: int, num_experts: int, hidden: int, dtype: torch.dtype) -> int: - sparse_rows = top_k * tokens - return 2 * (sparse_rows + tokens) * hidden * dtype.itemsize + 4 * tokens * top_k * dtype.itemsize + # 2× (sparse + dense) hidden traffic + scores read/write. + return 2 * (top_k + 1) * tokens * hidden * dtype.itemsize + 4 * tokens * top_k * dtype.itemsize def benchmarks( diff --git a/tools/benchmark/bench_sparse_linear.py b/tools/benchmark/bench_sparse_linear.py index 2ffdc69f4..2836c7e6d 100644 --- a/tools/benchmark/bench_sparse_linear.py +++ b/tools/benchmark/bench_sparse_linear.py @@ -1,12 +1,3 @@ -"""MoE sparse grouped GEMM kernels — the two linear layers in a MoE FFN. - -output_sparse (layer 1 / up-proj): - out[i, :] = lhs[i, :] @ rhs[:, expert(i)*ffn_per_expert : (expert(i)+1)*ffn_per_expert] - -input_inner_sparse (layer 2 / down-proj): - out[i, :] = lhs[i, :] @ rhs[expert(i)*ffn_per_expert : (expert(i)+1)*ffn_per_expert, :] -""" - import torch from fast_llm.functional.config import TritonConfig diff --git a/tools/benchmark/utils.py b/tools/benchmark/utils.py index 63495fcc6..23ebf1ab2 100644 --- a/tools/benchmark/utils.py +++ b/tools/benchmark/utils.py @@ -1,7 +1,3 @@ -"""Helpers shared by the `bench_*.py` modules — case construction, variant -builders for the canonical fp32_reference + pytorch_eager + pytorch_compiled -chunk, and the `run = bench_main(benchmarks)` glue.""" - from collections.abc import Callable from functools import partial @@ -14,20 +10,17 @@ DEFAULT_DTYPES: tuple[torch.dtype, ...] = (torch.bfloat16,) -def format_size(n: int) -> str: +def _format_size(n: int) -> str: for unit, factor in (("Gi", 1 << 30), ("Mi", 1 << 20), ("Ki", 1 << 10)): if n >= factor and n % factor == 0: return f"{n // factor} {unit}" return str(n) -def format_shape(shape: tuple[int, ...]) -> str: - joined = ", ".join(format_size(n) for n in shape) - return f"({joined},)" if len(shape) == 1 else f"({joined})" - - -def case_name(kernel: str, shape: tuple[int, ...], dtype: torch.dtype) -> str: - return f"[{kernel}] {format_shape(shape)} {DataType.from_torch(dtype).short}" +def _case_name(kernel: str, shape: tuple[int, ...], dtype: torch.dtype) -> str: + joined = ", ".join(_format_size(n) for n in shape) + shape_repr = f"({joined},)" if len(shape) == 1 else f"({joined})" + return f"[{kernel}] {shape_repr} {DataType.from_torch(dtype).short}" def device() -> str: @@ -50,7 +43,7 @@ def make_cases( shape_tuple = shape if isinstance(shape, tuple) else (shape,) cases.append( Case( - name=case_name(kernel_name, shape_tuple, dtype), + name=_case_name(kernel_name, shape_tuple, dtype), make_inputs=partial(make_inputs, *shape_tuple, dtype), expected_bytes=bytes_fn(*shape_tuple, dtype) if bytes_fn else None, expected_flops=flops_fn(*shape_tuple) if flops_fn else None, @@ -61,8 +54,6 @@ def make_cases( def bench_main(benchmarks_fn: Callable) -> Callable: - """Build the `run()` callable each `bench_*.py` exports for `__main__.py` to dispatch to.""" - def run( verbose: bool = False, dtypes: tuple[torch.dtype, ...] | None = None, @@ -84,8 +75,8 @@ def standard_fwd_variants( triton_function: Callable | None, unpack: Callable[[Inputs], tuple], ) -> list[Variant]: - """Forward-only variant set: fp32_reference, pytorch_eager, pytorch_compiled, - pytorch_compiled_max, and (if `TritonConfig.enabled()`) fast_llm_triton.""" + """fp32_reference, pytorch_eager, pytorch_compiled, pytorch_compiled_max, + and (if `TritonConfig.enabled()`) fast_llm_triton.""" def fp32_unpack(inputs: Inputs) -> tuple: return tuple( @@ -165,18 +156,13 @@ def standard_fwd_bwd_pytorch_variants( eager_name: str = "pytorch_eager", enable_max_autotune: bool = True, ) -> list[Variant]: - """Forward+backward variant set for kernels driven by a dict-style input. - - Generates fp32_reference, , pytorch_compiled, [pytorch_compiled_max,] - plus any callables in `extra_functions` (variant name = dict key). When - `triton_fwd`/`triton_fwd_bwd` are given and `TritonConfig.enabled()`, a - `` variant is appended; the callables receive the raw inputs dict. - - `eager_function(*[inputs[key] for key in input_keys])` computes the forward output. - `grad_input_keys` lists input keys whose `.grad` is collected as `grad_`. - `grad_output_key=None` triggers a scalar-loss `output.backward()`. - The fp32 reference upcasts every floating-point tensor in the input dict, re-attaching - `requires_grad=True` on `grad_input_keys`. + """fp32_reference + + pytorch_compiled + [pytorch_compiled_max] + + `extra_functions` + optional `` (gated on `TritonConfig.enabled()`). + + Pytorch variants call `eager_function(*[inputs[k] for k in input_keys])`. For + fwd_bwd, `grad_input_keys` are read out as `grad_` and `grad_output_key=None` + triggers a scalar `output.backward()`. The fp32 reference upcasts every + floating-point input, re-attaching `requires_grad=True` on `grad_input_keys`. """ fwd_kwargs = {"input_keys": input_keys, "output_key": output_key} fwd_bwd_kwargs = { From 4fb55f8acb573d088a7e6e9d2524ef95179d9375 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 1 May 2026 08:43:32 -0400 Subject: [PATCH 36/41] Restructure Case and Variant into class hierarchies Each benchmark now defines a Case subclass that holds its shape parameters and exposes name/expected_bytes/expected_flops/compute_dtype as properties, co-locating per-kernel logic instead of scattering it across module-level helpers. Variant gains a PytorchVariant/Fp32ReferenceVariant pair that absorbs the parameterized pytorch dispatch pattern; triton variants stay flat-data Variant instances with module-level fwd/fwd_bwd functions, since they're stateless and singleton per kernel. make_cases and its name-formatting helpers are gone; each bench file constructs cases via a list comprehension. device() is replaced by a device argument on Case.make_inputs, supplied by the runner. Co-Authored-By: Claude Opus 4.7 (1M context) --- tools/benchmark/bench_entropy_loss.py | 114 +++++----- tools/benchmark/bench_grpo_loss.py | 77 ++++--- tools/benchmark/bench_mlp_activation.py | 76 ++++--- tools/benchmark/bench_normalization.py | 134 +++++++----- tools/benchmark/bench_pointwise.py | 105 +++++---- tools/benchmark/bench_rotary.py | 61 ++++-- tools/benchmark/bench_sparse_copy.py | 160 +++++++++----- tools/benchmark/bench_sparse_linear.py | 230 +++++++++++--------- tools/benchmark/runner.py | 31 +-- tools/benchmark/utils.py | 270 +++++++++++------------- 10 files changed, 719 insertions(+), 539 deletions(-) diff --git a/tools/benchmark/bench_entropy_loss.py b/tools/benchmark/bench_entropy_loss.py index 4bcb4095e..bc14e8bf9 100644 --- a/tools/benchmark/bench_entropy_loss.py +++ b/tools/benchmark/bench_entropy_loss.py @@ -1,10 +1,13 @@ +import dataclasses + import torch import torch.nn.functional as F -from fast_llm.functional.config import EntropyLossType, TargetFormat +from fast_llm.functional.config import EntropyLossType, TargetFormat, TritonConfig from fast_llm.functional.triton.entropy_loss import triton_entropy_loss_forward_backward from fast_llm.functional.triton.z_loss import triton_z_loss_forward_backward -from tools.benchmark.utils import bench_main, device, make_cases, standard_fwd_bwd_pytorch_variants +from tools.benchmark.runner import Case, Inputs, Variant +from tools.benchmark.utils import bench_main, dtype_short, standard_fwd_bwd_pytorch_variants # (tokens, vocab_size) _SHAPES = [ @@ -14,18 +17,50 @@ ] -def _make_label_inputs(tokens: int, vocab: int, dtype: torch.dtype) -> dict: - return { - "logits": torch.randn(tokens, vocab, dtype=dtype, device=device(), requires_grad=True), - "labels": torch.randint(0, vocab, (tokens,), dtype=torch.long, device=device()), - } +@dataclasses.dataclass +class _EntropyCase(Case): + tokens: int + vocab: int + dtype: torch.dtype + + @property + def name(self) -> str: + return f"({self.tokens}, {self.vocab}) {dtype_short(self.dtype)}" + + @property + def compute_dtype(self) -> torch.dtype: + return self.dtype + + @property + def expected_flops(self) -> int: + # fwd ≈ 3*vocab per token, bwd ≈ vocab. + return 4 * self.tokens * self.vocab + + +class EntropyLabelCase(_EntropyCase): + @property + def expected_bytes(self) -> int: + # 2× logits + small labels traffic. + return 2 * self.tokens * self.vocab * self.dtype.itemsize + self.tokens * 4 + def make_inputs(self, device: str) -> Inputs: + return { + "logits": torch.randn(self.tokens, self.vocab, dtype=self.dtype, device=device, requires_grad=True), + "labels": torch.randint(0, self.vocab, (self.tokens,), dtype=torch.long, device=device), + } -def _make_distribution_inputs(tokens: int, vocab: int, dtype: torch.dtype) -> dict: - return { - "logits": torch.randn(tokens, vocab, dtype=dtype, device=device(), requires_grad=True), - "target_logits": torch.randn(tokens, vocab, dtype=dtype, device=device()), - } + +class EntropyDistCase(_EntropyCase): + @property + def expected_bytes(self) -> int: + # 2× logits + 1× target_logits. + return 3 * self.tokens * self.vocab * self.dtype.itemsize + + def make_inputs(self, device: str) -> Inputs: + return { + "logits": torch.randn(self.tokens, self.vocab, dtype=self.dtype, device=device, requires_grad=True), + "target_logits": torch.randn(self.tokens, self.vocab, dtype=self.dtype, device=device), + } def _reset_logits_grad(inputs: dict) -> None: @@ -49,7 +84,7 @@ def _z_loss_eager(logits: torch.Tensor) -> torch.Tensor: return (log_z * log_z).mean() -def _entropy_variants(eager_function, input_keys, triton_kwargs=None) -> list: +def _entropy_variants(eager_function, input_keys, triton_kwargs=None) -> list[Variant]: """Variants for the 3 entropy_loss kernels that share `triton_entropy_loss_forward_backward`.""" target_key = input_keys[1] triton_kwargs = triton_kwargs or {} @@ -66,15 +101,16 @@ def triton_fwd_bwd(inputs: dict) -> dict: ) return {"loss": loss, "grad_logits": grad_logits} - return standard_fwd_bwd_pytorch_variants( + variants = standard_fwd_bwd_pytorch_variants( eager_function, input_keys=input_keys, grad_input_keys=("logits",), output_key="loss", reset_inputs=_reset_logits_grad, - triton_fwd=triton_fwd, - triton_fwd_bwd=triton_fwd_bwd, ) + if TritonConfig.enabled(): + variants.append(Variant(name="fast_llm_triton", fwd=triton_fwd, fwd_bwd=triton_fwd_bwd)) + return variants def _z_loss_triton_fwd(inputs: dict) -> dict: @@ -87,35 +123,31 @@ def _z_loss_triton_fwd_bwd(inputs: dict) -> dict: return {"loss": loss, "grad_logits": grad_logits} -def _label_bytes(tokens: int, vocab: int, dtype: torch.dtype) -> int: - # 2× logits + small labels traffic. - return 2 * tokens * vocab * dtype.itemsize + tokens * 4 - - -def _dist_bytes(tokens: int, vocab: int, dtype: torch.dtype) -> int: - # 2× logits + 1× target_logits. - return 3 * tokens * vocab * dtype.itemsize - - -def _flops(tokens: int, vocab: int) -> int: - # fwd ≈ 3*vocab per token, bwd ≈ vocab. - return 4 * tokens * vocab - - def benchmarks( dtypes: tuple[torch.dtype, ...], shapes: list[tuple[int, int]] | None = None, ) -> list[tuple[str, list, list]]: shapes = shapes if shapes is not None else _SHAPES + label_cases = [EntropyLabelCase(tokens=t, vocab=v, dtype=d) for d in dtypes for (t, v) in shapes] + dist_cases = [EntropyDistCase(tokens=t, vocab=v, dtype=d) for d in dtypes for (t, v) in shapes] + z_loss_variants = standard_fwd_bwd_pytorch_variants( + _z_loss_eager, + input_keys=("logits",), + grad_input_keys=("logits",), + output_key="loss", + reset_inputs=_reset_logits_grad, + ) + if TritonConfig.enabled(): + z_loss_variants.append(Variant(name="fast_llm_triton", fwd=_z_loss_triton_fwd, fwd_bwd=_z_loss_triton_fwd_bwd)) return [ ( "entropy_loss: cross_entropy (labels)", - make_cases("cross_entropy_labels", dtypes, shapes, _make_label_inputs, _label_bytes, _flops), + label_cases, _entropy_variants(_ce_labels_eager, input_keys=("logits", "labels")), ), ( "entropy_loss: cross_entropy (logits)", - make_cases("cross_entropy_logits", dtypes, shapes, _make_distribution_inputs, _dist_bytes, _flops), + dist_cases, _entropy_variants( _ce_dist_eager, input_keys=("logits", "target_logits"), @@ -127,7 +159,7 @@ def benchmarks( ), ( "entropy_loss: reverse_kl (logits)", - make_cases("reverse_kl_logits", dtypes, shapes, _make_distribution_inputs, _dist_bytes, _flops), + dist_cases, _entropy_variants( _reverse_kl_eager, input_keys=("logits", "target_logits"), @@ -137,19 +169,7 @@ def benchmarks( }, ), ), - ( - "entropy_loss: z_loss", - make_cases("z_loss", dtypes, shapes, _make_label_inputs, _label_bytes, _flops), - standard_fwd_bwd_pytorch_variants( - _z_loss_eager, - input_keys=("logits",), - grad_input_keys=("logits",), - output_key="loss", - reset_inputs=_reset_logits_grad, - triton_fwd=_z_loss_triton_fwd, - triton_fwd_bwd=_z_loss_triton_fwd_bwd, - ), - ), + ("entropy_loss: z_loss", label_cases, z_loss_variants), ] diff --git a/tools/benchmark/bench_grpo_loss.py b/tools/benchmark/bench_grpo_loss.py index 7334962d1..4b7353d12 100644 --- a/tools/benchmark/bench_grpo_loss.py +++ b/tools/benchmark/bench_grpo_loss.py @@ -1,7 +1,11 @@ +import dataclasses + import torch +from fast_llm.functional.config import TritonConfig from fast_llm.functional.triton.grpo_loss import triton_grpo_loss_forward_backward -from tools.benchmark.utils import bench_main, device, make_cases, standard_fwd_bwd_pytorch_variants +from tools.benchmark.runner import Case, Inputs, Variant +from tools.benchmark.utils import bench_main, dtype_short, standard_fwd_bwd_pytorch_variants _SHAPES = [ (4096, 32768), @@ -12,13 +16,38 @@ _EPSILON_HIGH = 0.2 -def _make_grpo_inputs(tokens: int, vocab: int, dtype: torch.dtype) -> dict: - return { - "logits": torch.randn(tokens, vocab, dtype=dtype, device=device(), requires_grad=True), - "labels": torch.randint(0, vocab, (tokens,), dtype=torch.long, device=device()), - "advantages": torch.randn(tokens, dtype=torch.float32, device=device()), - "old_log_probs": torch.randn(tokens, dtype=torch.float32, device=device()) - 5.0, - } +@dataclasses.dataclass +class GrpoLossCase(Case): + tokens: int + vocab: int + dtype: torch.dtype + + @property + def name(self) -> str: + return f"({self.tokens}, {self.vocab}) {dtype_short(self.dtype)}" + + @property + def compute_dtype(self) -> torch.dtype: + return self.dtype + + @property + def expected_bytes(self) -> int: + # 3× logits traffic (read fwd, read+write bwd) + per-token scalars: + # labels (int64 = 8B), advantages (fp32 = 4B), old_log_probs (fp32 = 4B). + return 3 * self.tokens * self.vocab * self.dtype.itemsize + self.tokens * 16 + + @property + def expected_flops(self) -> int: + # softmax (fwd) + grad (bwd) ≈ 14 FLOPs/element. + return 14 * self.tokens * self.vocab + + def make_inputs(self, device: str) -> Inputs: + return { + "logits": torch.randn(self.tokens, self.vocab, dtype=self.dtype, device=device, requires_grad=True), + "labels": torch.randint(0, self.vocab, (self.tokens,), dtype=torch.long, device=device), + "advantages": torch.randn(self.tokens, dtype=torch.float32, device=device), + "old_log_probs": torch.randn(self.tokens, dtype=torch.float32, device=device) - 5.0, + } def _grpo_eager(logits: torch.Tensor, labels: torch.Tensor, advantages: torch.Tensor, old_log_probs: torch.Tensor): @@ -67,35 +96,25 @@ def _triton_fwd_bwd(inputs: dict) -> dict: return {"loss": loss, "grad_logits": grad_logits} -def _grpo_bytes(tokens: int, vocab: int, dtype: torch.dtype) -> int: - # 3× logits traffic (read fwd, read+write bwd) + per-token scalars: - # labels (int64 = 8B), advantages (fp32 = 4B), old_log_probs (fp32 = 4B). - return 3 * tokens * vocab * dtype.itemsize + tokens * 16 - - -def _grpo_flops(tokens: int, vocab: int) -> int: - # softmax (fwd) + grad (bwd) ≈ 14 FLOPs/element. - return 14 * tokens * vocab - - def benchmarks( dtypes: tuple[torch.dtype, ...], shapes: list[tuple[int, int]] | None = None, ) -> list[tuple[str, list, list]]: shapes = shapes if shapes is not None else _SHAPES + variants = standard_fwd_bwd_pytorch_variants( + _grpo_eager, + input_keys=("logits", "labels", "advantages", "old_log_probs"), + grad_input_keys=("logits",), + output_key="loss", + reset_inputs=_reset_logits_grad, + ) + if TritonConfig.enabled(): + variants.append(Variant(name="fast_llm_triton", fwd=_triton_fwd, fwd_bwd=_triton_fwd_bwd)) return [ ( "grpo_loss", - make_cases("grpo_loss", dtypes, shapes, _make_grpo_inputs, _grpo_bytes, _grpo_flops), - standard_fwd_bwd_pytorch_variants( - _grpo_eager, - input_keys=("logits", "labels", "advantages", "old_log_probs"), - grad_input_keys=("logits",), - output_key="loss", - reset_inputs=_reset_logits_grad, - triton_fwd=_triton_fwd, - triton_fwd_bwd=_triton_fwd_bwd, - ), + [GrpoLossCase(tokens=t, vocab=v, dtype=d) for d in dtypes for (t, v) in shapes], + variants, ) ] diff --git a/tools/benchmark/bench_mlp_activation.py b/tools/benchmark/bench_mlp_activation.py index 1b94351ad..0106d0521 100644 --- a/tools/benchmark/bench_mlp_activation.py +++ b/tools/benchmark/bench_mlp_activation.py @@ -1,12 +1,15 @@ +import dataclasses + import torch -from fast_llm.functional.config import ActivationType +from fast_llm.functional.config import ActivationType, TritonConfig from fast_llm.functional.triton.mlp import ( torch_mlp_activation, triton_mlp_activation_autograd, triton_mlp_activation_forward, ) -from tools.benchmark.utils import bench_main, device, make_cases, standard_fwd_bwd_pytorch_variants +from tools.benchmark.runner import Case, Inputs, Variant +from tools.benchmark.utils import bench_main, dtype_short, standard_fwd_bwd_pytorch_variants # (tokens, ffn_dim) — input has shape (tokens, 2*ffn_dim) for gated. _SHAPES = [ @@ -18,13 +21,37 @@ _ACTIVATION = ActivationType.silu -def _make_mlp_inputs(tokens: int, ffn_dim: int, dtype: torch.dtype) -> dict: - return { - "input": torch.randn(tokens, 2 * ffn_dim, dtype=dtype, device=device(), requires_grad=True), - "grad_output": torch.randn(tokens, ffn_dim, dtype=dtype, device=device()), - "gated": True, - "activation_type": _ACTIVATION, - } +@dataclasses.dataclass +class MlpActivationCase(Case): + tokens: int + ffn_dim: int + dtype: torch.dtype + + @property + def name(self) -> str: + return f"({self.tokens}, {self.ffn_dim}) {dtype_short(self.dtype)}" + + @property + def compute_dtype(self) -> torch.dtype: + return self.dtype + + @property + def expected_bytes(self) -> int: + # fwd: 3*ffn_dim traffic; bwd: 5*ffn_dim. 8 elements/token total. + return 8 * self.tokens * self.ffn_dim * self.dtype.itemsize + + @property + def expected_flops(self) -> int: + # gated silu: fwd ≈ 6 FLOPs/element, bwd ≈ 8 FLOPs/element. + return 14 * self.tokens * self.ffn_dim + + def make_inputs(self, device: str) -> Inputs: + return { + "input": torch.randn(self.tokens, 2 * self.ffn_dim, dtype=self.dtype, device=device, requires_grad=True), + "grad_output": torch.randn(self.tokens, self.ffn_dim, dtype=self.dtype, device=device), + "gated": True, + "activation_type": _ACTIVATION, + } def _triton_fwd(inputs: dict) -> dict: @@ -38,35 +65,24 @@ def _triton_fwd_bwd(inputs: dict) -> dict: return {"output": output.detach(), "grad_input": inputs["input"].grad} -def _mlp_activation_bytes(tokens: int, ffn_dim: int, dtype: torch.dtype) -> int: - # fwd: 3*ffn_dim traffic; bwd: 5*ffn_dim. 8 elements/token total. - return 8 * tokens * ffn_dim * dtype.itemsize - - -def _mlp_activation_flops(tokens: int, ffn_dim: int) -> int: - # gated silu: fwd ≈ 6 FLOPs/element, bwd ≈ 8 FLOPs/element. - return 14 * tokens * ffn_dim - - def benchmarks( dtypes: tuple[torch.dtype, ...], shapes: list[tuple[int, int]] | None = None, ) -> list[tuple[str, list, list]]: shapes = shapes if shapes is not None else _SHAPES + variants = standard_fwd_bwd_pytorch_variants( + torch_mlp_activation, + input_keys=("input", "gated", "activation_type"), + grad_input_keys=("input",), + grad_output_key="grad_output", + ) + if TritonConfig.enabled(): + variants.append(Variant(name="fast_llm_triton", fwd=_triton_fwd, fwd_bwd=_triton_fwd_bwd)) return [ ( "mlp_activation (gated silu)", - make_cases( - "mlp_activation", dtypes, shapes, _make_mlp_inputs, _mlp_activation_bytes, _mlp_activation_flops - ), - standard_fwd_bwd_pytorch_variants( - torch_mlp_activation, - input_keys=("input", "gated", "activation_type"), - grad_input_keys=("input",), - grad_output_key="grad_output", - triton_fwd=_triton_fwd, - triton_fwd_bwd=_triton_fwd_bwd, - ), + [MlpActivationCase(tokens=t, ffn_dim=f, dtype=d) for d in dtypes for (t, f) in shapes], + variants, ) ] diff --git a/tools/benchmark/bench_normalization.py b/tools/benchmark/bench_normalization.py index a75d579ef..636e428a7 100644 --- a/tools/benchmark/bench_normalization.py +++ b/tools/benchmark/bench_normalization.py @@ -1,8 +1,11 @@ """LayerNorm and RMSNorm. The Triton kernel writes parameter gradients to a `grad_buffer` attribute (Fast-LLM convention) instead of autograd's `.grad`.""" +import dataclasses + import torch +from fast_llm.functional.config import TritonConfig from fast_llm.functional.triton.normalization import triton_normalization_autograd from fast_llm.layers.common.normalization.normalization import ( FastLayerNorm, @@ -11,7 +14,8 @@ fast_normalization_available, fused_normalization_available, ) -from tools.benchmark.utils import bench_main, device, make_cases, standard_fwd_bwd_pytorch_variants +from tools.benchmark.runner import Case, Inputs, Variant +from tools.benchmark.utils import bench_main, dtype_short, standard_fwd_bwd_pytorch_variants # (batch*seq, hidden). Numel fixed at 32M to mimic a constant training memory # budget across model widths; hidden swept from 1K to 16K. @@ -31,21 +35,58 @@ def _setup_param(tensor: torch.Tensor) -> torch.Tensor: return tensor -def _make_layer_norm_inputs(rows: int, cols: int, dtype: torch.dtype) -> dict: - return { - "input": torch.randn(rows, cols, dtype=dtype, device=device(), requires_grad=True), - "weight": _setup_param(torch.randn(cols, dtype=dtype, device=device(), requires_grad=True)), - "bias": _setup_param(torch.zeros(cols, dtype=dtype, device=device(), requires_grad=True)), - "grad_output": torch.randn(rows, cols, dtype=dtype, device=device()), - } +@dataclasses.dataclass +class _NormalizationCase(Case): + rows: int + cols: int + dtype: torch.dtype + @property + def name(self) -> str: + return f"({self.rows}, {self.cols}) {dtype_short(self.dtype)}" -def _make_rms_norm_inputs(rows: int, cols: int, dtype: torch.dtype) -> dict: - return { - "input": torch.randn(rows, cols, dtype=dtype, device=device(), requires_grad=True), - "weight": _setup_param(torch.randn(cols, dtype=dtype, device=device(), requires_grad=True)), - "grad_output": torch.randn(rows, cols, dtype=dtype, device=device()), - } + @property + def compute_dtype(self) -> torch.dtype: + return self.dtype + + +class LayerNormCase(_NormalizationCase): + @property + def expected_bytes(self) -> int: + # 4× activations (fwd+bwd in/out) + weight & bias × (read + grad write). + return 4 * self.rows * self.cols * self.dtype.itemsize + 6 * self.cols * self.dtype.itemsize + + @property + def expected_flops(self) -> int: + # fwd ≈ 7 FLOPs/elem (mean, variance, normalize, scale+shift); bwd ≈ 2× fwd. + return 21 * self.rows * self.cols + + def make_inputs(self, device: str) -> Inputs: + return { + "input": torch.randn(self.rows, self.cols, dtype=self.dtype, device=device, requires_grad=True), + "weight": _setup_param(torch.randn(self.cols, dtype=self.dtype, device=device, requires_grad=True)), + "bias": _setup_param(torch.zeros(self.cols, dtype=self.dtype, device=device, requires_grad=True)), + "grad_output": torch.randn(self.rows, self.cols, dtype=self.dtype, device=device), + } + + +class RmsNormCase(_NormalizationCase): + @property + def expected_bytes(self) -> int: + # No bias compared to LayerNorm. + return 4 * self.rows * self.cols * self.dtype.itemsize + 3 * self.cols * self.dtype.itemsize + + @property + def expected_flops(self) -> int: + # No mean subtraction or bias compared to LayerNorm. + return 15 * self.rows * self.cols + + def make_inputs(self, device: str) -> Inputs: + return { + "input": torch.randn(self.rows, self.cols, dtype=self.dtype, device=device, requires_grad=True), + "weight": _setup_param(torch.randn(self.cols, dtype=self.dtype, device=device, requires_grad=True)), + "grad_output": torch.randn(self.rows, self.cols, dtype=self.dtype, device=device), + } def _layer_norm_eager(input_, weight, bias): @@ -114,56 +155,43 @@ def _rms_norm_apex_fused(input_, weight): _RMS_NORM_EXTRAS["apex_fused"] = _rms_norm_apex_fused -def _layer_norm_bytes(rows: int, cols: int, dtype: torch.dtype) -> int: - # fwd in/out + bwd grad_in/out (4× activations) + weight & bias × (read + grad write). - return 4 * rows * cols * dtype.itemsize + 6 * cols * dtype.itemsize - - -def _rms_norm_bytes(rows: int, cols: int, dtype: torch.dtype) -> int: - return 4 * rows * cols * dtype.itemsize + 3 * cols * dtype.itemsize - - -def _layer_norm_flops(rows: int, cols: int) -> int: - # fwd ≈ 7 per element (mean, variance, normalize, scale+shift); bwd ≈ 2× fwd. - return 21 * rows * cols - - -def _rms_norm_flops(rows: int, cols: int) -> int: - # No mean subtraction or bias. - return 15 * rows * cols - - def benchmarks( dtypes: tuple[torch.dtype, ...], shapes: list[tuple[int, int]] | None = None, ) -> list[tuple[str, list, list]]: shapes = shapes if shapes is not None else _SHAPES + layer_norm_variants = standard_fwd_bwd_pytorch_variants( + _layer_norm_eager, + input_keys=("input", "weight", "bias"), + grad_input_keys=("input", "weight", "bias"), + grad_output_key="grad_output", + extra_functions=_LAYER_NORM_EXTRAS, + ) + if TritonConfig.enabled(): + layer_norm_variants.append( + Variant(name="fast_llm_triton", fwd=_layer_norm_triton_fwd, fwd_bwd=_layer_norm_triton_fwd_bwd) + ) + rms_norm_variants = standard_fwd_bwd_pytorch_variants( + _rms_norm_eager, + input_keys=("input", "weight"), + grad_input_keys=("input", "weight"), + grad_output_key="grad_output", + extra_functions=_RMS_NORM_EXTRAS, + ) + if TritonConfig.enabled(): + rms_norm_variants.append( + Variant(name="fast_llm_triton", fwd=_rms_norm_triton_fwd, fwd_bwd=_rms_norm_triton_fwd_bwd) + ) return [ ( "normalization: layer_norm", - make_cases("layer_norm", dtypes, shapes, _make_layer_norm_inputs, _layer_norm_bytes, _layer_norm_flops), - standard_fwd_bwd_pytorch_variants( - _layer_norm_eager, - input_keys=("input", "weight", "bias"), - grad_input_keys=("input", "weight", "bias"), - grad_output_key="grad_output", - extra_functions=_LAYER_NORM_EXTRAS, - triton_fwd=_layer_norm_triton_fwd, - triton_fwd_bwd=_layer_norm_triton_fwd_bwd, - ), + [LayerNormCase(rows=r, cols=c, dtype=d) for d in dtypes for (r, c) in shapes], + layer_norm_variants, ), ( "normalization: rms_norm", - make_cases("rms_norm", dtypes, shapes, _make_rms_norm_inputs, _rms_norm_bytes, _rms_norm_flops), - standard_fwd_bwd_pytorch_variants( - _rms_norm_eager, - input_keys=("input", "weight"), - grad_input_keys=("input", "weight"), - grad_output_key="grad_output", - extra_functions=_RMS_NORM_EXTRAS, - triton_fwd=_rms_norm_triton_fwd, - triton_fwd_bwd=_rms_norm_triton_fwd_bwd, - ), + [RmsNormCase(rows=r, cols=c, dtype=d) for d in dtypes for (r, c) in shapes], + rms_norm_variants, ), ] diff --git a/tools/benchmark/bench_pointwise.py b/tools/benchmark/bench_pointwise.py index bc5c85468..47cd23812 100644 --- a/tools/benchmark/bench_pointwise.py +++ b/tools/benchmark/bench_pointwise.py @@ -1,7 +1,11 @@ +import dataclasses +import typing + import torch from fast_llm.functional.triton.pointwise import triton_add, triton_copy, triton_fill -from tools.benchmark.utils import bench_main, device, make_cases, standard_fwd_variants +from tools.benchmark.runner import Case, Inputs +from tools.benchmark.utils import bench_main, dtype_short, standard_fwd_variants # 4× steps so L2 → HBM and saturated-HBM regimes are visible. _SIZES_NUMEL = [ @@ -13,65 +17,78 @@ ] -def _copy_eager(input_: torch.Tensor, out: torch.Tensor) -> torch.Tensor: - return out.copy_(input_) - - -def _make_copy_inputs(numel: int, dtype: torch.dtype) -> dict: - input_ = torch.randn(numel, dtype=dtype, device=device()) - return {"input_": input_, "out": torch.empty_like(input_)} +@dataclasses.dataclass +class _PointwiseCase(Case): + numel: int + dtype: torch.dtype + # Bytes traffic = bytes_factor × numel × dtype.itemsize. + bytes_factor: typing.ClassVar[int] + @property + def name(self) -> str: + return f"({self.numel},) {dtype_short(self.dtype)}" -def _copy_bytes(numel: int, dtype: torch.dtype) -> int: - return 2 * numel * dtype.itemsize + @property + def compute_dtype(self) -> torch.dtype: + return self.dtype - -_COPY_VARIANTS = standard_fwd_variants( - eager_function=_copy_eager, - triton_function=triton_copy, - unpack=lambda inputs: (inputs["input_"], inputs["out"]), -) + @property + def expected_bytes(self) -> int: + return self.bytes_factor * self.numel * self.dtype.itemsize -def _fill_eager(input_: torch.Tensor, value: float) -> torch.Tensor: - return input_.fill_(value) +class CopyCase(_PointwiseCase): + bytes_factor = 2 + def make_inputs(self, device: str) -> Inputs: + input_ = torch.randn(self.numel, dtype=self.dtype, device=device) + return {"input_": input_, "out": torch.empty_like(input_)} -def _make_fill_inputs(numel: int, dtype: torch.dtype) -> dict: - return {"input_": torch.empty(numel, dtype=dtype, device=device()), "value": 1.5} +class FillCase(_PointwiseCase): + bytes_factor = 1 -def _fill_bytes(numel: int, dtype: torch.dtype) -> int: - return numel * dtype.itemsize + def make_inputs(self, device: str) -> Inputs: + return {"input_": torch.empty(self.numel, dtype=self.dtype, device=device), "value": 1.5} -_FILL_VARIANTS = standard_fwd_variants( - eager_function=_fill_eager, - triton_function=triton_fill, - unpack=lambda inputs: (inputs["input_"], inputs["value"]), -) +class AddCase(_PointwiseCase): + bytes_factor = 3 + @property + def expected_flops(self) -> int: + return self.numel -def _add_eager(input_: torch.Tensor, other: torch.Tensor, out: torch.Tensor) -> torch.Tensor: - return torch.add(input_, other, out=out) + def make_inputs(self, device: str) -> Inputs: + return { + "input_": torch.randn(self.numel, dtype=self.dtype, device=device), + "other": torch.randn(self.numel, dtype=self.dtype, device=device), + "out": torch.empty(self.numel, dtype=self.dtype, device=device), + } -def _make_add_inputs(numel: int, dtype: torch.dtype) -> dict: - return { - "input_": torch.randn(numel, dtype=dtype, device=device()), - "other": torch.randn(numel, dtype=dtype, device=device()), - "out": torch.empty(numel, dtype=dtype, device=device()), - } +def _copy_eager(input_: torch.Tensor, out: torch.Tensor) -> torch.Tensor: + return out.copy_(input_) -def _add_bytes(numel: int, dtype: torch.dtype) -> int: - return 3 * numel * dtype.itemsize +def _fill_eager(input_: torch.Tensor, value: float) -> torch.Tensor: + return input_.fill_(value) -def _add_flops(numel: int) -> int: - return numel +def _add_eager(input_: torch.Tensor, other: torch.Tensor, out: torch.Tensor) -> torch.Tensor: + return torch.add(input_, other, out=out) +_COPY_VARIANTS = standard_fwd_variants( + eager_function=_copy_eager, + triton_function=triton_copy, + unpack=lambda inputs: (inputs["input_"], inputs["out"]), +) +_FILL_VARIANTS = standard_fwd_variants( + eager_function=_fill_eager, + triton_function=triton_fill, + unpack=lambda inputs: (inputs["input_"], inputs["value"]), +) _ADD_VARIANTS = standard_fwd_variants( eager_function=_add_eager, triton_function=triton_add, @@ -82,13 +99,9 @@ def _add_flops(numel: int) -> int: def benchmarks(dtypes: tuple[torch.dtype, ...], shapes: list[int] | None = None) -> list[tuple[str, list, list]]: shapes = shapes if shapes is not None else _SIZES_NUMEL return [ - ("pointwise: copy", make_cases("copy", dtypes, shapes, _make_copy_inputs, _copy_bytes), _COPY_VARIANTS), - ("pointwise: fill", make_cases("fill", dtypes, shapes, _make_fill_inputs, _fill_bytes), _FILL_VARIANTS), - ( - "pointwise: add", - make_cases("add", dtypes, shapes, _make_add_inputs, _add_bytes, _add_flops), - _ADD_VARIANTS, - ), + ("pointwise: copy", [CopyCase(numel=n, dtype=d) for d in dtypes for n in shapes], _COPY_VARIANTS), + ("pointwise: fill", [FillCase(numel=n, dtype=d) for d in dtypes for n in shapes], _FILL_VARIANTS), + ("pointwise: add", [AddCase(numel=n, dtype=d) for d in dtypes for n in shapes], _ADD_VARIANTS), ] diff --git a/tools/benchmark/bench_rotary.py b/tools/benchmark/bench_rotary.py index fe9f0ff52..99f8651f8 100644 --- a/tools/benchmark/bench_rotary.py +++ b/tools/benchmark/bench_rotary.py @@ -1,12 +1,14 @@ """Rotary position embeddings. The Triton kernel is in-place; backward is an identical rotation with conjugated frequencies, so only fwd is benchmarked.""" +import dataclasses + import torch from fast_llm.functional.config import TritonConfig from fast_llm.functional.triton.rotary import triton_rotary_ -from tools.benchmark.runner import Variant -from tools.benchmark.utils import bench_main, device, make_cases +from tools.benchmark.runner import Case, Inputs, Variant +from tools.benchmark.utils import bench_main, dtype_short # (tokens, num_heads, head_size) — tokens = batch * seq_len _SHAPES = [ @@ -17,14 +19,41 @@ ] -def _make_rotary_inputs(tokens: int, num_heads: int, head_size: int, dtype: torch.dtype) -> dict: - rotary_dim = head_size // 2 - input_ = torch.randn(tokens, num_heads, head_size, dtype=dtype, device=device()) - return { - "input_": input_, - "work": input_.clone(), - "frequencies": torch.randn(tokens, 2 * rotary_dim, dtype=torch.float32, device=device()), - } +@dataclasses.dataclass +class RotaryCase(Case): + tokens: int + num_heads: int + head_size: int + dtype: torch.dtype + + @property + def name(self) -> str: + return f"({self.tokens}, {self.num_heads}, {self.head_size}) {dtype_short(self.dtype)}" + + @property + def compute_dtype(self) -> torch.dtype: + return self.dtype + + @property + def expected_bytes(self) -> int: + # frequencies are float32, hence the extra 4 bytes per token×head_size. + return ( + 2 * self.tokens * self.num_heads * self.head_size * self.dtype.itemsize + self.tokens * self.head_size * 4 + ) + + @property + def expected_flops(self) -> int: + # 6 FLOPs per (re, im) element pair: 4 muls + 2 add/sub. + return 6 * self.tokens * self.num_heads * (self.head_size // 2) + + def make_inputs(self, device: str) -> Inputs: + rotary_dim = self.head_size // 2 + input_ = torch.randn(self.tokens, self.num_heads, self.head_size, dtype=self.dtype, device=device) + return { + "input_": input_, + "work": input_.clone(), + "frequencies": torch.randn(self.tokens, 2 * rotary_dim, dtype=torch.float32, device=device), + } def _rotary_eager(input_: torch.Tensor, frequencies: torch.Tensor) -> torch.Tensor: @@ -41,16 +70,6 @@ def _rotary_eager(input_: torch.Tensor, frequencies: torch.Tensor) -> torch.Tens _rotary_compiled_max = torch.compile(_rotary_eager, mode="max-autotune-no-cudagraphs", dynamic=False) -def _rotary_bytes(tokens: int, num_heads: int, head_size: int, dtype: torch.dtype) -> int: - # frequencies are float32, hence the extra 4 bytes per token×head_size. - return 2 * tokens * num_heads * head_size * dtype.itemsize + tokens * head_size * 4 - - -def _rotary_flops(tokens: int, num_heads: int, head_size: int) -> int: - # 6 FLOPs per (re, im) element pair: 4 muls + 2 add/sub. - return 6 * tokens * num_heads * (head_size // 2) - - def _rotary_variants() -> list[Variant]: variants = [ Variant( @@ -90,7 +109,7 @@ def benchmarks( return [ ( "rotary", - make_cases("rotary", dtypes, shapes, _make_rotary_inputs, _rotary_bytes, _rotary_flops), + [RotaryCase(tokens=t, num_heads=h, head_size=hs, dtype=d) for d in dtypes for (t, h, hs) in shapes], _rotary_variants(), ) ] diff --git a/tools/benchmark/bench_sparse_copy.py b/tools/benchmark/bench_sparse_copy.py index 513d71ec8..62c165e3f 100644 --- a/tools/benchmark/bench_sparse_copy.py +++ b/tools/benchmark/bench_sparse_copy.py @@ -1,12 +1,16 @@ +import dataclasses + import torch +from fast_llm.functional.config import TritonConfig from fast_llm.functional.triton.sparse_copy import ( SparseMap, copy_dense_to_sparse_autograd, copy_sparse_to_dense_autograd, get_sparse_map, ) -from tools.benchmark.utils import bench_main, device, make_cases, standard_fwd_bwd_pytorch_variants +from tools.benchmark.runner import Case, Inputs, Variant +from tools.benchmark.utils import bench_main, dtype_short, standard_fwd_bwd_pytorch_variants # (tokens, top_k, num_experts, hidden_size) _SHAPES = [ @@ -16,15 +20,10 @@ ] -def _make_sparse_map(tokens: int, top_k: int, num_experts: int) -> SparseMap: - top_experts = torch.randint(0, num_experts, (tokens, top_k), device=device()) - return get_sparse_map(top_experts, num_experts) - - -def _make_phantom_mask(sparse_map: SparseMap) -> torch.Tensor: +def _make_phantom_mask(sparse_map: SparseMap, device: str) -> torch.Tensor: # True for within-expert padding rows and the static tail past expert_ends[-1]; # used only in output_postprocess, never in the timed path. - mask = torch.zeros(sparse_map.num_rows, 1, dtype=torch.bool, device=device()) + mask = torch.zeros(sparse_map.num_rows, 1, dtype=torch.bool, device=device) for expert in range(sparse_map.num_experts): pad_begin = int(sparse_map.expert_pad_begins[expert]) pad_end = int(sparse_map.expert_ends[expert]) @@ -36,25 +35,60 @@ def _make_phantom_mask(sparse_map: SparseMap) -> torch.Tensor: return mask -def _make_dispatch_inputs(tokens: int, top_k: int, num_experts: int, hidden: int, dtype: torch.dtype) -> dict: - sparse_map = _make_sparse_map(tokens, top_k, num_experts) - return { - "dense": torch.randn(tokens, hidden, dtype=dtype, device=device(), requires_grad=True), - "sparse_map": sparse_map, - "phantom_mask": _make_phantom_mask(sparse_map), - "backward_grad": torch.ones(sparse_map.num_rows, hidden, dtype=dtype, device=device()), - } - - -def _make_combine_inputs(tokens: int, top_k: int, num_experts: int, hidden: int, dtype: torch.dtype) -> dict: - sparse_map = _make_sparse_map(tokens, top_k, num_experts) - return { - "sparse": torch.randn(sparse_map.num_rows, hidden, dtype=dtype, device=device(), requires_grad=True), - "scores": torch.softmax(torch.randn(tokens, top_k, dtype=dtype, device=device()), dim=-1).requires_grad_(True), - "sparse_map": sparse_map, - "phantom_mask": _make_phantom_mask(sparse_map), - "backward_grad": torch.ones(tokens, hidden, dtype=dtype, device=device()), - } +@dataclasses.dataclass +class _SparseCopyCase(Case): + tokens: int + top_k: int + num_experts: int + hidden: int + dtype: torch.dtype + + @property + def name(self) -> str: + return f"({self.tokens}, {self.top_k}, {self.num_experts}, {self.hidden}) {dtype_short(self.dtype)}" + + @property + def compute_dtype(self) -> torch.dtype: + return self.dtype + + @property + def expected_bytes(self) -> int: + # 2× (sparse + dense) hidden traffic; combine adds scores read/write. + return 2 * (1 + self.top_k) * self.tokens * self.hidden * self.dtype.itemsize + + +class DispatchCase(_SparseCopyCase): + def make_inputs(self, device: str) -> Inputs: + top_experts = torch.randint(0, self.num_experts, (self.tokens, self.top_k), device=device) + sparse_map = get_sparse_map(top_experts, self.num_experts) + return { + "dense": torch.randn(self.tokens, self.hidden, dtype=self.dtype, device=device, requires_grad=True), + "sparse_map": sparse_map, + "phantom_mask": _make_phantom_mask(sparse_map, device), + "backward_grad": torch.ones(sparse_map.num_rows, self.hidden, dtype=self.dtype, device=device), + } + + +class CombineCase(_SparseCopyCase): + @property + def expected_bytes(self) -> int: + # Adds scores read/write on top of the dispatch traffic. + return super().expected_bytes + 4 * self.tokens * self.top_k * self.dtype.itemsize + + def make_inputs(self, device: str) -> Inputs: + top_experts = torch.randint(0, self.num_experts, (self.tokens, self.top_k), device=device) + sparse_map = get_sparse_map(top_experts, self.num_experts) + return { + "sparse": torch.randn( + sparse_map.num_rows, self.hidden, dtype=self.dtype, device=device, requires_grad=True + ), + "scores": torch.softmax( + torch.randn(self.tokens, self.top_k, dtype=self.dtype, device=device), dim=-1 + ).requires_grad_(True), + "sparse_map": sparse_map, + "phantom_mask": _make_phantom_mask(sparse_map, device), + "backward_grad": torch.ones(self.tokens, self.hidden, dtype=self.dtype, device=device), + } def _dispatch_pytorch(dense_input: torch.Tensor, sparse_map: SparseMap) -> torch.Tensor: @@ -108,47 +142,59 @@ def _combine_postprocess(output: dict[str, torch.Tensor], inputs: dict) -> dict[ return output -def _dispatch_bytes(tokens: int, top_k: int, num_experts: int, hidden: int, dtype: torch.dtype) -> int: - # fwd: read dense + write sparse; bwd: same traffic reversed. - return 2 * (1 + top_k) * tokens * hidden * dtype.itemsize - - -def _combine_bytes(tokens: int, top_k: int, num_experts: int, hidden: int, dtype: torch.dtype) -> int: - # 2× (sparse + dense) hidden traffic + scores read/write. - return 2 * (top_k + 1) * tokens * hidden * dtype.itemsize + 4 * tokens * top_k * dtype.itemsize - - def benchmarks( dtypes: tuple[torch.dtype, ...], shapes: list[tuple[int, int, int, int]] | None = None, ) -> list[tuple[str, list, list]]: shapes = shapes if shapes is not None else _SHAPES + dispatch_variants = standard_fwd_bwd_pytorch_variants( + _dispatch_pytorch, + input_keys=("dense", "sparse_map"), + grad_input_keys=("dense",), + grad_output_key="backward_grad", + ) + if TritonConfig.enabled(): + dispatch_variants.append( + Variant( + name="fast_llm_triton", + fwd=_dispatch_triton_fwd, + fwd_bwd=_dispatch_triton_fwd_bwd, + output_postprocess=_dispatch_postprocess, + ) + ) + combine_variants = standard_fwd_bwd_pytorch_variants( + _combine_pytorch, + input_keys=("sparse", "scores", "sparse_map"), + grad_input_keys=("sparse", "scores"), + grad_output_key="backward_grad", + ) + if TritonConfig.enabled(): + combine_variants.append( + Variant( + name="fast_llm_triton", + fwd=_combine_triton_fwd, + fwd_bwd=_combine_triton_fwd_bwd, + output_postprocess=_combine_postprocess, + ) + ) return [ ( "sparse_copy: dispatch", - make_cases("dispatch", dtypes, shapes, _make_dispatch_inputs, _dispatch_bytes), - standard_fwd_bwd_pytorch_variants( - _dispatch_pytorch, - input_keys=("dense", "sparse_map"), - grad_input_keys=("dense",), - grad_output_key="backward_grad", - triton_fwd=_dispatch_triton_fwd, - triton_fwd_bwd=_dispatch_triton_fwd_bwd, - triton_output_postprocess=_dispatch_postprocess, - ), + [ + DispatchCase(tokens=t, top_k=k, num_experts=e, hidden=h, dtype=d) + for d in dtypes + for (t, k, e, h) in shapes + ], + dispatch_variants, ), ( "sparse_copy: combine", - make_cases("combine", dtypes, shapes, _make_combine_inputs, _combine_bytes), - standard_fwd_bwd_pytorch_variants( - _combine_pytorch, - input_keys=("sparse", "scores", "sparse_map"), - grad_input_keys=("sparse", "scores"), - grad_output_key="backward_grad", - triton_fwd=_combine_triton_fwd, - triton_fwd_bwd=_combine_triton_fwd_bwd, - triton_output_postprocess=_combine_postprocess, - ), + [ + CombineCase(tokens=t, top_k=k, num_experts=e, hidden=h, dtype=d) + for d in dtypes + for (t, k, e, h) in shapes + ], + combine_variants, ), ] diff --git a/tools/benchmark/bench_sparse_linear.py b/tools/benchmark/bench_sparse_linear.py index 2836c7e6d..1e05eb038 100644 --- a/tools/benchmark/bench_sparse_linear.py +++ b/tools/benchmark/bench_sparse_linear.py @@ -1,9 +1,12 @@ +import dataclasses + import torch from fast_llm.functional.config import TritonConfig from fast_llm.functional.triton.sparse_copy import SparseMap, get_sparse_map from fast_llm.functional.triton.sparse_linear import InputSparseLinear, OutputSparseLinear -from tools.benchmark.utils import bench_main, device, make_cases, standard_fwd_bwd_pytorch_variants +from tools.benchmark.runner import Case, Inputs, Variant +from tools.benchmark.utils import bench_main, dtype_short, standard_fwd_bwd_pytorch_variants # (tokens, top_k, num_experts, hidden, ffn_per_expert) _SHAPES = [ @@ -18,11 +21,6 @@ _input_inner_sparse_warmed_up: set[tuple] = set() -def _make_sparse_map(tokens: int, top_k: int, num_experts: int) -> SparseMap: - top_experts = torch.randint(0, num_experts, (tokens, top_k), device=device()) - return get_sparse_map(top_experts, num_experts) - - def _mask_padded_rows(candidate: dict[str, torch.Tensor], inputs: dict) -> dict[str, torch.Tensor]: # Two regions in the kernel's forward output and grad_lhs are by-design garbage: # per-expert padding [pad_begin, expert_end) and phantom rows past expert_ends[-1]. @@ -46,50 +44,87 @@ def _mask_padded_rows(candidate: dict[str, torch.Tensor], inputs: dict) -> dict[ return masked -def _make_output_sparse_inputs( - tokens: int, top_k: int, num_experts: int, hidden: int, ffn_per_expert: int, dtype: torch.dtype -) -> dict: - sparse_map = _make_sparse_map(tokens, top_k, num_experts) - lhs_data = torch.randn(sparse_map.num_rows, hidden, dtype=dtype, device=device()) - rhs_data = torch.randn(hidden, ffn_per_expert * num_experts, dtype=dtype, device=device()) - backward_grad = torch.ones(sparse_map.num_rows, ffn_per_expert, dtype=dtype, device=device()) - warmup_key = (tokens, top_k, num_experts, hidden, ffn_per_expert, dtype) - if TritonConfig.enabled() and warmup_key not in _output_sparse_warmed_up: - warmup_lhs = lhs_data.detach().requires_grad_(True) - warmup_rhs = rhs_data.detach().requires_grad_(True) - warmup_out = OutputSparseLinear.apply(warmup_lhs, warmup_rhs, sparse_map) - warmup_out.backward(backward_grad) - del warmup_lhs, warmup_rhs, warmup_out - _output_sparse_warmed_up.add(warmup_key) - return { - "lhs": lhs_data.requires_grad_(True), - "rhs": rhs_data.requires_grad_(True), - "sparse_map": sparse_map, - "backward_grad": backward_grad, - } - - -def _make_input_inner_sparse_inputs( - tokens: int, top_k: int, num_experts: int, hidden: int, ffn_per_expert: int, dtype: torch.dtype -) -> dict: - sparse_map = _make_sparse_map(tokens, top_k, num_experts) - lhs_data = torch.randn(sparse_map.num_rows, ffn_per_expert, dtype=dtype, device=device()) - rhs_data = torch.randn(ffn_per_expert * num_experts, hidden, dtype=dtype, device=device()) - backward_grad = torch.ones(sparse_map.num_rows, hidden, dtype=dtype, device=device()) - warmup_key = (tokens, top_k, num_experts, hidden, ffn_per_expert, dtype) - if TritonConfig.enabled() and warmup_key not in _input_inner_sparse_warmed_up: - warmup_lhs = lhs_data.detach().requires_grad_(True) - warmup_rhs = rhs_data.detach().requires_grad_(True) - warmup_out = InputSparseLinear.apply(warmup_lhs, warmup_rhs, sparse_map) - warmup_out.backward(backward_grad) - del warmup_lhs, warmup_rhs, warmup_out - _input_inner_sparse_warmed_up.add(warmup_key) - return { - "lhs": lhs_data.requires_grad_(True), - "rhs": rhs_data.requires_grad_(True), - "sparse_map": sparse_map, - "backward_grad": backward_grad, - } +@dataclasses.dataclass +class _SparseLinearCase(Case): + tokens: int + top_k: int + num_experts: int + hidden: int + ffn_per_expert: int + dtype: torch.dtype + + @property + def name(self) -> str: + return ( + f"({self.tokens}, {self.top_k}, {self.num_experts}, " + f"{self.hidden}, {self.ffn_per_expert}) {dtype_short(self.dtype)}" + ) + + @property + def compute_dtype(self) -> torch.dtype: + return self.dtype + + @property + def expected_bytes(self) -> int: + # Approximation: 3× lhs + 3× rhs + 2× output traffic across fwd+bwd. + sparse_tokens = self.tokens * self.top_k + lhs_bytes = sparse_tokens * self.hidden * self.dtype.itemsize + rhs_bytes = self.hidden * self.ffn_per_expert * self.num_experts * self.dtype.itemsize + output_bytes = sparse_tokens * self.ffn_per_expert * self.dtype.itemsize + return 3 * lhs_bytes + 3 * rhs_bytes + 2 * output_bytes + + @property + def expected_flops(self) -> int: + # 3 matmuls (fwd: lhs@rhs, bwd_lhs: grad@rhs.T, bwd_rhs: lhs.T@grad). + return 3 * 2 * self.tokens * self.top_k * self.hidden * self.ffn_per_expert + + def _make_sparse_map(self, device: str) -> SparseMap: + top_experts = torch.randint(0, self.num_experts, (self.tokens, self.top_k), device=device) + return get_sparse_map(top_experts, self.num_experts) + + +class OutputSparseCase(_SparseLinearCase): + def make_inputs(self, device: str) -> Inputs: + sparse_map = self._make_sparse_map(device) + lhs_data = torch.randn(sparse_map.num_rows, self.hidden, dtype=self.dtype, device=device) + rhs_data = torch.randn(self.hidden, self.ffn_per_expert * self.num_experts, dtype=self.dtype, device=device) + backward_grad = torch.ones(sparse_map.num_rows, self.ffn_per_expert, dtype=self.dtype, device=device) + warmup_key = (self.tokens, self.top_k, self.num_experts, self.hidden, self.ffn_per_expert, self.dtype) + if TritonConfig.enabled() and warmup_key not in _output_sparse_warmed_up: + warmup_lhs = lhs_data.detach().requires_grad_(True) + warmup_rhs = rhs_data.detach().requires_grad_(True) + warmup_out = OutputSparseLinear.apply(warmup_lhs, warmup_rhs, sparse_map) + warmup_out.backward(backward_grad) + del warmup_lhs, warmup_rhs, warmup_out + _output_sparse_warmed_up.add(warmup_key) + return { + "lhs": lhs_data.requires_grad_(True), + "rhs": rhs_data.requires_grad_(True), + "sparse_map": sparse_map, + "backward_grad": backward_grad, + } + + +class InputInnerSparseCase(_SparseLinearCase): + def make_inputs(self, device: str) -> Inputs: + sparse_map = self._make_sparse_map(device) + lhs_data = torch.randn(sparse_map.num_rows, self.ffn_per_expert, dtype=self.dtype, device=device) + rhs_data = torch.randn(self.ffn_per_expert * self.num_experts, self.hidden, dtype=self.dtype, device=device) + backward_grad = torch.ones(sparse_map.num_rows, self.hidden, dtype=self.dtype, device=device) + warmup_key = (self.tokens, self.top_k, self.num_experts, self.hidden, self.ffn_per_expert, self.dtype) + if TritonConfig.enabled() and warmup_key not in _input_inner_sparse_warmed_up: + warmup_lhs = lhs_data.detach().requires_grad_(True) + warmup_rhs = rhs_data.detach().requires_grad_(True) + warmup_out = InputSparseLinear.apply(warmup_lhs, warmup_rhs, sparse_map) + warmup_out.backward(backward_grad) + del warmup_lhs, warmup_rhs, warmup_out + _input_inner_sparse_warmed_up.add(warmup_key) + return { + "lhs": lhs_data.requires_grad_(True), + "rhs": rhs_data.requires_grad_(True), + "sparse_map": sparse_map, + "backward_grad": backward_grad, + } def _output_sparse_loop(lhs: torch.Tensor, rhs: torch.Tensor, sparse_map: SparseMap) -> torch.Tensor: @@ -136,66 +171,63 @@ def _input_inner_sparse_triton_fwd_bwd(inputs: dict) -> dict: return {"output": output.detach(), "grad_lhs": inputs["lhs"].grad, "grad_rhs": inputs["rhs"].grad} -def _sparse_linear_bytes( - tokens: int, top_k: int, num_experts: int, hidden: int, ffn_per_expert: int, dtype: torch.dtype -) -> int: - # Approximation: 3× lhs + 3× rhs + 2× output traffic across fwd+bwd. - sparse_tokens = tokens * top_k - lhs_bytes = sparse_tokens * hidden * dtype.itemsize - rhs_bytes = hidden * ffn_per_expert * num_experts * dtype.itemsize - output_bytes = sparse_tokens * ffn_per_expert * dtype.itemsize - return 3 * lhs_bytes + 3 * rhs_bytes + 2 * output_bytes - - -def _sparse_linear_flops(tokens: int, top_k: int, num_experts: int, hidden: int, ffn_per_expert: int) -> int: - # 3 matmuls (fwd: lhs@rhs, bwd_lhs: grad@rhs.T, bwd_rhs: lhs.T@grad). - return 3 * 2 * tokens * top_k * hidden * ffn_per_expert - - def benchmarks( dtypes: tuple[torch.dtype, ...], shapes: list[tuple[int, int, int, int, int]] | None = None, ) -> list[tuple[str, list, list]]: shapes = shapes if shapes is not None else _SHAPES + output_sparse_variants = standard_fwd_bwd_pytorch_variants( + _output_sparse_loop, + input_keys=("lhs", "rhs", "sparse_map"), + grad_input_keys=("lhs", "rhs"), + grad_output_key="backward_grad", + eager_name="pytorch_loop", + enable_max_autotune=False, + ) + if TritonConfig.enabled(): + output_sparse_variants.append( + Variant( + name="fast_llm_triton", + fwd=_output_sparse_triton_fwd, + fwd_bwd=_output_sparse_triton_fwd_bwd, + output_postprocess=_mask_padded_rows, + ) + ) + input_inner_sparse_variants = standard_fwd_bwd_pytorch_variants( + _input_inner_sparse_loop, + input_keys=("lhs", "rhs", "sparse_map"), + grad_input_keys=("lhs", "rhs"), + grad_output_key="backward_grad", + eager_name="pytorch_loop", + enable_max_autotune=False, + ) + if TritonConfig.enabled(): + input_inner_sparse_variants.append( + Variant( + name="fast_llm_triton", + fwd=_input_inner_sparse_triton_fwd, + fwd_bwd=_input_inner_sparse_triton_fwd_bwd, + output_postprocess=_mask_padded_rows, + ) + ) return [ ( "sparse_linear: output_sparse (layer 1 / up-proj)", - make_cases( - "output_sparse", dtypes, shapes, _make_output_sparse_inputs, _sparse_linear_bytes, _sparse_linear_flops - ), - standard_fwd_bwd_pytorch_variants( - _output_sparse_loop, - input_keys=("lhs", "rhs", "sparse_map"), - grad_input_keys=("lhs", "rhs"), - grad_output_key="backward_grad", - eager_name="pytorch_loop", - enable_max_autotune=False, - triton_fwd=_output_sparse_triton_fwd, - triton_fwd_bwd=_output_sparse_triton_fwd_bwd, - triton_output_postprocess=_mask_padded_rows, - ), + [ + OutputSparseCase(tokens=t, top_k=k, num_experts=e, hidden=h, ffn_per_expert=f, dtype=d) + for d in dtypes + for (t, k, e, h, f) in shapes + ], + output_sparse_variants, ), ( "sparse_linear: input_inner_sparse (layer 2 / down-proj)", - make_cases( - "input_inner_sparse", - dtypes, - shapes, - _make_input_inner_sparse_inputs, - _sparse_linear_bytes, - _sparse_linear_flops, - ), - standard_fwd_bwd_pytorch_variants( - _input_inner_sparse_loop, - input_keys=("lhs", "rhs", "sparse_map"), - grad_input_keys=("lhs", "rhs"), - grad_output_key="backward_grad", - eager_name="pytorch_loop", - enable_max_autotune=False, - triton_fwd=_input_inner_sparse_triton_fwd, - triton_fwd_bwd=_input_inner_sparse_triton_fwd_bwd, - triton_output_postprocess=_mask_padded_rows, - ), + [ + InputInnerSparseCase(tokens=t, top_k=k, num_experts=e, hidden=h, ffn_per_expert=f, dtype=d) + for d in dtypes + for (t, k, e, h, f) in shapes + ], + input_inner_sparse_variants, ), ] diff --git a/tools/benchmark/runner.py b/tools/benchmark/runner.py index a5f823563..de4d94a24 100644 --- a/tools/benchmark/runner.py +++ b/tools/benchmark/runner.py @@ -73,27 +73,34 @@ class Variant: reset_inputs: Callable[[Inputs], Any] | None = None -@dataclasses.dataclass class Case: - """A single input configuration for the kernel under test. `make_inputs` - builds fresh input tensors on demand. It is called once per variant per - mode, after a global seed reset, so every variant sees identical inputs.""" + """Base for a single input configuration. Subclasses are dataclasses holding + the kernel's shape parameters (e.g. rows, cols, dtype) and override `name` + and `make_inputs`; the throughput properties are optional.""" + # Subclasses must provide these. name: str - make_inputs: Callable[[], Inputs] - # Minimum bytes read+written by the op. Used for GB/s + %BW. Optional. - expected_bytes: int | None = None - # Minimum floating-point ops performed by the op. Used for TFLOP/s + %FLOPs. Optional. - expected_flops: int | None = None - # For %FLOPs: which peak column to use (dtype of the hot inputs). - compute_dtype: torch.dtype | None = None + # Subclasses may override these (defaults skip the corresponding columns). + expected_bytes: int | None = None # bytes read+written; enables GB/s + %BW. + expected_flops: int | None = None # FLOPs performed; enables TFLOP/s + %FLOPs. + compute_dtype: torch.dtype | None = None # dtype of hot inputs, picks peak column. + + def make_inputs(self, device: str) -> Inputs: + """Return a fresh dict of input tensors on `device`. Called once per + variant per mode, after a global seed reset, so every variant sees + identical inputs.""" + raise NotImplementedError + + +def _device() -> str: + return "cuda" if torch.cuda.is_available() else "cpu" def _seeded_inputs(case: Case, seed: int = 0) -> Inputs: torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) - return case.make_inputs() + return case.make_inputs(_device()) @dataclasses.dataclass diff --git a/tools/benchmark/utils.py b/tools/benchmark/utils.py index 23ebf1ab2..f18ee870a 100644 --- a/tools/benchmark/utils.py +++ b/tools/benchmark/utils.py @@ -1,56 +1,18 @@ +import dataclasses from collections.abc import Callable -from functools import partial +from typing import Any import torch from fast_llm.engine.config_utils.data_type import DataType from fast_llm.functional.config import TritonConfig -from tools.benchmark.runner import Case, Inputs, Variant, VariantFn, run_benchmark +from tools.benchmark.runner import Inputs, Variant, run_benchmark DEFAULT_DTYPES: tuple[torch.dtype, ...] = (torch.bfloat16,) -def _format_size(n: int) -> str: - for unit, factor in (("Gi", 1 << 30), ("Mi", 1 << 20), ("Ki", 1 << 10)): - if n >= factor and n % factor == 0: - return f"{n // factor} {unit}" - return str(n) - - -def _case_name(kernel: str, shape: tuple[int, ...], dtype: torch.dtype) -> str: - joined = ", ".join(_format_size(n) for n in shape) - shape_repr = f"({joined},)" if len(shape) == 1 else f"({joined})" - return f"[{kernel}] {shape_repr} {DataType.from_torch(dtype).short}" - - -def device() -> str: - return "cuda" if torch.cuda.is_available() else "cpu" - - -def make_cases( - kernel_name: str, - dtypes: tuple[torch.dtype, ...], - shapes: list, - make_inputs: Callable, - bytes_fn: Callable | None = None, - flops_fn: Callable | None = None, -) -> list[Case]: - """Cross-product of `dtypes × shapes`. Each shape may be a tuple or scalar; - tuples are unpacked positionally into `make_inputs`, `bytes_fn`, `flops_fn`.""" - cases = [] - for dtype in dtypes: - for shape in shapes: - shape_tuple = shape if isinstance(shape, tuple) else (shape,) - cases.append( - Case( - name=_case_name(kernel_name, shape_tuple, dtype), - make_inputs=partial(make_inputs, *shape_tuple, dtype), - expected_bytes=bytes_fn(*shape_tuple, dtype) if bytes_fn else None, - expected_flops=flops_fn(*shape_tuple) if flops_fn else None, - compute_dtype=dtype, - ) - ) - return cases +def dtype_short(dtype: torch.dtype) -> str: + return DataType.from_torch(dtype).short def bench_main(benchmarks_fn: Callable) -> Callable: @@ -70,6 +32,91 @@ def run( return run +@dataclasses.dataclass(kw_only=True) +class PytorchVariant(Variant): + """Variant that calls a pytorch function on inputs picked by key. Used for + eager, torch.compile, and apex variants — each instance differs in `function` + while sharing the dispatch logic.""" + + function: Callable + input_keys: tuple[str, ...] + grad_input_keys: tuple[str, ...] = () + grad_output_key: str | None = None + output_key: str = "output" + + def __post_init__(self) -> None: + # Wire the inherited `fwd`/`fwd_bwd` callable fields to bound methods + # so subclasses can override the methods without touching the fields. + self.fwd = self._fwd + self.fwd_bwd = self._fwd_bwd + + def _fwd(self, inputs: Inputs) -> dict: + return {self.output_key: self.function(*(inputs[k] for k in self.input_keys))} + + def _fwd_bwd(self, inputs: Inputs) -> dict: + output = self.function(*(inputs[k] for k in self.input_keys)) + if self.grad_output_key is None: + output.backward() + else: + output.backward(inputs[self.grad_output_key]) + result = {self.output_key: output.detach()} + for key in self.grad_input_keys: + result[f"grad_{key}"] = inputs[key].grad + return result + + +@dataclasses.dataclass(kw_only=True) +class Fp32ReferenceVariant(PytorchVariant): + """Reference variant: upcasts every floating-point input to fp32 before + running the eager pytorch function. Re-attaches `requires_grad=True` on + `grad_input_keys` so backward sees a leaf tensor.""" + + name: str = "fp32_reference" + is_reference: bool = True + + def _fwd(self, inputs: Inputs) -> dict: + return super()._fwd(self._to_fp32(inputs)) + + def _fwd_bwd(self, inputs: Inputs) -> dict: + return super()._fwd_bwd(self._to_fp32(inputs)) + + def _to_fp32(self, inputs: Inputs) -> Inputs: + result = dict(inputs) + for key, value in inputs.items(): + if isinstance(value, torch.Tensor) and value.is_floating_point(): + float_value = value.float().detach() + result[key] = float_value.requires_grad_(True) if key in self.grad_input_keys else float_value + return result + + +@dataclasses.dataclass(kw_only=True) +class FwdOnlyPytorchVariant(Variant): + """Forward-only variant: calls a pytorch function with positional args + extracted via `unpack`. Used by bench_pointwise where there's no backward.""" + + function: Callable + unpack: Callable[[Inputs], tuple] + + def __post_init__(self) -> None: + self.fwd = self._fwd + + def _fwd(self, inputs: Inputs) -> Any: + return self.function(*self.unpack(inputs)) + + +@dataclasses.dataclass(kw_only=True) +class Fp32FwdOnlyReferenceVariant(FwdOnlyPytorchVariant): + name: str = "fp32_reference" + is_reference: bool = True + + def _fwd(self, inputs: Inputs) -> Any: + args = tuple( + arg.float() if isinstance(arg, torch.Tensor) and arg.is_floating_point() else arg + for arg in self.unpack(inputs) + ) + return self.function(*args) + + def standard_fwd_variants( eager_function: Callable, triton_function: Callable | None, @@ -77,69 +124,31 @@ def standard_fwd_variants( ) -> list[Variant]: """fp32_reference, pytorch_eager, pytorch_compiled, pytorch_compiled_max, and (if `TritonConfig.enabled()`) fast_llm_triton.""" - - def fp32_unpack(inputs: Inputs) -> tuple: - return tuple( - arg.float() if isinstance(arg, torch.Tensor) and arg.is_floating_point() else arg for arg in unpack(inputs) - ) - - compiled_default = torch.compile(eager_function, mode="default", dynamic=False) - compiled_max = torch.compile(eager_function, mode="max-autotune-no-cudagraphs", dynamic=False) - - variants = [ - Variant( - name="fp32_reference", - fwd=lambda inputs: eager_function(*fp32_unpack(inputs)), - is_reference=True, + variants: list[Variant] = [ + Fp32FwdOnlyReferenceVariant(function=eager_function, unpack=unpack), + FwdOnlyPytorchVariant(name="pytorch_eager", function=eager_function, unpack=unpack), + FwdOnlyPytorchVariant( + name="pytorch_compiled", + function=torch.compile(eager_function, mode="default", dynamic=False), + unpack=unpack, + ), + FwdOnlyPytorchVariant( + name="pytorch_compiled_max", + function=torch.compile(eager_function, mode="max-autotune-no-cudagraphs", dynamic=False), + unpack=unpack, ), - Variant(name="pytorch_eager", fwd=lambda inputs: eager_function(*unpack(inputs))), - Variant(name="pytorch_compiled", fwd=lambda inputs: compiled_default(*unpack(inputs))), - Variant(name="pytorch_compiled_max", fwd=lambda inputs: compiled_max(*unpack(inputs))), ] if triton_function is not None and TritonConfig.enabled(): variants.append( - Variant(name="fast_llm_triton", fwd=lambda inputs: triton_function(*unpack(inputs), use_triton=True)) + FwdOnlyPytorchVariant( + name="fast_llm_triton", + function=lambda *args: triton_function(*args, use_triton=True), + unpack=unpack, + ) ) return variants -def _run_pytorch_fwd( - inputs: Inputs, - function: Callable, - input_keys: tuple[str, ...], - output_key: str, -) -> dict: - return {output_key: function(*(inputs[key] for key in input_keys))} - - -def _run_pytorch_fwd_bwd( - inputs: Inputs, - function: Callable, - input_keys: tuple[str, ...], - grad_input_keys: tuple[str, ...], - grad_output_key: str | None, - output_key: str, -) -> dict: - output = function(*(inputs[key] for key in input_keys)) - if grad_output_key is None: - output.backward() - else: - output.backward(inputs[grad_output_key]) - result = {output_key: output.detach()} - for key in grad_input_keys: - result[f"grad_{key}"] = inputs[key].grad - return result - - -def _to_fp32_inputs(inputs: Inputs, grad_input_keys: tuple[str, ...]) -> Inputs: - result = dict(inputs) - for key, value in inputs.items(): - if isinstance(value, torch.Tensor) and value.is_floating_point(): - float_value = value.float().detach() - result[key] = float_value.requires_grad_(True) if key in grad_input_keys else float_value - return result - - def standard_fwd_bwd_pytorch_variants( eager_function: Callable, input_keys: tuple[str, ...], @@ -149,65 +158,36 @@ def standard_fwd_bwd_pytorch_variants( output_key: str = "output", reset_inputs: Callable[[Inputs], None] | None = None, extra_functions: dict[str, Callable] | None = None, - triton_fwd: VariantFn | None = None, - triton_fwd_bwd: VariantFn | None = None, - triton_name: str = "fast_llm_triton", - triton_output_postprocess: Callable[[dict, Inputs], dict] | None = None, eager_name: str = "pytorch_eager", enable_max_autotune: bool = True, ) -> list[Variant]: - """fp32_reference + + pytorch_compiled + [pytorch_compiled_max] + - `extra_functions` + optional `` (gated on `TritonConfig.enabled()`). - - Pytorch variants call `eager_function(*[inputs[k] for k in input_keys])`. For - fwd_bwd, `grad_input_keys` are read out as `grad_` and `grad_output_key=None` - triggers a scalar `output.backward()`. The fp32 reference upcasts every - floating-point input, re-attaching `requires_grad=True` on `grad_input_keys`. - """ - fwd_kwargs = {"input_keys": input_keys, "output_key": output_key} - fwd_bwd_kwargs = { + """fp32_reference + + pytorch_compiled + [pytorch_compiled_max] + + `extra_functions`. Triton variants are appended by the caller (so each + bench file owns its triton wiring explicitly).""" + common = { "input_keys": input_keys, "grad_input_keys": grad_input_keys, "grad_output_key": grad_output_key, "output_key": output_key, + "reset_inputs": reset_inputs, } - - def variant(name: str, function: Callable) -> Variant: - return Variant( - name=name, - fwd=partial(_run_pytorch_fwd, function=function, **fwd_kwargs), - fwd_bwd=partial(_run_pytorch_fwd_bwd, function=function, **fwd_bwd_kwargs), - reset_inputs=reset_inputs, - ) - - compiled_default = torch.compile(eager_function, mode="default", dynamic=False) - variants = [ - Variant( - name="fp32_reference", - fwd=lambda inputs: _run_pytorch_fwd( - _to_fp32_inputs(inputs, grad_input_keys), eager_function, **fwd_kwargs - ), - fwd_bwd=lambda inputs: _run_pytorch_fwd_bwd( - _to_fp32_inputs(inputs, grad_input_keys), eager_function, **fwd_bwd_kwargs - ), - is_reference=True, + variants: list[Variant] = [ + Fp32ReferenceVariant(function=eager_function, **common), + PytorchVariant(name=eager_name, function=eager_function, **common), + PytorchVariant( + name="pytorch_compiled", + function=torch.compile(eager_function, mode="default", dynamic=False), + **common, ), - variant(eager_name, eager_function), - variant("pytorch_compiled", compiled_default), ] if enable_max_autotune: - compiled_max = torch.compile(eager_function, mode="max-autotune-no-cudagraphs", dynamic=False) - variants.append(variant("pytorch_compiled_max", compiled_max)) - for name, function in (extra_functions or {}).items(): - variants.append(variant(name, function)) - if (triton_fwd is not None or triton_fwd_bwd is not None) and TritonConfig.enabled(): variants.append( - Variant( - name=triton_name, - fwd=triton_fwd, - fwd_bwd=triton_fwd_bwd, - reset_inputs=reset_inputs, - output_postprocess=triton_output_postprocess, + PytorchVariant( + name="pytorch_compiled_max", + function=torch.compile(eager_function, mode="max-autotune-no-cudagraphs", dynamic=False), + **common, ) ) + for name, function in (extra_functions or {}).items(): + variants.append(PytorchVariant(name=name, function=function, **common)) return variants From 1ca02cf2796a4d0a9ad66a4ef4be79dd894610f6 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 1 May 2026 09:28:52 -0400 Subject: [PATCH 37/41] Reset .grad between fwd_bwd reps for all bench variants MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Four bench files (mlp_activation, normalization, sparse_copy, sparse_linear) were missing reset_inputs, so inputs[key].grad accumulated across reps — each rep added an extra read+write of the full grad tensor, biasing fwd_bwd timing. Peak memory is unaffected since it's measured on a fresh inputs allocation (runner.py:353). Add make_grad_reset(keys) helper in utils.py and default reset_inputs in standard_fwd_bwd_pytorch_variants to clear .grad on grad_input_keys. Wire the same reset into the manually-constructed triton Variants in the four affected files (their triton paths also go through .backward()). For normalization, the helper also resets param_grad_is_zero=True on params with a grad_buffer so the triton kernel writes fresh. Drop the now-redundant _reset_logits_grad in bench_entropy_loss.py and bench_grpo_loss.py — the new default produces identical behavior. Co-Authored-By: Claude Opus 4.7 (1M context) --- tools/benchmark/bench_entropy_loss.py | 6 ------ tools/benchmark/bench_grpo_loss.py | 5 ----- tools/benchmark/bench_mlp_activation.py | 11 +++++++++-- tools/benchmark/bench_normalization.py | 16 +++++++++++++--- tools/benchmark/bench_sparse_copy.py | 4 +++- tools/benchmark/bench_sparse_linear.py | 4 +++- tools/benchmark/utils.py | 22 +++++++++++++++++++++- 7 files changed, 49 insertions(+), 19 deletions(-) diff --git a/tools/benchmark/bench_entropy_loss.py b/tools/benchmark/bench_entropy_loss.py index bc14e8bf9..4a5100915 100644 --- a/tools/benchmark/bench_entropy_loss.py +++ b/tools/benchmark/bench_entropy_loss.py @@ -63,10 +63,6 @@ def make_inputs(self, device: str) -> Inputs: } -def _reset_logits_grad(inputs: dict) -> None: - inputs["logits"].grad = None - - def _ce_labels_eager(logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: return F.cross_entropy(logits, labels) @@ -106,7 +102,6 @@ def triton_fwd_bwd(inputs: dict) -> dict: input_keys=input_keys, grad_input_keys=("logits",), output_key="loss", - reset_inputs=_reset_logits_grad, ) if TritonConfig.enabled(): variants.append(Variant(name="fast_llm_triton", fwd=triton_fwd, fwd_bwd=triton_fwd_bwd)) @@ -135,7 +130,6 @@ def benchmarks( input_keys=("logits",), grad_input_keys=("logits",), output_key="loss", - reset_inputs=_reset_logits_grad, ) if TritonConfig.enabled(): z_loss_variants.append(Variant(name="fast_llm_triton", fwd=_z_loss_triton_fwd, fwd_bwd=_z_loss_triton_fwd_bwd)) diff --git a/tools/benchmark/bench_grpo_loss.py b/tools/benchmark/bench_grpo_loss.py index 4b7353d12..ccfc15a58 100644 --- a/tools/benchmark/bench_grpo_loss.py +++ b/tools/benchmark/bench_grpo_loss.py @@ -66,10 +66,6 @@ def _grpo_eager(logits: torch.Tensor, labels: torch.Tensor, advantages: torch.Te return per_token_loss.mean() -def _reset_logits_grad(inputs: dict) -> None: - inputs["logits"].grad = None - - def _triton_fwd(inputs: dict) -> dict: loss, _, _ = triton_grpo_loss_forward_backward( inputs["logits"], @@ -106,7 +102,6 @@ def benchmarks( input_keys=("logits", "labels", "advantages", "old_log_probs"), grad_input_keys=("logits",), output_key="loss", - reset_inputs=_reset_logits_grad, ) if TritonConfig.enabled(): variants.append(Variant(name="fast_llm_triton", fwd=_triton_fwd, fwd_bwd=_triton_fwd_bwd)) diff --git a/tools/benchmark/bench_mlp_activation.py b/tools/benchmark/bench_mlp_activation.py index 0106d0521..1ce0c4466 100644 --- a/tools/benchmark/bench_mlp_activation.py +++ b/tools/benchmark/bench_mlp_activation.py @@ -9,7 +9,7 @@ triton_mlp_activation_forward, ) from tools.benchmark.runner import Case, Inputs, Variant -from tools.benchmark.utils import bench_main, dtype_short, standard_fwd_bwd_pytorch_variants +from tools.benchmark.utils import bench_main, dtype_short, make_grad_reset, standard_fwd_bwd_pytorch_variants # (tokens, ffn_dim) — input has shape (tokens, 2*ffn_dim) for gated. _SHAPES = [ @@ -77,7 +77,14 @@ def benchmarks( grad_output_key="grad_output", ) if TritonConfig.enabled(): - variants.append(Variant(name="fast_llm_triton", fwd=_triton_fwd, fwd_bwd=_triton_fwd_bwd)) + variants.append( + Variant( + name="fast_llm_triton", + fwd=_triton_fwd, + fwd_bwd=_triton_fwd_bwd, + reset_inputs=make_grad_reset(("input",)), + ) + ) return [ ( "mlp_activation (gated silu)", diff --git a/tools/benchmark/bench_normalization.py b/tools/benchmark/bench_normalization.py index 636e428a7..149012ab2 100644 --- a/tools/benchmark/bench_normalization.py +++ b/tools/benchmark/bench_normalization.py @@ -15,7 +15,7 @@ fused_normalization_available, ) from tools.benchmark.runner import Case, Inputs, Variant -from tools.benchmark.utils import bench_main, dtype_short, standard_fwd_bwd_pytorch_variants +from tools.benchmark.utils import bench_main, dtype_short, make_grad_reset, standard_fwd_bwd_pytorch_variants # (batch*seq, hidden). Numel fixed at 32M to mimic a constant training memory # budget across model widths; hidden swept from 1K to 16K. @@ -169,7 +169,12 @@ def benchmarks( ) if TritonConfig.enabled(): layer_norm_variants.append( - Variant(name="fast_llm_triton", fwd=_layer_norm_triton_fwd, fwd_bwd=_layer_norm_triton_fwd_bwd) + Variant( + name="fast_llm_triton", + fwd=_layer_norm_triton_fwd, + fwd_bwd=_layer_norm_triton_fwd_bwd, + reset_inputs=make_grad_reset(("input", "weight", "bias")), + ) ) rms_norm_variants = standard_fwd_bwd_pytorch_variants( _rms_norm_eager, @@ -180,7 +185,12 @@ def benchmarks( ) if TritonConfig.enabled(): rms_norm_variants.append( - Variant(name="fast_llm_triton", fwd=_rms_norm_triton_fwd, fwd_bwd=_rms_norm_triton_fwd_bwd) + Variant( + name="fast_llm_triton", + fwd=_rms_norm_triton_fwd, + fwd_bwd=_rms_norm_triton_fwd_bwd, + reset_inputs=make_grad_reset(("input", "weight")), + ) ) return [ ( diff --git a/tools/benchmark/bench_sparse_copy.py b/tools/benchmark/bench_sparse_copy.py index 62c165e3f..2da3773ae 100644 --- a/tools/benchmark/bench_sparse_copy.py +++ b/tools/benchmark/bench_sparse_copy.py @@ -10,7 +10,7 @@ get_sparse_map, ) from tools.benchmark.runner import Case, Inputs, Variant -from tools.benchmark.utils import bench_main, dtype_short, standard_fwd_bwd_pytorch_variants +from tools.benchmark.utils import bench_main, dtype_short, make_grad_reset, standard_fwd_bwd_pytorch_variants # (tokens, top_k, num_experts, hidden_size) _SHAPES = [ @@ -160,6 +160,7 @@ def benchmarks( fwd=_dispatch_triton_fwd, fwd_bwd=_dispatch_triton_fwd_bwd, output_postprocess=_dispatch_postprocess, + reset_inputs=make_grad_reset(("dense",)), ) ) combine_variants = standard_fwd_bwd_pytorch_variants( @@ -175,6 +176,7 @@ def benchmarks( fwd=_combine_triton_fwd, fwd_bwd=_combine_triton_fwd_bwd, output_postprocess=_combine_postprocess, + reset_inputs=make_grad_reset(("sparse", "scores")), ) ) return [ diff --git a/tools/benchmark/bench_sparse_linear.py b/tools/benchmark/bench_sparse_linear.py index 1e05eb038..2bfd4863a 100644 --- a/tools/benchmark/bench_sparse_linear.py +++ b/tools/benchmark/bench_sparse_linear.py @@ -6,7 +6,7 @@ from fast_llm.functional.triton.sparse_copy import SparseMap, get_sparse_map from fast_llm.functional.triton.sparse_linear import InputSparseLinear, OutputSparseLinear from tools.benchmark.runner import Case, Inputs, Variant -from tools.benchmark.utils import bench_main, dtype_short, standard_fwd_bwd_pytorch_variants +from tools.benchmark.utils import bench_main, dtype_short, make_grad_reset, standard_fwd_bwd_pytorch_variants # (tokens, top_k, num_experts, hidden, ffn_per_expert) _SHAPES = [ @@ -191,6 +191,7 @@ def benchmarks( fwd=_output_sparse_triton_fwd, fwd_bwd=_output_sparse_triton_fwd_bwd, output_postprocess=_mask_padded_rows, + reset_inputs=make_grad_reset(("lhs", "rhs")), ) ) input_inner_sparse_variants = standard_fwd_bwd_pytorch_variants( @@ -208,6 +209,7 @@ def benchmarks( fwd=_input_inner_sparse_triton_fwd, fwd_bwd=_input_inner_sparse_triton_fwd_bwd, output_postprocess=_mask_padded_rows, + reset_inputs=make_grad_reset(("lhs", "rhs")), ) ) return [ diff --git a/tools/benchmark/utils.py b/tools/benchmark/utils.py index f18ee870a..a7c63af95 100644 --- a/tools/benchmark/utils.py +++ b/tools/benchmark/utils.py @@ -149,6 +149,23 @@ def standard_fwd_variants( return variants +def make_grad_reset(keys: tuple[str, ...]) -> Callable[[Inputs], None]: + """Reset autograd `.grad` to None for the given input keys between reps. + `.backward()` accumulates into `.grad` on rep 2+, biasing fwd_bwd timing + via an extra read+write of the full grad tensor. Also resets + `param_grad_is_zero=True` on tensors with a `grad_buffer` (Fast-LLM + convention) so the next backward writes fresh instead of accumulating.""" + + def reset(inputs: Inputs) -> None: + for key in keys: + tensor = inputs[key] + tensor.grad = None + if hasattr(tensor, "grad_buffer"): + tensor.param_grad_is_zero = True + + return reset + + def standard_fwd_bwd_pytorch_variants( eager_function: Callable, input_keys: tuple[str, ...], @@ -163,7 +180,10 @@ def standard_fwd_bwd_pytorch_variants( ) -> list[Variant]: """fp32_reference + + pytorch_compiled + [pytorch_compiled_max] + `extra_functions`. Triton variants are appended by the caller (so each - bench file owns its triton wiring explicitly).""" + bench file owns its triton wiring explicitly). `reset_inputs` defaults to + clearing `.grad` on `grad_input_keys` between reps.""" + if reset_inputs is None: + reset_inputs = make_grad_reset(grad_input_keys) common = { "input_keys": input_keys, "grad_input_keys": grad_input_keys, From ec3e08bf41d1e6c383f194fd81ead73f97e564a8 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 1 May 2026 09:37:16 -0400 Subject: [PATCH 38/41] Validate bench skip lists; document reference output contract MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit tests/tools/test_triton_benchmark.py: assert at module load that every entry in _INTERPRETER_SKIP and _SKIP_VARIANTS matches a real benchmark or variant name. Without this, a rename silently turns a skip into a no-op (and the test falls back to running the kernel that's known to hit a Triton interpreter bug). tools/benchmark/runner.py: add a comment to _collect_reference_outputs noting that the reference is expected to natively produce values matching what output_postprocess masks on candidates — sparse benches honor this by zeroing padded/phantom rows in their loop reference. Makes the load-bearing implicit contract visible at the call site. Co-Authored-By: Claude Opus 4.7 (1M context) --- tests/tools/test_triton_benchmark.py | 11 +++++++++++ tools/benchmark/runner.py | 5 +++++ 2 files changed, 16 insertions(+) diff --git a/tests/tools/test_triton_benchmark.py b/tests/tools/test_triton_benchmark.py index cd2dd3db3..e27c9c188 100644 --- a/tests/tools/test_triton_benchmark.py +++ b/tests/tools/test_triton_benchmark.py @@ -67,6 +67,17 @@ def _build_params() -> list: _PARAMS = _build_params() +# Guard against silent drift if a benchmark or variant is renamed: every entry +# in _INTERPRETER_SKIP / _SKIP_VARIANTS must match at least one real name. +_actual_benchmark_names = {p.id for p in _PARAMS} +assert ( + _INTERPRETER_SKIP <= _actual_benchmark_names +), f"_INTERPRETER_SKIP entries don't match any benchmark: {_INTERPRETER_SKIP - _actual_benchmark_names}" +_actual_variant_names = {v.name for p in _PARAMS for v in p.values[2]} +assert ( + _SKIP_VARIANTS <= _actual_variant_names +), f"_SKIP_VARIANTS entries don't match any variant: {_SKIP_VARIANTS - _actual_variant_names}" + @pytest.fixture(autouse=True) def _patch_benchmark_env(monkeypatch): diff --git a/tools/benchmark/runner.py b/tools/benchmark/runner.py index de4d94a24..efc37e3c9 100644 --- a/tools/benchmark/runner.py +++ b/tools/benchmark/runner.py @@ -378,6 +378,11 @@ def _collect_reference_outputs( variant: Variant, case: Case, ) -> dict[str, dict[str, torch.Tensor]]: + # Reference outputs are taken raw — output_postprocess is only applied to + # candidate variants. The reference is therefore expected to natively + # produce zeros (or whatever value) in regions that output_postprocess + # masks, so the comparison is symmetric. Sparse benches honor this by + # zeroing padded/phantom rows in their loop reference. out: dict[str, dict[str, torch.Tensor]] = {} if variant.fwd is not None: out["fwd"] = _as_output_dict(variant.fwd(_seeded_inputs(case))) From 70e41ff2701700d0d15f45e21f02a448dae68f74 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 1 May 2026 09:50:55 -0400 Subject: [PATCH 39/41] Move triton kernel benchmarks to tools/benchmark/triton_kernels Namespace the suite under tools.benchmark.triton_kernels so future benchmark suites (for non-Triton subsystems) can sit alongside it under tools/benchmark/. All bench_*.py, runner.py, utils.py, gpu_specs.py, and __main__.py move into the new subpackage; tools/benchmark/__init__.py remains as the umbrella package marker. CLI entry point becomes `python -m tools.benchmark.triton_kernels`. Co-Authored-By: Claude Opus 4.7 (1M context) --- tests/tools/test_triton_benchmark.py | 6 +++--- tools/benchmark/triton_kernels/__init__.py | 0 tools/benchmark/{ => triton_kernels}/__main__.py | 8 ++++---- .../benchmark/{ => triton_kernels}/bench_entropy_loss.py | 4 ++-- tools/benchmark/{ => triton_kernels}/bench_grpo_loss.py | 4 ++-- .../{ => triton_kernels}/bench_mlp_activation.py | 9 +++++++-- .../{ => triton_kernels}/bench_normalization.py | 9 +++++++-- tools/benchmark/{ => triton_kernels}/bench_pointwise.py | 4 ++-- tools/benchmark/{ => triton_kernels}/bench_rotary.py | 4 ++-- .../benchmark/{ => triton_kernels}/bench_sparse_copy.py | 9 +++++++-- .../{ => triton_kernels}/bench_sparse_linear.py | 9 +++++++-- tools/benchmark/{ => triton_kernels}/gpu_specs.py | 0 tools/benchmark/{ => triton_kernels}/runner.py | 2 +- tools/benchmark/{ => triton_kernels}/utils.py | 2 +- 14 files changed, 45 insertions(+), 25 deletions(-) create mode 100644 tools/benchmark/triton_kernels/__init__.py rename tools/benchmark/{ => triton_kernels}/__main__.py (92%) rename tools/benchmark/{ => triton_kernels}/bench_entropy_loss.py (97%) rename tools/benchmark/{ => triton_kernels}/bench_grpo_loss.py (95%) rename tools/benchmark/{ => triton_kernels}/bench_mlp_activation.py (92%) rename tools/benchmark/{ => triton_kernels}/bench_normalization.py (96%) rename tools/benchmark/{ => triton_kernels}/bench_pointwise.py (95%) rename tools/benchmark/{ => triton_kernels}/bench_rotary.py (96%) rename tools/benchmark/{ => triton_kernels}/bench_sparse_copy.py (97%) rename tools/benchmark/{ => triton_kernels}/bench_sparse_linear.py (97%) rename tools/benchmark/{ => triton_kernels}/gpu_specs.py (100%) rename tools/benchmark/{ => triton_kernels}/runner.py (99%) rename tools/benchmark/{ => triton_kernels}/utils.py (98%) diff --git a/tests/tools/test_triton_benchmark.py b/tests/tools/test_triton_benchmark.py index e27c9c188..910c2910c 100644 --- a/tests/tools/test_triton_benchmark.py +++ b/tests/tools/test_triton_benchmark.py @@ -18,10 +18,10 @@ import pytest import torch -import tools.benchmark.runner as _bench_runner +import tools.benchmark.triton_kernels.runner as _bench_runner from fast_llm.functional.config import TritonConfig from fast_llm.functional.triton import triton_interpret -from tools.benchmark import ( +from tools.benchmark.triton_kernels import ( bench_entropy_loss, bench_grpo_loss, bench_mlp_activation, @@ -31,7 +31,7 @@ bench_sparse_copy, bench_sparse_linear, ) -from tools.benchmark.runner import run_benchmark +from tools.benchmark.triton_kernels.runner import run_benchmark _DTYPES = (torch.float32,) diff --git a/tools/benchmark/triton_kernels/__init__.py b/tools/benchmark/triton_kernels/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tools/benchmark/__main__.py b/tools/benchmark/triton_kernels/__main__.py similarity index 92% rename from tools/benchmark/__main__.py rename to tools/benchmark/triton_kernels/__main__.py index 102d72e63..e6ca0bb3b 100644 --- a/tools/benchmark/__main__.py +++ b/tools/benchmark/triton_kernels/__main__.py @@ -2,7 +2,7 @@ CLI entry point for the Fast-LLM Triton kernel benchmarking suite. Usage: - python -m tools.benchmark + python -m tools.benchmark.triton_kernels Available kernels are discovered dynamically from `bench_*.py` files in this package. Each such module must expose a `run(verbose: bool = False)` callable. @@ -19,7 +19,7 @@ # to eager. Bump it before any `torch.compile`-decorated code runs. import torch._dynamo -import tools.benchmark as _pkg +import tools.benchmark.triton_kernels as _pkg from fast_llm.engine.config_utils.data_type import DataType torch._dynamo.config.cache_size_limit = 64 @@ -35,14 +35,14 @@ def _list_benchmarks() -> dict[str, str]: names = {} for info in pkgutil.iter_modules(_pkg.__path__): if info.name.startswith("bench_"): - names[info.name.removeprefix("bench_")] = f"tools.benchmark.{info.name}" + names[info.name.removeprefix("bench_")] = f"tools.benchmark.triton_kernels.{info.name}" return names def main() -> None: benches = _list_benchmarks() parser = argparse.ArgumentParser( - prog="python -m tools.benchmark", + prog="python -m tools.benchmark.triton_kernels", description="Benchmark Fast-LLM Triton kernels against PyTorch alternatives.", ) parser.add_argument( diff --git a/tools/benchmark/bench_entropy_loss.py b/tools/benchmark/triton_kernels/bench_entropy_loss.py similarity index 97% rename from tools/benchmark/bench_entropy_loss.py rename to tools/benchmark/triton_kernels/bench_entropy_loss.py index 4a5100915..4e6c4073a 100644 --- a/tools/benchmark/bench_entropy_loss.py +++ b/tools/benchmark/triton_kernels/bench_entropy_loss.py @@ -6,8 +6,8 @@ from fast_llm.functional.config import EntropyLossType, TargetFormat, TritonConfig from fast_llm.functional.triton.entropy_loss import triton_entropy_loss_forward_backward from fast_llm.functional.triton.z_loss import triton_z_loss_forward_backward -from tools.benchmark.runner import Case, Inputs, Variant -from tools.benchmark.utils import bench_main, dtype_short, standard_fwd_bwd_pytorch_variants +from tools.benchmark.triton_kernels.runner import Case, Inputs, Variant +from tools.benchmark.triton_kernels.utils import bench_main, dtype_short, standard_fwd_bwd_pytorch_variants # (tokens, vocab_size) _SHAPES = [ diff --git a/tools/benchmark/bench_grpo_loss.py b/tools/benchmark/triton_kernels/bench_grpo_loss.py similarity index 95% rename from tools/benchmark/bench_grpo_loss.py rename to tools/benchmark/triton_kernels/bench_grpo_loss.py index ccfc15a58..ea998e050 100644 --- a/tools/benchmark/bench_grpo_loss.py +++ b/tools/benchmark/triton_kernels/bench_grpo_loss.py @@ -4,8 +4,8 @@ from fast_llm.functional.config import TritonConfig from fast_llm.functional.triton.grpo_loss import triton_grpo_loss_forward_backward -from tools.benchmark.runner import Case, Inputs, Variant -from tools.benchmark.utils import bench_main, dtype_short, standard_fwd_bwd_pytorch_variants +from tools.benchmark.triton_kernels.runner import Case, Inputs, Variant +from tools.benchmark.triton_kernels.utils import bench_main, dtype_short, standard_fwd_bwd_pytorch_variants _SHAPES = [ (4096, 32768), diff --git a/tools/benchmark/bench_mlp_activation.py b/tools/benchmark/triton_kernels/bench_mlp_activation.py similarity index 92% rename from tools/benchmark/bench_mlp_activation.py rename to tools/benchmark/triton_kernels/bench_mlp_activation.py index 1ce0c4466..52302c256 100644 --- a/tools/benchmark/bench_mlp_activation.py +++ b/tools/benchmark/triton_kernels/bench_mlp_activation.py @@ -8,8 +8,13 @@ triton_mlp_activation_autograd, triton_mlp_activation_forward, ) -from tools.benchmark.runner import Case, Inputs, Variant -from tools.benchmark.utils import bench_main, dtype_short, make_grad_reset, standard_fwd_bwd_pytorch_variants +from tools.benchmark.triton_kernels.runner import Case, Inputs, Variant +from tools.benchmark.triton_kernels.utils import ( + bench_main, + dtype_short, + make_grad_reset, + standard_fwd_bwd_pytorch_variants, +) # (tokens, ffn_dim) — input has shape (tokens, 2*ffn_dim) for gated. _SHAPES = [ diff --git a/tools/benchmark/bench_normalization.py b/tools/benchmark/triton_kernels/bench_normalization.py similarity index 96% rename from tools/benchmark/bench_normalization.py rename to tools/benchmark/triton_kernels/bench_normalization.py index 149012ab2..6368500fe 100644 --- a/tools/benchmark/bench_normalization.py +++ b/tools/benchmark/triton_kernels/bench_normalization.py @@ -14,8 +14,13 @@ fast_normalization_available, fused_normalization_available, ) -from tools.benchmark.runner import Case, Inputs, Variant -from tools.benchmark.utils import bench_main, dtype_short, make_grad_reset, standard_fwd_bwd_pytorch_variants +from tools.benchmark.triton_kernels.runner import Case, Inputs, Variant +from tools.benchmark.triton_kernels.utils import ( + bench_main, + dtype_short, + make_grad_reset, + standard_fwd_bwd_pytorch_variants, +) # (batch*seq, hidden). Numel fixed at 32M to mimic a constant training memory # budget across model widths; hidden swept from 1K to 16K. diff --git a/tools/benchmark/bench_pointwise.py b/tools/benchmark/triton_kernels/bench_pointwise.py similarity index 95% rename from tools/benchmark/bench_pointwise.py rename to tools/benchmark/triton_kernels/bench_pointwise.py index 47cd23812..682eea265 100644 --- a/tools/benchmark/bench_pointwise.py +++ b/tools/benchmark/triton_kernels/bench_pointwise.py @@ -4,8 +4,8 @@ import torch from fast_llm.functional.triton.pointwise import triton_add, triton_copy, triton_fill -from tools.benchmark.runner import Case, Inputs -from tools.benchmark.utils import bench_main, dtype_short, standard_fwd_variants +from tools.benchmark.triton_kernels.runner import Case, Inputs +from tools.benchmark.triton_kernels.utils import bench_main, dtype_short, standard_fwd_variants # 4× steps so L2 → HBM and saturated-HBM regimes are visible. _SIZES_NUMEL = [ diff --git a/tools/benchmark/bench_rotary.py b/tools/benchmark/triton_kernels/bench_rotary.py similarity index 96% rename from tools/benchmark/bench_rotary.py rename to tools/benchmark/triton_kernels/bench_rotary.py index 99f8651f8..c029e828c 100644 --- a/tools/benchmark/bench_rotary.py +++ b/tools/benchmark/triton_kernels/bench_rotary.py @@ -7,8 +7,8 @@ from fast_llm.functional.config import TritonConfig from fast_llm.functional.triton.rotary import triton_rotary_ -from tools.benchmark.runner import Case, Inputs, Variant -from tools.benchmark.utils import bench_main, dtype_short +from tools.benchmark.triton_kernels.runner import Case, Inputs, Variant +from tools.benchmark.triton_kernels.utils import bench_main, dtype_short # (tokens, num_heads, head_size) — tokens = batch * seq_len _SHAPES = [ diff --git a/tools/benchmark/bench_sparse_copy.py b/tools/benchmark/triton_kernels/bench_sparse_copy.py similarity index 97% rename from tools/benchmark/bench_sparse_copy.py rename to tools/benchmark/triton_kernels/bench_sparse_copy.py index 2da3773ae..f70c76cbd 100644 --- a/tools/benchmark/bench_sparse_copy.py +++ b/tools/benchmark/triton_kernels/bench_sparse_copy.py @@ -9,8 +9,13 @@ copy_sparse_to_dense_autograd, get_sparse_map, ) -from tools.benchmark.runner import Case, Inputs, Variant -from tools.benchmark.utils import bench_main, dtype_short, make_grad_reset, standard_fwd_bwd_pytorch_variants +from tools.benchmark.triton_kernels.runner import Case, Inputs, Variant +from tools.benchmark.triton_kernels.utils import ( + bench_main, + dtype_short, + make_grad_reset, + standard_fwd_bwd_pytorch_variants, +) # (tokens, top_k, num_experts, hidden_size) _SHAPES = [ diff --git a/tools/benchmark/bench_sparse_linear.py b/tools/benchmark/triton_kernels/bench_sparse_linear.py similarity index 97% rename from tools/benchmark/bench_sparse_linear.py rename to tools/benchmark/triton_kernels/bench_sparse_linear.py index 2bfd4863a..011d035e5 100644 --- a/tools/benchmark/bench_sparse_linear.py +++ b/tools/benchmark/triton_kernels/bench_sparse_linear.py @@ -5,8 +5,13 @@ from fast_llm.functional.config import TritonConfig from fast_llm.functional.triton.sparse_copy import SparseMap, get_sparse_map from fast_llm.functional.triton.sparse_linear import InputSparseLinear, OutputSparseLinear -from tools.benchmark.runner import Case, Inputs, Variant -from tools.benchmark.utils import bench_main, dtype_short, make_grad_reset, standard_fwd_bwd_pytorch_variants +from tools.benchmark.triton_kernels.runner import Case, Inputs, Variant +from tools.benchmark.triton_kernels.utils import ( + bench_main, + dtype_short, + make_grad_reset, + standard_fwd_bwd_pytorch_variants, +) # (tokens, top_k, num_experts, hidden, ffn_per_expert) _SHAPES = [ diff --git a/tools/benchmark/gpu_specs.py b/tools/benchmark/triton_kernels/gpu_specs.py similarity index 100% rename from tools/benchmark/gpu_specs.py rename to tools/benchmark/triton_kernels/gpu_specs.py diff --git a/tools/benchmark/runner.py b/tools/benchmark/triton_kernels/runner.py similarity index 99% rename from tools/benchmark/runner.py rename to tools/benchmark/triton_kernels/runner.py index efc37e3c9..c0b45f91b 100644 --- a/tools/benchmark/runner.py +++ b/tools/benchmark/triton_kernels/runner.py @@ -19,7 +19,7 @@ import torch from fast_llm.utils import header -from tools.benchmark.gpu_specs import GpuSpec, detect_gpu_spec +from tools.benchmark.triton_kernels.gpu_specs import GpuSpec, detect_gpu_spec # Before each timed CUDA-graph-backed call we must mark a new step so the graph # system knows the previous rep's output buffers are no longer live. Without diff --git a/tools/benchmark/utils.py b/tools/benchmark/triton_kernels/utils.py similarity index 98% rename from tools/benchmark/utils.py rename to tools/benchmark/triton_kernels/utils.py index a7c63af95..8887c4953 100644 --- a/tools/benchmark/utils.py +++ b/tools/benchmark/triton_kernels/utils.py @@ -6,7 +6,7 @@ from fast_llm.engine.config_utils.data_type import DataType from fast_llm.functional.config import TritonConfig -from tools.benchmark.runner import Inputs, Variant, run_benchmark +from tools.benchmark.triton_kernels.runner import Inputs, Variant, run_benchmark DEFAULT_DTYPES: tuple[torch.dtype, ...] = (torch.bfloat16,) From 1f23bce091d097e5513599afcb5369952af674d4 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 1 May 2026 09:56:45 -0400 Subject: [PATCH 40/41] Remove ad-hoc inspect_rotary_compile.py debug script Co-Authored-By: Claude Opus 4.7 (1M context) --- tools/inspect_rotary_compile.py | 60 --------------------------------- 1 file changed, 60 deletions(-) delete mode 100644 tools/inspect_rotary_compile.py diff --git a/tools/inspect_rotary_compile.py b/tools/inspect_rotary_compile.py deleted file mode 100644 index eac1624a2..000000000 --- a/tools/inspect_rotary_compile.py +++ /dev/null @@ -1,60 +0,0 @@ -""" -Dump the Triton kernel that torch.compile generates for the rotary embedding, -so we can compare it to our hand-written fast_llm kernel. - -Run on a GPU node: - python tools/inspect_rotary_compile.py -Output lands in /tmp/torchinductor_*/ (one subdir per compiled function). -This script also prints the path and first 200 lines of each .py file found. -""" - -import os -from pathlib import Path - -import torch -import torch._inductor.config as inductor_config - -# Route torch.compile output to a known directory. -_OUT = Path("/tmp/torchinductor_rotary_inspect") -_OUT.mkdir(parents=True, exist_ok=True) -os.environ["TORCHINDUCTOR_CACHE_DIR"] = str(_OUT) - - -inductor_config.debug = True # writes generated Triton .py files alongside the cache - -tokens, num_heads, head_size = 4096, 32, 128 -rotary_dim = head_size // 2 -dtype = torch.bfloat16 -device = "cuda" - -input_ = torch.randn(tokens, num_heads, head_size, dtype=dtype, device=device) -frequencies = torch.randn(tokens, 2 * rotary_dim, dtype=torch.float32, device=device) - - -def _rotary_eager(input_: torch.Tensor, frequencies: torch.Tensor) -> torch.Tensor: - rotary_dim = frequencies.shape[-1] // 2 - freq_re = frequencies[:, :rotary_dim].unsqueeze(1) - freq_im = frequencies[:, rotary_dim:].unsqueeze(1) - x_re, x_im = input_.chunk(2, dim=-1) - out_re = x_re * freq_re - x_im * freq_im - out_im = x_im * freq_re + x_re * freq_im - return torch.cat([out_re, out_im], dim=-1) - - -compiled = torch.compile(_rotary_eager, mode="default", dynamic=False) - -# Trigger compilation. -out = compiled(input_, frequencies) -torch.cuda.synchronize() -print(f"Output shape: {out.shape}, dtype: {out.dtype}") -print(f"\nInductor cache / debug output dir: {_OUT}") - -# Find and print the generated Triton kernel files. -for path in sorted(_OUT.rglob("*.py")): - print(f"\n{'='*80}") - print(f"FILE: {path}") - print("=" * 80) - lines = path.read_text().splitlines(keepends=True) - print("".join(lines[:300])) - if len(lines) > 300: - print(f"... ({len(lines) - 300} more lines)") From 6741db99ed0f09394fde11cace789b886f533777 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 1 May 2026 10:49:53 -0400 Subject: [PATCH 41/41] Address review items 1-15: style and type cleanup MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Mostly cosmetic — no behavior changes. Highlights: - Replace the PytorchVariant / Fp32ReferenceVariant / FwdOnlyPytorchVariant / Fp32FwdOnlyReferenceVariant dataclass hierarchy with two factory functions (`pytorch_variant`, `fwd_only_variant`) returning plain Variant instances. Removes the unusual __post_init__ field-shadowing pattern. - Convert Case base attributes to @property (required ones raise NotImplementedError); add @dataclass to all Case subclasses for explicit __init__ generation. - Re-privatize _fused_/_fast_normalization_available in the layer module; the benchmark probes Apex availability locally instead of relying on the layer module's public surface. - Switch typing-import style to `import typing` consistently (runner, utils). - `device: torch.device` instead of `device: str` throughout make_inputs. - Boolean kwargs on triton_normalization_autograd; replace bare assert with Assert.custom in tests; document _WarmupKey set typing; rename `header` shadow in runner.py; misc docstrings. Co-Authored-By: Claude Opus 4.7 (1M context) --- .../common/normalization/normalization.py | 20 +- tests/tools/test_triton_benchmark.py | 9 +- .../triton_kernels/bench_entropy_loss.py | 16 +- .../triton_kernels/bench_grpo_loss.py | 9 +- .../triton_kernels/bench_mlp_activation.py | 5 +- .../triton_kernels/bench_normalization.py | 66 +++-- .../triton_kernels/bench_pointwise.py | 13 +- .../benchmark/triton_kernels/bench_rotary.py | 2 +- .../triton_kernels/bench_sparse_copy.py | 11 +- .../triton_kernels/bench_sparse_linear.py | 19 +- tools/benchmark/triton_kernels/runner.py | 86 ++++--- tools/benchmark/triton_kernels/utils.py | 234 ++++++++---------- 12 files changed, 270 insertions(+), 220 deletions(-) diff --git a/fast_llm/layers/common/normalization/normalization.py b/fast_llm/layers/common/normalization/normalization.py index 4f92223d2..2858b9370 100644 --- a/fast_llm/layers/common/normalization/normalization.py +++ b/fast_llm/layers/common/normalization/normalization.py @@ -22,16 +22,16 @@ try: import fused_layer_norm_cuda # noqa - fused_normalization_available = torch.cuda.is_available() + _fused_normalization_available = torch.cuda.is_available() except ImportError: - fused_normalization_available = False + _fused_normalization_available = False try: import fast_layer_norm # noqa - fast_normalization_available = torch.cuda.is_available() + _fast_normalization_available = torch.cuda.is_available() except ImportError: - fast_normalization_available = False + _fast_normalization_available = False try: @@ -79,7 +79,7 @@ class FastLayerNorm(torch.autograd.Function): def forward( ctx, input_: torch.Tensor, normalized_shape: torch.Size, weight: torch.Tensor, bias: torch.Tensor, eps: float ) -> torch.Tensor: # noqa - assert fast_normalization_available + assert _fast_normalization_available Assert.incl(normalized_shape.numel(), _PERSIST_LN_SIZES) output, _, inv_var = fast_layer_norm.ln_fwd(input_, weight, bias, eps) ctx.save_for_backward(output, weight, bias, inv_var) @@ -110,7 +110,7 @@ class FusedLayerNorm(torch.autograd.Function): def forward( ctx, input_: torch.Tensor, normalized_shape: torch.Size, weight: torch.Tensor, bias: torch.Tensor, eps: float ) -> torch.Tensor: # noqa - assert fused_normalization_available + assert _fused_normalization_available ctx.eps = eps ctx.normalized_shape = normalized_shape output, _, inv_var = fused_layer_norm_cuda.forward_affine(input_, normalized_shape, weight, bias, eps) @@ -136,7 +136,7 @@ class FusedRMSNorm(torch.autograd.Function): def forward( ctx, input_: torch.Tensor, normalized_shape: torch.Size, weight: torch.Tensor, eps: float ) -> torch.Tensor: # noqa - assert fused_normalization_available + assert _fused_normalization_available ctx.eps = eps ctx.normalized_shape = normalized_shape output, inv_var = fused_layer_norm_cuda.rms_forward_affine(input_, normalized_shape, weight, eps) @@ -185,7 +185,7 @@ def __init__(self, config: ConfigType, hidden_dim: TensorDim, lr_scale: float | implementation = self._config.implementation if implementation == NormalizationImplementation.auto: if ( - fast_normalization_available + _fast_normalization_available and hidden_dim.size in _PERSIST_LN_SIZES and not self._config.zero_centered ): @@ -193,7 +193,7 @@ def __init__(self, config: ConfigType, hidden_dim: TensorDim, lr_scale: float | elif TritonConfig.enabled(torch.device("cuda")) or self._config.zero_centered: log_main_rank("Fast layer norm unavailable, using backup triton implementation.") implementation = NormalizationImplementation.triton - elif fused_normalization_available: + elif _fused_normalization_available: log_main_rank("Fast layer norm unavailable, using backup fused implementation.") implementation = NormalizationImplementation.fused else: @@ -261,7 +261,7 @@ def __init__(self, config: ConfigType, hidden_dim: TensorDim, lr_scale: float | if implementation == NormalizationImplementation.auto: if TritonConfig.enabled(torch.device("cuda")) or self._config.zero_centered: implementation = NormalizationImplementation.triton - elif fused_normalization_available: + elif _fused_normalization_available: log_main_rank("Triton RMS norm unavailable, using fused implementation.") implementation = NormalizationImplementation.fused else: diff --git a/tests/tools/test_triton_benchmark.py b/tests/tools/test_triton_benchmark.py index 910c2910c..e35266a1c 100644 --- a/tests/tools/test_triton_benchmark.py +++ b/tests/tools/test_triton_benchmark.py @@ -21,6 +21,7 @@ import tools.benchmark.triton_kernels.runner as _bench_runner from fast_llm.functional.config import TritonConfig from fast_llm.functional.triton import triton_interpret +from fast_llm.utils import Assert from tools.benchmark.triton_kernels import ( bench_entropy_loss, bench_grpo_loss, @@ -70,13 +71,9 @@ def _build_params() -> list: # Guard against silent drift if a benchmark or variant is renamed: every entry # in _INTERPRETER_SKIP / _SKIP_VARIANTS must match at least one real name. _actual_benchmark_names = {p.id for p in _PARAMS} -assert ( - _INTERPRETER_SKIP <= _actual_benchmark_names -), f"_INTERPRETER_SKIP entries don't match any benchmark: {_INTERPRETER_SKIP - _actual_benchmark_names}" _actual_variant_names = {v.name for p in _PARAMS for v in p.values[2]} -assert ( - _SKIP_VARIANTS <= _actual_variant_names -), f"_SKIP_VARIANTS entries don't match any variant: {_SKIP_VARIANTS - _actual_variant_names}" +Assert.custom(set.issubset, _INTERPRETER_SKIP, _actual_benchmark_names) +Assert.custom(set.issubset, _SKIP_VARIANTS, _actual_variant_names) @pytest.fixture(autouse=True) diff --git a/tools/benchmark/triton_kernels/bench_entropy_loss.py b/tools/benchmark/triton_kernels/bench_entropy_loss.py index 4e6c4073a..851995e6c 100644 --- a/tools/benchmark/triton_kernels/bench_entropy_loss.py +++ b/tools/benchmark/triton_kernels/bench_entropy_loss.py @@ -1,4 +1,8 @@ +"""Cross-entropy and z-loss kernels: label-target CE, logit-target CE, reverse +KL, and z-loss (logsumexp²).""" + import dataclasses +import typing import torch import torch.nn.functional as F @@ -37,26 +41,28 @@ def expected_flops(self) -> int: return 4 * self.tokens * self.vocab +@dataclasses.dataclass class EntropyLabelCase(_EntropyCase): @property def expected_bytes(self) -> int: # 2× logits + small labels traffic. return 2 * self.tokens * self.vocab * self.dtype.itemsize + self.tokens * 4 - def make_inputs(self, device: str) -> Inputs: + def make_inputs(self, device: torch.device) -> Inputs: return { "logits": torch.randn(self.tokens, self.vocab, dtype=self.dtype, device=device, requires_grad=True), "labels": torch.randint(0, self.vocab, (self.tokens,), dtype=torch.long, device=device), } +@dataclasses.dataclass class EntropyDistCase(_EntropyCase): @property def expected_bytes(self) -> int: # 2× logits + 1× target_logits. return 3 * self.tokens * self.vocab * self.dtype.itemsize - def make_inputs(self, device: str) -> Inputs: + def make_inputs(self, device: torch.device) -> Inputs: return { "logits": torch.randn(self.tokens, self.vocab, dtype=self.dtype, device=device, requires_grad=True), "target_logits": torch.randn(self.tokens, self.vocab, dtype=self.dtype, device=device), @@ -80,7 +86,11 @@ def _z_loss_eager(logits: torch.Tensor) -> torch.Tensor: return (log_z * log_z).mean() -def _entropy_variants(eager_function, input_keys, triton_kwargs=None) -> list[Variant]: +def _entropy_variants( + eager_function: typing.Callable, + input_keys: tuple[str, ...], + triton_kwargs: dict | None = None, +) -> list[Variant]: """Variants for the 3 entropy_loss kernels that share `triton_entropy_loss_forward_backward`.""" target_key = input_keys[1] triton_kwargs = triton_kwargs or {} diff --git a/tools/benchmark/triton_kernels/bench_grpo_loss.py b/tools/benchmark/triton_kernels/bench_grpo_loss.py index ea998e050..30e6901ab 100644 --- a/tools/benchmark/triton_kernels/bench_grpo_loss.py +++ b/tools/benchmark/triton_kernels/bench_grpo_loss.py @@ -1,3 +1,6 @@ +"""GRPO (Group Relative Policy Optimization) loss: clipped policy ratio with +fused softmax + gather + clipped advantage in a single kernel.""" + import dataclasses import torch @@ -41,7 +44,7 @@ def expected_flops(self) -> int: # softmax (fwd) + grad (bwd) ≈ 14 FLOPs/element. return 14 * self.tokens * self.vocab - def make_inputs(self, device: str) -> Inputs: + def make_inputs(self, device: torch.device) -> Inputs: return { "logits": torch.randn(self.tokens, self.vocab, dtype=self.dtype, device=device, requires_grad=True), "labels": torch.randint(0, self.vocab, (self.tokens,), dtype=torch.long, device=device), @@ -50,7 +53,9 @@ def make_inputs(self, device: str) -> Inputs: } -def _grpo_eager(logits: torch.Tensor, labels: torch.Tensor, advantages: torch.Tensor, old_log_probs: torch.Tensor): +def _grpo_eager( + logits: torch.Tensor, labels: torch.Tensor, advantages: torch.Tensor, old_log_probs: torch.Tensor +) -> torch.Tensor: log_probs = logits.float().log_softmax(-1) # clamp + labels>=0 guards mirror production code that handles ignore_index=-100; # labels here are always non-negative (randint), so the masks are dead in this benchmark. diff --git a/tools/benchmark/triton_kernels/bench_mlp_activation.py b/tools/benchmark/triton_kernels/bench_mlp_activation.py index 52302c256..b18dbcb21 100644 --- a/tools/benchmark/triton_kernels/bench_mlp_activation.py +++ b/tools/benchmark/triton_kernels/bench_mlp_activation.py @@ -1,3 +1,6 @@ +"""Gated MLP activation (e.g. SiLU): splits the input into (linear, gate), +applies the activation to gate, multiplies them, in a single kernel.""" + import dataclasses import torch @@ -50,7 +53,7 @@ def expected_flops(self) -> int: # gated silu: fwd ≈ 6 FLOPs/element, bwd ≈ 8 FLOPs/element. return 14 * self.tokens * self.ffn_dim - def make_inputs(self, device: str) -> Inputs: + def make_inputs(self, device: torch.device) -> Inputs: return { "input": torch.randn(self.tokens, 2 * self.ffn_dim, dtype=self.dtype, device=device, requires_grad=True), "grad_output": torch.randn(self.tokens, self.ffn_dim, dtype=self.dtype, device=device), diff --git a/tools/benchmark/triton_kernels/bench_normalization.py b/tools/benchmark/triton_kernels/bench_normalization.py index 6368500fe..70e63a039 100644 --- a/tools/benchmark/triton_kernels/bench_normalization.py +++ b/tools/benchmark/triton_kernels/bench_normalization.py @@ -7,13 +7,7 @@ from fast_llm.functional.config import TritonConfig from fast_llm.functional.triton.normalization import triton_normalization_autograd -from fast_llm.layers.common.normalization.normalization import ( - FastLayerNorm, - FusedLayerNorm, - FusedRMSNorm, - fast_normalization_available, - fused_normalization_available, -) +from fast_llm.layers.common.normalization.normalization import FastLayerNorm, FusedLayerNorm, FusedRMSNorm from tools.benchmark.triton_kernels.runner import Case, Inputs, Variant from tools.benchmark.triton_kernels.utils import ( bench_main, @@ -22,6 +16,20 @@ standard_fwd_bwd_pytorch_variants, ) +try: + import fused_layer_norm_cuda # noqa: F401 + + _fused_normalization_available = torch.cuda.is_available() +except ImportError: + _fused_normalization_available = False + +try: + import fast_layer_norm # noqa: F401 + + _fast_normalization_available = torch.cuda.is_available() +except ImportError: + _fast_normalization_available = False + # (batch*seq, hidden). Numel fixed at 32M to mimic a constant training memory # budget across model widths; hidden swept from 1K to 16K. _SHAPES = [ @@ -55,6 +63,7 @@ def compute_dtype(self) -> torch.dtype: return self.dtype +@dataclasses.dataclass class LayerNormCase(_NormalizationCase): @property def expected_bytes(self) -> int: @@ -66,7 +75,7 @@ def expected_flops(self) -> int: # fwd ≈ 7 FLOPs/elem (mean, variance, normalize, scale+shift); bwd ≈ 2× fwd. return 21 * self.rows * self.cols - def make_inputs(self, device: str) -> Inputs: + def make_inputs(self, device: torch.device) -> Inputs: return { "input": torch.randn(self.rows, self.cols, dtype=self.dtype, device=device, requires_grad=True), "weight": _setup_param(torch.randn(self.cols, dtype=self.dtype, device=device, requires_grad=True)), @@ -75,6 +84,7 @@ def make_inputs(self, device: str) -> Inputs: } +@dataclasses.dataclass class RmsNormCase(_NormalizationCase): @property def expected_bytes(self) -> int: @@ -86,7 +96,7 @@ def expected_flops(self) -> int: # No mean subtraction or bias compared to LayerNorm. return 15 * self.rows * self.cols - def make_inputs(self, device: str) -> Inputs: + def make_inputs(self, device: torch.device) -> Inputs: return { "input": torch.randn(self.rows, self.cols, dtype=self.dtype, device=device, requires_grad=True), "weight": _setup_param(torch.randn(self.cols, dtype=self.dtype, device=device, requires_grad=True)), @@ -94,11 +104,11 @@ def make_inputs(self, device: str) -> Inputs: } -def _layer_norm_eager(input_, weight, bias): +def _layer_norm_eager(input_: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor) -> torch.Tensor: return torch.layer_norm(input_, weight.shape, weight, bias, _EPS) -def _rms_norm_eager(input_, weight): +def _rms_norm_eager(input_: torch.Tensor, weight: torch.Tensor) -> torch.Tensor: return torch.rms_norm(input_, weight.shape, weight, _EPS) @@ -106,14 +116,24 @@ def _param_grad(param: torch.Tensor) -> torch.Tensor: return param.grad if param.grad is not None else param.grad_buffer +def _layer_norm_triton(inputs: dict) -> torch.Tensor: + return triton_normalization_autograd( + inputs["input"], inputs["weight"], inputs["bias"], eps=_EPS, training=True, zero_centered=False + ) + + +def _rms_norm_triton(inputs: dict) -> torch.Tensor: + return triton_normalization_autograd( + inputs["input"], inputs["weight"], None, eps=_EPS, training=True, zero_centered=False + ) + + def _layer_norm_triton_fwd(inputs: dict) -> dict: - return { - "output": triton_normalization_autograd(inputs["input"], inputs["weight"], inputs["bias"], _EPS, True, False) - } + return {"output": _layer_norm_triton(inputs)} def _layer_norm_triton_fwd_bwd(inputs: dict) -> dict: - output = triton_normalization_autograd(inputs["input"], inputs["weight"], inputs["bias"], _EPS, True, False) + output = _layer_norm_triton(inputs) output.backward(inputs["grad_output"]) return { "output": output.detach(), @@ -124,11 +144,11 @@ def _layer_norm_triton_fwd_bwd(inputs: dict) -> dict: def _rms_norm_triton_fwd(inputs: dict) -> dict: - return {"output": triton_normalization_autograd(inputs["input"], inputs["weight"], None, _EPS, True, False)} + return {"output": _rms_norm_triton(inputs)} def _rms_norm_triton_fwd_bwd(inputs: dict) -> dict: - output = triton_normalization_autograd(inputs["input"], inputs["weight"], None, _EPS, True, False) + output = _rms_norm_triton(inputs) output.backward(inputs["grad_output"]) return { "output": output.detach(), @@ -137,26 +157,26 @@ def _rms_norm_triton_fwd_bwd(inputs: dict) -> dict: } -def _layer_norm_apex_fused(input_, weight, bias): +def _layer_norm_apex_fused(input_: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor) -> torch.Tensor: return FusedLayerNorm.apply(input_, weight.shape, weight, bias, _EPS) -def _layer_norm_apex_fast(input_, weight, bias): +def _layer_norm_apex_fast(input_: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor) -> torch.Tensor: return FastLayerNorm.apply(input_, weight.shape, weight, bias, _EPS) -def _rms_norm_apex_fused(input_, weight): +def _rms_norm_apex_fused(input_: torch.Tensor, weight: torch.Tensor) -> torch.Tensor: return FusedRMSNorm.apply(input_, weight.shape, weight, _EPS) _LAYER_NORM_EXTRAS: dict = {} -if fused_normalization_available: +if _fused_normalization_available: _LAYER_NORM_EXTRAS["apex_fused"] = _layer_norm_apex_fused -if fast_normalization_available: +if _fast_normalization_available: _LAYER_NORM_EXTRAS["apex_fast"] = _layer_norm_apex_fast _RMS_NORM_EXTRAS: dict = {} -if fused_normalization_available: +if _fused_normalization_available: _RMS_NORM_EXTRAS["apex_fused"] = _rms_norm_apex_fused diff --git a/tools/benchmark/triton_kernels/bench_pointwise.py b/tools/benchmark/triton_kernels/bench_pointwise.py index 682eea265..91db26ddc 100644 --- a/tools/benchmark/triton_kernels/bench_pointwise.py +++ b/tools/benchmark/triton_kernels/bench_pointwise.py @@ -1,3 +1,7 @@ +"""Memory-bound pointwise ops (copy, fill, add). Sweeps numel from L2-resident +to HBM-saturated to surface the bandwidth-bound regime where Triton wins or +loses against PyTorch.""" + import dataclasses import typing @@ -37,21 +41,24 @@ def expected_bytes(self) -> int: return self.bytes_factor * self.numel * self.dtype.itemsize +@dataclasses.dataclass class CopyCase(_PointwiseCase): bytes_factor = 2 - def make_inputs(self, device: str) -> Inputs: + def make_inputs(self, device: torch.device) -> Inputs: input_ = torch.randn(self.numel, dtype=self.dtype, device=device) return {"input_": input_, "out": torch.empty_like(input_)} +@dataclasses.dataclass class FillCase(_PointwiseCase): bytes_factor = 1 - def make_inputs(self, device: str) -> Inputs: + def make_inputs(self, device: torch.device) -> Inputs: return {"input_": torch.empty(self.numel, dtype=self.dtype, device=device), "value": 1.5} +@dataclasses.dataclass class AddCase(_PointwiseCase): bytes_factor = 3 @@ -59,7 +66,7 @@ class AddCase(_PointwiseCase): def expected_flops(self) -> int: return self.numel - def make_inputs(self, device: str) -> Inputs: + def make_inputs(self, device: torch.device) -> Inputs: return { "input_": torch.randn(self.numel, dtype=self.dtype, device=device), "other": torch.randn(self.numel, dtype=self.dtype, device=device), diff --git a/tools/benchmark/triton_kernels/bench_rotary.py b/tools/benchmark/triton_kernels/bench_rotary.py index c029e828c..f2f6a08ed 100644 --- a/tools/benchmark/triton_kernels/bench_rotary.py +++ b/tools/benchmark/triton_kernels/bench_rotary.py @@ -46,7 +46,7 @@ def expected_flops(self) -> int: # 6 FLOPs per (re, im) element pair: 4 muls + 2 add/sub. return 6 * self.tokens * self.num_heads * (self.head_size // 2) - def make_inputs(self, device: str) -> Inputs: + def make_inputs(self, device: torch.device) -> Inputs: rotary_dim = self.head_size // 2 input_ = torch.randn(self.tokens, self.num_heads, self.head_size, dtype=self.dtype, device=device) return { diff --git a/tools/benchmark/triton_kernels/bench_sparse_copy.py b/tools/benchmark/triton_kernels/bench_sparse_copy.py index f70c76cbd..dfa0b2e65 100644 --- a/tools/benchmark/triton_kernels/bench_sparse_copy.py +++ b/tools/benchmark/triton_kernels/bench_sparse_copy.py @@ -1,3 +1,6 @@ +"""MoE dispatch and combine: scatter dense rows into sparse expert-grouped +buffers and gather them back with score weighting.""" + import dataclasses import torch @@ -25,7 +28,7 @@ ] -def _make_phantom_mask(sparse_map: SparseMap, device: str) -> torch.Tensor: +def _make_phantom_mask(sparse_map: SparseMap, device: torch.device) -> torch.Tensor: # True for within-expert padding rows and the static tail past expert_ends[-1]; # used only in output_postprocess, never in the timed path. mask = torch.zeros(sparse_map.num_rows, 1, dtype=torch.bool, device=device) @@ -62,8 +65,9 @@ def expected_bytes(self) -> int: return 2 * (1 + self.top_k) * self.tokens * self.hidden * self.dtype.itemsize +@dataclasses.dataclass class DispatchCase(_SparseCopyCase): - def make_inputs(self, device: str) -> Inputs: + def make_inputs(self, device: torch.device) -> Inputs: top_experts = torch.randint(0, self.num_experts, (self.tokens, self.top_k), device=device) sparse_map = get_sparse_map(top_experts, self.num_experts) return { @@ -74,13 +78,14 @@ def make_inputs(self, device: str) -> Inputs: } +@dataclasses.dataclass class CombineCase(_SparseCopyCase): @property def expected_bytes(self) -> int: # Adds scores read/write on top of the dispatch traffic. return super().expected_bytes + 4 * self.tokens * self.top_k * self.dtype.itemsize - def make_inputs(self, device: str) -> Inputs: + def make_inputs(self, device: torch.device) -> Inputs: top_experts = torch.randint(0, self.num_experts, (self.tokens, self.top_k), device=device) sparse_map = get_sparse_map(top_experts, self.num_experts) return { diff --git a/tools/benchmark/triton_kernels/bench_sparse_linear.py b/tools/benchmark/triton_kernels/bench_sparse_linear.py index 011d035e5..a5d91d91e 100644 --- a/tools/benchmark/triton_kernels/bench_sparse_linear.py +++ b/tools/benchmark/triton_kernels/bench_sparse_linear.py @@ -1,3 +1,7 @@ +"""MoE expert-grouped matmul: layer-1 (output_sparse, up-proj) and layer-2 +(input_inner_sparse, down-proj). Compares the Triton sparse kernel against a +per-expert pytorch loop reference.""" + import dataclasses import torch @@ -22,8 +26,11 @@ # Triton autotuning warmup needs to run only once per shape; make_inputs is # called many times per case (per variant, per fwd/fwd_bwd/memory pass). -_output_sparse_warmed_up: set[tuple] = set() -_input_inner_sparse_warmed_up: set[tuple] = set() +# Process-local — a fresh interpreter starts with empty caches, which is +# desired (autotuning state may differ across processes). +_WarmupKey = tuple[int, int, int, int, int, torch.dtype] +_output_sparse_warmed_up: set[_WarmupKey] = set() +_input_inner_sparse_warmed_up: set[_WarmupKey] = set() def _mask_padded_rows(candidate: dict[str, torch.Tensor], inputs: dict) -> dict[str, torch.Tensor]: @@ -83,13 +90,14 @@ def expected_flops(self) -> int: # 3 matmuls (fwd: lhs@rhs, bwd_lhs: grad@rhs.T, bwd_rhs: lhs.T@grad). return 3 * 2 * self.tokens * self.top_k * self.hidden * self.ffn_per_expert - def _make_sparse_map(self, device: str) -> SparseMap: + def _make_sparse_map(self, device: torch.device) -> SparseMap: top_experts = torch.randint(0, self.num_experts, (self.tokens, self.top_k), device=device) return get_sparse_map(top_experts, self.num_experts) +@dataclasses.dataclass class OutputSparseCase(_SparseLinearCase): - def make_inputs(self, device: str) -> Inputs: + def make_inputs(self, device: torch.device) -> Inputs: sparse_map = self._make_sparse_map(device) lhs_data = torch.randn(sparse_map.num_rows, self.hidden, dtype=self.dtype, device=device) rhs_data = torch.randn(self.hidden, self.ffn_per_expert * self.num_experts, dtype=self.dtype, device=device) @@ -110,8 +118,9 @@ def make_inputs(self, device: str) -> Inputs: } +@dataclasses.dataclass class InputInnerSparseCase(_SparseLinearCase): - def make_inputs(self, device: str) -> Inputs: + def make_inputs(self, device: torch.device) -> Inputs: sparse_map = self._make_sparse_map(device) lhs_data = torch.randn(sparse_map.num_rows, self.ffn_per_expert, dtype=self.dtype, device=device) rhs_data = torch.randn(self.ffn_per_expert * self.num_experts, self.hidden, dtype=self.dtype, device=device) diff --git a/tools/benchmark/triton_kernels/runner.py b/tools/benchmark/triton_kernels/runner.py index c0b45f91b..649a33a7c 100644 --- a/tools/benchmark/triton_kernels/runner.py +++ b/tools/benchmark/triton_kernels/runner.py @@ -13,8 +13,7 @@ import math import statistics import time -from collections.abc import Callable -from typing import Any +import typing import torch @@ -25,12 +24,12 @@ # system knows the previous rep's output buffers are no longer live. Without # this, max-autotune compiled functions raise "tensor output of CUDAGraphs has # been overwritten by a subsequent run" on the second call. -_cudagraph_mark_step_begin: Callable[[], None] | None = getattr( +_cudagraph_mark_step_begin: typing.Callable[[], None] | None = getattr( getattr(torch, "compiler", None), "cudagraph_mark_step_begin", None ) -def _guarded(fn: Callable[[], Any]) -> Callable[[], Any]: +def _guarded(fn: typing.Callable[[], typing.Any]) -> typing.Callable[[], typing.Any]: """Wrap fn so cudagraph_mark_step_begin() is called before each invocation. This tells the CUDA graph system that previous outputs are no longer live and can be overwritten, preventing 'overwritten by subsequent run' errors @@ -39,15 +38,15 @@ def _guarded(fn: Callable[[], Any]) -> Callable[[], Any]: if _cudagraph_mark_step_begin is None: return fn - def _wrapped() -> Any: + def _wrapped() -> typing.Any: _cudagraph_mark_step_begin() return fn() return _wrapped -Inputs = dict[str, Any] -VariantFn = Callable[[Inputs], Any] +Inputs = dict[str, typing.Any] +VariantFn = typing.Callable[[Inputs], typing.Any] @dataclasses.dataclass @@ -66,11 +65,11 @@ class Variant: # timing. Receives {name: tensor} and the full inputs dict; returns the # (possibly modified) dict. Use this to mask out don't-care regions so they # don't inflate RMS errors (e.g. uninitialized phantom rows in sparse buffers). - output_postprocess: Callable[[dict[str, torch.Tensor], Inputs], dict[str, torch.Tensor]] | None = None + output_postprocess: typing.Callable[[dict[str, torch.Tensor], Inputs], dict[str, torch.Tensor]] | None = None # Called between timing reps (outside the timed region) to restore any # input tensors the variant mutates in-place. Use this instead of cloning # inside the timed callable so the mutation cost is not measured. - reset_inputs: Callable[[Inputs], Any] | None = None + reset_inputs: typing.Callable[[Inputs], typing.Any] | None = None class Case: @@ -78,22 +77,35 @@ class Case: the kernel's shape parameters (e.g. rows, cols, dtype) and override `name` and `make_inputs`; the throughput properties are optional.""" - # Subclasses must provide these. - name: str - # Subclasses may override these (defaults skip the corresponding columns). - expected_bytes: int | None = None # bytes read+written; enables GB/s + %BW. - expected_flops: int | None = None # FLOPs performed; enables TFLOP/s + %FLOPs. - compute_dtype: torch.dtype | None = None # dtype of hot inputs, picks peak column. + @property + def name(self) -> str: + raise NotImplementedError() + + # Optional — defaults skip the corresponding columns. + @property + def expected_bytes(self) -> int | None: + """Bytes read+written; enables GB/s + %BW columns.""" + return None + + @property + def expected_flops(self) -> int | None: + """FLOPs performed; enables TFLOP/s + %FLOPs columns.""" + return None + + @property + def compute_dtype(self) -> torch.dtype | None: + """Dtype of hot inputs; picks the peak column for the %FLOPs computation.""" + return None - def make_inputs(self, device: str) -> Inputs: + def make_inputs(self, device: torch.device) -> Inputs: """Return a fresh dict of input tensors on `device`. Called once per variant per mode, after a global seed reset, so every variant sees identical inputs.""" - raise NotImplementedError + raise NotImplementedError() -def _device() -> str: - return "cuda" if torch.cuda.is_available() else "cpu" +def _device() -> torch.device: + return torch.device("cuda" if torch.cuda.is_available() else "cpu") def _seeded_inputs(case: Case, seed: int = 0) -> Inputs: @@ -133,7 +145,7 @@ class VariantResult: # --------------------------------------------------------------------------- timing -def _make_cache_flusher(size_bytes: int = 256 * 1024 * 1024) -> Callable[[], None]: +def _make_cache_flusher(size_bytes: int = 256 * 1024 * 1024) -> typing.Callable[[], None]: """Allocate a scratch buffer larger than any GPU L2 and zero it between reps to invalidate cached values (avoids over-optimistic timings).""" if not torch.cuda.is_available(): @@ -147,8 +159,8 @@ def flush() -> None: def bench_fn( - fn: Callable[[], Any], - reset: Callable[[], None] | None = None, + fn: typing.Callable[[], typing.Any], + reset: typing.Callable[[], None] | None = None, warmup_ms: float = 25.0, rep_ms: float = 100.0, min_reps: int = 5, @@ -232,7 +244,7 @@ def bench_fn( # --------------------------------------------------------------------------- memory -def measure_memory(fn: Callable[[], Any]) -> MemoryStats: +def measure_memory(fn: typing.Callable[[], typing.Any]) -> MemoryStats: """Run `fn` once and capture peak and final device memory. Must be called on a fresh GPU state (the caller resets stats before constructing inputs).""" if not torch.cuda.is_available(): @@ -266,7 +278,7 @@ def rms_relative_error(candidate: torch.Tensor, reference: torch.Tensor) -> floa return diff_rms / max(ref_rms, 1e-30) -def _as_output_dict(output: Any) -> dict[str, torch.Tensor]: +def _as_output_dict(output: typing.Any) -> dict[str, torch.Tensor]: """Normalize a variant's output into a {name: tensor} dict for comparison.""" if isinstance(output, torch.Tensor): return {"out": output} @@ -297,7 +309,7 @@ def _run_one_variant( if torch.cuda.is_available(): torch.cuda.empty_cache() - def _fwd_once() -> Any: + def _fwd_once() -> typing.Any: return variant.fwd(inputs) _guarded_fwd = _guarded(_fwd_once) @@ -330,7 +342,7 @@ def _fwd_once() -> Any: if torch.cuda.is_available(): torch.cuda.empty_cache() - def _fwd_bwd_once() -> Any: + def _fwd_bwd_once() -> typing.Any: return variant.fwd_bwd(inputs) _guarded_fwd_bwd = _guarded(_fwd_bwd_once) @@ -369,7 +381,9 @@ def _fwd_bwd_once() -> Any: torch.cuda.empty_cache() result.memory = measure_memory(_guarded(lambda: variant.fwd(fresh_inputs))) del fresh_inputs - except Exception as exc: # noqa: BLE001 + except ( + Exception + ) as exc: # noqa: BLE001 — variant failures are reported in the result column, not propagated, so a single broken kernel doesn't kill the rest of the sweep. result.error = f"{type(exc).__name__}: {exc}" return result @@ -467,8 +481,8 @@ def _unit_column( ((unit_label, unit_scale) for (unit_label, unit_scale) in units if unit_scale == 1.0), units[0] ) scaled = [v * scale if v is not None else None for v in canonical_values] - header = f"{prefix} {label}" if prefix else label - return header, _format_aligned(scaled) + column_header = f"{prefix} {label}" if prefix else label + return column_header, _format_aligned(scaled) def _percent_column(values: list[float | None]) -> list[str]: @@ -531,8 +545,8 @@ def _render_table( # variant-name column are merged into one (avoids a redundant title line). columns: list[tuple[str, list[str]]] = [(case.name, [r.variant_name for r in results])] - def _add(header: str, values: list[str]) -> None: - columns.append((header, values)) + def _add(column_header: str, values: list[str]) -> None: + columns.append((column_header, values)) if has_fwd: _add(*_unit_column("fwd", [r.fwd_timing.median_ms if r.fwd_timing else None for r in results], _TIME_UNITS)) @@ -576,8 +590,7 @@ def _time_for_throughput(r: VariantResult) -> float | None: for r in results: t_ms = _time_for_throughput(r) bandwidths.append(case.expected_bytes / (t_ms / 1000) if t_ms is not None else None) - header, values = _unit_column("", bandwidths, _BANDWIDTH_UNITS) - _add(header, values) + _add(*_unit_column("", bandwidths, _BANDWIDTH_UNITS)) if gpu_spec is not None: peak_bytes_per_s = gpu_spec.peak_bandwidth_gbps * 1e9 pct = [bw / peak_bytes_per_s if bw is not None else None for bw in bandwidths] @@ -588,8 +601,7 @@ def _time_for_throughput(r: VariantResult) -> float | None: for r in results: t_ms = _time_for_throughput(r) flop_rates.append(case.expected_flops / (t_ms / 1000) if t_ms is not None else None) - header, values = _unit_column("", flop_rates, _FLOPS_UNITS) - _add(header, values) + _add(*_unit_column("", flop_rates, _FLOPS_UNITS)) peak_tflops = gpu_spec.peak_tflops(case.compute_dtype) if gpu_spec and case.compute_dtype else None if peak_tflops is not None: peak_flops_per_s = peak_tflops * 1e12 @@ -609,7 +621,7 @@ def _time_for_throughput(r: VariantResult) -> float | None: if not any(r.error for r in results): columns.pop() - widths = [max(len(header), *(len(v) for v in values)) for header, values in columns] + widths = [max(len(column_header), *(len(v) for v in values)) for column_header, values in columns] separator = " " # First column (case name + variant names) is text — left-justify. All other @@ -639,7 +651,7 @@ def run_benchmark( rep_ms: float = 100.0, min_reps: int = 5, verbose: bool = False, - print_fn: Callable[[str], None] = print, + print_fn: typing.Callable[[str], None] = print, ) -> list[tuple[Case, list[VariantResult]]]: """Run all (case, variant) combinations and print one table per case. diff --git a/tools/benchmark/triton_kernels/utils.py b/tools/benchmark/triton_kernels/utils.py index 8887c4953..47997e031 100644 --- a/tools/benchmark/triton_kernels/utils.py +++ b/tools/benchmark/triton_kernels/utils.py @@ -1,6 +1,7 @@ -import dataclasses -from collections.abc import Callable -from typing import Any +"""Variant builders shared across bench files: pytorch eager / compiled / +fp32-reference factories and a `make_grad_reset` helper for fwd_bwd resets.""" + +import typing import torch @@ -15,7 +16,7 @@ def dtype_short(dtype: torch.dtype) -> str: return DataType.from_torch(dtype).short -def bench_main(benchmarks_fn: Callable) -> Callable: +def bench_main(benchmarks_fn: typing.Callable) -> typing.Callable: def run( verbose: bool = False, dtypes: tuple[torch.dtype, ...] | None = None, @@ -32,149 +33,134 @@ def run( return run -@dataclasses.dataclass(kw_only=True) -class PytorchVariant(Variant): - """Variant that calls a pytorch function on inputs picked by key. Used for - eager, torch.compile, and apex variants — each instance differs in `function` - while sharing the dispatch logic.""" - - function: Callable - input_keys: tuple[str, ...] - grad_input_keys: tuple[str, ...] = () - grad_output_key: str | None = None - output_key: str = "output" - - def __post_init__(self) -> None: - # Wire the inherited `fwd`/`fwd_bwd` callable fields to bound methods - # so subclasses can override the methods without touching the fields. - self.fwd = self._fwd - self.fwd_bwd = self._fwd_bwd - - def _fwd(self, inputs: Inputs) -> dict: - return {self.output_key: self.function(*(inputs[k] for k in self.input_keys))} - - def _fwd_bwd(self, inputs: Inputs) -> dict: - output = self.function(*(inputs[k] for k in self.input_keys)) - if self.grad_output_key is None: - output.backward() - else: - output.backward(inputs[self.grad_output_key]) - result = {self.output_key: output.detach()} - for key in self.grad_input_keys: - result[f"grad_{key}"] = inputs[key].grad - return result +def make_grad_reset(keys: tuple[str, ...]) -> typing.Callable[[Inputs], None]: + """Reset autograd `.grad` to None for the given input keys between reps. + `.backward()` accumulates into `.grad` on rep 2+, biasing fwd_bwd timing + via an extra read+write of the full grad tensor. Also resets + `param_grad_is_zero=True` on tensors with a `grad_buffer` (Fast-LLM + convention) so the next backward writes fresh instead of accumulating.""" + def reset(inputs: Inputs) -> None: + for key in keys: + tensor = inputs[key] + tensor.grad = None + if hasattr(tensor, "grad_buffer"): + tensor.param_grad_is_zero = True -@dataclasses.dataclass(kw_only=True) -class Fp32ReferenceVariant(PytorchVariant): - """Reference variant: upcasts every floating-point input to fp32 before - running the eager pytorch function. Re-attaches `requires_grad=True` on - `grad_input_keys` so backward sees a leaf tensor.""" + return reset - name: str = "fp32_reference" - is_reference: bool = True - def _fwd(self, inputs: Inputs) -> dict: - return super()._fwd(self._to_fp32(inputs)) +def _to_fp32(inputs: Inputs, grad_input_keys: tuple[str, ...]) -> Inputs: + """Upcast every floating-point input to fp32. Re-attach `requires_grad=True` + on `grad_input_keys` so backward sees a leaf tensor.""" + result = dict(inputs) + for key, value in inputs.items(): + if isinstance(value, torch.Tensor) and value.is_floating_point(): + float_value = value.float().detach() + result[key] = float_value.requires_grad_(True) if key in grad_input_keys else float_value + return result - def _fwd_bwd(self, inputs: Inputs) -> dict: - return super()._fwd_bwd(self._to_fp32(inputs)) - def _to_fp32(self, inputs: Inputs) -> Inputs: - result = dict(inputs) - for key, value in inputs.items(): - if isinstance(value, torch.Tensor) and value.is_floating_point(): - float_value = value.float().detach() - result[key] = float_value.requires_grad_(True) if key in self.grad_input_keys else float_value +def pytorch_variant( + name: str, + function: typing.Callable, + input_keys: tuple[str, ...], + *, + grad_input_keys: tuple[str, ...] = (), + grad_output_key: str | None = None, + output_key: str = "output", + is_reference: bool = False, + convert_to_fp32: bool = False, + reset_inputs: typing.Callable[[Inputs], typing.Any] | None = None, +) -> Variant: + """Build a Variant that calls `function(*[inputs[k] for k in input_keys])`. + Supports backward when `grad_input_keys` is non-empty; if `convert_to_fp32` + is set, all floating-point inputs are upcast to fp32 first (used by the + reference variant).""" + + def _prepare(inputs: Inputs) -> Inputs: + return _to_fp32(inputs, grad_input_keys) if convert_to_fp32 else inputs + + def fwd(inputs: Inputs) -> dict: + prepared = _prepare(inputs) + return {output_key: function(*(prepared[k] for k in input_keys))} + + def fwd_bwd(inputs: Inputs) -> dict: + prepared = _prepare(inputs) + output = function(*(prepared[k] for k in input_keys)) + if grad_output_key is None: + output.backward() + else: + output.backward(prepared[grad_output_key]) + result = {output_key: output.detach()} + for key in grad_input_keys: + result[f"grad_{key}"] = prepared[key].grad return result + return Variant( + name=name, + fwd=fwd, + fwd_bwd=fwd_bwd if grad_input_keys else None, + is_reference=is_reference, + reset_inputs=reset_inputs, + ) -@dataclasses.dataclass(kw_only=True) -class FwdOnlyPytorchVariant(Variant): - """Forward-only variant: calls a pytorch function with positional args - extracted via `unpack`. Used by bench_pointwise where there's no backward.""" - function: Callable - unpack: Callable[[Inputs], tuple] - - def __post_init__(self) -> None: - self.fwd = self._fwd - - def _fwd(self, inputs: Inputs) -> Any: - return self.function(*self.unpack(inputs)) - - -@dataclasses.dataclass(kw_only=True) -class Fp32FwdOnlyReferenceVariant(FwdOnlyPytorchVariant): - name: str = "fp32_reference" - is_reference: bool = True +def fwd_only_variant( + name: str, + function: typing.Callable, + unpack: typing.Callable[[Inputs], tuple], + *, + is_reference: bool = False, + convert_to_fp32: bool = False, +) -> Variant: + """Build a forward-only Variant. Used by bench_pointwise where there's no + backward; `unpack` extracts positional args from the inputs dict.""" + + def fwd(inputs: Inputs) -> typing.Any: + args = unpack(inputs) + if convert_to_fp32: + args = tuple( + arg.float() if isinstance(arg, torch.Tensor) and arg.is_floating_point() else arg for arg in args + ) + return function(*args) - def _fwd(self, inputs: Inputs) -> Any: - args = tuple( - arg.float() if isinstance(arg, torch.Tensor) and arg.is_floating_point() else arg - for arg in self.unpack(inputs) - ) - return self.function(*args) + return Variant(name=name, fwd=fwd, is_reference=is_reference) def standard_fwd_variants( - eager_function: Callable, - triton_function: Callable | None, - unpack: Callable[[Inputs], tuple], + eager_function: typing.Callable, + triton_function: typing.Callable | None, + unpack: typing.Callable[[Inputs], tuple], ) -> list[Variant]: """fp32_reference, pytorch_eager, pytorch_compiled, pytorch_compiled_max, and (if `TritonConfig.enabled()`) fast_llm_triton.""" variants: list[Variant] = [ - Fp32FwdOnlyReferenceVariant(function=eager_function, unpack=unpack), - FwdOnlyPytorchVariant(name="pytorch_eager", function=eager_function, unpack=unpack), - FwdOnlyPytorchVariant( - name="pytorch_compiled", - function=torch.compile(eager_function, mode="default", dynamic=False), - unpack=unpack, - ), - FwdOnlyPytorchVariant( - name="pytorch_compiled_max", - function=torch.compile(eager_function, mode="max-autotune-no-cudagraphs", dynamic=False), - unpack=unpack, + fwd_only_variant("fp32_reference", eager_function, unpack, is_reference=True, convert_to_fp32=True), + fwd_only_variant("pytorch_eager", eager_function, unpack), + fwd_only_variant("pytorch_compiled", torch.compile(eager_function, mode="default", dynamic=False), unpack), + fwd_only_variant( + "pytorch_compiled_max", + torch.compile(eager_function, mode="max-autotune-no-cudagraphs", dynamic=False), + unpack, ), ] if triton_function is not None and TritonConfig.enabled(): variants.append( - FwdOnlyPytorchVariant( - name="fast_llm_triton", - function=lambda *args: triton_function(*args, use_triton=True), - unpack=unpack, - ) + fwd_only_variant("fast_llm_triton", lambda *args: triton_function(*args, use_triton=True), unpack) ) return variants -def make_grad_reset(keys: tuple[str, ...]) -> Callable[[Inputs], None]: - """Reset autograd `.grad` to None for the given input keys between reps. - `.backward()` accumulates into `.grad` on rep 2+, biasing fwd_bwd timing - via an extra read+write of the full grad tensor. Also resets - `param_grad_is_zero=True` on tensors with a `grad_buffer` (Fast-LLM - convention) so the next backward writes fresh instead of accumulating.""" - - def reset(inputs: Inputs) -> None: - for key in keys: - tensor = inputs[key] - tensor.grad = None - if hasattr(tensor, "grad_buffer"): - tensor.param_grad_is_zero = True - - return reset - - def standard_fwd_bwd_pytorch_variants( - eager_function: Callable, + eager_function: typing.Callable, input_keys: tuple[str, ...], grad_input_keys: tuple[str, ...], *, grad_output_key: str | None = None, output_key: str = "output", - reset_inputs: Callable[[Inputs], None] | None = None, - extra_functions: dict[str, Callable] | None = None, + reset_inputs: typing.Callable[[Inputs], None] | None = None, + extra_functions: dict[str, typing.Callable] | None = None, eager_name: str = "pytorch_eager", enable_max_autotune: bool = True, ) -> list[Variant]: @@ -192,22 +178,18 @@ def standard_fwd_bwd_pytorch_variants( "reset_inputs": reset_inputs, } variants: list[Variant] = [ - Fp32ReferenceVariant(function=eager_function, **common), - PytorchVariant(name=eager_name, function=eager_function, **common), - PytorchVariant( - name="pytorch_compiled", - function=torch.compile(eager_function, mode="default", dynamic=False), - **common, - ), + pytorch_variant("fp32_reference", eager_function, is_reference=True, convert_to_fp32=True, **common), + pytorch_variant(eager_name, eager_function, **common), + pytorch_variant("pytorch_compiled", torch.compile(eager_function, mode="default", dynamic=False), **common), ] if enable_max_autotune: variants.append( - PytorchVariant( - name="pytorch_compiled_max", - function=torch.compile(eager_function, mode="max-autotune-no-cudagraphs", dynamic=False), + pytorch_variant( + "pytorch_compiled_max", + torch.compile(eager_function, mode="max-autotune-no-cudagraphs", dynamic=False), **common, ) ) for name, function in (extra_functions or {}).items(): - variants.append(PytorchVariant(name=name, function=function, **common)) + variants.append(pytorch_variant(name, function, **common)) return variants