diff --git a/fast_llm/engine/config_utils/data_type.py b/fast_llm/engine/config_utils/data_type.py index 27709a8bb..c8543b782 100644 --- a/fast_llm/engine/config_utils/data_type.py +++ b/fast_llm/engine/config_utils/data_type.py @@ -83,6 +83,11 @@ def triton(self) -> "tl.dtype": _set_triton_dtype_map() return _TRITON_DTYPE_MAP[self] + @property + def short(self) -> str: + """Abbreviated name (bf16, fp32, ...) when one is defined, else the full name.""" + return _DTYPE_SHORT_NAME_MAP.get(self, self.value) + _KNOWN_DATA_TYPE_PREFIXES = {"DataType", "numpy", "np", "torch", "triton.language", "tl"} @@ -92,6 +97,7 @@ def triton(self) -> "tl.dtype": "fp16": DataType.float16, "bf16": DataType.bfloat16, } +_DTYPE_SHORT_NAME_MAP = {v: k for k, v in _DTYPE_ALT_NAME_MAP_INV.items()} _TORCH_DTYPE_MAP: dict[DataType, "torch.dtype"] = {} _TORCH_DTYPE_MAP_INV: dict["torch.dtype", DataType] = {} diff --git a/fast_llm/functional/triton/sparse_linear.py b/fast_llm/functional/triton/sparse_linear.py index 15af789d7..14b15b319 100644 --- a/fast_llm/functional/triton/sparse_linear.py +++ b/fast_llm/functional/triton/sparse_linear.py @@ -232,6 +232,7 @@ def output_sparse_matmul_kernel( expert_ends = tl.load(expert_ends_ptr + sparse_range, mask=sparse_range < sparse_dim, other=row_dim) sparse_index = tl.sum((expert_ends <= row_offset).to(tl.int64)) # noqa if sparse_index == sparse_dim: + # Phantom block past the last expert; the caller is expected to ignore these rows. return col_dense_offset = col_sparse_offset + sparse_index * col_sparse_dim @@ -248,7 +249,7 @@ def output_sparse_matmul_kernel( for k in range(1, inner_dim // block_size_inner): lhs_ptr += block_size_inner * lhs_stride_inner rhs_ptr += block_size_inner * rhs_stride_inner - out += tl.dot(tl.load(lhs_ptr), tl.load(rhs_ptr)) + out += tl.dot(tl.load(lhs_ptr), tl.load(rhs_ptr), out_dtype=tl.float32) if accumulate: out += tl.load(out_ptr) @@ -355,6 +356,7 @@ def input_inner_sparse_matmul_kernel( expert_ends = tl.load(expert_ends_ptr + sparse_range, mask=sparse_range < sparse_dim, other=row_dim) sparse_index = tl.sum((expert_ends <= row_offset).to(tl.int64)) # noqa if sparse_index == sparse_dim: + # Phantom block past the last expert; the caller is expected to ignore these rows. return inner_dense_offset = sparse_index * inner_sparse_dim col_offset = pid_col * block_size_col @@ -374,7 +376,7 @@ def input_inner_sparse_matmul_kernel( for k in range(1, inner_sparse_dim // block_size_inner): lhs_ptr += block_size_inner * lhs_stride_inner rhs_ptr += block_size_inner * rhs_stride_inner - out += tl.dot(tl.load(lhs_ptr), tl.load(rhs_ptr)) + out += tl.dot(tl.load(lhs_ptr), tl.load(rhs_ptr), out_dtype=tl.float32) if accumulate: out += tl.load(out_ptr) @@ -497,6 +499,7 @@ def input_row_sparse_matmul_kernel( out = tl.dot( tl.load(lhs_ptr + inner_range[None, :] * lhs_stride_inner, mask=mask[None, :], other=0), tl.load(rhs_ptr + inner_range[:, None] * rhs_stride_inner, mask=mask[:, None], other=0), + out_dtype=tl.float32, ) for i in range(1, tl.cdiv(inner_end - inner_offset, block_size_inner)): inner_range += block_size_inner @@ -504,6 +507,7 @@ def input_row_sparse_matmul_kernel( out += tl.dot( tl.load(lhs_ptr + inner_range[None, :] * lhs_stride_inner, mask=mask[None, :], other=0), tl.load(rhs_ptr + inner_range[:, None] * rhs_stride_inner, mask=mask[:, None], other=0), + out_dtype=tl.float32, ) if accumulate: @@ -578,7 +582,7 @@ def backward(ctx, grad_out: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, N grad_out = grad_out.contiguous() lhs, rhs = ctx.saved_tensors grad_lhs = input_inner_sparse_matmul(grad_out, rhs.t(), ctx.sparse_map) - grad_rhs = input_row_sparse_matmul(lhs.t(), grad_out, ctx.sparse_map).t() + grad_rhs = input_row_sparse_matmul(grad_out.t(), lhs, ctx.sparse_map).t() return grad_lhs, grad_rhs, None @@ -597,7 +601,7 @@ def backward(ctx, grad_out: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, N grad_out = grad_out.contiguous() lhs, rhs = ctx.saved_tensors grad_lhs = output_sparse_matmul(grad_out, rhs.t(), ctx.sparse_map) - grad_rhs = input_row_sparse_matmul(grad_out.t(), lhs, ctx.sparse_map) + grad_rhs = input_row_sparse_matmul(lhs.t(), grad_out, ctx.sparse_map) return grad_lhs, grad_rhs, None diff --git a/fast_llm/layers/common/normalization/normalization.py b/fast_llm/layers/common/normalization/normalization.py index 2858b9370..4f92223d2 100644 --- a/fast_llm/layers/common/normalization/normalization.py +++ b/fast_llm/layers/common/normalization/normalization.py @@ -22,16 +22,16 @@ try: import fused_layer_norm_cuda # noqa - _fused_normalization_available = torch.cuda.is_available() + fused_normalization_available = torch.cuda.is_available() except ImportError: - _fused_normalization_available = False + fused_normalization_available = False try: import fast_layer_norm # noqa - _fast_normalization_available = torch.cuda.is_available() + fast_normalization_available = torch.cuda.is_available() except ImportError: - _fast_normalization_available = False + fast_normalization_available = False try: @@ -79,7 +79,7 @@ class FastLayerNorm(torch.autograd.Function): def forward( ctx, input_: torch.Tensor, normalized_shape: torch.Size, weight: torch.Tensor, bias: torch.Tensor, eps: float ) -> torch.Tensor: # noqa - assert _fast_normalization_available + assert fast_normalization_available Assert.incl(normalized_shape.numel(), _PERSIST_LN_SIZES) output, _, inv_var = fast_layer_norm.ln_fwd(input_, weight, bias, eps) ctx.save_for_backward(output, weight, bias, inv_var) @@ -110,7 +110,7 @@ class FusedLayerNorm(torch.autograd.Function): def forward( ctx, input_: torch.Tensor, normalized_shape: torch.Size, weight: torch.Tensor, bias: torch.Tensor, eps: float ) -> torch.Tensor: # noqa - assert _fused_normalization_available + assert fused_normalization_available ctx.eps = eps ctx.normalized_shape = normalized_shape output, _, inv_var = fused_layer_norm_cuda.forward_affine(input_, normalized_shape, weight, bias, eps) @@ -136,7 +136,7 @@ class FusedRMSNorm(torch.autograd.Function): def forward( ctx, input_: torch.Tensor, normalized_shape: torch.Size, weight: torch.Tensor, eps: float ) -> torch.Tensor: # noqa - assert _fused_normalization_available + assert fused_normalization_available ctx.eps = eps ctx.normalized_shape = normalized_shape output, inv_var = fused_layer_norm_cuda.rms_forward_affine(input_, normalized_shape, weight, eps) @@ -185,7 +185,7 @@ def __init__(self, config: ConfigType, hidden_dim: TensorDim, lr_scale: float | implementation = self._config.implementation if implementation == NormalizationImplementation.auto: if ( - _fast_normalization_available + fast_normalization_available and hidden_dim.size in _PERSIST_LN_SIZES and not self._config.zero_centered ): @@ -193,7 +193,7 @@ def __init__(self, config: ConfigType, hidden_dim: TensorDim, lr_scale: float | elif TritonConfig.enabled(torch.device("cuda")) or self._config.zero_centered: log_main_rank("Fast layer norm unavailable, using backup triton implementation.") implementation = NormalizationImplementation.triton - elif _fused_normalization_available: + elif fused_normalization_available: log_main_rank("Fast layer norm unavailable, using backup fused implementation.") implementation = NormalizationImplementation.fused else: @@ -261,7 +261,7 @@ def __init__(self, config: ConfigType, hidden_dim: TensorDim, lr_scale: float | if implementation == NormalizationImplementation.auto: if TritonConfig.enabled(torch.device("cuda")) or self._config.zero_centered: implementation = NormalizationImplementation.triton - elif _fused_normalization_available: + elif fused_normalization_available: log_main_rank("Triton RMS norm unavailable, using fused implementation.") implementation = NormalizationImplementation.fused else: diff --git a/tests/functional/test_sparse_matmul.py b/tests/functional/test_sparse_matmul.py index 0ebf9c5a5..7269e893f 100644 --- a/tests/functional/test_sparse_matmul.py +++ b/tests/functional/test_sparse_matmul.py @@ -6,6 +6,8 @@ from fast_llm.functional.triton.sparse_copy import SparseMap from fast_llm.functional.triton.sparse_linear import ( + InputSparseLinear, + OutputSparseLinear, dense_matmul, input_inner_sparse_matmul, input_row_sparse_matmul, @@ -74,7 +76,24 @@ def normal(self, dim_0: int, dim_1: int, device: torch.device) -> torch.Tensor: dense_dim=256, sparse_dim=512, expert_ends=(128, 256, 256, 384), - tokens_per_expert=(52, 125, 0, 97), + tokens_per_expert=(52, 125, 0, 97), # expert 2 has zero real tokens + ), + # Single expert — the simplest non-trivial case; also exercises the no-padding path. + _SparseTestData( + dense_dim=512, + sparse_dim=256, + expert_ends=(256,), + tokens_per_expert=(200,), + ), + # Four experts, fully packed (no padding rows) — exercises the pad_begin == expert_end path. + # Expert sizes must be multiples of the largest autotune block_size_row (128); otherwise + # blocks straddle expert boundaries and the kernel's "sparse_index constant within a block" + # assumption breaks. + _SparseTestData( + dense_dim=384, + sparse_dim=128, + expert_ends=(128, 256, 384, 512), + tokens_per_expert=(128, 128, 128, 128), ), ) @@ -151,3 +170,64 @@ def test_input_row_sparse_matmul(sparse_test_data, testing_device): ) Assert.rms_close(output, output_ref, 1e-3) + + +# --------------------------------------------------------------------------- autograd wrappers + + +def _sparse_linear_ref(lhs: torch.Tensor, rhs: torch.Tensor, data: _SparseTestData, expert_axis: int) -> torch.Tensor: + """Per-expert matmul reference; rows past `expert_pad_begins` are zero in the output.""" + rhs_per_expert = rhs.chunk(data.num_experts, dim=expert_axis) + out = lhs.new_zeros(data.token_dim, rhs_per_expert[0].shape[1]) + for i, (begin, end) in enumerate(zip(data.expert_begins, data.expert_pad_begins, strict=True)): + out[begin:end] = lhs[begin:end] @ rhs_per_expert[i] + return out + + +def _zero_padded_rows(tensor: torch.Tensor, data: _SparseTestData) -> torch.Tensor: + # The autograd kernels treat padded tokens as regular ones; forward output and grad_lhs + # contain matmul-of-random garbage in [pad_begin, expert_end). Zero those rows so the + # comparison vs the reference (which produces zeros there) reflects only real-row error. + masked = tensor.clone() + for begin, end in zip(data.expert_pad_begins, data.expert_ends, strict=True): + if end > begin: + masked[begin:end] = 0 + return masked + + +@requires_triton +@pytest.mark.slow +@pytest.mark.parametrize("sparse_test_data", _SPARSE_TEST_DATAS) +@pytest.mark.parametrize( + "autograd_class,expert_axis", + [(OutputSparseLinear, 1), (InputSparseLinear, 0)], + ids=["output_sparse", "input_sparse"], +) +def test_sparse_linear_autograd(sparse_test_data, testing_device, autograd_class, expert_axis): + # `expert_axis` is the rhs axis split per expert. Matmul contracts rhs's axis 0, so + # OutputSparseLinear (expert_axis=1) splits the output dim, InputSparseLinear (expert_axis=0) + # splits the contracting dim. + if expert_axis == 1: + lhs_features, out_features = sparse_test_data.dense_dim, sparse_test_data.sparse_dim + rhs_shape = (sparse_test_data.dense_dim, sparse_test_data.sparse_dim_expanded) + else: + lhs_features, out_features = sparse_test_data.sparse_dim, sparse_test_data.dense_dim + rhs_shape = (sparse_test_data.sparse_dim_expanded, sparse_test_data.dense_dim) + + lhs = sparse_test_data.normal(sparse_test_data.token_dim, lhs_features, testing_device) + rhs = sparse_test_data.normal(*rhs_shape, testing_device) + grad_output = sparse_test_data.normal(sparse_test_data.token_dim, out_features, testing_device) + + lhs_ref = lhs.detach().requires_grad_(True) + rhs_ref = rhs.detach().requires_grad_(True) + out_ref = _sparse_linear_ref(lhs_ref, rhs_ref, sparse_test_data, expert_axis) + out_ref.backward(grad_output) + + lhs_t = lhs.detach().requires_grad_(True) + rhs_t = rhs.detach().requires_grad_(True) + out_t = autograd_class.apply(lhs_t, rhs_t, sparse_test_data.get_sparse_map(testing_device)) + out_t.backward(grad_output) + + Assert.rms_close(_zero_padded_rows(out_t, sparse_test_data), out_ref, 1e-3) + Assert.rms_close(_zero_padded_rows(lhs_t.grad, sparse_test_data), lhs_ref.grad, 1e-3) + Assert.rms_close(rhs_t.grad, rhs_ref.grad, 1e-3) diff --git a/tests/tools/test_triton_benchmark.py b/tests/tools/test_triton_benchmark.py new file mode 100644 index 000000000..cd2dd3db3 --- /dev/null +++ b/tests/tools/test_triton_benchmark.py @@ -0,0 +1,94 @@ +""" +Smoke tests for all benchmark modules. + +One test per sub-benchmark (kernel): inputs are tiny so the runner code path is +exercised quickly without requiring a full benchmark run. + +Patches applied to keep each test under ~100 ms: +- torch.compile disabled (avoids JIT cold-start). +- fast_llm_triton variants replaced with fp32 reference (no Triton compilation; + kernel correctness is covered by the main test suite). +- TritonConfig.enabled → False (prevents make_inputs warmup in sparse_linear). +- _cudagraph_mark_step_begin → None and synchronize → no-op (both cause C-level + CUDA syncs per fn() call that dominate the wall time without this). +""" + +import dataclasses + +import pytest +import torch + +import tools.benchmark.runner as _bench_runner +from fast_llm.functional.config import TritonConfig +from fast_llm.functional.triton import triton_interpret +from tools.benchmark import ( + bench_entropy_loss, + bench_grpo_loss, + bench_mlp_activation, + bench_normalization, + bench_pointwise, + bench_rotary, + bench_sparse_copy, + bench_sparse_linear, +) +from tools.benchmark.runner import run_benchmark + +_DTYPES = (torch.float32,) + +# sparse_copy and sparse_linear use tl.histogram, which has unfixed bugs in the +# Triton interpreter. Skip them in interpreter mode; they're covered on GPU. +_INTERPRETER_SKIP = { + "sparse_copy: dispatch", + "sparse_copy: combine", + "sparse_linear: output_sparse (layer 1 / up-proj)", + "sparse_linear: input_inner_sparse (layer 2 / down-proj)", +} + +_SKIP_VARIANTS = {"pytorch_compiled", "pytorch_compiled_max"} + + +def _build_params() -> list: + modules_and_shapes = [ + (bench_entropy_loss, {"shapes": [(64, 256)]}), + (bench_grpo_loss, {"shapes": [(64, 256)]}), + (bench_mlp_activation, {"shapes": [(64, 128)]}), + (bench_normalization, {"shapes": [(64, 128)]}), + (bench_pointwise, {"shapes": [1024]}), + (bench_rotary, {"shapes": [(64, 4, 64)]}), + (bench_sparse_copy, {"shapes": [(64, 2, 4, 128)]}), + (bench_sparse_linear, {"shapes": [(64, 2, 4, 256, 256)]}), + ] + params = [] + for module, kwargs in modules_and_shapes: + for name, cases, variants in module.benchmarks(dtypes=_DTYPES, **kwargs): + params.append(pytest.param(name, cases, variants, id=name)) + return params + + +_PARAMS = _build_params() + + +@pytest.fixture(autouse=True) +def _patch_benchmark_env(monkeypatch): + import torch._dynamo + + monkeypatch.setattr(torch._dynamo.config, "disable", True) + monkeypatch.setattr(TritonConfig, "enabled", lambda *a, **kw: False) + monkeypatch.setattr(_bench_runner, "_cudagraph_mark_step_begin", None) + monkeypatch.setattr(torch.cuda, "synchronize", lambda: None) + + +@pytest.mark.parametrize("name,cases,variants", _PARAMS) +def test_triton_benchmark(name, cases, variants): + if triton_interpret and name in _INTERPRETER_SKIP: + pytest.skip("tl.histogram is broken in the Triton interpreter") + + # Replace fast_llm_triton fwd/fwd_bwd with the fp32 reference so no Triton + # kernels are compiled. The runner still exercises the full variant code path. + variants = [v for v in variants if v.name not in _SKIP_VARIANTS] + ref = next(v for v in variants if v.is_reference) + variants = [ + dataclasses.replace(v, fwd=ref.fwd, fwd_bwd=ref.fwd_bwd) if v.name == "fast_llm_triton" else v + for v in variants + ] + run_benchmark(name, cases, variants, warmup_ms=0, rep_ms=0, min_reps=1) diff --git a/tools/__init__.py b/tools/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tools/benchmark/__init__.py b/tools/benchmark/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tools/benchmark/__main__.py b/tools/benchmark/__main__.py new file mode 100644 index 000000000..102d72e63 --- /dev/null +++ b/tools/benchmark/__main__.py @@ -0,0 +1,79 @@ +""" +CLI entry point for the Fast-LLM Triton kernel benchmarking suite. + +Usage: + python -m tools.benchmark + +Available kernels are discovered dynamically from `bench_*.py` files in this +package. Each such module must expose a `run(verbose: bool = False)` callable. +""" + +import argparse +import importlib +import logging +import pkgutil +import warnings + +# Each bench file compiles the same function with multiple shapes/dtypes; the +# default cache size (8) is too small, causing dynamo to give up and fall back +# to eager. Bump it before any `torch.compile`-decorated code runs. +import torch._dynamo + +import tools.benchmark as _pkg +from fast_llm.engine.config_utils.data_type import DataType + +torch._dynamo.config.cache_size_limit = 64 + +# In-place ops (copy_, fill_, add with out=) emit "skipping cudagraphs due to +# mutated inputs" when using max-autotune. The fallback is correct; suppress noise. +warnings.filterwarnings("ignore", message=".*[Ss]kipping (cuda|CUDA)[Gg]raphs.*") +logging.getLogger("torch._inductor.cudagraph_trees").setLevel(logging.ERROR) + + +def _list_benchmarks() -> dict[str, str]: + """Map short kernel name → fully-qualified bench module name.""" + names = {} + for info in pkgutil.iter_modules(_pkg.__path__): + if info.name.startswith("bench_"): + names[info.name.removeprefix("bench_")] = f"tools.benchmark.{info.name}" + return names + + +def main() -> None: + benches = _list_benchmarks() + parser = argparse.ArgumentParser( + prog="python -m tools.benchmark", + description="Benchmark Fast-LLM Triton kernels against PyTorch alternatives.", + ) + parser.add_argument( + "kernels", + nargs="*", + choices=sorted(benches), + metavar="kernel", + help="Which kernels to benchmark. If omitted, run all kernels.", + ) + parser.add_argument( + "-v", + "--verbose", + action="store_true", + help="Show additional timing columns (mean, min, max).", + ) + parser.add_argument( + "-d", + "--dtypes", + nargs="+", + default=["bfloat16"], + metavar="DTYPE", + help="Fast-LLM DataType names to sweep (default: bfloat16). " + "Examples: bfloat16 float32 fp16. Accepts alternate names like bf16.", + ) + args = parser.parse_args() + dtypes = [DataType(d).torch for d in args.dtypes] + + selected = args.kernels or sorted(benches) + for kernel in selected: + importlib.import_module(benches[kernel]).run(verbose=args.verbose, dtypes=dtypes) + + +if __name__ == "__main__": + main() diff --git a/tools/benchmark/bench_entropy_loss.py b/tools/benchmark/bench_entropy_loss.py new file mode 100644 index 000000000..bc14e8bf9 --- /dev/null +++ b/tools/benchmark/bench_entropy_loss.py @@ -0,0 +1,176 @@ +import dataclasses + +import torch +import torch.nn.functional as F + +from fast_llm.functional.config import EntropyLossType, TargetFormat, TritonConfig +from fast_llm.functional.triton.entropy_loss import triton_entropy_loss_forward_backward +from fast_llm.functional.triton.z_loss import triton_z_loss_forward_backward +from tools.benchmark.runner import Case, Inputs, Variant +from tools.benchmark.utils import bench_main, dtype_short, standard_fwd_bwd_pytorch_variants + +# (tokens, vocab_size) +_SHAPES = [ + (4096, 32768), # Llama-2 vocab + (4096, 65536), + (4096, 131072), # Llama-3 vocab +] + + +@dataclasses.dataclass +class _EntropyCase(Case): + tokens: int + vocab: int + dtype: torch.dtype + + @property + def name(self) -> str: + return f"({self.tokens}, {self.vocab}) {dtype_short(self.dtype)}" + + @property + def compute_dtype(self) -> torch.dtype: + return self.dtype + + @property + def expected_flops(self) -> int: + # fwd ≈ 3*vocab per token, bwd ≈ vocab. + return 4 * self.tokens * self.vocab + + +class EntropyLabelCase(_EntropyCase): + @property + def expected_bytes(self) -> int: + # 2× logits + small labels traffic. + return 2 * self.tokens * self.vocab * self.dtype.itemsize + self.tokens * 4 + + def make_inputs(self, device: str) -> Inputs: + return { + "logits": torch.randn(self.tokens, self.vocab, dtype=self.dtype, device=device, requires_grad=True), + "labels": torch.randint(0, self.vocab, (self.tokens,), dtype=torch.long, device=device), + } + + +class EntropyDistCase(_EntropyCase): + @property + def expected_bytes(self) -> int: + # 2× logits + 1× target_logits. + return 3 * self.tokens * self.vocab * self.dtype.itemsize + + def make_inputs(self, device: str) -> Inputs: + return { + "logits": torch.randn(self.tokens, self.vocab, dtype=self.dtype, device=device, requires_grad=True), + "target_logits": torch.randn(self.tokens, self.vocab, dtype=self.dtype, device=device), + } + + +def _reset_logits_grad(inputs: dict) -> None: + inputs["logits"].grad = None + + +def _ce_labels_eager(logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: + return F.cross_entropy(logits, labels) + + +def _ce_dist_eager(logits: torch.Tensor, target_logits: torch.Tensor) -> torch.Tensor: + return F.cross_entropy(logits, target_logits.softmax(dim=-1)) + + +def _reverse_kl_eager(logits: torch.Tensor, target_logits: torch.Tensor) -> torch.Tensor: + return F.kl_div(target_logits.log_softmax(dim=-1), logits.softmax(dim=-1), reduction="batchmean") + + +def _z_loss_eager(logits: torch.Tensor) -> torch.Tensor: + log_z = torch.logsumexp(logits.float(), dim=-1) + return (log_z * log_z).mean() + + +def _entropy_variants(eager_function, input_keys, triton_kwargs=None) -> list[Variant]: + """Variants for the 3 entropy_loss kernels that share `triton_entropy_loss_forward_backward`.""" + target_key = input_keys[1] + triton_kwargs = triton_kwargs or {} + + def triton_fwd(inputs: dict) -> dict: + loss, _ = triton_entropy_loss_forward_backward( + inputs["logits"], inputs[target_key], loss_mask=None, grad_output=None, **triton_kwargs + ) + return {"loss": loss} + + def triton_fwd_bwd(inputs: dict) -> dict: + loss, grad_logits = triton_entropy_loss_forward_backward( + inputs["logits"], inputs[target_key], loss_mask=None, grad_output=1.0, **triton_kwargs + ) + return {"loss": loss, "grad_logits": grad_logits} + + variants = standard_fwd_bwd_pytorch_variants( + eager_function, + input_keys=input_keys, + grad_input_keys=("logits",), + output_key="loss", + reset_inputs=_reset_logits_grad, + ) + if TritonConfig.enabled(): + variants.append(Variant(name="fast_llm_triton", fwd=triton_fwd, fwd_bwd=triton_fwd_bwd)) + return variants + + +def _z_loss_triton_fwd(inputs: dict) -> dict: + loss, _ = triton_z_loss_forward_backward(inputs["logits"], loss_mask=None, grad_output=None) + return {"loss": loss} + + +def _z_loss_triton_fwd_bwd(inputs: dict) -> dict: + loss, grad_logits = triton_z_loss_forward_backward(inputs["logits"], loss_mask=None, grad_output=1.0) + return {"loss": loss, "grad_logits": grad_logits} + + +def benchmarks( + dtypes: tuple[torch.dtype, ...], + shapes: list[tuple[int, int]] | None = None, +) -> list[tuple[str, list, list]]: + shapes = shapes if shapes is not None else _SHAPES + label_cases = [EntropyLabelCase(tokens=t, vocab=v, dtype=d) for d in dtypes for (t, v) in shapes] + dist_cases = [EntropyDistCase(tokens=t, vocab=v, dtype=d) for d in dtypes for (t, v) in shapes] + z_loss_variants = standard_fwd_bwd_pytorch_variants( + _z_loss_eager, + input_keys=("logits",), + grad_input_keys=("logits",), + output_key="loss", + reset_inputs=_reset_logits_grad, + ) + if TritonConfig.enabled(): + z_loss_variants.append(Variant(name="fast_llm_triton", fwd=_z_loss_triton_fwd, fwd_bwd=_z_loss_triton_fwd_bwd)) + return [ + ( + "entropy_loss: cross_entropy (labels)", + label_cases, + _entropy_variants(_ce_labels_eager, input_keys=("logits", "labels")), + ), + ( + "entropy_loss: cross_entropy (logits)", + dist_cases, + _entropy_variants( + _ce_dist_eager, + input_keys=("logits", "target_logits"), + triton_kwargs={ + "target_format": TargetFormat.logits, + "entropy_loss_type": EntropyLossType.cross_entropy, + }, + ), + ), + ( + "entropy_loss: reverse_kl (logits)", + dist_cases, + _entropy_variants( + _reverse_kl_eager, + input_keys=("logits", "target_logits"), + triton_kwargs={ + "target_format": TargetFormat.logits, + "entropy_loss_type": EntropyLossType.reverse_kl, + }, + ), + ), + ("entropy_loss: z_loss", label_cases, z_loss_variants), + ] + + +run = bench_main(benchmarks) diff --git a/tools/benchmark/bench_grpo_loss.py b/tools/benchmark/bench_grpo_loss.py new file mode 100644 index 000000000..4b7353d12 --- /dev/null +++ b/tools/benchmark/bench_grpo_loss.py @@ -0,0 +1,122 @@ +import dataclasses + +import torch + +from fast_llm.functional.config import TritonConfig +from fast_llm.functional.triton.grpo_loss import triton_grpo_loss_forward_backward +from tools.benchmark.runner import Case, Inputs, Variant +from tools.benchmark.utils import bench_main, dtype_short, standard_fwd_bwd_pytorch_variants + +_SHAPES = [ + (4096, 32768), + (4096, 65536), + (4096, 131072), +] +_EPSILON_LOW = 0.2 +_EPSILON_HIGH = 0.2 + + +@dataclasses.dataclass +class GrpoLossCase(Case): + tokens: int + vocab: int + dtype: torch.dtype + + @property + def name(self) -> str: + return f"({self.tokens}, {self.vocab}) {dtype_short(self.dtype)}" + + @property + def compute_dtype(self) -> torch.dtype: + return self.dtype + + @property + def expected_bytes(self) -> int: + # 3× logits traffic (read fwd, read+write bwd) + per-token scalars: + # labels (int64 = 8B), advantages (fp32 = 4B), old_log_probs (fp32 = 4B). + return 3 * self.tokens * self.vocab * self.dtype.itemsize + self.tokens * 16 + + @property + def expected_flops(self) -> int: + # softmax (fwd) + grad (bwd) ≈ 14 FLOPs/element. + return 14 * self.tokens * self.vocab + + def make_inputs(self, device: str) -> Inputs: + return { + "logits": torch.randn(self.tokens, self.vocab, dtype=self.dtype, device=device, requires_grad=True), + "labels": torch.randint(0, self.vocab, (self.tokens,), dtype=torch.long, device=device), + "advantages": torch.randn(self.tokens, dtype=torch.float32, device=device), + "old_log_probs": torch.randn(self.tokens, dtype=torch.float32, device=device) - 5.0, + } + + +def _grpo_eager(logits: torch.Tensor, labels: torch.Tensor, advantages: torch.Tensor, old_log_probs: torch.Tensor): + log_probs = logits.float().log_softmax(-1) + # clamp + labels>=0 guards mirror production code that handles ignore_index=-100; + # labels here are always non-negative (randint), so the masks are dead in this benchmark. + new_log_probs = log_probs.gather(-1, labels.clamp(min=0).unsqueeze(-1)).squeeze(-1) + new_log_probs = torch.where(labels >= 0, new_log_probs, torch.zeros_like(new_log_probs)) + ratio = (new_log_probs - old_log_probs).exp() + clipped_ratio = ratio.clamp(1.0 - _EPSILON_LOW, 1.0 + _EPSILON_HIGH) + per_token_loss = torch.where( + labels >= 0, + -torch.minimum(ratio * advantages, clipped_ratio * advantages), + torch.zeros_like(ratio), + ) + return per_token_loss.mean() + + +def _reset_logits_grad(inputs: dict) -> None: + inputs["logits"].grad = None + + +def _triton_fwd(inputs: dict) -> dict: + loss, _, _ = triton_grpo_loss_forward_backward( + inputs["logits"], + inputs["labels"], + inputs["advantages"], + inputs["old_log_probs"], + grad_output=None, + epsilon_low=_EPSILON_LOW, + epsilon_high=_EPSILON_HIGH, + ) + return {"loss": loss} + + +def _triton_fwd_bwd(inputs: dict) -> dict: + loss, grad_logits, _ = triton_grpo_loss_forward_backward( + inputs["logits"], + inputs["labels"], + inputs["advantages"], + inputs["old_log_probs"], + grad_output=1.0, + epsilon_low=_EPSILON_LOW, + epsilon_high=_EPSILON_HIGH, + ) + return {"loss": loss, "grad_logits": grad_logits} + + +def benchmarks( + dtypes: tuple[torch.dtype, ...], + shapes: list[tuple[int, int]] | None = None, +) -> list[tuple[str, list, list]]: + shapes = shapes if shapes is not None else _SHAPES + variants = standard_fwd_bwd_pytorch_variants( + _grpo_eager, + input_keys=("logits", "labels", "advantages", "old_log_probs"), + grad_input_keys=("logits",), + output_key="loss", + reset_inputs=_reset_logits_grad, + ) + if TritonConfig.enabled(): + variants.append(Variant(name="fast_llm_triton", fwd=_triton_fwd, fwd_bwd=_triton_fwd_bwd)) + return [ + ( + "grpo_loss", + [GrpoLossCase(tokens=t, vocab=v, dtype=d) for d in dtypes for (t, v) in shapes], + variants, + ) + ] + + +run = bench_main(benchmarks) diff --git a/tools/benchmark/bench_mlp_activation.py b/tools/benchmark/bench_mlp_activation.py new file mode 100644 index 000000000..0106d0521 --- /dev/null +++ b/tools/benchmark/bench_mlp_activation.py @@ -0,0 +1,90 @@ +import dataclasses + +import torch + +from fast_llm.functional.config import ActivationType, TritonConfig +from fast_llm.functional.triton.mlp import ( + torch_mlp_activation, + triton_mlp_activation_autograd, + triton_mlp_activation_forward, +) +from tools.benchmark.runner import Case, Inputs, Variant +from tools.benchmark.utils import bench_main, dtype_short, standard_fwd_bwd_pytorch_variants + +# (tokens, ffn_dim) — input has shape (tokens, 2*ffn_dim) for gated. +_SHAPES = [ + (8192, 4096), # 7B/13B + (8192, 8192), + (8192, 14336), # 70B + (4096, 28672), # MoE up-projection +] +_ACTIVATION = ActivationType.silu + + +@dataclasses.dataclass +class MlpActivationCase(Case): + tokens: int + ffn_dim: int + dtype: torch.dtype + + @property + def name(self) -> str: + return f"({self.tokens}, {self.ffn_dim}) {dtype_short(self.dtype)}" + + @property + def compute_dtype(self) -> torch.dtype: + return self.dtype + + @property + def expected_bytes(self) -> int: + # fwd: 3*ffn_dim traffic; bwd: 5*ffn_dim. 8 elements/token total. + return 8 * self.tokens * self.ffn_dim * self.dtype.itemsize + + @property + def expected_flops(self) -> int: + # gated silu: fwd ≈ 6 FLOPs/element, bwd ≈ 8 FLOPs/element. + return 14 * self.tokens * self.ffn_dim + + def make_inputs(self, device: str) -> Inputs: + return { + "input": torch.randn(self.tokens, 2 * self.ffn_dim, dtype=self.dtype, device=device, requires_grad=True), + "grad_output": torch.randn(self.tokens, self.ffn_dim, dtype=self.dtype, device=device), + "gated": True, + "activation_type": _ACTIVATION, + } + + +def _triton_fwd(inputs: dict) -> dict: + output, _ = triton_mlp_activation_forward(inputs["input"], inputs["gated"], inputs["activation_type"]) + return {"output": output} + + +def _triton_fwd_bwd(inputs: dict) -> dict: + output = triton_mlp_activation_autograd(inputs["input"], inputs["gated"], inputs["activation_type"]) + output.backward(inputs["grad_output"]) + return {"output": output.detach(), "grad_input": inputs["input"].grad} + + +def benchmarks( + dtypes: tuple[torch.dtype, ...], + shapes: list[tuple[int, int]] | None = None, +) -> list[tuple[str, list, list]]: + shapes = shapes if shapes is not None else _SHAPES + variants = standard_fwd_bwd_pytorch_variants( + torch_mlp_activation, + input_keys=("input", "gated", "activation_type"), + grad_input_keys=("input",), + grad_output_key="grad_output", + ) + if TritonConfig.enabled(): + variants.append(Variant(name="fast_llm_triton", fwd=_triton_fwd, fwd_bwd=_triton_fwd_bwd)) + return [ + ( + "mlp_activation (gated silu)", + [MlpActivationCase(tokens=t, ffn_dim=f, dtype=d) for d in dtypes for (t, f) in shapes], + variants, + ) + ] + + +run = bench_main(benchmarks) diff --git a/tools/benchmark/bench_normalization.py b/tools/benchmark/bench_normalization.py new file mode 100644 index 000000000..636e428a7 --- /dev/null +++ b/tools/benchmark/bench_normalization.py @@ -0,0 +1,199 @@ +"""LayerNorm and RMSNorm. The Triton kernel writes parameter gradients to a +`grad_buffer` attribute (Fast-LLM convention) instead of autograd's `.grad`.""" + +import dataclasses + +import torch + +from fast_llm.functional.config import TritonConfig +from fast_llm.functional.triton.normalization import triton_normalization_autograd +from fast_llm.layers.common.normalization.normalization import ( + FastLayerNorm, + FusedLayerNorm, + FusedRMSNorm, + fast_normalization_available, + fused_normalization_available, +) +from tools.benchmark.runner import Case, Inputs, Variant +from tools.benchmark.utils import bench_main, dtype_short, standard_fwd_bwd_pytorch_variants + +# (batch*seq, hidden). Numel fixed at 32M to mimic a constant training memory +# budget across model widths; hidden swept from 1K to 16K. +_SHAPES = [ + (32768, 1024), + (16384, 2048), + (8192, 4096), + (4096, 8192), + (2048, 16384), +] +_EPS = 1e-5 + + +def _setup_param(tensor: torch.Tensor) -> torch.Tensor: + tensor.grad_buffer = torch.zeros_like(tensor) + tensor.param_grad_is_zero = True + return tensor + + +@dataclasses.dataclass +class _NormalizationCase(Case): + rows: int + cols: int + dtype: torch.dtype + + @property + def name(self) -> str: + return f"({self.rows}, {self.cols}) {dtype_short(self.dtype)}" + + @property + def compute_dtype(self) -> torch.dtype: + return self.dtype + + +class LayerNormCase(_NormalizationCase): + @property + def expected_bytes(self) -> int: + # 4× activations (fwd+bwd in/out) + weight & bias × (read + grad write). + return 4 * self.rows * self.cols * self.dtype.itemsize + 6 * self.cols * self.dtype.itemsize + + @property + def expected_flops(self) -> int: + # fwd ≈ 7 FLOPs/elem (mean, variance, normalize, scale+shift); bwd ≈ 2× fwd. + return 21 * self.rows * self.cols + + def make_inputs(self, device: str) -> Inputs: + return { + "input": torch.randn(self.rows, self.cols, dtype=self.dtype, device=device, requires_grad=True), + "weight": _setup_param(torch.randn(self.cols, dtype=self.dtype, device=device, requires_grad=True)), + "bias": _setup_param(torch.zeros(self.cols, dtype=self.dtype, device=device, requires_grad=True)), + "grad_output": torch.randn(self.rows, self.cols, dtype=self.dtype, device=device), + } + + +class RmsNormCase(_NormalizationCase): + @property + def expected_bytes(self) -> int: + # No bias compared to LayerNorm. + return 4 * self.rows * self.cols * self.dtype.itemsize + 3 * self.cols * self.dtype.itemsize + + @property + def expected_flops(self) -> int: + # No mean subtraction or bias compared to LayerNorm. + return 15 * self.rows * self.cols + + def make_inputs(self, device: str) -> Inputs: + return { + "input": torch.randn(self.rows, self.cols, dtype=self.dtype, device=device, requires_grad=True), + "weight": _setup_param(torch.randn(self.cols, dtype=self.dtype, device=device, requires_grad=True)), + "grad_output": torch.randn(self.rows, self.cols, dtype=self.dtype, device=device), + } + + +def _layer_norm_eager(input_, weight, bias): + return torch.layer_norm(input_, weight.shape, weight, bias, _EPS) + + +def _rms_norm_eager(input_, weight): + return torch.rms_norm(input_, weight.shape, weight, _EPS) + + +def _param_grad(param: torch.Tensor) -> torch.Tensor: + return param.grad if param.grad is not None else param.grad_buffer + + +def _layer_norm_triton_fwd(inputs: dict) -> dict: + return { + "output": triton_normalization_autograd(inputs["input"], inputs["weight"], inputs["bias"], _EPS, True, False) + } + + +def _layer_norm_triton_fwd_bwd(inputs: dict) -> dict: + output = triton_normalization_autograd(inputs["input"], inputs["weight"], inputs["bias"], _EPS, True, False) + output.backward(inputs["grad_output"]) + return { + "output": output.detach(), + "grad_input": inputs["input"].grad, + "grad_weight": _param_grad(inputs["weight"]), + "grad_bias": _param_grad(inputs["bias"]), + } + + +def _rms_norm_triton_fwd(inputs: dict) -> dict: + return {"output": triton_normalization_autograd(inputs["input"], inputs["weight"], None, _EPS, True, False)} + + +def _rms_norm_triton_fwd_bwd(inputs: dict) -> dict: + output = triton_normalization_autograd(inputs["input"], inputs["weight"], None, _EPS, True, False) + output.backward(inputs["grad_output"]) + return { + "output": output.detach(), + "grad_input": inputs["input"].grad, + "grad_weight": _param_grad(inputs["weight"]), + } + + +def _layer_norm_apex_fused(input_, weight, bias): + return FusedLayerNorm.apply(input_, weight.shape, weight, bias, _EPS) + + +def _layer_norm_apex_fast(input_, weight, bias): + return FastLayerNorm.apply(input_, weight.shape, weight, bias, _EPS) + + +def _rms_norm_apex_fused(input_, weight): + return FusedRMSNorm.apply(input_, weight.shape, weight, _EPS) + + +_LAYER_NORM_EXTRAS: dict = {} +if fused_normalization_available: + _LAYER_NORM_EXTRAS["apex_fused"] = _layer_norm_apex_fused +if fast_normalization_available: + _LAYER_NORM_EXTRAS["apex_fast"] = _layer_norm_apex_fast + +_RMS_NORM_EXTRAS: dict = {} +if fused_normalization_available: + _RMS_NORM_EXTRAS["apex_fused"] = _rms_norm_apex_fused + + +def benchmarks( + dtypes: tuple[torch.dtype, ...], + shapes: list[tuple[int, int]] | None = None, +) -> list[tuple[str, list, list]]: + shapes = shapes if shapes is not None else _SHAPES + layer_norm_variants = standard_fwd_bwd_pytorch_variants( + _layer_norm_eager, + input_keys=("input", "weight", "bias"), + grad_input_keys=("input", "weight", "bias"), + grad_output_key="grad_output", + extra_functions=_LAYER_NORM_EXTRAS, + ) + if TritonConfig.enabled(): + layer_norm_variants.append( + Variant(name="fast_llm_triton", fwd=_layer_norm_triton_fwd, fwd_bwd=_layer_norm_triton_fwd_bwd) + ) + rms_norm_variants = standard_fwd_bwd_pytorch_variants( + _rms_norm_eager, + input_keys=("input", "weight"), + grad_input_keys=("input", "weight"), + grad_output_key="grad_output", + extra_functions=_RMS_NORM_EXTRAS, + ) + if TritonConfig.enabled(): + rms_norm_variants.append( + Variant(name="fast_llm_triton", fwd=_rms_norm_triton_fwd, fwd_bwd=_rms_norm_triton_fwd_bwd) + ) + return [ + ( + "normalization: layer_norm", + [LayerNormCase(rows=r, cols=c, dtype=d) for d in dtypes for (r, c) in shapes], + layer_norm_variants, + ), + ( + "normalization: rms_norm", + [RmsNormCase(rows=r, cols=c, dtype=d) for d in dtypes for (r, c) in shapes], + rms_norm_variants, + ), + ] + + +run = bench_main(benchmarks) diff --git a/tools/benchmark/bench_pointwise.py b/tools/benchmark/bench_pointwise.py new file mode 100644 index 000000000..47cd23812 --- /dev/null +++ b/tools/benchmark/bench_pointwise.py @@ -0,0 +1,108 @@ +import dataclasses +import typing + +import torch + +from fast_llm.functional.triton.pointwise import triton_add, triton_copy, triton_fill +from tools.benchmark.runner import Case, Inputs +from tools.benchmark.utils import bench_main, dtype_short, standard_fwd_variants + +# 4× steps so L2 → HBM and saturated-HBM regimes are visible. +_SIZES_NUMEL = [ + 1 << 20, # 1M — 2 MiB bf16 (L2-resident on most GPUs) + 1 << 22, # 4M — 8 MiB bf16 (L2 boundary) + 1 << 24, # 16M — 32 MiB bf16 (HBM) + 1 << 26, # 64M — 128 MiB bf16 (HBM) + 1 << 28, # 256M — 512 MiB bf16 (large HBM, near-saturated) +] + + +@dataclasses.dataclass +class _PointwiseCase(Case): + numel: int + dtype: torch.dtype + # Bytes traffic = bytes_factor × numel × dtype.itemsize. + bytes_factor: typing.ClassVar[int] + + @property + def name(self) -> str: + return f"({self.numel},) {dtype_short(self.dtype)}" + + @property + def compute_dtype(self) -> torch.dtype: + return self.dtype + + @property + def expected_bytes(self) -> int: + return self.bytes_factor * self.numel * self.dtype.itemsize + + +class CopyCase(_PointwiseCase): + bytes_factor = 2 + + def make_inputs(self, device: str) -> Inputs: + input_ = torch.randn(self.numel, dtype=self.dtype, device=device) + return {"input_": input_, "out": torch.empty_like(input_)} + + +class FillCase(_PointwiseCase): + bytes_factor = 1 + + def make_inputs(self, device: str) -> Inputs: + return {"input_": torch.empty(self.numel, dtype=self.dtype, device=device), "value": 1.5} + + +class AddCase(_PointwiseCase): + bytes_factor = 3 + + @property + def expected_flops(self) -> int: + return self.numel + + def make_inputs(self, device: str) -> Inputs: + return { + "input_": torch.randn(self.numel, dtype=self.dtype, device=device), + "other": torch.randn(self.numel, dtype=self.dtype, device=device), + "out": torch.empty(self.numel, dtype=self.dtype, device=device), + } + + +def _copy_eager(input_: torch.Tensor, out: torch.Tensor) -> torch.Tensor: + return out.copy_(input_) + + +def _fill_eager(input_: torch.Tensor, value: float) -> torch.Tensor: + return input_.fill_(value) + + +def _add_eager(input_: torch.Tensor, other: torch.Tensor, out: torch.Tensor) -> torch.Tensor: + return torch.add(input_, other, out=out) + + +_COPY_VARIANTS = standard_fwd_variants( + eager_function=_copy_eager, + triton_function=triton_copy, + unpack=lambda inputs: (inputs["input_"], inputs["out"]), +) +_FILL_VARIANTS = standard_fwd_variants( + eager_function=_fill_eager, + triton_function=triton_fill, + unpack=lambda inputs: (inputs["input_"], inputs["value"]), +) +_ADD_VARIANTS = standard_fwd_variants( + eager_function=_add_eager, + triton_function=triton_add, + unpack=lambda inputs: (inputs["input_"], inputs["other"], inputs["out"]), +) + + +def benchmarks(dtypes: tuple[torch.dtype, ...], shapes: list[int] | None = None) -> list[tuple[str, list, list]]: + shapes = shapes if shapes is not None else _SIZES_NUMEL + return [ + ("pointwise: copy", [CopyCase(numel=n, dtype=d) for d in dtypes for n in shapes], _COPY_VARIANTS), + ("pointwise: fill", [FillCase(numel=n, dtype=d) for d in dtypes for n in shapes], _FILL_VARIANTS), + ("pointwise: add", [AddCase(numel=n, dtype=d) for d in dtypes for n in shapes], _ADD_VARIANTS), + ] + + +run = bench_main(benchmarks) diff --git a/tools/benchmark/bench_rotary.py b/tools/benchmark/bench_rotary.py new file mode 100644 index 000000000..99f8651f8 --- /dev/null +++ b/tools/benchmark/bench_rotary.py @@ -0,0 +1,118 @@ +"""Rotary position embeddings. The Triton kernel is in-place; backward is an +identical rotation with conjugated frequencies, so only fwd is benchmarked.""" + +import dataclasses + +import torch + +from fast_llm.functional.config import TritonConfig +from fast_llm.functional.triton.rotary import triton_rotary_ +from tools.benchmark.runner import Case, Inputs, Variant +from tools.benchmark.utils import bench_main, dtype_short + +# (tokens, num_heads, head_size) — tokens = batch * seq_len +_SHAPES = [ + (4096, 32, 128), # 7B/13B, 4K context + (8192, 32, 128), # 7B/13B, 8K context + (4096, 64, 128), # 70B / MoE, 4K context + (4096, 8, 128), # GQA key-value heads, 4K context +] + + +@dataclasses.dataclass +class RotaryCase(Case): + tokens: int + num_heads: int + head_size: int + dtype: torch.dtype + + @property + def name(self) -> str: + return f"({self.tokens}, {self.num_heads}, {self.head_size}) {dtype_short(self.dtype)}" + + @property + def compute_dtype(self) -> torch.dtype: + return self.dtype + + @property + def expected_bytes(self) -> int: + # frequencies are float32, hence the extra 4 bytes per token×head_size. + return ( + 2 * self.tokens * self.num_heads * self.head_size * self.dtype.itemsize + self.tokens * self.head_size * 4 + ) + + @property + def expected_flops(self) -> int: + # 6 FLOPs per (re, im) element pair: 4 muls + 2 add/sub. + return 6 * self.tokens * self.num_heads * (self.head_size // 2) + + def make_inputs(self, device: str) -> Inputs: + rotary_dim = self.head_size // 2 + input_ = torch.randn(self.tokens, self.num_heads, self.head_size, dtype=self.dtype, device=device) + return { + "input_": input_, + "work": input_.clone(), + "frequencies": torch.randn(self.tokens, 2 * rotary_dim, dtype=torch.float32, device=device), + } + + +def _rotary_eager(input_: torch.Tensor, frequencies: torch.Tensor) -> torch.Tensor: + rotary_dim = frequencies.shape[-1] // 2 + freq_re = frequencies[:, :rotary_dim].unsqueeze(1) + freq_im = frequencies[:, rotary_dim:].unsqueeze(1) + x_re, x_im = input_.chunk(2, dim=-1) + out_re = x_re * freq_re - x_im * freq_im + out_im = x_im * freq_re + x_re * freq_im + return torch.cat([out_re, out_im], dim=-1) + + +_rotary_compiled_default = torch.compile(_rotary_eager, mode="default", dynamic=False) +_rotary_compiled_max = torch.compile(_rotary_eager, mode="max-autotune-no-cudagraphs", dynamic=False) + + +def _rotary_variants() -> list[Variant]: + variants = [ + Variant( + name="fp32_reference", + fwd=lambda inputs: {"output": _rotary_eager(inputs["input_"].float(), inputs["frequencies"])}, + is_reference=True, + ), + Variant( + name="pytorch_eager", + fwd=lambda inputs: {"output": _rotary_eager(inputs["input_"], inputs["frequencies"])}, + ), + Variant( + name="pytorch_compiled", + fwd=lambda inputs: {"output": _rotary_compiled_default(inputs["input_"], inputs["frequencies"])}, + ), + Variant( + name="pytorch_compiled_max", + fwd=lambda inputs: {"output": _rotary_compiled_max(inputs["input_"], inputs["frequencies"])}, + ), + ] + if TritonConfig.enabled(): + variants.append( + Variant( + name="fast_llm_triton", + fwd=lambda inputs: {"output": triton_rotary_(inputs["work"], inputs["frequencies"])}, + reset_inputs=lambda inputs: inputs["work"].copy_(inputs["input_"]), + ) + ) + return variants + + +def benchmarks( + dtypes: tuple[torch.dtype, ...], + shapes: list[tuple[int, int, int]] | None = None, +) -> list[tuple[str, list, list]]: + shapes = shapes if shapes is not None else _SHAPES + return [ + ( + "rotary", + [RotaryCase(tokens=t, num_heads=h, head_size=hs, dtype=d) for d in dtypes for (t, h, hs) in shapes], + _rotary_variants(), + ) + ] + + +run = bench_main(benchmarks) diff --git a/tools/benchmark/bench_sparse_copy.py b/tools/benchmark/bench_sparse_copy.py new file mode 100644 index 000000000..62c165e3f --- /dev/null +++ b/tools/benchmark/bench_sparse_copy.py @@ -0,0 +1,202 @@ +import dataclasses + +import torch + +from fast_llm.functional.config import TritonConfig +from fast_llm.functional.triton.sparse_copy import ( + SparseMap, + copy_dense_to_sparse_autograd, + copy_sparse_to_dense_autograd, + get_sparse_map, +) +from tools.benchmark.runner import Case, Inputs, Variant +from tools.benchmark.utils import bench_main, dtype_short, standard_fwd_bwd_pytorch_variants + +# (tokens, top_k, num_experts, hidden_size) +_SHAPES = [ + (4096, 2, 8, 4096), # Mixtral-8x7B-like + (4096, 2, 64, 4096), # fine-grained MoE + (4096, 2, 8, 8192), # wide hidden +] + + +def _make_phantom_mask(sparse_map: SparseMap, device: str) -> torch.Tensor: + # True for within-expert padding rows and the static tail past expert_ends[-1]; + # used only in output_postprocess, never in the timed path. + mask = torch.zeros(sparse_map.num_rows, 1, dtype=torch.bool, device=device) + for expert in range(sparse_map.num_experts): + pad_begin = int(sparse_map.expert_pad_begins[expert]) + pad_end = int(sparse_map.expert_ends[expert]) + if pad_end > pad_begin: + mask[pad_begin:pad_end] = True + tail_begin = int(sparse_map.expert_ends[-1]) + if sparse_map.num_rows > tail_begin: + mask[tail_begin:] = True + return mask + + +@dataclasses.dataclass +class _SparseCopyCase(Case): + tokens: int + top_k: int + num_experts: int + hidden: int + dtype: torch.dtype + + @property + def name(self) -> str: + return f"({self.tokens}, {self.top_k}, {self.num_experts}, {self.hidden}) {dtype_short(self.dtype)}" + + @property + def compute_dtype(self) -> torch.dtype: + return self.dtype + + @property + def expected_bytes(self) -> int: + # 2× (sparse + dense) hidden traffic; combine adds scores read/write. + return 2 * (1 + self.top_k) * self.tokens * self.hidden * self.dtype.itemsize + + +class DispatchCase(_SparseCopyCase): + def make_inputs(self, device: str) -> Inputs: + top_experts = torch.randint(0, self.num_experts, (self.tokens, self.top_k), device=device) + sparse_map = get_sparse_map(top_experts, self.num_experts) + return { + "dense": torch.randn(self.tokens, self.hidden, dtype=self.dtype, device=device, requires_grad=True), + "sparse_map": sparse_map, + "phantom_mask": _make_phantom_mask(sparse_map, device), + "backward_grad": torch.ones(sparse_map.num_rows, self.hidden, dtype=self.dtype, device=device), + } + + +class CombineCase(_SparseCopyCase): + @property + def expected_bytes(self) -> int: + # Adds scores read/write on top of the dispatch traffic. + return super().expected_bytes + 4 * self.tokens * self.top_k * self.dtype.itemsize + + def make_inputs(self, device: str) -> Inputs: + top_experts = torch.randint(0, self.num_experts, (self.tokens, self.top_k), device=device) + sparse_map = get_sparse_map(top_experts, self.num_experts) + return { + "sparse": torch.randn( + sparse_map.num_rows, self.hidden, dtype=self.dtype, device=device, requires_grad=True + ), + "scores": torch.softmax( + torch.randn(self.tokens, self.top_k, dtype=self.dtype, device=device), dim=-1 + ).requires_grad_(True), + "sparse_map": sparse_map, + "phantom_mask": _make_phantom_mask(sparse_map, device), + "backward_grad": torch.ones(self.tokens, self.hidden, dtype=self.dtype, device=device), + } + + +def _dispatch_pytorch(dense_input: torch.Tensor, sparse_map: SparseMap) -> torch.Tensor: + out = dense_input.new_zeros(sparse_map.num_rows, dense_input.shape[1]) + sparse_rows = sparse_map.sparse_rows.long() + for k in range(sparse_map.num_experts_per_token): + out[sparse_rows[:, k]] = dense_input + return out + + +def _dispatch_triton_fwd(inputs: dict) -> dict: + return {"output": copy_dense_to_sparse_autograd(inputs["dense"], inputs["sparse_map"])} + + +def _dispatch_triton_fwd_bwd(inputs: dict) -> dict: + output = copy_dense_to_sparse_autograd(inputs["dense"], inputs["sparse_map"]) + output.backward(inputs["backward_grad"]) + return {"output": output.detach(), "grad_dense": inputs["dense"].grad} + + +def _dispatch_postprocess(output: dict[str, torch.Tensor], inputs: dict) -> dict[str, torch.Tensor]: + output["output"].masked_fill_(inputs["phantom_mask"], 0) + return output + + +def _combine_pytorch(sparse_input: torch.Tensor, scores: torch.Tensor, sparse_map: SparseMap) -> torch.Tensor: + out = sparse_input.new_zeros(sparse_map.num_rows_dense, sparse_input.shape[1]) + sparse_rows = sparse_map.sparse_rows.long() + for k in range(sparse_map.num_experts_per_token): + out = out + sparse_input[sparse_rows[:, k]] * scores[:, k : k + 1] + return out + + +def _combine_triton_fwd(inputs: dict) -> dict: + return {"output": copy_sparse_to_dense_autograd(inputs["sparse"], inputs["scores"], inputs["sparse_map"])} + + +def _combine_triton_fwd_bwd(inputs: dict) -> dict: + output = copy_sparse_to_dense_autograd(inputs["sparse"], inputs["scores"], inputs["sparse_map"]) + output.backward(inputs["backward_grad"]) + return { + "output": output.detach(), + "grad_sparse": inputs["sparse"].grad, + "grad_scores": inputs["scores"].grad, + } + + +def _combine_postprocess(output: dict[str, torch.Tensor], inputs: dict) -> dict[str, torch.Tensor]: + if "grad_sparse" in output: + output["grad_sparse"].masked_fill_(inputs["phantom_mask"], 0) + return output + + +def benchmarks( + dtypes: tuple[torch.dtype, ...], + shapes: list[tuple[int, int, int, int]] | None = None, +) -> list[tuple[str, list, list]]: + shapes = shapes if shapes is not None else _SHAPES + dispatch_variants = standard_fwd_bwd_pytorch_variants( + _dispatch_pytorch, + input_keys=("dense", "sparse_map"), + grad_input_keys=("dense",), + grad_output_key="backward_grad", + ) + if TritonConfig.enabled(): + dispatch_variants.append( + Variant( + name="fast_llm_triton", + fwd=_dispatch_triton_fwd, + fwd_bwd=_dispatch_triton_fwd_bwd, + output_postprocess=_dispatch_postprocess, + ) + ) + combine_variants = standard_fwd_bwd_pytorch_variants( + _combine_pytorch, + input_keys=("sparse", "scores", "sparse_map"), + grad_input_keys=("sparse", "scores"), + grad_output_key="backward_grad", + ) + if TritonConfig.enabled(): + combine_variants.append( + Variant( + name="fast_llm_triton", + fwd=_combine_triton_fwd, + fwd_bwd=_combine_triton_fwd_bwd, + output_postprocess=_combine_postprocess, + ) + ) + return [ + ( + "sparse_copy: dispatch", + [ + DispatchCase(tokens=t, top_k=k, num_experts=e, hidden=h, dtype=d) + for d in dtypes + for (t, k, e, h) in shapes + ], + dispatch_variants, + ), + ( + "sparse_copy: combine", + [ + CombineCase(tokens=t, top_k=k, num_experts=e, hidden=h, dtype=d) + for d in dtypes + for (t, k, e, h) in shapes + ], + combine_variants, + ), + ] + + +run = bench_main(benchmarks) diff --git a/tools/benchmark/bench_sparse_linear.py b/tools/benchmark/bench_sparse_linear.py new file mode 100644 index 000000000..1e05eb038 --- /dev/null +++ b/tools/benchmark/bench_sparse_linear.py @@ -0,0 +1,235 @@ +import dataclasses + +import torch + +from fast_llm.functional.config import TritonConfig +from fast_llm.functional.triton.sparse_copy import SparseMap, get_sparse_map +from fast_llm.functional.triton.sparse_linear import InputSparseLinear, OutputSparseLinear +from tools.benchmark.runner import Case, Inputs, Variant +from tools.benchmark.utils import bench_main, dtype_short, standard_fwd_bwd_pytorch_variants + +# (tokens, top_k, num_experts, hidden, ffn_per_expert) +_SHAPES = [ + (4096, 2, 8, 4096, 14336), # Mixtral-8x7B: 8 experts, ffn=14336 + (4096, 2, 64, 4096, 1792), # fine-grained MoE: 64 experts, same total capacity + (4096, 2, 8, 8192, 28672), # large hidden / wide FFN +] + +# Triton autotuning warmup needs to run only once per shape; make_inputs is +# called many times per case (per variant, per fwd/fwd_bwd/memory pass). +_output_sparse_warmed_up: set[tuple] = set() +_input_inner_sparse_warmed_up: set[tuple] = set() + + +def _mask_padded_rows(candidate: dict[str, torch.Tensor], inputs: dict) -> dict[str, torch.Tensor]: + # Two regions in the kernel's forward output and grad_lhs are by-design garbage: + # per-expert padding [pad_begin, expert_end) and phantom rows past expert_ends[-1]. + # The loop reference produces zeros there; mask the kernel output to match so + # rel_rms reflects only the real rows. grad_rhs already excludes padding. + sparse_map = inputs["sparse_map"] + pad_begins = sparse_map.expert_pad_begins.tolist() + pad_ends = sparse_map.expert_ends.tolist() + last_expert_end = pad_ends[-1] + masked = dict(candidate) + for key in ("output", "grad_lhs"): + if key not in masked: + continue + clone = masked[key].clone() + for begin, end in zip(pad_begins, pad_ends, strict=True): + if end > begin: + clone[begin:end] = 0 + if clone.shape[0] > last_expert_end: + clone[last_expert_end:] = 0 + masked[key] = clone + return masked + + +@dataclasses.dataclass +class _SparseLinearCase(Case): + tokens: int + top_k: int + num_experts: int + hidden: int + ffn_per_expert: int + dtype: torch.dtype + + @property + def name(self) -> str: + return ( + f"({self.tokens}, {self.top_k}, {self.num_experts}, " + f"{self.hidden}, {self.ffn_per_expert}) {dtype_short(self.dtype)}" + ) + + @property + def compute_dtype(self) -> torch.dtype: + return self.dtype + + @property + def expected_bytes(self) -> int: + # Approximation: 3× lhs + 3× rhs + 2× output traffic across fwd+bwd. + sparse_tokens = self.tokens * self.top_k + lhs_bytes = sparse_tokens * self.hidden * self.dtype.itemsize + rhs_bytes = self.hidden * self.ffn_per_expert * self.num_experts * self.dtype.itemsize + output_bytes = sparse_tokens * self.ffn_per_expert * self.dtype.itemsize + return 3 * lhs_bytes + 3 * rhs_bytes + 2 * output_bytes + + @property + def expected_flops(self) -> int: + # 3 matmuls (fwd: lhs@rhs, bwd_lhs: grad@rhs.T, bwd_rhs: lhs.T@grad). + return 3 * 2 * self.tokens * self.top_k * self.hidden * self.ffn_per_expert + + def _make_sparse_map(self, device: str) -> SparseMap: + top_experts = torch.randint(0, self.num_experts, (self.tokens, self.top_k), device=device) + return get_sparse_map(top_experts, self.num_experts) + + +class OutputSparseCase(_SparseLinearCase): + def make_inputs(self, device: str) -> Inputs: + sparse_map = self._make_sparse_map(device) + lhs_data = torch.randn(sparse_map.num_rows, self.hidden, dtype=self.dtype, device=device) + rhs_data = torch.randn(self.hidden, self.ffn_per_expert * self.num_experts, dtype=self.dtype, device=device) + backward_grad = torch.ones(sparse_map.num_rows, self.ffn_per_expert, dtype=self.dtype, device=device) + warmup_key = (self.tokens, self.top_k, self.num_experts, self.hidden, self.ffn_per_expert, self.dtype) + if TritonConfig.enabled() and warmup_key not in _output_sparse_warmed_up: + warmup_lhs = lhs_data.detach().requires_grad_(True) + warmup_rhs = rhs_data.detach().requires_grad_(True) + warmup_out = OutputSparseLinear.apply(warmup_lhs, warmup_rhs, sparse_map) + warmup_out.backward(backward_grad) + del warmup_lhs, warmup_rhs, warmup_out + _output_sparse_warmed_up.add(warmup_key) + return { + "lhs": lhs_data.requires_grad_(True), + "rhs": rhs_data.requires_grad_(True), + "sparse_map": sparse_map, + "backward_grad": backward_grad, + } + + +class InputInnerSparseCase(_SparseLinearCase): + def make_inputs(self, device: str) -> Inputs: + sparse_map = self._make_sparse_map(device) + lhs_data = torch.randn(sparse_map.num_rows, self.ffn_per_expert, dtype=self.dtype, device=device) + rhs_data = torch.randn(self.ffn_per_expert * self.num_experts, self.hidden, dtype=self.dtype, device=device) + backward_grad = torch.ones(sparse_map.num_rows, self.hidden, dtype=self.dtype, device=device) + warmup_key = (self.tokens, self.top_k, self.num_experts, self.hidden, self.ffn_per_expert, self.dtype) + if TritonConfig.enabled() and warmup_key not in _input_inner_sparse_warmed_up: + warmup_lhs = lhs_data.detach().requires_grad_(True) + warmup_rhs = rhs_data.detach().requires_grad_(True) + warmup_out = InputSparseLinear.apply(warmup_lhs, warmup_rhs, sparse_map) + warmup_out.backward(backward_grad) + del warmup_lhs, warmup_rhs, warmup_out + _input_inner_sparse_warmed_up.add(warmup_key) + return { + "lhs": lhs_data.requires_grad_(True), + "rhs": rhs_data.requires_grad_(True), + "sparse_map": sparse_map, + "backward_grad": backward_grad, + } + + +def _output_sparse_loop(lhs: torch.Tensor, rhs: torch.Tensor, sparse_map: SparseMap) -> torch.Tensor: + ffn_per_expert = rhs.shape[1] // sparse_map.num_experts + out = lhs.new_zeros(sparse_map.num_rows, ffn_per_expert) + for expert in range(sparse_map.num_experts): + row_begin = int(sparse_map.expert_ends[expert - 1]) if expert > 0 else 0 + row_end = int(sparse_map.expert_pad_begins[expert]) + if row_end > row_begin: + col_begin = expert * ffn_per_expert + out[row_begin:row_end] = lhs[row_begin:row_end] @ rhs[:, col_begin : col_begin + ffn_per_expert] + return out + + +def _output_sparse_triton_fwd(inputs: dict) -> dict: + return {"output": OutputSparseLinear.apply(inputs["lhs"], inputs["rhs"], inputs["sparse_map"])} + + +def _output_sparse_triton_fwd_bwd(inputs: dict) -> dict: + output = OutputSparseLinear.apply(inputs["lhs"], inputs["rhs"], inputs["sparse_map"]) + output.backward(inputs["backward_grad"]) + return {"output": output.detach(), "grad_lhs": inputs["lhs"].grad, "grad_rhs": inputs["rhs"].grad} + + +def _input_inner_sparse_loop(lhs: torch.Tensor, rhs: torch.Tensor, sparse_map: SparseMap) -> torch.Tensor: + ffn_per_expert = rhs.shape[0] // sparse_map.num_experts + out = lhs.new_zeros(sparse_map.num_rows, rhs.shape[1]) + for expert in range(sparse_map.num_experts): + row_begin = int(sparse_map.expert_ends[expert - 1]) if expert > 0 else 0 + row_end = int(sparse_map.expert_pad_begins[expert]) + if row_end > row_begin: + inner_begin = expert * ffn_per_expert + out[row_begin:row_end] = lhs[row_begin:row_end] @ rhs[inner_begin : inner_begin + ffn_per_expert] + return out + + +def _input_inner_sparse_triton_fwd(inputs: dict) -> dict: + return {"output": InputSparseLinear.apply(inputs["lhs"], inputs["rhs"], inputs["sparse_map"])} + + +def _input_inner_sparse_triton_fwd_bwd(inputs: dict) -> dict: + output = InputSparseLinear.apply(inputs["lhs"], inputs["rhs"], inputs["sparse_map"]) + output.backward(inputs["backward_grad"]) + return {"output": output.detach(), "grad_lhs": inputs["lhs"].grad, "grad_rhs": inputs["rhs"].grad} + + +def benchmarks( + dtypes: tuple[torch.dtype, ...], + shapes: list[tuple[int, int, int, int, int]] | None = None, +) -> list[tuple[str, list, list]]: + shapes = shapes if shapes is not None else _SHAPES + output_sparse_variants = standard_fwd_bwd_pytorch_variants( + _output_sparse_loop, + input_keys=("lhs", "rhs", "sparse_map"), + grad_input_keys=("lhs", "rhs"), + grad_output_key="backward_grad", + eager_name="pytorch_loop", + enable_max_autotune=False, + ) + if TritonConfig.enabled(): + output_sparse_variants.append( + Variant( + name="fast_llm_triton", + fwd=_output_sparse_triton_fwd, + fwd_bwd=_output_sparse_triton_fwd_bwd, + output_postprocess=_mask_padded_rows, + ) + ) + input_inner_sparse_variants = standard_fwd_bwd_pytorch_variants( + _input_inner_sparse_loop, + input_keys=("lhs", "rhs", "sparse_map"), + grad_input_keys=("lhs", "rhs"), + grad_output_key="backward_grad", + eager_name="pytorch_loop", + enable_max_autotune=False, + ) + if TritonConfig.enabled(): + input_inner_sparse_variants.append( + Variant( + name="fast_llm_triton", + fwd=_input_inner_sparse_triton_fwd, + fwd_bwd=_input_inner_sparse_triton_fwd_bwd, + output_postprocess=_mask_padded_rows, + ) + ) + return [ + ( + "sparse_linear: output_sparse (layer 1 / up-proj)", + [ + OutputSparseCase(tokens=t, top_k=k, num_experts=e, hidden=h, ffn_per_expert=f, dtype=d) + for d in dtypes + for (t, k, e, h, f) in shapes + ], + output_sparse_variants, + ), + ( + "sparse_linear: input_inner_sparse (layer 2 / down-proj)", + [ + InputInnerSparseCase(tokens=t, top_k=k, num_experts=e, hidden=h, ffn_per_expert=f, dtype=d) + for d in dtypes + for (t, k, e, h, f) in shapes + ], + input_inner_sparse_variants, + ), + ] + + +run = bench_main(benchmarks) diff --git a/tools/benchmark/gpu_specs.py b/tools/benchmark/gpu_specs.py new file mode 100644 index 000000000..563348dcd --- /dev/null +++ b/tools/benchmark/gpu_specs.py @@ -0,0 +1,99 @@ +""" +Peak memory bandwidth and compute for high-end datacenter GPUs, used to report +%-of-peak. Values are vendor-published nominal dense peaks (no 2:4 sparsity); +real-world achievable is typically ~80-90% of nominal for bandwidth and +~70-85% for compute. +""" + +import dataclasses + +import torch + + +@dataclasses.dataclass(frozen=True) +class GpuSpec: + name: str + peak_bandwidth_gbps: float + peak_tflops_fp32: float + peak_tflops_tf32: float + peak_tflops_fp16: float + peak_tflops_bf16: float + + def peak_tflops(self, dtype: torch.dtype) -> float | None: + if dtype == torch.float32: + return self.peak_tflops_fp32 + if dtype == torch.float16: + return self.peak_tflops_fp16 + if dtype == torch.bfloat16: + return self.peak_tflops_bf16 + return None + + +# Name-substring match → spec. First match wins. Dense (non-sparse) peaks. +_GPU_SPECS: list[tuple[str, GpuSpec]] = [ + ( + "B200", + GpuSpec( + name="NVIDIA B200 SXM", + peak_bandwidth_gbps=8000.0, + peak_tflops_fp32=80.0, + peak_tflops_tf32=1100.0, + peak_tflops_fp16=2250.0, + peak_tflops_bf16=2250.0, + ), + ), + ( + "B100", + GpuSpec( + name="NVIDIA B100 SXM", + peak_bandwidth_gbps=8000.0, + peak_tflops_fp32=60.0, + peak_tflops_tf32=880.0, + peak_tflops_fp16=1800.0, + peak_tflops_bf16=1800.0, + ), + ), + ( + "H200", + GpuSpec( + name="NVIDIA H200 SXM", + peak_bandwidth_gbps=4800.0, + peak_tflops_fp32=67.0, + peak_tflops_tf32=494.0, + peak_tflops_fp16=989.0, + peak_tflops_bf16=989.0, + ), + ), + ( + "H100", + GpuSpec( + name="NVIDIA H100 SXM", + peak_bandwidth_gbps=3350.0, + peak_tflops_fp32=67.0, + peak_tflops_tf32=494.0, + peak_tflops_fp16=989.0, + peak_tflops_bf16=989.0, + ), + ), + ( + "A100", + GpuSpec( + name="NVIDIA A100 80GB SXM", + peak_bandwidth_gbps=2039.0, + peak_tflops_fp32=19.5, + peak_tflops_tf32=156.0, + peak_tflops_fp16=312.0, + peak_tflops_bf16=312.0, + ), + ), +] + + +def detect_gpu_spec() -> GpuSpec | None: + if not torch.cuda.is_available(): + return None + name = torch.cuda.get_device_name() + for needle, spec in _GPU_SPECS: + if needle in name: + return spec + return None diff --git a/tools/benchmark/runner.py b/tools/benchmark/runner.py new file mode 100644 index 000000000..de4d94a24 --- /dev/null +++ b/tools/benchmark/runner.py @@ -0,0 +1,688 @@ +""" +Core benchmarking infrastructure for Fast-LLM Triton kernels. + +Each benchmark file defines a list of `Case` objects (input shape/dtype +sweep) and a list of `Variant` objects (implementations to compare — e.g. +pytorch eager, pytorch compiled, Triton). The runner invokes each variant +on each case, measures timing (median + mean + percentiles via CUDA events), +measures peak/final memory, and compares outputs against an fp32 reference +using RMS error. Results are printed as a table per case. +""" + +import dataclasses +import math +import statistics +import time +from collections.abc import Callable +from typing import Any + +import torch + +from fast_llm.utils import header +from tools.benchmark.gpu_specs import GpuSpec, detect_gpu_spec + +# Before each timed CUDA-graph-backed call we must mark a new step so the graph +# system knows the previous rep's output buffers are no longer live. Without +# this, max-autotune compiled functions raise "tensor output of CUDAGraphs has +# been overwritten by a subsequent run" on the second call. +_cudagraph_mark_step_begin: Callable[[], None] | None = getattr( + getattr(torch, "compiler", None), "cudagraph_mark_step_begin", None +) + + +def _guarded(fn: Callable[[], Any]) -> Callable[[], Any]: + """Wrap fn so cudagraph_mark_step_begin() is called before each invocation. + This tells the CUDA graph system that previous outputs are no longer live + and can be overwritten, preventing 'overwritten by subsequent run' errors + when a max-autotune compiled function is called more than once. + When CUDA graphs are not in use the wrapper is a no-op pass-through.""" + if _cudagraph_mark_step_begin is None: + return fn + + def _wrapped() -> Any: + _cudagraph_mark_step_begin() + return fn() + + return _wrapped + + +Inputs = dict[str, Any] +VariantFn = Callable[[Inputs], Any] + + +@dataclasses.dataclass +class Variant: + """A single implementation being compared. Provide `fwd` for forward-only + timing. Provide `fwd_bwd` for forward+backward timing; when both are set, + backward-only time is reported as `fwd_bwd - fwd`.""" + + name: str + fwd: VariantFn | None = None + fwd_bwd: VariantFn | None = None + # The fp32 reference variant. Exactly one per benchmark; its outputs are + # the ground truth for RMS-error comparison. + is_reference: bool = False + # Applied to the output dict during the correctness check only — never during + # timing. Receives {name: tensor} and the full inputs dict; returns the + # (possibly modified) dict. Use this to mask out don't-care regions so they + # don't inflate RMS errors (e.g. uninitialized phantom rows in sparse buffers). + output_postprocess: Callable[[dict[str, torch.Tensor], Inputs], dict[str, torch.Tensor]] | None = None + # Called between timing reps (outside the timed region) to restore any + # input tensors the variant mutates in-place. Use this instead of cloning + # inside the timed callable so the mutation cost is not measured. + reset_inputs: Callable[[Inputs], Any] | None = None + + +class Case: + """Base for a single input configuration. Subclasses are dataclasses holding + the kernel's shape parameters (e.g. rows, cols, dtype) and override `name` + and `make_inputs`; the throughput properties are optional.""" + + # Subclasses must provide these. + name: str + # Subclasses may override these (defaults skip the corresponding columns). + expected_bytes: int | None = None # bytes read+written; enables GB/s + %BW. + expected_flops: int | None = None # FLOPs performed; enables TFLOP/s + %FLOPs. + compute_dtype: torch.dtype | None = None # dtype of hot inputs, picks peak column. + + def make_inputs(self, device: str) -> Inputs: + """Return a fresh dict of input tensors on `device`. Called once per + variant per mode, after a global seed reset, so every variant sees + identical inputs.""" + raise NotImplementedError + + +def _device() -> str: + return "cuda" if torch.cuda.is_available() else "cpu" + + +def _seeded_inputs(case: Case, seed: int = 0) -> Inputs: + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(seed) + return case.make_inputs(_device()) + + +@dataclasses.dataclass +class TimingStats: + median_ms: float + mean_ms: float + min_ms: float + max_ms: float + std_ms: float + num_reps: int + + +@dataclasses.dataclass +class MemoryStats: + peak_mib: float + final_mib: float + delta_peak_mib: float + + +@dataclasses.dataclass +class VariantResult: + variant_name: str + fwd_timing: TimingStats | None = None + fwd_bwd_timing: TimingStats | None = None + memory: MemoryStats | None = None + rms_errors: dict[str, float] | None = None # output name → RMS rel error vs reference + error: str | None = None # If the variant failed, the error message + + +# --------------------------------------------------------------------------- timing + + +def _make_cache_flusher(size_bytes: int = 256 * 1024 * 1024) -> Callable[[], None]: + """Allocate a scratch buffer larger than any GPU L2 and zero it between reps + to invalidate cached values (avoids over-optimistic timings).""" + if not torch.cuda.is_available(): + return lambda: None + buffer = torch.empty(size_bytes // 4, dtype=torch.int32, device="cuda") + + def flush() -> None: + buffer.zero_() + + return flush + + +def bench_fn( + fn: Callable[[], Any], + reset: Callable[[], None] | None = None, + warmup_ms: float = 25.0, + rep_ms: float = 100.0, + min_reps: int = 5, + max_reps: int = 10_000, +) -> TimingStats: + """Benchmark `fn` — it should be a no-arg callable that invokes the kernel + being timed (close over inputs). Returns timing statistics in ms. + + Mirrors `triton.testing.do_bench` logic but returns raw per-rep list so we + can compute {median, mean, min, max, std} from one set of runs. + """ + if not torch.cuda.is_available(): + # CPU / Triton interpret: single timed run with wall clock. + if reset is not None: + reset() + fn() # warmup + if reset is not None: + reset() + start = time.perf_counter() + fn() + elapsed_ms = (time.perf_counter() - start) * 1000 + return TimingStats(elapsed_ms, elapsed_ms, elapsed_ms, elapsed_ms, 0.0, 1) + + flush = _make_cache_flusher() + + # Warmup to JIT-compile, autotune, etc. + torch.cuda.synchronize() + warmup_start = torch.cuda.Event(enable_timing=True) + warmup_end = torch.cuda.Event(enable_timing=True) + if reset is not None: + reset() + warmup_start.record() + fn() + warmup_end.record() + torch.cuda.synchronize() + one_rep_ms = warmup_start.elapsed_time(warmup_end) + + # Additional warmup to stabilize (covers autotune misses on first call) + num_warmup = max(1, int(warmup_ms / max(one_rep_ms, 0.01))) + for _ in range(num_warmup): + if reset is not None: + reset() + fn() + torch.cuda.synchronize() + + # Re-estimate after warmup (autotune usually settles to a faster config). + post_start = torch.cuda.Event(enable_timing=True) + post_end = torch.cuda.Event(enable_timing=True) + if reset is not None: + reset() + post_start.record() + fn() + post_end.record() + torch.cuda.synchronize() + one_rep_ms = max(post_start.elapsed_time(post_end), 0.001) + + num_reps = max(min_reps, min(max_reps, int(rep_ms / one_rep_ms))) + + start_events = [torch.cuda.Event(enable_timing=True) for _ in range(num_reps)] + end_events = [torch.cuda.Event(enable_timing=True) for _ in range(num_reps)] + for i in range(num_reps): + if reset is not None: + reset() + flush() + start_events[i].record() + fn() + end_events[i].record() + torch.cuda.synchronize() + + times = [start_events[i].elapsed_time(end_events[i]) for i in range(num_reps)] + return TimingStats( + median_ms=statistics.median(times), + mean_ms=statistics.fmean(times), + min_ms=min(times), + max_ms=max(times), + std_ms=statistics.pstdev(times) if len(times) > 1 else 0.0, + num_reps=num_reps, + ) + + +# --------------------------------------------------------------------------- memory + + +def measure_memory(fn: Callable[[], Any]) -> MemoryStats: + """Run `fn` once and capture peak and final device memory. Must be called + on a fresh GPU state (the caller resets stats before constructing inputs).""" + if not torch.cuda.is_available(): + return MemoryStats(0.0, 0.0, 0.0) + torch.cuda.synchronize() + baseline = torch.cuda.memory_allocated() + torch.cuda.reset_peak_memory_stats() + result = fn() + torch.cuda.synchronize() + peak = torch.cuda.max_memory_allocated() + final = torch.cuda.memory_allocated() + # Hold onto the result until after the measurement so it stays in `final`. + del result + return MemoryStats( + peak_mib=peak / 1024 / 1024, + final_mib=final / 1024 / 1024, + delta_peak_mib=(peak - baseline) / 1024 / 1024, + ) + + +# --------------------------------------------------------------------------- correctness + + +def rms_relative_error(candidate: torch.Tensor, reference: torch.Tensor) -> float: + """Root-mean-squared error of `candidate - reference`, normalized by the + RMS of `reference`. Both are cast to fp32 before comparison.""" + cand = candidate.detach().float() + ref = reference.detach().float() + diff_rms = (cand - ref).pow(2).mean().sqrt().item() + ref_rms = ref.pow(2).mean().sqrt().item() + return diff_rms / max(ref_rms, 1e-30) + + +def _as_output_dict(output: Any) -> dict[str, torch.Tensor]: + """Normalize a variant's output into a {name: tensor} dict for comparison.""" + if isinstance(output, torch.Tensor): + return {"out": output} + if isinstance(output, dict): + return {k: v for k, v in output.items() if isinstance(v, torch.Tensor)} + if isinstance(output, (tuple, list)): + return {f"out{i}": v for i, v in enumerate(output) if isinstance(v, torch.Tensor)} + raise TypeError(f"Cannot extract tensors from variant output of type {type(output).__name__}") + + +# --------------------------------------------------------------------------- runner + + +def _run_one_variant( + variant: Variant, + case: Case, + reference_outputs: dict[str, dict[str, torch.Tensor]] | None, + warmup_ms: float, + rep_ms: float, + min_reps: int = 5, +) -> VariantResult: + result = VariantResult(variant_name=variant.name) + try: + # --- correctness + memory: one fresh invocation per mode + # fwd mode + if variant.fwd is not None: + inputs = _seeded_inputs(case) + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + def _fwd_once() -> Any: + return variant.fwd(inputs) + + _guarded_fwd = _guarded(_fwd_once) + + # First: correctness. Run once, capture output for comparison. + fwd_output = _guarded_fwd() + if torch.cuda.is_available(): + torch.cuda.synchronize() + + if reference_outputs is not None and not variant.is_reference: + ref_fwd = reference_outputs.get("fwd") + if ref_fwd is not None: + cand = _as_output_dict(fwd_output) + if variant.output_postprocess is not None: + cand = variant.output_postprocess(cand, inputs) + rms = {name: rms_relative_error(cand[name], ref_fwd[name]) for name in ref_fwd if name in cand} + result.rms_errors = (result.rms_errors or {}) | {f"fwd.{k}": v for k, v in rms.items()} + del fwd_output + + # Timing: reuse the same input tensors, fn closes over them. + _reset_fwd = (lambda: variant.reset_inputs(inputs)) if variant.reset_inputs else None + result.fwd_timing = bench_fn( + _guarded_fwd, reset=_reset_fwd, warmup_ms=warmup_ms, rep_ms=rep_ms, min_reps=min_reps + ) + del inputs + + # fwd+bwd mode + if variant.fwd_bwd is not None: + inputs = _seeded_inputs(case) + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + def _fwd_bwd_once() -> Any: + return variant.fwd_bwd(inputs) + + _guarded_fwd_bwd = _guarded(_fwd_bwd_once) + + fwd_bwd_output = _guarded_fwd_bwd() + if torch.cuda.is_available(): + torch.cuda.synchronize() + + if reference_outputs is not None and not variant.is_reference: + ref_fb = reference_outputs.get("fwd_bwd") + if ref_fb is not None: + cand = _as_output_dict(fwd_bwd_output) + if variant.output_postprocess is not None: + cand = variant.output_postprocess(cand, inputs) + rms = {name: rms_relative_error(cand[name], ref_fb[name]) for name in ref_fb if name in cand} + result.rms_errors = (result.rms_errors or {}) | {f"fb.{k}": v for k, v in rms.items()} + del fwd_bwd_output + + # Memory measurement: one fresh call on fresh inputs. + fresh_inputs = _seeded_inputs(case) + if torch.cuda.is_available(): + torch.cuda.empty_cache() + result.memory = measure_memory(_guarded(lambda: variant.fwd_bwd(fresh_inputs))) + del fresh_inputs + + # Timing. + _reset_fwd_bwd = (lambda: variant.reset_inputs(inputs)) if variant.reset_inputs else None + result.fwd_bwd_timing = bench_fn( + _guarded_fwd_bwd, reset=_reset_fwd_bwd, warmup_ms=warmup_ms, rep_ms=rep_ms, min_reps=min_reps + ) + del inputs + elif variant.fwd is not None and result.memory is None: + # No backward — measure fwd-mode memory. + fresh_inputs = _seeded_inputs(case) + if torch.cuda.is_available(): + torch.cuda.empty_cache() + result.memory = measure_memory(_guarded(lambda: variant.fwd(fresh_inputs))) + del fresh_inputs + except Exception as exc: # noqa: BLE001 + result.error = f"{type(exc).__name__}: {exc}" + return result + + +def _collect_reference_outputs( + variant: Variant, + case: Case, +) -> dict[str, dict[str, torch.Tensor]]: + out: dict[str, dict[str, torch.Tensor]] = {} + if variant.fwd is not None: + out["fwd"] = _as_output_dict(variant.fwd(_seeded_inputs(case))) + if variant.fwd_bwd is not None: + out["fwd_bwd"] = _as_output_dict(variant.fwd_bwd(_seeded_inputs(case))) + if torch.cuda.is_available(): + torch.cuda.synchronize() + # Detach+clone to guard against in-place mutation by later variants + return {mode: {k: v.detach().clone() for k, v in tensors.items()} for mode, tensors in out.items()} + + +# --------------------------------------------------------------------------- table + + +def _column_decimals(values: list[float | None]) -> int: + """Number of decimal places to give the smallest non-zero value in a column + at least 4 significant digits. Capped at 6 so one tiny value doesn't bloat + the whole column (e.g. 0.00001 alongside 100 would otherwise force 8 decimals).""" + nonzero = [abs(v) for v in values if v is not None and v != 0] + if not nonzero: + return 0 + min_magnitude = math.floor(math.log10(min(nonzero))) + return min(6, max(0, 3 - min_magnitude)) + + +def _format_aligned(values: list[float | None]) -> list[str]: + """Format a column with the same number of decimals for every entry, so + decimal points line up. Zeros get the trailing zeros too (e.g. '0.0000').""" + decimals = _column_decimals(values) + out: list[str] = [] + for value in values: + if value is None: + out.append("—") + else: + out.append(f"{value:.{decimals}f}") + return out + + +def _pick_unit(max_value: float, table: list[tuple[str, float]]) -> tuple[str, float]: + """Given a magnitude-ordered list of (unit_label, scale_to_unit) pairs and + the column's max absolute value, return the unit where max_value*scale is + in [1, 1000) when possible. `table` must be ordered by ascending magnitude + (largest unit last).""" + chosen_label, chosen_scale = table[0] + for label, scale in table: + if max_value * scale >= 1: + chosen_label, chosen_scale = label, scale + else: + break + return chosen_label, chosen_scale + + +# Each table is ordered ascending (small unit → large unit). `scale` converts +# from the canonical storage unit (ms / bytes-per-second / flops-per-second / +# MiB) into the display unit. +_TIME_UNITS = [("ns", 1e6), ("us", 1e3), ("ms", 1.0), ("s", 1e-3)] +_BANDWIDTH_UNITS = [("B/s", 1.0), ("KB/s", 1e-3), ("MB/s", 1e-6), ("GB/s", 1e-9), ("TB/s", 1e-12)] +_FLOPS_UNITS = [ + ("FLOP/s", 1.0), + ("KFLOP/s", 1e-3), + ("MFLOP/s", 1e-6), + ("GFLOP/s", 1e-9), + ("TFLOP/s", 1e-12), + ("PFLOP/s", 1e-15), +] +_MEMORY_UNITS = [("KiB", 1024.0), ("MiB", 1.0), ("GiB", 1 / 1024), ("TiB", 1 / 1024 / 1024)] + + +def _unit_column( + prefix: str, canonical_values: list[float | None], units: list[tuple[str, float]] +) -> tuple[str, list[str]]: + """Pick the best display unit for a column's magnitude and format with + aligned decimals. Header is ' '.""" + non_none = [abs(v) for v in canonical_values if v is not None] + max_value = max(non_none, default=0.0) + if max_value > 0: + label, scale = _pick_unit(max_value, units) + else: + # All values are zero / None. Fall back to the canonical unit (scale=1.0) + # so e.g. memory defaults to MiB rather than the middle of the table. + label, scale = next( + ((unit_label, unit_scale) for (unit_label, unit_scale) in units if unit_scale == 1.0), units[0] + ) + scaled = [v * scale if v is not None else None for v in canonical_values] + header = f"{prefix} {label}" if prefix else label + return header, _format_aligned(scaled) + + +def _percent_column(values: list[float | None]) -> list[str]: + """Format a column of ratios as aligned percentages.""" + scaled = [v * 100 if v is not None else None for v in values] + formatted = _format_aligned(scaled) + return [f if f == "—" else f"{f}%" for f in formatted] + + +def _rms_column(values: list[float | None]) -> list[str]: + """Align RMS errors in scientific notation with a shared exponent-free width.""" + decimals = 3 # 4 sig figs + out: list[str] = [] + for value in values: + if value is None: + out.append("—") + elif value == 0.0: + out.append(f"{0.0:.{decimals}e}") + else: + out.append(f"{value:.{decimals}e}") + return out + + +def _simplify_rms_key(key: str, all_keys: list[str]) -> str: + """Turn internal keys like 'fwd.out' / 'fb.loss' into concise display labels. + + Rules: + - strip the mode prefix ('fwd.'/'fb.') when all keys share the same mode + - rename 'fb' → 'bwd' for display when it survives + - drop the trailing '.out' / standalone 'out' (the placeholder key used + when a variant returns a single unnamed tensor) + """ + mode, _, tensor = key.partition(".") + all_modes = {k.partition(".")[0] for k in all_keys} + if len(all_modes) <= 1: + remainder = tensor + else: + pretty = "bwd" if mode == "fb" else mode + remainder = f"{pretty}.{tensor}" if tensor else pretty + if remainder == "out": + return "" + return remainder.removesuffix(".out") + + +def _rms_header(key: str, all_keys: list[str]) -> str: + simplified = _simplify_rms_key(key, all_keys) + return f"rel_rms({simplified})" if simplified else "rel_rms" + + +def _render_table( + case: Case, + results: list[VariantResult], + gpu_spec: GpuSpec | None, + has_fwd: bool, + has_fwd_bwd: bool, + rms_keys: list[str], + verbose: bool, +) -> str: + # First column header carries the case name so the per-case label and the + # variant-name column are merged into one (avoids a redundant title line). + columns: list[tuple[str, list[str]]] = [(case.name, [r.variant_name for r in results])] + + def _add(header: str, values: list[str]) -> None: + columns.append((header, values)) + + if has_fwd: + _add(*_unit_column("fwd", [r.fwd_timing.median_ms if r.fwd_timing else None for r in results], _TIME_UNITS)) + if verbose: + _add( + *_unit_column( + "fwd mean", [r.fwd_timing.mean_ms if r.fwd_timing else None for r in results], _TIME_UNITS + ) + ) + _add( + *_unit_column("fwd min", [r.fwd_timing.min_ms if r.fwd_timing else None for r in results], _TIME_UNITS) + ) + _add( + *_unit_column("fwd max", [r.fwd_timing.max_ms if r.fwd_timing else None for r in results], _TIME_UNITS) + ) + + if has_fwd_bwd: + # Backward-only derived: fwd+bwd − fwd. + bwd_values: list[float | None] = [] + total_values: list[float | None] = [] + for r in results: + if r.fwd_bwd_timing is None: + bwd_values.append(None) + total_values.append(None) + continue + total = r.fwd_bwd_timing.median_ms + bwd_values.append(total - r.fwd_timing.median_ms if r.fwd_timing else None) + total_values.append(total) + _add(*_unit_column("bwd", bwd_values, _TIME_UNITS)) + _add(*_unit_column("total", total_values, _TIME_UNITS)) + + def _time_for_throughput(r: VariantResult) -> float | None: + if r.fwd_bwd_timing is not None: + return r.fwd_bwd_timing.median_ms + if r.fwd_timing is not None: + return r.fwd_timing.median_ms + return None + + if case.expected_bytes is not None: + bandwidths: list[float | None] = [] + for r in results: + t_ms = _time_for_throughput(r) + bandwidths.append(case.expected_bytes / (t_ms / 1000) if t_ms is not None else None) + header, values = _unit_column("", bandwidths, _BANDWIDTH_UNITS) + _add(header, values) + if gpu_spec is not None: + peak_bytes_per_s = gpu_spec.peak_bandwidth_gbps * 1e9 + pct = [bw / peak_bytes_per_s if bw is not None else None for bw in bandwidths] + _add("%BW", _percent_column(pct)) + + if case.expected_flops is not None: + flop_rates: list[float | None] = [] + for r in results: + t_ms = _time_for_throughput(r) + flop_rates.append(case.expected_flops / (t_ms / 1000) if t_ms is not None else None) + header, values = _unit_column("", flop_rates, _FLOPS_UNITS) + _add(header, values) + peak_tflops = gpu_spec.peak_tflops(case.compute_dtype) if gpu_spec and case.compute_dtype else None + if peak_tflops is not None: + peak_flops_per_s = peak_tflops * 1e12 + pct = [fr / peak_flops_per_s if fr is not None else None for fr in flop_rates] + _add("%FLOPs", _percent_column(pct)) + + peak_mib = [r.memory.peak_mib if r.memory else None for r in results] + delta_mib = [r.memory.delta_peak_mib if r.memory else None for r in results] + _add(*_unit_column("peak", peak_mib, _MEMORY_UNITS)) + _add(*_unit_column("Δpeak", delta_mib, _MEMORY_UNITS)) + + for key in rms_keys: + _add(_rms_header(key, rms_keys), _rms_column([(r.rms_errors or {}).get(key) for r in results])) + + _add("error", [r.error or "" for r in results]) + # Drop the error column if nothing failed + if not any(r.error for r in results): + columns.pop() + + widths = [max(len(header), *(len(v) for v in values)) for header, values in columns] + separator = " " + + # First column (case name + variant names) is text — left-justify. All other + # columns are numeric — right-justify so decimal points line up across rows. + def _justify(text: str, width: int, column_index: int) -> str: + return text.ljust(width) if column_index == 0 else text.rjust(width) + + header_line = separator.join(_justify(h, w, i) for i, ((h, _), w) in enumerate(zip(columns, widths))) + divider = separator.join("-" * w for w in widths) + body_lines = [] + for row in range(len(results)): + body_lines.append( + separator.join(_justify(values[row], w, i) for i, ((_, values), w) in enumerate(zip(columns, widths))) + ) + return "\n".join([header_line, divider, *body_lines]) + + +# --------------------------------------------------------------------------- orchestration + + +def run_benchmark( + benchmark_name: str, + cases: list[Case], + variants: list[Variant], + *, + warmup_ms: float = 25.0, + rep_ms: float = 100.0, + min_reps: int = 5, + verbose: bool = False, + print_fn: Callable[[str], None] = print, +) -> list[tuple[Case, list[VariantResult]]]: + """Run all (case, variant) combinations and print one table per case. + + Exactly one variant should have `is_reference=True` — its outputs are the + ground truth for RMS-error comparisons. That variant should compute in + fp32, eager, using the most straightforward reference implementation.""" + reference = [v for v in variants if v.is_reference] + if len(reference) != 1: + raise ValueError( + f"Expected exactly one reference variant (is_reference=True), got {len(reference)}. " + f"Variants: {[v.name for v in variants]}" + ) + gpu_spec = detect_gpu_spec() + print_fn(header(benchmark_name)) + if gpu_spec is not None: + print_fn(f"gpu: {gpu_spec.name} (peak BW {gpu_spec.peak_bandwidth_gbps:.0f} GB/s)") + else: + print_fn("gpu: unknown (no %-of-peak columns)") + print_fn("") + + all_results: list[tuple[Case, list[VariantResult]]] = [] + for case in cases: + ref_outputs = _collect_reference_outputs(reference[0], case) + + results = [] + has_fwd = False + has_fwd_bwd = False + rms_keys_seen: list[str] = [] + for variant in variants: + r = _run_one_variant(variant, case, ref_outputs, warmup_ms=warmup_ms, rep_ms=rep_ms, min_reps=min_reps) + results.append(r) + has_fwd = has_fwd or r.fwd_timing is not None + has_fwd_bwd = has_fwd_bwd or r.fwd_bwd_timing is not None + for k in r.rms_errors or {}: + if k not in rms_keys_seen: + rms_keys_seen.append(k) + + print_fn( + _render_table( + case, + results, + gpu_spec=gpu_spec, + has_fwd=has_fwd, + has_fwd_bwd=has_fwd_bwd, + rms_keys=rms_keys_seen, + verbose=verbose, + ) + ) + print_fn("") + all_results.append((case, results)) + return all_results diff --git a/tools/benchmark/utils.py b/tools/benchmark/utils.py new file mode 100644 index 000000000..f18ee870a --- /dev/null +++ b/tools/benchmark/utils.py @@ -0,0 +1,193 @@ +import dataclasses +from collections.abc import Callable +from typing import Any + +import torch + +from fast_llm.engine.config_utils.data_type import DataType +from fast_llm.functional.config import TritonConfig +from tools.benchmark.runner import Inputs, Variant, run_benchmark + +DEFAULT_DTYPES: tuple[torch.dtype, ...] = (torch.bfloat16,) + + +def dtype_short(dtype: torch.dtype) -> str: + return DataType.from_torch(dtype).short + + +def bench_main(benchmarks_fn: Callable) -> Callable: + def run( + verbose: bool = False, + dtypes: tuple[torch.dtype, ...] | None = None, + shapes: list | None = None, + warmup_ms: float = 25.0, + rep_ms: float = 100.0, + min_reps: int = 5, + ) -> None: + for name, cases, variants in benchmarks_fn(dtypes or DEFAULT_DTYPES, shapes): + run_benchmark( + name, cases, variants, verbose=verbose, warmup_ms=warmup_ms, rep_ms=rep_ms, min_reps=min_reps + ) + + return run + + +@dataclasses.dataclass(kw_only=True) +class PytorchVariant(Variant): + """Variant that calls a pytorch function on inputs picked by key. Used for + eager, torch.compile, and apex variants — each instance differs in `function` + while sharing the dispatch logic.""" + + function: Callable + input_keys: tuple[str, ...] + grad_input_keys: tuple[str, ...] = () + grad_output_key: str | None = None + output_key: str = "output" + + def __post_init__(self) -> None: + # Wire the inherited `fwd`/`fwd_bwd` callable fields to bound methods + # so subclasses can override the methods without touching the fields. + self.fwd = self._fwd + self.fwd_bwd = self._fwd_bwd + + def _fwd(self, inputs: Inputs) -> dict: + return {self.output_key: self.function(*(inputs[k] for k in self.input_keys))} + + def _fwd_bwd(self, inputs: Inputs) -> dict: + output = self.function(*(inputs[k] for k in self.input_keys)) + if self.grad_output_key is None: + output.backward() + else: + output.backward(inputs[self.grad_output_key]) + result = {self.output_key: output.detach()} + for key in self.grad_input_keys: + result[f"grad_{key}"] = inputs[key].grad + return result + + +@dataclasses.dataclass(kw_only=True) +class Fp32ReferenceVariant(PytorchVariant): + """Reference variant: upcasts every floating-point input to fp32 before + running the eager pytorch function. Re-attaches `requires_grad=True` on + `grad_input_keys` so backward sees a leaf tensor.""" + + name: str = "fp32_reference" + is_reference: bool = True + + def _fwd(self, inputs: Inputs) -> dict: + return super()._fwd(self._to_fp32(inputs)) + + def _fwd_bwd(self, inputs: Inputs) -> dict: + return super()._fwd_bwd(self._to_fp32(inputs)) + + def _to_fp32(self, inputs: Inputs) -> Inputs: + result = dict(inputs) + for key, value in inputs.items(): + if isinstance(value, torch.Tensor) and value.is_floating_point(): + float_value = value.float().detach() + result[key] = float_value.requires_grad_(True) if key in self.grad_input_keys else float_value + return result + + +@dataclasses.dataclass(kw_only=True) +class FwdOnlyPytorchVariant(Variant): + """Forward-only variant: calls a pytorch function with positional args + extracted via `unpack`. Used by bench_pointwise where there's no backward.""" + + function: Callable + unpack: Callable[[Inputs], tuple] + + def __post_init__(self) -> None: + self.fwd = self._fwd + + def _fwd(self, inputs: Inputs) -> Any: + return self.function(*self.unpack(inputs)) + + +@dataclasses.dataclass(kw_only=True) +class Fp32FwdOnlyReferenceVariant(FwdOnlyPytorchVariant): + name: str = "fp32_reference" + is_reference: bool = True + + def _fwd(self, inputs: Inputs) -> Any: + args = tuple( + arg.float() if isinstance(arg, torch.Tensor) and arg.is_floating_point() else arg + for arg in self.unpack(inputs) + ) + return self.function(*args) + + +def standard_fwd_variants( + eager_function: Callable, + triton_function: Callable | None, + unpack: Callable[[Inputs], tuple], +) -> list[Variant]: + """fp32_reference, pytorch_eager, pytorch_compiled, pytorch_compiled_max, + and (if `TritonConfig.enabled()`) fast_llm_triton.""" + variants: list[Variant] = [ + Fp32FwdOnlyReferenceVariant(function=eager_function, unpack=unpack), + FwdOnlyPytorchVariant(name="pytorch_eager", function=eager_function, unpack=unpack), + FwdOnlyPytorchVariant( + name="pytorch_compiled", + function=torch.compile(eager_function, mode="default", dynamic=False), + unpack=unpack, + ), + FwdOnlyPytorchVariant( + name="pytorch_compiled_max", + function=torch.compile(eager_function, mode="max-autotune-no-cudagraphs", dynamic=False), + unpack=unpack, + ), + ] + if triton_function is not None and TritonConfig.enabled(): + variants.append( + FwdOnlyPytorchVariant( + name="fast_llm_triton", + function=lambda *args: triton_function(*args, use_triton=True), + unpack=unpack, + ) + ) + return variants + + +def standard_fwd_bwd_pytorch_variants( + eager_function: Callable, + input_keys: tuple[str, ...], + grad_input_keys: tuple[str, ...], + *, + grad_output_key: str | None = None, + output_key: str = "output", + reset_inputs: Callable[[Inputs], None] | None = None, + extra_functions: dict[str, Callable] | None = None, + eager_name: str = "pytorch_eager", + enable_max_autotune: bool = True, +) -> list[Variant]: + """fp32_reference + + pytorch_compiled + [pytorch_compiled_max] + + `extra_functions`. Triton variants are appended by the caller (so each + bench file owns its triton wiring explicitly).""" + common = { + "input_keys": input_keys, + "grad_input_keys": grad_input_keys, + "grad_output_key": grad_output_key, + "output_key": output_key, + "reset_inputs": reset_inputs, + } + variants: list[Variant] = [ + Fp32ReferenceVariant(function=eager_function, **common), + PytorchVariant(name=eager_name, function=eager_function, **common), + PytorchVariant( + name="pytorch_compiled", + function=torch.compile(eager_function, mode="default", dynamic=False), + **common, + ), + ] + if enable_max_autotune: + variants.append( + PytorchVariant( + name="pytorch_compiled_max", + function=torch.compile(eager_function, mode="max-autotune-no-cudagraphs", dynamic=False), + **common, + ) + ) + for name, function in (extra_functions or {}).items(): + variants.append(PytorchVariant(name=name, function=function, **common)) + return variants diff --git a/tools/inspect_rotary_compile.py b/tools/inspect_rotary_compile.py new file mode 100644 index 000000000..eac1624a2 --- /dev/null +++ b/tools/inspect_rotary_compile.py @@ -0,0 +1,60 @@ +""" +Dump the Triton kernel that torch.compile generates for the rotary embedding, +so we can compare it to our hand-written fast_llm kernel. + +Run on a GPU node: + python tools/inspect_rotary_compile.py +Output lands in /tmp/torchinductor_*/ (one subdir per compiled function). +This script also prints the path and first 200 lines of each .py file found. +""" + +import os +from pathlib import Path + +import torch +import torch._inductor.config as inductor_config + +# Route torch.compile output to a known directory. +_OUT = Path("/tmp/torchinductor_rotary_inspect") +_OUT.mkdir(parents=True, exist_ok=True) +os.environ["TORCHINDUCTOR_CACHE_DIR"] = str(_OUT) + + +inductor_config.debug = True # writes generated Triton .py files alongside the cache + +tokens, num_heads, head_size = 4096, 32, 128 +rotary_dim = head_size // 2 +dtype = torch.bfloat16 +device = "cuda" + +input_ = torch.randn(tokens, num_heads, head_size, dtype=dtype, device=device) +frequencies = torch.randn(tokens, 2 * rotary_dim, dtype=torch.float32, device=device) + + +def _rotary_eager(input_: torch.Tensor, frequencies: torch.Tensor) -> torch.Tensor: + rotary_dim = frequencies.shape[-1] // 2 + freq_re = frequencies[:, :rotary_dim].unsqueeze(1) + freq_im = frequencies[:, rotary_dim:].unsqueeze(1) + x_re, x_im = input_.chunk(2, dim=-1) + out_re = x_re * freq_re - x_im * freq_im + out_im = x_im * freq_re + x_re * freq_im + return torch.cat([out_re, out_im], dim=-1) + + +compiled = torch.compile(_rotary_eager, mode="default", dynamic=False) + +# Trigger compilation. +out = compiled(input_, frequencies) +torch.cuda.synchronize() +print(f"Output shape: {out.shape}, dtype: {out.dtype}") +print(f"\nInductor cache / debug output dir: {_OUT}") + +# Find and print the generated Triton kernel files. +for path in sorted(_OUT.rglob("*.py")): + print(f"\n{'='*80}") + print(f"FILE: {path}") + print("=" * 80) + lines = path.read_text().splitlines(keepends=True) + print("".join(lines[:300])) + if len(lines) > 300: + print(f"... ({len(lines) - 300} more lines)")