diff --git a/original_files.txt b/KernelBench/changelog/original_files.txt similarity index 100% rename from original_files.txt rename to KernelBench/changelog/original_files.txt diff --git a/KernelBench/test.py b/KernelBench/test.py deleted file mode 100644 index 4c4d5451..00000000 --- a/KernelBench/test.py +++ /dev/null @@ -1,8 +0,0 @@ -import torch - -print("Before patch:", torch.randn) # built-in - -import src.utils # or: import src.utils -print("After patch:", torch.randn) # - -print("Module :", torch.randn.__module__) # should be 'src.utils' \ No newline at end of file diff --git a/README.md b/README.md index 99db491b..8ec45e8a 100644 --- a/README.md +++ b/README.md @@ -26,6 +26,8 @@ We construct KernelBench to have 4 Levels of categories: - **Level 4 🤗**: Level Hugging Face Optimize whole model architectures from HuggingFace +We are actively extending KernelBench to other DSLs beyond `cuda` as well. + ## ⚖️ Evaluation #### Methodology To evaluate model-generated kernels, we need to check if they: @@ -47,6 +49,7 @@ Some examples to illustrate this metric that filters based on speedups: You can increase speedup threshold `p` to make the task more challenging. + #### Compute Overall Benchmark Performance We provide a script `scripts/greedy_analysis.py` to compute the overall benchmark performance. @@ -95,6 +98,8 @@ python3 scripts/generate_and_eval_single_sample.py dataset_src="huggingface" lev # add .verbose_logging for more visbility ``` +We are also supporting other GPU programming languages beyond `cuda`. Simply specify `backend=triton`. For now we support (`cuda`, `triton`, `cute`). + ### Run on all problems ``` @@ -120,25 +125,10 @@ We provide some reference baseline times a variety of NVIDIA GPUs across generat We have also releaed the test-time framework [Caesar](https://github.com/simonguozirui/caesar) that are used in the multi-turn / iterative refinement experiments in our paper. You can use or modify this framework for high-throughput test-time scaling (both sequential and parallel) targeting KernelBench problems. ## 🛣️ Upcoming Roadmap -- [ ] Triton Variant (To be merged) -- [ ] Easy to use CoLab Notebook Example -- [ ] Push button flow on Modal / Cloud Provider -- [ ] Integrate with more frameworks, such as [ThunderKittens](https://github.com/HazyResearch/ThunderKittens) -- [ ] Add backward pass -- [ ] Integrate with toolchains such as NCU -See Issues for the ongoing roadmap and directions. - - +Check out our [roadmap](https://github.com/ScalingIntelligence/KernelBench/issues/74) for what we plan to add as features. We welcome community contirbutions in these directions. ## 🔍 Known Usage -- [NVIDIA](https://developer.nvidia.com/blog/automating-gpu-kernel-generation-with-deepseek-r1-and-inference-time-scaling/) - Automating GPU Kernel Generation with DeepSeek-R1 and Inference Time Scaling -- [METR](https://metr.org/blog/2025-02-14-measuring-automated-kernel-engineering/) - Measuring Automated Kernel Engineering -- [Sakana AI](https://sakana.ai/ai-cuda-engineer/) - AI Cuda Engineer -- [Project Popcorn](https://www.youtube.com/watch?v=mdDVkBeFy9A) - Triton Support for KernelBench, Data Scaling + SFT'd Kernel LLM -- [Kevin](https://cognition.ai/blog/kevin-32b) - Kevin-32B: Multi-Turn RL for Writing CUDA Kernels -- [Simple Test-Time Search](https://scalingintelligence.stanford.edu/blogs/fastkernels/) - by @anneouyang - -If you are using KernelBench, we love to hear more about it! +Since release, we have gotten a lot of interest from researchers, research labs, and companies that use KernelBench to explore this direction. We have documented [known usage](https://docs.google.com/document/d/e/2PACX-1vTjS-UMH1HB5n_PENq2k-3YRfXIXkqKIKeNC2zcWMyLPdl4Jrwvdk4dNDVSsM8ybKrCxZB7GJq1slZF/pub) of KernelBench and related efforts towards automated kernel generations. If you are using KernelBench, we love to hear more about it! ## 🪪 License MIT. Check `LICENSE.md` for more details. diff --git a/requirements.txt b/requirements.txt index 8b6e866f..c912156c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,17 +1,28 @@ -anthropic +# Frameworks +torch==2.5.0 +# we shall upgrade torch for blackwell when it is stable +transformers +datasets modal -numpy -openai + +# DSLs +nvidia-cutlass-dsl + +# helper +tqdm packaging pydra_config -torch==2.5.0 -tqdm -datasets -transformers -google-generativeai -together pytest ninja -archon-ai + +# Numerics einops -dotenv \ No newline at end of file +dotenv +numpy + +# to deprecate with litellm +google-generativeai +together +openai +anthropic + diff --git a/scripts/eval_from_generations.py b/scripts/eval_from_generations.py index 6c94a71c..787aca2b 100644 --- a/scripts/eval_from_generations.py +++ b/scripts/eval_from_generations.py @@ -3,6 +3,8 @@ import os import shutil import time +from dataclasses import dataclass + from collections import defaultdict from dataclasses import dataclass @@ -12,15 +14,19 @@ from datasets import load_dataset from pydra import Config, REQUIRED + +# Import only what we need from src import compile, eval, utils from src.dataset import construct_kernelbench_dataset from src.eval import ( build_compile_cache, + get_error_name, check_metadata_serializable_all_types, eval_kernel_against_ref, KernelExecResult, ) + from src.utils import read_file, set_gpu_arch from tqdm import tqdm @@ -137,6 +143,8 @@ def __init__(self): # number of GPUs to do batch evaluation self.num_gpu_devices = 1 + # Backend to use for kernel implementation (cuda or triton) + self.backend = "cuda" # Number of samples per problem to evaluate for pass@k analysis self.num_samples_per_problem = 1 # Default to 1 sample per problem @@ -312,6 +320,7 @@ def evaluate_single_sample( num_perf_trials=configs.num_perf_trials, build_dir=build_dir, device=device, + backend=configs.backend, ) return eval_result except Exception as e: @@ -322,6 +331,7 @@ def evaluate_single_sample( # NOTE: count this as compilation failure as it is not runnable code metadata = { "cuda_error": f"CUDA Error: {str(e)}", + "cuda_error_name": get_error_name(e), "hardware": torch.cuda.get_device_name(device=device), "device": str(device), } # log this for debugging as this usually signifies illegal memory access @@ -332,6 +342,7 @@ def evaluate_single_sample( else: metadata = { "other_error": f"error: {str(e)}", + "other_error_name": get_error_name(e), "hardware": torch.cuda.get_device_name(device=device), "device": str(device), } # for debugging @@ -387,10 +398,9 @@ def cuda_single_eval_wrapper(curr_work: WorkArgs, configs: dict, dataset, run_di pool.terminate() pool.join() raise - except mp.TimeoutError: + except mp.TimeoutError as e: print( - f"[WARNING] Evaluation TIMED OUT for Problem ID: {curr_work.problem_id}," - f" Sample ID: {curr_work.sample_id}" + f"[WARNING] Evaluation TIMED OUT for Problem ID: {curr_work.problem_id}, Sample ID: {curr_work.sample_id}\nException: {e}" ) print( @@ -691,7 +701,7 @@ def add_to_eval_results_file( os.makedirs(os.path.dirname(eval_file_path), exist_ok=True) with open(eval_file_path, "w") as f: - json.dump(eval_results, f) + json.dump(eval_results, f, indent=4) def single_eval_example( diff --git a/scripts/generate_and_eval_single_sample.py b/scripts/generate_and_eval_single_sample.py index 3fdb14b5..ff71e4bc 100644 --- a/scripts/generate_and_eval_single_sample.py +++ b/scripts/generate_and_eval_single_sample.py @@ -3,13 +3,21 @@ import os, sys import torch import json +import modal from datasets import load_dataset from src.dataset import construct_kernelbench_dataset from src.eval import eval_kernel_against_ref from src.prompt_constructor import prompt_generate_custom_cuda_from_prompt_template -from src.utils import extract_first_code, query_server, set_gpu_arch, read_file, create_inference_server_from_presets +from src.prompt_constructor_multilang import get_prompt_for_backend +from src.utils import ( + create_inference_server_from_presets, + extract_first_code, + query_server, + read_file, + set_gpu_arch, +) """ Generate and evaluate a single sample @@ -20,15 +28,15 @@ torch.set_printoptions(precision=4, threshold=10) + class EvalConfig(Config): def __init__(self): - - self.dataset_src = REQUIRED # either huggingface or local + + self.dataset_src = REQUIRED # either huggingface or local # name of dataset name on Hugging Face self.dataset_name = "ScalingIntelligence/KernelBench" - # Problem Specification self.level = REQUIRED # NOTE: this is the logical index (problem id the problem_name)\ @@ -56,6 +64,8 @@ def __init__(self): self.log_generated_kernel = False self.log_eval_result = False + self.backend = "cuda" + def verbose_logging(self): self.log = True self.log_prompt = True @@ -86,24 +96,31 @@ def main(config: EvalConfig): if config.log: os.makedirs(config.logdir, exist_ok=True) - + # Problem Checks num_problems = len(curr_level_dataset) print(f"Number of problems in Level {config.level}: {num_problems}") - print(f"Start Generation + Evaluation for Level {config.level} Problem {config.problem_id}") - - assert config.problem_id <= num_problems, f"Problem ID {config.problem_id} out of range for Level {config.level}" + print( + f"Start Generation + Evaluation for Level {config.level} Problem {config.problem_id}" + ) + assert ( + config.problem_id <= num_problems + ), f"Problem ID {config.problem_id} out of range for Level {config.level}" # 1. Fetch Problem if config.dataset_src == "huggingface": - curr_problem_row = curr_level_dataset.filter(lambda x: x["problem_id"] == config.problem_id) + curr_problem_row = curr_level_dataset.filter( + lambda x: x["problem_id"] == config.problem_id + ) ref_arch_src = curr_problem_row["code"][0] problem_name = curr_problem_row["name"][0] elif config.dataset_src == "local": - problem_idx_in_dataset = config.problem_id - 1 # due to dataset list being 0-indexed locally + problem_idx_in_dataset = ( + config.problem_id - 1 + ) # due to dataset list being 0-indexed locally ref_arch_path = curr_level_dataset[problem_idx_in_dataset] problem_name = os.path.basename(ref_arch_path) @@ -112,52 +129,90 @@ def main(config: EvalConfig): # Extract problem number from problem name (e.g. "1" from "1_Square_matrix_multiplication_.py") problem_number = int(problem_name.split("_")[0]) - assert problem_number == config.problem_id, f"Problem number in filename ({problem_number}) does not match config problem_id ({config.problem_id})" - - + assert ( + problem_number == config.problem_id + ), f"Problem number in filename ({problem_number}) does not match config problem_id ({config.problem_id})" + # 2. Generate Sample # Create inference function with config parameters # We provide some presets in utils but you can also pass in your own, see query_server for more details - inference_server = create_inference_server_from_presets(server_type=config.server_type, - model_name=config.model_name, - temperature=config.temperature, - max_tokens=config.max_tokens, - verbose=config.verbose, - time_generation=True) - + inference_server = create_inference_server_from_presets( + server_type=config.server_type, + model_name=config.model_name, + temperature=config.temperature, + max_tokens=config.max_tokens, + verbose=config.verbose, + time_generation=True, + ) + # Use appropriate prompt constructor based on backend + if config.backend == "cuda": + custom_prompt = prompt_generate_custom_cuda_from_prompt_template(ref_arch_src) + elif config.backend in ["triton", "cute"]: # removed "tilelang" + custom_prompt = get_prompt_for_backend(ref_arch_src, config.backend) + else: + raise ValueError( + f"Unsupported backend: {config.backend}. Must be 'cuda', 'triton', or 'cute'." + ) - custom_cuda_prompt = prompt_generate_custom_cuda_from_prompt_template(ref_arch_src) if config.log_prompt: - with open(os.path.join(config.logdir, f"prompt_level_{config.level}_problem_{config.problem_id}.txt"), "w") as f: - f.write(custom_cuda_prompt) + with open( + os.path.join( + config.logdir, + f"prompt_level_{config.level}_problem_{config.problem_id}.txt", + ), + "w", + ) as f: + f.write(custom_prompt) # Query server with constructed prompt - custom_cuda = inference_server(custom_cuda_prompt) - custom_cuda = extract_first_code(custom_cuda, ["python", "cpp"]) - # check LLM is able to generate custom CUDA code - assert custom_cuda is not None, "Custom CUDA code generation failed" - + custom_kernel = inference_server(custom_prompt) + custom_kernel = extract_first_code(custom_kernel, ["python", "cpp"]) + + # check LLM is able to generate custom kernel code + assert ( + custom_kernel is not None + ), f"Custom {config.backend} kernel code generation failed" + # this should be optional if config.log: - with open(os.path.join(config.logdir, f"generated_kernel_level_{config.level}_problem_{config.problem_id}.py"), "w") as f: - f.write(custom_cuda) + with open( + os.path.join( + config.logdir, + f"generated_kernel_level_{config.level}_problem_{config.problem_id}.py", + ), + "w", + ) as f: + f.write(custom_kernel) # 3. Evaluate Kernel # NOTE: no need to wrap around process here as only a single sample # see batch eval for examples of process isolation kernel_exec_result = eval_kernel_against_ref( - ref_arch_src, custom_cuda, verbose=config.verbose, measure_performance=True, num_correct_trials=5, num_perf_trials=100 + ref_arch_src, + custom_kernel, + verbose=config.verbose, + measure_performance=True, + num_correct_trials=5, + num_perf_trials=100, + backend=config.backend, + ) + + print( + f"Evaluation result for level {config.level} problem {config.problem_id}:\n{kernel_exec_result}" ) - - print(f"Evaluation result for level {config.level} problem {config.problem_id}:\n{kernel_exec_result}") if config.log: - with open(os.path.join(config.logdir, f"eval_result_level_{config.level}_problem_{config.problem_id}.txt"), "a") as f: + with open( + os.path.join( + config.logdir, + f"eval_result_level_{config.level}_problem_{config.problem_id}.txt", + ), + "a", + ) as f: f.write(f"Problem Name: {problem_name}\n") f.write(str(kernel_exec_result)) if __name__ == "__main__": - main() - + main() \ No newline at end of file diff --git a/scripts/generate_and_eval_single_sample_modal.py b/scripts/generate_and_eval_single_sample_modal.py index 743f3b89..e9e0866a 100644 --- a/scripts/generate_and_eval_single_sample_modal.py +++ b/scripts/generate_and_eval_single_sample_modal.py @@ -16,6 +16,7 @@ #from src.dataset import construct_kernelbench_dataset from src.eval import eval_kernel_against_ref from src.prompt_constructor import prompt_generate_custom_cuda_from_prompt_template +from src.prompt_constructor_multilang import get_prompt_for_backend from src.utils import extract_first_code, query_server, set_gpu_arch, read_file, create_inference_server_from_presets app = modal.App("eval_single_sample") @@ -69,6 +70,8 @@ def __init__(self): self.log_generated_kernel = False self.log_eval_result = False + self.backend = "cuda" + def verbose_logging(self): self.log = True self.log_prompt = True @@ -105,7 +108,11 @@ def __repr__(self): "pytest", "ninja", "utils", + # "tilelang", # commented out - not working currently + #"apache-tvm", "python-dotenv", + "nvidia-cutlass-dsl", + ) .add_local_python_source("src") ) @@ -114,15 +121,17 @@ def __repr__(self): class EvalFunc: @modal.method() - def eval_single_sample_modal(self, ref_arch_src, custom_cuda, verbose, gpu_arch): + def eval_single_sample_modal(self, ref_arch_src, custom_kernel, verbose, gpu_arch, backend): # 3. Evaluate Kernel # NOTE: no need to wrap around process here as only a single sample # see batch eval for examples of process isolation from src.eval import eval_kernel_against_ref - from src.utils import set_gpu_arch - set_gpu_arch(gpu_arch) + # Use utility function to set the GPU architecture in the modal environment + from src.utils import set_gpu_arch as modal_set_gpu_arch + modal_set_gpu_arch(gpu_arch) return eval_kernel_against_ref( - ref_arch_src, custom_cuda, verbose=verbose, measure_performance=True, num_correct_trials=5, num_perf_trials=100 + ref_arch_src, custom_kernel, verbose=verbose, measure_performance=True, + num_correct_trials=5, num_perf_trials=100, backend=backend ) @pydra.main(base=EvalConfig) @@ -182,24 +191,33 @@ def main(config: EvalConfig): - custom_cuda_prompt = prompt_generate_custom_cuda_from_prompt_template(ref_arch_src) + # Use appropriate prompt constructor based on backend + if config.backend == "cuda": + custom_prompt = prompt_generate_custom_cuda_from_prompt_template(ref_arch_src) + elif config.backend in ["triton", "cute"]: # removed "tilelang" + custom_prompt = get_prompt_for_backend(ref_arch_src, config.backend) + else: + raise ValueError(f"Unsupported backend: {config.backend}. Must be 'cuda', 'triton', or 'cute'.") + if config.log_prompt: with open(os.path.join(config.logdir, f"prompt_level_{config.level}_problem_{config.problem_id}.txt"), "w") as f: - f.write(custom_cuda_prompt) + f.write(custom_prompt) # Query server with constructed prompt - custom_cuda = inference_server(custom_cuda_prompt) - custom_cuda = extract_first_code(custom_cuda, ["python", "cpp"]) - # check LLM is able to generate custom CUDA code - assert custom_cuda is not None, "Custom CUDA code generation failed" + custom_kernel = inference_server(custom_prompt) + custom_kernel = extract_first_code(custom_kernel, ["python", "cpp"]) + # check LLM is able to generate custom kernel code + assert custom_kernel is not None, f"Custom {config.backend} kernel code generation failed" # this should be optional if config.log: with open(os.path.join(config.logdir, f"generated_kernel_level_{config.level}_problem_{config.problem_id}.py"), "w") as f: - f.write(custom_cuda) + f.write(custom_kernel) with app.run(): - kernel_exec_result = EvalFunc.with_options(gpu=config.gpu)().eval_single_sample_modal.remote(ref_arch_src, custom_cuda, config.verbose, gpu_arch_mapping[config.gpu]) + kernel_exec_result = EvalFunc.with_options(gpu=config.gpu)().eval_single_sample_modal.remote( + ref_arch_src, custom_kernel, config.verbose, gpu_arch_mapping[config.gpu], config.backend + ) print(f"Evaluation result for level {config.level} problem {config.problem_id}:\n{kernel_exec_result}") diff --git a/scripts/generate_samples.py b/scripts/generate_samples.py index b8650131..5ee217cf 100644 --- a/scripts/generate_samples.py +++ b/scripts/generate_samples.py @@ -11,6 +11,7 @@ from src.dataset import construct_kernelbench_dataset from src.eval import eval_kernel_against_ref from src.prompt_constructor import prompt_generate_custom_cuda_from_prompt_template +from src.prompt_constructor_multilang import get_prompt_for_backend from src.utils import ( create_inference_server_from_presets, extract_first_code, @@ -71,6 +72,8 @@ def __init__(self): self.log_prompt = False + self.backend = "cuda" + def greedy(self): # For greedy decoding, epsecially baseline eval self.greedy_sample = True @@ -117,7 +120,16 @@ def generate_sample_single( ), f"Problem number in filename ({problem_number}) does not match config problem_id ({config.problem_id})" # Construct Prompt - custom_cuda_prompt = prompt_generate_custom_cuda_from_prompt_template(ref_arch_src) + if config.backend == "cuda": + custom_cuda_prompt = prompt_generate_custom_cuda_from_prompt_template( + ref_arch_src + ) + elif config.backend in ["triton", "cute"]: # removed "tilelang" + custom_cuda_prompt = get_prompt_for_backend(ref_arch_src, config.backend) + else: + raise ValueError( + f"Unsupported backend: {config.backend}. Must be 'cuda', 'triton', or 'cute'." + ) if config.log_prompt: prompt_path = os.path.join( run_dir, diff --git a/src/eval.py b/src/eval.py index 4532154e..e2411f3c 100644 --- a/src/eval.py +++ b/src/eval.py @@ -2,17 +2,24 @@ Helpers for Evaluations """ +import hashlib +import importlib +import json +import linecache +import os, subprocess +import random +import sys +import tempfile +import traceback +from contextlib import redirect_stderr, redirect_stdout +from io import StringIO +from typing import Union + +import numpy as np import requests import torch import torch.nn as nn -import os, subprocess from pydantic import BaseModel -import numpy as np -import random -import json -from contextlib import redirect_stdout, redirect_stderr -from io import StringIO -import sys from . import utils @@ -25,6 +32,11 @@ KERNEL_BENCH_PATH = os.path.join(REPO_TOP_PATH, "KernelBench") +def get_error_name(e: Exception) -> str: + + return f"{e.__class__.__module__}.{e.__class__.__name__}" + + def fetch_kernel_from_database( run_name: str, problem_id: int, sample_id: int, server_url: str ): @@ -113,6 +125,74 @@ def load_original_model_and_inputs( return (Model, get_init_inputs_fn, get_inputs_fn) +def load_custom_model_with_tempfile(model_custom_src, entry_point="ModelNew"): + """ + Writes the provided Python code string to a temporary .py file, + dynamically imports the module so we can access the modified model class. + + Returns both a Model class and the temporary file. The temporary file must be + deleted manually be the caller. + + This is a hack that is needed for triton code as compile / exec do not play well + with the @triton.jit decorator. + """ + + # Create a temporary named file with a .py extension + with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as tmp_file: + # Write the code string into the file + tmp_file.write(model_custom_src) + # Capture the path to the file + tempfile_path = tmp_file.name + temp_file = tmp_file + + # Create a module specification pointing to our temp file + spec = importlib.util.spec_from_file_location("temp_module", tempfile_path) + # Create a new module based on that spec + temp_module = importlib.util.module_from_spec(spec) + # Execute the code in the module's namespace + spec.loader.exec_module(temp_module) + + ModelNew = getattr(temp_module, entry_point) + + # Return the object (class, function, etc.) that was defined in the code + return ModelNew, temp_file + + +# def load_tilelang_model( +# model_custom_src: str, +# context: dict, +# build_directory: str | None = None +# ): +# """ +# Load TileLang model using linecache instead of tempfile. +# This registers the source code in memory so inspect.getsource() works, +# which is needed for TileLang's JIT decorator. +# """ +# if build_directory: +# model_custom_src = ( +# "import os\n" +# f"os.environ['TORCH_EXTENSIONS_DIR'] = '{build_directory}'\n" +# + model_custom_src +# ) +# +# # Register source so inspect.getsource works +# fake_fname = ( +# f"/tmp/tilelang_kernel_" +# f"{hashlib.md5(model_custom_src.encode()).hexdigest()}.py" +# ) +# # linecache expects a list with trailing newlines +# linecache.cache[fake_fname] = ( +# len(model_custom_src), +# None, +# model_custom_src.splitlines(True), +# fake_fname, +# ) +# +# code_obj = compile(model_custom_src, fake_fname, "exec") +# exec(code_obj, context) +# return context["ModelNew"] + + def load_custom_model( model_custom_src: str, context: dict, build_directory: str = None ) -> nn.Module: @@ -151,7 +231,11 @@ def _cleanup_cuda_extensions(): shutil.rmtree(torch_extensions_path) -def graceful_eval_cleanup(curr_context: dict, device: torch.device): +def graceful_eval_cleanup( + curr_context: dict, + device: torch.device, + tempfile: tempfile.NamedTemporaryFile = None, +): """ Clean up env, gpu cache, and compiled CUDA extensions after evaluation """ # delete ran-specific function definitions before next eval run @@ -166,9 +250,13 @@ def graceful_eval_cleanup(curr_context: dict, device: torch.device): torch.cuda.synchronize( device=device ) # Wait for all CUDA operations to complete + if tempfile: + tempfile.close() + os.remove(tempfile.name) # _cleanup_cuda_extensions() # SIMON NOTE: is this necessary? + def build_compile_cache_legacy( custom_model_src: str, verbose: bool = False, @@ -202,11 +290,12 @@ def build_compile_cache_legacy( if verbose: print(f"[Compilation] Compilation Successful, saved cache at: {build_dir}") except Exception as e: - print(f"[Compilation] Failed to compile custom CUDA kernel. Unable to cache, \nError: {e}") + print( + f"[Compilation] Failed to compile custom CUDA kernel. Unable to cache, \nError: {e}" + ) return False, stdout_buffer.getvalue(), str(e) - - return True, stdout_buffer.getvalue(), None + return True, stdout_buffer.getvalue(), None def build_compile_cache( @@ -242,16 +331,16 @@ def build_compile_cache( if verbose: print(f"[Compilation] Compilation Successful, saved cache at: {build_dir}") except Exception as e: - print(f"[Compilation] Failed to compile custom CUDA kernel. Unable to cache, \nError: {e}") + print( + f"[Compilation] Failed to compile custom CUDA kernel. Unable to cache, \nError: {e}" + ) return False, stdout_buffer.getvalue(), str(e) return True, stdout_buffer.getvalue(), None def build_compile_cache_with_capturing( - custom_model_src: str, - verbose: bool = False, - build_dir: os.PathLike = None + custom_model_src: str, verbose: bool = False, build_dir: os.PathLike = None ) -> tuple[int, str, str]: """ Write a temporary python file to compile the custom model on CPU @@ -273,22 +362,48 @@ def build_compile_cache_with_capturing( f.write(custom_model_src) # Execute the temporary Python file and capture output - process = subprocess.Popen(['python', tmp], stdout=subprocess.PIPE, stderr=subprocess.PIPE) + process = subprocess.Popen( + ["python", tmp], stdout=subprocess.PIPE, stderr=subprocess.PIPE + ) stdout, stderr = process.communicate() returncode = process.returncode # Clean up temporary file os.remove(tmp) - if verbose: print("[CPU Precompile] return code: ", returncode) - print("[CPU Precompile] stdout: \n", stdout.decode('utf-8')) - print("[CPU Precompile] stderr: \n", stderr.decode('utf-8')) + print("[CPU Precompile] stdout: \n", stdout.decode("utf-8")) + print("[CPU Precompile] stderr: \n", stderr.decode("utf-8")) - return returncode, stdout.decode('utf-8'), stderr.decode('utf-8') + return returncode, stdout.decode("utf-8"), stderr.decode("utf-8") +def _process_input_tensor(tensor, device, backend): + """ + Helper function to move tensors to the correct device and apply backend-specific dtype casting. + + Args: + tensor: Input tensor or non-tensor value + device: Target CUDA device + backend: Backend type (e.g., 'cuda', 'triton', 'cute') + + Returns: + Processed tensor on correct device with correct dtype, or original value if not a tensor + """ + if not isinstance(tensor, torch.Tensor): + return tensor + + # Preserve integer dtypes for labels/targets (e.g., classification losses) + if tensor.dtype in [torch.int32, torch.int64, torch.long]: + return tensor.to(device=device) + + # Apply backend-specific dtype casting for float tensors + # if backend.lower() == "tilelang": + # return tensor.to(device=device, dtype=torch.float16) + + # Default for all other backends and float types + return tensor.to(device=device) def eval_kernel_against_ref( @@ -300,7 +415,10 @@ def eval_kernel_against_ref( verbose: bool = False, measure_performance: bool = False, build_dir: os.PathLike = None, - device: torch.device = torch.cuda.current_device() if torch.cuda.is_available() else None, # have to run on GPU + device: Union[torch.device, int] = ( + torch.cuda.current_device() if torch.cuda.is_available() else None + ), # have to run on GPU + backend: str = "cuda", # can be 'cuda', 'triton', or 'cute' ) -> KernelExecResult: """ Evaluate the custom kernel against the original model @@ -308,9 +426,15 @@ def eval_kernel_against_ref( num_correct_trials: number of trials to initialize different random inputs; correctness pass only if all trials pass num_perf_trials: run the evalutation many times to take the average device: GPU (cuda) device to run the evalutation on + backend: str, one of 'cuda', 'triton', or 'cute' """ # TODO: check device is busy assert torch.cuda.is_available(), "CUDA is not available, cannot run Eval" + + # SET DEFAULT DTYPE TO FLOAT16 ONLY FOR TILELANG + # if backend.lower() == "tilelang": + # torch.set_default_dtype(torch.float16) + torch.set_printoptions( precision=4, # Decimal places threshold=10, # Total number of elements before truncating @@ -320,7 +444,28 @@ def eval_kernel_against_ref( # set CUDA device torch.cuda.set_device(device) + + # Backends that use tempfile approach and need CUDA_VISIBLE_DEVICES + uses_tempfile = backend.lower() in ["triton", "cute"] # removed "tilelang" + + metadata = {} # for storing result metadata + metadata["hardware"] = torch.cuda.get_device_name(device=device) + metadata["device"] = str(device) # for debugging + if uses_tempfile: + # need to set env var for triton/cute code to guarantee no wrong device shenanigans + if isinstance(device, int): + device_num = device + elif isinstance(device, torch.device): + assert ( + device.type == "cuda" + ), "CUDA is not availible on device, cannot run Eval" + device_num = device.index + else: + raise ValueError( + f"device must be an int or torch.device, got {type(device)}" + ) + os.environ["CUDA_VISIBLE_DEVICES"] = str(device_num) context = {} if verbose: @@ -332,28 +477,38 @@ def eval_kernel_against_ref( ) set_seed(seed_num) # set seed for reproducible input init_inputs = get_init_inputs() - init_inputs = [ - x.cuda(device=device) if isinstance(x, torch.Tensor) else x for x in init_inputs - ] - + + # Convert inputs to appropriate dtypes for GPU computation + init_inputs = [_process_input_tensor(x, device, backend) for x in init_inputs] + with torch.no_grad(): set_seed(seed_num) # set seed for reproducible weights original_model = Model(*init_inputs) assert hasattr(original_model, "forward") if verbose: print("[Eval] Original Model Loaded") + if verbose: print("[Eval] Loading and Compiling New Model with Custom CUDA Kernel") - metadata = {} # for storing result metadata - metadata["hardware"] = torch.cuda.get_device_name(device=device) - metadata["device"] = str(device) # for debugging - # this is where compilation happens try: os.environ["TORCH_USE_CUDA_DSA"] = "1" # compile with device side assertion + tempfile = None # add hash for later to distinguish between multi-turn kernels - ModelNew = load_custom_model(custom_model_src, context, build_dir) + + backend_lower = backend.lower() + # if backend_lower == "tilelang": + # # Use linecache approach for TileLang + # ModelNew = load_tilelang_model(custom_model_src, context, build_dir) + if backend_lower in ["triton", "cute"]: + # Use tempfile approach for triton and cute + ModelNew, tempfile = load_custom_model_with_tempfile( + custom_model_src, entry_point="ModelNew" + ) + else: + # Default CUDA backend + ModelNew = load_custom_model(custom_model_src, context, build_dir) torch.cuda.synchronize(device=device) # not sure if this is too much except Exception as e: print( @@ -367,11 +522,12 @@ def eval_kernel_against_ref( print( f"[Eval] Lock file error during compilation, Please retry. Error: {e}" ) - graceful_eval_cleanup(context, device) + graceful_eval_cleanup(context, device, tempfile) return None else: + metadata["compilation_error_name"] = get_error_name(e) metadata["compilation_error"] = e - graceful_eval_cleanup(context, device) + graceful_eval_cleanup(context, device, tempfile) return KernelExecResult( compiled=False, metadata=metadata ) # skip further steps @@ -382,6 +538,27 @@ def eval_kernel_against_ref( set_seed(seed_num) # set seed for reproducible weights custom_model = ModelNew(*init_inputs) assert hasattr(custom_model, "forward") + # Move models to GPU with float16 dtype (only for TileLang) + # if backend.lower() == "tilelang": + # try: + # original_model = original_model.to(device=device, dtype=torch.float16) + # except Exception as e: + # # TileLang JIT kernels may not support .to(), already on GPU + # if verbose: + # print(f"[Info] Could not call .to() on original model (TileLang), using as-is: {e}") + # print("[Traceback]:") + # traceback.print_exc() + # try: + # custom_model = custom_model.to(device=device, dtype=torch.float16) + # except Exception as e: + # # TileLang JIT kernels may not support .to(), already on GPU + # if verbose: + # print(f"[Info] Could not call .to() on custom model (TileLang), using as-is: {e}") + # print("[Traceback]:") + # traceback.print_exc() + # else: + original_model = original_model.to(device=device) + custom_model = custom_model.to(device=device) torch.cuda.synchronize(device=device) if verbose: print("[Eval] New Model with Custom CUDA Kernel Loaded") @@ -390,8 +567,9 @@ def eval_kernel_against_ref( f"Failed to load custom CUDA kernel; Compiled but not able to run, count as runtime error. \nError: {e}" ) # TODO: add metadata for runtime error e.g. error in launching kernel, illegal memory access, ... - graceful_eval_cleanup(context, device) + graceful_eval_cleanup(context, device, tempfile) metadata["runtime_error"] = e + metadata["runtime_error_name"] = get_error_name(e) return KernelExecResult( compiled=True, correctness=False, metadata=metadata ) # skip further steps @@ -411,10 +589,12 @@ def eval_kernel_against_ref( verbose=verbose, seed=seed_num, device=device, + backend=backend, ) except Exception as e: # TODO: add metadata for runtime error e.g. error in launching kernel, illegal memory access, ... metadata["runtime_error"] = e + metadata["runtime_error_name"] = get_error_name(e) kernel_exec_result = KernelExecResult( compiled=True, correctness=False, metadata=metadata ) @@ -429,11 +609,21 @@ def eval_kernel_against_ref( torch.cuda.synchronize(device=device) set_seed(seed_num) inputs = get_inputs() - inputs = [ - x.cuda(device=device) if isinstance(x, torch.Tensor) else x - for x in inputs - ] - model_new = custom_model.cuda(device=device) + # Convert inputs for performance measurement + inputs = [_process_input_tensor(x, device, backend) for x in inputs] + + # if backend.lower() == "tilelang": + # try: + # model_new = custom_model.to(device=device, dtype=torch.float16) + # except Exception as e: + # # TileLang JIT kernels may not support .to(), already on GPU + # if verbose: + # print(f"[Info] Line 616 - Could not call .to() on custom model for perf measurement (TileLang): {e}") + # print("[Traceback] From performance measurement - line 616:") + # traceback.print_exc() + # model_new = custom_model + # else: + model_new = custom_model.to(device=device) torch.cuda.synchronize(device=device) elapsed_times = time_execution_with_cuda_event( @@ -454,7 +644,7 @@ def eval_kernel_against_ref( print(f"[Eval] Error in Measuring Performance: {e}") kernel_exec_result.metadata["error_during_performance"] = e - graceful_eval_cleanup(context, device) + graceful_eval_cleanup(context, device, tempfile) return kernel_exec_result @@ -550,6 +740,7 @@ def run_and_check_correctness( verbose=False, seed=42, device=None, + backend="cuda", ) -> KernelExecResult: """ run the model and check correctness, @@ -557,6 +748,7 @@ def run_and_check_correctness( this is all on GPU, requiring cuda device and transfer .cuda() num_correct_trials: run the evalutation multiple times with (ideally) different random inputs to ensure correctness + backend: backend type for handling dtype conversions """ pass_count = 0 @@ -573,19 +765,42 @@ def run_and_check_correctness( trial_seed = correctness_trial_seeds[trial] if verbose: print(f"[Eval] Generating Random Input with seed {trial_seed}") + + # if backend.lower() == "tilelang": + # torch.set_default_dtype(torch.float16) set_seed(trial_seed) inputs = get_inputs_fn() - inputs = [ - x.cuda(device=device) if isinstance(x, torch.Tensor) else x - for x in inputs - ] + # Convert inputs to appropriate dtypes for GPU computation + inputs = [_process_input_tensor(x, device, backend) for x in inputs] set_seed(trial_seed) - model = original_model_instance.cuda(device=device) + # if backend.lower() == "tilelang": + # try: + # model = original_model_instance.to(device=device, dtype=torch.float16) + # except Exception as e: + # # TileLang JIT kernels may not support .to(), already on GPU + # if verbose: + # print(f"[Info] Line 771 - Could not call .to() on original model (TileLang): {e}") + # print("[Traceback] From run_and_check_correctness - line 771:") + # traceback.print_exc() + # model = original_model_instance + # else: + model = original_model_instance.to(device=device) set_seed(trial_seed) - model_new = new_model_instance.cuda(device=device) + # if backend.lower() == "tilelang": + # try: + # model_new = new_model_instance.to(device=device, dtype=torch.float16) + # except Exception as e: + # # TileLang JIT kernels may not support .to(), already on GPU + # if verbose: + # print(f"[Info] Line 777 - Could not call .to() on custom model (TileLang): {e}") + # print("[Traceback] From run_and_check_correctness - line 777:") + # traceback.print_exc() + # model_new = new_model_instance + # else: + model_new = new_model_instance.to(device=device) output = model(*inputs) torch.cuda.synchronize(device=device) @@ -600,6 +815,7 @@ def run_and_check_correctness( f"Output shape mismatch: Expected {output.shape}, got {output_new.shape}", metadata, ) + metadata["correctness_issue_name"] = "correctness_issue" if verbose: print( f"[FAIL] trial {trial}: Output shape mismatch: Expected {output.shape}, got {output_new.shape}" @@ -627,10 +843,16 @@ def run_and_check_correctness( except Exception as e: print("[Error] Exception happens during correctness check") print(f"Error in launching kernel for ModelNew: {e}") + print("\n[Full Traceback]:") + traceback.print_exc() + print("\n") metadata = register_and_format_exception( "runtime_error", e, metadata, truncate=True ) + metadata["runtime_error_name"] = get_error_name(e) + # Also store the full traceback in metadata for debugging + metadata["runtime_error_traceback"] = traceback.format_exc() return KernelExecResult( compiled=True, correctness=False, metadata=metadata ) @@ -678,11 +900,13 @@ def check_metadata_serializable(metadata: dict): return metadata + def check_metadata_serializable_all_types(metadata: dict): """ Ensure metadata is JSON serializable, if not, convert non-serializable values to strings recursively """ + def convert_to_serializable(obj): if isinstance(obj, dict): return {k: convert_to_serializable(v) for k, v in obj.items()} @@ -759,4 +983,4 @@ def get_timing_stats(elapsed_times: list[float], device: torch.device = None) -> # if __name__ == "__main__": # fetch_kernel_from_database("kernelbench_prompt_v2_level_2", 1, 1, "http://localhost:9091") # print(fetch_ref_arch_from_level_problem_id("2", 1, with_name=True)) -# fetch_baseline_time("level1", 0, ["1_Square_matrix_multiplication_.py"], "tests/baseline_time_matx3.json") +# fetch_baseline_time("level1", 0, ["1_Square_matrix_multiplication_.py"], "tests/baseline_time_matx3.json") \ No newline at end of file diff --git a/src/prompt_constructor_multilang.py b/src/prompt_constructor_multilang.py new file mode 100644 index 00000000..39d16243 --- /dev/null +++ b/src/prompt_constructor_multilang.py @@ -0,0 +1,553 @@ +import os +from .utils import read_file + +""" +Multi-Language Prompt Constructor + +Supports: Triton, CuTe (TileLang currently disabled/commented out) + +Design principles: +- To evaluate base model performance on KernelBench, we use the simplest prompt possible to guide model output to generated desired output format. +- However, we do not do extensive prompt engineering or few-shot examples in the LLM to steer behaviour. +""" + +REPO_TOP_PATH = os.path.abspath( + os.path.join( + os.path.dirname(__file__), + "..", + ) +) +KERNEL_BENCH_PATH = os.path.join(REPO_TOP_PATH, "KernelBench") + + +def get_arch_definition_from_file(arch_path): + arch_src = read_file(arch_path) + return get_arch_definition(arch_src) + + +def get_arch_definition(arch_src): + """ + Construct torch definition from original torch nn.Module definition + """ + prompt = f"Here is a pytorch defintion of a neural network architecture in the file model.py: ```{arch_src}```\n" + return prompt + + +################################################################################ +# Triton Backend +################################################################################ + +TRITON_PROBLEM_STATEMENT = """You write custom Triton kernels to replace the pytorch operators in the given architecture to get speedups. \n + You have complete freedom to choose the set of operators you want to replace. You may make the decision to replace some operators with custom Triton kernels and leave others unchanged. You may replace multiple operators with custom implementations, consider operator fusion opportunities (combining multiple operators into a single kernel, for example, combining matmul+relu), or algorithmic changes (such as online softmax). You are only limited by your imagination.\n +""" + +TRITON_PROBLEM_INSTRUCTION = """ +Optimize the architecture named Model with custom Triton kernels! Name your optimized output architecture ModelNew. Output the new code in codeblocks. Please generate real code, NOT pseudocode, make sure the code compiles and is fully functional. Just output the new model code, no other text, and NO testing code! \n +""" + +TRITON_PROBLEM_STATEMENT_CLEANED = """You write custom Triton kernels to replace the pytorch operators in the given architecture to get speedups.\n\nYou have complete freedom to choose the set of operators you want to replace. You may make the decision to replace some operators with custom Triton kernels and leave others unchanged. You may replace multiple operators with custom implementations, consider operator fusion opportunities (combining multiple operators into a single kernel, for example, combining matmul+relu), or algorithmic changes (such as online softmax). You are only limited by your imagination.\n +""" + +TRITON_PROBLEM_INSTRUCTION_CLEANED = """ +Optimize the architecture named Model with custom Triton kernels! Name your optimized output architecture ModelNew. Output the new code in codeblocks. Please generate real code, NOT pseudocode, make sure the code compiles and is fully functional. Just output the new model code, no other text, and NO testing code! \n +""" + + +def prompt_generate_custom_triton( + arc_src: str, example_arch_src: str, example_new_arch_src: str +) -> str: + prompt = TRITON_PROBLEM_STATEMENT + + assert ( + "@triton.jit" in example_new_arch_src + ), "Example new arch must contain Triton kernel" + + if example_arch_src != "" and example_new_arch_src != "": + prompt += f""" + Here's an example to show you the syntax of inline embedding custom Triton kernels in torch: The example given architecture is: \n + ``` \n + {example_arch_src} + ``` \n + The example new arch with custom Triton kernels looks like this: \n + ``` + {example_new_arch_src} + ``` \n + """ + + prompt += f""" + You are given the following architecture: \n + ``` + {arc_src} + ``` + """ + prompt += TRITON_PROBLEM_INSTRUCTION + return prompt + + +def prompt_generate_custom_triton_fewshot_and_template( + ref_arch_src: str, shots: list +) -> str: + raise NotImplementedError("This function has not been implemented yet") + + +def prompt_generate_ex_with_CoT_template_triton(ref_arch_src: str, cot_example: str) -> str: + raise NotImplementedError("This function has not been implemented yet") + + +def prompt_generate_custom_triton_from_prompt_template(ref_arch_src: str) -> str: + """ + Using prompt example (an element-wise addition) for prompt templates + The most basic form of example just to show LLM the task and the expected output format + """ + arch = ref_arch_src + + # path to prompt template, show an example of Model (torch specifications) and ModelNew (torch + custom Triton kernels) + example_arch_path = os.path.join(REPO_TOP_PATH, f"src/prompts/model_ex_add.py") + example_new_arch_path = os.path.join( + REPO_TOP_PATH, f"src/prompts/model_new_ex_add_triton.py" + ) + + if not os.path.exists(example_arch_path): + raise FileNotFoundError( + f"Example architecture file not found: {example_arch_path}" + ) + if not os.path.exists(example_new_arch_path): + raise FileNotFoundError( + f"Example new architecture file not found: {example_new_arch_path}" + ) + + example_arch = read_file(example_arch_path) + example_new_arch = read_file(example_new_arch_path) + + return prompt_generate_custom_triton(arch, example_arch, example_new_arch) + + +def prompt_generate_prompt_with_hardware_info_from_template_triton( + ref_arch_src: str, gpu_name: str +) -> str: + """ + Similar to prompt_generate_custom_triton_from_prompt_template, + but with hardware information for the given GPU + """ + arch = ref_arch_src + + example_arch_path = os.path.join(REPO_TOP_PATH, f"src/prompts/model_ex_add.py") + example_new_arch_path = os.path.join( + REPO_TOP_PATH, f"src/prompts/model_new_ex_add_triton.py" + ) + gpu_spec_file_path = os.path.join( + REPO_TOP_PATH, f"src/prompts/hardware/gpu_specs.py" + ) + + example_arch = read_file(example_arch_path) + example_new_arch = read_file(example_new_arch_path) + gpu_spec_info = read_file(gpu_spec_file_path) + + return prompt_generate_prompt_with_hardware_info_triton( + ref_arch_src=arch, + gpu_name=gpu_name, + example_arch_src=example_arch, + example_new_arch_src=example_new_arch, + gpu_spec_info_src=gpu_spec_info, + ) + + +def prompt_generate_prompt_with_hardware_info_triton( + ref_arch_src: str, + gpu_name: str, + example_arch_src: str, + example_new_arch_src: str, + gpu_spec_info_src: str, +) -> str: + """ + Generate a prompt with hardware information for the given GPU + gpu_spec_info_src: str of the gpu spec src file + """ + local_dict = {} + exec(gpu_spec_info_src, {}, local_dict) + + GPU_SPEC_INFO = local_dict.get("GPU_SPEC_INFO") + GPU_DEFINITIONS = local_dict.get("GPU_DEFINITIONS") + GPU_BEST_PRACTICES = local_dict.get("GPU_BEST_PRACTICES") + + if not GPU_SPEC_INFO or not GPU_DEFINITIONS or not GPU_BEST_PRACTICES: + raise ValueError( + "GPU_SPEC_INFO or GPU_DEFINITIONS or GPU_BEST_PRACTICES not found in gpu_spec_info_src" + ) + + assert gpu_name in GPU_SPEC_INFO, f"GPU name {gpu_name} not found in GPU_SPEC_INFO" + + prompt = TRITON_PROBLEM_STATEMENT + + if example_arch_src != "" and example_new_arch_src != "": + prompt += f""" + Here's an example to show you the syntax of inline embedding custom Triton kernels in torch: The example given architecture is: \n + ``` \n + {example_arch_src} + ``` \n + The example new arch with custom Triton kernels looks like this: + ``` + {example_new_arch_src} + ``` \n + """ + + curr_gpu_spec_info = GPU_SPEC_INFO[gpu_name] + gpu_architecture = curr_gpu_spec_info.get("GPU Architecture") + prompt += f""" + Here is some information about the underlying hardware that you should keep in mind. \n\n +The GPU that will run the kernel is NVIDIA {gpu_name}, {gpu_architecture} architecture.\n\n""" + + for key, value in curr_gpu_spec_info.items(): + if key == "GPU Architecture": + continue + prompt += f"""- We have {value} of {key}.\n""" + + prompt += f"""\n\n +Here are some concepts about the GPU architecture that could be helpful: \n\n""" + for key, value in GPU_DEFINITIONS.items(): + prompt += f"""- {key}: {value}\n""" + + prompt += f"""\n\n +Here are some best practices for writing Triton kernels on GPU: \n\n""" + for best_practice in GPU_BEST_PRACTICES: + prompt += f"""- {best_practice}\n""" + + prompt += f""" + You are given the following architecture: \n + ``` + {ref_arch_src} + ``` + """ + + prompt += TRITON_PROBLEM_INSTRUCTION + return prompt + + +def prompt_fix_compile_triton(ref_arch_src, custom_kernel, metadata): + prompt = TRITON_PROBLEM_STATEMENT + prompt += f""" + With the following architecture: + ``` + {ref_arch_src} + ``` + You generated the following solution and it failed to compile: + ``` + {custom_kernel} + ``` + Here's the metadata of the compilation error: + ``` + {metadata} + ``` + + Please fix the compilation error in the new model code. Please output the corrected code in codeblocks. + """ + return prompt + + +def prompt_fix_correctness_triton(ref_arch_src, custom_kernel, metadata): + prompt = TRITON_PROBLEM_STATEMENT + prompt += f""" + With the following architecture: + ``` + {ref_arch_src} + ``` + You generated the following solution and it failed correctness: + ``` + {custom_kernel} + ``` + Here's the metadata of the correctness error: + ``` + {metadata} + ``` + Please consider how your custom Triton kernels are implemented, how it is different from the reference implementation, and fix the correctness error in the new model code. Please output the corrected code in codeblocks. + """ + return prompt + + +################################################################################ +# TileLang Backend - COMMENTED OUT (not working currently) +################################################################################ + +# TILELANG_PROBLEM_STATEMENT = """You write custom TileLang kernels to replace the pytorch operators in the given architecture to get speedups. \n +# You have complete freedom to choose the set of operators you want to replace. You may make the decision to replace some operators with custom TileLang kernels and leave others unchanged. You may replace multiple operators with custom implementations, consider operator fusion opportunities (combining multiple operators into a single kernel, for example, combining matmul+relu), or algorithmic changes (such as online softmax). You are only limited by your imagination.\n +# """ +# +# TILELANG_PROBLEM_INSTRUCTION = """ +# Optimize the architecture named Model with custom TileLang kernels! Name your optimized output architecture ModelNew. Output the new code in codeblocks. Please generate real code, NOT pseudocode, make sure the code compiles and is fully functional. Just output the new model code, no other text, and NO testing code! \n +# """ +# +# TILELANG_PROBLEM_STATEMENT_CLEANED = """You write custom TileLang kernels to replace the pytorch operators in the given architecture to get speedups.\n\nYou have complete freedom to choose the set of operators you want to replace. You may make the decision to replace some operators with custom TileLang kernels and leave others unchanged. You may replace multiple operators with custom implementations, consider operator fusion opportunities (combining multiple operators into a single kernel, for example, combining matmul+relu), or algorithmic changes (such as online softmax). You are only limited by your imagination.\n +# """ +# +# TILELANG_PROBLEM_INSTRUCTION_CLEANED = """ +# Optimize the architecture named Model with custom TileLang kernels! Name your optimized output architecture ModelNew. Output the new code in codeblocks. Please generate real code, NOT pseudocode, make sure the code compiles and is fully functional. Just output the new model code, no other text, and NO testing code! \n +# """ +# +# +# def prompt_generate_custom_tilelang( +# arc_src: str, example_arch_src: str, example_new_arch_src: str +# ) -> str: +# prompt = TILELANG_PROBLEM_STATEMENT +# +# if example_arch_src != "" and example_new_arch_src != "": +# prompt += f""" +# Here's an example to show you the syntax of inline embedding custom TileLang kernels in torch: The example given architecture is: \n +# ``` \n +# {example_arch_src} +# ``` \n +# The example new arch with custom TileLang kernels looks like this: \n +# ``` +# {example_new_arch_src} +# ``` \n +# """ +# +# prompt += f""" +# You are given the following architecture: \n +# ``` +# {arc_src} +# ``` +# """ +# prompt += TILELANG_PROBLEM_INSTRUCTION +# return prompt +# +# +# def prompt_generate_custom_tilelang_from_prompt_template(ref_arch_src: str) -> str: +# """ +# Using prompt example for TileLang +# Note: You'll need to create a TileLang example file similar to the Triton one +# """ +# arch = ref_arch_src +# +# # TODO: Create model_new_ex_add_tilelang.py example file +# example_arch_path = os.path.join(REPO_TOP_PATH, f"src/prompts/model_ex_add.py") +# example_new_arch_path = os.path.join( +# REPO_TOP_PATH, f"src/prompts/model_new_ex_add_tilelang.py" +# ) +# +# if not os.path.exists(example_arch_path): +# raise FileNotFoundError( +# f"Example architecture file not found: {example_arch_path}" +# ) +# if not os.path.exists(example_new_arch_path): +# # For now, use a basic template without examples if file doesn't exist +# return prompt_generate_custom_tilelang(arch, "", "") +# +# example_arch = read_file(example_arch_path) +# example_new_arch = read_file(example_new_arch_path) +# +# return prompt_generate_custom_tilelang(arch, example_arch, example_new_arch) +# +# +# def prompt_fix_compile_tilelang(ref_arch_src, custom_kernel, metadata): +# prompt = TILELANG_PROBLEM_STATEMENT +# prompt += f""" +# With the following architecture: +# ``` +# {ref_arch_src} +# ``` +# You generated the following solution and it failed to compile: +# ``` +# {custom_kernel} +# ``` +# Here's the metadata of the compilation error: +# ``` +# {metadata} +# ``` +# +# Please fix the compilation error in the new model code. Please output the corrected code in codeblocks. +# """ +# return prompt +# +# +# def prompt_fix_correctness_tilelang(ref_arch_src, custom_kernel, metadata): +# prompt = TILELANG_PROBLEM_STATEMENT +# prompt += f""" +# With the following architecture: +# ``` +# {ref_arch_src} +# ``` +# You generated the following solution and it failed correctness: +# ``` +# {custom_kernel} +# ``` +# Here's the metadata of the correctness error: +# ``` +# {metadata} +# ``` +# Please consider how your custom TileLang kernels are implemented, how it is different from the reference implementation, and fix the correctness error in the new model code. Please output the corrected code in codeblocks. +# """ +# return prompt + + +################################################################################ +# CuTe Backend +################################################################################ + +CUTE_PROBLEM_STATEMENT = """You write custom CuTe (CUTLASS) kernels to replace the pytorch operators in the given architecture to get speedups. \n + You have complete freedom to choose the set of operators you want to replace. You may make the decision to replace some operators with custom CuTe kernels and leave others unchanged. You may replace multiple operators with custom implementations, consider operator fusion opportunities (combining multiple operators into a single kernel, for example, combining matmul+relu), or algorithmic changes (such as online softmax). You are only limited by your imagination.\n +""" + +CUTE_PROBLEM_INSTRUCTION = """ +Optimize the architecture named Model with custom CuTe (CUTLASS) kernels! Name your optimized output architecture ModelNew. Output the new code in codeblocks. Please generate real code, NOT pseudocode, make sure the code compiles and is fully functional. Just output the new model code, no other text, and NO testing code! \n +""" + +CUTE_PROBLEM_STATEMENT_CLEANED = """You write custom CuTe (CUTLASS) kernels to replace the pytorch operators in the given architecture to get speedups.\n\nYou have complete freedom to choose the set of operators you want to replace. You may make the decision to replace some operators with custom CuTe kernels and leave others unchanged. You may replace multiple operators with custom implementations, consider operator fusion opportunities (combining multiple operators into a single kernel, for example, combining matmul+relu), or algorithmic changes (such as online softmax). You are only limited by your imagination.\n +""" + +CUTE_PROBLEM_INSTRUCTION_CLEANED = """ +Optimize the architecture named Model with custom CuTe (CUTLASS) kernels! Name your optimized output architecture ModelNew. Output the new code in codeblocks. Please generate real code, NOT pseudocode, make sure the code compiles and is fully functional. Just output the new model code, no other text, and NO testing code! \n +""" + + +def prompt_generate_custom_cute( + arc_src: str, example_arch_src: str, example_new_arch_src: str +) -> str: + prompt = CUTE_PROBLEM_STATEMENT + + if example_arch_src != "" and example_new_arch_src != "": + prompt += f""" + Here's an example to show you the syntax of inline embedding custom CuTe (CUTLASS) kernels in torch: The example given architecture is: \n + ``` \n + {example_arch_src} + ``` \n + The example new arch with custom CuTe kernels looks like this: \n + ``` + {example_new_arch_src} + ``` \n + """ + + prompt += f""" + You are given the following architecture: \n + ``` + {arc_src} + ``` + """ + prompt += CUTE_PROBLEM_INSTRUCTION + return prompt + + +def prompt_generate_custom_cute_from_prompt_template(ref_arch_src: str) -> str: + """ + Using prompt example for CuTe + Note: You'll need to create a CuTe example file + """ + arch = ref_arch_src + + # TODO: Create model_new_ex_add_cute.py example file + example_arch_path = os.path.join(REPO_TOP_PATH, f"src/prompts/model_ex_add.py") + example_new_arch_path = os.path.join( + REPO_TOP_PATH, f"src/prompts/model_new_ex_add_cute.py" + ) + + if not os.path.exists(example_arch_path): + raise FileNotFoundError( + f"Example architecture file not found: {example_arch_path}" + ) + if not os.path.exists(example_new_arch_path): + # For now, use a basic template without examples if file doesn't exist + return prompt_generate_custom_cute(arch, "", "") + + example_arch = read_file(example_arch_path) + example_new_arch = read_file(example_new_arch_path) + + return prompt_generate_custom_cute(arch, example_arch, example_new_arch) + + +def prompt_fix_compile_cute(ref_arch_src, custom_kernel, metadata): + prompt = CUTE_PROBLEM_STATEMENT + prompt += f""" + With the following architecture: + ``` + {ref_arch_src} + ``` + You generated the following solution and it failed to compile: + ``` + {custom_kernel} + ``` + Here's the metadata of the compilation error: + ``` + {metadata} + ``` + + Please fix the compilation error in the new model code. Please output the corrected code in codeblocks. + """ + return prompt + + +def prompt_fix_correctness_cute(ref_arch_src, custom_kernel, metadata): + prompt = CUTE_PROBLEM_STATEMENT + prompt += f""" + With the following architecture: + ``` + {ref_arch_src} + ``` + You generated the following solution and it failed correctness: + ``` + {custom_kernel} + ``` + Here's the metadata of the correctness error: + ``` + {metadata} + ``` + Please consider how your custom CuTe kernels are implemented, how it is different from the reference implementation, and fix the correctness error in the new model code. Please output the corrected code in codeblocks. + """ + return prompt + + +################################################################################ +# Unified API +################################################################################ + +def get_prompt_for_backend(ref_arch_src: str, backend: str = "triton") -> str: + """ + Unified API to get prompt for any supported backend + + Args: + ref_arch_src: Reference architecture source code + backend: One of 'triton', 'cute' (tilelang removed - not working) + + Returns: + Prompt string for the specified backend + """ + backend_lower = backend.lower() + + if backend_lower == "triton": + return prompt_generate_custom_triton_from_prompt_template(ref_arch_src) + # elif backend_lower == "tilelang": + # return prompt_generate_custom_tilelang_from_prompt_template(ref_arch_src) + elif backend_lower == "cute": + return prompt_generate_custom_cute_from_prompt_template(ref_arch_src) + else: + raise ValueError( + f"Unsupported backend: {backend}. Must be one of: 'triton', 'cute'" + ) + + +################################################################################ +# Main (for testing) +################################################################################ + +def main(): + gpu_name = "L40S" + backend = "triton" # Change this to test different backends + + ref_arch_src = read_file(os.path.join(KERNEL_BENCH_PATH, f"level1/19_ReLU.py")) + assert len(ref_arch_src) > 0, "ref_arch_src is empty" + + prompt = get_prompt_for_backend(ref_arch_src, backend) + print(f"\n{'='*80}\n{backend.upper()} PROMPT:\n{'='*80}\n") + print(prompt) + + # Write prompt to temp file + temp_file_path = os.path.join(REPO_TOP_PATH, "scratch", f"prompt_{backend}_draft.txt") + os.makedirs(os.path.dirname(temp_file_path), exist_ok=True) + with open(temp_file_path, "w") as f: + f.write(prompt) + print(f"\nPrompt written to: {temp_file_path}") + + +if __name__ == "__main__": + main() + + + diff --git a/src/prompts/model_new_ex_add_cute.py b/src/prompts/model_new_ex_add_cute.py new file mode 100644 index 00000000..3cd3d63f --- /dev/null +++ b/src/prompts/model_new_ex_add_cute.py @@ -0,0 +1,57 @@ +import torch +import cutlass +import cutlass.cute as cute +from cutlass.cute.runtime import from_dlpack + +@cute.kernel +def elementwise_add_kernel(gA: cute.Tensor, gB: cute.Tensor, gC: cute.Tensor): + tidx, _, _ = cute.arch.thread_idx() + bidx, _, _ = cute.arch.block_idx() + bdim, _, _ = cute.arch.block_dim() + + thread_idx = bidx * bdim + tidx + + m, n = gA.shape + ni = thread_idx % n + mi = thread_idx // n + + a_val = gA[mi, ni] + b_val = gB[mi, ni] + + gC[mi, ni] = a_val + b_val + +@cute.jit +def elementwise_add_host(mA: cute.Tensor, mB: cute.Tensor, mC: cute.Tensor): + M = mA.shape[0] + N = mA.shape[1] + + threads_per_block = 256 + total_elems = M * N + grid_x = cute.ceil_div(total_elems, threads_per_block) + + elementwise_add_kernel(mA, mB, mC).launch(grid=(grid_x, 1, 1), block=(threads_per_block, 1, 1)) + + +class ModelNew(torch.nn.Module): + def __init__(self): + super().__init__() + self.compiled = {} + + def forward(self, A: torch.Tensor, B: torch.Tensor) -> torch.Tensor: + M, N = A.shape + A = A.contiguous().cuda() + B = B.contiguous().cuda() + C = torch.empty((M, N), dtype=A.dtype, device=A.device) + + mA = from_dlpack(A, assumed_align=16).mark_compact_shape_dynamic(mode=0, stride_order=(0, 1)) + mB = from_dlpack(B, assumed_align=16).mark_compact_shape_dynamic(mode=0, stride_order=(0, 1)) + mC = from_dlpack(C, assumed_align=16).mark_compact_shape_dynamic(mode=0, stride_order=(0, 1)) + + key = (A.dtype,) + compiled = self.compiled.get(key) + if compiled is None: + compiled = cute.compile(elementwise_add_host, mA, mB, mC) + self.compiled[key] = compiled + + compiled(mA, mB, mC) + return C diff --git a/src/prompts/model_new_ex_add_triton.py b/src/prompts/model_new_ex_add_triton.py new file mode 100644 index 00000000..43a3f712 --- /dev/null +++ b/src/prompts/model_new_ex_add_triton.py @@ -0,0 +1,63 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import triton +import triton.language as tl + + +@triton.jit +def add_kernel( + x_ptr, # Pointer to first input + y_ptr, # Pointer to second input + out_ptr, # Pointer to output + n_elements, # Total number of elements in input/output + BLOCK_SIZE: tl.constexpr, +): + # Each program handles a contiguous block of data of size BLOCK_SIZE + block_start = tl.program_id(0) * BLOCK_SIZE + # Create a range of offsets [0..BLOCK_SIZE-1] + offsets = block_start + tl.arange(0, BLOCK_SIZE) + # Mask to ensure we don't go out of bounds + mask = offsets < n_elements + # Load input values + x = tl.load(x_ptr + offsets, mask=mask, other=0.0) + y = tl.load(y_ptr + offsets, mask=mask, other=0.0) + # Perform the elementwise addition + out = x + y + # Store the result + tl.store(out_ptr + offsets, out, mask=mask) + + +def triton_add(x: torch.Tensor, y: torch.Tensor): + """ + This function wraps the Triton kernel call. It: + 1. Ensures the inputs are contiguous on GPU. + 2. Calculates the grid (blocks) needed. + 3. Launches the Triton kernel. + """ + assert x.is_cuda and y.is_cuda, "Tensors must be on CUDA." + x = x.contiguous() + y = y.contiguous() + + # Prepare output tensor + out = torch.empty_like(x) + + # Number of elements in the tensor + n_elements = x.numel() + BLOCK_SIZE = 128 # Tunable parameter for block size + + # Determine the number of blocks needed + grid = lambda meta: ((n_elements + meta["BLOCK_SIZE"] - 1) // meta["BLOCK_SIZE"],) + + # Launch the Triton kernel + add_kernel[grid](x, y, out, n_elements, BLOCK_SIZE=BLOCK_SIZE) + return out + + +class ModelNew(nn.Module): + def __init__(self) -> None: + super().__init__() + + def forward(self, a, b): + # Instead of "return a + b", call our Triton-based addition + return triton_add(a, b) \ No newline at end of file