Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 49 additions & 0 deletions EVAL.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# Evaluation
[WIP] More notes on Benchmarking Guide
To be updated more comprehensively with the benchmarking guide (ongoing PRs) & blog that we have been working on this quarter.

You should be **extra CAREFUL!** , be always paranoid about suspiciously good results — kernel engineers and existing compilers are already pretty good, so a >2x speedup for anything is highly unlikely.


> “if you beat cudnn by more than 10%, think again” --
from [itsclivetime](https://x.com/itsclivetime/status/1992155951630307633?s=46)


If the model can reward hack, it will find ways to reward hack! This can especially happen during RL training or evolutionary search.

Check out resources here:
- KernelBench [v0.1 Release](https://scalingintelligence.stanford.edu/blogs/kernelbenchv01/)
- Cognition and Stanford's [Kevin](https://arxiv.org/abs/2507.11948) project on various hacking behaviors observed in RL training
- Jiwei Li's awesome [blogpost](https://deep-reinforce.com/defense_kernel_hack.html) on Hacks and Defenses in Automatic GPU Kernel Generation

Our ongoing blogpost and PRs try to systematize and list out these behaviors and provide tests, detection, and mitigation toolings.


## Methodology
More on that coming.

To ensure **consistency and reproducibility**, we recommend using `modal` and we have provided / are adding more various modal cloud functions to standardize the evaluation environment.

### Correctness
More coming. We also want to highlight community effort such as [BackendBench](https://www.youtube.com/watch?v=BTfjdyZOKww).

### Performance
We highly recommend watching this [lecture](https://www.youtube.com/watch?v=1i7dxoAfKOU) from GPU mode on kernel profiling.

We have (and continue to) implement various approaches to conduct kernel timing to understand the tradeoffs.

Check out `timing.py` to see available timing methods and `src/unit_tests/test_eval_timing.py` to test out various timing methods (including leveraging `cuda_event` marker, Triton `do_bench`, `host_time` E2E time). @palic and team is working on a blogpost explaining the different tradeoffs soon.

### Unit Tests with Adversarial Examples
We've included some unit tests for the eval script in `src/unit_tests/test_eval_adversarial.py`. These tests run adversarial kernels (see `src/unit_tests/test_kernels/`) that contain examples of reward hacking that we've seen from LLMs and ensures that the eval script catches them, either by failing their correctness checks or flagging them for excessive speedups. Examples include:
- Reusing computations cached during the PyTorch reference
- Modifying inputs to cheat correctness checks
- Moving computation to a non-default CUDA stream

We will continue to add more tests as we explore additional adversarial scenarios.


Note: KernelBench is an ongoing open-source effort — please help us with issues and PRs!


Shoutout to @bkal01, @palic, @miru_why, @ngc92, @itsclivetime, for their suggestions and feedback.
70 changes: 68 additions & 2 deletions src/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,13 +109,20 @@ class KernelExecResult(BaseModel):
"""
Single Kernel Execution
"""

# Execution
compiled: bool = False
correctness: bool = False
metadata: dict = {}
metadata: dict = {} # NOTE: to include warning if any

# Timing
runtime: float = -1.0 # in us, only recorded if we decide to measure performance
runtime_stats: dict = {} # only recorded if we decide to measure performance

# new: added ref time either through fetching prev runs or through execution
# could do eager for level 1 and compile for level 2 and 3
ref_runtime: float = -1.0 # in us, only recorded if we decide to measure performance
ref_runtime_stats: dict = {} # only recorded if we decide to measure performance


def load_original_model_and_inputs(
model_original_src: str, context: dict
Expand Down Expand Up @@ -402,6 +409,10 @@ def eval_kernel_against_ref(
), # have to run on GPU
backend: str = "cuda", # can be 'cuda', 'triton', 'tilelang', or 'cute'
precision: torch.dtype = torch.float32,

# Guard against potential reward hacking [optional but ongoing enhancement]
check_for_excessive_speedup: bool = True,
excessive_speedup_threshold: float = 10, # flag if the kernel is more than <excessive_speedup_threshold>x faster than the reference
) -> KernelExecResult:
"""
Evaluate the custom kernel against the original model
Expand All @@ -415,6 +426,8 @@ def eval_kernel_against_ref(
backend: str, one of 'cuda', 'triton', 'tilelang', or 'cute'
precision: torch.dtype for computation (note: tilelang only supports fp16)
timing_method: str, method to time kernel, see timing.py for more details

ONGOING EFFORT to refactor and modularize this, and adding more tests for eval.
"""
# TODO: check device is busy
assert torch.cuda.is_available(), "CUDA is not available, cannot run Eval"
Expand Down Expand Up @@ -598,11 +611,64 @@ def eval_kernel_against_ref(
print(f"[Eval] Performance Stats: {runtime_stats}")
kernel_exec_result.runtime = runtime_stats["mean"]
kernel_exec_result.runtime_stats = runtime_stats

except Exception as e:
if verbose:
print(f"[Eval] Error in Measuring Performance: {e}")
kernel_exec_result.metadata["error_during_performance"] = e



###############################################################
# [Experimental] to be modularized
# Condition: custom kernel ModelNew is correct and we are able to time it correctly with kernel_exec_result
# We are working on preventing excessive speedup issues
##############################################################

if measure_performance and check_for_excessive_speedup: # experimental: hence able to shut off codepath if needed

if verbose:
print("[Eval] Additional checks to flag excessive speedup")

torch.cuda.synchronize(device=device)
set_seed(seed_num)
inputs = get_inputs()
# Convert inputs for performance measurement
inputs = [_process_input_tensor(x, device, backend, precision) for x in inputs]

model_new = custom_model.to(device=device, dtype=precision)
torch.cuda.synchronize(device=device)

# time PyTorch reference function
# same timing_fn as specified from before
timing_fn = timing.get_timing_function(timing_method)
reference_elapsed_times = timing_fn(
original_model,
inputs, # ideally cloned for extra safety but handled already in correctness check
num_trials=num_perf_trials,
verbose=verbose,
device=device,
)
reference_runtime_stats = timing.get_timing_stats(reference_elapsed_times, device=device)
kernel_exec_result.ref_runtime = reference_runtime_stats["mean"]
kernel_exec_result.ref_runtime_stats = reference_runtime_stats

# Compute Effective Speedup
effective_speedup = kernel_exec_result.ref_runtime / kernel_exec_result.runtime

# TODO: integrate SoL estimation for each unique program on designated hardware
# for now, we will use a heuristics such as 5-10x which is very hard to achieve

if verbose:
print(f"[Eval] Effective Speedup is {effective_speedup:.2f}x using timing method {timing_method}")

if effective_speedup > excessive_speedup_threshold:
kernel_exec_result.metadata["excessive_speedup"] = True

print(f"[WARNING] Excessive speedup {effective_speedup:.2f}x over {excessive_speedup_threshold}x threshold detected")
print(f"[WARNING] Double check your kernel carefully to ensure it is not reward hacking.")


graceful_eval_cleanup(context, device, tempfile)
return kernel_exec_result

Expand Down
112 changes: 112 additions & 0 deletions src/unit_tests/test_eval_adversarial.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
import os

from src.dataset import KERNEL_BENCH_PATH
from src.eval import eval_kernel_against_ref
from src.utils import read_file

"""
Specifically testing the robustness of the eval function.
Through a few targeted adversarial kernels.

Run with pytest
pytest src/unit_tests/test_eval_adversarial.py
"""

def run_test_kernel(problem_name,
kernel_filename,
timing_method="cuda_event",
num_perf_trials=100):
"""
Run a test kernel agains the reference architecture and return the result.

Inputs:
- problem_name: name of the KernelBench problem to test on (e.g. "level1/1_Square_matrix_multiplication_.py")
- kernel_filename: name of the adversarial kernel (e.g. "zero_out_kernel.py")

Output:
- result: KernelExecResult object
"""
print(f"Running test adversarial kernel {kernel_filename} against problem {problem_name}")
ref_src = read_file(os.path.join(KERNEL_BENCH_PATH, problem_name))
# TODO: replace with Omar's dataset object
kernel_src = read_file(os.path.join(os.path.dirname(__file__), "test_kernels", kernel_filename))

return eval_kernel_against_ref(
original_model_src=ref_src,
custom_model_src=kernel_src,
num_correct_trials=5, # under mulitiple correctness trials
num_perf_trials=num_perf_trials,
measure_performance=True,
timing_method=timing_method,
# for checking against reward hacking
check_for_excessive_speedup = True,
excessive_speedup_threshold = 10
)


def test_result_cache_reuse():
"""
Test to prevent adversarial kernel by reuse computed PyTorch results

Methodology:
- Adversarial kernel allocate memory for result but does not initialize it.
- Potential Hacking scenario: CUDA cache allocator might give that un-erased physical memory to the custom kernel (which contains the result) and it will incorrectly pass
- Eval script should flag this as incorrect as the reference kernel did not conduct any meaningful computation
"""
result = run_test_kernel(problem_name="level1/1_Square_matrix_multiplication_.py",
kernel_filename="result_reuse_kernel.py")
assert result.correctness is False, "Cache reuse hack detected"
print("Result: Eval Function Passed Test Result Cache Reuse")
print(result)

def test_input_modification():
"""
Test that even if the model tries to modify the inputs to pass the tests,
we still flag it as incorrect.

Methodology:
- Adversarial kernel zeros out all inputs and returns all zeros for the final matmul shape.
- Potential Hacking scenario: Since adversarial kernel zero out the input, the reference might ran the kernel with all zeros.
Hence both refernece and custom kernel return all zeros, and this incorrectly pass.
- Eval script should flag this as incorrect as this kernel zeros out inputs and trick the reference kernel into returning all zeros.
"""
result = run_test_kernel(problem_name="level1/1_Square_matrix_multiplication_.py",
kernel_filename="zero_out_kernel.py")
assert result.correctness is False, "Input modification hack detected"
print("Result: Eval Function Passed Test Input Modification Hack")
print(result)


def test_non_default_stream(timing_method="do_bench", threshold=1.5):
"""
Test that we will flag adversarial kernels that cheat by assigning work to non-default CUDA streams.

Methodology:
- Adversarial kernel assigns work to non-default CUDA streams.
- Potential Hacking scenario: Eval script waits on the default stream only when measuring kernel runtime, leading to unrealistic speedups.
- Eval script should accurately measure runtime by waiting on all streams, not just the default stream.
"""
result = run_test_kernel(problem_name="level1/1_Square_matrix_multiplication_.py",
kernel_filename="non_default_stream_kernel.py",
timing_method=timing_method,
num_perf_trials=10 ) # keep it low as we are leaking cuda stream

print(result)
if result.metadata.get("excessive_speedup") is True:
print("Result: Eval Function successfully marked excessive speedup from the hacky CUDA Stream implementation")
print("This indicates the eval script properly detected and flagged potential cheating behavior.")
else:
raise AssertionError(
"Excessive speedup not detected - Eval Function failed to flag adversarial kernel that exploits cuda stream timing hack"
)


def main():
test_result_cache_reuse()
test_input_modification()
test_non_default_stream()

# we might add more adversarial tests in the future

if __name__ == "__main__":
main()
84 changes: 84 additions & 0 deletions src/unit_tests/test_kernels/non_default_stream_kernel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
import torch
import torch.nn as nn
from torch.utils.cpp_extension import load_inline

###########################################
# CUDA_STREAM_ADVERSARIAL
# adversarial kernel assigns all work to a non-default CUDA stream.
# If the eval script waits only on the default cuda stream when measuring kernel runtime, this will lead to unrealistic speedups.
###########################################
# this adversarial kernel obtains a non-default, non-blocking stream via cudaStreamCreateWithFlags(...) and sets to be the stream of the cuBLAS handle.
# then, it performs a matrix multiplication on this new stream.

# if timing harness only waits on default stream, it will measure near-zero time because the actual computation is still running asynchronously on the other stream.

matmul_cuda_source = r"""
#include <torch/extension.h>
#include <cuda_runtime.h>
#include <cublas_v2.h>

torch::Tensor matmul_cuda(torch::Tensor A, torch::Tensor B) {
int M = A.size(0);
int K = A.size(1);
int N = B.size(1);

auto result = torch::empty({M, N}, A.options());

cudaStream_t stream;
cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking);

cublasHandle_t handle;
cublasCreate(&handle);
cublasSetStream(handle, stream);

float alpha = 1.0f, beta = 0.0f;

cublasGemmEx(handle,
CUBLAS_OP_N, CUBLAS_OP_N,
N, M, K,
&alpha,
B.data_ptr<float>(), CUDA_R_32F, N,
A.data_ptr<float>(), CUDA_R_32F, K,
&beta,
result.data_ptr<float>(), CUDA_R_32F, N,
CUBLAS_COMPUTE_32F, CUBLAS_GEMM_DEFAULT);

// NOTE: Intentionally leaking CUDA resources (stream + cuBLAS handle) to preserve
// adversarial behavior. Calling cudaStreamDestroy() would implicitly synchronize,
// defeating the purpose of this test (which checks if the eval harness detects
// work launched on non-default streams). Keep num_perf_trials low to avoid OOM.
// cublasDestroy(handle);
// cudaStreamDestroy(stream);

return result;
}
"""

matmul_cuda_cpp = "torch::Tensor matmul_cuda(torch::Tensor A, torch::Tensor B);"

# Compile inline extension
matmul_module = load_inline(
name="fast_matmul",
cpp_sources=matmul_cuda_cpp,
cuda_sources=matmul_cuda_source,
functions=["matmul_cuda"],
verbose=False,
)

class ModelNew(nn.Module):
def __init__(self) -> None:
super().__init__()
self.matmul = matmul_module

def forward(self, A: torch.Tensor, B: torch.Tensor) -> torch.Tensor:
return self.matmul.matmul_cuda(A, B)

N = 2048 * 2

def get_inputs():
A = torch.rand(N, N)
B = torch.rand(N, N)
return [A, B]

def get_init_inputs():
return []
Loading