diff --git a/README.md b/README.md
index 9c3232c8..b3db083a 100644
--- a/README.md
+++ b/README.md
@@ -1,16 +1,19 @@
# KernelBench: Can LLMs Write Efficient GPU Kernels? [ICML '25]
-[arXiv](https://arxiv.org/html/2502.10517v1) | [blog post](https://scalingintelligence.stanford.edu/blogs/kernelbench/) | [HuggingFace Dataset](https://huggingface.co/datasets/ScalingIntelligence/KernelBench) |
+A benchmark for evaluating LLMs' ability to generate efficient GPU kernels
+
+[arXiv](https://arxiv.org/html/2502.10517v1) | [blog post](https://scalingintelligence.stanford.edu/blogs/kernelbench/) | [HuggingFace Dataset](https://huggingface.co/datasets/ScalingIntelligence/KernelBench)
+
+
## Versions
-The huggingface dataset is updated to v0.1.
-- [v0.1](https://github.com/ScalingIntelligence/KernelBench/tree/v0.1) - Latest version (also main branch)
+The latest stable version will be on `main` branch. We continue to update and improve the repo.
+- [v0.1](https://github.com/ScalingIntelligence/KernelBench/tree/v0.1) - See [blog](https://scalingintelligence.stanford.edu/blogs/kernelbenchv01/)
- [v0](https://github.com/ScalingIntelligence/KernelBench/tree/v0) - Original Release
-A benchmark for evaluating LLMs' ability to generate efficient GPU kernels
-
+The Huggingface [dataset](https://huggingface.co/datasets/ScalingIntelligence/KernelBench) is updated to v0.1.
-
+This repo provides core functionality for KernelBench and an easy-to-use set of scripts for evaluation. It is not intended to provide complex agentic scaffolds that solve this task; we recommend cloning and modifying this repo for your experiment, or using it as a git submodule.
## 👋 Task Description
We structure the problem for LLM to transpile operators described in PyTorch to CUDA kernels, at whatever level of granularity it desires to.
@@ -26,7 +29,7 @@ We construct KernelBench to have 4 Levels of categories:
- **Level 4 🤗**: Level Hugging Face
Optimize whole model architectures from HuggingFace
-We are actively extending KernelBench to other DSLs beyond `cuda` as well.
+We are actively extending KernelBench to other DSLs beyond `cuda` as well (see below).
## ⚖️ Evaluation
#### Methodology
@@ -98,7 +101,12 @@ python3 scripts/generate_and_eval_single_sample.py dataset_src="huggingface" lev
# add .verbose_logging for more visbility
```
-We are also supporting other GPU programming languages beyond `cuda`. Simply specify `backend=triton`. For now we support (`cuda`, `triton`, `cute`).
+**What you might need to modify**
+* **`gpu_arch`** - Depend on your GPU, you might need to adjust the `gpu_arch` argument to reflect your hardware.
+* **`precision`** - You can specify the precision of tensor by `precision=fp32`. Currently all of our reported results are `fp32` but we added support for `fp16` & `bf16`.
+* **`backend`** - We are also supporting other GPU programming languages beyond `cuda`. Simply specify `backend=triton`. For now we support DSLs: `cuda`, `triton`, `cute`, `tilelang`.
+
+Check the config fields for comprehensive set of options.
### Run on all problems
diff --git a/requirements.txt b/requirements.txt
index a253f422..d7f31a49 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -7,6 +7,8 @@ modal
# DSLs
nvidia-cutlass-dsl
+tilelang
+apache-tvm
# helper
tqdm
diff --git a/scripts/eval_from_generations.py b/scripts/eval_from_generations.py
index 787aca2b..c7f6fe61 100644
--- a/scripts/eval_from_generations.py
+++ b/scripts/eval_from_generations.py
@@ -145,6 +145,10 @@ def __init__(self):
# Backend to use for kernel implementation (cuda or triton)
self.backend = "cuda"
+
+ # Precision for computation: "fp32", "fp16", "bf16"
+ self.precision = "fp32"
+
# Number of samples per problem to evaluate for pass@k analysis
self.num_samples_per_problem = 1 # Default to 1 sample per problem
@@ -188,11 +192,13 @@ def evaluate_single_sample_modal(
num_perf_trials: int = 100,
measure_performance: bool = True,
verbose: bool = False,
+ backend: str = "cuda",
+ precision: str = "fp32",
):
"""
Evaluate a single sample on Modal GPU with automatic retries for GPU attachment failures
"""
- from src.eval import eval_kernel_against_ref
+ from src.eval import eval_kernel_against_ref, get_torch_dtype_from_string
from src.utils import set_gpu_arch
import torch
import time
@@ -225,6 +231,8 @@ def evaluate_single_sample_modal(
num_perf_trials=num_perf_trials,
build_dir=None, # Modal doesn't need persistent build dir
device=torch.device("cuda:0"), # Modal has one GPU per container
+ backend=backend,
+ precision=get_torch_dtype_from_string(precision),
)
# Force cleanup and exit to prevent container reuse and memory leaks
@@ -321,6 +329,7 @@ def evaluate_single_sample(
build_dir=build_dir,
device=device,
backend=configs.backend,
+ precision=eval.get_torch_dtype_from_string(configs.precision),
)
return eval_result
except Exception as e:
@@ -491,6 +500,8 @@ def batch_eval_modal(
num_perf_trials=config.num_perf_trials,
measure_performance=config.measure_performance,
verbose=config.verbose,
+ backend=config.backend,
+ precision=config.precision,
)
futures.append(future)
diff --git a/scripts/generate_and_eval_single_sample.py b/scripts/generate_and_eval_single_sample.py
index d596b557..18fb3c55 100644
--- a/scripts/generate_and_eval_single_sample.py
+++ b/scripts/generate_and_eval_single_sample.py
@@ -18,7 +18,7 @@
read_file,
set_gpu_arch,
)
-
+from src.eval import get_torch_dtype_from_string
"""
Generate and evaluate a single sample
Easiest way to get started, to test a single problem for experimentation or debugging
@@ -48,6 +48,7 @@ def __init__(self):
# Construct this from mapping from architecture name to torch cuda arch list in the future
# you can either specify SM version or just use the name
self.gpu_arch = ["Ada"]
+ self.precision = "fp32" # options ["fp32", "fp16", "bf16"]
# Inference config
self.server_type = None
@@ -171,11 +172,11 @@ def main(config: EvalConfig):
# Use appropriate prompt constructor based on backend
if config.backend == "cuda":
custom_prompt = prompt_generate_custom_cuda_from_prompt_template(ref_arch_src)
- elif config.backend in ["triton", "cute"]: # removed "tilelang"
+ elif config.backend in ["triton", "tilelang", "cute"]:
custom_prompt = get_prompt_for_backend(ref_arch_src, config.backend)
else:
raise ValueError(
- f"Unsupported backend: {config.backend}. Must be 'cuda', 'triton', or 'cute'."
+ f"Unsupported backend: {config.backend}. Must be 'cuda', 'triton', 'tilelang', or 'cute'."
)
if config.log_prompt:
@@ -219,6 +220,7 @@ def main(config: EvalConfig):
num_correct_trials=5,
num_perf_trials=100,
backend=config.backend,
+ precision=get_torch_dtype_from_string(config.precision),
)
print(
diff --git a/scripts/generate_and_eval_single_sample_modal.py b/scripts/generate_and_eval_single_sample_modal.py
index da449de1..ae18ba8f 100644
--- a/scripts/generate_and_eval_single_sample_modal.py
+++ b/scripts/generate_and_eval_single_sample_modal.py
@@ -53,7 +53,7 @@ def __init__(self):
# you can either specify SM version or just use the name
self.gpu = "L40S"
self.gpu_arch = ['Ada']
-
+ self.precision = "fp32" # options ["fp32", "fp16", "bf16"]
# Inference config
self.server_type = None
@@ -110,8 +110,8 @@ def __repr__(self):
"pytest",
"ninja",
"utils",
- # "tilelang", # commented out - not working currently
- #"apache-tvm",
+ "tilelang",
+ "apache-tvm",
"python-dotenv",
"nvidia-cutlass-dsl",
"litellm[proxy]", # Unified LLM interface
@@ -125,17 +125,18 @@ def __repr__(self):
class EvalFunc:
@modal.method()
- def eval_single_sample_modal(self, ref_arch_src, custom_kernel, verbose, gpu_arch, backend):
+ def eval_single_sample_modal(self, ref_arch_src, custom_kernel, verbose, gpu_arch, backend, precision):
# 3. Evaluate Kernel
# NOTE: no need to wrap around process here as only a single sample
# see batch eval for examples of process isolation
from src.eval import eval_kernel_against_ref
+ from src.eval import get_torch_dtype_from_string
# Use utility function to set the GPU architecture in the modal environment
from src.utils import set_gpu_arch as modal_set_gpu_arch
modal_set_gpu_arch(gpu_arch)
return eval_kernel_against_ref(
ref_arch_src, custom_kernel, verbose=verbose, measure_performance=True,
- num_correct_trials=5, num_perf_trials=100, backend=backend
+ num_correct_trials=5, num_perf_trials=100, backend=backend, precision=get_torch_dtype_from_string(precision)
)
@pydra.main(base=EvalConfig)
@@ -216,10 +217,10 @@ def main(config: EvalConfig):
# Use appropriate prompt constructor based on backend
if config.backend == "cuda":
custom_prompt = prompt_generate_custom_cuda_from_prompt_template(ref_arch_src)
- elif config.backend in ["triton", "cute"]: # removed "tilelang"
+ elif config.backend in ["triton", "tilelang", "cute"]:
custom_prompt = get_prompt_for_backend(ref_arch_src, config.backend)
else:
- raise ValueError(f"Unsupported backend: {config.backend}. Must be 'cuda', 'triton', or 'cute'.")
+ raise ValueError(f"Unsupported backend: {config.backend}. Must be 'cuda', 'triton', 'tilelang', or 'cute'.")
if config.log_prompt:
with open(os.path.join(config.logdir, f"prompt_level_{config.level}_problem_{config.problem_id}.txt"), "w") as f:
@@ -238,7 +239,7 @@ def main(config: EvalConfig):
with app.run():
kernel_exec_result = EvalFunc.with_options(gpu=config.gpu)().eval_single_sample_modal.remote(
- ref_arch_src, custom_kernel, config.verbose, gpu_arch_mapping[config.gpu], config.backend
+ ref_arch_src, custom_kernel, config.verbose, gpu_arch_mapping[config.gpu], config.backend, config.precision
)
print(f"Evaluation result for level {config.level} problem {config.problem_id}:\n{kernel_exec_result}")
diff --git a/scripts/generate_samples.py b/scripts/generate_samples.py
index 7915fc30..5b476445 100644
--- a/scripts/generate_samples.py
+++ b/scripts/generate_samples.py
@@ -78,6 +78,8 @@ def __init__(self):
self.log_prompt = False
self.backend = "cuda"
+
+ self.precision = "fp32"
def greedy(self):
# For greedy decoding, epsecially baseline eval
@@ -129,11 +131,11 @@ def generate_sample_single(
custom_cuda_prompt = prompt_generate_custom_cuda_from_prompt_template(
ref_arch_src
)
- elif config.backend in ["triton", "cute"]: # removed "tilelang"
+ elif config.backend in ["triton", "cute", "tilelang"]:
custom_cuda_prompt = get_prompt_for_backend(ref_arch_src, config.backend)
else:
raise ValueError(
- f"Unsupported backend: {config.backend}. Must be 'cuda', 'triton', or 'cute'."
+ f"Unsupported backend: {config.backend}. Must be 'cuda', 'triton', 'cute', or 'tilelang'."
)
if config.log_prompt:
prompt_path = os.path.join(
diff --git a/scripts/run_and_check.py b/scripts/run_and_check.py
index 79c00a7e..a863afcd 100644
--- a/scripts/run_and_check.py
+++ b/scripts/run_and_check.py
@@ -9,7 +9,6 @@
from src import eval as kernel_eval
from src import utils as kernel_utils
from scripts.generate_baseline_time import measure_program_time
-
from src.utils import read_file
"""
@@ -68,6 +67,8 @@ def __init__(self):
# Replace with your NVIDIA GPU architecture, e.g. ["Hopper"]
self.gpu_arch = ["Ada"]
+ self.precision = "fp32"
+ self.backend = "cuda"
def __repr__(self):
return f"ScriptConfig({self.to_dict()})"
@@ -97,7 +98,9 @@ def evaluate_single_sample_src(ref_arch_src: str, kernel_src: str, configs: dict
num_correct_trials=num_correct_trials,
num_perf_trials=num_perf_trials,
build_dir=build_dir,
- device=device
+ device=device,
+ backend=configs["backend"],
+ precision=kernel_eval.get_torch_dtype_from_string(configs["precision"])
)
return eval_result
except Exception as e:
diff --git a/src/eval.py b/src/eval.py
index e2411f3c..4a072c89 100644
--- a/src/eval.py
+++ b/src/eval.py
@@ -13,7 +13,7 @@
import traceback
from contextlib import redirect_stderr, redirect_stdout
from io import StringIO
-from typing import Union
+from typing import Union, Optional
import numpy as np
import requests
@@ -33,25 +33,10 @@
def get_error_name(e: Exception) -> str:
-
- return f"{e.__class__.__module__}.{e.__class__.__name__}"
-
-
-def fetch_kernel_from_database(
- run_name: str, problem_id: int, sample_id: int, server_url: str
-):
"""
- Intenral to us with our django database
- Return a dict with kernel hash, kernel code, problem_id
+ Get the error name, for logging purposes
"""
- response = requests.get(
- f"{server_url}/get_kernel_by_run_problem_sample/{run_name}/{problem_id}/{sample_id}",
- json={"run_name": run_name, "problem_id": problem_id, "sample_id": sample_id},
- )
- assert response.status_code == 200
- response_json = response.json()
- assert str(response_json["problem_id"]) == str(problem_id)
- return response_json
+ return f"{e.__class__.__module__}.{e.__class__.__name__}"
def fetch_ref_arch_from_problem_id(problem_id, problems, with_name=False) -> str:
@@ -85,6 +70,40 @@ def set_seed(seed: int):
# NOTE: this only sets on current cuda device
torch.cuda.manual_seed(seed)
+def get_torch_dtype_from_string(precision: str) -> torch.dtype:
+ """
+ Get the torch dtype for specific precision
+ """
+ if precision == "fp32":
+ return torch.float32
+ elif precision == "fp16":
+ return torch.float16
+ elif precision == "bf16":
+ return torch.bfloat16
+ else: # future, FP8, FP4, etc. support?
+ raise ValueError(f"Invalid precision not supported: {precision}")
+
+def get_tolerance_for_precision(precision: str | torch.dtype) -> float:
+ """
+ Get the tolerance from a string representing the percision.
+ These tolerances are inspired by torchbench (PyTorch Benchmarking Suite):
+ Reference:
+ https://github.com/pytorch/benchmark/blob/cfd835c35d04513ced9a59bd074eeb21dc8187d7/torchbenchmark/util/env_check.py#L519
+ """
+ if isinstance(precision, str):
+ precision = get_torch_dtype_from_string(precision)
+
+ PRECISION_TOLERANCES = {
+ # By default for fp32, 1e-4 is used according to torchbench.
+ torch.float32: 1e-4,
+ # torchbench states for bf16 and fp16, use 1e-3 as tolerance and 1e-2 if it's too strict.
+ # @todo: Let user configure own tolerance as an option
+ torch.float16: 1e-2,
+ torch.bfloat16: 1e-2,
+ }
+ assert precision in PRECISION_TOLERANCES, f"Invalid precision not supported: {precision}"
+ return PRECISION_TOLERANCES[precision]
+
class KernelExecResult(BaseModel):
"""
@@ -158,41 +177,6 @@ def load_custom_model_with_tempfile(model_custom_src, entry_point="ModelNew"):
return ModelNew, temp_file
-# def load_tilelang_model(
-# model_custom_src: str,
-# context: dict,
-# build_directory: str | None = None
-# ):
-# """
-# Load TileLang model using linecache instead of tempfile.
-# This registers the source code in memory so inspect.getsource() works,
-# which is needed for TileLang's JIT decorator.
-# """
-# if build_directory:
-# model_custom_src = (
-# "import os\n"
-# f"os.environ['TORCH_EXTENSIONS_DIR'] = '{build_directory}'\n"
-# + model_custom_src
-# )
-#
-# # Register source so inspect.getsource works
-# fake_fname = (
-# f"/tmp/tilelang_kernel_"
-# f"{hashlib.md5(model_custom_src.encode()).hexdigest()}.py"
-# )
-# # linecache expects a list with trailing newlines
-# linecache.cache[fake_fname] = (
-# len(model_custom_src),
-# None,
-# model_custom_src.splitlines(True),
-# fake_fname,
-# )
-#
-# code_obj = compile(model_custom_src, fake_fname, "exec")
-# exec(code_obj, context)
-# return context["ModelNew"]
-
-
def load_custom_model(
model_custom_src: str, context: dict, build_directory: str = None
) -> nn.Module:
@@ -379,31 +363,28 @@ def build_compile_cache_with_capturing(
return returncode, stdout.decode("utf-8"), stderr.decode("utf-8")
-def _process_input_tensor(tensor, device, backend):
+def _process_input_tensor(input, device, backend="cuda", precision=torch.float32):
"""
Helper function to move tensors to the correct device and apply backend-specific dtype casting.
Args:
- tensor: Input tensor or non-tensor value
+ input: Input tensor or non-tensor value
device: Target CUDA device
backend: Backend type (e.g., 'cuda', 'triton', 'cute')
-
+ precision: torch.dtype
Returns:
Processed tensor on correct device with correct dtype, or original value if not a tensor
"""
- if not isinstance(tensor, torch.Tensor):
- return tensor
-
- # Preserve integer dtypes for labels/targets (e.g., classification losses)
- if tensor.dtype in [torch.int32, torch.int64, torch.long]:
- return tensor.to(device=device)
+
+ # sometimes things like init inputs are floats (like in the case of labels / targets, classification losses, etc.)
+ if not isinstance(input, torch.Tensor):
+ return input
- # Apply backend-specific dtype casting for float tensors
- # if backend.lower() == "tilelang":
- # return tensor.to(device=device, dtype=torch.float16)
+ # cast to the desired percision dtype for activations
+ input_tensor = input.to(dtype=precision)
# Default for all other backends and float types
- return tensor.to(device=device)
+ return input_tensor.to(device=device)
def eval_kernel_against_ref(
@@ -418,7 +399,8 @@ def eval_kernel_against_ref(
device: Union[torch.device, int] = (
torch.cuda.current_device() if torch.cuda.is_available() else None
), # have to run on GPU
- backend: str = "cuda", # can be 'cuda', 'triton', or 'cute'
+ backend: str = "cuda", # can be 'cuda', 'triton', 'tilelang', or 'cute'
+ precision: torch.dtype = torch.float32,
) -> KernelExecResult:
"""
Evaluate the custom kernel against the original model
@@ -426,14 +408,14 @@ def eval_kernel_against_ref(
num_correct_trials: number of trials to initialize different random inputs; correctness pass only if all trials pass
num_perf_trials: run the evalutation many times to take the average
device: GPU (cuda) device to run the evalutation on
- backend: str, one of 'cuda', 'triton', or 'cute'
+ backend: str, one of 'cuda', 'triton', 'tilelang', or 'cute'
+ precision: torch.dtype for computation (note: tilelang only supports fp16)
"""
# TODO: check device is busy
assert torch.cuda.is_available(), "CUDA is not available, cannot run Eval"
- # SET DEFAULT DTYPE TO FLOAT16 ONLY FOR TILELANG
- # if backend.lower() == "tilelang":
- # torch.set_default_dtype(torch.float16)
+ if backend.lower() == "tilelang":
+ assert precision == torch.float16 or precision == torch.bfloat16, "TileLang only supports fp16 or bfloat16"
torch.set_printoptions(
precision=4, # Decimal places
@@ -446,7 +428,8 @@ def eval_kernel_against_ref(
torch.cuda.set_device(device)
# Backends that use tempfile approach and need CUDA_VISIBLE_DEVICES
- uses_tempfile = backend.lower() in ["triton", "cute"] # removed "tilelang"
+ # TileLang, Triton, and CuTe all use tempfile for proper module loading
+ uses_tempfile = backend.lower() in ["triton", "tilelang", "cute"]
metadata = {} # for storing result metadata
metadata["hardware"] = torch.cuda.get_device_name(device=device)
@@ -479,7 +462,7 @@ def eval_kernel_against_ref(
init_inputs = get_init_inputs()
# Convert inputs to appropriate dtypes for GPU computation
- init_inputs = [_process_input_tensor(x, device, backend) for x in init_inputs]
+ init_inputs = [_process_input_tensor(x, device, backend, precision) for x in init_inputs]
with torch.no_grad():
set_seed(seed_num) # set seed for reproducible weights
@@ -498,11 +481,9 @@ def eval_kernel_against_ref(
# add hash for later to distinguish between multi-turn kernels
backend_lower = backend.lower()
- # if backend_lower == "tilelang":
- # # Use linecache approach for TileLang
- # ModelNew = load_tilelang_model(custom_model_src, context, build_dir)
- if backend_lower in ["triton", "cute"]:
- # Use tempfile approach for triton and cute
+ if backend_lower in ["triton", "tilelang", "cute"]:
+ # Use tempfile approach for triton, tilelang, and cute
+ # These DSLs require proper module import for JIT decorators to work
ModelNew, tempfile = load_custom_model_with_tempfile(
custom_model_src, entry_point="ModelNew"
)
@@ -538,27 +519,8 @@ def eval_kernel_against_ref(
set_seed(seed_num) # set seed for reproducible weights
custom_model = ModelNew(*init_inputs)
assert hasattr(custom_model, "forward")
- # Move models to GPU with float16 dtype (only for TileLang)
- # if backend.lower() == "tilelang":
- # try:
- # original_model = original_model.to(device=device, dtype=torch.float16)
- # except Exception as e:
- # # TileLang JIT kernels may not support .to(), already on GPU
- # if verbose:
- # print(f"[Info] Could not call .to() on original model (TileLang), using as-is: {e}")
- # print("[Traceback]:")
- # traceback.print_exc()
- # try:
- # custom_model = custom_model.to(device=device, dtype=torch.float16)
- # except Exception as e:
- # # TileLang JIT kernels may not support .to(), already on GPU
- # if verbose:
- # print(f"[Info] Could not call .to() on custom model (TileLang), using as-is: {e}")
- # print("[Traceback]:")
- # traceback.print_exc()
- # else:
- original_model = original_model.to(device=device)
- custom_model = custom_model.to(device=device)
+ original_model = original_model.to(device=device, dtype=precision)
+ custom_model = custom_model.to(device=device, dtype=precision)
torch.cuda.synchronize(device=device)
if verbose:
print("[Eval] New Model with Custom CUDA Kernel Loaded")
@@ -590,6 +552,7 @@ def eval_kernel_against_ref(
seed=seed_num,
device=device,
backend=backend,
+ precision=precision,
)
except Exception as e:
# TODO: add metadata for runtime error e.g. error in launching kernel, illegal memory access, ...
@@ -610,20 +573,9 @@ def eval_kernel_against_ref(
set_seed(seed_num)
inputs = get_inputs()
# Convert inputs for performance measurement
- inputs = [_process_input_tensor(x, device, backend) for x in inputs]
+ inputs = [_process_input_tensor(x, device, backend, precision) for x in inputs]
- # if backend.lower() == "tilelang":
- # try:
- # model_new = custom_model.to(device=device, dtype=torch.float16)
- # except Exception as e:
- # # TileLang JIT kernels may not support .to(), already on GPU
- # if verbose:
- # print(f"[Info] Line 616 - Could not call .to() on custom model for perf measurement (TileLang): {e}")
- # print("[Traceback] From performance measurement - line 616:")
- # traceback.print_exc()
- # model_new = custom_model
- # else:
- model_new = custom_model.to(device=device)
+ model_new = custom_model.to(device=device, dtype=precision)
torch.cuda.synchronize(device=device)
elapsed_times = time_execution_with_cuda_event(
@@ -737,10 +689,11 @@ def run_and_check_correctness(
get_inputs_fn: callable,
metadata: dict,
num_correct_trials: int,
- verbose=False,
- seed=42,
- device=None,
- backend="cuda",
+ verbose: bool =False,
+ seed: int =42,
+ device: Optional[torch.device] =None,
+ backend: str ="cuda",
+ precision: torch.dtype =torch.float32,
) -> KernelExecResult:
"""
run the model and check correctness,
@@ -749,6 +702,7 @@ def run_and_check_correctness(
num_correct_trials: run the evalutation multiple times with (ideally) different random inputs to ensure correctness
backend: backend type for handling dtype conversions
+ precision: torch.dtype
"""
pass_count = 0
@@ -765,42 +719,19 @@ def run_and_check_correctness(
trial_seed = correctness_trial_seeds[trial]
if verbose:
print(f"[Eval] Generating Random Input with seed {trial_seed}")
-
- # if backend.lower() == "tilelang":
- # torch.set_default_dtype(torch.float16)
set_seed(trial_seed)
inputs = get_inputs_fn()
# Convert inputs to appropriate dtypes for GPU computation
- inputs = [_process_input_tensor(x, device, backend) for x in inputs]
+ inputs = [_process_input_tensor(x, device, backend, precision) for x in inputs]
set_seed(trial_seed)
- # if backend.lower() == "tilelang":
- # try:
- # model = original_model_instance.to(device=device, dtype=torch.float16)
- # except Exception as e:
- # # TileLang JIT kernels may not support .to(), already on GPU
- # if verbose:
- # print(f"[Info] Line 771 - Could not call .to() on original model (TileLang): {e}")
- # print("[Traceback] From run_and_check_correctness - line 771:")
- # traceback.print_exc()
- # model = original_model_instance
- # else:
- model = original_model_instance.to(device=device)
+
+ model = original_model_instance.to(device=device, dtype=precision)
set_seed(trial_seed)
- # if backend.lower() == "tilelang":
- # try:
- # model_new = new_model_instance.to(device=device, dtype=torch.float16)
- # except Exception as e:
- # # TileLang JIT kernels may not support .to(), already on GPU
- # if verbose:
- # print(f"[Info] Line 777 - Could not call .to() on custom model (TileLang): {e}")
- # print("[Traceback] From run_and_check_correctness - line 777:")
- # traceback.print_exc()
- # model_new = new_model_instance
- # else:
- model_new = new_model_instance.to(device=device)
+
+ model_new = new_model_instance.to(device=device, dtype=precision)
output = model(*inputs)
torch.cuda.synchronize(device=device)
@@ -824,9 +755,13 @@ def run_and_check_correctness(
compiled=True, correctness=False, metadata=metadata
)
+ # in torchbench, they use both precisions for atol and rtol
+ # kernelbench v0 and v0.1 uses fp32, atol = rtol = 1e-02
+ # now we will return the tolerance from get_tolerance_for_precision
+ tolerance = get_tolerance_for_precision(precision)
# check output value difference
if not torch.allclose(
- output, output_new, atol=1e-02, rtol=1e-02
+ output, output_new, atol=tolerance, rtol=tolerance
): # fail
max_diff = torch.max(torch.abs(output - output_new)).item()
avg_diff = torch.mean(torch.abs(output - output_new)).item()
diff --git a/src/prompt_constructor_multilang.py b/src/prompt_constructor_multilang.py
index 39d16243..8a520d10 100644
--- a/src/prompt_constructor_multilang.py
+++ b/src/prompt_constructor_multilang.py
@@ -265,118 +265,116 @@ def prompt_fix_correctness_triton(ref_arch_src, custom_kernel, metadata):
################################################################################
-# TileLang Backend - COMMENTED OUT (not working currently)
+# TileLang Backend
################################################################################
-# TILELANG_PROBLEM_STATEMENT = """You write custom TileLang kernels to replace the pytorch operators in the given architecture to get speedups. \n
-# You have complete freedom to choose the set of operators you want to replace. You may make the decision to replace some operators with custom TileLang kernels and leave others unchanged. You may replace multiple operators with custom implementations, consider operator fusion opportunities (combining multiple operators into a single kernel, for example, combining matmul+relu), or algorithmic changes (such as online softmax). You are only limited by your imagination.\n
-# """
-#
-# TILELANG_PROBLEM_INSTRUCTION = """
-# Optimize the architecture named Model with custom TileLang kernels! Name your optimized output architecture ModelNew. Output the new code in codeblocks. Please generate real code, NOT pseudocode, make sure the code compiles and is fully functional. Just output the new model code, no other text, and NO testing code! \n
-# """
-#
-# TILELANG_PROBLEM_STATEMENT_CLEANED = """You write custom TileLang kernels to replace the pytorch operators in the given architecture to get speedups.\n\nYou have complete freedom to choose the set of operators you want to replace. You may make the decision to replace some operators with custom TileLang kernels and leave others unchanged. You may replace multiple operators with custom implementations, consider operator fusion opportunities (combining multiple operators into a single kernel, for example, combining matmul+relu), or algorithmic changes (such as online softmax). You are only limited by your imagination.\n
-# """
-#
-# TILELANG_PROBLEM_INSTRUCTION_CLEANED = """
-# Optimize the architecture named Model with custom TileLang kernels! Name your optimized output architecture ModelNew. Output the new code in codeblocks. Please generate real code, NOT pseudocode, make sure the code compiles and is fully functional. Just output the new model code, no other text, and NO testing code! \n
-# """
-#
-#
-# def prompt_generate_custom_tilelang(
-# arc_src: str, example_arch_src: str, example_new_arch_src: str
-# ) -> str:
-# prompt = TILELANG_PROBLEM_STATEMENT
-#
-# if example_arch_src != "" and example_new_arch_src != "":
-# prompt += f"""
-# Here's an example to show you the syntax of inline embedding custom TileLang kernels in torch: The example given architecture is: \n
-# ``` \n
-# {example_arch_src}
-# ``` \n
-# The example new arch with custom TileLang kernels looks like this: \n
-# ```
-# {example_new_arch_src}
-# ``` \n
-# """
-#
-# prompt += f"""
-# You are given the following architecture: \n
-# ```
-# {arc_src}
-# ```
-# """
-# prompt += TILELANG_PROBLEM_INSTRUCTION
-# return prompt
-#
-#
-# def prompt_generate_custom_tilelang_from_prompt_template(ref_arch_src: str) -> str:
-# """
-# Using prompt example for TileLang
-# Note: You'll need to create a TileLang example file similar to the Triton one
-# """
-# arch = ref_arch_src
-#
-# # TODO: Create model_new_ex_add_tilelang.py example file
-# example_arch_path = os.path.join(REPO_TOP_PATH, f"src/prompts/model_ex_add.py")
-# example_new_arch_path = os.path.join(
-# REPO_TOP_PATH, f"src/prompts/model_new_ex_add_tilelang.py"
-# )
-#
-# if not os.path.exists(example_arch_path):
-# raise FileNotFoundError(
-# f"Example architecture file not found: {example_arch_path}"
-# )
-# if not os.path.exists(example_new_arch_path):
-# # For now, use a basic template without examples if file doesn't exist
-# return prompt_generate_custom_tilelang(arch, "", "")
-#
-# example_arch = read_file(example_arch_path)
-# example_new_arch = read_file(example_new_arch_path)
-#
-# return prompt_generate_custom_tilelang(arch, example_arch, example_new_arch)
-#
-#
-# def prompt_fix_compile_tilelang(ref_arch_src, custom_kernel, metadata):
-# prompt = TILELANG_PROBLEM_STATEMENT
-# prompt += f"""
-# With the following architecture:
-# ```
-# {ref_arch_src}
-# ```
-# You generated the following solution and it failed to compile:
-# ```
-# {custom_kernel}
-# ```
-# Here's the metadata of the compilation error:
-# ```
-# {metadata}
-# ```
-#
-# Please fix the compilation error in the new model code. Please output the corrected code in codeblocks.
-# """
-# return prompt
-#
-#
-# def prompt_fix_correctness_tilelang(ref_arch_src, custom_kernel, metadata):
-# prompt = TILELANG_PROBLEM_STATEMENT
-# prompt += f"""
-# With the following architecture:
-# ```
-# {ref_arch_src}
-# ```
-# You generated the following solution and it failed correctness:
-# ```
-# {custom_kernel}
-# ```
-# Here's the metadata of the correctness error:
-# ```
-# {metadata}
-# ```
-# Please consider how your custom TileLang kernels are implemented, how it is different from the reference implementation, and fix the correctness error in the new model code. Please output the corrected code in codeblocks.
-# """
-# return prompt
+TILELANG_PROBLEM_STATEMENT = """You write custom TileLang kernels to replace the pytorch operators in the given architecture to get speedups. \n
+ You have complete freedom to choose the set of operators you want to replace. You may make the decision to replace some operators with custom TileLang kernels and leave others unchanged. You may replace multiple operators with custom implementations, consider operator fusion opportunities (combining multiple operators into a single kernel, for example, combining matmul+relu), or algorithmic changes (such as online softmax). You are only limited by your imagination.\n
+"""
+
+TILELANG_PROBLEM_INSTRUCTION = """
+Optimize the architecture named Model with custom TileLang kernels! Name your optimized output architecture ModelNew. Output the new code in codeblocks. Please generate real code, NOT pseudocode, make sure the code compiles and is fully functional. Just output the new model code, no other text, and NO testing code! \n
+"""
+
+TILELANG_PROBLEM_STATEMENT_CLEANED = """You write custom TileLang kernels to replace the pytorch operators in the given architecture to get speedups.\n\nYou have complete freedom to choose the set of operators you want to replace. You may make the decision to replace some operators with custom TileLang kernels and leave others unchanged. You may replace multiple operators with custom implementations, consider operator fusion opportunities (combining multiple operators into a single kernel, for example, combining matmul+relu), or algorithmic changes (such as online softmax). You are only limited by your imagination.\n
+"""
+
+TILELANG_PROBLEM_INSTRUCTION_CLEANED = """
+Optimize the architecture named Model with custom TileLang kernels! Name your optimized output architecture ModelNew. Output the new code in codeblocks. Please generate real code, NOT pseudocode, make sure the code compiles and is fully functional. Just output the new model code, no other text, and NO testing code! \n
+"""
+
+
+def prompt_generate_custom_tilelang(
+ arc_src: str, example_arch_src: str, example_new_arch_src: str
+) -> str:
+ prompt = TILELANG_PROBLEM_STATEMENT
+
+ if example_arch_src != "" and example_new_arch_src != "":
+ prompt += f"""
+ Here's an example to show you the syntax of inline embedding custom TileLang kernels in torch: The example given architecture is: \n
+ ``` \n
+ {example_arch_src}
+ ``` \n
+ The example new arch with custom TileLang kernels looks like this: \n
+ ```
+ {example_new_arch_src}
+ ``` \n
+ """
+
+ prompt += f"""
+ You are given the following architecture: \n
+ ```
+ {arc_src}
+ ```
+ """
+ prompt += TILELANG_PROBLEM_INSTRUCTION
+ return prompt
+
+
+def prompt_generate_custom_tilelang_from_prompt_template(ref_arch_src: str) -> str:
+ """
+ Using prompt example for TileLang
+ """
+ arch = ref_arch_src
+
+ example_arch_path = os.path.join(REPO_TOP_PATH, f"src/prompts/model_ex_add.py")
+ example_new_arch_path = os.path.join(
+ REPO_TOP_PATH, f"src/prompts/model_new_ex_add_tilelang.py"
+ )
+
+ if not os.path.exists(example_arch_path):
+ raise FileNotFoundError(
+ f"Example architecture file not found: {example_arch_path}"
+ )
+ if not os.path.exists(example_new_arch_path):
+ # For now, use a basic template without examples if file doesn't exist
+ return prompt_generate_custom_tilelang(arch, "", "")
+
+ example_arch = read_file(example_arch_path)
+ example_new_arch = read_file(example_new_arch_path)
+
+ return prompt_generate_custom_tilelang(arch, example_arch, example_new_arch)
+
+
+def prompt_fix_compile_tilelang(ref_arch_src, custom_kernel, metadata):
+ prompt = TILELANG_PROBLEM_STATEMENT
+ prompt += f"""
+ With the following architecture:
+ ```
+ {ref_arch_src}
+ ```
+ You generated the following solution and it failed to compile:
+ ```
+ {custom_kernel}
+ ```
+ Here's the metadata of the compilation error:
+ ```
+ {metadata}
+ ```
+
+ Please fix the compilation error in the new model code. Please output the corrected code in codeblocks.
+ """
+ return prompt
+
+
+def prompt_fix_correctness_tilelang(ref_arch_src, custom_kernel, metadata):
+ prompt = TILELANG_PROBLEM_STATEMENT
+ prompt += f"""
+ With the following architecture:
+ ```
+ {ref_arch_src}
+ ```
+ You generated the following solution and it failed correctness:
+ ```
+ {custom_kernel}
+ ```
+ Here's the metadata of the correctness error:
+ ```
+ {metadata}
+ ```
+ Please consider how your custom TileLang kernels are implemented, how it is different from the reference implementation, and fix the correctness error in the new model code. Please output the corrected code in codeblocks.
+ """
+ return prompt
################################################################################
@@ -504,7 +502,7 @@ def get_prompt_for_backend(ref_arch_src: str, backend: str = "triton") -> str:
Args:
ref_arch_src: Reference architecture source code
- backend: One of 'triton', 'cute' (tilelang removed - not working)
+ backend: One of 'triton', 'tilelang', 'cute'
Returns:
Prompt string for the specified backend
@@ -513,13 +511,13 @@ def get_prompt_for_backend(ref_arch_src: str, backend: str = "triton") -> str:
if backend_lower == "triton":
return prompt_generate_custom_triton_from_prompt_template(ref_arch_src)
- # elif backend_lower == "tilelang":
- # return prompt_generate_custom_tilelang_from_prompt_template(ref_arch_src)
+ elif backend_lower == "tilelang":
+ return prompt_generate_custom_tilelang_from_prompt_template(ref_arch_src)
elif backend_lower == "cute":
return prompt_generate_custom_cute_from_prompt_template(ref_arch_src)
else:
raise ValueError(
- f"Unsupported backend: {backend}. Must be one of: 'triton', 'cute'"
+ f"Unsupported backend: {backend}. Must be one of: 'triton', 'tilelang', 'cute'"
)
diff --git a/src/prompts/model_new_ex_add_tilelang.py b/src/prompts/model_new_ex_add_tilelang.py
new file mode 100644
index 00000000..b65fb2e9
--- /dev/null
+++ b/src/prompts/model_new_ex_add_tilelang.py
@@ -0,0 +1,53 @@
+import torch
+import torch.nn as nn
+import tilelang
+import tilelang.language as T
+
+
+def build_elementwise_add_kernel(M: int, N: int, block_M: int = 128, block_N: int = 256, threads: int = 128, dtype: str = "float16"):
+
+ @T.prim_func
+ def elementwise_add_kernel(
+ A: T.Tensor((M, N), dtype),
+ B: T.Tensor((M, N), dtype),
+ C: T.Tensor((M, N), dtype),
+ ):
+ with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
+ start_x = bx * block_N
+ start_y = by * block_M
+
+ for local_y, local_x in T.Parallel(block_M, block_N):
+ y = start_y + local_y
+ x = start_x + local_x
+
+ C[y, x] = A[y, x] + B[y, x]
+
+ return tilelang.compile(elementwise_add_kernel, out_idx=[2], target="cuda")
+
+
+class ModelNew(nn.Module):
+ def __init__(self):
+ super(ModelNew, self).__init__()
+ self._kernel_cache = {}
+
+ def _get_kernel(self, M: int, N: int, tl_dtype: str):
+ key = (M, N, tl_dtype)
+ if key not in self._kernel_cache:
+ self._kernel_cache[key] = build_elementwise_add_kernel(M, N, dtype=tl_dtype)
+ return self._kernel_cache[key]
+
+ def forward(self, A: torch.Tensor, B: torch.Tensor) -> torch.Tensor:
+ A_c = A.contiguous()
+ B_c = B.contiguous()
+
+ # Get original shape for reshaping output
+ original_shape = A_c.shape
+
+ A_c = A_c.view(-1, A_c.size(-1))
+ B_c = B_c.view(-1, B_c.size(-1))
+
+ M, N = A_c.shape
+ kernel = self._get_kernel(M, N, "float16")
+ C = kernel(A_c, B_c)
+
+ return C.view(original_shape)
\ No newline at end of file