From 02e3e913e37b17f0b78f39f65b28b8fc951ce2df Mon Sep 17 00:00:00 2001 From: Simon Guo Date: Sun, 2 Nov 2025 17:01:02 -0800 Subject: [PATCH 1/5] initial implementation for various precision support on input and weights (during forward pass) use same precision for both Co-authored-by: Simon Guo Co-authored-by: Sahan Paliskara --- scripts/generate_and_eval_single_sample.py | 4 +- .../generate_and_eval_single_sample_modal.py | 9 +- src/eval.py | 117 ++++++++++++------ 3 files changed, 85 insertions(+), 45 deletions(-) diff --git a/scripts/generate_and_eval_single_sample.py b/scripts/generate_and_eval_single_sample.py index ff71e4bc..c2837ab8 100644 --- a/scripts/generate_and_eval_single_sample.py +++ b/scripts/generate_and_eval_single_sample.py @@ -18,7 +18,7 @@ read_file, set_gpu_arch, ) - +from src.eval import get_torch_dtype_from_string """ Generate and evaluate a single sample Easiest way to get started, to test a single problem for experimentation or debugging @@ -48,6 +48,7 @@ def __init__(self): # Construct this from mapping from architecture name to torch cuda arch list in the future # you can either specify SM version or just use the name self.gpu_arch = ["Ada"] + self.precision = "fp32" # options ["fp32", "fp16", "bf16"] # Inference config self.server_type = "deepseek" @@ -196,6 +197,7 @@ def main(config: EvalConfig): num_correct_trials=5, num_perf_trials=100, backend=config.backend, + precision=get_torch_dtype_from_string(config.precision), ) print( diff --git a/scripts/generate_and_eval_single_sample_modal.py b/scripts/generate_and_eval_single_sample_modal.py index e9e0866a..34c7fb22 100644 --- a/scripts/generate_and_eval_single_sample_modal.py +++ b/scripts/generate_and_eval_single_sample_modal.py @@ -53,7 +53,7 @@ def __init__(self): # you can either specify SM version or just use the name self.gpu = "L40S" self.gpu_arch = ['Ada'] - + self.precision = "fp32" # options ["fp32", "fp16", "bf16"] # Inference config self.server_type = "deepseek" @@ -121,17 +121,18 @@ def __repr__(self): class EvalFunc: @modal.method() - def eval_single_sample_modal(self, ref_arch_src, custom_kernel, verbose, gpu_arch, backend): + def eval_single_sample_modal(self, ref_arch_src, custom_kernel, verbose, gpu_arch, backend, precision): # 3. Evaluate Kernel # NOTE: no need to wrap around process here as only a single sample # see batch eval for examples of process isolation from src.eval import eval_kernel_against_ref + from src.eval import get_torch_dtype_from_string # Use utility function to set the GPU architecture in the modal environment from src.utils import set_gpu_arch as modal_set_gpu_arch modal_set_gpu_arch(gpu_arch) return eval_kernel_against_ref( ref_arch_src, custom_kernel, verbose=verbose, measure_performance=True, - num_correct_trials=5, num_perf_trials=100, backend=backend + num_correct_trials=5, num_perf_trials=100, backend=backend, precision=get_torch_dtype_from_string(precision) ) @pydra.main(base=EvalConfig) @@ -216,7 +217,7 @@ def main(config: EvalConfig): with app.run(): kernel_exec_result = EvalFunc.with_options(gpu=config.gpu)().eval_single_sample_modal.remote( - ref_arch_src, custom_kernel, config.verbose, gpu_arch_mapping[config.gpu], config.backend + ref_arch_src, custom_kernel, config.verbose, gpu_arch_mapping[config.gpu], config.backend, config.precision ) print(f"Evaluation result for level {config.level} problem {config.problem_id}:\n{kernel_exec_result}") diff --git a/src/eval.py b/src/eval.py index e2411f3c..a8bf7ac6 100644 --- a/src/eval.py +++ b/src/eval.py @@ -13,7 +13,7 @@ import traceback from contextlib import redirect_stderr, redirect_stdout from io import StringIO -from typing import Union +from typing import Union, Optional import numpy as np import requests @@ -33,25 +33,10 @@ def get_error_name(e: Exception) -> str: - - return f"{e.__class__.__module__}.{e.__class__.__name__}" - - -def fetch_kernel_from_database( - run_name: str, problem_id: int, sample_id: int, server_url: str -): """ - Intenral to us with our django database - Return a dict with kernel hash, kernel code, problem_id + Get the error name, for logging purposes """ - response = requests.get( - f"{server_url}/get_kernel_by_run_problem_sample/{run_name}/{problem_id}/{sample_id}", - json={"run_name": run_name, "problem_id": problem_id, "sample_id": sample_id}, - ) - assert response.status_code == 200 - response_json = response.json() - assert str(response_json["problem_id"]) == str(problem_id) - return response_json + return f"{e.__class__.__module__}.{e.__class__.__name__}" def fetch_ref_arch_from_problem_id(problem_id, problems, with_name=False) -> str: @@ -85,6 +70,40 @@ def set_seed(seed: int): # NOTE: this only sets on current cuda device torch.cuda.manual_seed(seed) +def get_torch_dtype_from_string(precision: str) -> torch.dtype: + """ + Get the torch dtype for specific precision + """ + if precision == "fp32": + return torch.float32 + elif precision == "fp16": + return torch.float16 + elif precision == "bf16": + return torch.bfloat16 + else: # future, FP8, FP4, etc. support? + raise ValueError(f"Invalid precision not supported: {precision}") + +def get_tolerance_for_precision(precision: str | torch.dtype) -> float: + """ + Get the tolerance from a string representing the percision. + These tolerances are inspired by torchbench (PyTorch Benchmarking Suite): + Reference: + https://github.com/pytorch/benchmark/blob/cfd835c35d04513ced9a59bd074eeb21dc8187d7/torchbenchmark/util/env_check.py#L519 + """ + if isinstance(precision, str): + precision = get_torch_dtype_from_string(precision) + + PRECISION_TOLERANCES = { + # By default for fp32, 1e-4 is used according to torchbench. + torch.float32: 1e-4, + # torchbench states for bf16 and fp16, use 1e-3 as tolerance and 1e-2 if it's too strict. + # @todo: Let user configure own tolerance as an option + torch.float16: 1e-2, + torch.bfloat16: 1e-2, + } + assert precision in PRECISION_TOLERANCES, f"Invalid precision not supported: {precision}" + return PRECISION_TOLERANCES[precision] + class KernelExecResult(BaseModel): """ @@ -158,6 +177,7 @@ def load_custom_model_with_tempfile(model_custom_src, entry_point="ModelNew"): return ModelNew, temp_file +# TODO: fix by @nathan # def load_tilelang_model( # model_custom_src: str, # context: dict, @@ -379,31 +399,39 @@ def build_compile_cache_with_capturing( return returncode, stdout.decode("utf-8"), stderr.decode("utf-8") -def _process_input_tensor(tensor, device, backend): +def _process_input_tensor(input, device, backend="cuda", precision=torch.float32): """ Helper function to move tensors to the correct device and apply backend-specific dtype casting. Args: - tensor: Input tensor or non-tensor value + input: Input tensor or non-tensor value device: Target CUDA device backend: Backend type (e.g., 'cuda', 'triton', 'cute') - + precision: torch.dtype Returns: Processed tensor on correct device with correct dtype, or original value if not a tensor """ - if not isinstance(tensor, torch.Tensor): - return tensor + + # sometimes things like init inputs are floats (like in the case of labels / targets, classification losses, etc.) + if not isinstance(input, torch.Tensor): + return input - # Preserve integer dtypes for labels/targets (e.g., classification losses) - if tensor.dtype in [torch.int32, torch.int64, torch.long]: - return tensor.to(device=device) + # cast to the desired percision dtype for activations + input_tensor = input.to(dtype=precision) + + # @Nathan what is going on here? + # NOTE: to come back to this + # # Preserve integer dtypes for labels/targets (e.g., classification losses) + # if tensor.dtype in [torch.int32, torch.int64, torch.long]: + # return tensor.to(device=device) # Apply backend-specific dtype casting for float tensors + # NOTE: for tilelang is it all fp16? # if backend.lower() == "tilelang": # return tensor.to(device=device, dtype=torch.float16) # Default for all other backends and float types - return tensor.to(device=device) + return input_tensor.to(device=device) def eval_kernel_against_ref( @@ -419,6 +447,7 @@ def eval_kernel_against_ref( torch.cuda.current_device() if torch.cuda.is_available() else None ), # have to run on GPU backend: str = "cuda", # can be 'cuda', 'triton', or 'cute' + precision: torch.dtype = torch.float32, ) -> KernelExecResult: """ Evaluate the custom kernel against the original model @@ -479,7 +508,7 @@ def eval_kernel_against_ref( init_inputs = get_init_inputs() # Convert inputs to appropriate dtypes for GPU computation - init_inputs = [_process_input_tensor(x, device, backend) for x in init_inputs] + init_inputs = [_process_input_tensor(x, device, backend, precision) for x in init_inputs] with torch.no_grad(): set_seed(seed_num) # set seed for reproducible weights @@ -557,8 +586,8 @@ def eval_kernel_against_ref( # print("[Traceback]:") # traceback.print_exc() # else: - original_model = original_model.to(device=device) - custom_model = custom_model.to(device=device) + original_model = original_model.to(device=device, dtype=precision) + custom_model = custom_model.to(device=device, dtype=precision) torch.cuda.synchronize(device=device) if verbose: print("[Eval] New Model with Custom CUDA Kernel Loaded") @@ -590,6 +619,7 @@ def eval_kernel_against_ref( seed=seed_num, device=device, backend=backend, + precision=precision, ) except Exception as e: # TODO: add metadata for runtime error e.g. error in launching kernel, illegal memory access, ... @@ -610,7 +640,7 @@ def eval_kernel_against_ref( set_seed(seed_num) inputs = get_inputs() # Convert inputs for performance measurement - inputs = [_process_input_tensor(x, device, backend) for x in inputs] + inputs = [_process_input_tensor(x, device, backend, precision) for x in inputs] # if backend.lower() == "tilelang": # try: @@ -623,7 +653,7 @@ def eval_kernel_against_ref( # traceback.print_exc() # model_new = custom_model # else: - model_new = custom_model.to(device=device) + model_new = custom_model.to(device=device, dtype=precision) torch.cuda.synchronize(device=device) elapsed_times = time_execution_with_cuda_event( @@ -737,10 +767,11 @@ def run_and_check_correctness( get_inputs_fn: callable, metadata: dict, num_correct_trials: int, - verbose=False, - seed=42, - device=None, - backend="cuda", + verbose: bool =False, + seed: int =42, + device: Optional[torch.device] =None, + backend: str ="cuda", + precision: torch.dtype =torch.float32, ) -> KernelExecResult: """ run the model and check correctness, @@ -749,6 +780,7 @@ def run_and_check_correctness( num_correct_trials: run the evalutation multiple times with (ideally) different random inputs to ensure correctness backend: backend type for handling dtype conversions + precision: torch.dtype """ pass_count = 0 @@ -766,13 +798,14 @@ def run_and_check_correctness( if verbose: print(f"[Eval] Generating Random Input with seed {trial_seed}") + # @Nathan, let's find a cleaner way to handle tilelang dtype etc # if backend.lower() == "tilelang": # torch.set_default_dtype(torch.float16) set_seed(trial_seed) inputs = get_inputs_fn() # Convert inputs to appropriate dtypes for GPU computation - inputs = [_process_input_tensor(x, device, backend) for x in inputs] + inputs = [_process_input_tensor(x, device, backend, precision) for x in inputs] set_seed(trial_seed) # if backend.lower() == "tilelang": @@ -786,7 +819,7 @@ def run_and_check_correctness( # traceback.print_exc() # model = original_model_instance # else: - model = original_model_instance.to(device=device) + model = original_model_instance.to(device=device, dtype=precision) set_seed(trial_seed) # if backend.lower() == "tilelang": @@ -800,7 +833,7 @@ def run_and_check_correctness( # traceback.print_exc() # model_new = new_model_instance # else: - model_new = new_model_instance.to(device=device) + model_new = new_model_instance.to(device=device, dtype=precision) output = model(*inputs) torch.cuda.synchronize(device=device) @@ -824,9 +857,13 @@ def run_and_check_correctness( compiled=True, correctness=False, metadata=metadata ) + # in torchbench, they use both precisions for atol and rtol + # kernelbench v0 and v0.1 uses fp32, atol = rtol = 1e-02 + # now we will return the tolerance from get_tolerance_for_precision + tolerance = get_tolerance_for_precision(precision) # check output value difference if not torch.allclose( - output, output_new, atol=1e-02, rtol=1e-02 + output, output_new, atol=tolerance, rtol=tolerance ): # fail max_diff = torch.max(torch.abs(output - output_new)).item() avg_diff = torch.mean(torch.abs(output - output_new)).item() From 1eaf1ce02d5477c7573efd68ca50c06c58a77812 Mon Sep 17 00:00:00 2001 From: Nathan Paek Date: Mon, 3 Nov 2025 22:24:53 -0800 Subject: [PATCH 2/5] add tilelang --- scripts/generate_and_eval_single_sample.py | 4 +- .../generate_and_eval_single_sample_modal.py | 9 +- src/eval.py | 127 +--------- src/prompt_constructor_multilang.py | 226 +++++++++--------- src/prompts/model_new_ex_add_tilelang.py | 53 ++++ 5 files changed, 183 insertions(+), 236 deletions(-) create mode 100644 src/prompts/model_new_ex_add_tilelang.py diff --git a/scripts/generate_and_eval_single_sample.py b/scripts/generate_and_eval_single_sample.py index c2837ab8..1cc2a4a1 100644 --- a/scripts/generate_and_eval_single_sample.py +++ b/scripts/generate_and_eval_single_sample.py @@ -149,11 +149,11 @@ def main(config: EvalConfig): # Use appropriate prompt constructor based on backend if config.backend == "cuda": custom_prompt = prompt_generate_custom_cuda_from_prompt_template(ref_arch_src) - elif config.backend in ["triton", "cute"]: # removed "tilelang" + elif config.backend in ["triton", "tilelang", "cute"]: custom_prompt = get_prompt_for_backend(ref_arch_src, config.backend) else: raise ValueError( - f"Unsupported backend: {config.backend}. Must be 'cuda', 'triton', or 'cute'." + f"Unsupported backend: {config.backend}. Must be 'cuda', 'triton', 'tilelang', or 'cute'." ) if config.log_prompt: diff --git a/scripts/generate_and_eval_single_sample_modal.py b/scripts/generate_and_eval_single_sample_modal.py index 34c7fb22..210508a9 100644 --- a/scripts/generate_and_eval_single_sample_modal.py +++ b/scripts/generate_and_eval_single_sample_modal.py @@ -108,11 +108,10 @@ def __repr__(self): "pytest", "ninja", "utils", - # "tilelang", # commented out - not working currently - #"apache-tvm", + "tilelang", + "apache-tvm", "python-dotenv", "nvidia-cutlass-dsl", - ) .add_local_python_source("src") ) @@ -195,10 +194,10 @@ def main(config: EvalConfig): # Use appropriate prompt constructor based on backend if config.backend == "cuda": custom_prompt = prompt_generate_custom_cuda_from_prompt_template(ref_arch_src) - elif config.backend in ["triton", "cute"]: # removed "tilelang" + elif config.backend in ["triton", "tilelang", "cute"]: custom_prompt = get_prompt_for_backend(ref_arch_src, config.backend) else: - raise ValueError(f"Unsupported backend: {config.backend}. Must be 'cuda', 'triton', or 'cute'.") + raise ValueError(f"Unsupported backend: {config.backend}. Must be 'cuda', 'triton', 'tilelang', or 'cute'.") if config.log_prompt: with open(os.path.join(config.logdir, f"prompt_level_{config.level}_problem_{config.problem_id}.txt"), "w") as f: diff --git a/src/eval.py b/src/eval.py index a8bf7ac6..36b92a89 100644 --- a/src/eval.py +++ b/src/eval.py @@ -177,42 +177,6 @@ def load_custom_model_with_tempfile(model_custom_src, entry_point="ModelNew"): return ModelNew, temp_file -# TODO: fix by @nathan -# def load_tilelang_model( -# model_custom_src: str, -# context: dict, -# build_directory: str | None = None -# ): -# """ -# Load TileLang model using linecache instead of tempfile. -# This registers the source code in memory so inspect.getsource() works, -# which is needed for TileLang's JIT decorator. -# """ -# if build_directory: -# model_custom_src = ( -# "import os\n" -# f"os.environ['TORCH_EXTENSIONS_DIR'] = '{build_directory}'\n" -# + model_custom_src -# ) -# -# # Register source so inspect.getsource works -# fake_fname = ( -# f"/tmp/tilelang_kernel_" -# f"{hashlib.md5(model_custom_src.encode()).hexdigest()}.py" -# ) -# # linecache expects a list with trailing newlines -# linecache.cache[fake_fname] = ( -# len(model_custom_src), -# None, -# model_custom_src.splitlines(True), -# fake_fname, -# ) -# -# code_obj = compile(model_custom_src, fake_fname, "exec") -# exec(code_obj, context) -# return context["ModelNew"] - - def load_custom_model( model_custom_src: str, context: dict, build_directory: str = None ) -> nn.Module: @@ -418,17 +382,6 @@ def _process_input_tensor(input, device, backend="cuda", precision=torch.float32 # cast to the desired percision dtype for activations input_tensor = input.to(dtype=precision) - - # @Nathan what is going on here? - # NOTE: to come back to this - # # Preserve integer dtypes for labels/targets (e.g., classification losses) - # if tensor.dtype in [torch.int32, torch.int64, torch.long]: - # return tensor.to(device=device) - - # Apply backend-specific dtype casting for float tensors - # NOTE: for tilelang is it all fp16? - # if backend.lower() == "tilelang": - # return tensor.to(device=device, dtype=torch.float16) # Default for all other backends and float types return input_tensor.to(device=device) @@ -446,7 +399,7 @@ def eval_kernel_against_ref( device: Union[torch.device, int] = ( torch.cuda.current_device() if torch.cuda.is_available() else None ), # have to run on GPU - backend: str = "cuda", # can be 'cuda', 'triton', or 'cute' + backend: str = "cuda", # can be 'cuda', 'triton', 'tilelang', or 'cute' precision: torch.dtype = torch.float32, ) -> KernelExecResult: """ @@ -455,14 +408,14 @@ def eval_kernel_against_ref( num_correct_trials: number of trials to initialize different random inputs; correctness pass only if all trials pass num_perf_trials: run the evalutation many times to take the average device: GPU (cuda) device to run the evalutation on - backend: str, one of 'cuda', 'triton', or 'cute' + backend: str, one of 'cuda', 'triton', 'tilelang', or 'cute' + precision: torch.dtype for computation (note: tilelang only supports fp16) """ # TODO: check device is busy assert torch.cuda.is_available(), "CUDA is not available, cannot run Eval" - # SET DEFAULT DTYPE TO FLOAT16 ONLY FOR TILELANG - # if backend.lower() == "tilelang": - # torch.set_default_dtype(torch.float16) + if backend.lower() == "tilelang": + assert precision == torch.float16 or precision == torch.bfloat16, "TileLang only supports fp16 or bfloat16" torch.set_printoptions( precision=4, # Decimal places @@ -475,7 +428,8 @@ def eval_kernel_against_ref( torch.cuda.set_device(device) # Backends that use tempfile approach and need CUDA_VISIBLE_DEVICES - uses_tempfile = backend.lower() in ["triton", "cute"] # removed "tilelang" + # TileLang, Triton, and CuTe all use tempfile for proper module loading + uses_tempfile = backend.lower() in ["triton", "tilelang", "cute"] metadata = {} # for storing result metadata metadata["hardware"] = torch.cuda.get_device_name(device=device) @@ -527,11 +481,9 @@ def eval_kernel_against_ref( # add hash for later to distinguish between multi-turn kernels backend_lower = backend.lower() - # if backend_lower == "tilelang": - # # Use linecache approach for TileLang - # ModelNew = load_tilelang_model(custom_model_src, context, build_dir) - if backend_lower in ["triton", "cute"]: - # Use tempfile approach for triton and cute + if backend_lower in ["triton", "tilelang", "cute"]: + # Use tempfile approach for triton, tilelang, and cute + # These DSLs require proper module import for JIT decorators to work ModelNew, tempfile = load_custom_model_with_tempfile( custom_model_src, entry_point="ModelNew" ) @@ -567,25 +519,6 @@ def eval_kernel_against_ref( set_seed(seed_num) # set seed for reproducible weights custom_model = ModelNew(*init_inputs) assert hasattr(custom_model, "forward") - # Move models to GPU with float16 dtype (only for TileLang) - # if backend.lower() == "tilelang": - # try: - # original_model = original_model.to(device=device, dtype=torch.float16) - # except Exception as e: - # # TileLang JIT kernels may not support .to(), already on GPU - # if verbose: - # print(f"[Info] Could not call .to() on original model (TileLang), using as-is: {e}") - # print("[Traceback]:") - # traceback.print_exc() - # try: - # custom_model = custom_model.to(device=device, dtype=torch.float16) - # except Exception as e: - # # TileLang JIT kernels may not support .to(), already on GPU - # if verbose: - # print(f"[Info] Could not call .to() on custom model (TileLang), using as-is: {e}") - # print("[Traceback]:") - # traceback.print_exc() - # else: original_model = original_model.to(device=device, dtype=precision) custom_model = custom_model.to(device=device, dtype=precision) torch.cuda.synchronize(device=device) @@ -641,18 +574,6 @@ def eval_kernel_against_ref( inputs = get_inputs() # Convert inputs for performance measurement inputs = [_process_input_tensor(x, device, backend, precision) for x in inputs] - - # if backend.lower() == "tilelang": - # try: - # model_new = custom_model.to(device=device, dtype=torch.float16) - # except Exception as e: - # # TileLang JIT kernels may not support .to(), already on GPU - # if verbose: - # print(f"[Info] Line 616 - Could not call .to() on custom model for perf measurement (TileLang): {e}") - # print("[Traceback] From performance measurement - line 616:") - # traceback.print_exc() - # model_new = custom_model - # else: model_new = custom_model.to(device=device, dtype=precision) torch.cuda.synchronize(device=device) @@ -797,10 +718,6 @@ def run_and_check_correctness( trial_seed = correctness_trial_seeds[trial] if verbose: print(f"[Eval] Generating Random Input with seed {trial_seed}") - - # @Nathan, let's find a cleaner way to handle tilelang dtype etc - # if backend.lower() == "tilelang": - # torch.set_default_dtype(torch.float16) set_seed(trial_seed) inputs = get_inputs_fn() @@ -808,31 +725,11 @@ def run_and_check_correctness( inputs = [_process_input_tensor(x, device, backend, precision) for x in inputs] set_seed(trial_seed) - # if backend.lower() == "tilelang": - # try: - # model = original_model_instance.to(device=device, dtype=torch.float16) - # except Exception as e: - # # TileLang JIT kernels may not support .to(), already on GPU - # if verbose: - # print(f"[Info] Line 771 - Could not call .to() on original model (TileLang): {e}") - # print("[Traceback] From run_and_check_correctness - line 771:") - # traceback.print_exc() - # model = original_model_instance - # else: + model = original_model_instance.to(device=device, dtype=precision) set_seed(trial_seed) - # if backend.lower() == "tilelang": - # try: - # model_new = new_model_instance.to(device=device, dtype=torch.float16) - # except Exception as e: - # # TileLang JIT kernels may not support .to(), already on GPU - # if verbose: - # print(f"[Info] Line 777 - Could not call .to() on custom model (TileLang): {e}") - # print("[Traceback] From run_and_check_correctness - line 777:") - # traceback.print_exc() - # model_new = new_model_instance - # else: + model_new = new_model_instance.to(device=device, dtype=precision) output = model(*inputs) diff --git a/src/prompt_constructor_multilang.py b/src/prompt_constructor_multilang.py index 39d16243..8a520d10 100644 --- a/src/prompt_constructor_multilang.py +++ b/src/prompt_constructor_multilang.py @@ -265,118 +265,116 @@ def prompt_fix_correctness_triton(ref_arch_src, custom_kernel, metadata): ################################################################################ -# TileLang Backend - COMMENTED OUT (not working currently) +# TileLang Backend ################################################################################ -# TILELANG_PROBLEM_STATEMENT = """You write custom TileLang kernels to replace the pytorch operators in the given architecture to get speedups. \n -# You have complete freedom to choose the set of operators you want to replace. You may make the decision to replace some operators with custom TileLang kernels and leave others unchanged. You may replace multiple operators with custom implementations, consider operator fusion opportunities (combining multiple operators into a single kernel, for example, combining matmul+relu), or algorithmic changes (such as online softmax). You are only limited by your imagination.\n -# """ -# -# TILELANG_PROBLEM_INSTRUCTION = """ -# Optimize the architecture named Model with custom TileLang kernels! Name your optimized output architecture ModelNew. Output the new code in codeblocks. Please generate real code, NOT pseudocode, make sure the code compiles and is fully functional. Just output the new model code, no other text, and NO testing code! \n -# """ -# -# TILELANG_PROBLEM_STATEMENT_CLEANED = """You write custom TileLang kernels to replace the pytorch operators in the given architecture to get speedups.\n\nYou have complete freedom to choose the set of operators you want to replace. You may make the decision to replace some operators with custom TileLang kernels and leave others unchanged. You may replace multiple operators with custom implementations, consider operator fusion opportunities (combining multiple operators into a single kernel, for example, combining matmul+relu), or algorithmic changes (such as online softmax). You are only limited by your imagination.\n -# """ -# -# TILELANG_PROBLEM_INSTRUCTION_CLEANED = """ -# Optimize the architecture named Model with custom TileLang kernels! Name your optimized output architecture ModelNew. Output the new code in codeblocks. Please generate real code, NOT pseudocode, make sure the code compiles and is fully functional. Just output the new model code, no other text, and NO testing code! \n -# """ -# -# -# def prompt_generate_custom_tilelang( -# arc_src: str, example_arch_src: str, example_new_arch_src: str -# ) -> str: -# prompt = TILELANG_PROBLEM_STATEMENT -# -# if example_arch_src != "" and example_new_arch_src != "": -# prompt += f""" -# Here's an example to show you the syntax of inline embedding custom TileLang kernels in torch: The example given architecture is: \n -# ``` \n -# {example_arch_src} -# ``` \n -# The example new arch with custom TileLang kernels looks like this: \n -# ``` -# {example_new_arch_src} -# ``` \n -# """ -# -# prompt += f""" -# You are given the following architecture: \n -# ``` -# {arc_src} -# ``` -# """ -# prompt += TILELANG_PROBLEM_INSTRUCTION -# return prompt -# -# -# def prompt_generate_custom_tilelang_from_prompt_template(ref_arch_src: str) -> str: -# """ -# Using prompt example for TileLang -# Note: You'll need to create a TileLang example file similar to the Triton one -# """ -# arch = ref_arch_src -# -# # TODO: Create model_new_ex_add_tilelang.py example file -# example_arch_path = os.path.join(REPO_TOP_PATH, f"src/prompts/model_ex_add.py") -# example_new_arch_path = os.path.join( -# REPO_TOP_PATH, f"src/prompts/model_new_ex_add_tilelang.py" -# ) -# -# if not os.path.exists(example_arch_path): -# raise FileNotFoundError( -# f"Example architecture file not found: {example_arch_path}" -# ) -# if not os.path.exists(example_new_arch_path): -# # For now, use a basic template without examples if file doesn't exist -# return prompt_generate_custom_tilelang(arch, "", "") -# -# example_arch = read_file(example_arch_path) -# example_new_arch = read_file(example_new_arch_path) -# -# return prompt_generate_custom_tilelang(arch, example_arch, example_new_arch) -# -# -# def prompt_fix_compile_tilelang(ref_arch_src, custom_kernel, metadata): -# prompt = TILELANG_PROBLEM_STATEMENT -# prompt += f""" -# With the following architecture: -# ``` -# {ref_arch_src} -# ``` -# You generated the following solution and it failed to compile: -# ``` -# {custom_kernel} -# ``` -# Here's the metadata of the compilation error: -# ``` -# {metadata} -# ``` -# -# Please fix the compilation error in the new model code. Please output the corrected code in codeblocks. -# """ -# return prompt -# -# -# def prompt_fix_correctness_tilelang(ref_arch_src, custom_kernel, metadata): -# prompt = TILELANG_PROBLEM_STATEMENT -# prompt += f""" -# With the following architecture: -# ``` -# {ref_arch_src} -# ``` -# You generated the following solution and it failed correctness: -# ``` -# {custom_kernel} -# ``` -# Here's the metadata of the correctness error: -# ``` -# {metadata} -# ``` -# Please consider how your custom TileLang kernels are implemented, how it is different from the reference implementation, and fix the correctness error in the new model code. Please output the corrected code in codeblocks. -# """ -# return prompt +TILELANG_PROBLEM_STATEMENT = """You write custom TileLang kernels to replace the pytorch operators in the given architecture to get speedups. \n + You have complete freedom to choose the set of operators you want to replace. You may make the decision to replace some operators with custom TileLang kernels and leave others unchanged. You may replace multiple operators with custom implementations, consider operator fusion opportunities (combining multiple operators into a single kernel, for example, combining matmul+relu), or algorithmic changes (such as online softmax). You are only limited by your imagination.\n +""" + +TILELANG_PROBLEM_INSTRUCTION = """ +Optimize the architecture named Model with custom TileLang kernels! Name your optimized output architecture ModelNew. Output the new code in codeblocks. Please generate real code, NOT pseudocode, make sure the code compiles and is fully functional. Just output the new model code, no other text, and NO testing code! \n +""" + +TILELANG_PROBLEM_STATEMENT_CLEANED = """You write custom TileLang kernels to replace the pytorch operators in the given architecture to get speedups.\n\nYou have complete freedom to choose the set of operators you want to replace. You may make the decision to replace some operators with custom TileLang kernels and leave others unchanged. You may replace multiple operators with custom implementations, consider operator fusion opportunities (combining multiple operators into a single kernel, for example, combining matmul+relu), or algorithmic changes (such as online softmax). You are only limited by your imagination.\n +""" + +TILELANG_PROBLEM_INSTRUCTION_CLEANED = """ +Optimize the architecture named Model with custom TileLang kernels! Name your optimized output architecture ModelNew. Output the new code in codeblocks. Please generate real code, NOT pseudocode, make sure the code compiles and is fully functional. Just output the new model code, no other text, and NO testing code! \n +""" + + +def prompt_generate_custom_tilelang( + arc_src: str, example_arch_src: str, example_new_arch_src: str +) -> str: + prompt = TILELANG_PROBLEM_STATEMENT + + if example_arch_src != "" and example_new_arch_src != "": + prompt += f""" + Here's an example to show you the syntax of inline embedding custom TileLang kernels in torch: The example given architecture is: \n + ``` \n + {example_arch_src} + ``` \n + The example new arch with custom TileLang kernels looks like this: \n + ``` + {example_new_arch_src} + ``` \n + """ + + prompt += f""" + You are given the following architecture: \n + ``` + {arc_src} + ``` + """ + prompt += TILELANG_PROBLEM_INSTRUCTION + return prompt + + +def prompt_generate_custom_tilelang_from_prompt_template(ref_arch_src: str) -> str: + """ + Using prompt example for TileLang + """ + arch = ref_arch_src + + example_arch_path = os.path.join(REPO_TOP_PATH, f"src/prompts/model_ex_add.py") + example_new_arch_path = os.path.join( + REPO_TOP_PATH, f"src/prompts/model_new_ex_add_tilelang.py" + ) + + if not os.path.exists(example_arch_path): + raise FileNotFoundError( + f"Example architecture file not found: {example_arch_path}" + ) + if not os.path.exists(example_new_arch_path): + # For now, use a basic template without examples if file doesn't exist + return prompt_generate_custom_tilelang(arch, "", "") + + example_arch = read_file(example_arch_path) + example_new_arch = read_file(example_new_arch_path) + + return prompt_generate_custom_tilelang(arch, example_arch, example_new_arch) + + +def prompt_fix_compile_tilelang(ref_arch_src, custom_kernel, metadata): + prompt = TILELANG_PROBLEM_STATEMENT + prompt += f""" + With the following architecture: + ``` + {ref_arch_src} + ``` + You generated the following solution and it failed to compile: + ``` + {custom_kernel} + ``` + Here's the metadata of the compilation error: + ``` + {metadata} + ``` + + Please fix the compilation error in the new model code. Please output the corrected code in codeblocks. + """ + return prompt + + +def prompt_fix_correctness_tilelang(ref_arch_src, custom_kernel, metadata): + prompt = TILELANG_PROBLEM_STATEMENT + prompt += f""" + With the following architecture: + ``` + {ref_arch_src} + ``` + You generated the following solution and it failed correctness: + ``` + {custom_kernel} + ``` + Here's the metadata of the correctness error: + ``` + {metadata} + ``` + Please consider how your custom TileLang kernels are implemented, how it is different from the reference implementation, and fix the correctness error in the new model code. Please output the corrected code in codeblocks. + """ + return prompt ################################################################################ @@ -504,7 +502,7 @@ def get_prompt_for_backend(ref_arch_src: str, backend: str = "triton") -> str: Args: ref_arch_src: Reference architecture source code - backend: One of 'triton', 'cute' (tilelang removed - not working) + backend: One of 'triton', 'tilelang', 'cute' Returns: Prompt string for the specified backend @@ -513,13 +511,13 @@ def get_prompt_for_backend(ref_arch_src: str, backend: str = "triton") -> str: if backend_lower == "triton": return prompt_generate_custom_triton_from_prompt_template(ref_arch_src) - # elif backend_lower == "tilelang": - # return prompt_generate_custom_tilelang_from_prompt_template(ref_arch_src) + elif backend_lower == "tilelang": + return prompt_generate_custom_tilelang_from_prompt_template(ref_arch_src) elif backend_lower == "cute": return prompt_generate_custom_cute_from_prompt_template(ref_arch_src) else: raise ValueError( - f"Unsupported backend: {backend}. Must be one of: 'triton', 'cute'" + f"Unsupported backend: {backend}. Must be one of: 'triton', 'tilelang', 'cute'" ) diff --git a/src/prompts/model_new_ex_add_tilelang.py b/src/prompts/model_new_ex_add_tilelang.py new file mode 100644 index 00000000..b65fb2e9 --- /dev/null +++ b/src/prompts/model_new_ex_add_tilelang.py @@ -0,0 +1,53 @@ +import torch +import torch.nn as nn +import tilelang +import tilelang.language as T + + +def build_elementwise_add_kernel(M: int, N: int, block_M: int = 128, block_N: int = 256, threads: int = 128, dtype: str = "float16"): + + @T.prim_func + def elementwise_add_kernel( + A: T.Tensor((M, N), dtype), + B: T.Tensor((M, N), dtype), + C: T.Tensor((M, N), dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + start_x = bx * block_N + start_y = by * block_M + + for local_y, local_x in T.Parallel(block_M, block_N): + y = start_y + local_y + x = start_x + local_x + + C[y, x] = A[y, x] + B[y, x] + + return tilelang.compile(elementwise_add_kernel, out_idx=[2], target="cuda") + + +class ModelNew(nn.Module): + def __init__(self): + super(ModelNew, self).__init__() + self._kernel_cache = {} + + def _get_kernel(self, M: int, N: int, tl_dtype: str): + key = (M, N, tl_dtype) + if key not in self._kernel_cache: + self._kernel_cache[key] = build_elementwise_add_kernel(M, N, dtype=tl_dtype) + return self._kernel_cache[key] + + def forward(self, A: torch.Tensor, B: torch.Tensor) -> torch.Tensor: + A_c = A.contiguous() + B_c = B.contiguous() + + # Get original shape for reshaping output + original_shape = A_c.shape + + A_c = A_c.view(-1, A_c.size(-1)) + B_c = B_c.view(-1, B_c.size(-1)) + + M, N = A_c.shape + kernel = self._get_kernel(M, N, "float16") + C = kernel(A_c, B_c) + + return C.view(original_shape) \ No newline at end of file From 89727ecc921be416a6cff2ba090dc9fe7b4b0eb1 Mon Sep 17 00:00:00 2001 From: Nathan Paek Date: Mon, 3 Nov 2025 22:26:37 -0800 Subject: [PATCH 3/5] update requirements for tilelang --- requirements.txt | 2 ++ 1 file changed, 2 insertions(+) diff --git a/requirements.txt b/requirements.txt index c912156c..f9a29461 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,6 +7,8 @@ modal # DSLs nvidia-cutlass-dsl +tilelang +apache-tvm # helper tqdm From d9709cd7b4b36b8b0b9233041e31dcece6ab633e Mon Sep 17 00:00:00 2001 From: nathanjp Date: Tue, 4 Nov 2025 22:16:34 -0800 Subject: [PATCH 4/5] add precision to other files --- scripts/eval_from_generations.py | 13 ++++++++++++- scripts/generate_samples.py | 6 ++++-- scripts/run_and_check.py | 7 +++++-- src/eval.py | 1 + 4 files changed, 22 insertions(+), 5 deletions(-) diff --git a/scripts/eval_from_generations.py b/scripts/eval_from_generations.py index 787aca2b..c7f6fe61 100644 --- a/scripts/eval_from_generations.py +++ b/scripts/eval_from_generations.py @@ -145,6 +145,10 @@ def __init__(self): # Backend to use for kernel implementation (cuda or triton) self.backend = "cuda" + + # Precision for computation: "fp32", "fp16", "bf16" + self.precision = "fp32" + # Number of samples per problem to evaluate for pass@k analysis self.num_samples_per_problem = 1 # Default to 1 sample per problem @@ -188,11 +192,13 @@ def evaluate_single_sample_modal( num_perf_trials: int = 100, measure_performance: bool = True, verbose: bool = False, + backend: str = "cuda", + precision: str = "fp32", ): """ Evaluate a single sample on Modal GPU with automatic retries for GPU attachment failures """ - from src.eval import eval_kernel_against_ref + from src.eval import eval_kernel_against_ref, get_torch_dtype_from_string from src.utils import set_gpu_arch import torch import time @@ -225,6 +231,8 @@ def evaluate_single_sample_modal( num_perf_trials=num_perf_trials, build_dir=None, # Modal doesn't need persistent build dir device=torch.device("cuda:0"), # Modal has one GPU per container + backend=backend, + precision=get_torch_dtype_from_string(precision), ) # Force cleanup and exit to prevent container reuse and memory leaks @@ -321,6 +329,7 @@ def evaluate_single_sample( build_dir=build_dir, device=device, backend=configs.backend, + precision=eval.get_torch_dtype_from_string(configs.precision), ) return eval_result except Exception as e: @@ -491,6 +500,8 @@ def batch_eval_modal( num_perf_trials=config.num_perf_trials, measure_performance=config.measure_performance, verbose=config.verbose, + backend=config.backend, + precision=config.precision, ) futures.append(future) diff --git a/scripts/generate_samples.py b/scripts/generate_samples.py index 5ee217cf..b3e562f5 100644 --- a/scripts/generate_samples.py +++ b/scripts/generate_samples.py @@ -73,6 +73,8 @@ def __init__(self): self.log_prompt = False self.backend = "cuda" + + self.precision = "fp32" def greedy(self): # For greedy decoding, epsecially baseline eval @@ -124,11 +126,11 @@ def generate_sample_single( custom_cuda_prompt = prompt_generate_custom_cuda_from_prompt_template( ref_arch_src ) - elif config.backend in ["triton", "cute"]: # removed "tilelang" + elif config.backend in ["triton", "cute", "tilelang"]: custom_cuda_prompt = get_prompt_for_backend(ref_arch_src, config.backend) else: raise ValueError( - f"Unsupported backend: {config.backend}. Must be 'cuda', 'triton', or 'cute'." + f"Unsupported backend: {config.backend}. Must be 'cuda', 'triton', 'cute', or 'tilelang'." ) if config.log_prompt: prompt_path = os.path.join( diff --git a/scripts/run_and_check.py b/scripts/run_and_check.py index 79c00a7e..a863afcd 100644 --- a/scripts/run_and_check.py +++ b/scripts/run_and_check.py @@ -9,7 +9,6 @@ from src import eval as kernel_eval from src import utils as kernel_utils from scripts.generate_baseline_time import measure_program_time - from src.utils import read_file """ @@ -68,6 +67,8 @@ def __init__(self): # Replace with your NVIDIA GPU architecture, e.g. ["Hopper"] self.gpu_arch = ["Ada"] + self.precision = "fp32" + self.backend = "cuda" def __repr__(self): return f"ScriptConfig({self.to_dict()})" @@ -97,7 +98,9 @@ def evaluate_single_sample_src(ref_arch_src: str, kernel_src: str, configs: dict num_correct_trials=num_correct_trials, num_perf_trials=num_perf_trials, build_dir=build_dir, - device=device + device=device, + backend=configs["backend"], + precision=kernel_eval.get_torch_dtype_from_string(configs["precision"]) ) return eval_result except Exception as e: diff --git a/src/eval.py b/src/eval.py index 36b92a89..4a072c89 100644 --- a/src/eval.py +++ b/src/eval.py @@ -574,6 +574,7 @@ def eval_kernel_against_ref( inputs = get_inputs() # Convert inputs for performance measurement inputs = [_process_input_tensor(x, device, backend, precision) for x in inputs] + model_new = custom_model.to(device=device, dtype=precision) torch.cuda.synchronize(device=device) From ca193c2b30cc2bd3d6b5320ae4b9a1e3db74b83c Mon Sep 17 00:00:00 2001 From: Simon Guo Date: Tue, 4 Nov 2025 23:48:34 -0800 Subject: [PATCH 5/5] tested and updated readmeg --- README.md | 24 ++++++++++++++++-------- 1 file changed, 16 insertions(+), 8 deletions(-) diff --git a/README.md b/README.md index 9c3232c8..b3db083a 100644 --- a/README.md +++ b/README.md @@ -1,16 +1,19 @@ # KernelBench: Can LLMs Write Efficient GPU Kernels? [ICML '25] -[arXiv](https://arxiv.org/html/2502.10517v1) | [blog post](https://scalingintelligence.stanford.edu/blogs/kernelbench/) | [HuggingFace Dataset](https://huggingface.co/datasets/ScalingIntelligence/KernelBench) | +A benchmark for evaluating LLMs' ability to generate efficient GPU kernels + +[arXiv](https://arxiv.org/html/2502.10517v1) | [blog post](https://scalingintelligence.stanford.edu/blogs/kernelbench/) | [HuggingFace Dataset](https://huggingface.co/datasets/ScalingIntelligence/KernelBench) + + ## Versions -The huggingface dataset is updated to v0.1. -- [v0.1](https://github.com/ScalingIntelligence/KernelBench/tree/v0.1) - Latest version (also main branch) +The latest stable version will be on `main` branch. We continue to update and improve the repo. +- [v0.1](https://github.com/ScalingIntelligence/KernelBench/tree/v0.1) - See [blog](https://scalingintelligence.stanford.edu/blogs/kernelbenchv01/) - [v0](https://github.com/ScalingIntelligence/KernelBench/tree/v0) - Original Release -A benchmark for evaluating LLMs' ability to generate efficient GPU kernels - +The Huggingface [dataset](https://huggingface.co/datasets/ScalingIntelligence/KernelBench) is updated to v0.1. - +This repo provides core functionality for KernelBench and an easy-to-use set of scripts for evaluation. It is not intended to provide complex agentic scaffolds that solve this task; we recommend cloning and modifying this repo for your experiment, or using it as a git submodule. ## 👋 Task Description We structure the problem for LLM to transpile operators described in PyTorch to CUDA kernels, at whatever level of granularity it desires to. @@ -26,7 +29,7 @@ We construct KernelBench to have 4 Levels of categories: - **Level 4 🤗**: Level Hugging Face Optimize whole model architectures from HuggingFace -We are actively extending KernelBench to other DSLs beyond `cuda` as well. +We are actively extending KernelBench to other DSLs beyond `cuda` as well (see below). ## ⚖️ Evaluation #### Methodology @@ -98,7 +101,12 @@ python3 scripts/generate_and_eval_single_sample.py dataset_src="huggingface" lev # add .verbose_logging for more visbility ``` -We are also supporting other GPU programming languages beyond `cuda`. Simply specify `backend=triton`. For now we support (`cuda`, `triton`, `cute`). +**What you might need to modify** +* **`gpu_arch`** - Depend on your GPU, you might need to adjust the `gpu_arch` argument to reflect your hardware. +* **`precision`** - You can specify the precision of tensor by `precision=fp32`. Currently all of our reported results are `fp32` but we added support for `fp16` & `bf16`. +* **`backend`** - We are also supporting other GPU programming languages beyond `cuda`. Simply specify `backend=triton`. For now we support DSLs: `cuda`, `triton`, `cute`, `tilelang`. + +Check the config fields for comprehensive set of options. ### Run on all problems