Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions src/kernelbench/timing.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import math
import torch
import json
import numpy as np
Expand Down Expand Up @@ -193,11 +194,31 @@ def get_timing_function(
case _:
raise ValueError(f"Unsupported timing method: {method}")

def exclude_outliers(outlier_ratio=0.1):
"""Decorator to remove outliers from the results of a timing function."""
def exclude_k_outliers(values, k=1):
arr = np.asarray(values, dtype=float)
med = np.median(arr)
deviations = np.abs(arr - med)
idx_sorted = np.argsort(deviations)
keep_idx = idx_sorted[:-k] if k > 0 else idx_sorted
return arr[keep_idx].tolist()
def decorator(timing_fn):
def wrapper(*args, **kwargs):
num_trials = kwargs.get('num_trials', 10)
k_outliers = math.floor(num_trials * outlier_ratio)
kwargs['num_trials'] = num_trials + k_outliers
results = timing_fn(*args, **kwargs)
return exclude_k_outliers(results, k=k_outliers) if k_outliers > 0 else results
return wrapper
return decorator

"""
Kernel Timing Functions
NOTE: we have a WIP blogpost on this topic covering the various timing approaches
"""

@exclude_outliers()
def time_execution_with_cuda_event(
kernel_fn: callable,
args: list[Any],
Expand Down Expand Up @@ -281,6 +302,7 @@ def time_execution_with_cuda_event(
return elapsed_times


@exclude_outliers()
def time_execution_with_do_bench_interface(
kernel_fn: callable,
args: list[Any],
Expand Down Expand Up @@ -330,6 +352,7 @@ def time_execution_with_do_bench_interface(
return_mode="all")


@exclude_outliers()
def time_execution_with_do_bench_impl(
kernel_fn: callable,
args: list[Any],
Expand Down Expand Up @@ -430,6 +453,7 @@ def time_execution_with_do_bench_impl(
return times


@exclude_outliers()
def time_execution_with_host_time(
kernel_fn: callable,
args: list[Any],
Expand Down Expand Up @@ -499,6 +523,7 @@ def time_execution_with_host_time(

return elapsed_times

@exclude_outliers()
def time_execution_with_nsight_python(
kernel_fn: callable,
args: list[Any],
Expand Down