diff --git a/src/kernels_bench/bench.py b/src/kernels_bench/bench.py index bb12947..f8264c2 100644 --- a/src/kernels_bench/bench.py +++ b/src/kernels_bench/bench.py @@ -25,6 +25,15 @@ def _resolve_workload(value: WorkloadSize | None, params: dict[str, int]) -> int return value(params) if callable(value) else value +def param_combinations(params: dict[str, list[int]]) -> list[dict[str, int]]: + """Cross-product of param value lists into a list of concrete dicts.""" + if not params: + return [{}] + keys = sorted(params.keys()) + values = [params[k] for k in keys] + return [dict(zip(keys, combo, strict=True)) for combo in itertools.product(*values)] + + def auto_bytes(specs: list[TensorSpec]) -> int: """Sum tensor bytes across specs. @@ -105,11 +114,7 @@ def fn(self, func: Callable[..., Any]) -> Callable[..., Any]: def _param_combinations(self) -> list[dict[str, int]]: """Generate all combinations of param values.""" - if not self.params: - return [{}] - keys = sorted(self.params.keys()) - values = [self.params[k] for k in keys] - return [dict(zip(keys, combo, strict=True)) for combo in itertools.product(*values)] + return param_combinations(self.params) def run( self, diff --git a/src/kernels_bench/cli.py b/src/kernels_bench/cli.py index a6faf18..19cb75a 100644 --- a/src/kernels_bench/cli.py +++ b/src/kernels_bench/cli.py @@ -2,6 +2,7 @@ from __future__ import annotations +import csv import importlib.util import json import os @@ -45,7 +46,11 @@ def _parse_arg(arg: str) -> TensorSpec: ) name = parts[0] - shape = tuple(int(d) for d in parts[1].split(",")) + # Each dim is either a literal int or a symbolic identifier (e.g. "M") that + # gets resolved per param combo when --sweep is used. + shape: tuple[int | str, ...] = tuple( + int(d) if d.lstrip("-").isdigit() else d for d in parts[1].split(",") + ) dtype = DTYPE_MAP.get(parts[2], torch.float16) if len(parts) > 2 else torch.float16 role = parts[3] if len(parts) > 3 else "input" @@ -59,6 +64,20 @@ def _parse_arg(arg: str) -> TensorSpec: return TensorSpec(name, shape=shape, dtype=dtype, role=role) +def _parse_sweep(s: str) -> tuple[str, list[int]]: + """Parse a --sweep spec like 'M=512,1024,2048' into ('M', [512, 1024, 2048]).""" + if "=" not in s: + raise click.ClickException(f"invalid --sweep {s!r}, expected KEY=v1,v2,...") + key, values = s.split("=", 1) + key = key.strip() + if not key.isidentifier(): + raise click.ClickException(f"sweep key {key!r} must be a valid identifier") + try: + return key, [int(v) for v in values.split(",")] + except ValueError as e: + raise click.ClickException(f"sweep values for {key!r} must be ints: {e}") from e + + def _load_bench_from_file(path: str) -> Bench: """Import a Python file and find the Bench instance in it.""" filepath = Path(path).resolve() @@ -80,11 +99,23 @@ def _load_bench_from_file(path: str) -> Bench: return benches[0] +def _write_csv(result: BenchResult, path: Path) -> None: + header, rows = result.to_csv_rows() + with path.open("w", newline="") as f: + writer = csv.DictWriter(f, fieldnames=header) + writer.writeheader() + writer.writerows(rows) + + def _handle_output(result: BenchResult, output: str | None) -> None: - """Print results to terminal and optionally write JSON to file.""" + """Print results to terminal and optionally write to file (JSON or CSV).""" print_results(result) if output: - Path(output).write_text(json.dumps(result.to_dict(), indent=2)) + path = Path(output) + if path.suffix.lower() == ".csv": + _write_csv(result, path) + else: + path.write_text(json.dumps(result.to_dict(), indent=2)) click.echo(f"\nResults saved to {output}") @@ -110,7 +141,7 @@ def main() -> None: "--output", "-o", default=None, - help="Write results to a JSON file.", + help="Write results to a file. JSON by default, CSV when path ends with .csv.", ) @click.option("--validate", is_flag=True, help="Validate output correctness across kernels.") @click.option("--atol", default=1e-3, show_default=True, help="Absolute tolerance for validation.") @@ -188,7 +219,7 @@ def run( "--output", "-o", default=None, - help="Write results to a JSON file.", + help="Write results to a file. JSON by default, CSV when path ends with .csv.", ) @click.option("--validate", is_flag=True, help="Validate output correctness across kernels.") @click.option("--atol", default=1e-3, show_default=True, help="Absolute tolerance for validation.") @@ -217,6 +248,16 @@ def run( default=None, help="Bytes moved per kernel call. Defaults to the sum of input+output tensor sizes.", ) +@click.option( + "--sweep", + "sweeps", + multiple=True, + help=( + "Sweep a symbolic dim across values, e.g. --sweep M=512,1024,2048. " + "Repeat for a multi-dim grid. Use symbolic names in --arg shapes " + "(e.g. x:M,M:float16:input) to bind them." + ), +) def quick( kernels: str, fn: str, @@ -231,6 +272,7 @@ def quick( profile: bool, flops: int | None, bytes_per_iter: int | None, + sweeps: tuple[str, ...], ) -> None: """Benchmark a kernel function directly — no bench file needed. @@ -243,15 +285,28 @@ def quick( """ from kernels import get_kernel - from kernels_bench.bench import auto_bytes + from kernels_bench.bench import auto_bytes, param_combinations from kernels_bench.progress import benchmark_progress, make_on_step - from kernels_bench.runner import KernelResult, run_benchmark_quick + from kernels_bench.runner import KernelResult, _resolve_specs, run_benchmark_quick from kernels_bench.runtime import detect_runtime from kernels_bench.validate import validate_quick specs = [_parse_arg(a) for a in args] - if bytes_per_iter is None: - bytes_per_iter = auto_bytes(specs) + sweep_dict: dict[str, list[int]] = dict(_parse_sweep(s) for s in sweeps) + + # Sanity-check sweep keys vs. symbolic dims used in --arg. + required = set().union(*(s.symbolic_dims for s in specs)) if specs else set() + provided = set(sweep_dict.keys()) + missing = required - provided + if missing: + raise click.ClickException( + f"symbolic dims {sorted(missing)} appear in --arg but no --sweep was given for them" + ) + unused = provided - required + if unused: + raise click.ClickException(f"--sweep keys {sorted(unused)} are not used in any --arg shape") + + combos = param_combinations(sweep_dict) kernel_list = [k.strip() for k in kernels.split(",")] runtime = detect_runtime() @@ -263,13 +318,13 @@ def quick( except Exception as e: raise click.ClickException(f"failed to load kernel {kernel_id!r}: {e}") from e - # Validation + # Validation: resolve specs against the first combo (consistent with Bench file mode). validation = None if validate and len(loaded_kernels) > 1: validation = validate_quick( kernels=loaded_kernels, fn_name=fn, - specs=specs, + specs=_resolve_specs(specs, combos[0]), runtime=runtime, atol=atol, rtol=rtol, @@ -278,32 +333,39 @@ def quick( all_results: list[KernelResult] = [] with benchmark_progress() as progress: for kernel_id, kernel in loaded_kernels.items(): - warmup_tid = progress.add_task(f"{kernel_id} warmup", total=warmup) - bench_tid = progress.add_task(f"{kernel_id} bench", total=iterations) - on_step = make_on_step(progress, warmup_tid, bench_tid) - times, metrics, compile_ms = run_benchmark_quick( - kernel=kernel, - fn_name=fn, - specs=specs, - warmup=warmup, - iterations=iterations, - runtime=runtime, - on_step=on_step, - collect_metrics=not no_metrics, - profile=profile, - profile_label=kernel_id, - ) - all_results.append( - KernelResult( - kernel_id=kernel_id, - params={}, - times_ms=times, - metrics=metrics, - compile_ms=compile_ms, - flops=flops, - bytes_per_iter=bytes_per_iter, + for params in combos: + resolved = _resolve_specs(specs, params) + params_str = ", ".join(f"{k}={v}" for k, v in sorted(params.items())) + label = f"{kernel_id}" + (f" ({params_str})" if params_str else "") + + warmup_tid = progress.add_task(f"{label} warmup", total=warmup) + bench_tid = progress.add_task(f"{label} bench", total=iterations) + on_step = make_on_step(progress, warmup_tid, bench_tid) + times, metrics, compile_ms = run_benchmark_quick( + kernel=kernel, + fn_name=fn, + specs=resolved, + warmup=warmup, + iterations=iterations, + runtime=runtime, + on_step=on_step, + collect_metrics=not no_metrics, + profile=profile, + profile_label=label, + ) + all_results.append( + KernelResult( + kernel_id=kernel_id, + params=params, + times_ms=times, + metrics=metrics, + compile_ms=compile_ms, + flops=flops, + bytes_per_iter=( + bytes_per_iter if bytes_per_iter is not None else auto_bytes(resolved) + ), + ) ) - ) result = BenchResult( bench_name=fn, diff --git a/src/kernels_bench/display.py b/src/kernels_bench/display.py index 8083f91..b89b08d 100644 --- a/src/kernels_bench/display.py +++ b/src/kernels_bench/display.py @@ -179,6 +179,54 @@ def _format_metrics(m: RunMetrics) -> str | None: return " ".join(parts) +SWEEP_SUMMARY_THRESHOLD = 4 +"""Switch to compact per-kernel summary when there are this many param combos or more.""" + + +def _print_sweep_summary( + result: BenchResult, + total_width: int, +) -> None: + """One row per param combo per kernel — readable at a glance for sweeps. + + Grouped visually by kernel so scaling across the swept dim is easy to spot. + Bar width is normalized within each kernel to highlight relative cost. + """ + by_kernel: dict[str, list[KernelResult]] = {} + for kr in result.kernel_results: + by_kernel.setdefault(kr.kernel_id, []).append(kr) + + # Label column width sized to the longest params string so the bars line up. + max_params_len = max(len(_format_params(kr.params)) for kr in result.kernel_results) + label_width = max(max_params_len + 1, 12) + inner = total_width - 2 + data_col_width = inner - label_width - 3 + bar_width = max(data_col_width - 24, 16) # leave room for "0.000 ms " + GB/s suffix + + kernels = list(by_kernel.items()) + for ki, (kernel_id, rows) in enumerate(kernels): + _print_divider(total_width, "section") + _print_centered(kernel_id, total_width) + _print_row_divider(total_width, label_width, "top") + + slowest = max(r.median_ms for r in rows) + for i, kr in enumerate(rows): + params_str = _format_params(kr.params) or "(no params)" + bar = _make_bar(kr.median_ms, slowest, bar_width) + median_text = f"{kr.median_ms:.3f} ms" + value = f"{median_text} {bar}" + if kr.gb_per_s is not None: + value += f" {DIM}{kr.gb_per_s:.1f} GB/s{RESET}{COLOR}" + _print_row(params_str, value, total_width, label_width) + if i < len(rows) - 1: + _print_row_divider(total_width, label_width) + + # Last kernel uses "bottom" to close the outer box; intermediate + # kernels use a section-style divider so the next block flows cleanly. + is_last = ki == len(kernels) - 1 + _print_row_divider(total_width, label_width, "bottom" if is_last else "mid") + + def print_results(result: BenchResult) -> None: """Print benchmark results in hf-mem box-drawing style.""" kernel_results = result.kernel_results @@ -254,6 +302,11 @@ def print_results(result: BenchResult) -> None: ) _print_row_divider(total_width, label_width, "bottom") + # Compact view for large sweeps — full per-combo blocks become unreadable. + if len(param_groups) >= SWEEP_SUMMARY_THRESHOLD: + _print_sweep_summary(result, total_width) + return + for group_key, group_results in param_groups.items(): _print_divider(total_width, "section") if group_key: diff --git a/src/kernels_bench/runner.py b/src/kernels_bench/runner.py index da0784d..5421032 100644 --- a/src/kernels_bench/runner.py +++ b/src/kernels_bench/runner.py @@ -114,6 +114,47 @@ def fastest(self, params: dict[str, int] | None = None) -> KernelResult: candidates = [r for r in candidates if r.params == params] return min(candidates, key=lambda r: r.median_ms) + def to_csv_rows(self) -> tuple[list[str], list[dict[str, Any]]]: + """Flatten results into (header, rows) for CSV export. + + Each row is one (kernel_id, params) pair. Param keys vary by run, so we + union them across all results to keep the column set stable. + """ + param_keys = sorted({k for kr in self.kernel_results for k in kr.params}) + header = [ + "kernel_id", + *param_keys, + "median_ms", + "p10_ms", + "p90_ms", + "iqr_ms", + "has_warnings", + "compile_ms", + "gflops_per_s", + "gb_per_s", + "peak_memory_mb", + "util_mean", + "util_peak", + ] + rows: list[dict[str, Any]] = [] + for kr in self.kernel_results: + row: dict[str, Any] = {"kernel_id": kr.kernel_id} + for k in param_keys: + row[k] = kr.params.get(k, "") + row["median_ms"] = kr.median_ms + row["p10_ms"] = kr.p10_ms + row["p90_ms"] = kr.p90_ms + row["iqr_ms"] = kr.iqr_ms + row["has_warnings"] = kr.has_warnings + row["compile_ms"] = kr.compile_ms + row["gflops_per_s"] = kr.gflops_per_s + row["gb_per_s"] = kr.gb_per_s + row["peak_memory_mb"] = kr.metrics.peak_memory_mb + row["util_mean"] = kr.metrics.util_mean + row["util_peak"] = kr.metrics.util_peak + rows.append(row) + return header, rows + def to_dict(self) -> dict[str, Any]: """Serialize results to a dict suitable for JSON export.""" return { diff --git a/tests/test_bench.py b/tests/test_bench.py index 0637b0e..557e6c2 100644 --- a/tests/test_bench.py +++ b/tests/test_bench.py @@ -3,7 +3,7 @@ import pytest import torch -from kernels_bench.bench import Bench, _resolve_workload, auto_bytes +from kernels_bench.bench import Bench, _resolve_workload, auto_bytes, param_combinations from kernels_bench.spec import TensorSpec @@ -119,3 +119,14 @@ def test_auto_bytes_sums_specs(): def test_auto_bytes_empty(): assert auto_bytes([]) == 0 + + +def test_param_combinations_empty(): + assert param_combinations({}) == [{}] + + +def test_param_combinations_grid(): + combos = param_combinations({"M": [1, 2], "N": [10, 20]}) + assert {"M": 1, "N": 10} in combos + assert {"M": 2, "N": 20} in combos + assert len(combos) == 4 diff --git a/tests/test_cli.py b/tests/test_cli.py new file mode 100644 index 0000000..84a2a72 --- /dev/null +++ b/tests/test_cli.py @@ -0,0 +1,52 @@ +"""Tests for CLI argument parsers.""" + +import click +import pytest +import torch + +from kernels_bench.cli import _parse_arg, _parse_sweep + + +def test_parse_arg_concrete_dims(): + spec = _parse_arg("x:1024,512:float16:input") + assert spec.name == "x" + assert spec.shape == (1024, 512) + assert spec.dtype is torch.float16 + assert spec.role == "input" + + +def test_parse_arg_symbolic_dims(): + spec = _parse_arg("x:M,N:float16:input") + assert spec.shape == ("M", "N") + assert spec.symbolic_dims == {"M", "N"} + + +def test_parse_arg_mixed_dims(): + spec = _parse_arg("y:M,128:float16:output") + assert spec.shape == ("M", 128) + + +def test_parse_sweep_basic(): + key, values = _parse_sweep("M=512,1024,2048") + assert key == "M" + assert values == [512, 1024, 2048] + + +def test_parse_sweep_strips_key(): + key, _ = _parse_sweep(" N =64,128") + assert key == "N" + + +def test_parse_sweep_missing_equals(): + with pytest.raises(click.ClickException, match="expected KEY="): + _parse_sweep("M:512,1024") + + +def test_parse_sweep_non_int_values(): + with pytest.raises(click.ClickException, match="must be ints"): + _parse_sweep("M=512,foo") + + +def test_parse_sweep_invalid_key(): + with pytest.raises(click.ClickException, match="valid identifier"): + _parse_sweep("1bad=1,2") diff --git a/tests/test_display.py b/tests/test_display.py index 23a157e..3b7c9fc 100644 --- a/tests/test_display.py +++ b/tests/test_display.py @@ -175,3 +175,37 @@ def test_format_throughput_combined(): bytes_per_iter=8 * 10**8, ) assert _format_throughput(kr) == "2.00 TFLOP/s 800.0 GB/s" + + +def test_print_results_uses_summary_view_for_large_sweeps(capsys): + """At/above the threshold, the per-combo block stats lines disappear.""" + from kernels_bench.display import SWEEP_SUMMARY_THRESHOLD, print_results + from kernels_bench.runner import BenchResult + + rows = [ + KernelResult(kernel_id="a", params={"M": 64 * (i + 1)}, times_ms=[1.0 + i]) + for i in range(SWEEP_SUMMARY_THRESHOLD) + ] + print_results(BenchResult(bench_name="b", kernel_results=rows)) + out = capsys.readouterr().out + + # Compact view drops the verbose stats / "PARAMS:" headers. + assert "PARAMS:" not in out + assert "p10=" not in out + # Each combo still shown as a row. + for kr in rows: + assert f"M={kr.params['M']}" in out + + +def test_print_results_keeps_full_blocks_below_threshold(capsys): + from kernels_bench.display import SWEEP_SUMMARY_THRESHOLD, print_results + from kernels_bench.runner import BenchResult + + rows = [ + KernelResult(kernel_id="a", params={"M": 64 * (i + 1)}, times_ms=[1.0 + i]) + for i in range(SWEEP_SUMMARY_THRESHOLD - 1) + ] + print_results(BenchResult(bench_name="b", kernel_results=rows)) + out = capsys.readouterr().out + assert "PARAMS:" in out + assert "p10=" in out diff --git a/tests/test_runner.py b/tests/test_runner.py index 69f4dce..b02336c 100644 --- a/tests/test_runner.py +++ b/tests/test_runner.py @@ -267,3 +267,38 @@ def test_bench_result_to_dict_compile_ms_none_when_missing(): kr = KernelResult(kernel_id="k", params={}, times_ms=[1.0]) d = BenchResult(bench_name="b", kernel_results=[kr]).to_dict() assert d["results"][0]["compile_ms"] is None + + +def test_to_csv_rows_columns_and_param_union(): + """Param keys are unioned across results so the column set is stable.""" + a = KernelResult(kernel_id="a", params={"M": 1, "N": 2}, times_ms=[1.0]) + b = KernelResult(kernel_id="b", params={"M": 3}, times_ms=[2.0]) # missing N + header, rows = BenchResult(bench_name="x", kernel_results=[a, b]).to_csv_rows() + + assert header[0] == "kernel_id" + # M and N appear after kernel_id, sorted, before metrics columns + assert header[1:3] == ["M", "N"] + assert "median_ms" in header and "gb_per_s" in header + assert rows[0]["M"] == 1 + assert rows[0]["N"] == 2 + assert rows[1]["M"] == 3 + assert rows[1]["N"] == "" # absent param renders as empty + + +def test_to_csv_rows_includes_throughput_and_metrics(): + metrics = RunMetrics(peak_memory_mb=8.0, util_mean=50.0, util_peak=99.0, util_samples=10) + kr = KernelResult( + kernel_id="k", + params={}, + times_ms=[1.0], + metrics=metrics, + flops=10**9, + bytes_per_iter=10**9, + ) + _, rows = BenchResult(bench_name="x", kernel_results=[kr]).to_csv_rows() + row = rows[0] + assert row["gflops_per_s"] == 1000.0 + assert row["gb_per_s"] == 1000.0 + assert row["peak_memory_mb"] == 8.0 + assert row["util_mean"] == 50.0 + assert row["util_peak"] == 99.0