Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ To evaluate model-generated kernels, we need to check if they:

Check out `src/eval.py` for details on how we implement correctness check and timing.

We provide a convenient script `scripts/run_and_check.py` to evaluate one single sample source code against a reference source code, check correctness and compute speedup. You can use this to evaluate a model-generated kernel.
We provide a convenient script `scripts/run_and_check.py` to evaluate one single sample source code against a reference source code, check correctness and compute speedup. You can use this to evaluate a kernel either locally or remotely by setting `eval_mode=local` or `eval_mode=modal`.

#### Overall Benchmark Metric

Expand Down
4 changes: 0 additions & 4 deletions scripts/eval_from_generations.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,17 +68,13 @@
"clang"
)
.pip_install(
"anthropic",
"numpy",
"openai",
"packaging",
"pydra_config",
"torch==2.5.0",
"tqdm",
"datasets",
"transformers",
"google-generativeai",
"together",
"pytest",
"ninja",
"utils",
Expand Down
5 changes: 1 addition & 4 deletions scripts/generate_baseline_time_modal.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,17 +68,14 @@
"clang" # note i skip a step
)
.pip_install( # required to build flash-attn
"anthropic",
# Let's unify these dependencies somewhere
"numpy",
"openai",
"packaging",
"pydra_config",
"torch==2.5.0",
"tqdm",
"datasets",
"transformers",
"google-generativeai",
"together",
"pytest",
"ninja",
"utils",
Expand Down
240 changes: 201 additions & 39 deletions scripts/run_and_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,57 @@
from pydra import REQUIRED, Config
import os
from datasets import load_dataset

import modal

from src import eval as kernel_eval
from src import utils as kernel_utils
from scripts.generate_baseline_time import measure_program_time
from src.utils import read_file

# Modal setup
app = modal.App("run_and_check")
gpu_arch_mapping = {
"L40S": ["Ada"],
"H100": ["Hopper"],
"H200": ["Hopper"],
"A100": ["Ampere"],
"A100-80GB": ["Ampere"],
"L4": ["Ada"],
"T4": ["Turing"],
"A10G": ["Ampere"]
}

REPO_TOP_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
KERNEL_BENCH_PATH = os.path.join(REPO_TOP_PATH, "KernelBench")

cuda_version = "12.4.0"
flavor = "devel"
operating_sys = "ubuntu22.04"
tag = f"{cuda_version}-{flavor}-{operating_sys}"

image = (
modal.Image.from_registry(f"nvidia/cuda:{tag}", add_python="3.10")
.apt_install("git", "gcc-10", "g++-10", "clang")
.pip_install(
"numpy",
"packaging",
"pydra_config",
"torch==2.5.0",
"tqdm",
"datasets",
"transformers",
"pytest",
"ninja",
"utils",
"einops",
"python-dotenv",
"litellm[proxy]",
)
.add_local_dir(KERNEL_BENCH_PATH, remote_path="/root/KernelBench")
.add_local_python_source("src")
.add_local_python_source("scripts")
)

"""
Run a pair of KernelBench format (problem, solution) to check if solution is correct and compute speedup

Expand All @@ -25,11 +69,17 @@

====================================================
Usage:
1. PyTorch reference is a local file
python3 scripts/run_and_check.py ref_origin=local ref_arch_src_path=src/prompts/model_ex_add.py kernel_src_path=src/prompts/model_new_ex_add.py
1. PyTorch reference is a local file (local eval)
python3 scripts/run_and_check.py ref_origin=local ref_arch_src_path=src/prompts/model_ex_add.py kernel_src_path=src/prompts/model_new_ex_add.py eval_mode=local

2. PyTorch reference is a kernelbench problem (local eval)
python3 scripts/run_and_check.py ref_origin=kernelbench level=<level> problem_id=<problem_id> kernel_src_path=<path to model-generated kernel> eval_mode=local

2. PyTorch refernece is a kernelbench problem
python3 scripts/run_and_check.py ref_origin=kernelbench level=<level> problem_id=<problem_id> kernel_src_path=<path to model-generated kernel>
3. PyTorch reference is a local file (modal eval on cloud GPU)
python3 scripts/run_and_check.py ref_origin=local ref_arch_src_path=src/prompts/model_ex_add.py kernel_src_path=src/prompts/model_new_ex_add.py eval_mode=modal gpu=H100

4. PyTorch reference is a kernelbench problem (modal eval on cloud GPU)
python3 scripts/run_and_check.py ref_origin=kernelbench level=<level> problem_id=<problem_id> kernel_src_path=<path to model-generated kernel> eval_mode=modal gpu=L40S
====================================================

"""
Expand All @@ -51,6 +101,9 @@ def __init__(self):
# Solution src definition
self.kernel_src_path = ""

# Evaluation mode
self.eval_mode = "local" # either "local" or "modal"
self.gpu = "L40S" # GPU type for modal (L40S, H100, H200, A100, etc.)

# KernelBench Eval specific
# number of trials to run for correctness
Expand All @@ -66,7 +119,7 @@ def __init__(self):
self.clear_cache = False # TODO

# Replace with your NVIDIA GPU architecture, e.g. ["Hopper"]
self.gpu_arch = ["Ada"]
self.gpu_arch = ["Ada"]
self.precision = "fp32"
self.backend = "cuda"

Expand Down Expand Up @@ -119,11 +172,70 @@ def evaluate_single_sample_src(ref_arch_src: str, kernel_src: str, configs: dict
"hardware": torch.cuda.get_device_name(device=device),
"device": str(device)
}
eval_result = kernel_eval.KernelExecResult(compiled=False, correctness=False,
eval_result = kernel_eval.KernelExecResult(compiled=False, correctness=False,
metadata=metadata)
return eval_result


# Modal evaluation class
@app.cls(image=image, scaledown_window=5)
class EvalFunc:

@modal.method()
def evaluate_single_sample_src_modal(self, ref_arch_src: str, kernel_src: str, configs: dict, gpu_arch: list):
"""Evaluate a single sample source code against a reference source code on Modal"""
from src.utils import set_gpu_arch
from src.eval import eval_kernel_against_ref, get_torch_dtype_from_string

set_gpu_arch(gpu_arch)
device = torch.device("cuda:0")

num_correct_trials = configs["num_correct_trials"]
num_perf_trials = configs["num_perf_trials"]
verbose = configs["verbose"]
measure_performance = configs["measure_performance"]

eval_result = eval_kernel_against_ref(
original_model_src=ref_arch_src,
custom_model_src=kernel_src,
measure_performance=measure_performance,
verbose=verbose,
num_correct_trials=num_correct_trials,
num_perf_trials=num_perf_trials,
device=device,
backend=configs["backend"],
precision=get_torch_dtype_from_string(configs["precision"])
)
return eval_result

@modal.method()
def measure_program_time_modal(
self,
ref_arch_src: str,
num_trials: int,
use_torch_compile: bool,
torch_compile_backend: str,
torch_compile_options: str,
gpu_arch: list
):
"""Measure the execution time of a reference program on Modal"""
from scripts.generate_baseline_time import measure_program_time
from src.utils import set_gpu_arch

set_gpu_arch(gpu_arch)
device = torch.device("cuda:0")

return measure_program_time(
ref_arch_name="Reference Program",
ref_arch_src=ref_arch_src,
num_trials=num_trials,
use_torch_compile=use_torch_compile,
torch_compile_backend=torch_compile_backend,
torch_compile_options=torch_compile_options,
device=device
)


@pydra.main(base=ScriptConfig)
def main(config: ScriptConfig):

Expand Down Expand Up @@ -162,38 +274,88 @@ def main(config: ScriptConfig):
kernel_src = read_file(config.kernel_src_path)

# Start Evaluation
device = torch.device("cuda:0") # default device
kernel_utils.set_gpu_arch(config.gpu_arch)

print("[INFO] Evaluating kernel against reference code")
# Evaluate kernel against reference code
kernel_eval_result = evaluate_single_sample_src(
ref_arch_src=ref_arch_src,
kernel_src=kernel_src,
configs=config.to_dict(),
device=device
)
kernel_exec_time = kernel_eval_result.runtime

# Measure baseline time
print("[INFO] Measuring reference program time")
# Default using PyTorch Eager here
ref_time_eager_result = measure_program_time(ref_arch_name="Reference Program",
ref_arch_src=ref_arch_src,
num_trials=config.num_perf_trials,
use_torch_compile=False,
device=device)
ref_exec_eager_time = ref_time_eager_result.get("mean", None)

# Measure Torch Compile time
ref_time_compile_result = measure_program_time(ref_arch_name="Reference Program",
ref_arch_src=ref_arch_src,
num_trials=config.num_perf_trials,
use_torch_compile=True,
torch_compile_backend="inductor",
torch_compile_options="default",
device=device)
ref_exec_compile_time = ref_time_compile_result.get("mean", None)
assert config.eval_mode in ["local", "modal"], "eval_mode must be either 'local' or 'modal'"

if config.eval_mode == "local":
# Local evaluation (existing code path)
device = torch.device("cuda:0")
kernel_utils.set_gpu_arch(config.gpu_arch)

print("[INFO] Evaluating kernel against reference code (LOCAL)")
# Evaluate kernel against reference code
kernel_eval_result = evaluate_single_sample_src(
ref_arch_src=ref_arch_src,
kernel_src=kernel_src,
configs=config.to_dict(),
device=device
)
kernel_exec_time = kernel_eval_result.runtime

# Measure baseline time
print("[INFO] Measuring reference program time")
# Default using PyTorch Eager here
ref_time_eager_result = measure_program_time(ref_arch_name="Reference Program",
ref_arch_src=ref_arch_src,
num_trials=config.num_perf_trials,
use_torch_compile=False,
device=device)
ref_exec_eager_time = ref_time_eager_result.get("mean", None)

# Measure Torch Compile time
ref_time_compile_result = measure_program_time(ref_arch_name="Reference Program",
ref_arch_src=ref_arch_src,
num_trials=config.num_perf_trials,
use_torch_compile=True,
torch_compile_backend="inductor",
torch_compile_options="default",
device=device)
ref_exec_compile_time = ref_time_compile_result.get("mean", None)

elif config.eval_mode == "modal":
# Modal evaluation (remote execution)
gpu_arch = gpu_arch_mapping.get(config.gpu, config.gpu_arch)
print(f"[INFO] Using GPU: {config.gpu} with architecture: {gpu_arch}")

with app.run():
print("[INFO] Evaluating kernel against reference code (MODAL)")
# Evaluate kernel against reference code
kernel_eval_result = EvalFunc.with_options(
gpu=config.gpu
)().evaluate_single_sample_src_modal.remote(
ref_arch_src=ref_arch_src,
kernel_src=kernel_src,
configs=config.to_dict(),
gpu_arch=gpu_arch
)
kernel_exec_time = kernel_eval_result.runtime

# Measure baseline time
print("[INFO] Measuring reference program time (PyTorch Eager)")
ref_time_eager_result = EvalFunc.with_options(
gpu=config.gpu
)().measure_program_time_modal.remote(
ref_arch_src=ref_arch_src,
num_trials=config.num_perf_trials,
use_torch_compile=False,
torch_compile_backend=None,
torch_compile_options=None,
gpu_arch=gpu_arch
)
ref_exec_eager_time = ref_time_eager_result.get("mean", None)

# Measure Torch Compile time
print("[INFO] Measuring reference program time (torch.compile)")
ref_time_compile_result = EvalFunc.with_options(
gpu=config.gpu
)().measure_program_time_modal.remote(
ref_arch_src=ref_arch_src,
num_trials=config.num_perf_trials,
use_torch_compile=True,
torch_compile_backend="inductor",
torch_compile_options="default",
gpu_arch=gpu_arch
)
ref_exec_compile_time = ref_time_compile_result.get("mean", None)

print("="*40)
print(f"[Eval] Kernel eval result: {kernel_eval_result}")
Expand Down