Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
c76ee4d
Add Triton kernel benchmark suite
jlamypoirier Apr 25, 2026
13678d5
Add sparse_linear benchmark and fix sparse matmul kernel correctness
jlamypoirier Apr 28, 2026
3a84431
Add GRPO loss and sparse copy benchmarks
jlamypoirier Apr 28, 2026
2c619aa
Fix sparse_copy benchmark: zero phantom rows before correctness compa…
jlamypoirier Apr 28, 2026
1949764
Fix sparse_copy benchmark correctness without polluting timing
jlamypoirier Apr 28, 2026
e3b4337
Fix benchmark timing loops: precompute backward gradients
jlamypoirier Apr 28, 2026
d521cf7
Fix sparse_linear backward_grad: zero phantom rows before precomputing
jlamypoirier Apr 28, 2026
d431d66
Restructure rotary kernel: full-head contiguous load + autotune head_…
jlamypoirier Apr 29, 2026
ae0ae12
Fix Triton rotary kernel: remove autotune, simplify to two-half load
jlamypoirier Apr 29, 2026
cc56b92
Revert "Fix Triton rotary kernel: remove autotune, simplify to two-ha…
jlamypoirier Apr 29, 2026
77e6b07
Revert "Restructure rotary kernel: full-head contiguous load + autotu…
jlamypoirier Apr 29, 2026
4e2b996
Revert "Revert "Restructure rotary kernel: full-head contiguous load …
jlamypoirier Apr 29, 2026
091ae5f
Revert "Revert "Fix Triton rotary kernel: remove autotune, simplify t…
jlamypoirier Apr 29, 2026
3b35693
Restore autotuned rotary kernel, add benchmark warmup
jlamypoirier Apr 29, 2026
7b6ce6f
Drop autotune from rotary kernel, add torch.compile inspection script
jlamypoirier Apr 29, 2026
cb3c161
Fix benchmark fairness: pre-allocate work buffer, reset between reps
jlamypoirier Apr 29, 2026
057efa0
Fix fwd_bwd benchmark bias: zero logits.grad between reps for PyTorch…
jlamypoirier Apr 30, 2026
c854b8b
Add fp32 references for sparse_copy and sparse_linear benchmarks
jlamypoirier Apr 30, 2026
c1b3cc3
Fix sparse warmup duplication; expose Apex availability flags publicly
jlamypoirier Apr 30, 2026
fdbf78b
Fix E741 ambiguous variable name; document dead GRPO masks
jlamypoirier Apr 30, 2026
2505fb5
Rename abbreviations and use functools.partial in benchmark suite
jlamypoirier Apr 30, 2026
031980c
Convert inspect_rotary_compile.py to pathlib
jlamypoirier Apr 30, 2026
71f1184
Add sparse linear autograd tests and benchmark smoke test
jlamypoirier Apr 30, 2026
a6c6e51
Remove benchmark smoke test
jlamypoirier Apr 30, 2026
e4eaab7
Fix two nits from rename pass
jlamypoirier Apr 30, 2026
2fbe934
Fix bench_fn: call reset() during warmup and calibration phases
jlamypoirier Apr 30, 2026
8441615
Add shapes parameter to bench run() functions and smoke tests
jlamypoirier Apr 30, 2026
8063175
Refactor benchmark tests: per-kernel parametrization, single timed rep
jlamypoirier Apr 30, 2026
df1730f
Skip compiled variants in benchmark smoke tests
jlamypoirier Apr 30, 2026
6ccfca3
Skip sparse benchmarks in Triton interpreter, revert broken histogram…
jlamypoirier Apr 30, 2026
7a668e8
Speed up benchmark smoke tests from 27s to ~4s
jlamypoirier May 1, 2026
e7f8f38
Merge remote-tracking branch 'origin/main' into worktree-triton_bench…
jlamypoirier May 1, 2026
d7d96d8
Mask padded entries in comparison instead of zeroing kernel inputs
jlamypoirier May 1, 2026
5da9509
Factor common variant/case scaffolding into utils helpers
jlamypoirier May 1, 2026
16eecb6
Trim per-file boilerplate from benchmark suite
jlamypoirier May 1, 2026
46d7277
Trim docstrings, inline normalization variant wrappers
jlamypoirier May 1, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions fast_llm/engine/config_utils/data_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,11 @@ def triton(self) -> "tl.dtype":
_set_triton_dtype_map()
return _TRITON_DTYPE_MAP[self]

@property
def short(self) -> str:
"""Abbreviated name (bf16, fp32, ...) when one is defined, else the full name."""
return _DTYPE_SHORT_NAME_MAP.get(self, self.value)


_KNOWN_DATA_TYPE_PREFIXES = {"DataType", "numpy", "np", "torch", "triton.language", "tl"}

Expand All @@ -92,6 +97,7 @@ def triton(self) -> "tl.dtype":
"fp16": DataType.float16,
"bf16": DataType.bfloat16,
}
_DTYPE_SHORT_NAME_MAP = {v: k for k, v in _DTYPE_ALT_NAME_MAP_INV.items()}

_TORCH_DTYPE_MAP: dict[DataType, "torch.dtype"] = {}
_TORCH_DTYPE_MAP_INV: dict["torch.dtype", DataType] = {}
Expand Down
12 changes: 8 additions & 4 deletions fast_llm/functional/triton/sparse_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,7 @@ def output_sparse_matmul_kernel(
expert_ends = tl.load(expert_ends_ptr + sparse_range, mask=sparse_range < sparse_dim, other=row_dim)
sparse_index = tl.sum((expert_ends <= row_offset).to(tl.int64)) # noqa
if sparse_index == sparse_dim:
# Phantom block past the last expert; the caller is expected to ignore these rows.
return
col_dense_offset = col_sparse_offset + sparse_index * col_sparse_dim

Expand All @@ -248,7 +249,7 @@ def output_sparse_matmul_kernel(
for k in range(1, inner_dim // block_size_inner):
lhs_ptr += block_size_inner * lhs_stride_inner
rhs_ptr += block_size_inner * rhs_stride_inner
out += tl.dot(tl.load(lhs_ptr), tl.load(rhs_ptr))
out += tl.dot(tl.load(lhs_ptr), tl.load(rhs_ptr), out_dtype=tl.float32)

if accumulate:
out += tl.load(out_ptr)
Expand Down Expand Up @@ -355,6 +356,7 @@ def input_inner_sparse_matmul_kernel(
expert_ends = tl.load(expert_ends_ptr + sparse_range, mask=sparse_range < sparse_dim, other=row_dim)
sparse_index = tl.sum((expert_ends <= row_offset).to(tl.int64)) # noqa
if sparse_index == sparse_dim:
# Phantom block past the last expert; the caller is expected to ignore these rows.
return
inner_dense_offset = sparse_index * inner_sparse_dim
col_offset = pid_col * block_size_col
Expand All @@ -374,7 +376,7 @@ def input_inner_sparse_matmul_kernel(
for k in range(1, inner_sparse_dim // block_size_inner):
lhs_ptr += block_size_inner * lhs_stride_inner
rhs_ptr += block_size_inner * rhs_stride_inner
out += tl.dot(tl.load(lhs_ptr), tl.load(rhs_ptr))
out += tl.dot(tl.load(lhs_ptr), tl.load(rhs_ptr), out_dtype=tl.float32)

if accumulate:
out += tl.load(out_ptr)
Expand Down Expand Up @@ -497,13 +499,15 @@ def input_row_sparse_matmul_kernel(
out = tl.dot(
tl.load(lhs_ptr + inner_range[None, :] * lhs_stride_inner, mask=mask[None, :], other=0),
tl.load(rhs_ptr + inner_range[:, None] * rhs_stride_inner, mask=mask[:, None], other=0),
out_dtype=tl.float32,
)
for i in range(1, tl.cdiv(inner_end - inner_offset, block_size_inner)):
inner_range += block_size_inner
mask = (inner_begin <= inner_range) & (inner_range < inner_end)
out += tl.dot(
tl.load(lhs_ptr + inner_range[None, :] * lhs_stride_inner, mask=mask[None, :], other=0),
tl.load(rhs_ptr + inner_range[:, None] * rhs_stride_inner, mask=mask[:, None], other=0),
out_dtype=tl.float32,
)

if accumulate:
Expand Down Expand Up @@ -578,7 +582,7 @@ def backward(ctx, grad_out: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, N
grad_out = grad_out.contiguous()
lhs, rhs = ctx.saved_tensors
grad_lhs = input_inner_sparse_matmul(grad_out, rhs.t(), ctx.sparse_map)
grad_rhs = input_row_sparse_matmul(lhs.t(), grad_out, ctx.sparse_map).t()
grad_rhs = input_row_sparse_matmul(grad_out.t(), lhs, ctx.sparse_map).t()
return grad_lhs, grad_rhs, None


Expand All @@ -597,7 +601,7 @@ def backward(ctx, grad_out: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, N
grad_out = grad_out.contiguous()
lhs, rhs = ctx.saved_tensors
grad_lhs = output_sparse_matmul(grad_out, rhs.t(), ctx.sparse_map)
grad_rhs = input_row_sparse_matmul(grad_out.t(), lhs, ctx.sparse_map)
grad_rhs = input_row_sparse_matmul(lhs.t(), grad_out, ctx.sparse_map)
return grad_lhs, grad_rhs, None


Expand Down
20 changes: 10 additions & 10 deletions fast_llm/layers/common/normalization/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,16 @@
try:
import fused_layer_norm_cuda # noqa

_fused_normalization_available = torch.cuda.is_available()
fused_normalization_available = torch.cuda.is_available()
except ImportError:
_fused_normalization_available = False
fused_normalization_available = False

try:
import fast_layer_norm # noqa

_fast_normalization_available = torch.cuda.is_available()
fast_normalization_available = torch.cuda.is_available()
except ImportError:
_fast_normalization_available = False
fast_normalization_available = False


try:
Expand Down Expand Up @@ -79,7 +79,7 @@ class FastLayerNorm(torch.autograd.Function):
def forward(
ctx, input_: torch.Tensor, normalized_shape: torch.Size, weight: torch.Tensor, bias: torch.Tensor, eps: float
) -> torch.Tensor: # noqa
assert _fast_normalization_available
assert fast_normalization_available
Assert.incl(normalized_shape.numel(), _PERSIST_LN_SIZES)
output, _, inv_var = fast_layer_norm.ln_fwd(input_, weight, bias, eps)
ctx.save_for_backward(output, weight, bias, inv_var)
Expand Down Expand Up @@ -110,7 +110,7 @@ class FusedLayerNorm(torch.autograd.Function):
def forward(
ctx, input_: torch.Tensor, normalized_shape: torch.Size, weight: torch.Tensor, bias: torch.Tensor, eps: float
) -> torch.Tensor: # noqa
assert _fused_normalization_available
assert fused_normalization_available
ctx.eps = eps
ctx.normalized_shape = normalized_shape
output, _, inv_var = fused_layer_norm_cuda.forward_affine(input_, normalized_shape, weight, bias, eps)
Expand All @@ -136,7 +136,7 @@ class FusedRMSNorm(torch.autograd.Function):
def forward(
ctx, input_: torch.Tensor, normalized_shape: torch.Size, weight: torch.Tensor, eps: float
) -> torch.Tensor: # noqa
assert _fused_normalization_available
assert fused_normalization_available
ctx.eps = eps
ctx.normalized_shape = normalized_shape
output, inv_var = fused_layer_norm_cuda.rms_forward_affine(input_, normalized_shape, weight, eps)
Expand Down Expand Up @@ -185,15 +185,15 @@ def __init__(self, config: ConfigType, hidden_dim: TensorDim, lr_scale: float |
implementation = self._config.implementation
if implementation == NormalizationImplementation.auto:
if (
_fast_normalization_available
fast_normalization_available
and hidden_dim.size in _PERSIST_LN_SIZES
and not self._config.zero_centered
):
implementation = NormalizationImplementation.fast
elif TritonConfig.enabled(torch.device("cuda")) or self._config.zero_centered:
log_main_rank("Fast layer norm unavailable, using backup triton implementation.")
implementation = NormalizationImplementation.triton
elif _fused_normalization_available:
elif fused_normalization_available:
log_main_rank("Fast layer norm unavailable, using backup fused implementation.")
implementation = NormalizationImplementation.fused
else:
Expand Down Expand Up @@ -261,7 +261,7 @@ def __init__(self, config: ConfigType, hidden_dim: TensorDim, lr_scale: float |
if implementation == NormalizationImplementation.auto:
if TritonConfig.enabled(torch.device("cuda")) or self._config.zero_centered:
implementation = NormalizationImplementation.triton
elif _fused_normalization_available:
elif fused_normalization_available:
log_main_rank("Triton RMS norm unavailable, using fused implementation.")
implementation = NormalizationImplementation.fused
else:
Expand Down
82 changes: 81 additions & 1 deletion tests/functional/test_sparse_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

from fast_llm.functional.triton.sparse_copy import SparseMap
from fast_llm.functional.triton.sparse_linear import (
InputSparseLinear,
OutputSparseLinear,
dense_matmul,
input_inner_sparse_matmul,
input_row_sparse_matmul,
Expand Down Expand Up @@ -74,7 +76,24 @@ def normal(self, dim_0: int, dim_1: int, device: torch.device) -> torch.Tensor:
dense_dim=256,
sparse_dim=512,
expert_ends=(128, 256, 256, 384),
tokens_per_expert=(52, 125, 0, 97),
tokens_per_expert=(52, 125, 0, 97), # expert 2 has zero real tokens
),
# Single expert — the simplest non-trivial case; also exercises the no-padding path.
_SparseTestData(
dense_dim=512,
sparse_dim=256,
expert_ends=(256,),
tokens_per_expert=(200,),
),
# Four experts, fully packed (no padding rows) — exercises the pad_begin == expert_end path.
# Expert sizes must be multiples of the largest autotune block_size_row (128); otherwise
# blocks straddle expert boundaries and the kernel's "sparse_index constant within a block"
# assumption breaks.
_SparseTestData(
dense_dim=384,
sparse_dim=128,
expert_ends=(128, 256, 384, 512),
tokens_per_expert=(128, 128, 128, 128),
),
)

Expand Down Expand Up @@ -151,3 +170,64 @@ def test_input_row_sparse_matmul(sparse_test_data, testing_device):
)

Assert.rms_close(output, output_ref, 1e-3)


# --------------------------------------------------------------------------- autograd wrappers


def _sparse_linear_ref(lhs: torch.Tensor, rhs: torch.Tensor, data: _SparseTestData, expert_axis: int) -> torch.Tensor:
"""Per-expert matmul reference; rows past `expert_pad_begins` are zero in the output."""
rhs_per_expert = rhs.chunk(data.num_experts, dim=expert_axis)
out = lhs.new_zeros(data.token_dim, rhs_per_expert[0].shape[1])
for i, (begin, end) in enumerate(zip(data.expert_begins, data.expert_pad_begins, strict=True)):
out[begin:end] = lhs[begin:end] @ rhs_per_expert[i]
return out


def _zero_padded_rows(tensor: torch.Tensor, data: _SparseTestData) -> torch.Tensor:
# The autograd kernels treat padded tokens as regular ones; forward output and grad_lhs
# contain matmul-of-random garbage in [pad_begin, expert_end). Zero those rows so the
# comparison vs the reference (which produces zeros there) reflects only real-row error.
masked = tensor.clone()
for begin, end in zip(data.expert_pad_begins, data.expert_ends, strict=True):
if end > begin:
masked[begin:end] = 0
return masked


@requires_triton
@pytest.mark.slow
@pytest.mark.parametrize("sparse_test_data", _SPARSE_TEST_DATAS)
@pytest.mark.parametrize(
"autograd_class,expert_axis",
[(OutputSparseLinear, 1), (InputSparseLinear, 0)],
ids=["output_sparse", "input_sparse"],
)
def test_sparse_linear_autograd(sparse_test_data, testing_device, autograd_class, expert_axis):
# `expert_axis` is the rhs axis split per expert. Matmul contracts rhs's axis 0, so
# OutputSparseLinear (expert_axis=1) splits the output dim, InputSparseLinear (expert_axis=0)
# splits the contracting dim.
if expert_axis == 1:
lhs_features, out_features = sparse_test_data.dense_dim, sparse_test_data.sparse_dim
rhs_shape = (sparse_test_data.dense_dim, sparse_test_data.sparse_dim_expanded)
else:
lhs_features, out_features = sparse_test_data.sparse_dim, sparse_test_data.dense_dim
rhs_shape = (sparse_test_data.sparse_dim_expanded, sparse_test_data.dense_dim)

lhs = sparse_test_data.normal(sparse_test_data.token_dim, lhs_features, testing_device)
rhs = sparse_test_data.normal(*rhs_shape, testing_device)
grad_output = sparse_test_data.normal(sparse_test_data.token_dim, out_features, testing_device)

lhs_ref = lhs.detach().requires_grad_(True)
rhs_ref = rhs.detach().requires_grad_(True)
out_ref = _sparse_linear_ref(lhs_ref, rhs_ref, sparse_test_data, expert_axis)
out_ref.backward(grad_output)

lhs_t = lhs.detach().requires_grad_(True)
rhs_t = rhs.detach().requires_grad_(True)
out_t = autograd_class.apply(lhs_t, rhs_t, sparse_test_data.get_sparse_map(testing_device))
out_t.backward(grad_output)

Assert.rms_close(_zero_padded_rows(out_t, sparse_test_data), out_ref, 1e-3)
Assert.rms_close(_zero_padded_rows(lhs_t.grad, sparse_test_data), lhs_ref.grad, 1e-3)
Assert.rms_close(rhs_t.grad, rhs_ref.grad, 1e-3)
94 changes: 94 additions & 0 deletions tests/tools/test_triton_benchmark.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
"""
Smoke tests for all benchmark modules.

One test per sub-benchmark (kernel): inputs are tiny so the runner code path is
exercised quickly without requiring a full benchmark run.

Patches applied to keep each test under ~100 ms:
- torch.compile disabled (avoids JIT cold-start).
- fast_llm_triton variants replaced with fp32 reference (no Triton compilation;
kernel correctness is covered by the main test suite).
- TritonConfig.enabled → False (prevents make_inputs warmup in sparse_linear).
- _cudagraph_mark_step_begin → None and synchronize → no-op (both cause C-level
CUDA syncs per fn() call that dominate the wall time without this).
"""

import dataclasses

import pytest
import torch

import tools.benchmark.runner as _bench_runner
from fast_llm.functional.config import TritonConfig
from fast_llm.functional.triton import triton_interpret
from tools.benchmark import (
bench_entropy_loss,
bench_grpo_loss,
bench_mlp_activation,
bench_normalization,
bench_pointwise,
bench_rotary,
bench_sparse_copy,
bench_sparse_linear,
)
from tools.benchmark.runner import run_benchmark

_DTYPES = (torch.float32,)

# sparse_copy and sparse_linear use tl.histogram, which has unfixed bugs in the
# Triton interpreter. Skip them in interpreter mode; they're covered on GPU.
_INTERPRETER_SKIP = {
"sparse_copy: dispatch",
"sparse_copy: combine",
"sparse_linear: output_sparse (layer 1 / up-proj)",
"sparse_linear: input_inner_sparse (layer 2 / down-proj)",
}

_SKIP_VARIANTS = {"pytorch_compiled", "pytorch_compiled_max"}


def _build_params() -> list:
modules_and_shapes = [
(bench_entropy_loss, {"shapes": [(64, 256)]}),
(bench_grpo_loss, {"shapes": [(64, 256)]}),
(bench_mlp_activation, {"shapes": [(64, 128)]}),
(bench_normalization, {"shapes": [(64, 128)]}),
(bench_pointwise, {"shapes": [1024]}),
(bench_rotary, {"shapes": [(64, 4, 64)]}),
(bench_sparse_copy, {"shapes": [(64, 2, 4, 128)]}),
(bench_sparse_linear, {"shapes": [(64, 2, 4, 256, 256)]}),
]
params = []
for module, kwargs in modules_and_shapes:
for name, cases, variants in module.benchmarks(dtypes=_DTYPES, **kwargs):
params.append(pytest.param(name, cases, variants, id=name))
return params


_PARAMS = _build_params()


@pytest.fixture(autouse=True)
def _patch_benchmark_env(monkeypatch):
import torch._dynamo

monkeypatch.setattr(torch._dynamo.config, "disable", True)
monkeypatch.setattr(TritonConfig, "enabled", lambda *a, **kw: False)
monkeypatch.setattr(_bench_runner, "_cudagraph_mark_step_begin", None)
monkeypatch.setattr(torch.cuda, "synchronize", lambda: None)


@pytest.mark.parametrize("name,cases,variants", _PARAMS)
def test_triton_benchmark(name, cases, variants):
if triton_interpret and name in _INTERPRETER_SKIP:
pytest.skip("tl.histogram is broken in the Triton interpreter")

# Replace fast_llm_triton fwd/fwd_bwd with the fp32 reference so no Triton
# kernels are compiled. The runner still exercises the full variant code path.
variants = [v for v in variants if v.name not in _SKIP_VARIANTS]
ref = next(v for v in variants if v.is_reference)
variants = [
dataclasses.replace(v, fwd=ref.fwd, fwd_bwd=ref.fwd_bwd) if v.name == "fast_llm_triton" else v
for v in variants
]
run_benchmark(name, cases, variants, warmup_ms=0, rep_ms=0, min_reps=1)
Empty file added tools/__init__.py
Empty file.
Empty file added tools/benchmark/__init__.py
Empty file.
Loading
Loading