Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 10 additions & 5 deletions src/kernels_bench/bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,15 @@ def _resolve_workload(value: WorkloadSize | None, params: dict[str, int]) -> int
return value(params) if callable(value) else value


def param_combinations(params: dict[str, list[int]]) -> list[dict[str, int]]:
"""Cross-product of param value lists into a list of concrete dicts."""
if not params:
return [{}]
keys = sorted(params.keys())
values = [params[k] for k in keys]
return [dict(zip(keys, combo, strict=True)) for combo in itertools.product(*values)]


def auto_bytes(specs: list[TensorSpec]) -> int:
"""Sum tensor bytes across specs.

Expand Down Expand Up @@ -105,11 +114,7 @@ def fn(self, func: Callable[..., Any]) -> Callable[..., Any]:

def _param_combinations(self) -> list[dict[str, int]]:
"""Generate all combinations of param values."""
if not self.params:
return [{}]
keys = sorted(self.params.keys())
values = [self.params[k] for k in keys]
return [dict(zip(keys, combo, strict=True)) for combo in itertools.product(*values)]
return param_combinations(self.params)

def run(
self,
Expand Down
134 changes: 98 additions & 36 deletions src/kernels_bench/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from __future__ import annotations

import csv
import importlib.util
import json
import os
Expand Down Expand Up @@ -45,7 +46,11 @@ def _parse_arg(arg: str) -> TensorSpec:
)

name = parts[0]
shape = tuple(int(d) for d in parts[1].split(","))
# Each dim is either a literal int or a symbolic identifier (e.g. "M") that
# gets resolved per param combo when --sweep is used.
shape: tuple[int | str, ...] = tuple(
int(d) if d.lstrip("-").isdigit() else d for d in parts[1].split(",")
)
dtype = DTYPE_MAP.get(parts[2], torch.float16) if len(parts) > 2 else torch.float16
role = parts[3] if len(parts) > 3 else "input"

Expand All @@ -59,6 +64,20 @@ def _parse_arg(arg: str) -> TensorSpec:
return TensorSpec(name, shape=shape, dtype=dtype, role=role)


def _parse_sweep(s: str) -> tuple[str, list[int]]:
"""Parse a --sweep spec like 'M=512,1024,2048' into ('M', [512, 1024, 2048])."""
if "=" not in s:
raise click.ClickException(f"invalid --sweep {s!r}, expected KEY=v1,v2,...")
key, values = s.split("=", 1)
key = key.strip()
if not key.isidentifier():
raise click.ClickException(f"sweep key {key!r} must be a valid identifier")
try:
return key, [int(v) for v in values.split(",")]
except ValueError as e:
raise click.ClickException(f"sweep values for {key!r} must be ints: {e}") from e


def _load_bench_from_file(path: str) -> Bench:
"""Import a Python file and find the Bench instance in it."""
filepath = Path(path).resolve()
Expand All @@ -80,11 +99,23 @@ def _load_bench_from_file(path: str) -> Bench:
return benches[0]


def _write_csv(result: BenchResult, path: Path) -> None:
header, rows = result.to_csv_rows()
with path.open("w", newline="") as f:
writer = csv.DictWriter(f, fieldnames=header)
writer.writeheader()
writer.writerows(rows)


def _handle_output(result: BenchResult, output: str | None) -> None:
"""Print results to terminal and optionally write JSON to file."""
"""Print results to terminal and optionally write to file (JSON or CSV)."""
print_results(result)
if output:
Path(output).write_text(json.dumps(result.to_dict(), indent=2))
path = Path(output)
if path.suffix.lower() == ".csv":
_write_csv(result, path)
else:
path.write_text(json.dumps(result.to_dict(), indent=2))
click.echo(f"\nResults saved to {output}")


Expand All @@ -110,7 +141,7 @@ def main() -> None:
"--output",
"-o",
default=None,
help="Write results to a JSON file.",
help="Write results to a file. JSON by default, CSV when path ends with .csv.",
)
@click.option("--validate", is_flag=True, help="Validate output correctness across kernels.")
@click.option("--atol", default=1e-3, show_default=True, help="Absolute tolerance for validation.")
Expand Down Expand Up @@ -188,7 +219,7 @@ def run(
"--output",
"-o",
default=None,
help="Write results to a JSON file.",
help="Write results to a file. JSON by default, CSV when path ends with .csv.",
)
@click.option("--validate", is_flag=True, help="Validate output correctness across kernels.")
@click.option("--atol", default=1e-3, show_default=True, help="Absolute tolerance for validation.")
Expand Down Expand Up @@ -217,6 +248,16 @@ def run(
default=None,
help="Bytes moved per kernel call. Defaults to the sum of input+output tensor sizes.",
)
@click.option(
"--sweep",
"sweeps",
multiple=True,
help=(
"Sweep a symbolic dim across values, e.g. --sweep M=512,1024,2048. "
"Repeat for a multi-dim grid. Use symbolic names in --arg shapes "
"(e.g. x:M,M:float16:input) to bind them."
),
)
def quick(
kernels: str,
fn: str,
Expand All @@ -231,6 +272,7 @@ def quick(
profile: bool,
flops: int | None,
bytes_per_iter: int | None,
sweeps: tuple[str, ...],
) -> None:
"""Benchmark a kernel function directly — no bench file needed.

Expand All @@ -243,15 +285,28 @@ def quick(
"""
from kernels import get_kernel

from kernels_bench.bench import auto_bytes
from kernels_bench.bench import auto_bytes, param_combinations
from kernels_bench.progress import benchmark_progress, make_on_step
from kernels_bench.runner import KernelResult, run_benchmark_quick
from kernels_bench.runner import KernelResult, _resolve_specs, run_benchmark_quick
from kernels_bench.runtime import detect_runtime
from kernels_bench.validate import validate_quick

specs = [_parse_arg(a) for a in args]
if bytes_per_iter is None:
bytes_per_iter = auto_bytes(specs)
sweep_dict: dict[str, list[int]] = dict(_parse_sweep(s) for s in sweeps)

# Sanity-check sweep keys vs. symbolic dims used in --arg.
required = set().union(*(s.symbolic_dims for s in specs)) if specs else set()
provided = set(sweep_dict.keys())
missing = required - provided
if missing:
raise click.ClickException(
f"symbolic dims {sorted(missing)} appear in --arg but no --sweep was given for them"
)
unused = provided - required
if unused:
raise click.ClickException(f"--sweep keys {sorted(unused)} are not used in any --arg shape")

combos = param_combinations(sweep_dict)
kernel_list = [k.strip() for k in kernels.split(",")]
runtime = detect_runtime()

Expand All @@ -263,13 +318,13 @@ def quick(
except Exception as e:
raise click.ClickException(f"failed to load kernel {kernel_id!r}: {e}") from e

# Validation
# Validation: resolve specs against the first combo (consistent with Bench file mode).
validation = None
if validate and len(loaded_kernels) > 1:
validation = validate_quick(
kernels=loaded_kernels,
fn_name=fn,
specs=specs,
specs=_resolve_specs(specs, combos[0]),
runtime=runtime,
atol=atol,
rtol=rtol,
Expand All @@ -278,32 +333,39 @@ def quick(
all_results: list[KernelResult] = []
with benchmark_progress() as progress:
for kernel_id, kernel in loaded_kernels.items():
warmup_tid = progress.add_task(f"{kernel_id} warmup", total=warmup)
bench_tid = progress.add_task(f"{kernel_id} bench", total=iterations)
on_step = make_on_step(progress, warmup_tid, bench_tid)
times, metrics, compile_ms = run_benchmark_quick(
kernel=kernel,
fn_name=fn,
specs=specs,
warmup=warmup,
iterations=iterations,
runtime=runtime,
on_step=on_step,
collect_metrics=not no_metrics,
profile=profile,
profile_label=kernel_id,
)
all_results.append(
KernelResult(
kernel_id=kernel_id,
params={},
times_ms=times,
metrics=metrics,
compile_ms=compile_ms,
flops=flops,
bytes_per_iter=bytes_per_iter,
for params in combos:
resolved = _resolve_specs(specs, params)
params_str = ", ".join(f"{k}={v}" for k, v in sorted(params.items()))
label = f"{kernel_id}" + (f" ({params_str})" if params_str else "")

warmup_tid = progress.add_task(f"{label} warmup", total=warmup)
bench_tid = progress.add_task(f"{label} bench", total=iterations)
on_step = make_on_step(progress, warmup_tid, bench_tid)
times, metrics, compile_ms = run_benchmark_quick(
kernel=kernel,
fn_name=fn,
specs=resolved,
warmup=warmup,
iterations=iterations,
runtime=runtime,
on_step=on_step,
collect_metrics=not no_metrics,
profile=profile,
profile_label=label,
)
all_results.append(
KernelResult(
kernel_id=kernel_id,
params=params,
times_ms=times,
metrics=metrics,
compile_ms=compile_ms,
flops=flops,
bytes_per_iter=(
bytes_per_iter if bytes_per_iter is not None else auto_bytes(resolved)
),
)
)
)

result = BenchResult(
bench_name=fn,
Expand Down
53 changes: 53 additions & 0 deletions src/kernels_bench/display.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,54 @@ def _format_metrics(m: RunMetrics) -> str | None:
return " ".join(parts)


SWEEP_SUMMARY_THRESHOLD = 4
"""Switch to compact per-kernel summary when there are this many param combos or more."""


def _print_sweep_summary(
result: BenchResult,
total_width: int,
) -> None:
"""One row per param combo per kernel — readable at a glance for sweeps.

Grouped visually by kernel so scaling across the swept dim is easy to spot.
Bar width is normalized within each kernel to highlight relative cost.
"""
by_kernel: dict[str, list[KernelResult]] = {}
for kr in result.kernel_results:
by_kernel.setdefault(kr.kernel_id, []).append(kr)

# Label column width sized to the longest params string so the bars line up.
max_params_len = max(len(_format_params(kr.params)) for kr in result.kernel_results)
label_width = max(max_params_len + 1, 12)
inner = total_width - 2
data_col_width = inner - label_width - 3
bar_width = max(data_col_width - 24, 16) # leave room for "0.000 ms " + GB/s suffix

kernels = list(by_kernel.items())
for ki, (kernel_id, rows) in enumerate(kernels):
_print_divider(total_width, "section")
_print_centered(kernel_id, total_width)
_print_row_divider(total_width, label_width, "top")

slowest = max(r.median_ms for r in rows)
for i, kr in enumerate(rows):
params_str = _format_params(kr.params) or "(no params)"
bar = _make_bar(kr.median_ms, slowest, bar_width)
median_text = f"{kr.median_ms:.3f} ms"
value = f"{median_text} {bar}"
if kr.gb_per_s is not None:
value += f" {DIM}{kr.gb_per_s:.1f} GB/s{RESET}{COLOR}"
_print_row(params_str, value, total_width, label_width)
if i < len(rows) - 1:
_print_row_divider(total_width, label_width)

# Last kernel uses "bottom" to close the outer box; intermediate
# kernels use a section-style divider so the next block flows cleanly.
is_last = ki == len(kernels) - 1
_print_row_divider(total_width, label_width, "bottom" if is_last else "mid")


def print_results(result: BenchResult) -> None:
"""Print benchmark results in hf-mem box-drawing style."""
kernel_results = result.kernel_results
Expand Down Expand Up @@ -254,6 +302,11 @@ def print_results(result: BenchResult) -> None:
)
_print_row_divider(total_width, label_width, "bottom")

# Compact view for large sweeps — full per-combo blocks become unreadable.
if len(param_groups) >= SWEEP_SUMMARY_THRESHOLD:
_print_sweep_summary(result, total_width)
return

for group_key, group_results in param_groups.items():
_print_divider(total_width, "section")
if group_key:
Expand Down
41 changes: 41 additions & 0 deletions src/kernels_bench/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,47 @@ def fastest(self, params: dict[str, int] | None = None) -> KernelResult:
candidates = [r for r in candidates if r.params == params]
return min(candidates, key=lambda r: r.median_ms)

def to_csv_rows(self) -> tuple[list[str], list[dict[str, Any]]]:
"""Flatten results into (header, rows) for CSV export.

Each row is one (kernel_id, params) pair. Param keys vary by run, so we
union them across all results to keep the column set stable.
"""
param_keys = sorted({k for kr in self.kernel_results for k in kr.params})
header = [
"kernel_id",
*param_keys,
"median_ms",
"p10_ms",
"p90_ms",
"iqr_ms",
"has_warnings",
"compile_ms",
"gflops_per_s",
"gb_per_s",
"peak_memory_mb",
"util_mean",
"util_peak",
]
rows: list[dict[str, Any]] = []
for kr in self.kernel_results:
row: dict[str, Any] = {"kernel_id": kr.kernel_id}
for k in param_keys:
row[k] = kr.params.get(k, "")
row["median_ms"] = kr.median_ms
row["p10_ms"] = kr.p10_ms
row["p90_ms"] = kr.p90_ms
row["iqr_ms"] = kr.iqr_ms
row["has_warnings"] = kr.has_warnings
row["compile_ms"] = kr.compile_ms
row["gflops_per_s"] = kr.gflops_per_s
row["gb_per_s"] = kr.gb_per_s
row["peak_memory_mb"] = kr.metrics.peak_memory_mb
row["util_mean"] = kr.metrics.util_mean
row["util_peak"] = kr.metrics.util_peak
rows.append(row)
return header, rows

def to_dict(self) -> dict[str, Any]:
"""Serialize results to a dict suitable for JSON export."""
return {
Expand Down
13 changes: 12 additions & 1 deletion tests/test_bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import pytest
import torch

from kernels_bench.bench import Bench, _resolve_workload, auto_bytes
from kernels_bench.bench import Bench, _resolve_workload, auto_bytes, param_combinations
from kernels_bench.spec import TensorSpec


Expand Down Expand Up @@ -119,3 +119,14 @@ def test_auto_bytes_sums_specs():

def test_auto_bytes_empty():
assert auto_bytes([]) == 0


def test_param_combinations_empty():
assert param_combinations({}) == [{}]


def test_param_combinations_grid():
combos = param_combinations({"M": [1, 2], "N": [10, 20]})
assert {"M": 1, "N": 10} in combos
assert {"M": 2, "N": 20} in combos
assert len(combos) == 4
Loading