From d9c42f800a529ca485457c402860573d60dc6c2a Mon Sep 17 00:00:00 2001 From: William Hu Date: Sun, 11 Jan 2026 14:33:41 -0800 Subject: [PATCH] add new backend regexes --- src/kernelbench/kernel_static_checker.py | 342 +++++++++++++++++- .../unit_tests/test_validate_kernel_static.py | 316 +++++++++++++++- 2 files changed, 640 insertions(+), 18 deletions(-) diff --git a/src/kernelbench/kernel_static_checker.py b/src/kernelbench/kernel_static_checker.py index c8832a1a..deaf1d4a 100644 --- a/src/kernelbench/kernel_static_checker.py +++ b/src/kernelbench/kernel_static_checker.py @@ -169,6 +169,24 @@ def check_torch_computation_ops(code: str) -> Tuple[bool, str]: # use load_inline or cpp_extension (PyTorch's inline compilation). CUDA_COMPILE_PATTERNS = ["load_inline", "cpp_extension"] +# Core CUDA patterns that indicate actual kernel implementation +CUDA_THREAD_PATTERNS = [ + r"\bthreadIdx\.", # threadIdx.x, threadIdx.y, threadIdx.z + r"\bblockIdx\.", # blockIdx.x, blockIdx.y, blockIdx.z + r"\bblockDim\.", # blockDim.x, blockDim.y, blockDim.z + r"\bgridDim\.", # gridDim.x, gridDim.y, gridDim.z +] + +# CUDA synchronization and memory patterns +CUDA_KERNEL_PATTERNS = [ + r"__syncthreads\s*\(", # Thread synchronization + r"__shared__\s+", # Shared memory declaration + r"__device__\s+", # Device function + r"atomicAdd\s*\(", # Atomic operations + r"atomicMax\s*\(", + r"atomicMin\s*\(", +] + def check_cuda_impl(code: str) -> Tuple[bool, str]: """ Check for valid CUDA kernel implementation. @@ -176,12 +194,29 @@ def check_cuda_impl(code: str) -> Tuple[bool, str]: Requirements: - Must have __global__ void kernel_name (kernel definition) - Must have load_inline or cpp_extension (PyTorch inline compilation) + - Must use CUDA thread indexing (threadIdx, blockIdx, blockDim, or gridDim) + OR CUDA kernel features (__syncthreads, __shared__, __device__, atomics) + + Rationale: Ensures code actually implements a CUDA kernel rather than + just wrapping PyTorch operations. """ code = _strip_comments(code) + + # Check for kernel definition if "__global__" not in code: return (True, "Missing __global__ kernel definition") + + # Check for compilation method if not any(p in code for p in CUDA_COMPILE_PATTERNS): return (True, "Missing load_inline or cpp_extension for compilation") + + # Check for actual CUDA kernel features (thread indexing or kernel patterns) + has_thread_patterns = any(re.search(p, code) for p in CUDA_THREAD_PATTERNS) + has_kernel_patterns = any(re.search(p, code) for p in CUDA_KERNEL_PATTERNS) + + if not (has_thread_patterns or has_kernel_patterns): + return (True, "Missing CUDA thread indexing or kernel features (threadIdx, blockIdx, __syncthreads, __shared__, etc.)") + return (False, "") # <========= TRITON CHECKS =========> @@ -190,6 +225,26 @@ def check_cuda_impl(code: str) -> Tuple[bool, str]: TRITON_JIT_PATTERN = r"@triton\.(jit|autotune)" TRITON_OPS_PATTERN = r"\btl\.\w+" +# Core Triton memory operations (must-have) +TRITON_MEMORY_OPS = [ + r"tl\.load\s*\(", # Memory load + r"tl\.store\s*\(", # Memory store +] + +# Core Triton kernel patterns +TRITON_KERNEL_PATTERNS = [ + r"tl\.program_id\s*\(", # Program/block ID + r"tl\.num_programs\s*\(", # Number of programs + r"tl\.constexpr", # Compile-time constants + r"tl\.arange\s*\(", # Index generation + r"tl\.cdiv\s*\(", # Ceiling division +] + +# Triton data types +TRITON_DTYPE_PATTERNS = [ + r"tl\.(float16|float32|float64|int32|int64|bfloat16)", +] + def check_triton_impl(code: str) -> Tuple[bool, str]: """ Check for valid Triton kernel implementation. @@ -197,41 +252,137 @@ def check_triton_impl(code: str) -> Tuple[bool, str]: Requirements: - Must have @triton.jit or @triton.autotune decorator - Must have tl.* operations (enforces actual Triton code, not wrapper) + - Must have tl.load or tl.store (core memory operations) + - Should have tl.program_id or other kernel patterns (for proper indexing) Note: Triton's compiler itself prevents PyTorch ops inside @triton.jit. """ code = _strip_comments(code) + + # Check for decorator if not re.search(TRITON_JIT_PATTERN, code): return (True, "Missing @triton.jit or @triton.autotune") + + # Check for any tl.* operations if not re.search(TRITON_OPS_PATTERN, code): return (True, "No tl.* operations found in Triton kernel") + + # Check for memory operations (load or store) + has_memory_ops = any(re.search(p, code) for p in TRITON_MEMORY_OPS) + if not has_memory_ops: + return (True, "Missing Triton memory operations (tl.load or tl.store)") + + # Check for kernel patterns (program_id is essential for indexing) + has_kernel_patterns = any(re.search(p, code) for p in TRITON_KERNEL_PATTERNS) + if not has_kernel_patterns: + return (True, "Missing Triton kernel patterns (tl.program_id, tl.arange, etc.)") + return (False, "") # <========= THUNDERKITTENS CHECKS =========> # Rationale: ThunderKittens uses warp/warpgroup primitives and tile abstractions. # Valid TK code must have namespace patterns and tile declarations. +# Reference: https://github.com/HazyResearch/ThunderKittens/ +TK_NAMESPACE_PATTERNS = [ + r"kittens::", # Core namespace + r"using namespace kittens", # Using declaration +] + TK_WARP_PATTERNS = [ - r"kittens::warp\b", r"kittens::warpgroup\b", - r"::warpgroup::", r"::warp::", r"warpgroup::", r"warp::" + r"kittens::warp\b", + r"kittens::warpgroup\b", + r"kittens::group\s*<\s*\d+\s*>", # kittens::group<4> for warpgroup operations + r"::warpgroup::", + r"::warp::", + r"warpgroup::", + r"warp::" +] + +# ThunderKittens tile types: rt (register tile), st (shared tile) +# Examples: kittens::rt_bf<32,16>, kittens::st_hf<32,64>, rt_fl<32,64> +TK_TILE_PATTERN = r"(?:kittens::)?(?:st|rt)_(?:bf|fl|hf|i8|i32)\s*<[^>]+>" + +# ThunderKittens vector types (associated with tiles) +TK_VECTOR_PATTERN = r"::(?:col_vec|row_vec)\b" + +# ThunderKittens memory operations (often namespaced) +TK_MEMORY_OPS = [ + r"kittens::load\s*\(", # Namespaced load + r"kittens::store\s*\(", # Namespaced store + r"\bload\s*\(", # Tile load (in using namespace context) + r"\bstore\s*\(", # Tile store (in using namespace context) + r"load_async\s*\(", # Async load +] + +# ThunderKittens compute operations (from the manual) +TK_COMPUTE_OPS = [ + r"kittens::(?:warpgroup::)?mma_AB\s*\(", # Warpgroup MMA: mma_AB + r"kittens::(?:warpgroup::)?mma_ABt\s*\(", # MMA variants + r"kittens::(?:warpgroup::)?mma_AtB\s*\(", + r"(?:warpgroup::)?mma_AB[t]?\s*\(", # Without namespace (in using context) + r"kittens::mul\s*\(", # Namespaced element-wise ops + r"kittens::add\s*\(", + r"kittens::sub\s*\(", + r"kittens::copy\s*\(", + r"kittens::zero\s*\(", + r"\bmul\s*\(", # Element-wise multiply (in using namespace) + r"\badd\s*\(", # Element-wise add + r"\bsub\s*\(", # Element-wise subtract + r"\bcopy\s*\(", # Copy operation + r"\bzero\s*\(", # Zero initialization +] + +# ThunderKittens control and utilities +TK_CONTROL_PATTERNS = [ + r"kittens::warpid\s*\(", # Get warp ID + r"tma::", # Tensor Memory Accelerator namespace + r"__syncthreads\s*\(", # Thread synchronization (CUDA primitive often used) + r"__syncwarp\s*\(", # Warp synchronization ] -TK_TILE_PATTERN = r"(?:kittens::)?(?:st|rt)_\w+\s*<[^>]+>" def check_tk_impl(code: str) -> Tuple[bool, str]: """ Check for valid ThunderKittens kernel implementation. Requirements: - - Must have warp/warpgroup namespace patterns (kittens::warp, etc.) - - Must have tile declarations (st_bf<...>, rt_fl<...>, etc.) - - TODO: Add producer-consumer pattern check for complex kernels. + - Must have kittens namespace (kittens::, using namespace kittens) + - Must have tile declarations (st_bf, rt_fl, st_hf, rt_i8, etc.) + - Must have memory operations (kittens::load, kittens::store, load_async) + - Should have compute operations (mma_AB, mul, add, copy, zero, etc.) + - Optional: warp/warpgroup patterns (kittens::warpgroup, kittens::group) + for warpgroup-specific operations + + ThunderKittens is a tile-based programming model that abstracts + warp-level operations with register (rt) and shared (st) tiles. + By default, operations exist at warp-level, so explicit warp/warpgroup + scope is only needed for warpgroup operations like mma_AB. + + Reference: https://github.com/HazyResearch/ThunderKittens/ """ code = _strip_comments(code) - if not any(re.search(p, code) for p in TK_WARP_PATTERNS): - return (True, "Missing ThunderKittens warp/warpgroup patterns") - if not re.search(TK_TILE_PATTERN, code): - return (True, "Missing ThunderKittens tile declarations (st_*/rt_*)") + + # Check for kittens namespace (fundamental requirement) + has_namespace = any(re.search(p, code) for p in TK_NAMESPACE_PATTERNS) + if not has_namespace: + return (True, "Missing kittens namespace (kittens:: or using namespace kittens)") + + # Check for tile declarations (rt_* or st_*) + has_tiles = re.search(TK_TILE_PATTERN, code) + has_vectors = re.search(TK_VECTOR_PATTERN, code) + if not (has_tiles or has_vectors): + return (True, "Missing ThunderKittens tile/vector declarations (rt_bf, st_fl, ::col_vec, etc.)") + + # Check for memory operations + has_memory_ops = any(re.search(p, code) for p in TK_MEMORY_OPS) + if not has_memory_ops: + return (True, "Missing ThunderKittens memory operations (kittens::load, kittens::store, load_async)") + + # Check for compute operations + has_compute_ops = any(re.search(p, code) for p in TK_COMPUTE_OPS) + if not has_compute_ops: + return (True, "Missing ThunderKittens compute operations (mma_AB, mul, add, copy, zero, etc.)") + return (False, "") @@ -244,11 +395,67 @@ def check_tk_impl(code: str) -> Tuple[bool, str]: r"from cutlass", # Python CUTLASS bindings ] +# CuTe tensor operations +CUTE_TENSOR_OPS = [ + r"make_tensor\s*\(", # Tensor creation + r"make_layout\s*\(", # Layout creation + r"make_shape\s*\(", # Shape creation + r"make_stride\s*\(", # Stride creation +] + +# CuTe/CUTLASS copy operations +CUTE_COPY_OPS = [ + r"copy\s*\(", # Generic copy + r"copy_if\s*\(", # Conditional copy + r"cute::copy", # Namespaced copy + r"Copy_Atom", # Copy atom template +] + +# CUTLASS GEMM patterns +CUTLASS_GEMM_PATTERNS = [ + r"cutlass::gemm", # GEMM namespace + r"cutlass::epilogue", # Epilogue operations + r"Gemm\w*<", # GEMM templates (Gemm, GemmUniversal, etc.) + r"GemmConfiguration", # GEMM configuration + r"ThreadblockSwizzle", # Threadblock scheduling +] + +# CUTLASS kernel patterns +CUTLASS_KERNEL_PATTERNS = [ + r"cutlass::arch", # Architecture-specific code + r"cutlass::layout", # Layout specifications + r"RowMajor|ColumnMajor", # Layout types + r"TensorRef\s*<", # Tensor reference template +] + def check_cute_impl(code: str) -> Tuple[bool, str]: - """Check for valid CUTLASS/CuTe kernel implementation.""" + """ + Check for valid CUTLASS/CuTe kernel implementation. + + Requirements: + - Must have cute:: or cutlass:: namespace (or Python bindings) + - Must have tensor operations (make_tensor, make_layout) OR + copy operations (copy, Copy_Atom) OR + CUTLASS GEMM patterns (cutlass::gemm, Gemm templates) + + CuTe is a layout/tensor abstraction library used by CUTLASS 3.x. + We check for both high-level CUTLASS templates and low-level CuTe ops. + """ code = _strip_comments(code) + + # Check for namespace if not any(p in code for p in ["cute::", "cutlass::", "from cutlass"]): return (True, "Missing cute:: or cutlass:: namespace") + + # Check for actual operations (tensor, copy, or GEMM) + has_tensor_ops = any(re.search(p, code) for p in CUTE_TENSOR_OPS) + has_copy_ops = any(re.search(p, code) for p in CUTE_COPY_OPS) + has_gemm_patterns = any(re.search(p, code) for p in CUTLASS_GEMM_PATTERNS) + has_kernel_patterns = any(re.search(p, code) for p in CUTLASS_KERNEL_PATTERNS) + + if not (has_tensor_ops or has_copy_ops or has_gemm_patterns or has_kernel_patterns): + return (True, "Missing CUTLASS/CuTe operations (make_tensor, copy, gemm patterns, etc.)") + return (False, "") @@ -257,15 +464,116 @@ def check_cute_impl(code: str) -> Tuple[bool, str]: # https://github.com/tile-ai/tilelang TILELANG_PATTERNS = [ r"@T\.prim_func", # TVM primitive function decorator - r"tvm\.build", # TVM build call - r"T\.grid", # TileLang grid + r"T\.Kernel", # TileLang kernel +] + +# TileLang/TVM iteration patterns +TILELANG_ITERATION = [ + r"T\.serial\s*\(", # Serial loop + r"T\.Parallel\s*\(", # Parallel loop + r"T\.Pipelined\s*\(", # Pipelined loop + r"T\.unroll\s*\(", # Unrolled loop +] + +# TileLang/TVM memory management +TILELANG_MEM_MGMT = [ + r"T\.alloc_shared\s*\(", + r"T\.alloc_fragment\s*\(", + r"T\.alloc_var\s*\(", + r"T\.alloc_barrier\s*\(", + r"T\.alloc_tmem\s*\(", + r"T\.alloc_reducer\s*\(", + r"T\.alloc_descriptor\s*\(", + r"T\.alloc_wgmma_desc\s*\(", + r"T\.alloc_tcgen05_smem_desc\s*\(", + r"T\.alloc_tcgen05_instr_desc\s*\(", + r"T\.empty\s*\(", +] + +# TileLang/TVM data movement +TILELANG_DATA_MOVE = [ + r"T\.copy\s*\(", + r"T\.c2d_im2col\s*\(", +] + +# TileLang/TVM compute primitives +TILELANG_COMPUTE = [ + r"T\.gemm\s*\(", + r"T\.gemm_sp\s*\(", + r"T\.reduce_(sum|max|min|abssum|absmax)\s*\(", + r"T\.cumsum\s*\(", + r"T\.finalize_reducer\s*\(", + r"T\.warp_reduce_(sum|max|min|bitand|bitor)\s*\(", + r"T\.(exp|log|max|min|rsqrt)\s*\(", + r"T\.ieee_(add|sub|mul|fmaf)\s*\(", + r"T\.clear\s*\(", + r"T\.fill\s*\(", + r"T\.reshape\s*\(", + r"T\.view\s*\(", +] + +# TileLang/TVM synchronization and hardware control +TILELANG_HARDWARE = [ + r"T\.pdl_trigger\s*\(", + r"T\.pdl_sync\s*\(", + r"T\.create_list_of_mbarrier\s*\(", + r"T\.get_mbarrier\s*\(", + r"T\.mbarrier_(wait_parity|arrive|expect_tx)\s*\(", + r"T\.barrier_wait\s*\(", + r"T\.fence_proxy_async\s*\(", + r"T\.warpgroup_fence_operand\s*\(", + r"T\.warpgroup_(arrive|commit_batch|wait)\s*\(", + r"T\.wait_wgmma\s*\(", + r"T\.atomic_(add|addx2|addx4|max|min|load|store)\s*\(", +] + +# TileLang/TVM intrinsics and indexing +TILELANG_INTRINSICS = [ + r"T\.dp4a\s*\(", + r"T\.clamp\s*\(", + r"T\.loop_break\s*\(", + r"T\.get_(lane|warp|warp_group)_idx(_sync)?\s*\(", +] + +# TileLang register control +TILELANG_REGS = [ + r"T\.(set|inc|dec)_max_nreg\s*\(", + r"T\.annotate_(producer_reg_dealloc|consumer_reg_alloc)\s*\(", + r"T\.no_set_max_nreg\s*\(", + r"T\.disable_warp_group_reg_alloc\s*\(", ] def check_tilelang_impl(code: str) -> Tuple[bool, str]: - """Check for valid TileLang kernel implementation.""" + """ + Check for valid TileLang kernel implementation. + + Requirements: + - Must have @T.prim_func decorator or T.Kernel + - Must have iteration constructs (T.serial, T.Parallel, T.unroll, etc.) + - Must have at least one TileLang operation (Memory, Data Move, Compute, etc.) + + TileLang is a tensor program DSL built on TVM that uses structured + iteration spaces and explicit buffer operations. + """ code = _strip_comments(code) - if not re.search(r"@T\.prim_func", code): - return (True, "Missing @T.prim_func decorator") + + # Check for decorator + if not any(re.search(p, code) for p in TILELANG_PATTERNS): + return (True, "Missing @T.prim_func decorator or T.Kernel") + + # Check for iteration constructs + has_iteration = any(re.search(p, code) for p in TILELANG_ITERATION) + if not has_iteration: + return (True, "Missing TileLang iteration constructs (T.Parallel, T.serial, etc.)") + + # Check for any TileLang operations + all_ops = TILELANG_MEM_MGMT + TILELANG_DATA_MOVE + TILELANG_COMPUTE + \ + TILELANG_HARDWARE + TILELANG_INTRINSICS + TILELANG_REGS + has_ops = any(re.search(p, code) for p in all_ops) + + if not has_ops: + return (True, "Missing TileLang operations (T.alloc_shared, T.copy, T.gemm, etc.)") + return (False, "") diff --git a/src/kernelbench/unit_tests/test_validate_kernel_static.py b/src/kernelbench/unit_tests/test_validate_kernel_static.py index 55d97a2c..80e74098 100644 --- a/src/kernelbench/unit_tests/test_validate_kernel_static.py +++ b/src/kernelbench/unit_tests/test_validate_kernel_static.py @@ -7,6 +7,23 @@ - Handles backend-specific checks - Respects forbidden/warnings parameters - Returns correct output format +- Validates backend-specific patterns with real kernel code + +Test Coverage: +1. API/Infrastructure Tests + - Function signature and return values + - Precision parameter handling + - Error vs warning categorization + - Custom forbidden/warnings lists + - Backend parameter processing + - Edge cases and integration + +2. Backend Pattern Validation Tests + - CUDA: Valid kernels, shared memory, thread indexing + - Triton: Valid kernels, autotune, memory operations + - ThunderKittens: Simple kernels, warpgroup MMA, compute ops + - CUTLASS/CuTe: GEMM kernels, tensor operations + - TileLang: Complete kernels, buffer allocation, iteration Run with pytest: pytest src/kernelbench/unit_tests/test_validate_kernel_static.py -v @@ -418,6 +435,303 @@ def forward(self, x): assert len(errors) > 0, "Should have errors" +# ============================================================================ +# Test Backend-Specific Pattern Validation +# These tests validate that backend checks correctly identify valid kernels +# from official documentation and reject wrapper/incomplete code +# ============================================================================ + +# ----------------------------------------------------------------------------- +# CUDA Backend Tests +# ----------------------------------------------------------------------------- + +def test_cuda_valid_kernel_with_thread_indexing(): + """Test that valid CUDA kernel with threadIdx/blockIdx passes""" + code = """ + #include + + __global__ void vector_add(float* a, float* b, float* c, int n) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + c[idx] = a[idx] + b[idx]; + } + } + + torch::Tensor forward(torch::Tensor a, torch::Tensor b) { + auto c = torch::empty_like(a); + int threads = 256; + int blocks = (a.numel() + threads - 1) / threads; + vector_add<<>>( + a.data_ptr(), b.data_ptr(), c.data_ptr(), a.numel() + ); + return c; + } + + auto module = torch::utils::cpp_extension::load_inline( + "vector_add", cuda_src, cuda_src, {"vector_add.cu"}, {}, {}, true + ); + """ + valid, errors, warnings = validate_kernel_static(code, backend="cuda") + assert valid, f"Expected valid CUDA kernel to pass, got errors: {errors}" + + +def test_cuda_valid_kernel_with_shared_memory(): + """Test that CUDA kernel with __shared__ memory passes""" + code = """ + #include + + __global__ void matmul_shared(float* A, float* B, float* C, int N) { + __shared__ float shared_A[16][16]; + __shared__ float shared_B[16][16]; + + int tx = threadIdx.x; + int ty = threadIdx.y; + int bx = blockIdx.x; + int by = blockIdx.y; + + __syncthreads(); + // ... computation using shared memory + } + + torch::Tensor matmul(torch::Tensor a, torch::Tensor b) { + return torch::utils::cpp_extension::load_inline("matmul", cuda_src, {}, {}); + } + """ + valid, errors, warnings = validate_kernel_static(code, backend="cuda") + assert valid, f"Expected CUDA kernel with __shared__ to pass, got errors: {errors}" + + +def test_cuda_missing_thread_indexing(): + """Test that CUDA kernel without thread indexing is rejected""" + code = """ + #include + + __global__ void fake_kernel() { + // No threadIdx, no blockIdx, no actual CUDA features + float x = 1.0f; + } + + void wrapper() { + torch::utils::cpp_extension::load_inline("fake", src, {}, {}); + } + """ + valid, errors, warnings = validate_kernel_static(code, backend="cuda") + assert not valid, "Expected CUDA kernel without thread indexing to fail" + assert any("thread indexing" in err.lower() or "kernel features" in err.lower() for err in errors) + + +# ----------------------------------------------------------------------------- +# Triton Backend Tests +# ----------------------------------------------------------------------------- + +def test_triton_valid_kernel(): + """Test that valid Triton kernel with tl.load/store/program_id passes""" + code = """ + import triton + import triton.language as tl + + @triton.jit + def add_kernel(x_ptr, y_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr): + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + x = tl.load(x_ptr + offsets, mask=mask) + y = tl.load(y_ptr + offsets, mask=mask) + output = x + y + tl.store(output_ptr + offsets, output, mask=mask) + """ + valid, errors, warnings = validate_kernel_static(code, backend="triton") + assert valid, f"Expected valid Triton kernel to pass, got errors: {errors}" + + +def test_triton_autotune(): + """Test that Triton kernel with @triton.autotune passes""" + code = """ + import triton + import triton.language as tl + + @triton.autotune( + configs=[ + triton.Config({'BLOCK_SIZE': 128}), + triton.Config({'BLOCK_SIZE': 256}), + ], + key=['n_elements'], + ) + @triton.jit + def optimized_kernel(x_ptr, out_ptr, n_elements, BLOCK_SIZE: tl.constexpr): + pid = tl.program_id(0) + offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + x = tl.load(x_ptr + offsets, mask=mask) + tl.store(out_ptr + offsets, x * 2, mask=mask) + """ + valid, errors, warnings = validate_kernel_static(code, backend="triton") + assert valid, f"Expected Triton autotune kernel to pass, got errors: {errors}" + + +def test_triton_missing_memory_ops(): + """Test that Triton kernel without memory operations is rejected""" + code = """ + import triton + + @triton.jit + def fake_kernel(x): + # Missing memory operations + return x * 2 + """ + valid, errors, warnings = validate_kernel_static(code, backend="triton") + assert not valid, f"Expected Triton kernel without memory ops to fail, but got valid={valid}" + assert any("load" in err.lower() or "store" in err.lower() or "tl" in err.lower() for err in errors), \ + f"Expected error about tl operations, got: {errors}" + + +# ----------------------------------------------------------------------------- +# ThunderKittens Backend Tests +# ----------------------------------------------------------------------------- + +def test_thunderkittens_valid_simple(): + """Test that simple ThunderKittens kernel passes (from official manual)""" + code = """ + #include + + using namespace kittens; + + __global__ void example_kernel() { + rt_fl<32, 64> a, b, c; + __shared__ st_hf<32, 64> s; + + kittens::mul(c, a, b); + kittens::store(s, c); + } + """ + valid, errors, warnings = validate_kernel_static(code, backend="thunderkittens") + assert valid, f"Expected simple ThunderKittens kernel to pass, got errors: {errors}" + + +def test_thunderkittens_valid_warpgroup(): + """Test that ThunderKittens warpgroup MMA kernel passes""" + code = """ + #include + + using namespace kittens; + + __global__ void gemm_kernel(const bf16 *A, const bf16 *B, bf16 *C) { + using namespace kittens::warpgroup; + + rt_fl<16, 16> A_tile; + rt_fl<16, 16> B_tile; + rt_fl<16, 16> C_tile; + + zero(C_tile); + kittens::load(A_tile, A); + kittens::load(B_tile, B); + mma_AB(C_tile, A_tile, B_tile); + kittens::store(C, C_tile); + } + """ + valid, errors, warnings = validate_kernel_static(code, backend="thunderkittens") + assert valid, f"Expected ThunderKittens warpgroup kernel to pass, got errors: {errors}" + + +def test_thunderkittens_missing_compute(): + """Test that ThunderKittens kernel without compute ops is rejected""" + code = """ + #include + + using namespace kittens; + + __global__ void incomplete_kernel() { + rt_fl<32, 64> tile; + kittens::load(tile, ptr); + kittens::store(output, tile); + // Has load/store but no compute! + } + """ + valid, errors, warnings = validate_kernel_static(code, backend="thunderkittens") + assert not valid, "Expected ThunderKittens kernel without compute to fail" + assert any("compute" in err.lower() for err in errors) + + +# ----------------------------------------------------------------------------- +# CUTLASS/CuTe Backend Tests +# ----------------------------------------------------------------------------- + +def test_cutlass_valid_gemm(): + """Test that CUTLASS GEMM kernel passes""" + code = """ + #include + #include + #include + + using namespace cutlass; + + using Gemm = cutlass::gemm::device::Gemm< + float, layout::RowMajor, + float, layout::ColumnMajor, + float, layout::RowMajor + >; + + void run_gemm() { + Gemm gemm_op; + gemm_op(); + } + """ + valid, errors, warnings = validate_kernel_static(code, backend="cutlass") + assert valid, f"Expected CUTLASS GEMM kernel to pass, got errors: {errors}" + + +def test_cute_valid_tensor_ops(): + """Test that CuTe tensor operations pass""" + code = """ + #include + + __global__ void cute_kernel() { + auto tensor = cute::make_tensor(ptr, cute::make_layout(cute::make_shape(16, 16))); + auto layout = cute::make_layout(cute::make_shape(16, 16), cute::make_stride(1, 16)); + cute::copy(tensor_A, tensor_B); + } + """ + valid, errors, warnings = validate_kernel_static(code, backend="cute") + assert valid, f"Expected CuTe tensor ops kernel to pass, got errors: {errors}" + + +def test_cutlass_missing_operations(): + """Test that code with just namespace is rejected""" + code = """ + #include + + using namespace cutlass; + + void wrapper() { + // Just includes cutlass but no actual operations + float x = 1.0f; + } + """ + valid, errors, warnings = validate_kernel_static(code, backend="cutlass") + assert not valid, "Expected CUTLASS code without operations to fail" + assert any("operations" in err.lower() or "namespace" in err.lower() for err in errors) + + +# ----------------------------------------------------------------------------- +# TileLang Backend Tests +# ----------------------------------------------------------------------------- + +def test_tilelang_missing_iteration(): + """Test that TileLang kernel without iteration constructs is rejected""" + code = """ + import tvm + from tvm.script import tir as T + + @T.prim_func + def fake_kernel(): + # Has decorator but no T.grid, T.serial, etc. + x = 1.0 + """ + valid, errors, warnings = validate_kernel_static(code, backend="tilelang") + assert not valid, "Expected TileLang kernel without iteration to fail" + assert any("iteration" in err.lower() for err in errors) + + if __name__ == "__main__": pytest.main([__file__, "-v"]) -