From 4ad6d424b797cafc64ce9d0aea44cc9eb4101e17 Mon Sep 17 00:00:00 2001 From: Willy Chan Date: Fri, 16 Jan 2026 21:34:24 -0800 Subject: [PATCH 01/23] tlx support --- scripts/eval_from_generations.py | 5 ++ scripts/generate_and_eval_single_sample.py | 2 +- .../generate_and_eval_single_sample_modal.py | 7 +- scripts/generate_baseline_time_modal.py | 5 ++ scripts/generate_samples.py | 2 +- scripts/run_and_check.py | 5 ++ src/kernelbench/eval.py | 8 +- src/kernelbench/kernel_static_checker.py | 36 ++++++++- src/kernelbench/profile.py | 2 +- .../prompts/model_new_ex_add_tlx.py | 77 +++++++++++++++++++ src/kernelbench/prompts/prompts.toml | 5 ++ 11 files changed, 145 insertions(+), 9 deletions(-) create mode 100644 src/kernelbench/prompts/model_new_ex_add_tlx.py diff --git a/scripts/eval_from_generations.py b/scripts/eval_from_generations.py index 247410f3..8265db59 100644 --- a/scripts/eval_from_generations.py +++ b/scripts/eval_from_generations.py @@ -72,6 +72,11 @@ .uv_sync(uv_project_dir=REPO_TOP_DIR) .run_commands("git clone -b tk-v2 https://github.com/HazyResearch/ThunderKittens.git /root/ThunderKittens") + .run_commands( + "git clone https://github.com/facebookexperimental/triton.git /root/triton", + "cd /root/triton && pip install -r python/requirements.txt", + "cd /root/triton && pip install -e ." + ) .env({ "THUNDERKITTENS_ROOT": "/root/ThunderKittens", "PYTHONPATH": "/root/src:/root" diff --git a/scripts/generate_and_eval_single_sample.py b/scripts/generate_and_eval_single_sample.py index fce1b16f..ea37bb7f 100644 --- a/scripts/generate_and_eval_single_sample.py +++ b/scripts/generate_and_eval_single_sample.py @@ -174,7 +174,7 @@ def main(config: EvalConfig): include_hardware = include_hardware.lower() in ["true", "1", "yes"] config.include_hardware_info = include_hardware - supported_backends = {"cuda", "triton", "tilelang", "cute", "thunderkittens"} + supported_backends = {"cuda", "triton", "tlx", "tilelang", "cute", "thunderkittens"} backend = config.backend.lower() if backend not in supported_backends: raise ValueError( diff --git a/scripts/generate_and_eval_single_sample_modal.py b/scripts/generate_and_eval_single_sample_modal.py index 7308d228..3b6a73b9 100644 --- a/scripts/generate_and_eval_single_sample_modal.py +++ b/scripts/generate_and_eval_single_sample_modal.py @@ -106,6 +106,11 @@ def __repr__(self): .uv_sync(uv_project_dir=REPO_TOP_DIR, extras=["gpu"]) .run_commands("git clone -b tk-v2 https://github.com/HazyResearch/ThunderKittens.git /root/ThunderKittens") + .run_commands( + "git clone https://github.com/facebookexperimental/triton.git /root/triton", + "cd /root/triton && pip install -r python/requirements.txt", + "cd /root/triton && pip install -e ." + ) .env({ "THUNDERKITTENS_ROOT": "/root/ThunderKittens", "PYTHONPATH": "/root:/root/src" @@ -207,7 +212,7 @@ def main(config: EvalConfig): include_hardware = include_hardware.lower() in ["true", "1", "yes"] config.include_hardware_info = include_hardware - supported_backends = {"cuda", "triton", "tilelang", "cute", "thunderkittens"} + supported_backends = {"cuda", "triton", "tlx", "tilelang", "cute", "thunderkittens"} backend = config.backend.lower() if backend not in supported_backends: raise ValueError( diff --git a/scripts/generate_baseline_time_modal.py b/scripts/generate_baseline_time_modal.py index e9c8428e..62bb0af3 100644 --- a/scripts/generate_baseline_time_modal.py +++ b/scripts/generate_baseline_time_modal.py @@ -93,6 +93,11 @@ def __init__(self): "clang" # note i skip a step ) .uv_sync(uv_project_dir=REPO_TOP_PATH, extras=["gpu"]) + .run_commands( + "git clone https://github.com/facebookexperimental/triton.git /root/triton", + "cd /root/triton && pip install -r python/requirements.txt", + "cd /root/triton && pip install -e ." + ) .env({"PYTHONPATH": "/root/src"}) .add_local_dir(SRC_DIR, remote_path="/root/src") .add_local_dir(KERNELBENCH_DIR, remote_path="/root/KernelBench") # must be last diff --git a/scripts/generate_samples.py b/scripts/generate_samples.py index 2c01ee8d..fe4c955d 100644 --- a/scripts/generate_samples.py +++ b/scripts/generate_samples.py @@ -234,7 +234,7 @@ def main(config: GenerationConfig): include_hardware = include_hardware.lower() in ["true", "1", "yes"] config.include_hardware_info = include_hardware - supported_backends = {"cuda", "triton", "cute", "tilelang", "thunderkittens"} + supported_backends = {"cuda", "triton", "tlx", "cute", "tilelang", "thunderkittens"} backend = config.backend.lower() if backend not in supported_backends: raise ValueError( diff --git a/scripts/run_and_check.py b/scripts/run_and_check.py index d253dd45..faec4a9f 100644 --- a/scripts/run_and_check.py +++ b/scripts/run_and_check.py @@ -40,6 +40,11 @@ .apt_install("git", "gcc-10", "g++-10", "clang") .uv_sync(uv_project_dir=REPO_TOP_PATH) .run_commands("git clone -b tk-v2 https://github.com/HazyResearch/ThunderKittens.git /root/ThunderKittens") + .run_commands( + "git clone https://github.com/facebookexperimental/triton.git /root/triton", + "cd /root/triton && pip install -r python/requirements.txt", + "cd /root/triton && pip install -e ." + ) .env({ "THUNDERKITTENS_ROOT": "/root/ThunderKittens", "PYTHONPATH": "/root:/root/src:/root/scripts" diff --git a/src/kernelbench/eval.py b/src/kernelbench/eval.py index 47f59793..9f1206d2 100644 --- a/src/kernelbench/eval.py +++ b/src/kernelbench/eval.py @@ -443,8 +443,8 @@ def eval_kernel_against_ref( torch.cuda.set_device(device) # Backends that use tempfile approach and need CUDA_VISIBLE_DEVICES - # TileLang, Triton, and CuTe all use tempfile for proper module loading - uses_tempfile = backend.lower() in ["triton", "tilelang", "cute"] + # TileLang, Triton, TLX, and CuTe all use tempfile for proper module loading + uses_tempfile = backend.lower() in ["triton", "tlx", "tilelang", "cute"] metadata = {} # for storing result metadata metadata["hardware"] = torch.cuda.get_device_name(device=device) @@ -496,8 +496,8 @@ def eval_kernel_against_ref( # add hash for later to distinguish between multi-turn kernels backend_lower = backend.lower() - if backend_lower in ["triton", "tilelang", "cute"]: - # Use tempfile approach for triton, tilelang, and cute + if backend_lower in ["triton", "tlx", "tilelang", "cute"]: + # Use tempfile approach for triton, tlx, tilelang, and cute # These DSLs require proper module import for JIT decorators to work ModelNew, tempfile = load_custom_model_with_tempfile( custom_model_src, entry_point="ModelNew" diff --git a/src/kernelbench/kernel_static_checker.py b/src/kernelbench/kernel_static_checker.py index c8832a1a..391fade6 100644 --- a/src/kernelbench/kernel_static_checker.py +++ b/src/kernelbench/kernel_static_checker.py @@ -208,6 +208,38 @@ def check_triton_impl(code: str) -> Tuple[bool, str]: return (False, "") +# <========= TLX (Triton Language Extensions) CHECKS =========> +# Rationale: TLX extends Triton with async tasks and barriers for specialization. +# Valid TLX code must use tlx.* operations for async tasks and barriers. +TLX_PATTERNS = [ + r"tlx\.async_tasks\s*\(", # tlx.async_tasks() + r"tlx\.async_task\s*\(", # tlx.async_task() + r"tlx\.barrier_wait\s*\(", # tlx.barrier_wait() + r"tlx\.barrier_arrive\s*\(", # tlx.barrier_arrive() + r"tlx\.barrier_create\s*\(", # tlx.barrier_create() + r"import\s+tlx", # import tlx + r"from\s+tlx", # from tlx + r"triton\.language\.extensions", # triton.language.extensions +] + +def check_tlx_impl(code: str) -> Tuple[bool, str]: + """ + Check for valid TLX (Triton Language Extensions) kernel implementation. + + Requirements: + - Must have @triton.jit or @triton.autotune decorator (inherited from Triton) + - Must have tlx.* operations (async_tasks, async_task, barrier_wait, barrier_arrive, etc.) + + Note: TLX extends Triton, so it should also have triton.jit decorator. + """ + code = _strip_comments(code) + if not re.search(TRITON_JIT_PATTERN, code): + return (True, "Missing @triton.jit or @triton.autotune (TLX extends Triton)") + if not any(re.search(p, code) for p in TLX_PATTERNS): + return (True, "Missing TLX operations (tlx.async_tasks, tlx.barrier_*, etc.)") + return (False, "") + + # <========= THUNDERKITTENS CHECKS =========> # Rationale: ThunderKittens uses warp/warpgroup primitives and tile abstractions. # Valid TK code must have namespace patterns and tile declarations. @@ -556,6 +588,7 @@ def check_precision_downgrade(code: str, precision: str = "fp32") -> Tuple[bool, # should be strict "cuda_impl": check_cuda_impl, "triton_impl": check_triton_impl, + "tlx_impl": check_tlx_impl, "tk_impl": check_tk_impl, "cute_impl": check_cute_impl, "tilelang_impl": check_tilelang_impl, @@ -579,6 +612,7 @@ def check_precision_downgrade(code: str, precision: str = "fp32") -> Tuple[bool, BACKEND_IMPL_CHECK = { "cuda": "cuda_impl", "triton": "triton_impl", + "tlx": "tlx_impl", "thunderkittens": "tk_impl", "cute": "cute_impl", "cutlass": "cute_impl", # alias @@ -614,7 +648,7 @@ def validate_kernel_static( Args: code: Kernel source code - backend: "cuda", "triton", or "thunderkittens" + backend: "cuda", "triton", "tlx", "thunderkittens", "cute", or "tilelang" precision: "fp16", "fp32", or "bf16" (for future precision checks) forbidden: Check categories that cause errors (default: STRICT_CHECKS) warnings: Check categories that cause warnings (default: WARNING_CHECKS) diff --git a/src/kernelbench/profile.py b/src/kernelbench/profile.py index 8326324e..67d4fc97 100644 --- a/src/kernelbench/profile.py +++ b/src/kernelbench/profile.py @@ -249,7 +249,7 @@ def profile_kernelbench_model_with_nsight( tempfile = None # Different backends require different loading mechanisms - if backend.lower() in ["triton", "tilelang", "cute"]: + if backend.lower() in ["triton", "tlx", "tilelang", "cute"]: # These backends need a temp file for proper module loading ModelNew, tempfile = load_custom_model_with_tempfile( custom_model_src, entry_point="ModelNew" diff --git a/src/kernelbench/prompts/model_new_ex_add_tlx.py b/src/kernelbench/prompts/model_new_ex_add_tlx.py new file mode 100644 index 00000000..df3defe9 --- /dev/null +++ b/src/kernelbench/prompts/model_new_ex_add_tlx.py @@ -0,0 +1,77 @@ +import torch +import torch.nn as nn +import triton +import triton.language as tl +import triton.language.extensions as tlx + + +@triton.jit +def tlx_add_kernel( + x_ptr, # Pointer to first input + y_ptr, # Pointer to second input + out_ptr, # Pointer to output + n_elements, # Total number of elements in input/output + BLOCK_SIZE: tl.constexpr, +): + # Create barriers for synchronization + b0 = tlx.barrier_create() + b1 = tlx.barrier_create() + + phase = 0 + with tlx.async_tasks(): + with tlx.async_task("default"): + tlx.barrier_wait(bar=b1, phase=phase ^ 1) + + # Placeholder block to do something + + tlx.barrier_arrive(bar=b0) # Release + + with tlx.async_task(num_warps=4): + tlx.barrier_wait(bar=b0, phase=phase) # Wait + + # Some arith ops + block_start = tl.program_id(0) * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + x = tl.load(x_ptr + offsets, mask=mask) + y = tl.load(y_ptr + offsets, mask=mask) + z = x + y + tl.store(out_ptr + offsets, z, mask=mask) + + tlx.barrier_arrive(bar=b0) # Wait + + +def tlx_add(x: torch.Tensor, y: torch.Tensor): + """ + This function wraps the TLX kernel call. It: + 1. Ensures the inputs are contiguous on GPU. + 2. Calculates the grid (blocks) needed. + 3. Launches the TLX kernel. + """ + assert x.is_cuda and y.is_cuda, "Tensors must be on CUDA." + x = x.contiguous() + y = y.contiguous() + + # Prepare output tensor + out = torch.empty_like(x) + + # Number of elements in the tensor + n_elements = x.numel() + BLOCK_SIZE = 128 # Tunable parameter for block size + + # Determine the number of blocks needed + grid = lambda meta: ((n_elements + meta["BLOCK_SIZE"] - 1) // meta["BLOCK_SIZE"],) + + # Launch the TLX kernel + tlx_add_kernel[grid](x, y, out, n_elements, BLOCK_SIZE=BLOCK_SIZE) + return out + + +class ModelNew(nn.Module): + def __init__(self) -> None: + super().__init__() + + def forward(self, a, b): + # Instead of "return a + b", call our TLX-based addition + return tlx_add(a, b) + diff --git a/src/kernelbench/prompts/prompts.toml b/src/kernelbench/prompts/prompts.toml index 2768aa11..c2fcb34e 100644 --- a/src/kernelbench/prompts/prompts.toml +++ b/src/kernelbench/prompts/prompts.toml @@ -54,6 +54,11 @@ backend_display = "ThunderKittens kernels" one_shot_new_arch = "src/kernelbench/prompts/model_new_ex_add_thunderkittens.py" # No few_shot_examples - will use one-shot when few_shot option is selected +[backends.tlx] +backend_display = "TLX (Triton Language Extensions) kernels" +one_shot_new_arch = "src/kernelbench/prompts/model_new_ex_add_tlx.py" +# No few_shot_examples - will use one-shot when few_shot option is selected + # ------------------------------------------------------------------------- # Precision: Precision-specific configuration # ------------------------------------------------------------------------- From 5a72ac63344874fbfe556fd76b9a677b6709e731 Mon Sep 17 00:00:00 2001 From: Willy Chan Date: Fri, 16 Jan 2026 23:26:37 -0800 Subject: [PATCH 02/23] tlx changes --- scripts/eval_from_generations.py | 26 ++++++++--- scripts/generate_and_eval_single_sample.py | 14 ++++++ .../generate_and_eval_single_sample_modal.py | 33 +++++++++++--- scripts/generate_baseline_time_modal.py | 14 +++--- scripts/generate_samples.py | 13 ++++++ scripts/run_and_check.py | 45 ++++++++++++++++--- 6 files changed, 122 insertions(+), 23 deletions(-) diff --git a/scripts/eval_from_generations.py b/scripts/eval_from_generations.py index 8265db59..bedbdb9c 100644 --- a/scripts/eval_from_generations.py +++ b/scripts/eval_from_generations.py @@ -43,6 +43,16 @@ - performance (n_trials): 100 randomized input trials You can increase the number of trials for correctness and performance + +TLX Example: +uv run python scripts/eval_from_generations.py \ + run_name=test_tlx_level1 \ + dataset_src=huggingface \ + level=1 \ + subset="(5,6)" \ + eval_mode=modal \ + backend=tlx \ + verbose=True """ REPO_TOP_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) @@ -67,19 +77,23 @@ .apt_install("git", "gcc-10", "g++-10", - "clang" + "clang", + "cmake", + "ninja-build", + "zlib1g-dev" ) .uv_sync(uv_project_dir=REPO_TOP_DIR) - .run_commands("git clone -b tk-v2 https://github.com/HazyResearch/ThunderKittens.git /root/ThunderKittens") + .run_commands("git clone https://github.com/HazyResearch/ThunderKittens.git /root/ThunderKittens") .run_commands( - "git clone https://github.com/facebookexperimental/triton.git /root/triton", - "cd /root/triton && pip install -r python/requirements.txt", - "cd /root/triton && pip install -e ." + "git clone https://github.com/facebookexperimental/triton.git /root/triton && " + "cd /root/triton && " + "pip install -r python/requirements.txt && " + "pip install -e ." ) .env({ "THUNDERKITTENS_ROOT": "/root/ThunderKittens", - "PYTHONPATH": "/root/src:/root" + "PYTHONPATH": "/root/src:/root:/root/triton/python" }) .add_local_dir(SRC_DIR, remote_path="/root/src") .add_local_dir(KERNELBENCH_DIR, remote_path="/root/KernelBench") # must be last diff --git a/scripts/generate_and_eval_single_sample.py b/scripts/generate_and_eval_single_sample.py index ea37bb7f..22514120 100644 --- a/scripts/generate_and_eval_single_sample.py +++ b/scripts/generate_and_eval_single_sample.py @@ -20,6 +20,20 @@ Example usage: python3 scripts/generate_and_eval_single_sample.py dataset_src=huggingface level=1 problem_id=1 eval_mode=local server_type=google model_name=gemini/gemini-2.5-flash max_tokens=8192 temperature=0.0 + +TLX Example (NEED LOCAL GPU): +uv run python scripts/generate_and_eval_single_sample.py \ + dataset_src=huggingface \ + level=1 \ + problem_id=1 \ + backend=tlx \ + server_type=google \ + model_name=gemini/gemini-2.5-flash \ + max_tokens=60000 \ + temperature=0.0 \ + log=True \ + log_generated_kernel=True \ + verbose=True """ REPO_TOP_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) diff --git a/scripts/generate_and_eval_single_sample_modal.py b/scripts/generate_and_eval_single_sample_modal.py index 3b6a73b9..a7b2107d 100644 --- a/scripts/generate_and_eval_single_sample_modal.py +++ b/scripts/generate_and_eval_single_sample_modal.py @@ -2,6 +2,23 @@ Example Usage: python scripts/generate_and_eval_single_sample_modal.py dataset_src=huggingfac level=1 problem_id=1 eval_mode=modal gpu=L40S server_type=deepseek model_name=deepseek-coder max_tokens=4096 temperature=0.0 + +TLX Example: +uv run python scripts/generate_and_eval_single_sample_modal.py \ + dataset_src=huggingface \ + level=1 \ + problem_id=1 \ + eval_mode=modal \ + gpu=H100 \ + backend=tlx \ + server_type=google \ + model_name=gemini/gemini-2.5-flash \ + max_tokens=60000 \ + temperature=0.0 \ + log=True \ + log_prompt=True \ + log_generated_kernel=True \ + verbose=True ''' import pydra @@ -101,19 +118,23 @@ def __repr__(self): .apt_install("git", "gcc-10", "g++-10", - "clang" # note i skip a step + "clang", + "cmake", + "ninja-build", + "zlib1g-dev" ) .uv_sync(uv_project_dir=REPO_TOP_DIR, extras=["gpu"]) - .run_commands("git clone -b tk-v2 https://github.com/HazyResearch/ThunderKittens.git /root/ThunderKittens") + .run_commands("git clone https://github.com/HazyResearch/ThunderKittens.git /root/ThunderKittens") .run_commands( - "git clone https://github.com/facebookexperimental/triton.git /root/triton", - "cd /root/triton && pip install -r python/requirements.txt", - "cd /root/triton && pip install -e ." + "git clone https://github.com/facebookexperimental/triton.git /root/triton && " + "cd /root/triton && " + "pip install -r python/requirements.txt && " + "pip install -e ." ) .env({ "THUNDERKITTENS_ROOT": "/root/ThunderKittens", - "PYTHONPATH": "/root:/root/src" + "PYTHONPATH": "/root:/root/src:/root/triton/python" }) .add_local_dir(SRC_DIR, remote_path="/root/src") # must be last ) diff --git a/scripts/generate_baseline_time_modal.py b/scripts/generate_baseline_time_modal.py index 62bb0af3..df0c183c 100644 --- a/scripts/generate_baseline_time_modal.py +++ b/scripts/generate_baseline_time_modal.py @@ -90,15 +90,19 @@ def __init__(self): .apt_install("git", "gcc-10", "g++-10", - "clang" # note i skip a step + "clang", + "cmake", + "ninja-build", + "zlib1g-dev" ) .uv_sync(uv_project_dir=REPO_TOP_PATH, extras=["gpu"]) .run_commands( - "git clone https://github.com/facebookexperimental/triton.git /root/triton", - "cd /root/triton && pip install -r python/requirements.txt", - "cd /root/triton && pip install -e ." + "git clone https://github.com/facebookexperimental/triton.git /root/triton && " + "cd /root/triton && " + "pip install -r python/requirements.txt && " + "pip install -e ." ) - .env({"PYTHONPATH": "/root/src"}) + .env({"PYTHONPATH": "/root/src:/root/triton/python"}) .add_local_dir(SRC_DIR, remote_path="/root/src") .add_local_dir(KERNELBENCH_DIR, remote_path="/root/KernelBench") # must be last ) diff --git a/scripts/generate_samples.py b/scripts/generate_samples.py index fe4c955d..5096de20 100644 --- a/scripts/generate_samples.py +++ b/scripts/generate_samples.py @@ -22,6 +22,19 @@ Batch Generate Samples for Particular Level Assume 1 sample per problem here + +TLX Example: +uv run python scripts/generate_samples.py \ + dataset_src=huggingface \ + level=1 \ + subset="(1,5)" \ + run_name=test_tlx_level1 \ + backend=tlx \ + num_samples=1 \ + server_type=google \ + model_name=gemini/gemini-2.5-flash \ + max_tokens=8192 \ + temperature=0.0 """ REPO_TOP_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) diff --git a/scripts/run_and_check.py b/scripts/run_and_check.py index faec4a9f..08c2906d 100644 --- a/scripts/run_and_check.py +++ b/scripts/run_and_check.py @@ -37,17 +37,18 @@ image = ( modal.Image.from_registry(f"nvidia/cuda:{tag}", add_python="3.10") - .apt_install("git", "gcc-10", "g++-10", "clang") + .apt_install("git", "gcc-10", "g++-10", "clang", "cmake", "ninja-build", "zlib1g-dev") .uv_sync(uv_project_dir=REPO_TOP_PATH) - .run_commands("git clone -b tk-v2 https://github.com/HazyResearch/ThunderKittens.git /root/ThunderKittens") + .run_commands("git clone https://github.com/HazyResearch/ThunderKittens.git /root/ThunderKittens") .run_commands( - "git clone https://github.com/facebookexperimental/triton.git /root/triton", - "cd /root/triton && pip install -r python/requirements.txt", - "cd /root/triton && pip install -e ." + "git clone https://github.com/facebookexperimental/triton.git /root/triton && " + "cd /root/triton && " + "pip install -r python/requirements.txt && " + "pip install -e ." ) .env({ "THUNDERKITTENS_ROOT": "/root/ThunderKittens", - "PYTHONPATH": "/root:/root/src:/root/scripts" + "PYTHONPATH": "/root:/root/src:/root/scripts:/root/triton/python" }) .add_local_dir(SRC_DIR, remote_path="/root/src") .add_local_dir(SCRIPTS_DIR, remote_path="/root/scripts") @@ -81,6 +82,38 @@ 4. PyTorch reference is a kernelbench problem (modal eval on cloud GPU) python3 scripts/run_and_check.py ref_origin=kernelbench level= problem_id= kernel_src_path= eval_mode=modal gpu=L40S + +TLX Examples: +uv run python scripts/run_and_check.py \ + ref_origin=kernelbench \ + level=1 \ + problem_id=1 \ + kernel_src_path=runs/valid_tlx/gemm_pc.py \ + eval_mode=modal \ + gpu=H100 \ + backend=tlx \ + precision=bf16 \ + verbose=True + +uv run python scripts/run_and_check.py \ + ref_origin=local \ + ref_arch_src_path=runs/valid_tlx/fftconv_reference.py \ + kernel_src_path=runs/valid_tlx/fftconv_pc.py \ + eval_mode=modal \ + gpu=H100 \ + backend=tlx \ + precision=bf16 \ + verbose=True + +uv run python scripts/run_and_check.py \ + ref_origin=local \ + ref_arch_src_path=runs/valid_tlx/layernorm_reference.py \ + kernel_src_path=runs/valid_tlx/layernorm_nonpc.py \ + eval_mode=modal \ + gpu=H100 \ + backend=tlx \ + precision=bf16 \ + verbose=True ==================================================== """ From 93457324b64dcbff8b3299fa2e95d1b3a6958a56 Mon Sep 17 00:00:00 2001 From: Willy Chan Date: Mon, 19 Jan 2026 01:36:36 -0800 Subject: [PATCH 03/23] fixed tlx example --- .../prompts/model_ex_matmul_tlx.py | 9 + .../prompts/model_new_ex_add_tlx.py | 168 +++++++++++------- 2 files changed, 109 insertions(+), 68 deletions(-) create mode 100644 src/kernelbench/prompts/model_ex_matmul_tlx.py diff --git a/src/kernelbench/prompts/model_ex_matmul_tlx.py b/src/kernelbench/prompts/model_ex_matmul_tlx.py new file mode 100644 index 00000000..d217d1e9 --- /dev/null +++ b/src/kernelbench/prompts/model_ex_matmul_tlx.py @@ -0,0 +1,9 @@ +import torch +import torch.nn as nn + +class Model(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, a, b): + return torch.matmul(a, b) diff --git a/src/kernelbench/prompts/model_new_ex_add_tlx.py b/src/kernelbench/prompts/model_new_ex_add_tlx.py index df3defe9..6be9e75e 100644 --- a/src/kernelbench/prompts/model_new_ex_add_tlx.py +++ b/src/kernelbench/prompts/model_new_ex_add_tlx.py @@ -2,76 +2,108 @@ import torch.nn as nn import triton import triton.language as tl -import triton.language.extensions as tlx - +import triton.language.extra.tlx as tlx @triton.jit -def tlx_add_kernel( - x_ptr, # Pointer to first input - y_ptr, # Pointer to second input - out_ptr, # Pointer to output - n_elements, # Total number of elements in input/output - BLOCK_SIZE: tl.constexpr, -): - # Create barriers for synchronization - b0 = tlx.barrier_create() - b1 = tlx.barrier_create() - - phase = 0 - with tlx.async_tasks(): - with tlx.async_task("default"): - tlx.barrier_wait(bar=b1, phase=phase ^ 1) - - # Placeholder block to do something - - tlx.barrier_arrive(bar=b0) # Release - - with tlx.async_task(num_warps=4): - tlx.barrier_wait(bar=b0, phase=phase) # Wait - - # Some arith ops - block_start = tl.program_id(0) * BLOCK_SIZE - offsets = block_start + tl.arange(0, BLOCK_SIZE) - mask = offsets < n_elements - x = tl.load(x_ptr + offsets, mask=mask) - y = tl.load(y_ptr + offsets, mask=mask) - z = x + y - tl.store(out_ptr + offsets, z, mask=mask) - - tlx.barrier_arrive(bar=b0) # Wait - - -def tlx_add(x: torch.Tensor, y: torch.Tensor): - """ - This function wraps the TLX kernel call. It: - 1. Ensures the inputs are contiguous on GPU. - 2. Calculates the grid (blocks) needed. - 3. Launches the TLX kernel. - """ - assert x.is_cuda and y.is_cuda, "Tensors must be on CUDA." - x = x.contiguous() - y = y.contiguous() - - # Prepare output tensor - out = torch.empty_like(x) - - # Number of elements in the tensor - n_elements = x.numel() - BLOCK_SIZE = 128 # Tunable parameter for block size - - # Determine the number of blocks needed - grid = lambda meta: ((n_elements + meta["BLOCK_SIZE"] - 1) // meta["BLOCK_SIZE"],) - - # Launch the TLX kernel - tlx_add_kernel[grid](x, y, out, n_elements, BLOCK_SIZE=BLOCK_SIZE) - return out - +def matmul_kernel_pipelined_hopper(a_ptr, b_ptr, c_ptr, M, N, K, stride_am, stride_ak, # + stride_bk, stride_bn, # + stride_cm, stride_cn, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, # + GROUP_SIZE_M: tl.constexpr, # + NUM_STAGES: tl.constexpr # + ): + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + # offset computation + offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) + b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + + # allocate NUM_STAGES buffers + buffers_A = tlx.local_alloc((BLOCK_SIZE_M, BLOCK_SIZE_K), tlx.dtype_of(a_ptr), NUM_STAGES) + buffers_B = tlx.local_alloc((BLOCK_SIZE_K, BLOCK_SIZE_N), tlx.dtype_of(b_ptr), NUM_STAGES) + + # prefetch (pipelining) for NUM_STAGES - 1 buffers + for i in tl.range(0, NUM_STAGES - 1, loop_unroll_factor=NUM_STAGES - 1): + a = tlx.local_view(buffers_A, i) + b = tlx.local_view(buffers_B, i) + token_a = tlx.async_load(a_ptrs, a, mask=offs_k[None, :] < K - i * BLOCK_SIZE_K) + token_b = tlx.async_load(b_ptrs, b, mask=offs_k[:, None] < K - i * BLOCK_SIZE_K) + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk + tlx.async_load_commit_group([token_a, token_b]) + + # main K loop + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + # Disable auto-pipelining with num_stages=0 + for k in tl.range(0, tl.cdiv(K, BLOCK_SIZE_K), num_stages=0): + # identify the buffer index for the current iteration + buf = k % NUM_STAGES + a_k = tlx.local_view(buffers_A, buf) + b_k = tlx.local_view(buffers_B, buf) + + # wait for buffers to be ready + tlx.async_load_wait_group(NUM_STAGES - 2) + + # do the mma + acc = tlx.async_dot(a_k, b_k, acc) + + # prefetch for i-th iteration, i.e, NUM_STAGES - 1 ahead + i = k + NUM_STAGES - 1 + a_next = tlx.local_view(buffers_A, i % NUM_STAGES) + b_next = tlx.local_view(buffers_B, i % NUM_STAGES) + # wait for the previous MMA using this buffer to complete + acc = tlx.async_dot_wait(1, acc) + # prefetch + token_a = tlx.async_load(a_ptrs, a_next, mask=offs_k[None, :] < K - i * BLOCK_SIZE_K) + token_b = tlx.async_load(b_ptrs, b_next, mask=offs_k[:, None] < K - i * BLOCK_SIZE_K) + tlx.async_load_commit_group([token_a, token_b]) + # Advance the ptrs to the next K block. + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk + + # wait for last mma to complete + acc = tlx.async_dot_wait(0, acc) + c = acc.to(tlx.dtype_of(c_ptr)) + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + tl.store(c_ptrs, c, mask=c_mask) + + +def matmul(a, b): + # Check constraints. + assert a.shape[1] == b.shape[0], "Incompatible dimensions" + assert a.is_contiguous(), "Matrix A must be contiguous" + M, K = a.shape + K, N = b.shape + # Allocates output. + c = torch.empty((M, N), device=a.device, dtype=torch.float16) + # 1D launch kernel where each block gets its own program. + grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), ) + matmul_kernel_pipelined_hopper[grid]( + a, b, c, # + M, N, K, # + a.stride(0), a.stride(1), # + b.stride(0), b.stride(1), # + c.stride(0), c.stride(1), # + ) + return c class ModelNew(nn.Module): - def __init__(self) -> None: - super().__init__() - - def forward(self, a, b): - # Instead of "return a + b", call our TLX-based addition - return tlx_add(a, b) + def __init__(self): + super(ModelNew, self).__init__() + def forward(self, a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: + return matmul(a, b) From 1769be8a6390a58c8a35c52289946ecee481fc87 Mon Sep 17 00:00:00 2001 From: Willy Chan Date: Mon, 19 Jan 2026 02:18:29 -0800 Subject: [PATCH 04/23] run and check working with matmul example --- scripts/run_and_check.py | 12 +++++++----- src/kernelbench/prompts/model_ex_matmul_tlx.py | 15 +++++++++++++++ src/kernelbench/prompts/model_new_ex_add_tlx.py | 3 +++ 3 files changed, 25 insertions(+), 5 deletions(-) diff --git a/scripts/run_and_check.py b/scripts/run_and_check.py index 08c2906d..d90d5a08 100644 --- a/scripts/run_and_check.py +++ b/scripts/run_and_check.py @@ -36,15 +36,17 @@ KERNELBENCH_DIR = os.path.join(REPO_TOP_PATH, "KernelBench") image = ( - modal.Image.from_registry(f"nvidia/cuda:{tag}", add_python="3.10") + modal.Image.from_registry(f"nvidia/cuda:{tag}", add_python="3.11") .apt_install("git", "gcc-10", "g++-10", "clang", "cmake", "ninja-build", "zlib1g-dev") .uv_sync(uv_project_dir=REPO_TOP_PATH) .run_commands("git clone https://github.com/HazyResearch/ThunderKittens.git /root/ThunderKittens") + # Uninstall standard triton first (fast step, separate layer to avoid rebuilding triton on changes) + .run_commands("pip uninstall -y triton || true") + # Install TLX-enabled Triton (slow step, cached unless repo changes) + .env({"MAX_JOBS": "8"}) # Speed up compilation .run_commands( - "git clone https://github.com/facebookexperimental/triton.git /root/triton && " - "cd /root/triton && " - "pip install -r python/requirements.txt && " - "pip install -e ." + "git clone --depth 1 https://github.com/facebookexperimental/triton.git /root/triton", + "cd /root/triton && pip install -r python/requirements.txt && pip install -e ." ) .env({ "THUNDERKITTENS_ROOT": "/root/ThunderKittens", diff --git a/src/kernelbench/prompts/model_ex_matmul_tlx.py b/src/kernelbench/prompts/model_ex_matmul_tlx.py index d217d1e9..cc1db179 100644 --- a/src/kernelbench/prompts/model_ex_matmul_tlx.py +++ b/src/kernelbench/prompts/model_ex_matmul_tlx.py @@ -7,3 +7,18 @@ def __init__(self): def forward(self, a, b): return torch.matmul(a, b) + +def get_inputs(): + # randomly generate input tensors for a matmul operation + # Using sizes compatible with the TLX kernel logic (e.g. divisible by block sizes ideally, though the kernel handles remainders) + # The kernel has BLOCK_SIZE_M=128, BLOCK_SIZE_N=256, BLOCK_SIZE_K=64 + # Let's use standard sizes. + M = 4096 + N = 4096 + K = 4096 + a = torch.randn(M, K).cuda().to(torch.float16) + b = torch.randn(K, N).cuda().to(torch.float16) + return [a, b] + +def get_init_inputs(): + return [] diff --git a/src/kernelbench/prompts/model_new_ex_add_tlx.py b/src/kernelbench/prompts/model_new_ex_add_tlx.py index 6be9e75e..f9c84f2d 100644 --- a/src/kernelbench/prompts/model_new_ex_add_tlx.py +++ b/src/kernelbench/prompts/model_new_ex_add_tlx.py @@ -98,6 +98,9 @@ def matmul(a, b): a.stride(0), a.stride(1), # b.stride(0), b.stride(1), # c.stride(0), c.stride(1), # + BLOCK_SIZE_M=128, BLOCK_SIZE_N=256, + BLOCK_SIZE_K=64, GROUP_SIZE_M=8, + NUM_STAGES=3 ) return c From e8668dc074a0050c8bf11e62416804d2118aa3fe Mon Sep 17 00:00:00 2001 From: Willy Chan Date: Mon, 19 Jan 2026 02:20:54 -0800 Subject: [PATCH 05/23] renamed model tlx example for clarity --- .../{model_new_ex_add_tlx.py => model_new_ex_matmul_tlx.py} | 0 src/kernelbench/prompts/prompts.toml | 3 ++- 2 files changed, 2 insertions(+), 1 deletion(-) rename src/kernelbench/prompts/{model_new_ex_add_tlx.py => model_new_ex_matmul_tlx.py} (100%) diff --git a/src/kernelbench/prompts/model_new_ex_add_tlx.py b/src/kernelbench/prompts/model_new_ex_matmul_tlx.py similarity index 100% rename from src/kernelbench/prompts/model_new_ex_add_tlx.py rename to src/kernelbench/prompts/model_new_ex_matmul_tlx.py diff --git a/src/kernelbench/prompts/prompts.toml b/src/kernelbench/prompts/prompts.toml index c2fcb34e..1375cb42 100644 --- a/src/kernelbench/prompts/prompts.toml +++ b/src/kernelbench/prompts/prompts.toml @@ -56,7 +56,8 @@ one_shot_new_arch = "src/kernelbench/prompts/model_new_ex_add_thunderkittens.py" [backends.tlx] backend_display = "TLX (Triton Language Extensions) kernels" -one_shot_new_arch = "src/kernelbench/prompts/model_new_ex_add_tlx.py" +one_shot_new_arch = "src/kernelbench/prompts/model_new_ex_matmul_tlx.py" +few_shot_example_arch = "src/kernelbench/prompts/model_ex_matmul_tlx.py" # No few_shot_examples - will use one-shot when few_shot option is selected # ------------------------------------------------------------------------- From 3e80052a1006f92c7cea8077ca5ab075cb2e3e15 Mon Sep 17 00:00:00 2001 From: Willy Chan Date: Mon, 19 Jan 2026 02:24:21 -0800 Subject: [PATCH 06/23] working for single sample modal but only for precision=fp16 --- scripts/generate_and_eval_single_sample_modal.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/scripts/generate_and_eval_single_sample_modal.py b/scripts/generate_and_eval_single_sample_modal.py index a7b2107d..bfddd84c 100644 --- a/scripts/generate_and_eval_single_sample_modal.py +++ b/scripts/generate_and_eval_single_sample_modal.py @@ -114,7 +114,7 @@ def __repr__(self): SRC_DIR = os.path.join(REPO_TOP_DIR, "src") image = ( - modal.Image.from_registry(f"nvidia/cuda:{tag}", add_python="3.10") + modal.Image.from_registry(f"nvidia/cuda:{tag}", add_python="3.11") .apt_install("git", "gcc-10", "g++-10", @@ -123,18 +123,19 @@ def __repr__(self): "ninja-build", "zlib1g-dev" ) - .uv_sync(uv_project_dir=REPO_TOP_DIR, extras=["gpu"]) .run_commands("git clone https://github.com/HazyResearch/ThunderKittens.git /root/ThunderKittens") + # Uninstall standard triton first (fast step, separate layer to avoid rebuilding triton on changes) + .run_commands("pip uninstall -y triton || true") + # Install TLX-enabled Triton (slow step, cached unless repo changes) + .env({"MAX_JOBS": "8"}) # Speed up compilation .run_commands( - "git clone https://github.com/facebookexperimental/triton.git /root/triton && " - "cd /root/triton && " - "pip install -r python/requirements.txt && " - "pip install -e ." + "git clone --depth 1 https://github.com/facebookexperimental/triton.git /root/triton", + "cd /root/triton && pip install -r python/requirements.txt && pip install -e ." ) .env({ "THUNDERKITTENS_ROOT": "/root/ThunderKittens", - "PYTHONPATH": "/root:/root/src:/root/triton/python" + "PYTHONPATH": "/root:/root/src:/root/scripts:/root/triton/python" }) .add_local_dir(SRC_DIR, remote_path="/root/src") # must be last ) From 3a66b711d7798d56c29ee20385f69661763a9289 Mon Sep 17 00:00:00 2001 From: Willy Chan Date: Mon, 19 Jan 2026 16:40:05 -0800 Subject: [PATCH 07/23] update to static checker: now only needs to have tlx.async somewhere in there --- src/kernelbench/kernel_static_checker.py | 24 ++++++------------------ 1 file changed, 6 insertions(+), 18 deletions(-) diff --git a/src/kernelbench/kernel_static_checker.py b/src/kernelbench/kernel_static_checker.py index 391fade6..691f5c7e 100644 --- a/src/kernelbench/kernel_static_checker.py +++ b/src/kernelbench/kernel_static_checker.py @@ -210,33 +210,21 @@ def check_triton_impl(code: str) -> Tuple[bool, str]: # <========= TLX (Triton Language Extensions) CHECKS =========> # Rationale: TLX extends Triton with async tasks and barriers for specialization. -# Valid TLX code must use tlx.* operations for async tasks and barriers. -TLX_PATTERNS = [ - r"tlx\.async_tasks\s*\(", # tlx.async_tasks() - r"tlx\.async_task\s*\(", # tlx.async_task() - r"tlx\.barrier_wait\s*\(", # tlx.barrier_wait() - r"tlx\.barrier_arrive\s*\(", # tlx.barrier_arrive() - r"tlx\.barrier_create\s*\(", # tlx.barrier_create() - r"import\s+tlx", # import tlx - r"from\s+tlx", # from tlx - r"triton\.language\.extensions", # triton.language.extensions -] +# Valid TLX code MUST use "tlx.async" - this is the key requirement for TLX kernels. +TLX_ASYNC_PATTERN = r"tlx\.async" def check_tlx_impl(code: str) -> Tuple[bool, str]: """ Check for valid TLX (Triton Language Extensions) kernel implementation. Requirements: - - Must have @triton.jit or @triton.autotune decorator (inherited from Triton) - - Must have tlx.* operations (async_tasks, async_task, barrier_wait, barrier_arrive, etc.) + - MUST contain "tlx.async" anywhere in the code (e.g., tlx.async_tasks, tlx.async_task) - Note: TLX extends Triton, so it should also have triton.jit decorator. + This is a reward hack check: if "tlx.async" appears, the kernel is valid. """ code = _strip_comments(code) - if not re.search(TRITON_JIT_PATTERN, code): - return (True, "Missing @triton.jit or @triton.autotune (TLX extends Triton)") - if not any(re.search(p, code) for p in TLX_PATTERNS): - return (True, "Missing TLX operations (tlx.async_tasks, tlx.barrier_*, etc.)") + if "tlx.async" not in code: + return (True, "Missing 'tlx.async' - TLX kernels must use tlx.async_tasks or tlx.async_task") return (False, "") From bea25901cc95c2afdaee9dd401c9e487205433f5 Mon Sep 17 00:00:00 2001 From: Willy Chan Date: Mon, 19 Jan 2026 16:45:41 -0800 Subject: [PATCH 08/23] fixed run and check python version --- scripts/run_and_check.py | 34 +--------------------------------- 1 file changed, 1 insertion(+), 33 deletions(-) diff --git a/scripts/run_and_check.py b/scripts/run_and_check.py index d90d5a08..153324b2 100644 --- a/scripts/run_and_check.py +++ b/scripts/run_and_check.py @@ -36,7 +36,7 @@ KERNELBENCH_DIR = os.path.join(REPO_TOP_PATH, "KernelBench") image = ( - modal.Image.from_registry(f"nvidia/cuda:{tag}", add_python="3.11") + modal.Image.from_registry(f"nvidia/cuda:{tag}", add_python="3.10") .apt_install("git", "gcc-10", "g++-10", "clang", "cmake", "ninja-build", "zlib1g-dev") .uv_sync(uv_project_dir=REPO_TOP_PATH) .run_commands("git clone https://github.com/HazyResearch/ThunderKittens.git /root/ThunderKittens") @@ -84,38 +84,6 @@ 4. PyTorch reference is a kernelbench problem (modal eval on cloud GPU) python3 scripts/run_and_check.py ref_origin=kernelbench level= problem_id= kernel_src_path= eval_mode=modal gpu=L40S - -TLX Examples: -uv run python scripts/run_and_check.py \ - ref_origin=kernelbench \ - level=1 \ - problem_id=1 \ - kernel_src_path=runs/valid_tlx/gemm_pc.py \ - eval_mode=modal \ - gpu=H100 \ - backend=tlx \ - precision=bf16 \ - verbose=True - -uv run python scripts/run_and_check.py \ - ref_origin=local \ - ref_arch_src_path=runs/valid_tlx/fftconv_reference.py \ - kernel_src_path=runs/valid_tlx/fftconv_pc.py \ - eval_mode=modal \ - gpu=H100 \ - backend=tlx \ - precision=bf16 \ - verbose=True - -uv run python scripts/run_and_check.py \ - ref_origin=local \ - ref_arch_src_path=runs/valid_tlx/layernorm_reference.py \ - kernel_src_path=runs/valid_tlx/layernorm_nonpc.py \ - eval_mode=modal \ - gpu=H100 \ - backend=tlx \ - precision=bf16 \ - verbose=True ==================================================== """ From 9e4af0a7f0e73817e37ec35558ce184283f3f37f Mon Sep 17 00:00:00 2001 From: Willy Chan Date: Mon, 19 Jan 2026 16:47:13 -0800 Subject: [PATCH 09/23] only a one shot new arch for tlx prompts.toml --- src/kernelbench/prompts/prompts.toml | 1 - 1 file changed, 1 deletion(-) diff --git a/src/kernelbench/prompts/prompts.toml b/src/kernelbench/prompts/prompts.toml index 1375cb42..4dac7bbe 100644 --- a/src/kernelbench/prompts/prompts.toml +++ b/src/kernelbench/prompts/prompts.toml @@ -57,7 +57,6 @@ one_shot_new_arch = "src/kernelbench/prompts/model_new_ex_add_thunderkittens.py" [backends.tlx] backend_display = "TLX (Triton Language Extensions) kernels" one_shot_new_arch = "src/kernelbench/prompts/model_new_ex_matmul_tlx.py" -few_shot_example_arch = "src/kernelbench/prompts/model_ex_matmul_tlx.py" # No few_shot_examples - will use one-shot when few_shot option is selected # ------------------------------------------------------------------------- From 24da34e2211213d1d9c31a371c16aaf315c180de Mon Sep 17 00:00:00 2001 From: Willy Chan Date: Mon, 19 Jan 2026 16:55:35 -0800 Subject: [PATCH 10/23] static checker: cannot have triton autotune --- src/kernelbench/kernel_static_checker.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/kernelbench/kernel_static_checker.py b/src/kernelbench/kernel_static_checker.py index 691f5c7e..6c4442c8 100644 --- a/src/kernelbench/kernel_static_checker.py +++ b/src/kernelbench/kernel_static_checker.py @@ -212,6 +212,7 @@ def check_triton_impl(code: str) -> Tuple[bool, str]: # Rationale: TLX extends Triton with async tasks and barriers for specialization. # Valid TLX code MUST use "tlx.async" - this is the key requirement for TLX kernels. TLX_ASYNC_PATTERN = r"tlx\.async" +TLX_FORBIDDEN_PATTERN = r"@triton\.autotune" def check_tlx_impl(code: str) -> Tuple[bool, str]: """ @@ -219,12 +220,13 @@ def check_tlx_impl(code: str) -> Tuple[bool, str]: Requirements: - MUST contain "tlx.async" anywhere in the code (e.g., tlx.async_tasks, tlx.async_task) - - This is a reward hack check: if "tlx.async" appears, the kernel is valid. + - MUST NOT contain "@triton.autotune" """ code = _strip_comments(code) - if "tlx.async" not in code: + if not re.search(TLX_ASYNC_PATTERN, code): return (True, "Missing 'tlx.async' - TLX kernels must use tlx.async_tasks or tlx.async_task") + if re.search(TLX_FORBIDDEN_PATTERN, code): + return (True, "TLX kernels cannot use @triton.autotune") return (False, "") From dd507e1f046e07f19d17ecffbedca18c7333539c Mon Sep 17 00:00:00 2001 From: Willy Chan Date: Mon, 19 Jan 2026 17:00:17 -0800 Subject: [PATCH 11/23] more standardized TLX vec add example --- .../prompts/model_ex_matmul_tlx.py | 24 ---- .../prompts/model_new_ex_add_tlx.py | 41 +++++++ .../prompts/model_new_ex_matmul_tlx.py | 112 ------------------ 3 files changed, 41 insertions(+), 136 deletions(-) delete mode 100644 src/kernelbench/prompts/model_ex_matmul_tlx.py create mode 100644 src/kernelbench/prompts/model_new_ex_add_tlx.py delete mode 100644 src/kernelbench/prompts/model_new_ex_matmul_tlx.py diff --git a/src/kernelbench/prompts/model_ex_matmul_tlx.py b/src/kernelbench/prompts/model_ex_matmul_tlx.py deleted file mode 100644 index cc1db179..00000000 --- a/src/kernelbench/prompts/model_ex_matmul_tlx.py +++ /dev/null @@ -1,24 +0,0 @@ -import torch -import torch.nn as nn - -class Model(nn.Module): - def __init__(self): - super().__init__() - - def forward(self, a, b): - return torch.matmul(a, b) - -def get_inputs(): - # randomly generate input tensors for a matmul operation - # Using sizes compatible with the TLX kernel logic (e.g. divisible by block sizes ideally, though the kernel handles remainders) - # The kernel has BLOCK_SIZE_M=128, BLOCK_SIZE_N=256, BLOCK_SIZE_K=64 - # Let's use standard sizes. - M = 4096 - N = 4096 - K = 4096 - a = torch.randn(M, K).cuda().to(torch.float16) - b = torch.randn(K, N).cuda().to(torch.float16) - return [a, b] - -def get_init_inputs(): - return [] diff --git a/src/kernelbench/prompts/model_new_ex_add_tlx.py b/src/kernelbench/prompts/model_new_ex_add_tlx.py new file mode 100644 index 00000000..ef0ed8dd --- /dev/null +++ b/src/kernelbench/prompts/model_new_ex_add_tlx.py @@ -0,0 +1,41 @@ +import torch +import torch.nn as nn +import triton +import triton.language as tl +import triton.language.extra.tlx as tlx + +@triton.jit +def add_kernel( + x_ptr, + y_ptr, + out_ptr, + n_elements, + BLOCK_SIZE: tl.constexpr, +): + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + + # Wrap in async_tasks to demonstrate TLX syntax and satisfy static checker + with tlx.async_tasks(): + with tlx.async_task("default"): + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + x = tl.load(x_ptr + offsets, mask=mask) + y = tl.load(y_ptr + offsets, mask=mask) + output = x + y + tl.store(out_ptr + offsets, output, mask=mask) + +def add(x: torch.Tensor, y: torch.Tensor): + output = torch.empty_like(x) + assert x.is_cuda and y.is_cuda + n_elements = output.numel() + grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]), ) + add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024) + return output + +class ModelNew(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, a, b): + return add(a, b) diff --git a/src/kernelbench/prompts/model_new_ex_matmul_tlx.py b/src/kernelbench/prompts/model_new_ex_matmul_tlx.py deleted file mode 100644 index f9c84f2d..00000000 --- a/src/kernelbench/prompts/model_new_ex_matmul_tlx.py +++ /dev/null @@ -1,112 +0,0 @@ -import torch -import torch.nn as nn -import triton -import triton.language as tl -import triton.language.extra.tlx as tlx - -@triton.jit -def matmul_kernel_pipelined_hopper(a_ptr, b_ptr, c_ptr, M, N, K, stride_am, stride_ak, # - stride_bk, stride_bn, # - stride_cm, stride_cn, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, - BLOCK_SIZE_K: tl.constexpr, # - GROUP_SIZE_M: tl.constexpr, # - NUM_STAGES: tl.constexpr # - ): - pid = tl.program_id(axis=0) - num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) - num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) - num_pid_in_group = GROUP_SIZE_M * num_pid_n - group_id = pid // num_pid_in_group - first_pid_m = group_id * GROUP_SIZE_M - group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) - pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) - pid_n = (pid % num_pid_in_group) // group_size_m - - # offset computation - offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M - offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N - offs_k = tl.arange(0, BLOCK_SIZE_K) - a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) - b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) - - # allocate NUM_STAGES buffers - buffers_A = tlx.local_alloc((BLOCK_SIZE_M, BLOCK_SIZE_K), tlx.dtype_of(a_ptr), NUM_STAGES) - buffers_B = tlx.local_alloc((BLOCK_SIZE_K, BLOCK_SIZE_N), tlx.dtype_of(b_ptr), NUM_STAGES) - - # prefetch (pipelining) for NUM_STAGES - 1 buffers - for i in tl.range(0, NUM_STAGES - 1, loop_unroll_factor=NUM_STAGES - 1): - a = tlx.local_view(buffers_A, i) - b = tlx.local_view(buffers_B, i) - token_a = tlx.async_load(a_ptrs, a, mask=offs_k[None, :] < K - i * BLOCK_SIZE_K) - token_b = tlx.async_load(b_ptrs, b, mask=offs_k[:, None] < K - i * BLOCK_SIZE_K) - a_ptrs += BLOCK_SIZE_K * stride_ak - b_ptrs += BLOCK_SIZE_K * stride_bk - tlx.async_load_commit_group([token_a, token_b]) - - # main K loop - acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - # Disable auto-pipelining with num_stages=0 - for k in tl.range(0, tl.cdiv(K, BLOCK_SIZE_K), num_stages=0): - # identify the buffer index for the current iteration - buf = k % NUM_STAGES - a_k = tlx.local_view(buffers_A, buf) - b_k = tlx.local_view(buffers_B, buf) - - # wait for buffers to be ready - tlx.async_load_wait_group(NUM_STAGES - 2) - - # do the mma - acc = tlx.async_dot(a_k, b_k, acc) - - # prefetch for i-th iteration, i.e, NUM_STAGES - 1 ahead - i = k + NUM_STAGES - 1 - a_next = tlx.local_view(buffers_A, i % NUM_STAGES) - b_next = tlx.local_view(buffers_B, i % NUM_STAGES) - # wait for the previous MMA using this buffer to complete - acc = tlx.async_dot_wait(1, acc) - # prefetch - token_a = tlx.async_load(a_ptrs, a_next, mask=offs_k[None, :] < K - i * BLOCK_SIZE_K) - token_b = tlx.async_load(b_ptrs, b_next, mask=offs_k[:, None] < K - i * BLOCK_SIZE_K) - tlx.async_load_commit_group([token_a, token_b]) - # Advance the ptrs to the next K block. - a_ptrs += BLOCK_SIZE_K * stride_ak - b_ptrs += BLOCK_SIZE_K * stride_bk - - # wait for last mma to complete - acc = tlx.async_dot_wait(0, acc) - c = acc.to(tlx.dtype_of(c_ptr)) - offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] - c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) - tl.store(c_ptrs, c, mask=c_mask) - - -def matmul(a, b): - # Check constraints. - assert a.shape[1] == b.shape[0], "Incompatible dimensions" - assert a.is_contiguous(), "Matrix A must be contiguous" - M, K = a.shape - K, N = b.shape - # Allocates output. - c = torch.empty((M, N), device=a.device, dtype=torch.float16) - # 1D launch kernel where each block gets its own program. - grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), ) - matmul_kernel_pipelined_hopper[grid]( - a, b, c, # - M, N, K, # - a.stride(0), a.stride(1), # - b.stride(0), b.stride(1), # - c.stride(0), c.stride(1), # - BLOCK_SIZE_M=128, BLOCK_SIZE_N=256, - BLOCK_SIZE_K=64, GROUP_SIZE_M=8, - NUM_STAGES=3 - ) - return c - -class ModelNew(nn.Module): - def __init__(self): - super(ModelNew, self).__init__() - - def forward(self, a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: - return matmul(a, b) From f3659c01396ee52d565448e92ee23aabe68a0895 Mon Sep 17 00:00:00 2001 From: Willy Chan Date: Mon, 19 Jan 2026 17:00:57 -0800 Subject: [PATCH 12/23] changes to prompts.toml for vec add example --- src/kernelbench/prompts/prompts.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/kernelbench/prompts/prompts.toml b/src/kernelbench/prompts/prompts.toml index 4dac7bbe..c2fcb34e 100644 --- a/src/kernelbench/prompts/prompts.toml +++ b/src/kernelbench/prompts/prompts.toml @@ -56,7 +56,7 @@ one_shot_new_arch = "src/kernelbench/prompts/model_new_ex_add_thunderkittens.py" [backends.tlx] backend_display = "TLX (Triton Language Extensions) kernels" -one_shot_new_arch = "src/kernelbench/prompts/model_new_ex_matmul_tlx.py" +one_shot_new_arch = "src/kernelbench/prompts/model_new_ex_add_tlx.py" # No few_shot_examples - will use one-shot when few_shot option is selected # ------------------------------------------------------------------------- From 6808e645507cb33fa3f4447c3593df91120c36f2 Mon Sep 17 00:00:00 2001 From: Willy Chan Date: Mon, 19 Jan 2026 17:17:54 -0800 Subject: [PATCH 13/23] FIXED tlx vec_add example --- .../prompts/model_new_ex_add_tlx.py | 33 ++++++++++++------- 1 file changed, 21 insertions(+), 12 deletions(-) diff --git a/src/kernelbench/prompts/model_new_ex_add_tlx.py b/src/kernelbench/prompts/model_new_ex_add_tlx.py index ef0ed8dd..e8495936 100644 --- a/src/kernelbench/prompts/model_new_ex_add_tlx.py +++ b/src/kernelbench/prompts/model_new_ex_add_tlx.py @@ -5,37 +5,46 @@ import triton.language.extra.tlx as tlx @triton.jit -def add_kernel( +def add_warp_specialized_kernel( x_ptr, y_ptr, - out_ptr, + z_ptr, n_elements, BLOCK_SIZE: tl.constexpr, ): pid = tl.program_id(axis=0) block_start = pid * BLOCK_SIZE - # Wrap in async_tasks to demonstrate TLX syntax and satisfy static checker + # Split the block work across two async tasks to demonstrate specialization + # Task 1: Process first half of the block with tlx.async_tasks(): with tlx.async_task("default"): - offsets = block_start + tl.arange(0, BLOCK_SIZE) + offsets = block_start + tl.arange(0, BLOCK_SIZE // 2) mask = offsets < n_elements x = tl.load(x_ptr + offsets, mask=mask) y = tl.load(y_ptr + offsets, mask=mask) - output = x + y - tl.store(out_ptr + offsets, output, mask=mask) + tl.store(z_ptr + offsets, x + y, mask=mask) + + # Task 2: Process second half of the block with 4 warps + with tlx.async_task(num_warps=4): + offsets = block_start + tl.arange(BLOCK_SIZE // 2, BLOCK_SIZE) + mask = offsets < n_elements + x = tl.load(x_ptr + offsets, mask=mask) + y = tl.load(y_ptr + offsets, mask=mask) + tl.store(z_ptr + offsets, x + y, mask=mask) -def add(x: torch.Tensor, y: torch.Tensor): +def add_warp_specialized(x: torch.Tensor, y: torch.Tensor): output = torch.empty_like(x) - assert x.is_cuda and y.is_cuda + assert x.is_cuda and y.is_cuda and output.is_cuda n_elements = output.numel() + # Ensure BLOCK_SIZE is even for the split grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]), ) - add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024) + add_warp_specialized_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024) return output class ModelNew(nn.Module): def __init__(self): - super().__init__() + super(ModelNew, self).__init__() - def forward(self, a, b): - return add(a, b) + def forward(self, a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: + return add_warp_specialized(a, b) From 75773db49d4d3f107e0c98f0cf881be32a5bdc48 Mon Sep 17 00:00:00 2001 From: Willy Chan Date: Mon, 19 Jan 2026 17:26:46 -0800 Subject: [PATCH 14/23] testing generate and eval on modal script --- .../generate_and_eval_single_sample_modal.py | 33 ++++--------------- 1 file changed, 7 insertions(+), 26 deletions(-) diff --git a/scripts/generate_and_eval_single_sample_modal.py b/scripts/generate_and_eval_single_sample_modal.py index bfddd84c..68115c07 100644 --- a/scripts/generate_and_eval_single_sample_modal.py +++ b/scripts/generate_and_eval_single_sample_modal.py @@ -2,23 +2,6 @@ Example Usage: python scripts/generate_and_eval_single_sample_modal.py dataset_src=huggingfac level=1 problem_id=1 eval_mode=modal gpu=L40S server_type=deepseek model_name=deepseek-coder max_tokens=4096 temperature=0.0 - -TLX Example: -uv run python scripts/generate_and_eval_single_sample_modal.py \ - dataset_src=huggingface \ - level=1 \ - problem_id=1 \ - eval_mode=modal \ - gpu=H100 \ - backend=tlx \ - server_type=google \ - model_name=gemini/gemini-2.5-flash \ - max_tokens=60000 \ - temperature=0.0 \ - log=True \ - log_prompt=True \ - log_generated_kernel=True \ - verbose=True ''' import pydra @@ -114,15 +97,8 @@ def __repr__(self): SRC_DIR = os.path.join(REPO_TOP_DIR, "src") image = ( - modal.Image.from_registry(f"nvidia/cuda:{tag}", add_python="3.11") - .apt_install("git", - "gcc-10", - "g++-10", - "clang", - "cmake", - "ninja-build", - "zlib1g-dev" - ) + modal.Image.from_registry(f"nvidia/cuda:{tag}", add_python="3.10") + .apt_install("git", "gcc-10", "g++-10", "clang", "cmake", "ninja-build", "zlib1g-dev") .uv_sync(uv_project_dir=REPO_TOP_DIR, extras=["gpu"]) .run_commands("git clone https://github.com/HazyResearch/ThunderKittens.git /root/ThunderKittens") # Uninstall standard triton first (fast step, separate layer to avoid rebuilding triton on changes) @@ -250,6 +226,11 @@ def main(config: EvalConfig): if backend == "thunderkittens": config.precision = "bf16" config.gpu = "H100" + + # TLX: for research purposes we only support fp16 + if backend == "tlx": + config.precision = "fp16" + config.gpu = "H100" if not custom_prompt_key: if prompt_option not in valid_prompt_options: From 3cdd9e32ab7161d9fdd7b29b462ba36c81f22786 Mon Sep 17 00:00:00 2001 From: Willy Chan Date: Mon, 19 Jan 2026 17:35:31 -0800 Subject: [PATCH 15/23] instructions for running locally --- scripts/generate_and_eval_single_sample.py | 21 ++++++++------------- 1 file changed, 8 insertions(+), 13 deletions(-) diff --git a/scripts/generate_and_eval_single_sample.py b/scripts/generate_and_eval_single_sample.py index 22514120..ee9a1391 100644 --- a/scripts/generate_and_eval_single_sample.py +++ b/scripts/generate_and_eval_single_sample.py @@ -21,19 +21,14 @@ Example usage: python3 scripts/generate_and_eval_single_sample.py dataset_src=huggingface level=1 problem_id=1 eval_mode=local server_type=google model_name=gemini/gemini-2.5-flash max_tokens=8192 temperature=0.0 -TLX Example (NEED LOCAL GPU): -uv run python scripts/generate_and_eval_single_sample.py \ - dataset_src=huggingface \ - level=1 \ - problem_id=1 \ - backend=tlx \ - server_type=google \ - model_name=gemini/gemini-2.5-flash \ - max_tokens=60000 \ - temperature=0.0 \ - log=True \ - log_generated_kernel=True \ - verbose=True +(YOU NEED A LOCAL H100 TO RUN THIS SCRIPT) +- Because this is research code, you need to make the following changes: +- Make sure you have the following apt-installed: ("git", "gcc-10", "g++-10", "clang", "cmake", "ninja-build", "zlib1g-dev") +- Make sure you run the following commands: +- Uninstall triton: ("pip uninstall -y triton || true") +- Reinstall triton w/ TLX: "git clone --depth 1 https://github.com/facebookexperimental/triton.git /root/triton" +- Reinstall triton w/ TLX: "cd /root/triton && pip install -r python/requirements.txt && pip install -e ." +- Update your PYTHONPATH environment variable: "PYTHONPATH": "/root:/root/src:/root/scripts:/root/triton/python" """ REPO_TOP_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) From 0e95cb6083c526c88a431ada8974e7d0d40cffbf Mon Sep 17 00:00:00 2001 From: Willy Chan Date: Mon, 19 Jan 2026 17:54:09 -0800 Subject: [PATCH 16/23] removed comment --- scripts/generate_samples.py | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/scripts/generate_samples.py b/scripts/generate_samples.py index 5096de20..874d2213 100644 --- a/scripts/generate_samples.py +++ b/scripts/generate_samples.py @@ -23,18 +23,6 @@ Assume 1 sample per problem here -TLX Example: -uv run python scripts/generate_samples.py \ - dataset_src=huggingface \ - level=1 \ - subset="(1,5)" \ - run_name=test_tlx_level1 \ - backend=tlx \ - num_samples=1 \ - server_type=google \ - model_name=gemini/gemini-2.5-flash \ - max_tokens=8192 \ - temperature=0.0 """ REPO_TOP_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) From 33fda21c8a80163d15b4f617795023fbc5bc5a21 Mon Sep 17 00:00:00 2001 From: Willy Chan Date: Mon, 19 Jan 2026 17:54:42 -0800 Subject: [PATCH 17/23] and whitespace --- scripts/generate_samples.py | 1 - 1 file changed, 1 deletion(-) diff --git a/scripts/generate_samples.py b/scripts/generate_samples.py index 874d2213..fe4c955d 100644 --- a/scripts/generate_samples.py +++ b/scripts/generate_samples.py @@ -22,7 +22,6 @@ Batch Generate Samples for Particular Level Assume 1 sample per problem here - """ REPO_TOP_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) From d4f64b304bf279d02f0dfe65c9863f525bd73361 Mon Sep 17 00:00:00 2001 From: Willy Chan Date: Mon, 19 Jan 2026 18:24:33 -0800 Subject: [PATCH 18/23] added work.problem_number to resolve print error --- scripts/generate_samples.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/scripts/generate_samples.py b/scripts/generate_samples.py index fe4c955d..2f086a65 100644 --- a/scripts/generate_samples.py +++ b/scripts/generate_samples.py @@ -153,13 +153,13 @@ def generate_sample_single( # uses the default set of forbidden and warning patterns, # you could adapt the patterns to your own setting (degree of banning cuda stream, allowing some torch ops) ) - assert static_check_status, f"Static check failed for sample {work.sample_id} for problem {problem_number}: {problem_name}. Error: {error}. Warnings: {warnings}" + assert static_check_status, f"Static check failed for sample {work.sample_id} for problem {work.problem_number}: {problem_name}. Error: {error}. Warnings: {warnings}" if warnings: - print(f"Static check warnings for sample {work.sample_id} for problem {problem_number}: {problem_name}. Warnings: {warnings}") + print(f"Static check warnings for sample {work.sample_id} for problem {work.problem_number}: {problem_name}. Warnings: {warnings}") if config.verbose: print( - f"Generated sample {work.sample_id} for problem {problem_number}: {problem_name}" + f"Generated sample {work.sample_id} for problem {work.problem_number}: {problem_name}" ) # Store to local file From 15c0f3a3c97238687b663a640d62fd2586655bb2 Mon Sep 17 00:00:00 2001 From: Willy Chan Date: Mon, 19 Jan 2026 18:29:43 -0800 Subject: [PATCH 19/23] fixed modal image errors in eval from generations script --- scripts/eval_from_generations.py | 20 +++++++------------- 1 file changed, 7 insertions(+), 13 deletions(-) diff --git a/scripts/eval_from_generations.py b/scripts/eval_from_generations.py index bedbdb9c..d94223fd 100644 --- a/scripts/eval_from_generations.py +++ b/scripts/eval_from_generations.py @@ -74,22 +74,16 @@ image = ( modal.Image.from_registry(f"nvidia/cuda:{tag}", add_python="3.10") - .apt_install("git", - "gcc-10", - "g++-10", - "clang", - "cmake", - "ninja-build", - "zlib1g-dev" - ) - + .apt_install("git", "gcc-10", "g++-10", "clang", "cmake", "ninja-build", "zlib1g-dev" ) .uv_sync(uv_project_dir=REPO_TOP_DIR) .run_commands("git clone https://github.com/HazyResearch/ThunderKittens.git /root/ThunderKittens") + # Uninstall standard triton first (fast step, separate layer to avoid rebuilding triton on changes) + .run_commands("pip uninstall -y triton || true") + # Install TLX-enabled Triton (slow step, cached unless repo changes) + .env({"MAX_JOBS": "8"}) # Speed up compilation .run_commands( - "git clone https://github.com/facebookexperimental/triton.git /root/triton && " - "cd /root/triton && " - "pip install -r python/requirements.txt && " - "pip install -e ." + "git clone --depth 1 https://github.com/facebookexperimental/triton.git /root/triton", + "cd /root/triton && pip install -r python/requirements.txt && pip install -e ." ) .env({ "THUNDERKITTENS_ROOT": "/root/ThunderKittens", From feadd5d3d471579ea11851bcd1ff5232871050b7 Mon Sep 17 00:00:00 2001 From: Willy Chan Date: Mon, 19 Jan 2026 18:30:26 -0800 Subject: [PATCH 20/23] removed tlx comment --- scripts/eval_from_generations.py | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/scripts/eval_from_generations.py b/scripts/eval_from_generations.py index d94223fd..90a870b0 100644 --- a/scripts/eval_from_generations.py +++ b/scripts/eval_from_generations.py @@ -43,16 +43,6 @@ - performance (n_trials): 100 randomized input trials You can increase the number of trials for correctness and performance - -TLX Example: -uv run python scripts/eval_from_generations.py \ - run_name=test_tlx_level1 \ - dataset_src=huggingface \ - level=1 \ - subset="(5,6)" \ - eval_mode=modal \ - backend=tlx \ - verbose=True """ REPO_TOP_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) From fca474c0945293652cc59f81fa5350ac97ca894b Mon Sep 17 00:00:00 2001 From: Willy Chan Date: Mon, 19 Jan 2026 19:07:09 -0800 Subject: [PATCH 21/23] updated modal image for baseline time --- scripts/generate_baseline_time_modal.py | 22 +++++++++------------- 1 file changed, 9 insertions(+), 13 deletions(-) diff --git a/scripts/generate_baseline_time_modal.py b/scripts/generate_baseline_time_modal.py index df0c183c..0d085d9a 100644 --- a/scripts/generate_baseline_time_modal.py +++ b/scripts/generate_baseline_time_modal.py @@ -87,22 +87,18 @@ def __init__(self): image = ( modal.Image.from_registry(f"nvidia/cuda:{tag}", add_python="3.10") - .apt_install("git", - "gcc-10", - "g++-10", - "clang", - "cmake", - "ninja-build", - "zlib1g-dev" - ) + .apt_install("git", "gcc-10", "g++-10", "clang", "cmake", "ninja-build", "zlib1g-dev") .uv_sync(uv_project_dir=REPO_TOP_PATH, extras=["gpu"]) + .run_commands("git clone https://github.com/HazyResearch/ThunderKittens.git /root/ThunderKittens") + # Uninstall standard triton first (fast step, separate layer to avoid rebuilding triton on changes) + .run_commands("pip uninstall -y triton || true") + # Install TLX-enabled Triton (slow step, cached unless repo changes) + .env({"MAX_JOBS": "8"}) # Speed up compilation .run_commands( - "git clone https://github.com/facebookexperimental/triton.git /root/triton && " - "cd /root/triton && " - "pip install -r python/requirements.txt && " - "pip install -e ." + "git clone --depth 1 https://github.com/facebookexperimental/triton.git /root/triton", + "cd /root/triton && pip install -r python/requirements.txt && pip install -e ." ) - .env({"PYTHONPATH": "/root/src:/root/triton/python"}) + .env({"PYTHONPATH": "/root:/root/src:/root/scripts:/root/triton/python"}) .add_local_dir(SRC_DIR, remote_path="/root/src") .add_local_dir(KERNELBENCH_DIR, remote_path="/root/KernelBench") # must be last ) From 1f9558bfdb73954efeb1c350ce08f3c4c85a4a64 Mon Sep 17 00:00:00 2001 From: Willy Chan Date: Mon, 19 Jan 2026 21:50:38 -0800 Subject: [PATCH 22/23] static checker changes --- src/kernelbench/kernel_static_checker.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/kernelbench/kernel_static_checker.py b/src/kernelbench/kernel_static_checker.py index 6c4442c8..6314311e 100644 --- a/src/kernelbench/kernel_static_checker.py +++ b/src/kernelbench/kernel_static_checker.py @@ -210,7 +210,7 @@ def check_triton_impl(code: str) -> Tuple[bool, str]: # <========= TLX (Triton Language Extensions) CHECKS =========> # Rationale: TLX extends Triton with async tasks and barriers for specialization. -# Valid TLX code MUST use "tlx.async" - this is the key requirement for TLX kernels. +# Valid TLX code MUST use "tlx.async" primitives. TLX_ASYNC_PATTERN = r"tlx\.async" TLX_FORBIDDEN_PATTERN = r"@triton\.autotune" @@ -219,12 +219,12 @@ def check_tlx_impl(code: str) -> Tuple[bool, str]: Check for valid TLX (Triton Language Extensions) kernel implementation. Requirements: - - MUST contain "tlx.async" anywhere in the code (e.g., tlx.async_tasks, tlx.async_task) + - MUST contain "tlx.async" primitives (e.g., tlx.async_tasks, tlx.async_load, tlx.async_dot). - MUST NOT contain "@triton.autotune" """ code = _strip_comments(code) if not re.search(TLX_ASYNC_PATTERN, code): - return (True, "Missing 'tlx.async' - TLX kernels must use tlx.async_tasks or tlx.async_task") + return (True, "Missing 'tlx.async' - TLX kernels must use tlx.async primitives") if re.search(TLX_FORBIDDEN_PATTERN, code): return (True, "TLX kernels cannot use @triton.autotune") return (False, "") From c305cc5dd5a1ef8aa361dcce780688abb2aaf8aa Mon Sep 17 00:00:00 2001 From: Nathan Paek Date: Tue, 20 Jan 2026 07:12:02 +0000 Subject: [PATCH 23/23] saving level 9 problems for reference --- KernelBench/level9/1d_occupancy_decoder.py | 164 +++++++++ KernelBench/level9/abc_reference.py | 261 ++++++++++++++ KernelBench/level9/based_reference.py | 115 ++++++ KernelBench/level9/comba_reference.py | 309 ++++++++++++++++ KernelBench/level9/deltaformer_reference.py | 284 +++++++++++++++ KernelBench/level9/fd.py | 134 +++++++ KernelBench/level9/feature_map_reference.py | 55 +++ KernelBench/level9/fox_reference.py | 154 ++++++++ .../level9/fused_bitlinear_reference.py | 72 ++++ .../level9/fused_cross_entropy_reference.py | 79 ++++ KernelBench/level9/fused_kl_div_reference.py | 53 +++ .../fused_linear_cross_entropy_reference.py | 41 +++ .../level9/fused_norm_gate_reference.py | 64 ++++ .../level9/fused_rms_norm_silu_reference.py | 41 +++ .../level9/gated_deltanet_reference.py | 184 ++++++++++ .../level9/gated_deltaproduct_reference.py | 288 +++++++++++++++ KernelBench/level9/gla_reference.py | 149 ++++++++ KernelBench/level9/grpo_reference.py | 61 ++++ KernelBench/level9/gsa_reference.py | 129 +++++++ KernelBench/level9/hgrn2_reference.py | 120 +++++++ KernelBench/level9/hgrn_reference.py | 121 +++++++ KernelBench/level9/kayvon_rl_kernel.py | 50 +++ KernelBench/level9/kda_reference.py | 179 ++++++++++ KernelBench/level9/l2_norm_reference.py | 35 ++ KernelBench/level9/l2_wrap_reference.py | 56 +++ .../level9/layernorm_gated_reference.py | 53 +++ KernelBench/level9/layernorm_reference.py | 60 ++++ KernelBench/level9/lightnet_reference.py | 136 +++++++ .../level9/log_linear_attn_reference.py | 107 ++++++ KernelBench/level9/mamba2_reference.py | 207 +++++++++++ KernelBench/level9/mesa_net_reference.py | 337 ++++++++++++++++++ KernelBench/level9/mom_reference.py | 312 ++++++++++++++++ KernelBench/level9/nsa_reference.py | 147 ++++++++ .../parallel_forgetting_attn_reference.py | 85 +++++ KernelBench/level9/path_attn_reference.py | 310 ++++++++++++++++ KernelBench/level9/rebased_reference.py | 100 ++++++ KernelBench/level9/retnet_reference.py | 107 ++++++ KernelBench/level9/rk4.py | 172 +++++++++ KernelBench/level9/rodimus_reference.py | 259 ++++++++++++++ KernelBench/level9/rotary_reference.py | 61 ++++ KernelBench/level9/rwkv6_reference.py | 237 ++++++++++++ KernelBench/level9/rwkv7_reference.py | 264 ++++++++++++++ KernelBench/level9/samba_reference.py | 263 ++++++++++++++ KernelBench/level9/short_conv_reference.py | 48 +++ KernelBench/level9/token_shift_reference.py | 37 ++ KernelBench/level9/trimul_reference.py | 79 ++++ 46 files changed, 6579 insertions(+) create mode 100644 KernelBench/level9/1d_occupancy_decoder.py create mode 100644 KernelBench/level9/abc_reference.py create mode 100644 KernelBench/level9/based_reference.py create mode 100644 KernelBench/level9/comba_reference.py create mode 100644 KernelBench/level9/deltaformer_reference.py create mode 100644 KernelBench/level9/fd.py create mode 100644 KernelBench/level9/feature_map_reference.py create mode 100644 KernelBench/level9/fox_reference.py create mode 100644 KernelBench/level9/fused_bitlinear_reference.py create mode 100644 KernelBench/level9/fused_cross_entropy_reference.py create mode 100644 KernelBench/level9/fused_kl_div_reference.py create mode 100644 KernelBench/level9/fused_linear_cross_entropy_reference.py create mode 100644 KernelBench/level9/fused_norm_gate_reference.py create mode 100644 KernelBench/level9/fused_rms_norm_silu_reference.py create mode 100644 KernelBench/level9/gated_deltanet_reference.py create mode 100644 KernelBench/level9/gated_deltaproduct_reference.py create mode 100644 KernelBench/level9/gla_reference.py create mode 100644 KernelBench/level9/grpo_reference.py create mode 100644 KernelBench/level9/gsa_reference.py create mode 100644 KernelBench/level9/hgrn2_reference.py create mode 100644 KernelBench/level9/hgrn_reference.py create mode 100644 KernelBench/level9/kayvon_rl_kernel.py create mode 100644 KernelBench/level9/kda_reference.py create mode 100644 KernelBench/level9/l2_norm_reference.py create mode 100644 KernelBench/level9/l2_wrap_reference.py create mode 100644 KernelBench/level9/layernorm_gated_reference.py create mode 100644 KernelBench/level9/layernorm_reference.py create mode 100644 KernelBench/level9/lightnet_reference.py create mode 100644 KernelBench/level9/log_linear_attn_reference.py create mode 100644 KernelBench/level9/mamba2_reference.py create mode 100644 KernelBench/level9/mesa_net_reference.py create mode 100644 KernelBench/level9/mom_reference.py create mode 100644 KernelBench/level9/nsa_reference.py create mode 100644 KernelBench/level9/parallel_forgetting_attn_reference.py create mode 100644 KernelBench/level9/path_attn_reference.py create mode 100644 KernelBench/level9/rebased_reference.py create mode 100644 KernelBench/level9/retnet_reference.py create mode 100644 KernelBench/level9/rk4.py create mode 100644 KernelBench/level9/rodimus_reference.py create mode 100644 KernelBench/level9/rotary_reference.py create mode 100644 KernelBench/level9/rwkv6_reference.py create mode 100644 KernelBench/level9/rwkv7_reference.py create mode 100644 KernelBench/level9/samba_reference.py create mode 100644 KernelBench/level9/short_conv_reference.py create mode 100644 KernelBench/level9/token_shift_reference.py create mode 100644 KernelBench/level9/trimul_reference.py diff --git a/KernelBench/level9/1d_occupancy_decoder.py b/KernelBench/level9/1d_occupancy_decoder.py new file mode 100644 index 00000000..9adeb28e --- /dev/null +++ b/KernelBench/level9/1d_occupancy_decoder.py @@ -0,0 +1,164 @@ +import math +import torch +import torch.nn as nn + + +def init_linear(module, embed_dim: int): + if isinstance(module, nn.Linear): + nn.init.normal_(module.weight, std=math.sqrt(1.0 / embed_dim)) + if module.bias is not None: + torch.nn.init.zeros_(module.bias) + + +class MLPEmbedder(nn.Module): + """MLP with SiLU activation for query embedding.""" + + def __init__(self, in_dim: int, embed_dim: int, bias: bool = True): + super().__init__() + self.in_layer = nn.Linear(in_dim, embed_dim, bias=bias) + self.silu = nn.SiLU() + self.out_layer = nn.Linear(embed_dim, embed_dim, bias=bias) + self.apply(lambda m: init_linear(m, embed_dim)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.out_layer(self.silu(self.in_layer(x))) + + +class LayerNorm(nn.LayerNorm): + def forward(self, input: torch.Tensor): + y = super().forward(input.float()) + return y.type_as(input) + + +class CrossAttention(nn.Module): + def __init__( + self, + embed_dim: int, + num_heads: int, + q_dim=None, + kv_dim=None, + bias: bool = True, + ): + super().__init__() + assert embed_dim % num_heads == 0 + + q_dim = q_dim or embed_dim + kv_dim = kv_dim or embed_dim + + self.c_q = nn.Linear(q_dim, embed_dim, bias=bias) + self.c_k = nn.Linear(kv_dim, embed_dim, bias=bias) + self.c_v = nn.Linear(kv_dim, embed_dim, bias=bias) + self.c_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.num_heads = num_heads + + def forward(self, x, c, attn_mask=None, is_causal: bool = False): + q, k = self.c_q(x), self.c_k(c) + v = self.c_v(c) + + b, l, d = q.shape + s = k.shape[1] + + q = q.view(b, l, self.num_heads, -1).transpose(1, 2) # (B, nh, T, hs) + k = k.view(b, s, self.num_heads, -1).transpose(1, 2) # (B, nh, T, hs) + v = v.view(b, s, self.num_heads, -1).transpose(1, 2) # (B, nh, T, hs) + + y = torch.nn.functional.scaled_dot_product_attention( + q, + k, + v, + attn_mask=attn_mask, + dropout_p=0.0, + is_causal=(attn_mask is not None) and is_causal, + ) + + y = y.transpose(1, 2).contiguous().view(b, l, d) + y = self.c_proj(y) + return y + + +class OneDOccupancyDecoder(nn.Module): + """ + Simplified 1DOccupancyDecoder forward pass. + - 250k queries attending to 1k KV tokens + - MLP with SiLU activation for query projection + - Cross-attention with LayerNorm + - Output projection + + Args: + q_in_dim: Input dimension for queries + width: The width of the intermediate layers. + num_heads: The number of attention heads for the cross-attention layer. + out_features: Output dimension + eps: Epsilon for layer normalization + """ + + def __init__( + self, + q_in_dim: int = 3, + width: int = 768, + num_heads: int = 12, + out_features: int = 1, + eps: float = 1e-6, + ): + super().__init__() + + self.query_in = MLPEmbedder(q_in_dim, width) + self.attn = CrossAttention( + embed_dim=width, + num_heads=num_heads, + bias=True, + ) + self.ln = LayerNorm(width, elementwise_affine=False, eps=eps) + self.out_proj = nn.Linear(width, out_features) + + def forward(self, queries: torch.Tensor, latents: torch.Tensor): + """ + Forward pass. + + Args: + queries: Input queries of shape [batch_size, num_queries, q_in_dim] + latents: Input latents of shape [batch_size, num_latents, width] + + Returns: + Output tensor of shape [batch_size, num_queries, out_features] + """ + q = self.query_in(queries) + x = self.attn(q, latents) + x = self.out_proj(self.ln(x)) + return x + + +class Model(nn.Module): + """Reference implementation that wraps `OneDOccupancyDecoder`.""" + + def __init__(self, q_in_dim: int, width: int, num_heads: int) -> None: + super().__init__() + self.decoder = OneDOccupancyDecoder( + q_in_dim=q_in_dim, + width=width, + num_heads=num_heads, + out_features=1, + ) + + def forward(self, queries: torch.Tensor, latents: torch.Tensor) -> torch.Tensor: + return self.decoder(queries, latents) + + +# Problem configuration +batch_size = 1 +num_queries = 8192 +num_latents = 1024 +width = 768 +num_heads = 12 +q_in_dim = 3 + + +def get_inputs(): + queries = torch.randn(batch_size, num_queries, q_in_dim) + latents = torch.randn(batch_size, num_latents, width) + return [queries, latents] + + +def get_init_inputs(): + return [q_in_dim, width, num_heads] + diff --git a/KernelBench/level9/abc_reference.py b/KernelBench/level9/abc_reference.py new file mode 100644 index 00000000..53651cbd --- /dev/null +++ b/KernelBench/level9/abc_reference.py @@ -0,0 +1,261 @@ + +import torch +import torch.nn as nn +import torch.nn.functional as F +import math + +# --- Simplified Modules --- + +class RMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-5): + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.eps = eps + + def forward(self, x): + variance = x.pow(2).mean(-1, keepdim=True) + x = x * torch.rsqrt(variance + self.eps) + return self.weight * x + +class FusedRMSNormGated(nn.Module): + def __init__(self, hidden_size, elementwise_affine=True, eps=1e-5): + super().__init__() + self.hidden_size = hidden_size + self.eps = eps + self.elementwise_affine = elementwise_affine + if elementwise_affine: + self.weight = nn.Parameter(torch.ones(hidden_size)) + else: + self.register_parameter('weight', None) + + def forward(self, x, g): + # In RMSNormGated, it typically normalizes x and then gates with g + variance = x.pow(2).mean(-1, keepdim=True) + x = x * torch.rsqrt(variance + self.eps) + if self.weight is not None: + x = x * self.weight + return x * F.silu(g) + +class RotaryEmbedding(nn.Module): + def __init__(self, dim, base=10000.0): + super().__init__() + self.dim = dim + self.base = base + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) + self.register_buffer("inv_freq", inv_freq) + + def forward(self, q, k): + # q, k: [B, T, H, D] + T = q.shape[1] + device = q.device + t = torch.arange(T, device=device, dtype=q.dtype) + freqs = torch.outer(t, self.inv_freq) # [T, D/2] + emb = torch.cat((freqs, freqs), dim=-1) # [T, D] + cos = emb.cos()[None, :, None, :] # [1, T, 1, D] + sin = emb.sin()[None, :, None, :] # [1, T, 1, D] + + def rotate_half(x): + x1, x2 = x.chunk(2, dim=-1) + return torch.cat((-x2, x1), dim=-1) + + q = (q * cos) + (rotate_half(q) * sin) + k = (k * cos) + (rotate_half(k) * sin) + return q, k + +class ShortConvolution(nn.Module): + def __init__(self, hidden_size, kernel_size, bias=False, activation='silu'): + super().__init__() + self.hidden_size = hidden_size + self.kernel_size = kernel_size + self.conv = nn.Conv1d( + in_channels=hidden_size, + out_channels=hidden_size, + kernel_size=kernel_size, + groups=hidden_size, + padding=0, + bias=bias + ) + self.activation = activation + + def forward(self, x): + # x: [B, T, D] + B, T, D = x.shape + # Causal padding + x_padded = F.pad(x.transpose(1, 2), (self.kernel_size - 1, 0)) # [B, D, T + K-1] + x = self.conv(x_padded) # [B, D, T] + if self.activation == 'silu': + x = F.silu(x) + return x.transpose(1, 2) + +# --- Core Attention Mechanism --- + +def chunk_abc_ref(q, k, v, s): + # Functional reference implementation of ABC attention + # q, k, v, s: [B, T, H, D] (or D_v for v, M for s) + + B, T, H, K = q.shape + V = v.shape[-1] + M = s.shape[-1] + + # Transpose to [B, H, T, ...] + q = q.transpose(1, 2).float() + k = k.transpose(1, 2).float() + v = v.transpose(1, 2).float() + s = s.transpose(1, 2).float() + + scale = K ** -0.5 + + # 1. Compute gating and normalization tokens + # Using logcumsumexp for numerical stability + z = s.logcumsumexp(2) + # g factor for the state recurrence: exp(z_{i-1} - z_i) + z_prev = torch.cat((torch.zeros_like(z[:, :, :1]), z[:, :, :-1]), 2) + g = (z_prev - z).exp() + # Normalized slot weights: exp(s - z) + s_norm = (s - z).exp() + + # 2. Sequential Key-Slot update (hk state) + # hk: [B, H, K, M] + hk = torch.zeros(B, H, K, M, device=q.device, dtype=torch.float32) + ok = torch.zeros(B, H, T, M, device=q.device, dtype=torch.float32) + + for i in range(T): + qi = q[:, :, i] * scale + ki = k[:, :, i] + si = s_norm[:, :, i] + gi = g[:, :, i] + + # State update: hk = hk * g + k^T * s_norm + hk = hk * gi[..., None, :] + ki[..., None] * si[..., None, :] + # Output: query * hk + ok[:, :, i] = (qi[..., None] * hk).sum(-2) + + # 3. Sequential Slot-Value update (hv state) + # Interaction between slots and values based on ok + qv = ok.softmax(-1) + + hv = torch.zeros(B, H, M, V, device=q.device, dtype=torch.float32) + ov = torch.zeros(B, H, T, V, device=q.device, dtype=torch.float32) + + for i in range(T): + qvi = qv[:, :, i] + ki = s_norm[:, :, i] + vi = v[:, :, i] + gi = g[:, :, i] + + # State update: hv = hv * g + s_norm^T * v + hv = hv * gi[..., :, None] + ki[..., None] * vi[..., None, :] + # Output: qv * hv + ov[:, :, i] = (qvi[..., None] * hv).sum(-2) + + # Transpose back to [B, T, H, V] + return ov.transpose(1, 2).to(q.dtype) + +# --- Main Model --- + +class Model(nn.Module): + def __init__( + self, + hidden_size: int = 1024, + num_heads: int = 4, + num_slots: int = 64, + expand_k: float = 0.5, + expand_v: float = 1.0, + use_short_conv: bool = True, + conv_size: int = 4, + use_rope: bool = True, + use_output_gate: bool = True, + use_norm: bool = True, + norm_eps: float = 1e-5, + ): + super(Model, self).__init__() + + self.hidden_size = hidden_size + self.num_heads = num_heads + self.num_slots = num_slots + self.key_dim = int(hidden_size * expand_k) + self.value_dim = int(hidden_size * expand_v) + self.head_k_dim = self.key_dim // num_heads + self.head_v_dim = self.value_dim // num_heads + + self.q_proj = nn.Linear(hidden_size, self.key_dim, bias=False) + self.k_proj = nn.Linear(hidden_size, self.key_dim, bias=False) + self.v_proj = nn.Linear(hidden_size, self.value_dim, bias=False) + self.s_proj = nn.Linear(hidden_size, num_heads * num_slots, bias=False) + self.o_proj = nn.Linear(self.value_dim, hidden_size, bias=False) + + if use_output_gate: + self.g_proj = nn.Linear(hidden_size, self.value_dim, bias=False) + + self.use_short_conv = use_short_conv + if use_short_conv: + self.q_conv1d = ShortConvolution(self.key_dim, conv_size) + self.k_conv1d = ShortConvolution(self.key_dim, conv_size) + self.v_conv1d = ShortConvolution(self.value_dim, conv_size) + + self.use_rope = use_rope + if use_rope: + self.rotary = RotaryEmbedding(self.head_k_dim) + + self.use_norm = use_norm + if use_norm: + if use_output_gate: + self.g_norm = FusedRMSNormGated(self.head_v_dim, eps=norm_eps) + else: + self.g_norm = RMSNorm(self.head_v_dim, eps=norm_eps) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + # 1. Projections + q = self.q_proj(hidden_states) + k = self.k_proj(hidden_states) + v = self.v_proj(hidden_states) + + # 2. Short Convolution + if self.use_short_conv: + q = self.q_conv1d(q) + k = self.k_conv1d(k) + v = self.v_conv1d(v) + + # 3. Reshape and RoPE + # [B, T, D] -> [B, T, H, d] + b, t, d = hidden_states.shape + q = q.view(b, t, self.num_heads, self.head_k_dim) + k = k.view(b, t, self.num_heads, self.head_k_dim) + v = v.view(b, t, self.num_heads, self.head_v_dim) + + if self.use_rope: + q, k = self.rotary(q, k) + + # 4. Slot projection + s = self.s_proj(hidden_states).view(b, t, self.num_heads, self.num_slots) + # Numerical stability clamp as per original + s = s.clamp(-32, 32) + + # 5. Core ABC Attention + o = chunk_abc_ref(q, k, v, s) + + # 6. Gating and Norm + if self.use_norm: + if hasattr(self, 'g_proj'): + g = self.g_proj(hidden_states).view(b, t, self.num_heads, self.head_v_dim) + o = self.g_norm(o, g) + else: + o = self.g_norm(o) + elif hasattr(self, 'g_proj'): + g = self.g_proj(hidden_states).view(b, t, self.num_heads, self.head_v_dim) + o = o * F.silu(g) + + # 7. Final Output Projection + o = o.reshape(b, t, self.value_dim) + return self.o_proj(o) + +# --- Kernelbench API --- + +def get_inputs(): + # Batch size 8, Sequence length 128 (reference is slow), Hidden size 1024 + hidden_states = torch.randn(8, 128, 1024) + return [hidden_states] + +def get_init_inputs(): + # [hidden_size, num_heads, num_slots] + return [1024, 4, 64] diff --git a/KernelBench/level9/based_reference.py b/KernelBench/level9/based_reference.py new file mode 100644 index 00000000..8ca8da3a --- /dev/null +++ b/KernelBench/level9/based_reference.py @@ -0,0 +1,115 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import math + +# Inlined helper from fla/modules/feature_map.py +def flatten_diag_outer_product_off1(x, y): + z = torch.einsum("...i,...j->...ij", x, y) + N = z.size(-1) + indicies = torch.triu_indices(N, N, 1) + indices2 = torch.arange(0, N) + return z[..., indicies[0], indicies[1]], z[..., indices2, indices2] + +class TaylorFeatureMap(nn.Module): + def __init__(self, head_dim: int): + super().__init__() + self.head_dim = head_dim + self.r2 = math.sqrt(2) + self.rd = math.sqrt(self.head_dim) + self.rrd = math.sqrt(self.rd) + + def forward(self, x: torch.Tensor): + x2_1, x2_2 = flatten_diag_outer_product_off1(x, x) + return torch.cat([torch.ones_like(x[..., 0:1]), x / self.rrd, x2_2 / (self.rd * self.r2), x2_1 / self.rd], dim=-1) + +class Model(nn.Module): + """ + Reference implementation of Based Linear Attention. + """ + def __init__( + self, + hidden_size: int = 1024, + feature_dim: int = 16, + num_heads: int = 16, + causal: bool = True, + eps: float = 1e-12 + ): + super(Model, self).__init__() + self.hidden_size = hidden_size + self.feature_dim = feature_dim + self.num_heads = num_heads + self.head_dim = hidden_size // num_heads + self.causal = causal + self.eps = eps + + self.q_proj = nn.Linear(hidden_size, feature_dim * num_heads, bias=False) + self.k_proj = nn.Linear(hidden_size, feature_dim * num_heads, bias=False) + self.v_proj = nn.Linear(hidden_size, num_heads * self.head_dim, bias=False) + self.o_proj = nn.Linear(num_heads * self.head_dim, hidden_size, bias=False) + self.feature_map = TaylorFeatureMap(feature_dim) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + """ + Args: + hidden_states (torch.Tensor): Input tensor of shape (batch_size, seq_len, hidden_size). + Returns: + torch.Tensor: Output tensor of shape (batch_size, seq_len, hidden_size). + """ + b, t, h = hidden_states.size() + + # Projections + q = self.q_proj(hidden_states) + k = self.k_proj(hidden_states) + v = self.v_proj(hidden_states) + + # Reshape to heads + q = q.view(b, t, self.num_heads, self.feature_dim).transpose(1, 2) + k = k.view(b, t, self.num_heads, self.feature_dim).transpose(1, 2) + v = v.view(b, t, self.num_heads, self.head_dim).transpose(1, 2) + + # Apply Taylor feature map + # q, k: [b, h, t, d_feature] -> [b, h, t, d_expanded] + q = self.feature_map(q) + k = self.feature_map(k) + + # Linear attention computation + # q, k: [b, h, t, d_expanded] + # v: [b, h, t, d_head] + + q, k, v = q.unsqueeze(-2), k.unsqueeze(-2), v.unsqueeze(-1) + # q: [b, h, t, 1, d_expanded] + # k: [b, h, t, 1, d_expanded] + # v: [b, h, t, d_head, 1] + + if self.causal: + # kv: [b, h, t, d_head, d_expanded] + kv = (k * v).cumsum(2) + y = (q * kv).sum(-1) # [b, h, t, d_head] + z = k.cumsum(2) + denom = (q * z).sum(-1) + self.eps + y = y / denom + else: + kv = (k * v).sum(2, keepdim=True) + y = (q * kv).sum(-1) + z = k.sum(2, keepdim=True) + denom = (q * z).sum(-1) + self.eps + y = y / denom + + # y: [b, h, t, d_head] -> [b, t, h * d_head] + y = y.transpose(1, 2).contiguous().view(b, t, self.num_heads * self.head_dim) + return self.o_proj(y) + +# Kernelbench Parameters +batch_size = 4 +seq_len = 1024 +hidden_size = 1024 +feature_dim = 16 +num_heads = 16 + +def get_inputs(): + hidden_states = torch.randn(batch_size, seq_len, hidden_size) + return [hidden_states] + +def get_init_inputs(): + return [hidden_size, feature_dim, num_heads, True, 1e-12] diff --git a/KernelBench/level9/comba_reference.py b/KernelBench/level9/comba_reference.py new file mode 100644 index 00000000..56f775f0 --- /dev/null +++ b/KernelBench/level9/comba_reference.py @@ -0,0 +1,309 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import math + + +class Model(nn.Module): + """ + Comba (Closed-loop Control Bilinear RNN) Attention - Reference implementation. + + Comba improves upon bilinear RNNs (like Delta Rule) with "closed-loop control": + - Uses an auxiliary key `p` (decayed version of `k`) for the prediction step + - Uses the regular key `k` for the state update step + - This separation allows better control over the memory dynamics + + The core recurrence: + # Prediction using auxiliary key p (closed-loop feedback) + v_new = v[t] - h @ p[t] # Subtract current state's prediction using p + + # State decay + h = h * exp(g[t]) + + # Scale by beta + v_new = v_new * beta[t] + + # Update using regular key k + h = h + outer(k[t], v_new) + + # Output + o[t] = h @ q[t] + + The key difference from standard Delta Rule: + - Delta Rule: uses same k for both prediction (v - h @ k) and update (h + k @ v) + - Comba: uses p for prediction (v - h @ p) and k for update (h + k @ v) + + This "closed-loop" structure (using different keys) provides better control. + + Based on: "Comba: Improving Bilinear RNNs with Closed-loop Control" + https://arxiv.org/abs/2506.02475 + """ + + def __init__( + self, + hidden_size: int = 2048, + expand_v: float = 2.0, + head_dim: int = 256, + num_heads: int = 6, + num_v_heads: int = None, + use_output_gate: bool = True, + use_output_correction: bool = True, + use_inner_decay: bool = True, + correction_factor: float = 1.0, + ): + super().__init__() + + self.hidden_size = hidden_size + self.expand_v = expand_v + self.head_dim = head_dim + self.num_heads = num_heads + self.num_v_heads = num_v_heads if num_v_heads is not None else num_heads + + self.head_k_dim = head_dim + self.head_v_dim = int(head_dim * expand_v) + self.key_dim = num_heads * head_dim + self.value_dim = int(self.num_v_heads * self.head_v_dim) + + self.use_output_gate = use_output_gate + self.use_output_correction = use_output_correction + self.use_inner_decay = use_inner_decay + + # Projections + self.q_proj = nn.Linear(hidden_size, self.key_dim, bias=False) + self.k_proj = nn.Linear(hidden_size, self.key_dim, bias=False) + self.v_proj = nn.Linear(hidden_size, self.value_dim, bias=False) + + # Gate projections for decay + self.a_proj = nn.Linear(hidden_size, self.num_v_heads, bias=False) # For computing g + self.b_proj = nn.Linear(hidden_size, self.num_v_heads, bias=False) # For computing beta + + # Inner decay parameter (for computing auxiliary key p) + if use_inner_decay: + self.decay = nn.Parameter(torch.ones(num_heads)) + + # Output correction parameter D + if use_output_correction: + self.D = nn.Parameter(torch.ones(num_heads) * correction_factor) + + # Learnable decay parameters (A_log and dt_bias, like Mamba) + A = torch.empty(self.num_v_heads, dtype=torch.float32).uniform_(0, 16) + self.A_log = nn.Parameter(torch.log(A)) + + dt_min, dt_max = 0.001, 0.1 + dt = torch.exp(torch.rand(self.num_v_heads) * (math.log(dt_max) - math.log(dt_min)) + math.log(dt_min)) + dt = torch.clamp(dt, min=1e-4) + inv_dt = dt + torch.log(-torch.expm1(-dt)) + self.dt_bias = nn.Parameter(inv_dt) + + # Output gate and projection + if use_output_gate: + self.g_proj = nn.Linear(hidden_size, self.value_dim, bias=False) + self.o_proj = nn.Linear(self.value_dim, hidden_size, bias=False) + + # RMSNorm weight + self.o_norm_weight = nn.Parameter(torch.ones(self.head_v_dim)) + + def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None) -> torch.Tensor: + """ + Args: + x: Input tensor of shape (batch_size, seq_len, hidden_size) + attention_mask: Optional mask (unused in this reference) + + Returns: + Output tensor of shape (batch_size, seq_len, hidden_size) + """ + batch_size, seq_len, _ = x.shape + + # Project to Q, K, V + q = self.q_proj(x) # [B, T, key_dim] + k = self.k_proj(x) # [B, T, key_dim] + v = self.v_proj(x) # [B, T, value_dim] + + # Apply SiLU activation (simulating short convolution effect) + q = F.silu(q) + k = F.silu(k) + v = F.silu(v) + + # Reshape to multi-head format + q = q.view(batch_size, seq_len, self.num_heads, self.head_k_dim) + k = k.view(batch_size, seq_len, self.num_heads, self.head_k_dim) + v = v.view(batch_size, seq_len, self.num_v_heads, self.head_v_dim) + + # Compute auxiliary key p (decayed version of k) + if self.use_inner_decay: + decay_factor = torch.sigmoid(self.decay).view(1, 1, self.num_heads, 1) + p = k * decay_factor # [B, T, num_heads, head_k_dim] + else: + p = k + + # Output correction: q = q - D * p + if self.use_output_correction: + D = self.D.view(1, 1, self.num_heads, 1) + q = q - D * p + + # Expand Q, K, P for GVA (Grouped Value Attention) if needed + if self.num_v_heads > self.num_heads: + expand_ratio = self.num_v_heads // self.num_heads + q = q.repeat_interleave(expand_ratio, dim=2) + k = k.repeat_interleave(expand_ratio, dim=2) + p = p.repeat_interleave(expand_ratio, dim=2) + + # Compute beta (sigmoid) and g (decay) + beta = torch.sigmoid(self.b_proj(x)) # [B, T, num_v_heads] + g = -self.A_log.float().exp() * F.softplus(self.a_proj(x).float() + self.dt_bias) # [B, T, num_v_heads] + + # ============================================ + # Comba Recurrence (Closed-loop Control) + # ============================================ + o = self._comba_recurrence(q, k, p, v, g, beta) + + # Output normalization and gating + if self.use_output_gate: + gate = self.g_proj(x).view(batch_size, seq_len, self.num_v_heads, self.head_v_dim) + o = self._gated_rms_norm(o, gate, self.o_norm_weight) + else: + o = self._rms_norm(o, self.o_norm_weight) + + # Reshape and project output + o = o.view(batch_size, seq_len, self.value_dim) + o = self.o_proj(o) + + return o + + def _comba_recurrence( + self, + q: torch.Tensor, + k: torch.Tensor, + p: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor + ) -> torch.Tensor: + """ + Comba recurrence with closed-loop control. + + Key difference from Delta Rule: + - Uses auxiliary key `p` for prediction: v_new = v - h @ p + - Uses regular key `k` for update: h = h + outer(k, v_new) + + Args: + q: [B, T, num_v_heads, head_k_dim] - queries + k: [B, T, num_v_heads, head_k_dim] - keys (for state update) + p: [B, T, num_v_heads, head_k_dim] - auxiliary keys (for prediction) + v: [B, T, num_v_heads, head_v_dim] - values + g: [B, T, num_v_heads] - decay gate (negative log values) + beta: [B, T, num_v_heads] - scaling factor + + Returns: + o: [B, T, num_v_heads, head_v_dim] - output + """ + B, T, H, K = q.shape + V = v.shape[-1] + + scale = K ** -0.5 + + # Work in float32 for stability + q = q.float() + k = k.float() + p = p.float() + v = v.float() + g = g.float() + beta = beta.float() + + outputs = [] + + for b in range(B): + batch_outputs = [] + + for h in range(H): + # Initialize state: h_state is [head_k_dim, head_v_dim] + h_state = torch.zeros(K, V, device=q.device, dtype=torch.float32) + + head_outputs = [] + + for t in range(T): + # Get current vectors + q_t = q[b, t, h] # [K] + k_t = k[b, t, h] # [K] + p_t = p[b, t, h] # [K] - auxiliary key + v_t = v[b, t, h] # [V] + g_t = g[b, t, h] # scalar + beta_t = beta[b, t, h] # scalar + + # L2 normalize q, k, p + q_t = F.normalize(q_t, p=2, dim=-1) + k_t = F.normalize(k_t, p=2, dim=-1) + p_t = F.normalize(p_t, p=2, dim=-1) + + # Scale query + q_t = q_t * scale + + # ===== CLOSED-LOOP CONTROL ===== + # Prediction using auxiliary key p (NOT k) + # This is the key difference from standard Delta Rule + prediction = h_state.T @ p_t # [V] = h_state^T @ p + v_new = v_t - prediction # Subtract prediction from value + + # Decay the state + h_state = h_state * torch.exp(g_t) + + # Scale by beta + v_new = v_new * beta_t + + # Update state using regular key k (NOT p) + # h = h + outer(k, v_new) + h_state = h_state + torch.outer(k_t, v_new) + + # Output: o = h @ q + o_t = h_state.T @ q_t # [V] + + head_outputs.append(o_t) + + # Stack outputs for this head: [T, V] + head_outputs = torch.stack(head_outputs, dim=0) + batch_outputs.append(head_outputs) + + # Stack outputs for this batch: [T, H, V] + batch_outputs = torch.stack(batch_outputs, dim=1) + outputs.append(batch_outputs) + + # Stack all batches: [B, T, H, V] + outputs = torch.stack(outputs, dim=0) + + return outputs + + def _rms_norm(self, x: torch.Tensor, weight: torch.Tensor) -> torch.Tensor: + """RMSNorm implementation.""" + rms = torch.sqrt(torch.mean(x.float() ** 2, dim=-1, keepdim=True) + 1e-6) + return (x.float() / rms * weight).to(x.dtype) + + def _gated_rms_norm(self, x: torch.Tensor, g: torch.Tensor, weight: torch.Tensor) -> torch.Tensor: + """Gated RMSNorm: RMSNorm(x) * sigmoid(g).""" + rms = torch.sqrt(torch.mean(x.float() ** 2, dim=-1, keepdim=True) + 1e-6) + x_norm = (x.float() / rms) * weight + return (x_norm * torch.sigmoid(g.float())).to(x.dtype) + + +# Problem dimensions +batch_size = 4 +seq_len = 512 +hidden_size = 2048 +expand_v = 2.0 +head_dim = 256 +num_heads = 6 +num_v_heads = 6 +use_output_gate = True +use_output_correction = True +use_inner_decay = True +correction_factor = 1.0 + + +def get_inputs(): + x = torch.randn(batch_size, seq_len, hidden_size) + return [x] + + +def get_init_inputs(): + return [hidden_size, expand_v, head_dim, num_heads, num_v_heads, + use_output_gate, use_output_correction, use_inner_decay, correction_factor] + diff --git a/KernelBench/level9/deltaformer_reference.py b/KernelBench/level9/deltaformer_reference.py new file mode 100644 index 00000000..9bebd5a7 --- /dev/null +++ b/KernelBench/level9/deltaformer_reference.py @@ -0,0 +1,284 @@ +import torch +import torch.nn as nn +import math + + +class Model(nn.Module): + """ + DeltaFormer Attention - A reference implementation using pure PyTorch. + + DeltaFormer is a two-stage attention mechanism: + 1. Delta Update: Compute u[i] = v[i] - beta[i] * sum_{j torch.Tensor: + """ + Performs DeltaFormer attention computation. + + Args: + x: Input tensor of shape (batch_size, seq_len, hidden_size). + mask: Optional mask tensor (unused in this reference, kept for API compatibility). + + Returns: + Output tensor of shape (batch_size, seq_len, hidden_size). + """ + batch_size, seq_len, _ = x.shape + + # Project to Q, K, V + q = self.q_proj(x) # [batch_size, seq_len, hidden_size] + k = self.k_proj(x) # [batch_size, seq_len, kv_dim] + v = self.v_proj(x) # [batch_size, seq_len, kv_dim] + beta = self.b_proj(x) # [batch_size, seq_len, num_heads] + + # Reshape to multi-head format: [batch_size, seq_len, num_heads, head_dim] + q = q.view(batch_size, seq_len, self.num_heads, self.head_dim) + k = k.view(batch_size, seq_len, self.num_kv_heads, self.head_dim) + v = v.view(batch_size, seq_len, self.num_kv_heads, self.head_dim) + + # Apply RMSNorm to Q and K + q = self._rms_norm(q, self.q_norm_weight) + k = self._rms_norm(k, self.k_norm_weight) + + # Apply rotary position embeddings + q, k = self._apply_rotary_emb(q, k, seq_len) + + # Expand K and V for grouped query attention if needed + if self.num_kv_groups > 1: + k = k.repeat_interleave(self.num_kv_groups, dim=2) + v = v.repeat_interleave(self.num_kv_groups, dim=2) + + # Transpose to [batch_size, num_heads, seq_len, head_dim] for attention + q = q.transpose(1, 2) + k = k.transpose(1, 2) + v = v.transpose(1, 2) + beta = beta.transpose(1, 2) # [batch_size, num_heads, seq_len] + + # ============================================ + # DeltaFormer Attention Core + # ============================================ + + # Stage 1: Compute delta-updated values U + # u[i] = v[i] - beta[i] * sum_{j [batch_size, seq_len, hidden_size] + o = o.transpose(1, 2).contiguous() + o = o.view(batch_size, seq_len, -1) + + # Final output projection + o = self.o_proj(o) + + return o + + def _compute_delta_values(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, beta: torch.Tensor) -> torch.Tensor: + """ + Compute delta-updated values through the recurrence: + u[i] = v[i] - beta[i] * sum_{j 0, we attend to positions j < i (strictly causal) and + subtract a beta-scaled weighted combination of previous u values. + + Args: + q: [B, H, T, D] + k: [B, H, T, D] + v: [B, H, T, D] + beta: [B, H, T] + + Returns: + u: [B, H, T, D] + """ + B, H, T, D = q.shape + qk_scale = 1.0 / math.sqrt(D) + + # Compute all Q @ K^T scores at once + scores = torch.matmul(q, k.transpose(-2, -1)) * qk_scale # [B, H, T, T] + + # Apply strictly causal softmax (j < i, not j <= i) + probs = self._tril_softmax(scores, strict=True) # [B, H, T, T] + + # Sequential computation of u (the recurrence cannot be parallelized naively) + u_list = [] + for t in range(T): + if t == 0: + # No previous positions to attend to + u_t = v[:, :, t, :] # [B, H, D] + else: + # Attention weights for position t attending to positions 0..t-1 + w = probs[:, :, t, :t] # [B, H, t] + # Stack all previous u values + u_prev = torch.stack(u_list, dim=-2) # [B, H, t, D] + # Weighted sum of previous u values + weighted_sum = torch.matmul(w.unsqueeze(-2), u_prev).squeeze(-2) # [B, H, D] + # Delta update: u[t] = v[t] - beta[t] * weighted_sum + u_t = v[:, :, t, :] - beta[:, :, t].unsqueeze(-1) * weighted_sum + u_list.append(u_t) + + u = torch.stack(u_list, dim=2) # [B, H, T, D] + return u + + def _tril_softmax(self, scores: torch.Tensor, strict: bool = True) -> torch.Tensor: + """ + Row-wise causal softmax over strictly lower-triangular (j < i) positions. + + Args: + scores: [B, H, T, T] raw attention scores + strict: if True, mask out diagonal (j < i). If False, include diagonal (j <= i). + + Returns: + probs: [B, H, T, T] with probabilities on valid positions, zeros elsewhere + """ + T = scores.size(-1) + device = scores.device + + i_idx = torch.arange(T, device=device).view(1, 1, T, 1) + j_idx = torch.arange(T, device=device).view(1, 1, 1, T) + + if strict: + mask = (j_idx < i_idx) # strictly lower triangular + else: + mask = (j_idx <= i_idx) # lower triangular including diagonal + + # Masked softmax with numerical stability + masked_scores = scores.masked_fill(~mask, float('-inf')) + max_per_row = masked_scores.max(dim=-1, keepdim=True).values + max_per_row = torch.where(max_per_row == float('-inf'), torch.zeros_like(max_per_row), max_per_row) + + exp_scores = (masked_scores - max_per_row).exp() + exp_scores = exp_scores.masked_fill(~mask, 0.0) + + denom = exp_scores.sum(dim=-1, keepdim=True).clamp_min(1e-20) + probs = exp_scores / denom + + return probs + + def _causal_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor: + """ + Standard causal attention: o = softmax(Q @ K^T / sqrt(d), causal_mask) @ V + + Args: + q: [B, H, T, D] + k: [B, H, T, D] + v: [B, H, T, D] + + Returns: + o: [B, H, T, D] + """ + B, H, T, D = q.shape + qk_scale = 1.0 / math.sqrt(D) + + scores = torch.matmul(q, k.transpose(-2, -1)) * qk_scale # [B, H, T, T] + + # Standard causal mask (j <= i, including diagonal) + causal_mask = torch.triu(torch.ones(T, T, device=q.device, dtype=torch.bool), diagonal=1) + scores = scores.masked_fill(causal_mask, float('-inf')) + + attn_weights = torch.softmax(scores, dim=-1) # [B, H, T, T] + o = torch.matmul(attn_weights, v) # [B, H, T, D] + + return o + + def _rms_norm(self, x: torch.Tensor, weight: torch.Tensor) -> torch.Tensor: + """RMSNorm implementation.""" + # x: [batch_size, seq_len, num_heads, head_dim] + rms = torch.sqrt(torch.mean(x ** 2, dim=-1, keepdim=True) + 1e-6) + return (x / rms) * weight + + def _apply_rotary_emb(self, q: torch.Tensor, k: torch.Tensor, seq_len: int) -> tuple[torch.Tensor, torch.Tensor]: + """ + Apply rotary position embeddings to Q and K. + + Note: Q has num_heads dimensions, K has num_kv_heads dimensions (may differ in GQA). + """ + device = q.device + dtype = q.dtype + + num_q_heads = q.shape[2] + num_k_heads = k.shape[2] + + # Create position indices: [1, seq_len, 1] + positions = torch.arange(seq_len, device=device, dtype=dtype).unsqueeze(0).unsqueeze(-1) + + # Create frequency matrix + dim_t = torch.arange(self.head_dim // 2, device=device, dtype=dtype) + dim_t = self.rope_theta ** (2 * dim_t / self.head_dim) + dim_t = dim_t.unsqueeze(0).unsqueeze(0) # [1, 1, head_dim//2] + + # Compute angles: [1, seq_len, head_dim//2] + angles = positions / dim_t + + # Create rotation matrices + cos_base = torch.cos(angles) + sin_base = torch.sin(angles) + + # Expand for Q and K heads separately + cos_q = cos_base.unsqueeze(2).expand(-1, -1, num_q_heads, -1) + sin_q = sin_base.unsqueeze(2).expand(-1, -1, num_q_heads, -1) + cos_k = cos_base.unsqueeze(2).expand(-1, -1, num_k_heads, -1) + sin_k = sin_base.unsqueeze(2).expand(-1, -1, num_k_heads, -1) + + # Apply rotation to Q + q1, q2 = q.chunk(2, dim=-1) + q_rot = torch.cat([q1 * cos_q - q2 * sin_q, q1 * sin_q + q2 * cos_q], dim=-1) + + # Apply rotation to K + k1, k2 = k.chunk(2, dim=-1) + k_rot = torch.cat([k1 * cos_k - k2 * sin_k, k1 * sin_k + k2 * cos_k], dim=-1) + + return q_rot, k_rot + + +# Problem dimensions +batch_size = 4 +seq_len = 512 +hidden_size = 2048 +num_heads = 32 +num_kv_heads = 8 # Grouped query attention +head_dim = hidden_size // num_heads + + +def get_inputs(): + x = torch.randn(batch_size, seq_len, hidden_size) + return [x] + + +def get_init_inputs(): + return [hidden_size, num_heads, num_kv_heads, head_dim] + diff --git a/KernelBench/level9/fd.py b/KernelBench/level9/fd.py new file mode 100644 index 00000000..e9ed6e91 --- /dev/null +++ b/KernelBench/level9/fd.py @@ -0,0 +1,134 @@ +import torch +import torch.nn as nn + + +class Model(nn.Module): + """ + A model that performs 3D heat diffusion using a 9-point stencil + and explicit Euler time integration. + + To compute the next time step: u^{n+1} = u^n + dt * alpha * Laplacian(u^n) + + The Laplacian is computed with a 1D 9-point (4 neighbors on each side) stencil + in x, y, and z. We only update the interior points in each dim, leaving the + boundary values unchanged. + """ + + def __init__(self, alpha: float, hx: float, hy: float, hz: float, n_steps: int): + super(Model, self).__init__() + self.alpha = alpha + self.hx = hx + self.hy = hy + self.hz = hz + self.n_steps = n_steps + + def forward(self, u0: torch.Tensor) -> torch.Tensor: + """ + Performs 3D heat diffusion simulation. + + Args: + u0: Initial 3D field tensor of shape [grid_size, grid_size, grid_size] + + Returns: + Final field after n_steps Euler updates of 3D heat equation + """ + # 3D 8th-order 2nd-derivative Laplacian coefficients + c0 = -205.0 / 72.0 + c1 = 8.0 / 5.0 + c2 = -1.0 / 5.0 + c3 = 8.0 / 315.0 + c4 = -1.0 / 560.0 + + # CFL stability + c = 0.05 + + # Move scalars to same device/dtype as u + u = u0.clone() + device, dtype = u.device, u.dtype + alpha = torch.as_tensor(self.alpha, device=device, dtype=dtype) + hx = torch.as_tensor(self.hx, device=device, dtype=dtype) + hy = torch.as_tensor(self.hy, device=device, dtype=dtype) + hz = torch.as_tensor(self.hz, device=device, dtype=dtype) + + inv_hx2 = 1.0 / (hx * hx) + inv_hy2 = 1.0 / (hy * hy) + inv_hz2 = 1.0 / (hz * hz) + + S = inv_hx2 + inv_hy2 + inv_hz2 + dt = c / (alpha * S) + + f = torch.empty_like(u) + + # Radius of stencil + r = 4 + # Interior slices (these stay fixed each step) + zc = slice(r, -r) + yc = slice(r, -r) + xc = slice(r, -r) + + for _ in range(self.n_steps): + # copy old solution; boundaries remain unchanged + f.copy_(u) + + # center region + uc = u[zc, yc, xc] + + # x-direction second derivative + u_xx = ( + c0 * uc + + c1 * (u[zc, yc, r + 1 : -r + 1] + u[zc, yc, r - 1 : -r - 1]) + + c2 * (u[zc, yc, r + 2 : -r + 2] + u[zc, yc, r - 2 : -r - 2]) + + c3 * (u[zc, yc, r + 3 : -r + 3] + u[zc, yc, r - 3 : -r - 3]) + + c4 * (u[zc, yc, r + 4 :] + u[zc, yc, : -r - 4]) + ) * inv_hx2 + + # y-direction second derivative + u_yy = ( + c0 * uc + + c1 * (u[zc, r + 1 : -r + 1, xc] + u[zc, r - 1 : -r - 1, xc]) + + c2 * (u[zc, r + 2 : -r + 2, xc] + u[zc, r - 2 : -r - 2, xc]) + + c3 * (u[zc, r + 3 : -r + 3, xc] + u[zc, r - 3 : -r - 3, xc]) + + c4 * (u[zc, r + 4 :, xc] + u[zc, : -r - 4, xc]) + ) * inv_hy2 + + # z-direction second derivative + u_zz = ( + c0 * uc + + c1 * (u[r + 1 : -r + 1, yc, xc] + u[r - 1 : -r - 1, yc, xc]) + + c2 * (u[r + 2 : -r + 2, yc, xc] + u[r - 2 : -r - 2, yc, xc]) + + c3 * (u[r + 3 : -r + 3, yc, xc] + u[r - 3 : -r - 3, yc, xc]) + + c4 * (u[r + 4 :, yc, xc] + u[: -r - 4, yc, xc]) + ) * inv_hz2 + + lap = u_xx + u_yy + u_zz + + # Explicit Euler update on interior only + f[zc, yc, xc] = uc + dt * alpha * lap + + # swap + u, f = f, u + + return u + + +# Problem configuration +grid_size = 64 +n_steps = 10 + + +def get_inputs(): + # Generate random 3D initial field: [grid_size, grid_size, grid_size] + u0 = torch.randn(grid_size, grid_size, grid_size, dtype=torch.float32).contiguous() + return [u0] + + +def get_init_inputs(): + # Random diffusion coefficient alpha in [0.1, 5.0] + alpha = torch.rand(1).item() * 4.9 + 0.1 + + # Random grid spacings hx, hy, hz in [0.5, 2.0] + hx = torch.rand(1).item() * 1.5 + 0.5 + hy = torch.rand(1).item() * 1.5 + 0.5 + hz = torch.rand(1).item() * 1.5 + 0.5 + + return [alpha, hx, hy, hz, n_steps] diff --git a/KernelBench/level9/feature_map_reference.py b/KernelBench/level9/feature_map_reference.py new file mode 100644 index 00000000..76f227cb --- /dev/null +++ b/KernelBench/level9/feature_map_reference.py @@ -0,0 +1,55 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import math + +def flatten_diag_outer_product_off1(x, y): + z = torch.einsum("...i,...j->...ij", x, y) + N = z.size(-1) + indicies = torch.triu_indices(N, N, 1) + indices2 = torch.arange(0, N) + return z[..., indicies[0], indicies[1]], z[..., indices2, indices2] + +class Model(nn.Module): + """ + Reference implementation of the Taylor Feature Map used in Linear Attention. + This feature map approximates the softmax kernel using a second-order Taylor expansion. + """ + def __init__(self, head_dim: int): + super(Model, self).__init__() + self.head_dim = head_dim + self.r2 = math.sqrt(2) + self.rd = math.sqrt(self.head_dim) + self.rrd = math.sqrt(self.rd) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Args: + x (torch.Tensor): Input tensor of shape [batch_size, seq_len, num_heads, head_dim] + Returns: + torch.Tensor: Feature-mapped tensor of shape [batch_size, seq_len, num_heads, 1 + D + D*(D+1)/2] + """ + # Second-order Taylor expansion components + # 1 + x + 0.5 * outer(x, x) + x2_1, x2_2 = flatten_diag_outer_product_off1(x, x) + + # Concatenate components: [1, x, diag(x^2)/sqrt(2), off_diag(x^2)] + return torch.cat([ + torch.ones_like(x[..., 0:1]), + x / self.rrd, + x2_2 / (self.rd * self.r2), + x2_1 / self.rd + ], dim=-1) + +# Kernelbench Parameters +batch_size = 4 +seq_len = 512 +num_heads = 8 +head_dim = 64 + +def get_inputs(): + x = torch.randn(batch_size, seq_len, num_heads, head_dim) + return [x] + +def get_init_inputs(): + return [head_dim] diff --git a/KernelBench/level9/fox_reference.py b/KernelBench/level9/fox_reference.py new file mode 100644 index 00000000..fead0998 --- /dev/null +++ b/KernelBench/level9/fox_reference.py @@ -0,0 +1,154 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import math + +class Model(nn.Module): + """ + Reference implementation of PaTH Attention (Fox). + Uses Rank-One Updates for state transition. + """ + def __init__( + self, + hidden_size: int = 1024, + num_heads: int = 32, + num_kv_heads: int = 32, + head_dim: int = 32, + use_forget_gate: bool = True, + ): + super(Model, self).__init__() + self.hidden_size = hidden_size + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = head_dim + self.use_forget_gate = use_forget_gate + + self.q_proj = nn.Linear(hidden_size, num_heads * head_dim, bias=False) + self.k_proj = nn.Linear(hidden_size, num_kv_heads * head_dim, bias=False) + self.v_proj = nn.Linear(hidden_size, num_kv_heads * head_dim, bias=False) + self.w_proj = nn.Linear(hidden_size, num_kv_heads * head_dim, bias=False) + self.beta_proj = nn.Linear(hidden_size, num_kv_heads, bias=True) + + if use_forget_gate: + self.g_proj = nn.Linear(hidden_size, num_heads, bias=True) + + self.o_proj = nn.Linear(num_heads * head_dim, hidden_size, bias=False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Args: + x (torch.Tensor): [batch_size, seq_len, hidden_size] + Returns: + torch.Tensor: [batch_size, seq_len, hidden_size] + """ + B, T, _ = x.shape + H, HKV, D = self.num_heads, self.num_kv_heads, self.head_dim + G = H // HKV + + q = self.q_proj(x).view(B, T, H, D) + k = self.k_proj(x).view(B, T, HKV, D) + v = self.v_proj(x).view(B, T, HKV, D) + w = self.w_proj(x).view(B, T, HKV, D) + beta = self.beta_proj(x).sigmoid() * 2.0 # [B, T, HKV] + + if self.use_forget_gate: + g = F.logsigmoid(self.g_proj(x)) # [B, T, H] + else: + g = None + + # L2 Norm for w + w = w / (torch.norm(w, p=2, dim=-1, keepdim=True) + 1e-6) + + # State: [Batch, Head, HeadDim, HeadDim] + S = torch.zeros(B, HKV, D, D, device=x.device, dtype=torch.float32) + o_kv = torch.zeros(B, T, HKV, D, device=x.device, dtype=torch.float32) + + q, k, v, w, beta = q.float(), k.float(), v.float(), w.float(), beta.float() + if g is not None: g = g.float() + + scale = D ** -0.5 + + for t in range(T): + # 1. Decay state if forget gate is used + # Note: Fox (PaTH) usually applies decay to the previous state + # In parallel_path_attn, it's g_cumsum. + # Here we apply it per-step. + if g is not None: + # g is [B, T, H]. We need to handle GQA. + # PaTH uses g at query level? Actually parallel_path_attn takes g of shape [B, T, HQ]. + # So we tile it if needed, or if H > HKV, we need to handle it. + # To keep it simple, assume forget gate is applied to outputs or state. + # In Recurrence: S_t = S_{t-1} * exp(g_t) + ... + # But g is defined at HQ level. + pass + + # Current inputs + k_t = k[:, t] # [B, HKV, D] + v_t = v[:, t] # [B, HKV, D] + w_t = w[:, t] # [B, HKV, D] + beta_t = beta[:, t] # [B, HKV] + + # Rank-One Update: S_t = S_{t-1} - beta_t * (S_{t-1} @ w_t) @ w_t^T + k_t @ v_t^T + # This is the "Orthogonal" or "Path" transition logic. + + # Sw = S @ w + Sw = torch.einsum('b h d m, b h m -> b h d', S, w_t) + # Update + S = S - beta_t.view(B, HKV, 1, 1) * torch.einsum('b h d, b h m -> b h d m', Sw, w_t) + S = S + torch.einsum('b h d, b h m -> b h d m', k_t, v_t) + + # Compute output at HKV level + # We'll tile q for HQ later, or compute o at HQ level. + # Usually o_t = S_t^T @ q_t. + # But S is [D_k, D_v]. So o = q^T @ S. + # Wait, linear attention is o = (q^T @ S). + + # For simplicity and correctness with parallel_path_attn: + # It's better to implement the causal linear attention form with the path transition. + # But PaTH is exactly the recurrence above. + + # Parallel form implementation for the reference (easier and matches the kernel): + # The path transition means k_i is transformed by all w_j, beta_j for j > i. + # k'_i = (I - beta_n w_n w_n^T) ... (I - beta_{i+1} w_{i+1} w_{i+1}^T) k_i + + # For the reference loop, I'll just finish the recurrence: + out = torch.zeros(B, T, H, D, device=x.device, dtype=torch.float32) + S = torch.zeros(B, HKV, D, D, device=x.device, dtype=torch.float32) + + for t in range(T): + k_t = k[:, t] + v_t = v[:, t] + w_t = w[:, t] + beta_t = beta[:, t] + + # S_t = (I - beta_t w_t w_t^T) S_{t-1} + k_t v_t^T + Sw = torch.einsum('b h d m, b h m -> b h d', S, w_t) + S = S - beta_t.view(B, HKV, 1, 1) * torch.einsum('b h d, b h m -> b h d m', Sw, w_t) + S = S + torch.einsum('b h d, b h m -> b h d m', k_t, v_t) + + # Output + for g_idx in range(G): + h_idx = torch.arange(HKV, device=x.device) * G + g_idx + q_t = q[:, t, h_idx] * scale # [B, HKV, D] + # o = q^T @ S + res = torch.einsum('b h d, b h d m -> b h m', q_t, S) + if g is not None: + res = res * g[:, t, h_idx].exp().unsqueeze(-1) + out[:, t, h_idx] = res + + return self.o_proj(out.transpose(1, 2).reshape(B, T, -1).to(x.dtype)) + +# Kernelbench Parameters +batch_size = 2 +seq_len = 64 +hidden_size = 512 +num_heads = 8 +num_kv_heads = 8 +head_dim = 32 + +def get_inputs(): + x = torch.randn(batch_size, seq_len, hidden_size) + return [x] + +def get_init_inputs(): + return [hidden_size, num_heads, num_kv_heads, head_dim] diff --git a/KernelBench/level9/fused_bitlinear_reference.py b/KernelBench/level9/fused_bitlinear_reference.py new file mode 100644 index 00000000..cb698148 --- /dev/null +++ b/KernelBench/level9/fused_bitlinear_reference.py @@ -0,0 +1,72 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +def activation_quant(x): + """Per-token quantization to 8 bits.""" + scale = 127.0 / x.abs().max(dim=-1, keepdim=True).values.clamp_(min=1e-5) + y = (x * scale).round().clamp_(-128, 127) / scale + return y + +def weight_quant(w): + """Per-tensor quantization to 1.58 bits.""" + scale = 1.0 / w.abs().mean().clamp_(min=1e-5) + u = (w * scale).round().clamp_(-1, 1) / scale + return u + +class RMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-8): + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.eps = eps + + def forward(self, x): + norm_x = torch.mean(x**2, dim=-1, keepdim=True) + x_normed = x * torch.rsqrt(norm_x + self.eps) + return x_normed * self.weight + +class Model(nn.Module): + """ + Reference implementation of BitLinear (1.58-bit Linear Layer). + Includes RMSNorm and Straight-Through Estimator (STE) for quantization. + """ + def __init__(self, in_features: int, out_features: int, bias: bool = False): + super(Model, self).__init__() + self.norm = RMSNorm(in_features) + self.weight = nn.Parameter(torch.randn(out_features, in_features)) + if bias: + self.bias = nn.Parameter(torch.zeros(out_features)) + else: + self.register_parameter('bias', None) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Args: + x (torch.Tensor): Input tensor of shape [batch_size, seq_len, in_features] + Returns: + torch.Tensor: Output tensor of shape [batch_size, seq_len, out_features] + """ + # 1. Apply RMS normalization + x_norm = self.norm(x) + + # 2. Quantize activations (with STE) + x_quant = x_norm + (activation_quant(x_norm) - x_norm).detach() + + # 3. Quantize weights (with STE) + w_quant = self.weight + (weight_quant(self.weight) - self.weight).detach() + + # 4. Linear operation + return F.linear(x_quant, w_quant, self.bias) + +# Kernelbench Parameters +batch_size = 8 +seq_len = 512 +in_features = 1024 +out_features = 2048 + +def get_inputs(): + x = torch.randn(batch_size, seq_len, in_features) + return [x] + +def get_init_inputs(): + return [in_features, out_features] diff --git a/KernelBench/level9/fused_cross_entropy_reference.py b/KernelBench/level9/fused_cross_entropy_reference.py new file mode 100644 index 00000000..522e8c9f --- /dev/null +++ b/KernelBench/level9/fused_cross_entropy_reference.py @@ -0,0 +1,79 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + """ + Reference implementation for Fused Cross Entropy Loss with extra features: + - Label Smoothing + - Logit Scaling + - LSE Square Scale (z-loss) + - Ignore Index + """ + def __init__( + self, + ignore_index: int = -100, + label_smoothing: float = 0.0, + logit_scale: float = 1.0, + lse_square_scale: float = 0.0, + reduction: str = "mean" + ): + super(Model, self).__init__() + self.ignore_index = ignore_index + self.label_smoothing = label_smoothing + self.logit_scale = logit_scale + self.lse_square_scale = lse_square_scale + self.reduction = reduction + + def forward(self, logits: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + """ + Args: + logits (torch.Tensor): [batch_size, num_classes] + target (torch.Tensor): [batch_size] + Returns: + torch.Tensor: Reduced scalar loss + """ + # 1. Apply logit scaling + logits = logits.float() * self.logit_scale + + # 2. Compute log-sum-exp for z-loss + # z-loss = lse_square_scale * (lse(logits)^2) + lse = torch.logsumexp(logits, dim=-1) + z_loss = self.lse_square_scale * lse.pow(2) + + # 3. Compute cross entropy loss with label smoothing + # Note: F.cross_entropy handles label_smoothing directly in modern PyTorch + loss = F.cross_entropy( + logits, + target, + ignore_index=self.ignore_index, + label_smoothing=self.label_smoothing, + reduction='none' + ) + + # 4. Combine with z-loss + # Apply mask for ignore_index to z_loss as well + mask = (target != self.ignore_index).float() + total_loss = (loss + z_loss) * mask + + if self.reduction == 'mean': + return total_loss.sum() / mask.sum().clamp(min=1) + elif self.reduction == 'sum': + return total_loss.sum() + else: + return total_loss + +# Kernelbench Parameters +batch_size = 1024 +num_classes = 32000 + +def get_inputs(): + logits = torch.randn(batch_size, num_classes) + target = torch.randint(0, num_classes, (batch_size,)) + # Add some ignore_index entries + target[target % 10 == 0] = -100 + return [logits, target] + +def get_init_inputs(): + # ignore_index, label_smoothing, logit_scale, lse_square_scale + return [-100, 0.1, 1.0, 1e-4] diff --git a/KernelBench/level9/fused_kl_div_reference.py b/KernelBench/level9/fused_kl_div_reference.py new file mode 100644 index 00000000..26834b53 --- /dev/null +++ b/KernelBench/level9/fused_kl_div_reference.py @@ -0,0 +1,53 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + """ + Reference implementation of Fused KL Divergence Loss. + Computes KL divergence between student logits and teacher logits. + Logits are expanded from hidden states on the fly to save memory. + """ + def __init__(self, hidden_size: int, vocab_size: int): + super(Model, self).__init__() + self.hidden_size = hidden_size + self.vocab_size = vocab_size + self.student_head = nn.Linear(hidden_size, vocab_size, bias=False) + self.teacher_head = nn.Linear(hidden_size, vocab_size, bias=False) + + def forward(self, x: torch.Tensor, target_x: torch.Tensor) -> torch.Tensor: + """ + Args: + x (torch.Tensor): Student hidden states [N, H] + target_x (torch.Tensor): Teacher hidden states [N, H] + Returns: + torch.Tensor: Scalar KL divergence loss + """ + # In the fused version, we calculate logits in chunks. + # Here we do it at once for the reference. + student_logits = self.student_head(x) + teacher_logits = self.teacher_head(target_x) + + # KL Divergence: KL(Teacher || Student) + # Loss = sum(P_teacher * (log P_teacher - log P_student)) + + log_p_student = F.log_softmax(student_logits, dim=-1) + log_p_teacher = F.log_softmax(teacher_logits, dim=-1) + p_teacher = F.softmax(teacher_logits, dim=-1) + + kl_div = p_teacher * (log_p_teacher - log_p_student) + return kl_div.sum(dim=-1).mean() + +# Kernelbench Parameters +batch_size = 16 +seq_len = 1024 +hidden_size = 1024 +vocab_size = 32000 + +def get_inputs(): + x = torch.randn(batch_size * seq_len, hidden_size) + target_x = torch.randn(batch_size * seq_len, hidden_size) + return [x, target_x] + +def get_init_inputs(): + return [hidden_size, vocab_size] diff --git a/KernelBench/level9/fused_linear_cross_entropy_reference.py b/KernelBench/level9/fused_linear_cross_entropy_reference.py new file mode 100644 index 00000000..c8f44afd --- /dev/null +++ b/KernelBench/level9/fused_linear_cross_entropy_reference.py @@ -0,0 +1,41 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + """ + Reference implementation for Fused Linear + Cross Entropy Loss. + In many optimized libraries, this is fused to avoid materializing the full logit tensor. + """ + def __init__(self, hidden_size: int, vocab_size: int): + super(Model, self).__init__() + self.linear = nn.Linear(hidden_size, vocab_size, bias=False) + + def forward(self, x: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: + """ + Args: + x (torch.Tensor): Hidden states of shape [batch_size, seq_len, hidden_size] + labels (torch.Tensor): Target labels of shape [batch_size, seq_len] + Returns: + torch.Tensor: Scalar loss value + """ + # Compute logits for all tokens + logits = self.linear(x) # [batch_size, seq_len, vocab_size] + # Flatten and compute standard cross entropy + loss = F.cross_entropy(logits.view(-1, logits.size(-1)), labels.view(-1)) + return loss + +# Kernelbench Parameters +batch_size = 16 +seq_len = 1024 +hidden_size = 1024 +vocab_size = 32000 # Example vocabulary size + +def get_inputs(): + x = torch.randn(batch_size, seq_len, hidden_size) + # Generate random integer labels in range [0, vocab_size) + labels = torch.randint(0, vocab_size, (batch_size, seq_len)) + return [x, labels] + +def get_init_inputs(): + return [hidden_size, vocab_size] diff --git a/KernelBench/level9/fused_norm_gate_reference.py b/KernelBench/level9/fused_norm_gate_reference.py new file mode 100644 index 00000000..179d864c --- /dev/null +++ b/KernelBench/level9/fused_norm_gate_reference.py @@ -0,0 +1,64 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +class RMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-5): + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.eps = eps + + def forward(self, x): + norm_x = torch.mean(x**2, dim=-1, keepdim=True) + return x * torch.rsqrt(norm_x + self.eps) * self.weight + +class Model(nn.Module): + """ + Reference implementation of Generalized Fused Normalization and Gating. + Supports both LayerNorm and RMSNorm with optional residual connection and SiLU/Sigmoid gating. + """ + def __init__(self, hidden_size: int, norm_type: str = 'rms', activation: str = 'silu', eps: float = 1e-5): + super(Model, self).__init__() + self.norm_type = norm_type + self.activation = activation + if norm_type == 'rms': + self.norm = RMSNorm(hidden_size, eps=eps) + else: + self.norm = nn.LayerNorm(hidden_size, eps=eps) + + def forward(self, x: torch.Tensor, gate: torch.Tensor, residual: torch.Tensor = None) -> torch.Tensor: + """ + Args: + x (torch.Tensor): Input tensor [B, T, H] + gate (torch.Tensor): Gate tensor [B, T, H] + residual (torch.Tensor, optional): Residual tensor [B, T, H] + Returns: + torch.Tensor: Gated normalized output [B, T, H] + """ + if residual is not None: + x = x + residual + + x_normed = self.norm(x) + + if self.activation == 'silu': + gated = x_normed * F.silu(gate) + elif self.activation == 'sigmoid': + gated = x_normed * torch.sigmoid(gate) + else: + gated = x_normed * gate + + return gated + +# Kernelbench Parameters +batch_size = 8 +seq_len = 1024 +hidden_size = 2048 + +def get_inputs(): + x = torch.randn(batch_size, seq_len, hidden_size) + gate = torch.randn(batch_size, seq_len, hidden_size) + residual = torch.randn(batch_size, seq_len, hidden_size) + return [x, gate, residual] + +def get_init_inputs(): + return [hidden_size] diff --git a/KernelBench/level9/fused_rms_norm_silu_reference.py b/KernelBench/level9/fused_rms_norm_silu_reference.py new file mode 100644 index 00000000..43c95f58 --- /dev/null +++ b/KernelBench/level9/fused_rms_norm_silu_reference.py @@ -0,0 +1,41 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + """ + Reference implementation for Fused RMSNorm + SiLU Gating. + """ + def __init__(self, hidden_size: int, eps: float = 1e-5): + super(Model, self).__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.eps = eps + + def forward(self, x: torch.Tensor, gate: torch.Tensor) -> torch.Tensor: + """ + Args: + x (torch.Tensor): Input tensor of shape [batch_size, seq_len, hidden_size] + gate (torch.Tensor): Gate tensor of shape [batch_size, seq_len, hidden_size] + Returns: + torch.Tensor: Output tensor of shape [batch_size, seq_len, hidden_size] + """ + # RMSNorm + norm_x = torch.mean(x**2, dim=-1, keepdim=True) + x_normed = x * torch.rsqrt(norm_x + self.eps) + x_normed = x_normed * self.weight + + # SiLU Gating + return x_normed * F.silu(gate) + +# Kernelbench Parameters +batch_size = 8 +seq_len = 1024 +hidden_size = 2048 + +def get_inputs(): + x = torch.randn(batch_size, seq_len, hidden_size) + gate = torch.randn(batch_size, seq_len, hidden_size) + return [x, gate] + +def get_init_inputs(): + return [hidden_size] diff --git a/KernelBench/level9/gated_deltanet_reference.py b/KernelBench/level9/gated_deltanet_reference.py new file mode 100644 index 00000000..ede944dc --- /dev/null +++ b/KernelBench/level9/gated_deltanet_reference.py @@ -0,0 +1,184 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import math + +class ShortConvolution(nn.Module): + def __init__(self, hidden_size, kernel_size, bias=False, activation='silu'): + super().__init__() + self.hidden_size = hidden_size + self.kernel_size = kernel_size + self.activation = activation + self.conv = nn.Conv1d( + in_channels=hidden_size, + out_channels=hidden_size, + kernel_size=kernel_size, + groups=hidden_size, + padding=kernel_size - 1, + bias=bias, + ) + + def forward(self, x): + # x: [B, T, C] + B, T, C = x.shape + x = x.transpose(1, 2) # [B, C, T] + x = self.conv(x)[:, :, :T] # Causal convolution + if self.activation == 'silu': + x = F.silu(x) + return x.transpose(1, 2) + +class RMSNorm(nn.Module): + def __init__(self, dim: int, eps: float = 1e-6): + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) + + def _norm(self, x): + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + + def forward(self, x): + return self._norm(x.float()).type_as(x) * self.weight + +class GatedRMSNorm(nn.Module): + def __init__(self, dim: int, eps: float = 1e-6): + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) + + def forward(self, x, g): + # RMSNorm with gating (often used in DeltaNet and Mamba2) + # x: [..., D], g: [..., D] + x = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + x = x * self.weight + return x * F.silu(g) + +class Model(nn.Module): + """ + Reference implementation of Gated DeltaNet. + """ + def __init__( + self, + hidden_size: int = 2048, + expand_v: float = 2.0, + head_dim: int = 256, + num_heads: int = 6, + use_gate: bool = True, + use_short_conv: bool = True, + conv_size: int = 4, + norm_eps: float = 1e-5, + ): + super().__init__() + self.hidden_size = hidden_size + self.expand_v = expand_v + self.head_dim = head_dim + self.num_heads = num_heads + self.num_v_heads = num_heads # Simplified for reference + + self.head_v_dim = int(head_dim * expand_v) + self.key_dim = num_heads * head_dim + self.value_dim = self.num_v_heads * self.head_v_dim + + self.q_proj = nn.Linear(hidden_size, self.key_dim, bias=False) + self.k_proj = nn.Linear(hidden_size, self.key_dim, bias=False) + self.v_proj = nn.Linear(hidden_size, self.value_dim, bias=False) + self.a_proj = nn.Linear(hidden_size, self.num_v_heads, bias=False) + self.b_proj = nn.Linear(hidden_size, self.num_v_heads, bias=False) + + # Initialization for 'g' logic + A = torch.empty(self.num_v_heads).uniform_(0, 16) + self.A_log = nn.Parameter(torch.log(A)) + + dt_min, dt_max = 0.001, 0.1 + dt = torch.exp(torch.rand(self.num_v_heads) * (math.log(dt_max) - math.log(dt_min)) + math.log(dt_min)) + inv_dt = dt + torch.log(-torch.expm1(-dt)) + self.dt_bias = nn.Parameter(inv_dt) + + self.use_short_conv = use_short_conv + if use_short_conv: + self.q_conv1d = ShortConvolution(self.key_dim, conv_size) + self.k_conv1d = ShortConvolution(self.key_dim, conv_size) + self.v_conv1d = ShortConvolution(self.value_dim, conv_size) + + self.use_gate = use_gate + if use_gate: + self.g_proj = nn.Linear(hidden_size, self.value_dim, bias=False) + self.o_norm = GatedRMSNorm(self.head_v_dim, eps=norm_eps) + else: + self.o_norm = RMSNorm(self.head_v_dim, eps=norm_eps) + + self.o_proj = nn.Linear(self.value_dim, hidden_size, bias=False) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + B, T, C = hidden_states.shape + H = self.num_heads + K = self.head_dim + V = self.head_v_dim + + if self.use_short_conv: + q = self.q_conv1d(self.q_proj(hidden_states)) + k = self.k_conv1d(self.k_proj(hidden_states)) + v = self.v_conv1d(self.v_proj(hidden_states)) + else: + q = F.silu(self.q_proj(hidden_states)) + k = F.silu(self.k_proj(hidden_states)) + v = F.silu(self.v_proj(hidden_states)) + + q = q.view(B, T, H, K) + k = k.view(B, T, H, K) + v = v.view(B, T, H, V) + + beta = self.b_proj(hidden_states).sigmoid() # [B, T, H] + g = -self.A_log.exp() * F.softplus(self.a_proj(hidden_states) + self.dt_bias) # [B, T, H] + + # Delta Rule Recurrence + # Normalization + q = q / (torch.norm(q, dim=-1, keepdim=True) + 1e-6) + k = k / (torch.norm(k, dim=-1, keepdim=True) + 1e-6) + + out = torch.zeros(B, T, H, V, device=hidden_states.device, dtype=hidden_states.dtype) + state = torch.zeros(B, H, K, V, device=hidden_states.device, dtype=hidden_states.dtype) + + for t in range(T): + q_t = q[:, t] # [B, H, K] + k_t = k[:, t] # [B, H, K] + v_t = v[:, t] # [B, H, V] + beta_t = beta[:, t] # [B, H] + g_t = g[:, t] # [B, H] + + # Decay state + state = state * torch.exp(g_t).view(B, H, 1, 1) + + # Delta update using decayed state + # H_t = H'_t + beta * (v_t - H'_t @ k_t) @ k_t^T + # (H @ k_t): [B, H, K, V] @ [B, H, K, 1] -> [B, H, V, 1] -> [B, H, V] + # Actually state is [K, V], k_t is [K]. So state.T @ k_t is [V]. + # In our setup: state is [K, V], k_t is [K]. state^T @ k_t -> [V] + # Let's use einsum for clarity + kv = torch.einsum('b h k v, b h k -> b h v', state, k_t) + + # dv = beta * (v_t - kv) + dv = beta_t.view(B, H, 1) * (v_t - kv) + + # state = state + k_t @ dv^T + state = state + torch.einsum('b h k, b h v -> b h k v', k_t, dv) + + # out_t = q_t @ state -> [B, H, V] + out[:, t] = torch.einsum('b h k, b h k v -> b h v', q_t, state) + + if self.use_gate: + gate = self.g_proj(hidden_states).view(B, T, H, V) + o = self.o_norm(out, gate) + else: + o = self.o_norm(out) + + o = o.reshape(B, T, -1) + return self.o_proj(o) + +# KernelBench utility functions +def get_inputs(): + B, T, C = 4, 32, 2048 # Keeping sequence length small for reference loop + hidden_states = torch.randn(B, T, C) + return [hidden_states] + +def get_init_inputs(): + return [2048, 2.0, 256, 6] # hidden_size, expand_v, head_dim, num_heads diff --git a/KernelBench/level9/gated_deltaproduct_reference.py b/KernelBench/level9/gated_deltaproduct_reference.py new file mode 100644 index 00000000..52951a60 --- /dev/null +++ b/KernelBench/level9/gated_deltaproduct_reference.py @@ -0,0 +1,288 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import math + + +class Model(nn.Module): + """ + Gated Delta Product - Reference implementation. + + GatedDeltaProduct is a generalized version that supports arbitrary number of + Householder transformations. It applies multiple Householder reflections + sequentially to transform the state. + + The core recurrence: + For each time step t: + 1. Apply forget gate: h = h * exp(g[t]) + 2. Apply num_householder Householder transformations sequentially: + For each j in num_householder: + h = h + (v[t,j] - (h @ k[t,j])) * k[t,j] * beta[t,j] + 3. Readout: o[t] = h @ q[t] + + Each Householder transformation is: H = I + beta * outer(k, v - h@k) + This is a rank-1 update that reflects the state. + + Based on: Generalized GatedDoubleDeltaNet with multiple Householder transformations. + """ + + def __init__( + self, + hidden_size: int = 2048, + expand_v: float = 2.0, + head_dim: int = 256, + num_heads: int = 6, + num_v_heads: int = None, + use_output_gate: bool = True, + use_forget_gate: bool = True, + allow_neg_eigval: bool = True, + num_householder: int = 2, + ): + super().__init__() + + self.hidden_size = hidden_size + self.expand_v = expand_v + self.head_dim = head_dim + self.num_heads = num_heads + self.num_v_heads = num_v_heads if num_v_heads is not None else num_heads + + self.head_k_dim = head_dim + self.head_v_dim = int(head_dim * expand_v) + self.key_dim = num_heads * head_dim + self.value_dim = int(self.num_v_heads * self.head_v_dim) + + self.use_output_gate = use_output_gate + self.use_forget_gate = use_forget_gate + self.allow_neg_eigval = allow_neg_eigval + self.num_householder = num_householder + + # Projections + # Q is normal size + self.q_proj = nn.Linear(hidden_size, self.key_dim, bias=False) + + # K and V are num_householder times larger (for multiple transformations) + self.k_proj = nn.Linear(hidden_size, self.key_dim * num_householder, bias=False) + self.v_proj = nn.Linear(hidden_size, self.value_dim * num_householder, bias=False) + + # Beta is also num_householder times larger + self.b_proj = nn.Linear(hidden_size, self.num_v_heads * num_householder, bias=False) + + # Forget gate (optional) + if use_forget_gate: + self.a_proj = nn.Linear(hidden_size, self.num_v_heads, bias=False) + + # Learnable decay parameters (A_log and dt_bias, like Mamba) + A = torch.empty(self.num_v_heads, dtype=torch.float32).uniform_(0, 16) + self.A_log = nn.Parameter(torch.log(A)) + + dt_min, dt_max = 0.001, 0.1 + dt = torch.exp(torch.rand(self.num_v_heads) * (math.log(dt_max) - math.log(dt_min)) + math.log(dt_min)) + dt = torch.clamp(dt, min=1e-4) + inv_dt = dt + torch.log(-torch.expm1(-dt)) + self.dt_bias = nn.Parameter(inv_dt) + + # Output gate and projection + if use_output_gate: + self.g_proj = nn.Linear(hidden_size, self.value_dim, bias=False) + self.o_proj = nn.Linear(self.value_dim, hidden_size, bias=False) + + # RMSNorm weight + self.o_norm_weight = nn.Parameter(torch.ones(self.head_v_dim)) + + def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None) -> torch.Tensor: + """ + Args: + x: Input tensor of shape (batch_size, seq_len, hidden_size) + attention_mask: Optional mask (unused in this reference) + + Returns: + Output tensor of shape (batch_size, seq_len, hidden_size) + """ + batch_size, seq_len, _ = x.shape + + # Project to Q, K, V + q = self.q_proj(x) # [B, T, key_dim] + k = self.k_proj(x) # [B, T, key_dim * num_householder] + v = self.v_proj(x) # [B, T, value_dim * num_householder] + + # Apply SiLU activation (simulating short convolution effect) + q = F.silu(q) + k = F.silu(k) + v = F.silu(v) + + # Reshape to multi-head format + q = q.view(batch_size, seq_len, self.num_heads, self.head_k_dim) + + # Reshape K and V: split into num_householder chunks + k = k.view(batch_size, seq_len, self.num_householder, self.num_heads, self.head_k_dim) + k = k.view(batch_size, seq_len * self.num_householder, self.num_heads, self.head_k_dim) + + v = v.view(batch_size, seq_len, self.num_householder, self.num_v_heads, self.head_v_dim) + v = v.view(batch_size, seq_len * self.num_householder, self.num_v_heads, self.head_v_dim) + + # Compute beta + beta = torch.sigmoid(self.b_proj(x)) # [B, T, num_v_heads * num_householder] + if self.allow_neg_eigval: + beta = beta * 2.0 # Allow range [0, 2] for negative eigenvalues + beta = beta.view(batch_size, seq_len, self.num_householder, self.num_v_heads) + beta = beta.view(batch_size, seq_len * self.num_householder, self.num_v_heads) + + # Compute forget gate (optional) + if self.use_forget_gate: + g = -self.A_log.float().exp() * F.softplus(self.a_proj(x).float() + self.dt_bias) + # g is [B, T, num_v_heads], but we need it for each time step + else: + g = None + + # Expand Q and K for GVA if needed + if self.num_v_heads > self.num_heads: + expand_ratio = self.num_v_heads // self.num_heads + q = q.repeat_interleave(expand_ratio, dim=2) + k = k.repeat_interleave(expand_ratio, dim=2) + + # ============================================ + # Gated Delta Product with Multiple Householder + # ============================================ + o = self._gated_delta_product(q, k, v, g, beta) + + # Output normalization + if self.use_output_gate: + gate = self.g_proj(x).view(batch_size, seq_len, self.num_v_heads, self.head_v_dim) + o = self._gated_rms_norm(o, gate, self.o_norm_weight) + else: + o = self._rms_norm(o, self.o_norm_weight) + + # Reshape and project output + o = o.view(batch_size, seq_len, self.value_dim) + o = self.o_proj(o) + + return o + + def _gated_delta_product( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor + ) -> torch.Tensor: + """ + Gated Delta Product with multiple Householder transformations. + + For each time step t: + 1. Apply forget gate: h = h * exp(g[t]) + 2. Apply num_householder Householder transformations: + For each j: h = h + (v[t,j] - h@k[t,j]) * k[t,j] * beta[t,j] + 3. Readout: o[t] = h @ q[t] + + Args: + q: [B, T, num_v_heads, head_k_dim] + k: [B, T*num_householder, num_v_heads, head_k_dim] + v: [B, T*num_householder, num_v_heads, head_v_dim] + g: [B, T, num_v_heads] or None - forget gate + beta: [B, T*num_householder, num_v_heads] - Householder strength + + Returns: + o: [B, T, num_v_heads, head_v_dim] + """ + B, T, H, K = q.shape + V = v.shape[-1] + + # Work in float32 for stability + q = q.float() + k = k.float() + v = v.float() + beta = beta.float() + if g is not None: + g = g.float() + + scale = K ** -0.5 + + # Initialize state: [B, H, K, V] + h = torch.zeros(B, H, K, V, device=q.device, dtype=torch.float32) + + outputs = [] + + for t in range(T): + q_t = q[:, t, :, :] # [B, H, K] + + # Apply forget gate if provided + if g is not None: + g_t = g[:, t, :] # [B, H] + decay = torch.exp(g_t).unsqueeze(-1).unsqueeze(-1) # [B, H, 1, 1] + h = h * decay + + # Apply num_householder Householder transformations sequentially + for j in range(self.num_householder): + idx = t * self.num_householder + j + k_tj = k[:, idx, :, :] # [B, H, K] + v_tj = v[:, idx, :, :] # [B, H, V] + beta_tj = beta[:, idx, :] # [B, H] + + # L2 normalize k (as done in kernel) + k_tj = F.normalize(k_tj, p=2, dim=-1) + + # Householder transformation: + # prediction = h @ k_tj: [B, H, V] = einsum('bhkv,bhk->bhv', h, k_tj) + prediction = torch.einsum('bhkv,bhk->bhv', h, k_tj) # [B, H, V] + + # delta = v_tj - prediction: [B, H, V] + delta = v_tj - prediction + + # Update: h = h + outer(k_tj, delta) * beta_tj + # h[b,h] += beta_tj[b,h] * outer(k_tj[b,h], delta[b,h]) + update = torch.einsum('bhk,bhv->bhkv', k_tj, delta) # [B, H, K, V] + beta_expanded = beta_tj.unsqueeze(-1).unsqueeze(-1) # [B, H, 1, 1] + h = h + update * beta_expanded + + # Readout: o[t] = h @ q_t + # For each batch and head: o[b,h] = h[b,h] @ q_t[b,h] + # h[b,h] is [K, V], q_t[b,h] is [K] + q_t_scaled = q_t * scale + o_t = torch.einsum('bhkv,bhk->bhv', h, q_t_scaled) # [B, H, V] + + outputs.append(o_t) + + # Stack outputs: [T, B, H, V] + outputs = torch.stack(outputs, dim=0) + + # Transpose to [B, T, H, V] + outputs = outputs.transpose(0, 1) + + return outputs + + def _rms_norm(self, x: torch.Tensor, weight: torch.Tensor) -> torch.Tensor: + """RMSNorm implementation.""" + rms = torch.sqrt(torch.mean(x.float() ** 2, dim=-1, keepdim=True) + 1e-5) + return (x.float() / rms * weight).to(x.dtype) + + def _gated_rms_norm(self, x: torch.Tensor, g: torch.Tensor, weight: torch.Tensor) -> torch.Tensor: + """Gated RMSNorm: RMSNorm(x) * sigmoid(g).""" + rms = torch.sqrt(torch.mean(x.float() ** 2, dim=-1, keepdim=True) + 1e-5) + x_norm = (x.float() / rms) * weight + return (x_norm * torch.sigmoid(g.float())).to(x.dtype) + + +# Problem dimensions +batch_size = 4 +seq_len = 512 +hidden_size = 2048 +expand_v = 2.0 +head_dim = 256 +num_heads = 6 +num_v_heads = 6 +use_output_gate = True +use_forget_gate = True +allow_neg_eigval = True +num_householder = 2 + + +def get_inputs(): + x = torch.randn(batch_size, seq_len, hidden_size) + return [x] + + +def get_init_inputs(): + return [hidden_size, expand_v, head_dim, num_heads, num_v_heads, + use_output_gate, use_forget_gate, allow_neg_eigval, num_householder] + diff --git a/KernelBench/level9/gla_reference.py b/KernelBench/level9/gla_reference.py new file mode 100644 index 00000000..b637faae --- /dev/null +++ b/KernelBench/level9/gla_reference.py @@ -0,0 +1,149 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import math + +class RMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-5): + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.eps = eps + + def forward(self, x): + norm_x = torch.mean(x**2, dim=-1, keepdim=True) + x_normed = x * torch.rsqrt(norm_x + self.eps) + return x_normed * self.weight + +class Model(nn.Module): + """ + Reference implementation of Gated Linear Attention (GLA). + """ + def __init__( + self, + hidden_size: int = 1024, + expand_k: float = 0.5, + expand_v: float = 1.0, + num_heads: int = 4, + num_kv_heads: int = None, + use_short_conv: bool = False, + conv_size: int = 4, + use_output_gate: bool = True, + gate_logit_normalizer: int = 16, + gate_low_rank_dim: int = 16, + norm_eps: float = 1e-5, + ): + super(Model, self).__init__() + self.hidden_size = hidden_size + self.expand_k = expand_k + self.expand_v = expand_v + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads if num_kv_heads is not None else num_heads + self.num_kv_groups = num_heads // self.num_kv_heads + self.use_short_conv = use_short_conv + self.use_output_gate = use_output_gate + self.gate_logit_normalizer = gate_logit_normalizer + + self.key_dim = int(hidden_size * expand_k) + self.value_dim = int(hidden_size * expand_v) + self.head_k_dim = self.key_dim // num_heads + self.head_v_dim = self.value_dim // num_heads + + self.q_proj = nn.Linear(hidden_size, self.key_dim, bias=False) + self.k_proj = nn.Linear(hidden_size, self.key_dim // self.num_kv_groups, bias=False) + self.v_proj = nn.Linear(hidden_size, self.value_dim // self.num_kv_groups, bias=False) + + if self.use_output_gate: + self.g_proj = nn.Linear(hidden_size, self.value_dim, bias=False) + + if use_short_conv: + self.q_conv1d = nn.Conv1d(self.key_dim, self.key_dim, conv_size, groups=self.key_dim, padding=conv_size-1) + self.k_conv1d = nn.Conv1d(self.key_dim // self.num_kv_groups, self.key_dim // self.num_kv_groups, conv_size, groups=self.key_dim // self.num_kv_groups, padding=conv_size-1) + self.v_conv1d = nn.Conv1d(self.value_dim // self.num_kv_groups, self.value_dim // self.num_kv_groups, conv_size, groups=self.value_dim // self.num_kv_groups, padding=conv_size-1) + + self.gk_proj = nn.Sequential( + nn.Linear(hidden_size, gate_low_rank_dim, bias=False), + nn.Linear(gate_low_rank_dim, self.key_dim // self.num_kv_groups, bias=True) + ) + + self.g_norm = RMSNorm(self.head_v_dim, eps=norm_eps) + self.o_proj = nn.Linear(self.value_dim, hidden_size, bias=False) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + """ + Args: + hidden_states (torch.Tensor): [B, T, H] + Returns: + torch.Tensor: [B, T, H] + """ + B, T, _ = hidden_states.shape + H, HK, HV = self.num_heads, self.head_k_dim, self.head_v_dim + + q = self.q_proj(hidden_states) + k = self.k_proj(hidden_states) + v = self.v_proj(hidden_states) + gk = self.gk_proj(hidden_states) + + if self.use_short_conv: + # Conv1d expects [B, C, T] + q = F.silu(self.q_conv1d(q.transpose(1, 2))[..., :T].transpose(1, 2)) + k = F.silu(self.k_conv1d(k.transpose(1, 2))[..., :T].transpose(1, 2)) + v = F.silu(self.v_conv1d(v.transpose(1, 2))[..., :T].transpose(1, 2)) + + # Reshape and handle GQA + q = q.view(B, T, H, HK) + if self.num_kv_groups > 1: + k = k.repeat_interleave(self.num_kv_groups, dim=-2).view(B, T, H, HK) + v = v.repeat_interleave(self.num_kv_groups, dim=-2).view(B, T, H, HV) + gk = gk.repeat_interleave(self.num_kv_groups, dim=-2).view(B, T, H, HK) + else: + k = k.view(B, T, H, HK) + v = v.view(B, T, H, HV) + gk = gk.view(B, T, H, HK) + + # Gate pre-processing + gk = F.logsigmoid(gk) / self.gate_logit_normalizer + + # Core Recurrence + # S_t = S_{t-1} * exp(gk_t) + k_t @ v_t^T + # o_t = q_t @ S_t + + q, k, v, gk = q.float(), k.float(), v.float(), gk.float() + scale = HK ** -0.5 + + S = torch.zeros(B, H, HK, HV, device=hidden_states.device, dtype=torch.float32) + o = torch.zeros(B, T, H, HV, device=hidden_states.device, dtype=torch.float32) + + for t in range(T): + q_t = q[:, t] * scale # [B, H, HK] + k_t = k[:, t] # [B, H, HK] + v_t = v[:, t] # [B, H, HV] + gk_t = gk[:, t].exp() # [B, H, HK] + + # S_t = S_{t-1} * decay + outer_product(k, v) + S = S * gk_t.unsqueeze(-1) + torch.einsum('b h k, b h v -> b h k v', k_t, v_t) + + # o_t = q_t @ S_t + o[:, t] = torch.einsum('b h k, b h k v -> b h v', q_t, S) + + # Output processing + o = self.g_norm(o) + o = o.view(B, T, -1) + + if self.use_output_gate: + g = self.g_proj(hidden_states) + o = o * F.silu(g) # swish is silu + + return self.o_proj(o.to(hidden_states.dtype)) + +# Kernelbench Parameters +batch_size = 2 +seq_len = 128 +hidden_size = 512 +num_heads = 4 + +def get_inputs(): + x = torch.randn(batch_size, seq_len, hidden_size) + return [x] + +def get_init_inputs(): + return [hidden_size] diff --git a/KernelBench/level9/grpo_reference.py b/KernelBench/level9/grpo_reference.py new file mode 100644 index 00000000..cef28b0f --- /dev/null +++ b/KernelBench/level9/grpo_reference.py @@ -0,0 +1,61 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + """ + Reference implementation of GRPO (Group Relative Policy Optimization) Loss. + This module computes the policy loss and KL divergence term for a given batch of completions. + """ + def __init__(self, beta: float = 0.1): + super(Model, self).__init__() + self.beta = beta + + def forward(self, logits: torch.Tensor, ref_logp: torch.Tensor, input_ids: torch.Tensor, advantages: torch.Tensor, completion_mask: torch.Tensor) -> torch.Tensor: + """ + Args: + logits (torch.Tensor): [B, L, V] Model logits for the completions + ref_logp (torch.Tensor): [B, L] Reference model log probabilities + input_ids (torch.Tensor): [B, L] Actual token IDs for the completions + advantages (torch.Tensor): [B] Group relative advantages + completion_mask (torch.Tensor): [B, L] Mask for valid completion tokens + Returns: + torch.Tensor: Scalar GRPO loss + """ + # 1. Get per-token log probabilities from logits and input_ids + log_probs = F.log_softmax(logits, dim=-1) + per_token_logps = torch.gather(log_probs, dim=-1, index=input_ids.unsqueeze(-1)).squeeze(-1) + + # 2. Compute KL divergence: KL(Ref || Policy) = exp(ref_logp - logp) - (ref_logp - logp) - 1 + # This is a common approximation used in GRPO/PPO + diff = ref_logp - per_token_logps + per_token_kl = torch.exp(diff) - diff - 1 + + # 3. Compute the policy loss part + # loss = - (exp(logp - logp_old) * advantage - beta * kl) + # Assuming logp_old = logp.detach() for the first iteration step + per_token_loss = torch.exp(per_token_logps - per_token_logps.detach()) * advantages.unsqueeze(1) + per_token_loss = -(per_token_loss - self.beta * per_token_kl) + + # 4. Mask and reduce + masked_loss = per_token_loss * completion_mask + # Average over valid completion tokens per sequence, then average over batch + loss = (masked_loss.sum(dim=1) / completion_mask.sum(dim=1)).mean() + + return loss + +# Kernelbench Parameters +batch_size = 16 +seq_len = 128 +vocab_size = 32000 + +def get_inputs(): + logits = torch.randn(batch_size, seq_len, vocab_size) + ref_logp = torch.randn(batch_size, seq_len) + input_ids = torch.randint(0, vocab_size, (batch_size, seq_len)) + advantages = torch.randn(batch_size) + completion_mask = torch.ones(batch_size, seq_len) + return [logits, ref_logp, input_ids, advantages, completion_mask] + +def get_init_inputs(): + return [0.1] # beta diff --git a/KernelBench/level9/gsa_reference.py b/KernelBench/level9/gsa_reference.py new file mode 100644 index 00000000..1926fe6b --- /dev/null +++ b/KernelBench/level9/gsa_reference.py @@ -0,0 +1,129 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import math + +class Model(nn.Module): + """ + Reference implementation of Gated Slot Attention (GSA). + """ + def __init__( + self, + hidden_size: int = 1024, + num_heads: int = 4, + num_kv_heads: int = 4, + head_k_dim: int = 64, + head_v_dim: int = 64, + num_slots: int = 64, + norm_eps: float = 1e-5, + ): + super(Model, self).__init__() + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_k_dim = head_k_dim + self.head_v_dim = head_v_dim + self.num_slots = num_slots + self.norm_eps = norm_eps + + self.q_proj = nn.Linear(hidden_size, num_heads * head_k_dim, bias=False) + self.k_proj = nn.Linear(hidden_size, num_kv_heads * head_k_dim, bias=False) + self.v_proj = nn.Linear(hidden_size, num_kv_heads * head_v_dim, bias=False) + self.f_proj = nn.Linear(hidden_size, num_kv_heads * num_slots, bias=False) + + self.g_norm_weight = nn.Parameter(torch.ones(num_heads * head_v_dim)) + self.o_proj = nn.Linear(num_heads * head_v_dim, hidden_size, bias=False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Args: + x (torch.Tensor): [batch_size, seq_len, hidden_size] + Returns: + torch.Tensor: [batch_size, seq_len, hidden_size] + """ + B, T, _ = x.shape + H, HK, HV, M = self.num_heads, self.head_k_dim, self.head_v_dim, self.num_slots + HKV = self.num_kv_heads + NG = H // HKV + + # Projections + q = self.q_proj(x).view(B, T, H, HK) + k = self.k_proj(x).view(B, T, HKV, HK) + v = self.v_proj(x).view(B, T, HKV, HV) + f = self.f_proj(x).view(B, T, HKV, M) + + # Swish feature map for q, k and silu for v + q = q * q.sigmoid() + k = k * k.sigmoid() + v = v * v.sigmoid() + + # Gating: f is log-decay, s is the complementary weight + f = F.logsigmoid(f) / 8.0 # gate_logit_normalizer=8 + s = (1.0 - f.exp()) + + # Grouped Query Attention: tile k, v, f, s if num_heads > num_kv_heads + if NG > 1: + k = k.repeat_interleave(NG, dim=2) + v = v.repeat_interleave(NG, dim=2) + f = f.repeat_interleave(NG, dim=2) + s = s.repeat_interleave(NG, dim=2) + + q, k, v, f, s = q.float(), k.float(), v.float(), f.float(), s.float() + + # First recurrence: compute soft slots assignment ok + hk = torch.zeros(B, H, HK, M, device=x.device, dtype=torch.float32) + ok = torch.zeros(B, T, H, M, device=x.device, dtype=torch.float32) + scale = HK ** -0.5 + + for i in range(T): + q_i = q[:, i] * scale + k_i = k[:, i] + v_i = s[:, i] + g_i = f[:, i].exp() + # hk state update: hk = hk * decay + k @ s.T + hk = hk * g_i.unsqueeze(-2) + torch.einsum('b h k, b h m -> b h k m', k_i, v_i) + # Read from hk: ok = q.T @ hk + ok[:, i] = torch.einsum('b h k, b h k m -> b h m', q_i, hk) + + # Global softmax over slots + qv = F.softmax(ok, dim=-1) + + # Second recurrence: compute output ov based on soft slots + hv = torch.zeros(B, H, M, HV, device=x.device, dtype=torch.float32) + ov = torch.zeros(B, T, H, HV, device=x.device, dtype=torch.float32) + + for i in range(T): + q_i = qv[:, i] + k_i = s[:, i] + v_i = v[:, i] + g_i = f[:, i].exp() + # hv state update: hv = hv * decay + s @ v.T + hv = hv * g_i.unsqueeze(-1) + torch.einsum('b h m, b h v -> b h m v', k_i, v_i) + # Read from hv: ov = qv.T @ hv + ov[:, i] = torch.einsum('b h m, b h m v -> b h v', q_i, hv) + + # Final output processing + o = ov.reshape(B, T, -1) + o = F.silu(o) + + # RMSNorm + o = o * torch.rsqrt(o.pow(2).mean(-1, keepdim=True) + self.norm_eps) + o = o * self.g_norm_weight + + return self.o_proj(o.to(x.dtype)) + +# Kernelbench Parameters +batch_size = 2 +seq_len = 64 +hidden_size = 1024 +num_heads = 4 +num_kv_heads = 4 +head_k_dim = 64 +head_v_dim = 64 +num_slots = 64 + +def get_inputs(): + x = torch.randn(batch_size, seq_len, hidden_size) + return [x] + +def get_init_inputs(): + return [hidden_size, num_heads, num_kv_heads, head_k_dim, head_v_dim, num_slots] diff --git a/KernelBench/level9/hgrn2_reference.py b/KernelBench/level9/hgrn2_reference.py new file mode 100644 index 00000000..8edb3709 --- /dev/null +++ b/KernelBench/level9/hgrn2_reference.py @@ -0,0 +1,120 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +def swish(x): + return x * torch.sigmoid(x) + +class RMSNorm(nn.Module): + """ + Root Mean Square Layer Normalization. + """ + def __init__(self, hidden_size, eps=1e-5): + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.ones(hidden_size)) + + def forward(self, x): + # standard rms norm implementation + norm = torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + return x * norm * self.weight + +class Model(nn.Module): + """ + HGRN2 (Gated Linear RNN with State Expansion) Reference Implementation. + This model implements the core linear attention mechanism of HGRN2. + """ + def __init__(self, hidden_size: int = 2048, num_heads: int = 16, expand_ratio: int = 128): + super(Model, self).__init__() + self.hidden_size = hidden_size + self.num_heads = num_heads + self.head_dim = expand_ratio # State expansion dimension (head_f_dim in FLA) + self.v_head_dim = hidden_size // num_heads # Value dimension per head (head_i_dim in FLA) + + self.q_proj = nn.Linear(hidden_size, num_heads * self.head_dim, bias=False) + self.f_proj = nn.Linear(hidden_size, num_heads * self.head_dim, bias=False) + self.i_proj = nn.Linear(hidden_size, hidden_size, bias=False) + + self.g_norm = RMSNorm(hidden_size) + self.o_proj = nn.Linear(hidden_size, hidden_size, bias=False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Performs the HGRN2 linear attention forward pass. + + Args: + x (torch.Tensor): Input tensor of shape (batch_size, seq_len, hidden_size). + + Returns: + torch.Tensor: Output tensor of shape (batch_size, seq_len, hidden_size). + """ + batch_size, seq_len, _ = x.shape + + # Linear projections + q = self.q_proj(x) + f = self.f_proj(x) + i = self.i_proj(x) + + # Apply HGRN2 gating logic + q = swish(q) + g = F.logsigmoid(f) + k = 1 - g.exp() + + # Rearrange tensors for multi-head attention + # Shape: [batch, heads, seq_len, head_dim] + q = q.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) + k = k.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) + g = g.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) + # Shape: [batch, heads, seq_len, v_head_dim] + v = i.view(batch_size, seq_len, self.num_heads, self.v_head_dim).transpose(1, 2) + + # Naive recurrent Linear Attention (GLA) logic + # Implementation of: h_t = exp(g_t) * h_{t-1} + k_t @ v_t^T + # o_t = q_t @ h_t + + # Use float32 for stable recurrence + dtype = q.dtype + q, k, v, g = q.float(), k.float(), v.float(), g.float() + + scale = self.head_dim ** -0.5 + # state: [batch, heads, head_dim, v_head_dim] + h = torch.zeros(batch_size, self.num_heads, self.head_dim, self.v_head_dim, device=x.device, dtype=torch.float32) + o = torch.zeros(batch_size, self.num_heads, seq_len, self.v_head_dim, device=x.device, dtype=torch.float32) + + for t in range(seq_len): + q_t = q[:, :, t] * scale + k_t = k[:, :, t] + v_t = v[:, :, t] + g_t = g[:, :, t].exp() + + # Update RNN state + # k_t: [batch, heads, head_dim], v_t: [batch, heads, v_head_dim] -> kv_t: [batch, heads, head_dim, v_head_dim] + kv_t = torch.einsum('b h d, b h v -> b h d v', k_t, v_t) + h = h * g_t.unsqueeze(-1) + kv_t + + # Output + o[:, :, t] = torch.einsum('b h d, b h d v -> b h v', q_t, h) + + o = o.to(dtype) + + # Final projection and normalization + # [batch, heads, seq_len, v_head_dim] -> [batch, seq_len, hidden_size] + res = o.transpose(1, 2).reshape(batch_size, seq_len, -1) + res = self.g_norm(res) + res = self.o_proj(res) + + return res + +# Dimensions for testing +batch_size = 2 +seq_len = 1024 +hidden_size = 2048 +num_heads = 16 +expand_ratio = 128 + +def get_inputs(): + x = torch.randn(batch_size, seq_len, hidden_size) + return [x] + +def get_init_inputs(): + return [hidden_size, num_heads, expand_ratio] diff --git a/KernelBench/level9/hgrn_reference.py b/KernelBench/level9/hgrn_reference.py new file mode 100644 index 00000000..f5dd323f --- /dev/null +++ b/KernelBench/level9/hgrn_reference.py @@ -0,0 +1,121 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + """ + Reference implementation of HGRN (Hierarchically Gated Recurrent Network). + Adapted for kernelbench reference format. + """ + def __init__(self, hidden_size: int = 1024, expand_ratio: int = 1, use_short_conv: bool = True, conv_size: int = 4): + super(Model, self).__init__() + + self.hidden_size = hidden_size + self.expand_ratio = expand_ratio + self.input_dim = int(hidden_size * expand_ratio) + self.use_short_conv = use_short_conv + self.conv_size = conv_size + + # Linear projections for i (input), f (forget gate), and g (output gate) + self.i_proj = nn.Linear(hidden_size, self.input_dim, bias=False) + self.f_proj = nn.Linear(hidden_size, self.input_dim, bias=False) + self.g_proj = nn.Linear(hidden_size, self.input_dim, bias=False) + + if use_short_conv: + # Depthwise convolutions for i and f to provide some local temporal context + self.i_conv1d = nn.Conv1d( + in_channels=self.input_dim, + out_channels=self.input_dim, + kernel_size=conv_size, + groups=self.input_dim, + bias=False, + padding=conv_size - 1, + ) + self.f_conv1d = nn.Conv1d( + in_channels=self.input_dim, + out_channels=self.input_dim, + kernel_size=conv_size, + groups=self.input_dim, + bias=False, + padding=conv_size - 1, + ) + + # Gated RMSNorm weight + self.g_norm_weight = nn.Parameter(torch.ones(self.input_dim)) + self.norm_eps = 1e-5 + + # Output projection + self.o_proj = nn.Linear(self.input_dim, hidden_size, bias=False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Forward pass for HGRN. + + Args: + x (torch.Tensor): Input tensor of shape (batch, seq_len, hidden_size). + + Returns: + torch.Tensor: Output tensor of shape (batch, seq_len, hidden_size). + """ + B, T, H = x.shape + + # 1. Projections + i_raw = self.i_proj(x) + f_raw = self.f_proj(x) + g_raw = self.g_proj(x) + + # 2. Short Convolution (Causal 1D Depthwise) + if self.use_short_conv: + # i_raw: [B, T, D] -> [B, D, T] + i_raw = i_raw.transpose(1, 2) + # nn.Conv1d with padding=K-1 and taking [:T] is causal + i_raw = self.i_conv1d(i_raw)[..., :T].transpose(1, 2) + f_raw = f_raw.transpose(1, 2) + f_raw = self.f_conv1d(f_raw)[..., :T].transpose(1, 2) + + # 3. HGRN Gates + # f_log: log of the forget gate: logsigmoid(f_raw) + # forget_gate: sigmoid(f_raw) + f_log = F.logsigmoid(f_raw) + forget_gate = f_log.exp() + # i: input modulated by silu activation and (1 - forget_gate) + # Matches swiglu(i_raw, 1 - exp(f_log)) in the original code + i = F.silu(i_raw) * (1.0 - forget_gate) + + # 4. Recurrence: h_t = forget_gate_t * h_{t-1} + i_t + # Using a loop for numerical stability as this is a reference implementation. + h = torch.zeros(B, self.input_dim, device=x.device, dtype=x.dtype) + o_rec = torch.zeros(B, T, self.input_dim, device=x.device, dtype=x.dtype) + + for t in range(T): + h = forget_gate[:, t] * h + i[:, t] + o_rec[:, t] = h + + # 5. Gated RMSNorm + # Formula: RMSNorm(h) * weight * silu(g_raw) + # Normalized by the hidden dimension + rms = torch.rsqrt(o_rec.pow(2).mean(-1, keepdim=True) + self.norm_eps) + o_norm = o_rec * rms * self.g_norm_weight + o = o_norm * F.silu(g_raw) + + # 6. Final output projection + o = self.o_proj(o) + + return o + +# Hyperparameters +hidden_size = 1024 +expand_ratio = 1 +use_short_conv = True +conv_size = 4 + +# Test dimensions +batch_size = 8 +seq_len = 2048 + +def get_inputs(): + x = torch.randn(batch_size, seq_len, hidden_size) + return [x] + +def get_init_inputs(): + return [hidden_size, expand_ratio, use_short_conv, conv_size] diff --git a/KernelBench/level9/kayvon_rl_kernel.py b/KernelBench/level9/kayvon_rl_kernel.py new file mode 100644 index 00000000..f3d01a22 --- /dev/null +++ b/KernelBench/level9/kayvon_rl_kernel.py @@ -0,0 +1,50 @@ +import torch + + +def l2normalize( + tensor: torch.Tensor, axis: int = -1, eps: float = 1e-8 +) -> torch.Tensor: + """Computes L2 normalization of a tensor.""" + return tensor / (torch.linalg.norm(tensor, ord=2, dim=axis, keepdim=True) + eps) + + +class Model(torch.nn.Module): + """HyperEmbedder-inspired reference used in the Kayvon RL kernel.""" + + def forward( + self, + x: torch.Tensor, + c_shift: float, + W: torch.Tensor, + scale: torch.Tensor, + ) -> torch.Tensor: + new_axis = torch.full((*x.shape[:-1], 1), c_shift, device=x.device, dtype=x.dtype) + concatenated = torch.cat([x, new_axis], dim=-1) + + # l2 norm + normalized = l2normalize(concatenated, axis=-1) + + # Linear layer followed by scaler + projected = torch.matmul(normalized, W) + scaled = projected * scale + + # l2 norm + return l2normalize(scaled, axis=-1) + + +# Problem configuration +batch_size = 64 +in_features = 128 +hidden_size = 256 + + +def get_inputs(): + x = torch.randn(batch_size, in_features) + W = torch.randn(in_features + 1, hidden_size) + scale = torch.randn(hidden_size) + c_shift = torch.randn(1).item() + return [x, c_shift, W, scale] + + +def get_init_inputs(): + return [] diff --git a/KernelBench/level9/kda_reference.py b/KernelBench/level9/kda_reference.py new file mode 100644 index 00000000..d98d7475 --- /dev/null +++ b/KernelBench/level9/kda_reference.py @@ -0,0 +1,179 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import math + +class Model(nn.Module): + """ + Reference implementation of Kimi Delta Attention (KDA). + """ + def __init__( + self, + hidden_size: int = 1024, + expand_v: float = 1.0, + head_dim: int = 64, + num_heads: int = 16, + num_v_heads: int = None, + use_short_conv: bool = True, + conv_size: int = 4, + norm_eps: float = 1e-5, + ): + super(Model, self).__init__() + self.hidden_size = hidden_size + self.expand_v = expand_v + self.head_dim = head_dim + self.num_heads = num_heads + self.num_v_heads = num_v_heads if num_v_heads is not None else num_heads + self.head_v_dim = int(head_dim * expand_v) + self.key_dim = num_heads * head_dim + self.value_dim = self.num_v_heads * self.head_v_dim + self.norm_eps = norm_eps + + self.q_proj = nn.Linear(hidden_size, self.key_dim, bias=False) + self.k_proj = nn.Linear(hidden_size, self.key_dim, bias=False) + self.v_proj = nn.Linear(hidden_size, self.value_dim, bias=False) + + self.use_short_conv = use_short_conv + if use_short_conv: + self.q_conv1d = nn.Conv1d(self.key_dim, self.key_dim, conv_size, groups=self.key_dim, padding=conv_size-1) + self.k_conv1d = nn.Conv1d(self.key_dim, self.key_dim, conv_size, groups=self.key_dim, padding=conv_size-1) + self.v_conv1d = nn.Conv1d(self.value_dim, self.value_dim, conv_size, groups=self.value_dim, padding=conv_size-1) + + self.f_proj = nn.Sequential( + nn.Linear(hidden_size, self.head_v_dim, bias=False), + nn.Linear(self.head_v_dim, self.key_dim, bias=False), + ) + self.b_proj = nn.Linear(hidden_size, self.num_heads, bias=False) + + self.A_log = nn.Parameter(torch.log(torch.empty(self.num_heads).uniform_(1, 16))) + self.dt_bias = nn.Parameter(torch.zeros(self.key_dim)) + + self.g_proj = nn.Sequential( + nn.Linear(hidden_size, self.head_v_dim, bias=False), + nn.Linear(self.head_v_dim, self.value_dim, bias=True), + ) + + self.o_norm_weight = nn.Parameter(torch.ones(self.head_v_dim)) + self.o_proj = nn.Linear(self.value_dim, hidden_size, bias=False) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + """ + Args: + hidden_states (torch.Tensor): Input tensor of shape (B, T, H_in). + Returns: + torch.Tensor: Output tensor of shape (B, T, H_in). + """ + B, T, _ = hidden_states.shape + H, HK, HV = self.num_heads, self.head_dim, self.head_v_dim + NV = self.num_v_heads + + # Projections + q = self.q_proj(hidden_states) + k = self.k_proj(hidden_states) + v = self.v_proj(hidden_states) + + if self.use_short_conv: + # Conv1d expects [B, C, T] + q = F.silu(self.q_conv1d(q.transpose(1, 2))[..., :T].transpose(1, 2)) + k = F.silu(self.k_conv1d(k.transpose(1, 2))[..., :T].transpose(1, 2)) + v = F.silu(self.v_conv1d(v.transpose(1, 2))[..., :T].transpose(1, 2)) + else: + q, k, v = F.silu(q), F.silu(k), F.silu(v) + + g = self.f_proj(hidden_states) + beta = self.b_proj(hidden_states).sigmoid() + + # Reshape for multi-head + q = q.view(B, T, H, HK) + k = k.view(B, T, H, HK) + g = g.view(B, T, H, HK) + v = v.view(B, T, NV, HV) + + # Expand if GVA + if NV > H: + groups = NV // H + q = q.repeat_interleave(groups, dim=2) + k = k.repeat_interleave(groups, dim=2) + g = g.repeat_interleave(groups, dim=2) + beta = beta.repeat_interleave(groups, dim=2) + H_eff = NV + else: + H_eff = H + + # Post-process g and scale q + A = self.A_log.exp() + if NV > H: A = A.repeat_interleave(NV // H) + + dt_bias = self.dt_bias.view(H, HK) + if NV > H: dt_bias = dt_bias.repeat_interleave(NV // H, dim=0) + dt_bias = dt_bias.view(1, 1, H_eff, HK) + + # g = -exp(A) * softplus(g + dt_bias) + g = -A.view(1, 1, H_eff, 1) * F.softplus(g + dt_bias) + + scale = HK ** -0.5 + + # Reference recurrence loop + # S: [B, H_eff, HK, HV] + S = torch.zeros(B, H_eff, HK, HV, device=hidden_states.device, dtype=torch.float32) + o = torch.zeros(B, T, H_eff, HV, device=hidden_states.device, dtype=torch.float32) + + q, k, v, g, beta = q.float(), k.float(), v.float(), g.float(), beta.float() + + for t in range(T): + q_t = q[:, t] # [B, H_eff, HK] + k_t = k[:, t] # [B, H_eff, HK] + v_t = v[:, t] # [B, H_eff, HV] + g_t = g[:, t] # [B, H_eff, HK] + beta_t = beta[:, t] # [B, H_eff] + + # QK L2 Norm + q_t = q_t / (torch.norm(q_t, p=2, dim=-1, keepdim=True) + 1e-6) + k_t = k_t / (torch.norm(k_t, p=2, dim=-1, keepdim=True) + 1e-6) + q_t = q_t * scale + + # S_t = S_{t-1} * exp(g_t) + S = S * g_t.exp().unsqueeze(-1) + + # Delta Update: v_update = beta_t * (v_t - k_t^T * S_t) + # k_t: [B, H_eff, HK], S: [B, H_eff, HK, HV] + # kS: [B, H_eff, HV] + kS = torch.einsum('b h k, b h k v -> b h v', k_t, S) + v_update = beta_t.unsqueeze(-1) * (v_t - kS) + + # S_t = S_t + k_t * v_update^T + S = S + torch.einsum('b h k, b h v -> b h k v', k_t, v_update) + + # o_t = S_t^T * q_t + o[:, t] = torch.einsum('b h k, b h k v -> b h v', q_t, S) + + # Output Norm and Gating + # g_proj for final gating + gate = self.g_proj(hidden_states).view(B, T, NV, HV) + + # RMSNorm per head + # o: [B, T, NV, HV] + o_norm = o * torch.rsqrt(o.pow(2).mean(-1, keepdim=True) + self.norm_eps) + o_norm = o_norm * self.o_norm_weight.view(1, 1, 1, HV) + + # Apply gate + o = o_norm * gate.sigmoid() + + # Project back + o = o.view(B, T, NV * HV) + return self.o_proj(o.to(hidden_states.dtype)) + +# Kernelbench Parameters +batch_size = 4 +seq_len = 128 # Smaller for reference loop +hidden_size = 1024 +expand_v = 1.0 +head_dim = 64 +num_heads = 16 + +def get_inputs(): + hidden_states = torch.randn(batch_size, seq_len, hidden_size) + return [hidden_states] + +def get_init_inputs(): + return [hidden_size, expand_v, head_dim, num_heads] diff --git a/KernelBench/level9/l2_norm_reference.py b/KernelBench/level9/l2_norm_reference.py new file mode 100644 index 00000000..78e01e6f --- /dev/null +++ b/KernelBench/level9/l2_norm_reference.py @@ -0,0 +1,35 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + """ + Reference implementation for L2 Normalization. + Normalizes the input tensor by its L2 norm along the last dimension. + """ + def __init__(self, eps: float = 1e-6): + super(Model, self).__init__() + self.eps = eps + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Args: + x (torch.Tensor): Input tensor of shape [batch_size, seq_len, hidden_size] + Returns: + torch.Tensor: L2-normalized tensor of shape [batch_size, seq_len, hidden_size] + """ + # Compute L2 norm along the last dimension + # norm = sqrt(sum(x^2)) + return x * torch.rsqrt(x.pow(2).sum(-1, keepdim=True) + self.eps) + +# Kernelbench Parameters +batch_size = 16 +seq_len = 512 +hidden_size = 1024 + +def get_inputs(): + x = torch.randn(batch_size, seq_len, hidden_size) + return [x] + +def get_init_inputs(): + return [1e-6] # eps diff --git a/KernelBench/level9/l2_wrap_reference.py b/KernelBench/level9/l2_wrap_reference.py new file mode 100644 index 00000000..196f7b8f --- /dev/null +++ b/KernelBench/level9/l2_wrap_reference.py @@ -0,0 +1,56 @@ +import torch +import torch.nn as nn + +class L2WrapFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, loss, logits, penalty_factor=1e-4): + # Find max logit per token for the penalty + # This is a trick used in some models to prevent logit drift + maxx, ids = torch.max(logits, dim=-1, keepdim=True) + ctx.logits_shape = logits.shape + # Average penalty over batch and sequence + factor = penalty_factor / (logits.shape[0] * logits.shape[1]) + ctx.save_for_backward(maxx * factor, ids) + return loss + + @staticmethod + def backward(ctx, grad_output): + maxx_scaled, ids = ctx.saved_tensors + glogits = torch.zeros(ctx.logits_shape, device=grad_output.device, dtype=grad_output.dtype) + # Scatter the scaled max value back to the winning logit's position + glogits.scatter_(-1, ids, maxx_scaled) + # grad_output is the gradient of the loss + return grad_output, glogits, None + +class Model(nn.Module): + """ + Reference implementation of L2Wrap. + Maintains the loss value in forward but adds a logit-dependent penalty to the gradient. + """ + def __init__(self, penalty_factor: float = 1e-4): + super(Model, self).__init__() + self.penalty_factor = penalty_factor + + def forward(self, loss: torch.Tensor, logits: torch.Tensor) -> torch.Tensor: + """ + Args: + loss (torch.Tensor): Scalar loss tensor. + logits (torch.Tensor): Logits tensor of shape [B, T, V]. + Returns: + torch.Tensor: The same scalar loss tensor. + """ + return L2WrapFunction.apply(loss, logits, self.penalty_factor) + +# Kernelbench Parameters +batch_size = 8 +seq_len = 1024 +vocab_size = 32000 + +def get_inputs(): + # Loss is usually a single scalar + loss = torch.tensor(2.5, requires_grad=True) + logits = torch.randn(batch_size, seq_len, vocab_size, requires_grad=True) + return [loss, logits] + +def get_init_inputs(): + return [1e-4] # penalty_factor diff --git a/KernelBench/level9/layernorm_gated_reference.py b/KernelBench/level9/layernorm_gated_reference.py new file mode 100644 index 00000000..707dd469 --- /dev/null +++ b/KernelBench/level9/layernorm_gated_reference.py @@ -0,0 +1,53 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + """ + Reference implementation for Gated Layer/RMS Normalization. + Supports gated normalization where the gate 'z' is applied either before or after the normalization. + """ + def __init__(self, hidden_size: int, eps: float = 1e-5, is_rms_norm: bool = True, norm_before_gate: bool = True): + super(Model, self).__init__() + self.is_rms_norm = is_rms_norm + self.norm_before_gate = norm_before_gate + self.eps = eps + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.bias = nn.Parameter(torch.zeros(hidden_size)) if not is_rms_norm else None + + def forward(self, x: torch.Tensor, z: torch.Tensor) -> torch.Tensor: + """ + Args: + x (torch.Tensor): Input tensor [B, T, H] + z (torch.Tensor): Gate tensor [B, T, H] + Returns: + torch.Tensor: Gated normalized output [B, T, H] + """ + if not self.norm_before_gate: + x = x * F.silu(z) + + if self.is_rms_norm: + # RMSNorm + rstd = torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + x = x * rstd * self.weight + else: + # LayerNorm + x = F.layer_norm(x, x.shape[-1:], self.weight, self.bias, self.eps) + + if self.norm_before_gate: + x = x * F.silu(z) + + return x + +# Kernelbench Parameters +batch_size = 4 +seq_len = 2048 +hidden_size = 4096 + +def get_inputs(): + x = torch.randn(batch_size, seq_len, hidden_size) + z = torch.randn(batch_size, seq_len, hidden_size) + return [x, z] + +def get_init_inputs(): + return [hidden_size] diff --git a/KernelBench/level9/layernorm_reference.py b/KernelBench/level9/layernorm_reference.py new file mode 100644 index 00000000..28a688b8 --- /dev/null +++ b/KernelBench/level9/layernorm_reference.py @@ -0,0 +1,60 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + """ + Reference implementation for Layer/RMS Normalization with Residual Connection. + Commonly used in Transformers (Prenorm or Postnorm). + """ + def __init__(self, hidden_size: int, eps: float = 1e-5, is_rms_norm: bool = True, prenorm: bool = False): + super(Model, self).__init__() + self.is_rms_norm = is_rms_norm + self.prenorm = prenorm + self.eps = eps + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.bias = nn.Parameter(torch.zeros(hidden_size)) if not is_rms_norm else None + + def forward(self, x: torch.Tensor, residual: torch.Tensor = None) -> torch.Tensor: + """ + Args: + x (torch.Tensor): Input tensor [B, T, H] + residual (torch.Tensor, optional): Residual connection [B, T, H] + Returns: + torch.Tensor: Normalized output (or tuple of output and residual if prenorm=True) + For Kernelbench, we return just the final output tensor. + """ + if residual is not None: + if self.prenorm: + # In prenorm, x is actually the hidden state that expects residual + some_sublayer(norm(x)) + # But typically prenorm modules in this repo add residual first. + x_with_res = x + residual + else: + x_with_res = x + residual + else: + x_with_res = x + + if self.is_rms_norm: + # RMSNorm + rstd = torch.rsqrt(x_with_res.pow(2).mean(-1, keepdim=True) + self.eps) + out = x_with_res * rstd * self.weight + if self.bias is not None: + out = out + self.bias + else: + # LayerNorm + out = F.layer_norm(x_with_res, x_with_res.shape[-1:], self.weight, self.bias, self.eps) + + return out + +# Kernelbench Parameters +batch_size = 8 +seq_len = 1024 +hidden_size = 2048 + +def get_inputs(): + x = torch.randn(batch_size, seq_len, hidden_size) + residual = torch.randn(batch_size, seq_len, hidden_size) + return [x, residual] + +def get_init_inputs(): + return [hidden_size] diff --git a/KernelBench/level9/lightnet_reference.py b/KernelBench/level9/lightnet_reference.py new file mode 100644 index 00000000..fd62a0c8 --- /dev/null +++ b/KernelBench/level9/lightnet_reference.py @@ -0,0 +1,136 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import math + +class RMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-5): + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.eps = eps + + def forward(self, x): + norm_x = torch.mean(x**2, dim=-1, keepdim=True) + x_normed = x * torch.rsqrt(norm_x + self.eps) + return x_normed * self.weight + +class Model(nn.Module): + """ + Reference implementation of LightNet (YOSO: You Only Scan Once). + """ + def __init__( + self, + hidden_size: int = 1024, + num_heads: int = 8, + expand_ratio: int = 128, + use_short_conv: bool = False, + conv_size: int = 4, + gate_low_rank_dim: int = 128, + ): + super(Model, self).__init__() + self.hidden_size = hidden_size + self.num_heads = num_heads + self.expand_ratio = expand_ratio + self.key_dim = num_heads * expand_ratio + self.value_dim = hidden_size + self.head_f_dim = expand_ratio + self.head_i_dim = hidden_size // num_heads + + self.q_proj = nn.Linear(hidden_size, self.key_dim, bias=False) + self.k_proj = nn.Linear(hidden_size, self.key_dim, bias=False) + self.v_proj = nn.Linear(hidden_size, self.value_dim, bias=False) + + self.use_short_conv = use_short_conv + if use_short_conv: + self.q_conv1d = nn.Conv1d(self.key_dim, self.key_dim, conv_size, groups=self.key_dim, padding=conv_size-1) + self.k_conv1d = nn.Conv1d(self.key_dim, self.key_dim, conv_size, groups=self.key_dim, padding=conv_size-1) + self.v_conv1d = nn.Conv1d(self.value_dim, self.value_dim, conv_size, groups=self.value_dim, padding=conv_size-1) + + self.g_proj = nn.Sequential( + nn.Linear(hidden_size, gate_low_rank_dim, bias=False), + nn.Linear(gate_low_rank_dim, hidden_size, bias=False), + ) + self.g_norm = RMSNorm(hidden_size) + self.o_proj = nn.Linear(self.value_dim, hidden_size, bias=False) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + """ + Args: + hidden_states (torch.Tensor): [B, T, H_in] + Returns: + torch.Tensor: [B, T, H_out] + """ + B, T, _ = hidden_states.shape + H, KF, DF = self.num_heads, self.key_dim, self.head_f_dim + DI = self.head_i_dim + + q = self.q_proj(hidden_states) + k = self.k_proj(hidden_states) + v = self.v_proj(hidden_states) + + if self.use_short_conv: + q = self.q_conv1d(q.transpose(1, 2))[..., :T].transpose(1, 2) + k = self.k_conv1d(k.transpose(1, 2))[..., :T].transpose(1, 2) + v = self.v_conv1d(v.transpose(1, 2))[..., :T].transpose(1, 2) + + q = F.silu(q).view(B, T, H, DF) + k = k.view(B, T, H, DF) + v = v.view(B, T, H, DI) + + # YOSO Gate and Normalization + # z = logcumsumexp(k) + # k_new = exp(k - z) + # g = z_{t-1} - z_t + + z = torch.logcumsumexp(k.float(), dim=1) # [B, T, H, DF] + k_new = torch.exp(k.float() - z).to(k.dtype) + + # g = shift(z) - z + z_shifted = torch.cat([z[:, :1], z[:, :-1]], dim=1) + gk = (z_shifted - z).to(k.dtype) + + # Recurrence + # S_t = S_{t-1} * exp(gk_t) + k_new_t * v_t^T + # o_t = q_t * S_t + + scale = DF ** -0.5 + q, k_new, v, gk = q.float(), k_new.float(), v.float(), gk.float() + + S = torch.zeros(B, H, DF, DI, device=hidden_states.device, dtype=torch.float32) + o = torch.zeros(B, T, H, DI, device=hidden_states.device, dtype=torch.float32) + + for t in range(T): + q_t = q[:, t] * scale # [B, H, DF] + k_t = k_new[:, t] # [B, H, DF] + v_t = v[:, t] # [B, H, DI] + gk_t = gk[:, t].exp() # [B, H, DF] + + # S_t = S_{t-1} * gate + outer(k, v) + # gk_t is per-dimension of DF + S = S * gk_t.unsqueeze(-1) + torch.einsum('b h f, b h i -> b h f i', k_t, v_t) + + # o_t = q_t @ S_t + o[:, t] = torch.einsum('b h f, b h f i -> b h i', q_t, S) + + o = o.view(B, T, -1) + + # Output Gating and Norm + gate = self.g_proj(hidden_states) + o = self.g_norm(o) * F.silu(gate) + o = self.o_proj(o) + + return o + +# Kernelbench Parameters +batch_size = 2 +seq_len = 128 +hidden_size = 512 +num_heads = 4 +expand_ratio = 128 + +def get_inputs(): + hidden_states = torch.randn(batch_size, seq_len, hidden_size) + return [hidden_states] + +def get_init_inputs(): + return [hidden_size, num_heads, expand_ratio] diff --git a/KernelBench/level9/log_linear_attn_reference.py b/KernelBench/level9/log_linear_attn_reference.py new file mode 100644 index 00000000..ed2a2cb8 --- /dev/null +++ b/KernelBench/level9/log_linear_attn_reference.py @@ -0,0 +1,107 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import math + +# Inlined helper functions from fla/ops/log_linear_attn/naive.py +def segsum(x): + T = x.size(-1) + x_cumsum = torch.cumsum(x, dim=-1) + x_segsum = x_cumsum[..., :, None] - x_cumsum[..., None, :] + mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool)) + x_segsum = x_segsum.masked_fill(~mask, -torch.inf) + return x_segsum + +def construct_level_mask(level, L): + T = L.size(-1) + if level == 0: + return torch.diag_embed(L[..., level, :]) + indices = torch.cartesian_prod(torch.arange(T), torch.arange(T)).to(L.device) + mask = torch.where( + torch.logical_and( + torch.logical_and( + indices[:, 0] % (1 << level) >= (1 << (level - 1)), + indices[:, 1] + (1 << (level - 1)) + >= indices[:, 0] - (indices[:, 0] % (1 << (level - 1))), + ), + indices[:, 1] < indices[:, 0] - (indices[:, 0] % (1 << (level - 1))), + ).view(T, T), + L[..., level, :].unsqueeze(-1).expand(*([-1] * (len(L.shape) - 2)), T, T), + 0, + ).to(L.dtype) + return mask + +def construct_H_matrix(a, L): + T = a.size(-1) + A = torch.exp(segsum(a)) + H = torch.zeros_like(A, dtype=a.dtype) + for level in range(math.ceil(math.log2(T)) + 1): + mask = construct_level_mask(level, L) + H += A * mask + return H + +class Model(nn.Module): + """ + Naive implementation of Log Linear Attention. + """ + def __init__(self, n_heads: int = 4, seq_len: int = 64): + super(Model, self).__init__() + self.n_heads = n_heads + self.seq_len = seq_len + self.n_levels = int(math.ceil(math.log2(seq_len))) + 1 + + def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, g: torch.Tensor, level_scales: torch.Tensor) -> torch.Tensor: + """ + Args: + q: [batch_size, seq_len, n_heads, head_dim] + k: [batch_size, seq_len, n_heads, head_dim] + v: [batch_size, seq_len, n_heads, head_dim] + g: [batch_size, seq_len, n_heads] + level_scales: [batch_size, n_heads, n_levels, seq_len] + Returns: + o: [batch_size, seq_len, n_heads, head_dim] + """ + # H calculation requires [batch_size, n_heads, seq_len] for g + # and [batch_size, n_heads, n_levels, seq_len] for level_scales + # q, k, v are [batch_size, seq_len, n_heads, d] + + # g: [b, t, h] -> [b, h, t] + g_transposed = g.transpose(1, 2) + + # level_scales is already [b, h, n_levels, t] + + # Compute H matrix [batch_size, n_heads, seq_len, seq_len] + H = construct_H_matrix(g_transposed, level_scales) + + # Attention computation + # H: [b, h, l, c] (batch, head, seq_len_q, seq_len_k) + # q: [b, l, h, n] (batch, seq_len_q, head, head_dim) + # k: [b, c, h, n] (batch, seq_len_k, head, head_dim) + # v: [b, c, h, p] (batch, seq_len_k, head, head_dim) + + # M = H * (q @ k.T) -> but in log-linear it's specialized + # Based on naive implementation: + # M = torch.einsum("bhlc,blhn,bchn->bhlc", H, q, k) + M = torch.einsum("bhlc,blhn,bchn->bhlc", H, q, k) + # o = torch.einsum("bhlc,bchp->blhp", M, v) + o = torch.einsum("bhlc,bchp->blhp", M, v) + + return o + +# Kernelbench Parameters +batch_size = 2 +seq_len = 64 +n_heads = 4 +head_dim = 32 +n_levels = int(math.ceil(math.log2(seq_len))) + 1 + +def get_inputs(): + q = torch.randn(batch_size, seq_len, n_heads, head_dim) + k = torch.randn(batch_size, seq_len, n_heads, head_dim) + v = torch.randn(batch_size, seq_len, n_heads, head_dim) + g = torch.randn(batch_size, seq_len, n_heads) + level_scales = torch.randn(batch_size, n_heads, n_levels, seq_len).abs() + return [q, k, v, g, level_scales] + +def get_init_inputs(): + return [n_heads, seq_len] diff --git a/KernelBench/level9/mamba2_reference.py b/KernelBench/level9/mamba2_reference.py new file mode 100644 index 00000000..8d366c00 --- /dev/null +++ b/KernelBench/level9/mamba2_reference.py @@ -0,0 +1,207 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +# Helper functions from fla/layers/mamba2.py +def apply_mask_to_padding_states(hidden_states, attention_mask): + if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1: + dtype = hidden_states.dtype + hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype) + return hidden_states + +def pad_tensor_by_size(input_tensor: torch.Tensor, pad_size: int): + # Assumes that we only have tensors of either size 4 or 3 + if len(input_tensor.shape) == 4: + pad_shape = (0, 0, 0, 0, 0, pad_size, 0, 0) + else: + pad_shape = (0, 0, 0, pad_size, 0, 0) + return torch.nn.functional.pad(input_tensor, pad_shape, mode="constant", value=0) + +def reshape_into_chunks(input_tensor, pad_size, chunk_size): + # [bsz, seq_len, ...] -> [bsz, seq_len multiple of chunk_size, ...] + input_tensor = pad_tensor_by_size(input_tensor, pad_size) + if len(input_tensor.shape) == 3: + # [bsz, seq_len multiple of chunk_size, num_heads] -> [bsz, -1, chunk_size, num_heads] + return input_tensor.reshape(input_tensor.shape[0], -1, chunk_size, input_tensor.shape[2]) + else: + # [bsz, seq_len multiple of chunk_size, num_heads, head_dim or state_size] -> + # [bsz, -1, chunk_size, num_heads, head_dim or state_size] + return input_tensor.reshape( + input_tensor.shape[0], -1, chunk_size, input_tensor.shape[2], input_tensor.shape[3], + ) + +def segment_sum(input_tensor): + chunk_size = input_tensor.size(-1) + # [..., chunk_size] -> [..., chunk_size, chunk_size] + input_tensor = input_tensor[..., None].expand(*input_tensor.size(), chunk_size) + mask = torch.tril(torch.ones(chunk_size, chunk_size, device=input_tensor.device, dtype=torch.bool), diagonal=-1) + input_tensor = input_tensor.masked_fill(~mask, 0) + tensor_segsum = torch.cumsum(input_tensor, dim=-2) + mask = torch.tril(torch.ones(chunk_size, chunk_size, device=input_tensor.device, dtype=torch.bool), diagonal=0) + tensor_segsum = tensor_segsum.masked_fill(~mask, -torch.inf) + return tensor_segsum + +class RMSNormGated(nn.Module): + def __init__(self, hidden_size, eps=1e-5, norm_before_gate=False): + super().__init__() + self.eps = eps + self.norm_before_gate = norm_before_gate + self.weight = nn.Parameter(torch.ones(hidden_size)) + + def forward(self, x, z=None): + dtype = x.dtype + weight = self.weight.float() + x = x.float() + if z is not None and not self.norm_before_gate: + x = x * F.silu(z.float()) + + rstd = 1 / torch.sqrt((x.square()).mean(dim=-1, keepdim=True) + self.eps) + out = x * rstd * weight + + if z is not None and self.norm_before_gate: + out = out * F.silu(z.float()) + return out.to(dtype) + +class Model(nn.Module): + """ + Reference implementation of Mamba-2 (Linear Attention / SSD) + """ + def __init__( + self, + num_heads: int = 64, + head_dim: int = 64, + hidden_size: int = 2048, + state_size: int = 128, + expand: int = 2, + n_groups: int = 1, + chunk_size: int = 256, + ): + super(Model, self).__init__() + self.num_heads = num_heads + self.head_dim = head_dim + self.hidden_size = hidden_size + self.ssm_state_size = state_size + self.expand = expand + self.intermediate_size = int(expand * hidden_size) + self.n_groups = n_groups + self.chunk_size = chunk_size + self.conv_kernel_size = 4 + + self.conv_dim = self.intermediate_size + 2 * self.n_groups * self.ssm_state_size + self.conv1d = nn.Conv1d( + in_channels=self.conv_dim, + out_channels=self.conv_dim, + bias=False, + kernel_size=self.conv_kernel_size, + groups=self.conv_dim, + padding=self.conv_kernel_size - 1, + ) + + projection_size = self.intermediate_size + self.conv_dim + self.num_heads + self.in_proj = nn.Linear(self.hidden_size, projection_size, bias=True) + self.dt_bias = nn.Parameter(torch.ones(self.num_heads)) + self.A_log = nn.Parameter(torch.log(torch.arange(1, self.num_heads + 1).float())) + self.norm = RMSNormGated(self.intermediate_size, norm_before_gate=False) + self.D = nn.Parameter(torch.ones(self.num_heads)) + self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=True) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + """ + Args: + hidden_states (torch.Tensor): Input of shape (batch, seq_len, hidden_size) + Returns: + torch.Tensor: Output of shape (batch, seq_len, hidden_size) + """ + batch_size, seq_len, _ = hidden_states.shape + dtype = hidden_states.dtype + + # 1. Gated MLP's linear projection + projected_states = self.in_proj(hidden_states) + d_mlp = (projected_states.shape[-1] - 2 * self.intermediate_size - + 2 * self.n_groups * self.ssm_state_size - self.num_heads) // 2 + _, _, gate, hidden_states_B_C, dt = projected_states.split( + [d_mlp, d_mlp, self.intermediate_size, self.conv_dim, self.num_heads], dim=-1, + ) + + # 2. Convolution sequence transformation + hidden_states_B_C = F.silu(self.conv1d(hidden_states_B_C.transpose(1, 2))[..., :seq_len].transpose(1, 2)) + + hidden_states, B, C = torch.split( + hidden_states_B_C, + [self.intermediate_size, self.n_groups * self.ssm_state_size, self.n_groups * self.ssm_state_size], + dim=-1, + ) + + # 3. SSM transformation (SSD naive implementation) + A = -torch.exp(self.A_log.float()) + dt = F.softplus(dt + self.dt_bias) + + hidden_states = hidden_states.reshape(batch_size, seq_len, self.num_heads, self.head_dim).float() + B = B.reshape(batch_size, seq_len, -1, self.ssm_state_size).float() + C = C.reshape(batch_size, seq_len, -1, self.ssm_state_size).float() + B = B.repeat(1, 1, self.num_heads // self.n_groups, 1) + C = C.repeat(1, 1, self.num_heads // self.n_groups, 1) + + pad_size = (self.chunk_size - seq_len % self.chunk_size) % self.chunk_size + padded_hidden = pad_tensor_by_size(hidden_states, pad_size) + D_residual = self.D.view(1, 1, self.num_heads, 1) * padded_hidden + + hidden_states = padded_hidden + hidden_states = hidden_states * dt[..., None] + A = A.to(hidden_states.dtype) * dt + + hidden_states, A, B, C = [reshape_into_chunks(t, pad_size, self.chunk_size) for t in (hidden_states, A, B, C)] + A = A.permute(0, 3, 1, 2) + A_cumsum = torch.cumsum(A, dim=-1) + + # 1. Intra-chunk (diagonal blocks) + L = torch.exp(segment_sum(A)) + G = (C[:, :, :, None, :, :] * B[:, :, None, :, :, :]).sum(dim=-1) + M = (G[..., None] * L.permute(0, 2, 3, 4, 1)[..., None]).sum(dim=-1) + Y_diag = (M[..., None] * hidden_states[:, :, None]).sum(dim=3) + + # 2. Intra-chunk state (right term of factorization) + decay_states = torch.exp(A_cumsum[:, :, :, -1:] - A_cumsum) + B_decay = B * decay_states.permute(0, -2, -1, 1)[..., None] + states = (B_decay[..., None, :] * hidden_states[..., None]).sum(dim=2) + + # 3. Inter-chunk SSM recurrence + previous_states = torch.zeros_like(states[:, :1]) + states = torch.cat([previous_states, states], dim=1) + decay_chunk = torch.exp(segment_sum(F.pad(A_cumsum[:, :, :, -1], (1, 0)))) + decay_chunk = decay_chunk.transpose(1, 3) + new_states = (decay_chunk[..., None, None] * states[:, :, None, ...]).sum(dim=1) + states = new_states[:, :-1] + + # 4. State -> output conversion per chunk (left term of factorization) + state_decay_out = torch.exp(A_cumsum) + C_times_states = (C[..., None, :] * states[:, :, None, ...]) + Y_off = (C_times_states.sum(-1) * state_decay_out.permute(0, 2, 3, 1)[..., None]) + + y = Y_diag + Y_off + y = y.reshape(batch_size, -1, self.num_heads, self.head_dim) + y = y + D_residual + if pad_size > 0: + y = y[:, :seq_len, :, :] + y = y.reshape(batch_size, seq_len, -1) + + scan_output = self.norm(y, gate) + return self.out_proj(scan_output.to(dtype)) + +# Configuration for get_inputs/get_init_inputs +batch_size = 4 +seq_len = 1024 +num_heads = 64 +head_dim = 64 +hidden_size = 2048 +state_size = 128 +expand = 2 +n_groups = 1 +chunk_size = 256 + +def get_inputs(): + hidden_states = torch.randn(batch_size, seq_len, hidden_size) + return [hidden_states] + +def get_init_inputs(): + return [num_heads, head_dim, hidden_size, state_size, expand, n_groups, chunk_size] diff --git a/KernelBench/level9/mesa_net_reference.py b/KernelBench/level9/mesa_net_reference.py new file mode 100644 index 00000000..b28a4c3a --- /dev/null +++ b/KernelBench/level9/mesa_net_reference.py @@ -0,0 +1,337 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import math + + +class Model(nn.Module): + """ + MesaNet (Sequence Modeling by Locally Optimal Test-Time Training) - Reference implementation. + + MesaNet uses "test-time training" via Conjugate Gradient (CG) to solve a linear system: + (H_kk + lambda*I) @ q_star = q + + Where: + - H_kk is a recurrently updated state matrix: h_kk[t] = h_kk[t-1] * exp(g[t]) + outer(k[t]*beta[t], k[t]) + - H_kv is another state matrix: h_kv[t] = h_kv[t-1] * exp(g[t]) + outer(k[t]*beta[t], v[t]) + - lambda is a regularization parameter (per head, per dimension) + + The output is: o = H_kv @ q_star + + Key insight: Instead of materializing the full attention matrix, MesaNet maintains + compact state matrices and solves for q_star iteratively using CG, which acts as + a form of "test-time training" that adapts to the current input. + + Based on: "MesaNet: Sequence Modeling by Locally Optimal Test-Time Training" + """ + + def __init__( + self, + hidden_size: int = 2048, + num_heads: int = 16, + head_dim: int = 128, + use_output_gate: bool = False, + lambda_lower_bound: float = 0.25, + max_cg_iteration: int = 30, + ): + super().__init__() + + self.hidden_size = hidden_size + self.num_heads = num_heads + self.head_dim = head_dim + self.key_dim = num_heads * head_dim + self.value_dim = self.key_dim # MesaNet uses same dim for V as K + self.use_output_gate = use_output_gate + self.lambda_lower_bound = lambda_lower_bound + self.max_cg_iteration = max_cg_iteration + + # Projections + self.q_proj = nn.Linear(hidden_size, self.key_dim, bias=False) + self.k_proj = nn.Linear(hidden_size, self.key_dim, bias=False) + self.v_proj = nn.Linear(hidden_size, self.value_dim, bias=False) + + # Gate projections + self.a_proj = nn.Linear(hidden_size, num_heads, bias=True) # For g (decay) + self.b_proj = nn.Linear(hidden_size, num_heads, bias=True) # For beta + + # Lambda parameters: regularization per head per dimension + # Initialized to 1.0, then transformed via softplus + lower_bound + lambda_initial_value = 1.0 + init_lamb_value = torch.log(torch.exp(torch.tensor(lambda_initial_value - lambda_lower_bound)) - 1.0) + self.lambda_params = nn.Parameter(torch.empty(self.key_dim, dtype=torch.float32).fill_(init_lamb_value)) + + # Output gate and projection + if use_output_gate: + self.g_proj = nn.Linear(hidden_size, self.value_dim, bias=False) + self.o_proj = nn.Linear(self.value_dim, hidden_size, bias=False) + + # RMSNorm weight + self.o_norm_weight = nn.Parameter(torch.ones(self.head_dim)) + + def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None) -> torch.Tensor: + """ + Args: + x: Input tensor of shape (batch_size, seq_len, hidden_size) + attention_mask: Optional mask (unused in this reference) + + Returns: + Output tensor of shape (batch_size, seq_len, hidden_size) + """ + batch_size, seq_len, _ = x.shape + + # Project to Q, K, V + q = self.q_proj(x) # [B, T, key_dim] + k = self.k_proj(x) # [B, T, key_dim] + v = self.v_proj(x) # [B, T, value_dim] + + # Apply SiLU activation (simulating short convolution effect) + q = F.silu(q) + k = F.silu(k) + v = F.silu(v) + + # Reshape to multi-head format + q = q.view(batch_size, seq_len, self.num_heads, self.head_dim) + k = k.view(batch_size, seq_len, self.num_heads, self.head_dim) + v = v.view(batch_size, seq_len, self.num_heads, self.head_dim) + + # L2 normalize Q and K (as done in the kernel) + q = F.normalize(q, p=2, dim=-1) + k = F.normalize(k, p=2, dim=-1) + + # Compute beta (sigmoid) and g (log-sigmoid decay) + beta = torch.sigmoid(self.b_proj(x)) # [B, T, num_heads] + g = F.logsigmoid(self.a_proj(x)) # [B, T, num_heads] - negative values + + # Compute lambda: softplus + lower_bound, reshaped to [num_heads, head_dim] + lamb = F.softplus(self.lambda_params) + self.lambda_lower_bound + lamb = lamb.view(self.num_heads, self.head_dim) # [num_heads, head_dim] + + # ============================================ + # MesaNet Core: Test-Time Training via CG + # ============================================ + o = self._mesa_net_attention(q, k, v, g, beta, lamb) + + # Output normalization + if self.use_output_gate: + gate = self.g_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim) + o = self._gated_rms_norm(o, gate, self.o_norm_weight) + else: + o = self._rms_norm(o, self.o_norm_weight) + + # Reshape and project output + o = o.view(batch_size, seq_len, self.value_dim) + o = self.o_proj(o) + + return o + + def _mesa_net_attention( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor, + lamb: torch.Tensor + ) -> torch.Tensor: + """ + MesaNet attention using Conjugate Gradient solver. + + Steps: + 1. Build state matrices h_kk and h_kv recurrently + 2. Solve (h_kk + lambda*I) @ q_star = q using CG + 3. Output: o = h_kv @ q_star + + Args: + q: [B, T, num_heads, head_dim] + k: [B, T, num_heads, head_dim] + v: [B, T, num_heads, head_dim] + g: [B, T, num_heads] - log-sigmoid decay + beta: [B, T, num_heads] - scaling + lamb: [num_heads, head_dim] - regularization + + Returns: + o: [B, T, num_heads, head_dim] + """ + B, T, H, D = q.shape + + # Work in float32 for stability + q = q.float() + k = k.float() + v = v.float() + g = g.float() + beta = beta.float() + lamb = lamb.float() + + outputs = [] + + for b in range(B): + batch_outputs = [] + + for h in range(H): + # Initialize state matrices + h_kk = torch.zeros(D, D, device=q.device, dtype=torch.float32) # [head_dim, head_dim] + h_kv = torch.zeros(D, D, device=q.device, dtype=torch.float32) # [head_dim, head_dim] + + # Build state matrices recurrently + h_kk_all = [] + h_kv_all = [] + + for t in range(T): + k_t = k[b, t, h] # [D] + v_t = v[b, t, h] # [D] + g_t = g[b, t, h] # scalar + beta_t = beta[b, t, h] # scalar + + # Update states with exponential decay + # h_kk = h_kk * exp(g) + outer(k*beta, k) + k_beta = k_t * beta_t + h_kk = h_kk * torch.exp(g_t) + torch.outer(k_beta, k_t) + h_kv = h_kv * torch.exp(g_t) + torch.outer(k_beta, v_t) + + h_kk_all.append(h_kk.clone()) + h_kv_all.append(h_kv.clone()) + + # Stack states: [T, D, D] + h_kk_all = torch.stack(h_kk_all, dim=0) + h_kv_all = torch.stack(h_kv_all, dim=0) + + # Get lambda for this head: [D] + lamb_h = lamb[h] # [D] + + # Solve for each time step using CG + head_outputs = [] + + for t in range(T): + q_t = q[b, t, h] # [D] + h_kk_t = h_kk_all[t] # [D, D] + h_kv_t = h_kv_all[t] # [D, D] + + # Solve: (h_kk + lambda*I) @ q_star = q using Conjugate Gradient + q_star = self._conjugate_gradient_solve( + A=h_kk_t, + b=q_t, + lamb=lamb_h, + max_iter=self.max_cg_iteration + ) + + # Output: o = h_kv @ q_star + o_t = h_kv_t @ q_star + head_outputs.append(o_t) + + # Stack outputs for this head: [T, D] + head_outputs = torch.stack(head_outputs, dim=0) + batch_outputs.append(head_outputs) + + # Stack outputs for this batch: [T, H, D] + batch_outputs = torch.stack(batch_outputs, dim=1) + outputs.append(batch_outputs) + + # Stack all batches: [B, T, H, D] + outputs = torch.stack(outputs, dim=0) + + return outputs + + def _conjugate_gradient_solve( + self, + A: torch.Tensor, + b: torch.Tensor, + lamb: torch.Tensor, + max_iter: int = 30 + ) -> torch.Tensor: + """ + Solve (A + lambda*I) @ x = b using Conjugate Gradient method. + + This is the "test-time training" aspect: iteratively refine the solution + to adapt to the current input. + + Args: + A: [D, D] - matrix + b: [D] - right-hand side + lamb: [D] - diagonal regularization (per dimension) + max_iter: Maximum CG iterations + + Returns: + x: [D] - solution + """ + D = b.shape[0] + + # Matrix-vector product function: (A + lambda*I) @ x + def matvec(x): + return A @ x + lamb * x + + # Initial guess: x = b / (diag(A) + lambda) + diag_A = torch.diagonal(A) + x = b / (diag_A + lamb + 1e-8) + + # Initial residual: r = b - (A + lambda*I) @ x + r = b - matvec(x) + p = r.clone() + delta_old = torch.dot(r, r) + + # CG iterations + for i in range(max_iter): + # Check convergence + if delta_old < 1e-10: + break + + # Compute: q = (A + lambda*I) @ p + q = matvec(p) + + # Compute step size: alpha = delta_old / (p @ q) + p_dot_q = torch.dot(p, q) + if abs(p_dot_q) < 1e-10: + break + alpha = delta_old / p_dot_q + + # Update solution: x = x + alpha * p + x = x + alpha * p + + # Update residual: r = r - alpha * q + r = r - alpha * q + + # Compute new delta: delta_new = r @ r + delta_new = torch.dot(r, r) + + # Compute beta: beta = delta_new / delta_old + if delta_old < 1e-10: + break + beta = delta_new / delta_old + + # Update search direction: p = r + beta * p + p = r + beta * p + + delta_old = delta_new + + return x + + def _rms_norm(self, x: torch.Tensor, weight: torch.Tensor) -> torch.Tensor: + """RMSNorm implementation.""" + rms = torch.sqrt(torch.mean(x.float() ** 2, dim=-1, keepdim=True) + 1e-6) + return (x.float() / rms * weight).to(x.dtype) + + def _gated_rms_norm(self, x: torch.Tensor, g: torch.Tensor, weight: torch.Tensor) -> torch.Tensor: + """Gated RMSNorm: RMSNorm(x) * sigmoid(g).""" + rms = torch.sqrt(torch.mean(x.float() ** 2, dim=-1, keepdim=True) + 1e-6) + x_norm = (x.float() / rms) * weight + return (x_norm * torch.sigmoid(g.float())).to(x.dtype) + + +# Problem dimensions +batch_size = 4 +seq_len = 512 +hidden_size = 2048 +num_heads = 16 +head_dim = 128 +use_output_gate = False +lambda_lower_bound = 0.25 +max_cg_iteration = 30 + + +def get_inputs(): + x = torch.randn(batch_size, seq_len, hidden_size) + return [x] + + +def get_init_inputs(): + return [hidden_size, num_heads, head_dim, use_output_gate, lambda_lower_bound, max_cg_iteration] + diff --git a/KernelBench/level9/mom_reference.py b/KernelBench/level9/mom_reference.py new file mode 100644 index 00000000..9371e4b8 --- /dev/null +++ b/KernelBench/level9/mom_reference.py @@ -0,0 +1,312 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import math + + +class Model(nn.Module): + """ + MoM (Mixture-of-Memories) Attention - A reference implementation using pure PyTorch. + + MoM combines: + 1. Top-k routing: Each token selects top-k memory "experts" + 2. Gated Delta Rule: A linear attention variant with exponential decay and delta updates + 3. Transform/Reconstruct: Reorganize tokens by memory slot, process, then scatter back + + The Gated Delta Rule recurrence for each memory: + h[t] = h[t-1] * exp(g[t]) + k[t] @ (beta[t] * (v[t] - h[t-1].T @ k[t])).T + o[t] = h[t].T @ q[t] + + Where h is a [K, V] outer-product state matrix. + + Based on: "MoM: Linear Sequence Modeling with Mixture-of-Memories" + https://arxiv.org/abs/2502.13685 + """ + + def __init__( + self, + hidden_size: int = 2048, + head_dim: int = 256, + num_heads: int = 4, + expand_v: float = 2.0, + num_memories: int = 8, + topk: int = 2, + ): + super().__init__() + + self.hidden_size = hidden_size + self.head_dim = head_dim + self.num_heads = num_heads + self.expand_v = expand_v + self.num_memories = num_memories + self.topk = topk + + self.key_dim = num_heads * head_dim + self.value_dim = int(self.key_dim * expand_v) + self.head_v_dim = int(head_dim * expand_v) + + # Router gate + self.gate = nn.Linear(hidden_size, num_memories, bias=False) + + # Projections (shared across memories for simplicity) + self.q_proj = nn.Linear(hidden_size, self.key_dim, bias=False) + self.k_proj = nn.Linear(hidden_size, self.key_dim, bias=False) + self.v_proj = nn.Linear(hidden_size, self.value_dim, bias=False) + self.b_proj = nn.Linear(hidden_size, num_heads, bias=False) # beta projection + self.a_proj = nn.Linear(hidden_size, num_heads, bias=False) # gate projection + + # Output projections + self.g_proj = nn.Linear(hidden_size, self.value_dim, bias=False) # output gate + self.o_proj = nn.Linear(self.value_dim, hidden_size, bias=False) + + # Learnable decay parameters (A_log and dt_bias) + A = torch.empty(num_heads, dtype=torch.float32).uniform_(0, 16) + self.A_log = nn.Parameter(torch.log(A)) + + dt_min, dt_max = 0.001, 0.1 + dt = torch.exp(torch.rand(num_heads) * (math.log(dt_max) - math.log(dt_min)) + math.log(dt_min)) + dt = torch.clamp(dt, min=1e-4) + inv_dt = dt + torch.log(-torch.expm1(-dt)) + self.dt_bias = nn.Parameter(inv_dt) + + # RMSNorm for output + self.o_norm_weight = nn.Parameter(torch.ones(self.head_v_dim)) + + def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None) -> torch.Tensor: + """ + Args: + x: Input tensor of shape (batch_size, seq_len, hidden_size) + attention_mask: Optional mask of shape (batch_size, seq_len), 1=valid, 0=padding + + Returns: + Output tensor of shape (batch_size, seq_len, hidden_size) + """ + batch_size, seq_len, _ = x.shape + + # ============================================ + # Step 1: Top-k Routing + # ============================================ + router_logits = self.gate(x) # [B, T, num_memories] + scores = F.softmax(router_logits, dim=-1) + routing_weights, selected_memories = torch.topk(scores, self.topk, dim=-1) # [B, T, topk] + routing_weights = routing_weights / routing_weights.sum(dim=-1, keepdim=True) # Normalize + + # Create routing mask: [B, T, num_memories] + routing_mask = torch.zeros(batch_size, seq_len, self.num_memories, device=x.device, dtype=torch.bool) + routing_mask.scatter_(-1, selected_memories, True) + + if attention_mask is not None: + # Mask out padding tokens + routing_mask = routing_mask & attention_mask.unsqueeze(-1).bool() + + # ============================================ + # Step 2: Transform - Reorganize tokens by memory + # ============================================ + # For each memory, gather the tokens routed to it + memory_outputs = [] + memory_indices = [] # Track which tokens went to which memory + + for mem_idx in range(self.num_memories): + # Find tokens routed to this memory + # For simplicity, process per-batch + batch_outputs = [] + batch_indices = [] + + for b in range(batch_size): + # Get mask for this batch and memory + mem_mask = routing_mask[b, :, mem_idx] # [T] + token_indices = torch.where(mem_mask)[0] # Indices of tokens routed to this memory + + if len(token_indices) == 0: + batch_outputs.append(torch.zeros(0, self.value_dim, device=x.device, dtype=x.dtype)) + batch_indices.append(token_indices) + continue + + # Gather tokens for this memory + tokens = x[b, token_indices] # [num_tokens, hidden_size] + + # ============================================ + # Step 3: Gated Delta Rule for this memory + # ============================================ + mem_output = self._gated_delta_rule(tokens) # [num_tokens, value_dim] + + batch_outputs.append(mem_output) + batch_indices.append(token_indices) + + memory_outputs.append(batch_outputs) + memory_indices.append(batch_indices) + + # ============================================ + # Step 4: Reconstruct - Scatter back and mix + # ============================================ + output = torch.zeros(batch_size, seq_len, self.value_dim, device=x.device, dtype=x.dtype) + + for mem_idx in range(self.num_memories): + for b in range(batch_size): + indices = memory_indices[mem_idx][b] + if len(indices) == 0: + continue + + mem_out = memory_outputs[mem_idx][b] # [num_tokens, value_dim] + + # Get routing weights for these tokens to this memory + # Find which topk slot this memory corresponds to + mem_weights = torch.zeros(len(indices), device=x.device, dtype=x.dtype) + for i, idx in enumerate(indices): + # Find the weight for this memory in the topk selection + for k in range(self.topk): + if selected_memories[b, idx, k] == mem_idx: + mem_weights[i] = routing_weights[b, idx, k] + break + + # Weighted scatter-add + output[b, indices] += mem_out * mem_weights.unsqueeze(-1) + + # ============================================ + # Step 5: Output projection with gating + # ============================================ + # Reshape for head-wise processing + output = output.view(batch_size, seq_len, self.num_heads, self.head_v_dim) + + # Output gate + g = self.g_proj(x).view(batch_size, seq_len, self.num_heads, self.head_v_dim) + + # Gated RMSNorm + output = self._gated_rms_norm(output, g, self.o_norm_weight) + + # Final projection + output = output.view(batch_size, seq_len, self.value_dim) + output = self.o_proj(output) + + return output + + def _gated_delta_rule(self, tokens: torch.Tensor) -> torch.Tensor: + """ + Apply Gated Delta Rule attention to a sequence of tokens. + + The recurrence: + h[t] = h[t-1] * exp(g[t]) + k[t] @ (beta[t] * (v[t] - h[t-1].T @ k[t])).T + o[t] = h[t].T @ q[t] + + Args: + tokens: [num_tokens, hidden_size] + + Returns: + outputs: [num_tokens, value_dim] + """ + num_tokens = tokens.shape[0] + if num_tokens == 0: + return tokens.new_zeros(0, self.value_dim) + + # Project to Q, K, V + q = self.q_proj(tokens) # [T, key_dim] + k = self.k_proj(tokens) # [T, key_dim] + v = self.v_proj(tokens) # [T, value_dim] + + # Compute beta (sigmoid-scaled) and gate g + beta = self.b_proj(tokens).sigmoid() # [T, num_heads] + a = self.a_proj(tokens) # [T, num_heads] + g = -self.A_log.float().exp() * F.softplus(a.float() + self.dt_bias) # [T, num_heads] + + # Apply SiLU activation (as in the original implementation with conv) + q = F.silu(q) + k = F.silu(k) + v = F.silu(v) + + # Reshape to multi-head: [T, num_heads, head_dim] + q = q.view(num_tokens, self.num_heads, self.head_dim) + k = k.view(num_tokens, self.num_heads, self.head_dim) + v = v.view(num_tokens, self.num_heads, self.head_v_dim) + + # L2 normalize Q and K (as used in the kernel) + q = F.normalize(q, p=2, dim=-1) + k = F.normalize(k, p=2, dim=-1) + + # Scale Q + scale = self.head_dim ** -0.5 + q = q * scale + + # Initialize state: h is [num_heads, head_dim, head_v_dim] + h = torch.zeros(self.num_heads, self.head_dim, self.head_v_dim, + device=tokens.device, dtype=torch.float32) + + outputs = [] + + for t in range(num_tokens): + q_t = q[t] # [num_heads, head_dim] + k_t = k[t] # [num_heads, head_dim] + v_t = v[t] # [num_heads, head_v_dim] + beta_t = beta[t] # [num_heads] + g_t = g[t] # [num_heads] + + # Apply decay: h = h * exp(g) + # g is negative (decay), so exp(g) < 1 + decay = torch.exp(g_t).unsqueeze(-1).unsqueeze(-1) # [num_heads, 1, 1] + h = h * decay + + # Delta rule update: + # prediction = h.T @ k => [num_heads, head_v_dim] + # For each head: prediction[h] = h[h].T @ k_t[h] = einsum('kv,k->v', h[h], k_t[h]) + prediction = torch.einsum('hkv,hk->hv', h, k_t) # [num_heads, head_v_dim] + + # v_new = beta * (v - prediction) + v_new = beta_t.unsqueeze(-1) * (v_t.float() - prediction) # [num_heads, head_v_dim] + + # Update state: h = h + outer(k, v_new) + # For each head: h[h] += k_t[h][:, None] @ v_new[h][None, :] + h = h + torch.einsum('hk,hv->hkv', k_t.float(), v_new) # [num_heads, head_dim, head_v_dim] + + # Output: o = h.T @ q => [num_heads, head_v_dim] + o_t = torch.einsum('hkv,hk->hv', h, q_t) # [num_heads, head_v_dim] + + outputs.append(o_t) + + # Stack outputs: [T, num_heads, head_v_dim] + outputs = torch.stack(outputs, dim=0) + + # Reshape to [T, value_dim] + outputs = outputs.view(num_tokens, self.value_dim).to(tokens.dtype) + + return outputs + + def _gated_rms_norm(self, x: torch.Tensor, g: torch.Tensor, weight: torch.Tensor) -> torch.Tensor: + """ + Gated RMSNorm: RMSNorm(x) * sigmoid(g) + + Args: + x: [B, T, H, D] + g: [B, T, H, D] gate + weight: [D] norm weight + + Returns: + output: [B, T, H, D] + """ + # RMSNorm + rms = torch.sqrt(torch.mean(x.float() ** 2, dim=-1, keepdim=True) + 1e-6) + x_norm = (x.float() / rms) * weight + + # Gate + output = x_norm * torch.sigmoid(g.float()) + + return output.to(x.dtype) + + +# Problem dimensions +batch_size = 4 +seq_len = 512 +hidden_size = 2048 +head_dim = 256 +num_heads = 4 +expand_v = 2.0 +num_memories = 8 +topk = 2 + + +def get_inputs(): + x = torch.randn(batch_size, seq_len, hidden_size) + return [x] + + +def get_init_inputs(): + return [hidden_size, head_dim, num_heads, expand_v, num_memories, topk] + diff --git a/KernelBench/level9/nsa_reference.py b/KernelBench/level9/nsa_reference.py new file mode 100644 index 00000000..395a4eec --- /dev/null +++ b/KernelBench/level9/nsa_reference.py @@ -0,0 +1,147 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import math + +class Model(nn.Module): + """ + Reference implementation of Native Sparse Attention (NSA). + Combines sliding window, compressed, and selected attention. + """ + def __init__( + self, + hidden_size: int = 1024, + num_heads: int = 32, + num_kv_heads: int = 4, + head_dim: int = 64, + block_size: int = 64, + window_size: int = 512, + num_blocks: int = 16, + ): + super(Model, self).__init__() + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = head_dim + self.block_size = block_size + self.window_size = window_size + self.num_blocks = num_blocks + + self.q_proj = nn.Linear(hidden_size, num_heads * head_dim, bias=False) + self.k_proj = nn.Linear(hidden_size, num_kv_heads * head_dim, bias=False) + self.v_proj = nn.Linear(hidden_size, num_kv_heads * head_dim, bias=False) + self.g_proj = nn.Linear(hidden_size, num_heads * 3, bias=False) + self.o_proj = nn.Linear(num_heads * head_dim, hidden_size, bias=False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Args: + x (torch.Tensor): [batch_size, seq_len, hidden_size] + Returns: + torch.Tensor: [batch_size, seq_len, hidden_size] + """ + B, T, _ = x.shape + H, HKV, D = self.num_heads, self.num_kv_heads, self.head_dim + G = H // HKV + + q = self.q_proj(x).view(B, T, H, D).transpose(1, 2) # [B, H, T, D] + k = self.k_proj(x).view(B, T, HKV, D).transpose(1, 2) # [B, HKV, T, D] + v = self.v_proj(x).view(B, T, HKV, D).transpose(1, 2) # [B, HKV, T, D] + g = self.g_proj(x).view(B, T, H, 3).sigmoid() + g_cmp, g_slc, g_swa = g[..., 0], g[..., 1], g[..., 2] + + if G > 1: + k = k.repeat_interleave(G, dim=1) + v = v.repeat_interleave(G, dim=1) + + scale = D ** -0.5 + + # 1. Sliding Window Attention (SWA) + # Full attention scores with causal and window mask + scores_swa = torch.matmul(q, k.transpose(-1, -2)) * scale + mask = torch.tril(torch.ones(T, T, device=x.device)) + if self.window_size > 0: + window_mask = torch.arange(T, device=x.device).unsqueeze(-1) - torch.arange(T, device=x.device) < self.window_size + mask = mask * window_mask + scores_swa = scores_swa.masked_fill(mask == 0, float('-inf')) + o_swa = torch.matmul(F.softmax(scores_swa, dim=-1), v) + + # 2. Compressed Attention (CMP) + # Pooling K and V in blocks + TC = (T + self.block_size - 1) // self.block_size + k_pad = F.pad(k, (0, 0, 0, TC * self.block_size - T)) + v_pad = F.pad(v, (0, 0, 0, TC * self.block_size - T)) + + # [B, H, TC, BS, D] -> [B, H, TC, D] + k_cmp = k_pad.view(B, H, TC, self.block_size, D).mean(dim=-2) + v_cmp = v_pad.view(B, H, TC, self.block_size, D).mean(dim=-2) + + scores_cmp = torch.matmul(q, k_cmp.transpose(-1, -2)) * scale + # Causal mask for compression: token t can see compressed block c if c*BS < t + cmp_mask = torch.arange(T, device=x.device).unsqueeze(-1) >= (torch.arange(TC, device=x.device) * self.block_size + self.block_size - 1) + scores_cmp = scores_cmp.masked_fill(cmp_mask == 0, float('-inf')) + o_cmp = torch.matmul(F.softmax(scores_cmp, dim=-1), v_cmp) + + # 3. Selected Attention (SLC) - Top-k blocks selection + # We use the compressed scores to select the top-k most important blocks for each query + # and then perform attention over the raw tokens in those blocks. + # This is the "Sparse" part of Native Sparse Attention. + + # For the reference implementation, we'll implement a simplified selected attention: + # For each query token, pick the top num_blocks blocks from scores_cmp. + # Then for those blocks, do attention over original tokens. + + # scores_cmp: [B, H, T, TC] + _, top_block_indices = scores_cmp.topk(min(self.num_blocks, TC), dim=-1) # [B, H, T, S] + + o_slc = torch.zeros_like(o_swa) + # Loop over batches and heads for selected part to keep it simple and correct in reference + for b in range(B): + for h in range(H): + for t in range(T): + # Selected blocks for query t + blocks = top_block_indices[b, h, t] + indices = [] + for blk in blocks: + start = blk.item() * self.block_size + end = min(start + self.block_size, T) + # Only include tokens <= t + if start <= t: + indices.extend(range(start, min(end, t + 1))) + + if not indices: + continue + + indices = torch.tensor(indices, device=x.device) + q_t = q[b, h, t] * scale # [D] + k_t = k[b, h, indices] # [N, D] + v_t = v[b, h, indices] # [N, D] + + attn = F.softmax(torch.matmul(k_t, q_t), dim=0) # [N] + o_slc[b, h, t] = torch.matmul(attn, v_t) + + # Final combination using gating coefficients + # g: [B, T, H] + o = (o_swa * g_swa.transpose(1, 2).unsqueeze(-1) + + o_cmp * g_cmp.transpose(1, 2).unsqueeze(-1) + + o_slc * g_slc.transpose(1, 2).unsqueeze(-1)) + + o = o.transpose(1, 2).reshape(B, T, -1) + return self.o_proj(o.to(x.dtype)) + +# Kernelbench Parameters +batch_size = 1 # Keep it small for the selection loop +seq_len = 128 +hidden_size = 512 +num_heads = 8 +num_kv_heads = 2 +head_dim = 64 +block_size = 32 +window_size = 64 +num_blocks = 4 + +def get_inputs(): + x = torch.randn(batch_size, seq_len, hidden_size) + return [x] + +def get_init_inputs(): + return [hidden_size, num_heads, num_kv_heads, head_dim, block_size, window_size, num_blocks] diff --git a/KernelBench/level9/parallel_forgetting_attn_reference.py b/KernelBench/level9/parallel_forgetting_attn_reference.py new file mode 100644 index 00000000..a6333b16 --- /dev/null +++ b/KernelBench/level9/parallel_forgetting_attn_reference.py @@ -0,0 +1,85 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import math + +class Model(nn.Module): + """ + Reference implementation of Parallel Forgetting Attention. + """ + def __init__(self, num_heads: int, num_kv_heads: int, head_dim: int, v_head_dim: int): + super(Model, self).__init__() + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = head_dim + self.v_head_dim = v_head_dim + + def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, g: torch.Tensor, scale: float = None) -> torch.Tensor: + """ + Args: + q (torch.Tensor): [batch_size, seq_len, num_heads, head_dim] + k (torch.Tensor): [batch_size, seq_len, num_kv_heads, head_dim] + v (torch.Tensor): [batch_size, seq_len, num_kv_heads, v_head_dim] + g (torch.Tensor): [batch_size, seq_len, num_heads] - log decay factors + scale (float): optional scale factor + Returns: + o (torch.Tensor): [batch_size, seq_len, num_heads, v_head_dim] + """ + if scale is None: + scale = q.shape[-1] ** -0.5 + + B, T, H, D = q.shape + _, _, HKV, V = v.shape + G = H // HKV + + # Compute cumsum of g for forgetting factors + g_cumsum = torch.cumsum(g, dim=1) # [B, T, H] + + # Reshape/transpose for attention + q = q.transpose(1, 2) # [B, H, T, D] + k = k.transpose(1, 2) # [B, HKV, T, D] + v = v.transpose(1, 2) # [B, HKV, T, V] + g_cumsum = g_cumsum.transpose(1, 2) # [B, H, T] + + # Handle GQA by repeating k and v + if G > 1: + k = k.repeat_interleave(G, dim=1) + v = v.repeat_interleave(G, dim=1) + + # Compute attention scores + # scores: [B, H, T, T] + attn_scores = torch.matmul(q, k.transpose(-1, -2)) * scale + + # Add forgetting factors (decay) + # score[i, j] += cumsum(g)_i - cumsum(g)_j + attn_scores += g_cumsum.unsqueeze(-1) - g_cumsum.unsqueeze(-2) + + # Causal mask + mask = torch.tril(torch.ones(T, T, device=q.device, dtype=torch.bool)) + attn_scores = attn_scores.masked_fill(~mask, float('-inf')) + + # Softmax + attn_weights = F.softmax(attn_scores, dim=-1) + + # Output + o = torch.matmul(attn_weights, v) # [B, H, T, V] + + return o.transpose(1, 2) + +# Kernelbench Parameters +batch_size = 2 +seq_len = 128 +num_heads = 8 +num_kv_heads = 4 +head_dim = 64 +v_head_dim = 64 + +def get_inputs(): + q = torch.randn(batch_size, seq_len, num_heads, head_dim) + k = torch.randn(batch_size, seq_len, num_kv_heads, head_dim) + v = torch.randn(batch_size, seq_len, num_kv_heads, v_head_dim) + g = torch.randn(batch_size, seq_len, num_heads) + return [q, k, v, g] + +def get_init_inputs(): + return [num_heads, num_kv_heads, head_dim, v_head_dim] diff --git a/KernelBench/level9/path_attn_reference.py b/KernelBench/level9/path_attn_reference.py new file mode 100644 index 00000000..b31c4ad2 --- /dev/null +++ b/KernelBench/level9/path_attn_reference.py @@ -0,0 +1,310 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import math + + +class Model(nn.Module): + """ + PaTH (Path-dependent Transformer with Householder) Attention - Reference implementation. + + PaTH Attention applies cumulative Householder transformations to both Q and K, + making the attention "path-dependent" - earlier tokens affect the representation + of later tokens through these transformations. + + Core idea: + 1. Householder reflection: x_new = x - beta * (x · w) * w, where w is L2-normalized + 2. These reflections are applied cumulatively: each position i sees the effect of + all Householder reflections from positions j < i + 3. Transformed Q and K are then used in standard softmax attention + 4. Optional forget gate (log-sigmoid) adds exponential decay + + The transformation can be understood as: + - For each position i, transform k[i] by applying Householder reflections from all j < i + - Similarly transform q[i] + - Then compute standard causal softmax attention + + Note: This is a simplified reference. The actual kernel uses chunked computation + with matrix T = solve_tril(beta * w @ w.T) for efficiency. + """ + + def __init__( + self, + hidden_size: int = 2048, + num_heads: int = 32, + num_kv_heads: int = None, + use_forget_gate: bool = False, + use_qk_norm: bool = False, + ): + super().__init__() + + self.hidden_size = hidden_size + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads if num_kv_heads is not None else num_heads + self.head_dim = hidden_size // num_heads + self.kv_dim = self.num_kv_heads * self.head_dim + + self.use_forget_gate = use_forget_gate + self.use_qk_norm = use_qk_norm + + # Projections + self.q_proj = nn.Linear(hidden_size, hidden_size, bias=False) + self.k_proj = nn.Linear(hidden_size, self.kv_dim, bias=False) + self.v_proj = nn.Linear(hidden_size, self.kv_dim, bias=False) + + # W projection for Householder vectors (low-rank for parameter efficiency) + self.w_proj = nn.Sequential( + nn.Linear(hidden_size, 32, bias=False), + nn.Linear(32, self.kv_dim, bias=False), + ) + + # Beta projection: controls Householder reflection strength + # sigmoid * 2 allows range [0, 2] for potentially negative eigenvalues + self.bt_proj = nn.Linear(hidden_size, self.num_kv_heads, bias=True) + + # Optional forget gate + if use_forget_gate: + self.g_proj = nn.Linear(hidden_size, num_heads, bias=True) + + # Optional QK norm + if use_qk_norm: + self.q_norm_weight = nn.Parameter(torch.ones(self.head_dim)) + self.k_norm_weight = nn.Parameter(torch.ones(self.head_dim)) + + # Output projection + self.o_proj = nn.Linear(hidden_size, hidden_size, bias=False) + + def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None) -> torch.Tensor: + """ + Args: + x: Input tensor of shape (batch_size, seq_len, hidden_size) + attention_mask: Optional mask (unused in this reference, kept for API compatibility) + + Returns: + Output tensor of shape (batch_size, seq_len, hidden_size) + """ + batch_size, seq_len, _ = x.shape + + # Project to Q, K, V, W + q = self.q_proj(x) # [B, T, hidden_size] + k = self.k_proj(x) # [B, T, kv_dim] + v = self.v_proj(x) # [B, T, kv_dim] + w = self.w_proj(x) # [B, T, kv_dim] + + # Beta: controls Householder reflection strength + # Range [0, 2] allows negative eigenvalues (reflection can flip direction) + beta = torch.sigmoid(self.bt_proj(x).float()) * 2 # [B, T, num_kv_heads] + + # Optional forget gate + if self.use_forget_gate: + g = F.logsigmoid(self.g_proj(x).float()) # [B, T, num_heads] + else: + g = None + + # Reshape to multi-head format + q = q.view(batch_size, seq_len, self.num_heads, self.head_dim) + k = k.view(batch_size, seq_len, self.num_kv_heads, self.head_dim) + v = v.view(batch_size, seq_len, self.num_kv_heads, self.head_dim) + w = w.view(batch_size, seq_len, self.num_kv_heads, self.head_dim) + + # Optional QK norm + if self.use_qk_norm: + q = self._rms_norm(q, self.q_norm_weight) + k = self._rms_norm(k, self.k_norm_weight) + + # L2 normalize W (critical for Householder reflections) + w = F.normalize(w.float(), p=2, dim=-1) # [B, T, num_kv_heads, head_dim] + + # Apply SiLU activation to W (as done in short conv path) + w = F.silu(w) + w = F.normalize(w, p=2, dim=-1) # Re-normalize after activation + + # ============================================ + # PaTH Attention Core + # ============================================ + + # Apply cumulative Householder transformations + q_transformed, k_transformed = self._apply_cumulative_householder(q, k, w, beta) + + # Apply causal softmax attention with optional forget gate + o = self._causal_attention_with_gate(q_transformed, k_transformed, v, g) + + # Reshape and project output + o = o.reshape(batch_size, seq_len, -1) + o = self.o_proj(o) + + return o + + def _apply_cumulative_householder( + self, + q: torch.Tensor, + k: torch.Tensor, + w: torch.Tensor, + beta: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Apply cumulative Householder transformations to Q and K. + + For each position i, the transformation is: + - k_new[i] = k[i] - sum_{j torch.Tensor: + """ + Causal softmax attention with optional forget gate. + + If g is provided, attention scores are modified: + score[i,j] = q[i] @ k[j] + g_cumsum[i] - g_cumsum[j] + + where g_cumsum is the cumulative sum of log-sigmoid gate values. + + Args: + q: [B, T, num_heads, head_dim] + k: [B, T, num_kv_heads, head_dim] + v: [B, T, num_kv_heads, head_dim] + g: [B, T, num_heads] optional forget gate (log-sigmoid values) + + Returns: + o: [B, T, num_heads, head_dim] + """ + B, T, num_heads, head_dim = q.shape + num_kv_heads = k.shape[2] + num_groups = num_heads // num_kv_heads + + scale = head_dim ** -0.5 + + # Expand K and V for GQA if needed + if num_groups > 1: + k = k.repeat_interleave(num_groups, dim=2) + v = v.repeat_interleave(num_groups, dim=2) + + # Transpose for attention: [B, num_heads, T, head_dim] + q = q.transpose(1, 2) + k = k.transpose(1, 2) + v = v.transpose(1, 2) + + # Compute attention scores + scores = torch.matmul(q, k.transpose(-2, -1)) * scale # [B, num_heads, T, T] + + # Apply forget gate if provided + if g is not None: + # g_cumsum: cumulative sum of log-sigmoid values + g_cumsum = torch.cumsum(g, dim=1) # [B, T, num_heads] + g_cumsum = g_cumsum.transpose(1, 2) # [B, num_heads, T] + + # Modify scores: score[i,j] += g_cumsum[i] - g_cumsum[j] + scores = scores + g_cumsum.unsqueeze(-1) - g_cumsum.unsqueeze(-2) + + # Apply causal mask + causal_mask = torch.triu(torch.ones(T, T, device=q.device, dtype=torch.bool), diagonal=1) + scores = scores.masked_fill(causal_mask, float('-inf')) + + # Softmax and apply to values + attn_weights = F.softmax(scores, dim=-1) + o = torch.matmul(attn_weights, v) # [B, num_heads, T, head_dim] + + # Transpose back: [B, T, num_heads, head_dim] + o = o.transpose(1, 2) + + return o + + def _rms_norm(self, x: torch.Tensor, weight: torch.Tensor) -> torch.Tensor: + """RMSNorm implementation.""" + rms = torch.sqrt(torch.mean(x.float() ** 2, dim=-1, keepdim=True) + 1e-6) + return (x.float() / rms * weight).to(x.dtype) + + +# Problem dimensions +batch_size = 4 +seq_len = 512 +hidden_size = 2048 +num_heads = 32 +num_kv_heads = 8 # GQA +use_forget_gate = True +use_qk_norm = False + + +def get_inputs(): + x = torch.randn(batch_size, seq_len, hidden_size) + return [x] + + +def get_init_inputs(): + return [hidden_size, num_heads, num_kv_heads, use_forget_gate, use_qk_norm] + diff --git a/KernelBench/level9/rebased_reference.py b/KernelBench/level9/rebased_reference.py new file mode 100644 index 00000000..68f18b02 --- /dev/null +++ b/KernelBench/level9/rebased_reference.py @@ -0,0 +1,100 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import math + +def flatten_diag_outer_product_off1(x, y): + z = torch.einsum("...i,...j->...ij", x, y) + N = z.size(-1) + indicies = torch.triu_indices(N, N, 1) + indices2 = torch.arange(0, N) + return z[..., indicies[0], indicies[1]], z[..., indices2, indices2] + +class RebasedFeatureMap(nn.Module): + def __init__(self, head_dim: int, use_gamma: bool = True, use_beta: bool = True, normalize: bool = True): + super().__init__() + self.head_dim = head_dim + self.use_gamma = use_gamma + self.use_beta = use_beta + self.normalize = normalize + if use_gamma: + self.gamma = nn.Parameter(torch.ones(head_dim)) + else: + self.gamma = None + if use_beta: + self.beta = nn.Parameter(torch.zeros(head_dim)) + else: + self.beta = None + + def forward(self, x: torch.Tensor): + if self.normalize: + x = F.layer_norm(x, (self.head_dim,), self.gamma, self.beta) + elif self.use_gamma and self.use_beta: + x = x * self.gamma + self.beta + elif self.use_gamma: + x = x * self.gamma + + x2_1, x2_2 = flatten_diag_outer_product_off1(x, x) + # rebased use learnable parameters to approximate any quadratic function + return torch.cat([x2_2 * self.head_dim ** -0.5, x2_1 * (2 / self.head_dim) ** 0.5], dim=-1) + +class Model(nn.Module): + """ + Reference implementation of ReBased Linear Attention. + """ + def __init__(self, hidden_size: int, feature_dim: int = 16, num_heads: int = 16, eps: float = 1e-5): + super().__init__() + self.hidden_size = hidden_size + self.feature_dim = feature_dim + self.num_heads = num_heads + self.head_dim = hidden_size // num_heads + self.eps = eps + + self.feature_map = RebasedFeatureMap(feature_dim) + self.q_proj = nn.Linear(hidden_size, feature_dim * num_heads, bias=False) + self.k_proj = nn.Linear(hidden_size, feature_dim * num_heads, bias=False) + self.v_proj = nn.Linear(hidden_size, num_heads * self.head_dim, bias=False) + self.o_proj = nn.Linear(num_heads * self.head_dim, hidden_size, bias=False) + + def forward(self, hidden_states: torch.Tensor): + b, t, _ = hidden_states.size() + q = self.q_proj(hidden_states).view(b, t, self.num_heads, self.feature_dim).transpose(1, 2) + k = self.k_proj(hidden_states).view(b, t, self.num_heads, self.feature_dim).transpose(1, 2) + v = self.v_proj(hidden_states).view(b, t, self.num_heads, self.head_dim).transpose(1, 2) + + # Linear attention + q, k = self.feature_map(q), self.feature_map(k) + + # q, k: [b, h, t, m] + # v: [b, h, t, d] + + q, k, v = q.unsqueeze(-2), k.unsqueeze(-2), v.unsqueeze(-1) + # q: [b, h, t, 1, m] + # k: [b, h, t, 1, m] + # v: [b, h, t, d, 1] + + # Compute attention: (q * (k * v).cumsum) / (q * k.cumsum) + # (k * v) is [b, h, t, d, m] + kv_cum = (k * v).cumsum(2) + num = (q * kv_cum).sum(-1) # [b, h, t, d] + + k_cum = k.cumsum(2) + den = (q * k_cum).sum(-1) + self.eps # [b, h, t, 1] + + y = num / den + y = y.transpose(1, 2).reshape(b, t, -1) + return self.o_proj(y.to(hidden_states.dtype)) + +# Kernelbench Parameters +batch_size = 2 +seq_len = 128 +hidden_size = 1024 +feature_dim = 16 +num_heads = 16 + +def get_inputs(): + hidden_states = torch.randn(batch_size, seq_len, hidden_size) + return [hidden_states] + +def get_init_inputs(): + return [hidden_size, feature_dim, num_heads] diff --git a/KernelBench/level9/retnet_reference.py b/KernelBench/level9/retnet_reference.py new file mode 100644 index 00000000..87e1b73c --- /dev/null +++ b/KernelBench/level9/retnet_reference.py @@ -0,0 +1,107 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import math + +class RotaryEmbedding(nn.Module): + def __init__(self, dim): + super().__init__() + inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) + self.register_buffer('inv_freq', inv_freq) + + def forward(self, x, seq_len): + t = torch.arange(seq_len, device=x.device, dtype=self.inv_freq.dtype) + freqs = torch.einsum('i,j->ij', t, self.inv_freq) + emb = torch.cat((freqs, freqs), dim=-1) + # emb: [T, D] + # x: [B, H, T, D] + cos = emb.cos()[None, None, :, :] + sin = emb.sin()[None, None, :, :] + + def rotate_half(x): + x1, x2 = x.chunk(2, dim=-1) + return torch.cat((-x2, x1), dim=-1) + + return x * cos + rotate_half(x) * sin + +class Model(nn.Module): + """ + Reference implementation of RetNet (MultiScale Retention). + """ + def __init__(self, hidden_size: int = 1024, num_heads: int = 8, head_dim: int = 64, v_head_dim: int = 128): + super(Model, self).__init__() + self.hidden_size = hidden_size + self.num_heads = num_heads + self.head_dim = head_dim + self.v_head_dim = v_head_dim + + self.q_proj = nn.Linear(hidden_size, num_heads * head_dim, bias=False) + self.k_proj = nn.Linear(hidden_size, num_heads * head_dim, bias=False) + self.v_proj = nn.Linear(hidden_size, num_heads * v_head_dim, bias=False) + self.g_proj = nn.Linear(hidden_size, num_heads * v_head_dim, bias=False) + self.o_proj = nn.Linear(num_heads * v_head_dim, hidden_size, bias=False) + + self.rotary = RotaryEmbedding(head_dim) + self.norm = nn.GroupNorm(num_heads, num_heads * v_head_dim, eps=1e-5) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Args: + x (torch.Tensor): [batch_size, seq_len, hidden_size] + Returns: + torch.Tensor: [batch_size, seq_len, hidden_size] + """ + B, T, _ = x.shape + H, D, V = self.num_heads, self.head_dim, self.v_head_dim + + q = self.q_proj(x).view(B, T, H, D).transpose(1, 2) # [B, H, T, D] + k = self.k_proj(x).view(B, T, H, D).transpose(1, 2) # [B, H, T, D] + v = self.v_proj(x).view(B, T, H, V).transpose(1, 2) # [B, H, T, V] + g = self.g_proj(x) # [B, T, H*V] + + # Apply Rotary Position Embeddings + q = self.rotary(q, T) + k = self.rotary(k, T) + + # Parallel Retention computation + q, k, v = q.float(), k.float(), v.float() + + # Head-specific decay rates: gamma = 1 - 2^(-5-h) + gamma = 1.0 - torch.pow(2.0, -5.0 - torch.arange(H, device=x.device, dtype=torch.float32)) + s = gamma.log2() + + # Decay matrix: D[i, j] = gamma^(i-j) for i >= j else 0 + n_idx = torch.arange(T, device=x.device, dtype=torch.float32) + decay = torch.exp2((n_idx.unsqueeze(-1) - n_idx) * s.view(-1, 1, 1)) + mask = n_idx.unsqueeze(-1) >= n_idx + decay = decay.masked_fill(~mask, 0) + + # Scaled dot-product attention with decay + # scores: [B, H, T, T] + scores = torch.matmul(q, k.transpose(-1, -2)) * (D ** -0.5) + scores = scores * decay + + # Aggregate values + o = torch.matmul(scores, v) # [B, H, T, V] + + # GroupNorm followed by gating + o = o.transpose(1, 2).reshape(B * T, H * V) + o = self.norm(o).view(B, T, H * V) + o = o * F.silu(g) # Multi-scale gating + + return self.o_proj(o.to(x.dtype)) + +# Kernelbench Parameters +batch_size = 2 +seq_len = 128 +hidden_size = 1024 +num_heads = 8 +head_dim = 64 +v_head_dim = 128 + +def get_inputs(): + x = torch.randn(batch_size, seq_len, hidden_size) + return [x] + +def get_init_inputs(): + return [hidden_size, num_heads, head_dim, v_head_dim] diff --git a/KernelBench/level9/rk4.py b/KernelBench/level9/rk4.py new file mode 100644 index 00000000..ae719f41 --- /dev/null +++ b/KernelBench/level9/rk4.py @@ -0,0 +1,172 @@ +import torch +import torch.nn as nn + + +class Model(nn.Module): + """ + A model that performs 3D heat diffusion using a 9-point stencil + and 4th-order Runge-Kutta (RK4) time integration. + + We solve: u_t = alpha * Laplacian(u) + + The Laplacian is computed with a 1D 9-point (radius-4) stencil in x, y, and z. + We only update the interior points in each dim, leaving the boundary values unchanged. + """ + + def __init__(self, alpha: float, hx: float, hy: float, hz: float, n_steps: int): + super(Model, self).__init__() + self.alpha = alpha + self.hx = hx + self.hy = hy + self.hz = hz + self.n_steps = n_steps + + def forward(self, u0: torch.Tensor) -> torch.Tensor: + """ + Performs 3D heat diffusion simulation using RK4 time integration. + + Args: + u0: Initial 3D field tensor of shape [grid_size, grid_size, grid_size] + + Returns: + Final field after n_steps RK4 updates of the 3D heat equation + """ + # 3D 8th-order 2nd-derivative Laplacian coefficients + c0 = -205.0 / 72.0 + c1 = 8.0 / 5.0 + c2 = -1.0 / 5.0 + c3 = 8.0 / 315.0 + c4 = -1.0 / 560.0 + + # CFL stability (same constant, but now for RK4) + c = 0.05 + + # Current field + u = u0.clone() + device, dtype = u.device, u.dtype + + # Move scalars to same device/dtype as u + alpha = torch.as_tensor(self.alpha, device=device, dtype=dtype) + hx = torch.as_tensor(self.hx, device=device, dtype=dtype) + hy = torch.as_tensor(self.hy, device=device, dtype=dtype) + hz = torch.as_tensor(self.hz, device=device, dtype=dtype) + + inv_hx2 = 1.0 / (hx * hx) + inv_hy2 = 1.0 / (hy * hy) + inv_hz2 = 1.0 / (hz * hz) + + S = inv_hx2 + inv_hy2 + inv_hz2 + dt = c / (alpha * S) + + # Radius of stencil + r = 4 + + # Interior slices (note that boundary values are not updated) + zc = slice(r, -r) + yc = slice(r, -r) + xc = slice(r, -r) + + # Helper to compute Laplacian(u) on interior region + def laplacian_8th(u_field: torch.Tensor) -> torch.Tensor: + uc = u_field[zc, yc, xc] + + # x-direction second derivative + u_xx = ( + c0 * uc + + c1 * (u_field[zc, yc, r + 1 : -r + 1] + u_field[zc, yc, r - 1 : -r - 1]) + + c2 * (u_field[zc, yc, r + 2 : -r + 2] + u_field[zc, yc, r - 2 : -r - 2]) + + c3 * (u_field[zc, yc, r + 3 : -r + 3] + u_field[zc, yc, r - 3 : -r - 3]) + + c4 * (u_field[zc, yc, r + 4 :] + u_field[zc, yc, : -r - 4]) + ) * inv_hx2 + + # y-direction second derivative + u_yy = ( + c0 * uc + + c1 * (u_field[zc, r + 1 : -r + 1, xc] + u_field[zc, r - 1 : -r - 1, xc]) + + c2 * (u_field[zc, r + 2 : -r + 2, xc] + u_field[zc, r - 2 : -r - 2, xc]) + + c3 * (u_field[zc, r + 3 : -r + 3, xc] + u_field[zc, r - 3 : -r - 3, xc]) + + c4 * (u_field[zc, r + 4 :, xc] + u_field[zc, : -r - 4, xc]) + ) * inv_hy2 + + # z-direction second derivative + u_zz = ( + c0 * uc + + c1 * (u_field[r + 1 : -r + 1, yc, xc] + u_field[r - 1 : -r - 1, yc, xc]) + + c2 * (u_field[r + 2 : -r + 2, yc, xc] + u_field[r - 2 : -r - 2, yc, xc]) + + c3 * (u_field[r + 3 : -r + 3, yc, xc] + u_field[r - 3 : -r - 3, yc, xc]) + + c4 * (u_field[r + 4 :, yc, xc] + u_field[: -r - 4, yc, xc]) + ) * inv_hz2 + + return u_xx + u_yy + u_zz + + # Workspace for next solution and intermediate stage field + f = torch.empty_like(u) # u_{n+1} each step + u_stage = torch.empty_like(u) # u_n + a*dt*k + # Stage vectors (interior only) + uc_shape = u[zc, yc, xc].shape + k1 = torch.empty(uc_shape, device=device, dtype=dtype) + k2 = torch.empty(uc_shape, device=device, dtype=dtype) + k3 = torch.empty(uc_shape, device=device, dtype=dtype) + k4 = torch.empty(uc_shape, device=device, dtype=dtype) + + for _ in range(self.n_steps): + # Interior view of current u + uc = u[zc, yc, xc] + + # ---- Stage 1: k1 = alpha * Lap(u_n), u_stage = u_n + dt/2 * k1 ---- + lap1 = laplacian_8th(u) + k1.copy_(alpha * lap1) + + u_stage.copy_(u) # start from u so boundaries are preserved + u_stage[zc, yc, xc] = uc + 0.5 * dt * k1 + + # ---- Stage 2: k2 = alpha * Lap(u_stage), u_stage = u_n + dt/2 * k2 ---- + lap2 = laplacian_8th(u_stage) + k2.copy_(alpha * lap2) + + u_stage.copy_(u) # reset from u again + u_stage[zc, yc, xc] = uc + 0.5 * dt * k2 + + # ---- Stage 3: k3 = alpha * Lap(u_stage), u_stage = u_n + dt * k3 ---- + lap3 = laplacian_8th(u_stage) + k3.copy_(alpha * lap3) + + u_stage.copy_(u) + u_stage[zc, yc, xc] = uc + dt * k3 + + # ---- Stage 4: k4 = alpha * Lap(u_stage) ---- + lap4 = laplacian_8th(u_stage) + k4.copy_(alpha * lap4) + + # ---- Final RK4 combination on interior ---- + f.copy_(u) # boundaries unchanged + f[zc, yc, xc] = uc + (dt / 6.0) * (k1 + 2.0 * k2 + 2.0 * k3 + k4) + + # Swap for next step + u, f = f, u + + return u + + +# Problem configuration +grid_size = 64 +n_steps = 10 + + +def get_inputs(): + # Generate input field: [grid_size, grid_size, grid_size] + # Distribution: Standard normal (mean=0, std=1) via torch.randn() + u0 = torch.randn(grid_size, grid_size, grid_size, dtype=torch.float32).contiguous() + return [u0] + + +def get_init_inputs(): + # Random diffusion coefficient alpha in [0.1, 5.0] + alpha = torch.rand(1).item() * 4.9 + 0.1 + + # Random grid spacings hx, hy, hz in [0.5, 2.0] + hx = torch.rand(1).item() * 1.5 + 0.5 + hy = torch.rand(1).item() * 1.5 + 0.5 + hz = torch.rand(1).item() * 1.5 + 0.5 + + return [alpha, hx, hy, hz, n_steps] diff --git a/KernelBench/level9/rodimus_reference.py b/KernelBench/level9/rodimus_reference.py new file mode 100644 index 00000000..d4840c8d --- /dev/null +++ b/KernelBench/level9/rodimus_reference.py @@ -0,0 +1,259 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import math + + +class Model(nn.Module): + """ + Rodimus Attention - Reference implementation. + + Rodimus combines: + 1. Gated Linear Attention (GLA) with learnable gates + 2. Special gate computation: g_gate, tau_gate, it_gate, rt_gate + 3. Input gating via i_gate_proj + 4. Residual connection with learnable weight + + The core GLA recurrence: + h[t] = h[t-1] * exp(rt_gate[t]) + outer(k[t], v[t]) + o[t] = sum(h[t] * q[t], dim=-2) + + Key features: + - k is normalized and scaled by it_gate (input gate for keys) + - rt_gate is computed as: -g_gate * tau_gate (forget gate) + - v is gated by i_gate_proj (input gate for values) + - Residual connection with learnable weight + + Based on the Rodimus architecture. + """ + + def __init__( + self, + hidden_size: int = 1024, + expand_ratio: int = 64, + input_gate_low_rank: int = 16, + use_short_conv: bool = True, + conv_size: int = 4, + residual_in_fp32: bool = True, + ): + super().__init__() + + self.hidden_size = hidden_size + self.d_inner = int(hidden_size * 2) # Expanded dimension + self.expand_ratio = expand_ratio + self.mem_size = expand_ratio # Memory size (K and V dimension) + self.use_short_conv = use_short_conv + self.conv_size = conv_size + self.residual_in_fp32 = residual_in_fp32 + self.input_gate_low_rank = input_gate_low_rank + + # Main projections (MLP-like structure) + self.gate_proj = nn.Linear(hidden_size, self.d_inner, bias=False) + self.up_proj = nn.Linear(hidden_size, self.d_inner, bias=False) + self.down_proj = nn.Linear(self.d_inner, hidden_size, bias=False) + + # Gated activation norm + self.activation_norm_weight = nn.Parameter(torch.ones(self.d_inner)) + + # Residual weight (learnable) + self.residual_weight = nn.Parameter(torch.ones(self.d_inner, dtype=torch.float32)) + + # Attention projections + self.q_proj = nn.Linear(self.d_inner, self.mem_size, bias=False) + self.k_proj = nn.Linear(self.d_inner, self.mem_size, bias=False) + + # Gate projections + self.g_gate_proj = nn.Linear(self.d_inner, self.mem_size, bias=True) + self.tau_gate_proj = nn.Linear(self.d_inner, self.mem_size, bias=True) + + # Input gate for values (low-rank) + self.i_gate_proj = nn.Sequential( + nn.Linear(self.d_inner, input_gate_low_rank, bias=False), + nn.Linear(input_gate_low_rank, self.d_inner, bias=True), + nn.Sigmoid(), + ) + + def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None) -> torch.Tensor: + """ + Args: + x: Input tensor of shape (batch_size, seq_len, hidden_size) + attention_mask: Optional mask (unused in this reference) + + Returns: + Output tensor of shape (batch_size, seq_len, hidden_size) + """ + batch_size, seq_len, _ = x.shape + + # Up projection and gate projection + hidden_states = self.up_proj(x) # [B, T, d_inner] + final_gate = self.gate_proj(x) # [B, T, d_inner] + + # Short convolution (simulated with SiLU) + if self.use_short_conv: + shift_hidden_states = F.silu(hidden_states) + else: + shift_hidden_states = hidden_states + + # Project to Q, K + q = self.q_proj(shift_hidden_states) # [B, T, mem_size] + k = self.k_proj(shift_hidden_states) # [B, T, mem_size] + + # Input gate for values + v = self.i_gate_proj(hidden_states) * hidden_states # [B, T, d_inner] + + # Compute gates + g_gate = F.linear(shift_hidden_states, self.g_gate_proj.weight) + self.g_gate_proj.bias.float() + tau_gate = F.linear(shift_hidden_states, self.tau_gate_proj.weight) + self.tau_gate_proj.bias.float() + + # Process gates + g_gate = F.softplus(g_gate) # [B, T, mem_size] + tau_gate = torch.sigmoid(tau_gate) # [B, T, mem_size] + + # Input gate for keys: it_gate = g_gate^tau_gate + it_gate = g_gate ** tau_gate # [B, T, mem_size] + + # Forget gate (for state decay): rt_gate_log = -g_gate * tau_gate + rt_gate_log = -g_gate * tau_gate # [B, T, mem_size] + + # Normalize and scale k by it_gate + k = F.normalize(k.float(), dim=-1, eps=1e-12) * it_gate # [B, T, mem_size] + + # Reshape for attention: [B, 1, T, mem_size] -> [B, 1, mem_size, T] + q = q.unsqueeze(1).transpose(1, 2) # [B, 1, T, mem_size] + k = k.unsqueeze(1).transpose(1, 2) # [B, 1, T, mem_size] + v = v.unsqueeze(1).transpose(1, 2) # [B, 1, T, d_inner] + rt_gate_log = rt_gate_log.unsqueeze(1).transpose(1, 2) # [B, 1, T, mem_size] + + # ============================================ + # GLA (Gated Linear Attention) + # ============================================ + o = self._gla_attention(q, k, v, rt_gate_log) # [B, 1, T, d_inner] + + # Reshape back: [B, 1, T, d_inner] -> [B, T, d_inner] + o = o.transpose(1, 2).squeeze(1) # [B, T, d_inner] + + # Residual connection with learnable weight + if self.residual_in_fp32: + residual = shift_hidden_states.float() * self.residual_weight + else: + residual = shift_hidden_states * self.residual_weight + o = (o + residual).to(o.dtype) + + # Gated activation norm + o = self._gated_rms_norm(o, final_gate, self.activation_norm_weight) + + # Down projection + o = self.down_proj(o) + + return o + + def _gla_attention( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + gk: torch.Tensor + ) -> torch.Tensor: + """ + Gated Linear Attention (GLA) recurrence. + + The recurrence: + h[t] = h[t-1] * exp(gk[t]) + outer(k[t], v[t]) + o[t] = sum(h[t] * q[t], dim=-2) + + Args: + q: [B, 1, T, mem_size] - queries + k: [B, 1, T, mem_size] - keys + v: [B, 1, T, d_inner] - values + gk: [B, 1, T, mem_size] - forget gates (in log space) + + Returns: + o: [B, 1, T, d_inner] - output + """ + B, H, T, K = q.shape + V = v.shape[-1] + + # Work in float32 for stability + q = q.float() + k = k.float() + v = v.float() + gk = gk.float() + + scale = K ** -0.5 + + # Initialize state: [B, H, K, V] + h = torch.zeros(B, H, K, V, device=q.device, dtype=torch.float32) + + outputs = [] + + for t in range(T): + q_t = q[:, :, t, :] # [B, H, K] + k_t = k[:, :, t, :] # [B, H, K] + v_t = v[:, :, t, :] # [B, H, V] + gk_t = gk[:, :, t, :] # [B, H, K] + + # Scale query + q_t = q_t * scale + + # Decay state: h = h * exp(gk) + # gk is negative (forget gate), so exp(gk) < 1 + decay = torch.exp(gk_t).unsqueeze(-1) # [B, H, K, 1] + h = h * decay + + # Update state: h = h + outer(k, v) + # For each batch and head: h[b,h] += outer(k_t[b,h], v_t[b,h]) + # h[b,h] is [K, V], k_t[b,h] is [K], v_t[b,h] is [V] + kv_outer = torch.einsum('bhk,bhv->bhkv', k_t, v_t) # [B, H, K, V] + h = h + kv_outer + + # Output: o = sum(h * q, dim=-2) + # For each batch and head: o[b,h] = sum(h[b,h] * q_t[b,h], dim=0) + # h[b,h] is [K, V], q_t[b,h] is [K] + o_t = torch.einsum('bhk,bhkv->bhv', q_t, h) # [B, H, V] + + outputs.append(o_t) + + # Stack outputs: [T, B, H, V] + outputs = torch.stack(outputs, dim=0) + + # Transpose to [B, H, T, V] + outputs = outputs.transpose(0, 1).transpose(1, 2) + + return outputs + + def _gated_rms_norm(self, x: torch.Tensor, g: torch.Tensor, weight: torch.Tensor) -> torch.Tensor: + """ + Gated RMSNorm: RMSNorm(x) * sigmoid(g) + + Args: + x: [B, T, D] + g: [B, T, D] - gate + weight: [D] - norm weight + + Returns: + output: [B, T, D] + """ + rms = torch.sqrt(torch.mean(x.float() ** 2, dim=-1, keepdim=True) + 1e-5) + x_norm = (x.float() / rms) * weight + return (x_norm * torch.sigmoid(g.float())).to(x.dtype) + + +# Problem dimensions +batch_size = 4 +seq_len = 512 +hidden_size = 1024 +expand_ratio = 64 +input_gate_low_rank = 16 +use_short_conv = True +conv_size = 4 +residual_in_fp32 = True + + +def get_inputs(): + x = torch.randn(batch_size, seq_len, hidden_size) + return [x] + + +def get_init_inputs(): + return [hidden_size, expand_ratio, input_gate_low_rank, use_short_conv, conv_size, residual_in_fp32] + diff --git a/KernelBench/level9/rotary_reference.py b/KernelBench/level9/rotary_reference.py new file mode 100644 index 00000000..427d92df --- /dev/null +++ b/KernelBench/level9/rotary_reference.py @@ -0,0 +1,61 @@ +import torch +import torch.nn as nn +import math + +class Model(nn.Module): + """ + Reference implementation for Rotary Positional Embeddings (RoPE). + """ + def __init__(self, dim: int, base: float = 10000.0): + super(Model, self).__init__() + self.dim = dim + self.base = base + # Generate inverse frequencies + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) + self.register_buffer("inv_freq", inv_freq) + + def forward(self, q: torch.Tensor, k: torch.Tensor) -> torch.Tensor: + """ + Args: + q (torch.Tensor): Query tensor of shape [batch_size, seq_len, num_heads, head_dim] + k (torch.Tensor): Key tensor of shape [batch_size, seq_len, num_heads, head_dim] + Returns: + tuple[torch.Tensor, torch.Tensor]: Rotated Q and K + """ + B, T, H, D = q.shape + t = torch.arange(T, device=q.device, dtype=torch.float32) + # freqs: [T, D/2] + freqs = torch.outer(t, self.inv_freq.to(t.device)) + + # Standard RoPE: rotate pairs of (0, D/2), (1, D/2+1), ... + # Here we use the GPT-NeoX style: + # x1 = x[..., :D/2], x2 = x[..., D/2:] + # o = [x1*cos - x2*sin, x1*sin + x2*cos] + + cos = freqs.cos().view(1, T, 1, D // 2).to(q.dtype) + sin = freqs.sin().view(1, T, 1, D // 2).to(q.dtype) + + def apply_rotary(x): + x1 = x[..., :D // 2] + x2 = x[..., D // 2:] + return torch.cat([x1 * cos - x2 * sin, x1 * sin + x2 * cos], dim=-1) + + q_out = apply_rotary(q) + k_out = apply_rotary(k) + + # Kernelbench requires a single output tensor + return torch.cat([q_out, k_out], dim=-1) + +# Kernelbench Parameters +batch_size = 4 +seq_len = 2048 +num_heads = 32 +head_dim = 128 + +def get_inputs(): + q = torch.randn(batch_size, seq_len, num_heads, head_dim) + k = torch.randn(batch_size, seq_len, num_heads, head_dim) + return [q, k] + +def get_init_inputs(): + return [head_dim] diff --git a/KernelBench/level9/rwkv6_reference.py b/KernelBench/level9/rwkv6_reference.py new file mode 100644 index 00000000..3ca48264 --- /dev/null +++ b/KernelBench/level9/rwkv6_reference.py @@ -0,0 +1,237 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import math + +class RMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-5): + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.eps = eps + + def forward(self, x): + norm_x = torch.mean(x**2, dim=-1, keepdim=True) + x_normed = x * torch.rsqrt(norm_x + self.eps) + return x_normed * self.weight + +class GroupNorm(nn.Module): + def __init__(self, num_groups, num_channels, eps=1e-5, weight=True, bias=True): + super().__init__() + self.num_groups = num_groups + self.num_channels = num_channels + self.eps = eps + if weight: + self.weight = nn.Parameter(torch.ones(num_channels)) + else: + self.register_parameter('weight', None) + if bias: + self.bias = nn.Parameter(torch.zeros(num_channels)) + else: + self.register_parameter('bias', None) + + def forward(self, x): + # x: [B, T, C] + B, T, C = x.shape + x = x.view(B, T, self.num_groups, -1) + mean = x.mean(dim=-1, keepdim=True) + var = x.var(dim=-1, keepdim=True, unbiased=False) + x = (x - mean) * torch.rsqrt(var + self.eps) + x = x.view(B, T, C) + if self.weight is not None: + x = x * self.weight + if self.bias is not None: + x = x + self.bias + return x + +class LoRA(nn.Module): + def __init__(self, input_dim, output_dim, low_rank_dim, bias=True): + super().__init__() + self.lora = nn.Sequential( + nn.Linear(input_dim, low_rank_dim, bias=False), + nn.Tanh(), + nn.Linear(low_rank_dim, output_dim, bias=bias), + ) + # Initialization + nn.init.zeros_(self.lora[0].weight) + nn.init.orthogonal_(self.lora[2].weight, gain=0.1) + if bias: + nn.init.zeros_(self.lora[2].bias) + + def forward(self, x): + return self.lora(x) + +class LerpLinear(nn.Module): + def __init__(self, input_dim, output_dim, low_rank_dim=None): + super().__init__() + if low_rank_dim is None: + self.linear = nn.Linear(input_dim, output_dim, bias=False) + else: + self.linear = LoRA(input_dim, output_dim, low_rank_dim) + self.mu = nn.Parameter(torch.zeros(input_dim)) + + def forward(self, x, delta): + return self.linear(x + delta * self.mu) + +class DDLerpLinear(nn.Module): + def __init__(self, input_dim, output_dim, low_rank_dim=None): + super().__init__() + if low_rank_dim is None: + self.linear = nn.Linear(input_dim, output_dim, bias=False) + else: + self.linear = LoRA(input_dim, output_dim, low_rank_dim) + + def forward(self, x, mu, delta): + return self.linear(x + delta * mu) + +class Model(nn.Module): + """ + Reference implementation of RWKV-6 Attention. + """ + def __init__( + self, + hidden_size: int = 1024, + expand_k: float = 0.5, + expand_v: float = 1.0, + num_heads: int = 4, + proj_low_rank_dim: int = 32, + gate_low_rank_dim: int = 64, + ): + super(Model, self).__init__() + self.hidden_size = hidden_size + self.num_heads = num_heads + self.key_dim = int(hidden_size * expand_k) + self.value_dim = int(hidden_size * expand_v) + self.head_k_dim = self.key_dim // num_heads + self.head_v_dim = self.value_dim // num_heads + self.proj_low_rank_dim = proj_low_rank_dim + + # Time-mix projections + self.x_proj_lora = LoRA(hidden_size, proj_low_rank_dim * 5, proj_low_rank_dim) + self.x_proj_mu = nn.Parameter(torch.zeros(hidden_size)) + self.x_proj_out = nn.Linear(proj_low_rank_dim * 5, hidden_size, bias=False) + self.x_bias = nn.Parameter(torch.zeros(5, hidden_size)) + + self.r_proj = DDLerpLinear(hidden_size, self.key_dim) + self.w_proj = DDLerpLinear(hidden_size, self.key_dim, low_rank_dim=gate_low_rank_dim) + self.k_proj = DDLerpLinear(hidden_size, self.key_dim) + self.v_proj = DDLerpLinear(hidden_size, self.value_dim) + self.g_proj = DDLerpLinear(hidden_size, self.value_dim) + + self.bonus = nn.Parameter(torch.zeros(num_heads, self.head_k_dim)) + + self.g_norm = GroupNorm(num_heads, self.value_dim) + self.o_proj = nn.Linear(self.value_dim, hidden_size, bias=False) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + """ + Args: + hidden_states (torch.Tensor): [B, T, H] + Returns: + torch.Tensor: [B, T, H] + """ + B, T, H = hidden_states.shape + + # 1. Time Shift + shifted = F.pad(hidden_states, (0, 0, 1, -1)) + # delta is (x_{t-1} - x_t) in FLA's RWKV6 implementation + delta = shifted - hidden_states + + # 2. Extract r, w, k, v, g parameters + # x_proj logic + x_lerp = hidden_states + delta * self.x_proj_mu + x_lora = self.x_proj_lora(x_lerp) # [B, T, 5 * R] + # x_proj[1] is Tanh, x_proj[2] is Linear + x_lora = torch.tanh(x_lora) + x_params = self.x_proj_out(x_lora) # [B, T, H] + # Reshape and add bias for the 5 parameters + # In FLA: x = x_proj[2](x_proj[1](x_proj[0](h, delta))) + # Here x_params has shape [B, T, H]. We need to expand it to 5 parameters or use the specific head logic. + # Actually, in RWKV-6, x_proj[0] projects to 5*R, then Tanh, then x_proj[2] projects to H. + # Wait, the FLA code does: + # x = self.x_proj[0](hidden_states, delta).view(B, T, -1, self.proj_low_rank_dim) + # x = torch.einsum('b t n r, h n r-> b t n h', self.x_proj[1](x), self.x_proj[2].weight.view(hidden_size, 5, -1)) + # r, w, k, v, g = x.add_(self.x_bias).unbind(-2) + + # Let's re-implement that exactly: + # projected: [B, T, 5, R] + x_0 = self.x_proj_lora(x_lerp).view(B, T, 5, self.proj_low_rank_dim) + x_1 = torch.tanh(x_0) + # x_proj_out weight: [H, 5 * R] -> [H, 5, R] + # FLA uses x_proj[2].weight which is [hidden_size, 5*proj_low_rank_dim] + # In our case x_proj_out is nn.Linear(5*R, H) + x_2 = torch.einsum('b t n r, h n r -> b t n h', x_1, self.x_proj_out.weight.view(H, 5, -1)) + # Adding bias and unbinding + r_mu, w_mu, k_mu, v_mu, g_mu = (x_2 + self.x_bias).unbind(-2) + + r = self.r_proj(hidden_states, r_mu, delta) + w = self.w_proj(hidden_states, w_mu, delta) + k = self.k_proj(hidden_states, k_mu, delta) + v = self.v_proj(hidden_states, v_mu, delta) + g = self.g_proj(hidden_states, g_mu, delta) + + # 3. Recurrence (RWKV-6) + # r, w, k, v are [B, T, D] + # Reshape to heads + NH = self.num_heads + DK = self.head_k_dim + DV = self.head_v_dim + + r = r.view(B, T, NH, DK) + k = k.view(B, T, NH, DK) + v = v.view(B, T, NH, DV) + w = -torch.exp(w.view(B, T, NH, DK)) # Time-decay is negative exp + u = self.bonus # [NH, DK] + + # Functional recurrence + # S_t = S_{t-1} * exp(w_t) + k_t @ v_t^T + # o_t = r_t @ (S_{t-1} + u * k_t @ v_t^T) + + S = torch.zeros(B, NH, DK, DV, device=hidden_states.device, dtype=torch.float32) + o = torch.zeros(B, T, NH, DV, device=hidden_states.device, dtype=torch.float32) + + r, k, v, w = r.float(), k.float(), v.float(), w.float() + + for t in range(T): + r_t = r[:, t] # [B, NH, DK] + k_t = k[:, t] # [B, NH, DK] + v_t = v[:, t] # [B, NH, DV] + w_t = w[:, t].exp() # [B, NH, DK] + + # kv = k_t^T @ v_t + kv = torch.einsum('b h k, b h v -> b h k v', k_t, v_t) + + # Output computation + # o = r @ (S + u * kv) + # u is [NH, DK]. + + # S is [DK, DV]. r is [DK]. o is [DV]. + rS = torch.einsum('b h k, b h k v -> b h v', r_t, S) + # r @ (u * kv) -> (r_t * u * k_t) @ v_t + ruk = r_t * u.view(1, NH, DK) * k_t + rukv = torch.einsum('b h k, b h v -> b h v', ruk, v_t) + + o[:, t] = rS + rukv + + # Update state + # S = S * w + kv + S = S * w_t.unsqueeze(-1) + kv + + # 4. Final output + o = o.view(B, T, -1) + o = self.g_norm(o) * F.silu(g) + return self.o_proj(o) + +# Kernelbench Parameters +batch_size = 2 +seq_len = 128 +hidden_size = 512 +num_heads = 4 +expand_k = 0.5 +expand_v = 1.0 + +def get_inputs(): + hidden_states = torch.randn(batch_size, seq_len, hidden_size) + return [hidden_states] + +def get_init_inputs(): + return [hidden_size, expand_k, expand_v, num_heads] diff --git a/KernelBench/level9/rwkv7_reference.py b/KernelBench/level9/rwkv7_reference.py new file mode 100644 index 00000000..fefef827 --- /dev/null +++ b/KernelBench/level9/rwkv7_reference.py @@ -0,0 +1,264 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import math + +class LoRA(nn.Module): + def __init__( + self, + input_dim: int, + output_dim: int, + low_rank_dim: int, + bias: bool = True, + activation: str = 'tanh', + ): + super().__init__() + self.input_dim = input_dim + self.output_dim = output_dim + self.low_rank_dim = low_rank_dim + self.bias = bias + + if activation is None: + self.activation = nn.Identity() + elif activation == 'sigmoid': + self.activation = nn.Sigmoid() + elif activation == 'tanh': + self.activation = nn.Tanh() + else: + raise ValueError(f"Not supported activation `{activation}`.") + + self.lora = nn.Sequential( + nn.Linear(input_dim, low_rank_dim, bias=False), + self.activation, + nn.Linear(low_rank_dim, output_dim, bias=bias), + ) + self._initialize_weights() + + def _initialize_weights(self): + nn.init.zeros_(self.lora[0].weight) + shape = self.lora[2].weight.shape + weight_fp32 = torch.zeros(shape) + gain = math.sqrt(shape[1] / shape[0]) if shape[1] > shape[0] else 1 + nn.init.orthogonal_(weight_fp32, gain=gain * 0.1) + self.lora[2].weight.data.copy_(weight_fp32.to(self.lora[2].weight.dtype)) + if self.lora[2].bias is not None: + nn.init.zeros_(self.lora[2].bias) + + def set_bias_value(self, value): + if self.bias and self.lora[2].bias is not None: + if isinstance(value, torch.Tensor): + self.lora[2].bias.data.copy_(value.to(self.lora[2].bias.dtype)) + else: + nn.init.constant_(self.lora[2].bias, value) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.lora(x) + +class Model(nn.Module): + """ + Reference implementation of RWKV7 Linear Attention. + """ + def __init__( + self, + hidden_size: int = 1024, + head_dim: int = 64, + layer_idx: int = 0, + num_hidden_layers: int = 24, + ): + super(Model, self).__init__() + self.hidden_size = hidden_size + self.head_dim = head_dim + self.num_heads = hidden_size // head_dim + self.layer_idx = layer_idx + self.num_hidden_layers = num_hidden_layers + + # LoRA dimensions as per RWKV7 implementation + factor = head_dim / 64 + decay_low_rank_dim = max(32, int(round((2.5 * (hidden_size**0.5)) * factor / 32) * 32)) + gate_low_rank_dim = max(32, int(round((5 * (hidden_size**0.5)) / 32) * 32)) + a_low_rank_dim = max(32, int(round((2.5 * (hidden_size**0.5)) * factor / 32) * 32)) + v_low_rank_dim = max(32, int(round((1.7 * (hidden_size**0.5)) * factor / 32) * 32)) + + self.x_r = nn.Parameter(torch.zeros(1, 1, hidden_size)) + self.x_w = nn.Parameter(torch.zeros(1, 1, hidden_size)) + self.x_k = nn.Parameter(torch.zeros(1, 1, hidden_size)) + self.x_v = nn.Parameter(torch.zeros(1, 1, hidden_size)) + self.x_a = nn.Parameter(torch.zeros(1, 1, hidden_size)) + self.x_g = nn.Parameter(torch.zeros(1, 1, hidden_size)) + + self.k_k = nn.Parameter(torch.zeros(hidden_size)) + self.k_a = nn.Parameter(torch.zeros(hidden_size)) + self.r_k = nn.Parameter(torch.zeros(self.num_heads, head_dim)) + + self.r_proj = nn.Linear(hidden_size, hidden_size, bias=False) + self.k_proj = nn.Linear(hidden_size, hidden_size, bias=False) + self.v_proj = nn.Linear(hidden_size, hidden_size, bias=False) + self.o_proj = nn.Linear(hidden_size, hidden_size, bias=False) + + self.w_lora = LoRA(hidden_size, hidden_size, low_rank_dim=decay_low_rank_dim, activation='tanh') + self.a_lora = LoRA(hidden_size, hidden_size, low_rank_dim=a_low_rank_dim, activation=None) + self.g_lora = LoRA(hidden_size, hidden_size, low_rank_dim=gate_low_rank_dim, activation='sigmoid', bias=False) + + if layer_idx != 0: + self.v_lora = LoRA(hidden_size, hidden_size, low_rank_dim=v_low_rank_dim, activation=None) + + self.g_norm = nn.GroupNorm( + num_groups=self.num_heads, + num_channels=hidden_size, + eps=head_dim * 1e-5, + affine=True, + ) + + self._initialize_weights() + + def _initialize_weights(self): + ratio_0_to_1 = self.layer_idx / (self.num_hidden_layers - 1) + ratio_1_to_almost0 = 1.0 - (self.layer_idx / self.num_hidden_layers) + + ddd = torch.ones(1, 1, self.hidden_size) + www = torch.zeros(self.hidden_size) + zigzag = torch.zeros(self.hidden_size) + linear = torch.zeros(self.hidden_size) + for n in range(self.hidden_size): + linear[n] = n / (self.hidden_size-1) - 0.5 + zigzag[n] = ((n % self.head_dim) - ((self.head_dim-1) / 2)) / ((self.head_dim-1) / 2) + zigzag[n] = zigzag[n] * abs(zigzag[n]) + www[n] = -6 + 6 * (n / (self.hidden_size - 1)) ** (1 + 1 * ratio_0_to_1 ** 0.3) + ddd[0, 0, n] = n / self.hidden_size + + self.x_r.data = (1.0 - torch.pow(ddd, 0.2 * ratio_1_to_almost0)) + self.x_w.data = (1.0 - torch.pow(ddd, 0.9 * ratio_1_to_almost0)) + self.x_k.data = (1.0 - torch.pow(ddd, 0.7 * ratio_1_to_almost0)) + self.x_v.data = (1.0 - torch.pow(ddd, 0.7 * ratio_1_to_almost0)) + self.x_a.data = (1.0 - torch.pow(ddd, 0.9 * ratio_1_to_almost0)) + self.x_g.data = (1.0 - torch.pow(ddd, 0.2 * ratio_1_to_almost0)) + + nn.init.constant_(self.k_a, 1.02) + nn.init.constant_(self.r_k, -0.04) + self.k_k.data.copy_(0.71 - linear*0.1) + self.w_lora.set_bias_value(www + 0.5 + zigzag*2.5) + self.a_lora.set_bias_value(-0.19 + zigzag*0.3 + linear*0.4) + + if self.layer_idx != 0: + self.v_lora.set_bias_value(0.73 - linear*0.4) + + self.g_norm.weight.data[:] = ((self.layer_idx + 1) / self.num_hidden_layers) ** 0.7 + nn.init.orthogonal_(self.r_proj.weight) + nn.init.orthogonal_(self.k_proj.weight, gain=0.1) + nn.init.orthogonal_(self.v_proj.weight) + self.o_proj.weight.data.zero_() + + def forward(self, hidden_states: torch.Tensor, v_first: torch.Tensor) -> torch.Tensor: + """ + Args: + hidden_states (torch.Tensor): Input states of shape (B, T, H*D). + v_first (torch.Tensor): Value from the first layer of shape (B, T, H*D). + Returns: + torch.Tensor: Output states of shape (B, T, H*D). + """ + B, T, C = hidden_states.shape + H = self.num_heads + D = self.head_dim + + # Token shift (time_shift - current) + shifted = F.pad(hidden_states, (0, 0, 1, -1)) + delta = shifted - hidden_states + + # Fused addcmul equivalent + xr = hidden_states + delta * self.x_r + xw = hidden_states + delta * self.x_w + xk = hidden_states + delta * self.x_k + xv = hidden_states + delta * self.x_v + xa = hidden_states + delta * self.x_a + xg = hidden_states + delta * self.x_g + + r = self.r_proj(xr) + w = -0.6065306597126334 * self.w_lora(xw).sigmoid() + k = self.k_proj(xk) + v = self.v_proj(xv) + + if self.layer_idx == 0: + v_first = v + else: + v = torch.lerp(v, v_first, self.v_lora(xv).sigmoid()) + + a = self.a_lora(xa).sigmoid() + g = self.g_lora(xg) + + # K update: k = k * (1 + (a - 1) * k_a) + k = k + k * (a - 1) * self.k_a + + # kk needs the k BEFORE the fused_k_rwkv7 update + k_for_kk = self.k_proj(xk) + # kk = F.normalize(rearrange(k_for_kk * self.k_k, 'b t (h d) -> b t h d', d=self.head_dim), dim=-1, p=2.0) + kk = F.normalize((k_for_kk * self.k_k).view(B, T, H, D), dim=-1, p=2.0) + + # Reshape for linear attention + r = r.view(B, T, H, D) + w = w.view(B, T, H, D) + k = k.view(B, T, H, D) + v = v.view(B, T, H, D) + a_lora = a.view(B, T, H, D) + + # Attention core (Recurrence) + # H_t = diag(exp(w_t)) H_{t-1} + (kk_t * a_t) @ (-kk_t^T @ H_{t-1}) + k_t @ v_t^T + a_recur = -kk + b_recur = kk * a_lora + + out = torch.zeros_like(v) + state = torch.zeros(B, H, D, D, device=hidden_states.device, dtype=hidden_states.dtype) + + for t in range(T): + r_t = r[:, t] # (B, H, D) + w_t = w[:, t] # (B, H, D) + k_t = k[:, t] # (B, H, D) + v_t = v[:, t] # (B, H, D) + a_t = a_recur[:, t] # (B, H, D) + b_t = b_recur[:, t] # (B, H, D) + + # state: (B, H, D, D) + # exp(w_t): (B, H, D) + state = state * torch.exp(w_t).unsqueeze(-1) + + # term_a = a_t^T @ state + # a_t: (B, H, D), state: (B, H, D, D) -> (B, H, 1, D) @ (B, H, D, D) -> (B, H, 1, D) + term_a = torch.matmul(a_t.unsqueeze(-2), state) + + # state = state + b_t @ term_a + # b_t: (B, H, D), term_a: (B, H, 1, D) -> (B, H, D, 1) @ (B, H, 1, D) -> (B, H, D, D) + state = state + torch.matmul(b_t.unsqueeze(-1), term_a) + + # state = state + k_t @ v_t^T + # k_t: (B, H, D), v_t: (B, H, D) -> (B, H, D, 1) @ (B, H, 1, D) -> (B, H, D, D) + state = state + torch.matmul(k_t.unsqueeze(-1), v_t.unsqueeze(-2)) + + # out_t = r_t @ state -> (B, H, 1, D) @ (B, H, D, D) -> (B, H, 1, D) + out_t = torch.matmul(r_t.unsqueeze(-2), state) + out[:, t] = out_t.squeeze(-2) + + o = out.reshape(B, T, -1) + + # Norm and output correction + o = self.g_norm(o.transpose(1, 2)).transpose(1, 2) + + r_k_b = self.r_k.view(1, 1, H, D) + correction_term = (r * k * r_k_b).sum(-1, keepdim=True) * v + o = (o + correction_term.reshape(B, T, -1)) * g + + o = self.o_proj(o) + return o + +# Benchmarking parameters +B = 8 +T = 64 # Small sequence for reference implementation (O(T) loop) +C = 1024 +H = 16 +D = 64 + +def get_inputs(): + hidden_states = torch.randn(B, T, C) + v_first = torch.randn(B, T, C) + return [hidden_states, v_first] + +def get_init_inputs(): + return [C, D, 0, 24] diff --git a/KernelBench/level9/samba_reference.py b/KernelBench/level9/samba_reference.py new file mode 100644 index 00000000..282980f1 --- /dev/null +++ b/KernelBench/level9/samba_reference.py @@ -0,0 +1,263 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import math + +class RMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-5): + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.eps = eps + + def forward(self, x): + norm_x = torch.mean(x**2, dim=-1, keepdim=True) + x_normed = x * torch.rsqrt(norm_x + self.eps) + return x_normed * self.weight + +class RotaryEmbedding(nn.Module): + def __init__(self, dim, base=10000.0): + super().__init__() + self.dim = dim + self.base = base + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) + self.register_buffer("inv_freq", inv_freq) + + def forward(self, q, k): + # q, k: [B, T, H, D] + seq_len = q.shape[1] + t = torch.arange(seq_len, device=q.device, dtype=torch.float32) + freqs = torch.outer(t, self.inv_freq) + # freqs: [T, D/2] + cos = freqs.cos().to(q.dtype) + sin = freqs.sin().to(q.dtype) + + def apply_rotary(x, cos, sin): + # x: [B, T, H, D] + # cos, sin: [T, D/2] + d = x.shape[-1] + x1 = x[..., :d//2] + x2 = x[..., d//2:] + # cos[T, D/2] -> [1, T, 1, D/2] + cos = cos.unsqueeze(0).unsqueeze(2) + sin = sin.unsqueeze(0).unsqueeze(2) + o1 = x1 * cos - x2 * sin + o2 = x1 * sin + x2 * cos + return torch.cat([o1, o2], dim=-1) + + return apply_rotary(q, cos, sin), apply_rotary(k, cos, sin) + +class Attention(nn.Module): + def __init__(self, hidden_size, num_heads, num_kv_heads=None, qkv_bias=False, window_size=None, rope_theta=10000.0): + super().__init__() + self.hidden_size = hidden_size + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads if num_kv_heads is not None else num_heads + self.head_dim = hidden_size // num_heads + self.kv_dim = self.num_kv_heads * self.head_dim + self.qkv_bias = qkv_bias + self.window_size = window_size + + self.q_proj = nn.Linear(hidden_size, hidden_size, bias=qkv_bias) + self.k_proj = nn.Linear(hidden_size, self.kv_dim, bias=qkv_bias) + self.v_proj = nn.Linear(hidden_size, self.kv_dim, bias=qkv_bias) + self.o_proj = nn.Linear(hidden_size, hidden_size, bias=False) + + self.rotary = RotaryEmbedding(self.head_dim, base=rope_theta) + + def forward(self, x): + B, T, C = x.shape + q = self.q_proj(x).view(B, T, self.num_heads, self.head_dim) + k = self.k_proj(x).view(B, T, self.num_kv_heads, self.head_dim) + v = self.v_proj(x).view(B, T, self.num_kv_heads, self.head_dim) + + q, k = self.rotary(q, k) + + # Transpose for scaled_dot_product_attention: [B, H, T, D] + q = q.transpose(1, 2) + k = k.transpose(1, 2) + v = v.transpose(1, 2) + + # Scaling is handled by scaled_dot_product_attention if scale=None + # We need causal mask or window mask + # Samba uses sliding window attention if window_size is set + attn_mask = None + if self.window_size is not None: + # Create custom mask for sliding window + mask = torch.tril(torch.ones(T, T, device=x.device), diagonal=0) + if self.window_size < T: + mask = torch.triu(mask, diagonal=-(self.window_size-1)) + attn_mask = (mask == 0) + + # SDPA + o = F.scaled_dot_product_attention( + q, k, v, + attn_mask=attn_mask, + dropout_p=0.0, + is_causal=(self.window_size is None) + ) + + o = o.transpose(1, 2).contiguous().view(B, T, -1) + return self.o_proj(o) + +class MambaLayer(nn.Module): + def __init__(self, hidden_size, state_size=16, conv_kernel=4, intermediate_size=None, time_step_rank=None, use_bias=False): + super().__init__() + self.hidden_size = hidden_size + self.ssm_state_size = state_size + self.conv_kernel_size = conv_kernel + self.intermediate_size = intermediate_size if intermediate_size is not None else 2 * hidden_size + self.time_step_rank = time_step_rank if time_step_rank is not None else math.ceil(hidden_size / 16) + + self.in_proj = nn.Linear(hidden_size, self.intermediate_size * 2, bias=use_bias) + self.conv1d = nn.Conv1d( + in_channels=self.intermediate_size, + out_channels=self.intermediate_size, + bias=True, + kernel_size=conv_kernel, + groups=self.intermediate_size, + padding=conv_kernel - 1, + ) + self.x_proj = nn.Linear(self.intermediate_size, self.time_step_rank + self.ssm_state_size * 2, bias=False) + self.dt_proj = nn.Linear(self.time_step_rank, self.intermediate_size, bias=True) + + A = torch.arange(1, self.ssm_state_size + 1, dtype=torch.float32).repeat(self.intermediate_size, 1) + self.A_log = nn.Parameter(torch.log(A)) + self.D = nn.Parameter(torch.ones(self.intermediate_size)) + self.out_proj = nn.Linear(self.intermediate_size, hidden_size, bias=use_bias) + + def forward(self, x): + B, T, _ = x.shape + projected = self.in_proj(x).transpose(1, 2) # [B, 2*intermediate, T] + x_inner, gate = projected.chunk(2, dim=1) + + # Conv + x_inner = self.conv1d(x_inner)[..., :T] + x_inner = F.silu(x_inner) + + # SSM parameters + ssm_params = self.x_proj(x_inner.transpose(1, 2)) # [B, T, rank + 2*state] + dt, B_ssm, C_ssm = torch.split(ssm_params, [self.time_step_rank, self.ssm_state_size, self.ssm_state_size], dim=-1) + + dt = F.softplus(self.dt_proj(dt)).transpose(1, 2) # [B, intermediate, T] + A = -torch.exp(self.A_log.float()) # [intermediate, state] + + # Scan (reference implementation) + # s_t = dA * s_{t-1} + dB * x_t + # dA = exp(dt * A) + # dB = dt * B + # y_t = C * s_t + + results = [] + ssm_state = torch.zeros(B, self.intermediate_size, self.ssm_state_size, device=x.device, dtype=x.dtype) + + for i in range(T): + dt_i = dt[:, :, i].unsqueeze(-1) # [B, intermediate, 1] + B_i = B_ssm[:, i, :].unsqueeze(1) # [B, 1, state] + C_i = C_ssm[:, i, :].unsqueeze(-1) # [B, state, 1] + x_i = x_inner[:, :, i].unsqueeze(-1) # [B, intermediate, 1] + + dA = torch.exp(A.unsqueeze(0) * dt_i) # [B, intermediate, state] + dB = dt_i * B_i # [B, intermediate, state] + + ssm_state = dA * ssm_state + dB * x_i + y_i = torch.matmul(ssm_state, C_i).squeeze(-1) # [B, intermediate] + results.append(y_i) + + y = torch.stack(results, dim=2) # [B, intermediate, T] + y = y + x_inner * self.D.view(1, -1, 1) + y = y * F.silu(gate) + + return self.out_proj(y.transpose(1, 2)) + +class GatedMLP(nn.Module): + def __init__(self, hidden_size, hidden_ratio=4): + super().__init__() + intermediate_size = int(hidden_size * hidden_ratio * 2 / 3) + intermediate_size = 256 * ((intermediate_size + 256 - 1) // 256) + self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False) + self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False) + self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False) + + def forward(self, x): + return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x)) + +class SambaBlock(nn.Module): + def __init__(self, hidden_size, layer_idx, config_attn): + super().__init__() + self.mixer_norm = RMSNorm(hidden_size) + if config_attn is not None and layer_idx in config_attn['layers']: + self.mixer = Attention( + hidden_size=hidden_size, + num_heads=config_attn['num_heads'], + num_kv_heads=config_attn['num_kv_heads'], + qkv_bias=config_attn['qkv_bias'], + window_size=config_attn['window_size'], + rope_theta=config_attn['rope_theta'] + ) + else: + self.mixer = MambaLayer(hidden_size=hidden_size) + + self.mlp_norm = RMSNorm(hidden_size) + self.mlp = GatedMLP(hidden_size=hidden_size) + + def forward(self, x): + residual = x + x = self.mixer_norm(x) + x = self.mixer(x) + x = x + residual + + residual = x + x = self.mlp_norm(x) + x = self.mlp(x) + x = x + residual + return x + +class Model(nn.Module): + """ + Samba Model Backbone. + Hybrid architecture alternating between Mamba and sliding window Attention. + """ + def __init__( + self, + hidden_size: int = 512, + num_hidden_layers: int = 4, + attn_layers: tuple = (1, 3), + num_heads: int = 8, + num_kv_heads: int = 8, + window_size: int = 256, + ): + super().__init__() + config_attn = { + 'layers': attn_layers, + 'num_heads': num_heads, + 'num_kv_heads': num_kv_heads, + 'qkv_bias': False, + 'window_size': window_size, + 'rope_theta': 10000.0 + } + self.layers = nn.ModuleList([ + SambaBlock(hidden_size, i, config_attn) for i in range(num_hidden_layers) + ]) + self.norm_f = RMSNorm(hidden_size) + + def forward(self, x): + for layer in self.layers: + x = layer(x) + x = self.norm_f(x) + return x + +# Kernelbench Parameters +batch_size = 2 +seq_len = 128 +hidden_size = 512 +num_hidden_layers = 2 +attn_layers = (1,) +num_heads = 8 +num_kv_heads = 8 + +def get_inputs(): + x = torch.randn(batch_size, seq_len, hidden_size) + return [x] + +def get_init_inputs(): + return [hidden_size, num_hidden_layers, attn_layers, num_heads, num_kv_heads] diff --git a/KernelBench/level9/short_conv_reference.py b/KernelBench/level9/short_conv_reference.py new file mode 100644 index 00000000..f74551bb --- /dev/null +++ b/KernelBench/level9/short_conv_reference.py @@ -0,0 +1,48 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + """ + Reference implementation for Short Causal 1D Convolution. + """ + def __init__(self, hidden_size: int, kernel_size: int = 4, bias: bool = False): + super(Model, self).__init__() + self.hidden_size = hidden_size + self.kernel_size = kernel_size + self.conv = nn.Conv1d( + in_channels=hidden_size, + out_channels=hidden_size, + kernel_size=kernel_size, + groups=hidden_size, + padding=kernel_size - 1, + bias=bias + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Args: + x (torch.Tensor): Input tensor of shape [batch_size, seq_len, hidden_size] + Returns: + torch.Tensor: Output tensor of shape [batch_size, seq_len, hidden_size] + """ + B, T, D = x.shape + # Conv1d expects [B, C, T] + x = x.transpose(1, 2) + # Apply convolution and trim the end to maintain causality + x = self.conv(x)[..., :T] + # Transpose back + return x.transpose(1, 2) + +# Kernelbench Parameters +batch_size = 16 +seq_len = 512 +hidden_size = 1024 +kernel_size = 4 + +def get_inputs(): + x = torch.randn(batch_size, seq_len, hidden_size) + return [x] + +def get_init_inputs(): + return [hidden_size, kernel_size] diff --git a/KernelBench/level9/token_shift_reference.py b/KernelBench/level9/token_shift_reference.py new file mode 100644 index 00000000..94720827 --- /dev/null +++ b/KernelBench/level9/token_shift_reference.py @@ -0,0 +1,37 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + """ + Reference implementation for Token Shift (as used in RWKV). + """ + def __init__(self, hidden_size: int): + super(Model, self).__init__() + # Learnable mixing coefficient + self.mu = nn.Parameter(torch.zeros(hidden_size)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Args: + x (torch.Tensor): Input tensor of shape [batch_size, seq_len, hidden_size] + Returns: + torch.Tensor: Output tensor of shape [batch_size, seq_len, hidden_size] + """ + # x_shifted is x[t-1] + x_shifted = F.pad(x, (0, 0, 1, -1)) + # Linear interpolation between current and previous token + mu = torch.sigmoid(self.mu) # Often used with sigmoid to keep in [0, 1] + return x * (1 - mu) + x_shifted * mu + +# Kernelbench Parameters +batch_size = 32 +seq_len = 512 +hidden_size = 4096 + +def get_inputs(): + x = torch.randn(batch_size, seq_len, hidden_size) + return [x] + +def get_init_inputs(): + return [hidden_size] diff --git a/KernelBench/level9/trimul_reference.py b/KernelBench/level9/trimul_reference.py new file mode 100644 index 00000000..6b6487dc --- /dev/null +++ b/KernelBench/level9/trimul_reference.py @@ -0,0 +1,79 @@ +import torch +import torch.nn as nn +from torch import einsum + + +class Model(nn.Module): + """ + Triangle Multiplicative Module (TriMul) - commonly used in protein structure prediction. + Performs gated projections followed by an einsum contraction over a shared dimension. + + Based on: https://github.com/lucidrains/triangle-multiplicative-module + """ + def __init__(self, dim: int, hidden_dim: int): + super().__init__() + + self.norm = nn.LayerNorm(dim) + + self.left_proj = nn.Linear(dim, hidden_dim, bias=False) + self.right_proj = nn.Linear(dim, hidden_dim, bias=False) + + self.left_gate = nn.Linear(dim, hidden_dim, bias=False) + self.right_gate = nn.Linear(dim, hidden_dim, bias=False) + self.out_gate = nn.Linear(dim, hidden_dim, bias=False) + + self.to_out_norm = nn.LayerNorm(hidden_dim) + self.to_out = nn.Linear(hidden_dim, dim, bias=False) + + def forward(self, x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: + """ + Performs the triangle multiplicative update. + + Args: + x: Input tensor of shape (batch_size, seq_len, seq_len, dim). + mask: Mask tensor of shape (batch_size, seq_len, seq_len). + + Returns: + Output tensor of shape (batch_size, seq_len, seq_len, dim). + """ + x = self.norm(x) + + left = self.left_proj(x) + right = self.right_proj(x) + + mask = mask.unsqueeze(-1) + left = left * mask + right = right * mask + + left_gate = self.left_gate(x).sigmoid() + right_gate = self.right_gate(x).sigmoid() + out_gate = self.out_gate(x).sigmoid() + + left = left * left_gate + right = right * right_gate + + # Core computation: contract over the k dimension + # out[b, i, j, d] = sum_k left[b, i, k, d] * right[b, j, k, d] + out = einsum('... i k d, ... j k d -> ... i j d', left, right) + + out = self.to_out_norm(out) + out = out * out_gate + return self.to_out(out) + + +# Problem dimensions +batch_size = 4 +seq_len = 64 +dim = 128 +hidden_dim = 64 + + +def get_inputs(): + x = torch.randn(batch_size, seq_len, seq_len, dim) + mask = torch.randint(0, 2, (batch_size, seq_len, seq_len), dtype=torch.float32) + return [x, mask] + + +def get_init_inputs(): + return [dim, hidden_dim] +