diff --git a/.gitignore b/.gitignore index abb125d6..bc05f65e 100644 Binary files a/.gitignore and b/.gitignore differ diff --git a/scripts/eval_from_generations.py b/scripts/eval_from_generations.py index 2e39e3be..9c2ae3c4 100644 --- a/scripts/eval_from_generations.py +++ b/scripts/eval_from_generations.py @@ -60,20 +60,36 @@ operating_sys = "ubuntu22.04" tag = f"{cuda_version}-{flavor}-{operating_sys}" -image = ( - modal.Image.from_registry(f"nvidia/cuda:{tag}", add_python="3.10") - .apt_install("git", - "gcc-10", - "g++-10", - "clang" - ) - .pip_install_from_requirements(os.path.join(REPO_TOP_DIR, "requirements.txt")) - .add_local_dir( - KERNEL_BENCH_PATH, - remote_path="/root/KernelBench" +# ThunderKittens support - use TK image if directory exists locally +THUNDERKITTENS_LOCAL_PATH = os.path.join(REPO_TOP_DIR, "ThunderKittens") +SRC_PATH = os.path.join(REPO_TOP_DIR, "src") + +if os.path.isdir(THUNDERKITTENS_LOCAL_PATH): + # ThunderKittens image with TK environment and mounting + image = ( + modal.Image.from_registry(f"nvidia/cuda:{tag}", add_python="3.10") + .apt_install("git", "gcc-10", "g++-10", "clang") + .pip_install_from_requirements(os.path.join(REPO_TOP_DIR, "requirements.txt")) + .env({ + "THUNDERKITTENS_ROOT": "/root/ThunderKittens", + "THUNDERKITTENS_PATH": "/root/ThunderKittens", + "TORCH_CUDA_ARCH_LIST": "9.0", + "CXX": "g++-10", + "CC": "gcc-10", + }) + .add_local_dir(THUNDERKITTENS_LOCAL_PATH, remote_path="/root/ThunderKittens", copy=True) + .add_local_dir(KERNEL_BENCH_PATH, remote_path="/root/KernelBench") + .add_local_dir(SRC_PATH, remote_path="/root/src") + ) +else: + # Standard image + image = ( + modal.Image.from_registry(f"nvidia/cuda:{tag}", add_python="3.10") + .apt_install("git", "gcc-10", "g++-10", "clang") + .pip_install_from_requirements(os.path.join(REPO_TOP_DIR, "requirements.txt")) + .add_local_dir(KERNEL_BENCH_PATH, remote_path="/root/KernelBench") + .add_local_dir(SRC_PATH, remote_path="/root/src") ) - .add_local_python_source("src") -) class EvalConfig(Config): diff --git a/scripts/generate_and_eval_single_sample.py b/scripts/generate_and_eval_single_sample.py index 2b2d5301..92c42ef8 100644 --- a/scripts/generate_and_eval_single_sample.py +++ b/scripts/generate_and_eval_single_sample.py @@ -197,12 +197,16 @@ def main(config: EvalConfig): include_hardware = include_hardware.lower() in ["true", "1", "yes"] config.include_hardware_info = include_hardware - supported_backends = {"cuda", "triton", "tilelang", "cute"} + supported_backends = {"cuda", "triton", "tilelang", "cute", "thunderkittens"} backend = config.backend.lower() if backend not in supported_backends: raise ValueError( f"Unsupported backend: {config.backend}. Must be one of {sorted(supported_backends)}." ) + + # ThunderKittens uses fp32 by default + if backend == "thunderkittens": + config.precision = "fp32" if backend == "tilelang": config.precision = "fp16" # tilelang only operates with fp16 diff --git a/scripts/generate_and_eval_single_sample_modal.py b/scripts/generate_and_eval_single_sample_modal.py index 7628e0bf..77016d49 100644 --- a/scripts/generate_and_eval_single_sample_modal.py +++ b/scripts/generate_and_eval_single_sample_modal.py @@ -95,16 +95,35 @@ def __repr__(self): operating_sys = "ubuntu22.04" tag = f"{cuda_version}-{flavor}-{operating_sys}" -image = ( - modal.Image.from_registry(f"nvidia/cuda:{tag}", add_python="3.10") - .apt_install("git", - "gcc-10", - "g++-10", - "clang" # note i skip a step - ) - .pip_install_from_requirements(os.path.join(REPO_TOP_DIR, "requirements.txt")) - .add_local_python_source("src") -) +# ThunderKittens support - use TK image if directory exists locally +THUNDERKITTENS_LOCAL_PATH = os.path.join(REPO_TOP_DIR, "ThunderKittens") + +SRC_PATH = os.path.join(REPO_TOP_DIR, "src") + +if os.path.isdir(THUNDERKITTENS_LOCAL_PATH): + # ThunderKittens image with TK environment and mounting + image = ( + modal.Image.from_registry(f"nvidia/cuda:{tag}", add_python="3.10") + .apt_install("git", "gcc-10", "g++-10", "clang") + .pip_install_from_requirements(os.path.join(REPO_TOP_DIR, "requirements.txt")) + .env({ + "THUNDERKITTENS_ROOT": "/root/ThunderKittens", + "THUNDERKITTENS_PATH": "/root/ThunderKittens", + "TORCH_CUDA_ARCH_LIST": "9.0", + "CXX": "g++-10", + "CC": "gcc-10", + }) + .add_local_dir(THUNDERKITTENS_LOCAL_PATH, remote_path="/root/ThunderKittens", copy=True) + .add_local_dir(SRC_PATH, remote_path="/root/src") + ) +else: + # Standard image + image = ( + modal.Image.from_registry(f"nvidia/cuda:{tag}", add_python="3.10") + .apt_install("git", "gcc-10", "g++-10", "clang") + .pip_install_from_requirements(os.path.join(REPO_TOP_DIR, "requirements.txt")) + .add_local_dir(SRC_PATH, remote_path="/root/src") + ) @app.cls(image=image) class EvalFunc: @@ -215,12 +234,16 @@ def main(config: EvalConfig): include_hardware = include_hardware.lower() in ["true", "1", "yes"] config.include_hardware_info = include_hardware - supported_backends = {"cuda", "triton", "tilelang", "cute"} + supported_backends = {"cuda", "triton", "tilelang", "cute", "thunderkittens"} backend = config.backend.lower() if backend not in supported_backends: raise ValueError( f"Unsupported backend: {config.backend}. Must be one of {sorted(supported_backends)}." ) + + # ThunderKittens uses fp32 by default + if backend == "thunderkittens": + config.precision = "fp32" #tilelang only supports fp16 or bf16 if backend == "tilelang": diff --git a/scripts/generate_samples.py b/scripts/generate_samples.py index e47c6e87..82f2c6c6 100644 --- a/scripts/generate_samples.py +++ b/scripts/generate_samples.py @@ -239,7 +239,7 @@ def main(config: GenerationConfig): include_hardware = include_hardware.lower() in ["true", "1", "yes"] config.include_hardware_info = include_hardware - supported_backends = {"cuda", "triton", "cute", "tilelang"} + supported_backends = {"cuda", "triton", "cute", "tilelang", "thunderkittens"} backend = config.backend.lower() if backend not in supported_backends: raise ValueError( @@ -248,6 +248,8 @@ def main(config: GenerationConfig): config.backend = backend if backend == "tilelang": config.precision = "fp16" + if backend == "thunderkittens": + config.precision = "fp32" # ThunderKittens supports fp32 by default config.prompt_option = str(config.prompt_option).lower() valid_prompt_options = {"zero_shot", "one_shot", "few_shot"} diff --git a/scripts/run_and_check.py b/scripts/run_and_check.py index 316b96ee..04a61d57 100644 --- a/scripts/run_and_check.py +++ b/scripts/run_and_check.py @@ -26,20 +26,44 @@ REPO_TOP_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) KERNEL_BENCH_PATH = os.path.join(REPO_TOP_PATH, "KernelBench") +THUNDERKITTENS_LOCAL_PATH = os.path.join(REPO_TOP_PATH, "ThunderKittens") +SRC_PATH = os.path.join(REPO_TOP_PATH, "src") cuda_version = "12.8.0" flavor = "devel" operating_sys = "ubuntu22.04" tag = f"{cuda_version}-{flavor}-{operating_sys}" -image = ( - modal.Image.from_registry(f"nvidia/cuda:{tag}", add_python="3.10") - .apt_install("git", "gcc-10", "g++-10", "clang") - .pip_install_from_requirements(os.path.join(REPO_TOP_PATH, "requirements.txt")) - .add_local_dir(KERNEL_BENCH_PATH, remote_path="/root/KernelBench") - .add_local_python_source("src") - .add_local_python_source("scripts") -) +# ThunderKittens support - use TK image if directory exists locally +if os.path.isdir(THUNDERKITTENS_LOCAL_PATH): + # ThunderKittens image with TK environment and mounting + image = ( + modal.Image.from_registry(f"nvidia/cuda:{tag}", add_python="3.10") + .apt_install("git", "gcc-10", "g++-10", "clang") + .pip_install_from_requirements(os.path.join(REPO_TOP_PATH, "requirements.txt")) + .env({ + "THUNDERKITTENS_ROOT": "/root/ThunderKittens", + "THUNDERKITTENS_PATH": "/root/ThunderKittens", + "TORCH_CUDA_ARCH_LIST": "9.0", + "CXX": "g++-10", + "CC": "gcc-10", + }) + .add_local_dir(THUNDERKITTENS_LOCAL_PATH, remote_path="/root/ThunderKittens", copy=True) + .add_local_dir(KERNEL_BENCH_PATH, remote_path="/root/KernelBench") + .add_local_dir(SRC_PATH, remote_path="/root/src") + .add_local_python_source("src") + .add_local_python_source("scripts") + ) +else: + # Standard image without ThunderKittens + image = ( + modal.Image.from_registry(f"nvidia/cuda:{tag}", add_python="3.10") + .apt_install("git", "gcc-10", "g++-10", "clang") + .pip_install_from_requirements(os.path.join(REPO_TOP_PATH, "requirements.txt")) + .add_local_dir(KERNEL_BENCH_PATH, remote_path="/root/KernelBench") + .add_local_python_source("src") + .add_local_python_source("scripts") + ) """ Run a pair of KernelBench format (problem, solution) to check if solution is correct and compute speedup diff --git a/src/prompts/model_new_ex_add_thunderkittens.py b/src/prompts/model_new_ex_add_thunderkittens.py new file mode 100644 index 00000000..af6ff0fb --- /dev/null +++ b/src/prompts/model_new_ex_add_thunderkittens.py @@ -0,0 +1,142 @@ +import torch +import torch.nn as nn +from torch.utils.cpp_extension import load_inline +import os + +# ThunderKittens header-only library path (set via environment variable) +# Default to /root/ThunderKittens for Modal containers, or use THUNDERKITTENS_PATH env var +TK_PATH = os.environ.get("THUNDERKITTENS_PATH", os.environ.get("THUNDERKITTENS_ROOT", "/root/ThunderKittens")) + +# C++ source: function declaration for binding +elementwise_add_cpp_source = """ +torch::Tensor elementwise_add_cuda(torch::Tensor a, torch::Tensor b); +""" + +# CUDA source: ThunderKittens kernel implementation +# +# IMPORTANT ThunderKittens API notes: +# 1. Define KITTENS_HOPPER before including kittens.cuh for H100/Hopper GPUs +# 2. Operations like load, store, zero, mma_AB are NOT free functions! +# They are static member functions inside kittens::group template struct. +# 3. Create an alias like: using warp = kittens::group<1>; +# 4. Then call: warp::load(...), warp::zero(...), etc. +# +elementwise_add_cuda_source = """ +// IMPORTANT: Define KITTENS_HOPPER before including ThunderKittens headers for H100/Hopper GPUs +// This enables FP8 types and Hopper-specific features +#define KITTENS_HOPPER + +#include +#include + +// Include ThunderKittens headers +#include "kittens.cuh" + +// ThunderKittens namespace and group aliases +// Operations are accessed through these group types, NOT as free functions +using namespace kittens; +using warp = kittens::group<1>; // For single-warp operations (32 threads) +// For multi-warp operations, use: using warpgroup = kittens::group<4>; + +// Constants for tile dimensions +constexpr int TILE_DIM = 16; + +// ThunderKittens elementwise add kernel using shared memory tiles +// This example demonstrates the ThunderKittens API pattern +__global__ void tk_elementwise_add_kernel(const float* __restrict__ a_ptr, + const float* __restrict__ b_ptr, + float* __restrict__ out_ptr, + int rows, int cols) { + // For simple element-wise ops, we use a straightforward approach + // ThunderKittens shines for matrix ops with tiles, but here we show basic pattern + + int idx = blockIdx.x * blockDim.x + threadIdx.x; + int total = rows * cols; + + // Grid-stride loop for simple element-wise addition + for (int i = idx; i < total; i += blockDim.x * gridDim.x) { + out_ptr[i] = a_ptr[i] + b_ptr[i]; + } +} + +// Alternative: ThunderKittens tiled version for larger matrices +// Shows proper usage of ThunderKittens tile types and group operations +// Uncomment and adapt for matrix operations: +/* +__global__ void tk_matmul_kernel(const bf16* A, const bf16* B, bf16* C, + int M, int N, int K) { + // Define aliases for the group - THIS IS REQUIRED for ThunderKittens ops + using warpgroup = kittens::group<4>; // 4 warps = 128 threads + + // ThunderKittens register tiles for accumulation + rt_fl<16, 16> acc; // 16x16 float register tile + + // Shared memory tiles + extern __shared__ alignment_dummy __shm[]; + st_bf<16, 16> (&a_smem)[2] = *reinterpret_cast(*)[2]>(__shm); + st_bf<16, 16> (&b_smem)[2] = *reinterpret_cast(*)[2]>(__shm + sizeof(st_bf<16,16>)*2); + + // Initialize accumulator to zero - NOTE: use warpgroup:: prefix! + warpgroup::zero(acc); + + // Main loop would go here with: + // warpgroup::load(a_smem[...], ...); // Load from global to shared + // warpgroup::mma_AB(acc, a_tile, b_tile); // Matrix multiply-accumulate + // warpgroup::store(C_ptr, acc, ...); // Store result +} +*/ + +torch::Tensor elementwise_add_cuda(torch::Tensor a, torch::Tensor b) { + TORCH_CHECK(a.is_cuda(), "Input tensor a must be on CUDA"); + TORCH_CHECK(b.is_cuda(), "Input tensor b must be on CUDA"); + TORCH_CHECK(a.sizes() == b.sizes(), "Input tensors must have the same shape"); + + auto out = torch::empty_like(a); + int rows = a.size(0); + int cols = a.numel() / rows; + + const int block_size = 256; + const int num_blocks = (a.numel() + block_size - 1) / block_size; + + tk_elementwise_add_kernel<<>>( + a.data_ptr(), + b.data_ptr(), + out.data_ptr(), + rows, cols + ); + + return out; +} +""" + +# Compile the ThunderKittens kernel inline +elementwise_add = load_inline( + name="elementwise_add_tk", + cpp_sources=elementwise_add_cpp_source, + cuda_sources=elementwise_add_cuda_source, + functions=["elementwise_add_cuda"], + verbose=True, + extra_include_paths=[ + TK_PATH, + os.path.join(TK_PATH, "include"), + ], + extra_cflags=["-std=c++20", "-O3"], + extra_cuda_cflags=[ + "-std=c++20", + "-O3", + "--expt-relaxed-constexpr", + "--expt-extended-lambda", + "-Xcompiler", "-fPIC", + "-DNDEBUG", + "-DKITTENS_HOPPER", + ], +) + + +class ModelNew(nn.Module): + def __init__(self) -> None: + super().__init__() + self.elementwise_add = elementwise_add + + def forward(self, a, b): + return self.elementwise_add.elementwise_add_cuda(a, b) diff --git a/src/prompts/prompts.toml b/src/prompts/prompts.toml index bcf4e4ed..1d0c1fed 100644 --- a/src/prompts/prompts.toml +++ b/src/prompts/prompts.toml @@ -49,6 +49,11 @@ backend_display = "TileLang kernels" one_shot_new_arch = "src/prompts/model_new_ex_add_tilelang.py" # No few_shot_examples - will use one-shot when few_shot option is selected +[backends.thunderkittens] +backend_display = "ThunderKittens kernels" +one_shot_new_arch = "src/prompts/model_new_ex_add_thunderkittens.py" +# No few_shot_examples - will use one-shot when few_shot option is selected + # ------------------------------------------------------------------------- # Precision: Precision-specific configuration # -------------------------------------------------------------------------