From bb8ec90f163d122848e115f46affdea17b9d4cb2 Mon Sep 17 00:00:00 2001 From: Bhavesh Kalisetti Date: Tue, 6 Jan 2026 23:00:23 -0800 Subject: [PATCH 01/10] use Pareto distribution for Level 1 Problem 96 With inputs sampled from Unif(0,1), the Huber Loss is effectively MSE, which we know can be hacked via statistical properties of the loss fn/inputs. We use the Pareto distribution to sample inputs w/finite mean and infinite variance to prevent hacking this way. --- KernelBench/level1/96_HuberLoss.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/KernelBench/level1/96_HuberLoss.py b/KernelBench/level1/96_HuberLoss.py index 5e60d5df..4e4a1488 100644 --- a/KernelBench/level1/96_HuberLoss.py +++ b/KernelBench/level1/96_HuberLoss.py @@ -1,6 +1,8 @@ import torch import torch.nn as nn +from torch.distributions import Pareto + class Model(nn.Module): """ A model that computes Smooth L1 (Huber) Loss for regression tasks. @@ -20,7 +22,9 @@ def forward(self, predictions, targets): def get_inputs(): scale = torch.rand(()) - return [torch.rand(batch_size, *input_shape)*scale, torch.rand(batch_size, *input_shape)] + predictions = Pareto(0.01, 1.5).sample((batch_size, *input_shape)) + targets = Pareto(0.01, 1.5).sample((batch_size, *input_shape)) + return [predictions*scale, targets] def get_init_inputs(): return [] From a3c7f8ae85cb04cbff9d4a19ae351be9930cb058 Mon Sep 17 00:00:00 2001 From: Bhavesh Kalisetti Date: Tue, 6 Jan 2026 23:53:11 -0800 Subject: [PATCH 02/10] use Pareto Distribution for Level 1 Problem 100 With inputs sampled from Unif(0,1), we can directly compute the expected value of the output using the mean of the targets. We use the Pareto distribution to sample inputs w/finite mean and infinite variance to prevent hacking this way. --- KernelBench/level1/100_HingeLoss.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/KernelBench/level1/100_HingeLoss.py b/KernelBench/level1/100_HingeLoss.py index 0b733a05..fae79ab7 100644 --- a/KernelBench/level1/100_HingeLoss.py +++ b/KernelBench/level1/100_HingeLoss.py @@ -1,6 +1,8 @@ import torch import torch.nn as nn +from torch.distributions import Pareto + class Model(nn.Module): """ A model that computes Hinge Loss for binary classification tasks. @@ -19,7 +21,9 @@ def forward(self, predictions, targets): dim = 1 def get_inputs(): - return [torch.rand(batch_size, *input_shape), torch.randint(0, 2, (batch_size,)).float() * 2 - 1] + predictions = Pareto(0.01, 1.5).sample((batch_size, *input_shape)) + targets = torch.randint(0, 2, (batch_size,)).float() * 2 - 1 + return [predictions, targets] def get_init_inputs(): return [] \ No newline at end of file From 50a3ced3efb086964b6e30736ebf883d6d6c4838 Mon Sep 17 00:00:00 2001 From: Simon Guo Date: Wed, 7 Jan 2026 19:50:37 +0000 Subject: [PATCH 03/10] modify scirpt to test at diff percision for specific problem verification --- scripts/verify_bench.py | 82 ++++++++++++++++++++++++++++++----------- 1 file changed, 61 insertions(+), 21 deletions(-) diff --git a/scripts/verify_bench.py b/scripts/verify_bench.py index 369a25e2..c9682d6b 100644 --- a/scripts/verify_bench.py +++ b/scripts/verify_bench.py @@ -3,17 +3,20 @@ and random initialization. It compares the output of the original model against itself. It ensures that the test is well-formed and there are no sources of non-determinism in the test. -Usage: python test_bench.py +Usage: + python verify_bench.py # Run all levels + python verify_bench.py level=1 # Run only level 1 + python verify_bench.py problem_ids=[96,100] # Run only problem IDs 96 and 100 """ -import importlib -import torch -import torch.nn as nn -import torch.nn.functional as F +import importlib.util +import os import random + import numpy as np -import os -import importlib.util +import pydra +from pydra import Config +import torch """ Test all the reference architectures compiles @@ -37,13 +40,17 @@ def set_seed(seed): def check_correctness( - Model, NewModel, get_inputs, get_init_inputs, seed=1012, atol=1e-02, rtol=1e-02 + Model, NewModel, get_inputs, get_init_inputs, seed=42, atol=None, rtol=None, precision=None ): + if atol is None: + atol = get_tolerance_for_precision(precision) + if rtol is None: + rtol = get_tolerance_for_precision(precision) # run the model and check correctness with torch.no_grad(): set_seed(seed) inputs = get_inputs() - inputs = [x.cuda() if isinstance(x, torch.Tensor) else x for x in inputs] + inputs = [x.cuda().to(precision) if isinstance(x, torch.Tensor) else x for x in inputs] set_seed(seed) init_inputs = get_init_inputs() @@ -52,10 +59,10 @@ def check_correctness( ] set_seed(seed) - model = Model(*init_inputs).cuda() + model = Model(*init_inputs).cuda().to(precision) set_seed(seed) - model_new = NewModel(*init_inputs).cuda() + model_new = NewModel(*init_inputs).cuda().to(precision) output = model(*inputs) output_new = model_new(*inputs) @@ -67,22 +74,46 @@ def check_correctness( return True -def run(Model, NewModel, get_inputs, get_init_inputs, seed=1012): - return check_correctness(Model, NewModel, get_inputs, get_init_inputs, seed) +def run(Model, NewModel, get_inputs, get_init_inputs, seed=1012, precision=None): + return check_correctness(Model, NewModel, get_inputs, get_init_inputs, seed, precision=precision) from kernelbench.dataset import construct_kernelbench_dataset +from kernelbench.eval import get_torch_dtype_from_string, get_tolerance_for_precision + + +class ScriptConfig(Config): + def __init__(self): + # Level(s) to run - can be single int or list + self.level = [1, 2, 3] + # Filter by problem IDs (e.g., [96, 100]) + self.problem_ids = [] + # Dataset source + self.source = "local" + # Precision: "fp32", "fp16", "bf16" + self.precision = "fp32" -def run_all(level): - print(f"Running Level {level}") - dataset = construct_kernelbench_dataset(level) + +def run_all(level: int, problem_ids: list, source: str, precision: torch.dtype): + """ + Run all problems in the given level. + """ + + print(f"Running Level {level} of length {len(problem_ids)} problems from {source} with precision {precision}") + + # Use problem_ids filtering at dataset level if specified + if problem_ids: + dataset = construct_kernelbench_dataset(level, source=source, problem_ids=problem_ids) + else: + dataset = construct_kernelbench_dataset(level, source=source) + total = 0 passed = 0 fail_tests = [] for problem in dataset: - total += 1 module_name = problem.name.replace(".py", "") + total += 1 try: problem_path = getattr(problem, "path", None) if not problem_path: @@ -100,8 +131,9 @@ def run_all(level): Model = getattr(module, "Model") get_inputs = getattr(module, "get_inputs") get_init_inputs = getattr(module, "get_init_inputs") - assert run(Model, Model, get_inputs, get_init_inputs) + assert run(Model, Model, get_inputs, get_init_inputs, precision=precision) passed += 1 + print(f"Passed {module_name}") except Exception as e: print(f"Failed {module_name}: {e}") fail_tests.append(module_name) @@ -110,7 +142,15 @@ def run_all(level): print(f"Failed tests: {fail_tests}") +@pydra.main(base=ScriptConfig) +def main(config: ScriptConfig): + levels = config.level if isinstance(config.level, list) else [config.level] + problem_ids = config.problem_ids if config.problem_ids else [] + precision = get_torch_dtype_from_string(config.precision) + + for level in levels: + run_all(level, problem_ids, config.source, precision) + + if __name__ == "__main__": - run_all(1) - run_all(2) - run_all(3) + main() From c6984cdc9be2c647eb8f3af87a2cee98eff77af3 Mon Sep 17 00:00:00 2001 From: Bhavesh Kalisetti Date: Wed, 7 Jan 2026 21:40:38 -0800 Subject: [PATCH 04/10] tests for 94, 96, 100 adds some tests for problems 94, 96, and 100. for each problem we test a hacked solution and a correct solution. we run these on the old inputs (sampled from uniform) and the new inputs (sampled from gaussian with mean/std sampled from uniform). also updates verify_bench.py to check for overflow after casting inputs. verified that the new input sampling approach passes correct kernel implementations, fails hacked ones, and does not overflow precision. --- scripts/verify_bench.py | 4 + .../tests/problems/100_HingeLoss_NEW.py | 30 +++++++ .../tests/problems/100_HingeLoss_OLD.py | 25 ++++++ .../tests/problems/94_MSELoss_NEW.py | 32 +++++++ .../tests/problems/94_MSELoss_OLD.py | 26 ++++++ .../tests/problems/96_HuberLoss_NEW.py | 32 +++++++ .../tests/problems/96_HuberLoss_OLD.py | 26 ++++++ .../tests/solutions/100_HingeLoss_CORRECT.py | 84 +++++++++++++++++++ .../tests/solutions/100_HingeLoss_HACK.py | 9 ++ .../tests/solutions/94_MSELoss_CORRECT.py | 75 +++++++++++++++++ .../tests/solutions/94_MSELoss_HACK.py | 11 +++ .../tests/solutions/96_HuberLoss_CORRECT.py | 53 ++++++++++++ .../tests/solutions/96_HuberLoss_HACK.py | 11 +++ src/kernelbench/tests/test_level1_hacks.py | 74 ++++++++++++++++ 14 files changed, 492 insertions(+) create mode 100644 src/kernelbench/tests/problems/100_HingeLoss_NEW.py create mode 100644 src/kernelbench/tests/problems/100_HingeLoss_OLD.py create mode 100644 src/kernelbench/tests/problems/94_MSELoss_NEW.py create mode 100644 src/kernelbench/tests/problems/94_MSELoss_OLD.py create mode 100644 src/kernelbench/tests/problems/96_HuberLoss_NEW.py create mode 100644 src/kernelbench/tests/problems/96_HuberLoss_OLD.py create mode 100644 src/kernelbench/tests/solutions/100_HingeLoss_CORRECT.py create mode 100644 src/kernelbench/tests/solutions/100_HingeLoss_HACK.py create mode 100644 src/kernelbench/tests/solutions/94_MSELoss_CORRECT.py create mode 100644 src/kernelbench/tests/solutions/94_MSELoss_HACK.py create mode 100644 src/kernelbench/tests/solutions/96_HuberLoss_CORRECT.py create mode 100644 src/kernelbench/tests/solutions/96_HuberLoss_HACK.py create mode 100644 src/kernelbench/tests/test_level1_hacks.py diff --git a/scripts/verify_bench.py b/scripts/verify_bench.py index c9682d6b..a5a73e87 100644 --- a/scripts/verify_bench.py +++ b/scripts/verify_bench.py @@ -52,6 +52,10 @@ def check_correctness( inputs = get_inputs() inputs = [x.cuda().to(precision) if isinstance(x, torch.Tensor) else x for x in inputs] + for i, x in enumerate(inputs): + if isinstance(x, torch.Tensor) and torch.isinf(x).any(): + raise ValueError(f"Input {i} contains infinity values") + set_seed(seed) init_inputs = get_init_inputs() init_inputs = [ diff --git a/src/kernelbench/tests/problems/100_HingeLoss_NEW.py b/src/kernelbench/tests/problems/100_HingeLoss_NEW.py new file mode 100644 index 00000000..6e2c4f34 --- /dev/null +++ b/src/kernelbench/tests/problems/100_HingeLoss_NEW.py @@ -0,0 +1,30 @@ +import torch +import torch.nn as nn + +from torch.distributions import Normal + +class Model(nn.Module): + """ + A model that computes Hinge Loss for binary classification tasks. + + Parameters: + None + """ + def __init__(self): + super(Model, self).__init__() + + def forward(self, predictions, targets): + return torch.mean(torch.clamp(1 - predictions * targets, min=0)) + +batch_size = 32768 +input_shape = (32768,) +dim = 1 + +def get_inputs(): + m, s = torch.rand(()), torch.rand(()) + 0.1 + predictions = Normal(m, s).sample((batch_size, *input_shape)) + targets = torch.randint(0, 2, (batch_size,)).float() * 2 - 1 + return [predictions, targets] + +def get_init_inputs(): + return [] \ No newline at end of file diff --git a/src/kernelbench/tests/problems/100_HingeLoss_OLD.py b/src/kernelbench/tests/problems/100_HingeLoss_OLD.py new file mode 100644 index 00000000..0b733a05 --- /dev/null +++ b/src/kernelbench/tests/problems/100_HingeLoss_OLD.py @@ -0,0 +1,25 @@ +import torch +import torch.nn as nn + +class Model(nn.Module): + """ + A model that computes Hinge Loss for binary classification tasks. + + Parameters: + None + """ + def __init__(self): + super(Model, self).__init__() + + def forward(self, predictions, targets): + return torch.mean(torch.clamp(1 - predictions * targets, min=0)) + +batch_size = 32768 +input_shape = (32768,) +dim = 1 + +def get_inputs(): + return [torch.rand(batch_size, *input_shape), torch.randint(0, 2, (batch_size,)).float() * 2 - 1] + +def get_init_inputs(): + return [] \ No newline at end of file diff --git a/src/kernelbench/tests/problems/94_MSELoss_NEW.py b/src/kernelbench/tests/problems/94_MSELoss_NEW.py new file mode 100644 index 00000000..79f21b47 --- /dev/null +++ b/src/kernelbench/tests/problems/94_MSELoss_NEW.py @@ -0,0 +1,32 @@ +import torch +import torch.nn as nn + +from torch.distributions import Normal + +class Model(nn.Module): + """ + A model that computes the Mean Squared Error loss for regression tasks. + + Parameters: + None + """ + def __init__(self): + super(Model, self).__init__() + + def forward(self, predictions, targets): + return torch.mean((predictions - targets) ** 2) + +batch_size = 32768 +input_shape = (32768,) +dim = 1 + +def get_inputs(): + scale = torch.rand(()) + m1, m2 = torch.rand(2) + s1, s2 = torch.rand(2) + 0.1 + predictions = Normal(m1, s1).sample((batch_size, *input_shape)) + targets = Normal(m2, s2).sample((batch_size, *input_shape)) + return [predictions*scale, targets] + +def get_init_inputs(): + return [] diff --git a/src/kernelbench/tests/problems/94_MSELoss_OLD.py b/src/kernelbench/tests/problems/94_MSELoss_OLD.py new file mode 100644 index 00000000..2dc77eed --- /dev/null +++ b/src/kernelbench/tests/problems/94_MSELoss_OLD.py @@ -0,0 +1,26 @@ +import torch +import torch.nn as nn + +class Model(nn.Module): + """ + A model that computes the Mean Squared Error loss for regression tasks. + + Parameters: + None + """ + def __init__(self): + super(Model, self).__init__() + + def forward(self, predictions, targets): + return torch.mean((predictions - targets) ** 2) + +batch_size = 32768 +input_shape = (32768,) +dim = 1 + +def get_inputs(): + scale = torch.rand(()) + return [torch.rand(batch_size, *input_shape)*scale, torch.rand(batch_size, *input_shape)] + +def get_init_inputs(): + return [] diff --git a/src/kernelbench/tests/problems/96_HuberLoss_NEW.py b/src/kernelbench/tests/problems/96_HuberLoss_NEW.py new file mode 100644 index 00000000..dbc673f2 --- /dev/null +++ b/src/kernelbench/tests/problems/96_HuberLoss_NEW.py @@ -0,0 +1,32 @@ +import torch +import torch.nn as nn + +from torch.distributions import Normal + +class Model(nn.Module): + """ + A model that computes Smooth L1 (Huber) Loss for regression tasks. + + Parameters: + None + """ + def __init__(self): + super(Model, self).__init__() + + def forward(self, predictions, targets): + return torch.nn.functional.smooth_l1_loss(predictions, targets) + +batch_size = 32768 +input_shape = (32768,) +dim = 1 + +def get_inputs(): + scale = torch.rand(()) + m1, m2 = torch.rand(2) + s1, s2 = torch.rand(2) + 0.1 + predictions = Normal(m1, s1).sample((batch_size, *input_shape)) + targets = Normal(m2, s2).sample((batch_size, *input_shape)) + return [predictions*scale, targets] + +def get_init_inputs(): + return [] diff --git a/src/kernelbench/tests/problems/96_HuberLoss_OLD.py b/src/kernelbench/tests/problems/96_HuberLoss_OLD.py new file mode 100644 index 00000000..5e60d5df --- /dev/null +++ b/src/kernelbench/tests/problems/96_HuberLoss_OLD.py @@ -0,0 +1,26 @@ +import torch +import torch.nn as nn + +class Model(nn.Module): + """ + A model that computes Smooth L1 (Huber) Loss for regression tasks. + + Parameters: + None + """ + def __init__(self): + super(Model, self).__init__() + + def forward(self, predictions, targets): + return torch.nn.functional.smooth_l1_loss(predictions, targets) + +batch_size = 32768 +input_shape = (32768,) +dim = 1 + +def get_inputs(): + scale = torch.rand(()) + return [torch.rand(batch_size, *input_shape)*scale, torch.rand(batch_size, *input_shape)] + +def get_init_inputs(): + return [] diff --git a/src/kernelbench/tests/solutions/100_HingeLoss_CORRECT.py b/src/kernelbench/tests/solutions/100_HingeLoss_CORRECT.py new file mode 100644 index 00000000..ca0b1c5a --- /dev/null +++ b/src/kernelbench/tests/solutions/100_HingeLoss_CORRECT.py @@ -0,0 +1,84 @@ +import torch +import torch.nn as nn +from torch.utils.cpp_extension import load_inline + +hinge_loss_source = """ +#include +#include + +__global__ void hinge_loss_kernel( + const float* predictions, const float* targets, + float* result, int batch_size, int inner_size) { + + extern __shared__ float shared[]; + + int tid = threadIdx.x; + int global_idx = blockIdx.x * blockDim.x + threadIdx.x; + int total_elements = batch_size * inner_size; + int stride = blockDim.x * gridDim.x; + + float sum = 0.0f; + + for (int i = global_idx; i < total_elements; i += stride) { + int col_idx = i % inner_size; + float pred = predictions[i]; + float target = targets[col_idx]; + float val = 1.0f - pred * target; + if (val > 0.0f) { + sum += val; + } + } + + shared[tid] = sum; + __syncthreads(); + + for (int s = blockDim.x / 2; s > 0; s >>= 1) { + if (tid < s) { + shared[tid] += shared[tid + s]; + } + __syncthreads(); + } + + if (tid == 0) { + atomicAdd(result, shared[0]); + } +} + +torch::Tensor hinge_loss_cuda(torch::Tensor predictions, torch::Tensor targets) { + int batch_size = predictions.size(0); + int inner_size = predictions.numel() / batch_size; + int total_elements = predictions.numel(); + + auto result = torch::zeros({}, predictions.options()); + + const int block_size = 256; + int num_blocks = min((total_elements + block_size - 1) / block_size, 1024); + + hinge_loss_kernel<<>>( + predictions.data_ptr(), targets.data_ptr(), + result.data_ptr(), batch_size, inner_size); + + return result / total_elements; +} +""" + +hinge_loss_cpp_source = "torch::Tensor hinge_loss_cuda(torch::Tensor predictions, torch::Tensor targets);" + +hinge_loss = load_inline( + name="hinge_loss", + cpp_sources=hinge_loss_cpp_source, + cuda_sources=hinge_loss_source, + functions=["hinge_loss_cuda"], + verbose=True, + extra_cflags=[""], + extra_ldflags=[""], +) + +class ModelNew(nn.Module): + def __init__(self): + super(ModelNew, self).__init__() + self.hinge_loss = hinge_loss + + def forward(self, predictions, targets): + return self.hinge_loss.hinge_loss_cuda(predictions, targets) + diff --git a/src/kernelbench/tests/solutions/100_HingeLoss_HACK.py b/src/kernelbench/tests/solutions/100_HingeLoss_HACK.py new file mode 100644 index 00000000..5cda94b0 --- /dev/null +++ b/src/kernelbench/tests/solutions/100_HingeLoss_HACK.py @@ -0,0 +1,9 @@ +import torch +import torch.nn as nn + +class ModelNew(nn.Module): + def __init__(self): + super(ModelNew, self).__init__() + + def forward(self, predictions, targets): + return 1.0 - 0.5 * targets.mean() \ No newline at end of file diff --git a/src/kernelbench/tests/solutions/94_MSELoss_CORRECT.py b/src/kernelbench/tests/solutions/94_MSELoss_CORRECT.py new file mode 100644 index 00000000..503d3b7a --- /dev/null +++ b/src/kernelbench/tests/solutions/94_MSELoss_CORRECT.py @@ -0,0 +1,75 @@ +import torch +import torch.nn as nn +from torch.utils.cpp_extension import load_inline + +mse_loss_source = """ +#include +#include + +__global__ void mse_loss_kernel( + const float* predictions, const float* targets, + float* result, int total_elements) { + + extern __shared__ float shared[]; + + int tid = threadIdx.x; + int global_idx = blockIdx.x * blockDim.x + threadIdx.x; + int stride = blockDim.x * gridDim.x; + + float sum = 0.0f; + + for (int i = global_idx; i < total_elements; i += stride) { + float diff = predictions[i] - targets[i]; + sum += diff * diff; + } + + shared[tid] = sum; + __syncthreads(); + + for (int s = blockDim.x / 2; s > 0; s >>= 1) { + if (tid < s) { + shared[tid] += shared[tid + s]; + } + __syncthreads(); + } + + if (tid == 0) { + atomicAdd(result, shared[0]); + } +} + +torch::Tensor mse_loss_cuda(torch::Tensor predictions, torch::Tensor targets) { + int total_elements = predictions.numel(); + + auto result = torch::zeros({}, predictions.options()); + + const int block_size = 256; + int num_blocks = min((total_elements + block_size - 1) / block_size, 1024); + + mse_loss_kernel<<>>( + predictions.data_ptr(), targets.data_ptr(), + result.data_ptr(), total_elements); + + return result / total_elements; +} +""" + +mse_loss_cpp_source = "torch::Tensor mse_loss_cuda(torch::Tensor predictions, torch::Tensor targets);" + +mse_loss = load_inline( + name="mse_loss", + cpp_sources=mse_loss_cpp_source, + cuda_sources=mse_loss_source, + functions=["mse_loss_cuda"], + verbose=True, + extra_cflags=[""], + extra_ldflags=[""], +) + +class ModelNew(nn.Module): + def __init__(self): + super(ModelNew, self).__init__() + self.mse_loss = mse_loss + + def forward(self, predictions, targets): + return self.mse_loss.mse_loss_cuda(predictions, targets) diff --git a/src/kernelbench/tests/solutions/94_MSELoss_HACK.py b/src/kernelbench/tests/solutions/94_MSELoss_HACK.py new file mode 100644 index 00000000..dca3a1af --- /dev/null +++ b/src/kernelbench/tests/solutions/94_MSELoss_HACK.py @@ -0,0 +1,11 @@ +import torch +import torch.nn as nn + +class ModelNew(nn.Module): + def __init__(self): + super(ModelNew, self).__init__() + + def forward(self, predictions, targets): + scale = 2.0 * torch.mean(predictions) + expected = scale * scale / 3.0 - scale / 2.0 + 1.0 / 3.0 + return expected \ No newline at end of file diff --git a/src/kernelbench/tests/solutions/96_HuberLoss_CORRECT.py b/src/kernelbench/tests/solutions/96_HuberLoss_CORRECT.py new file mode 100644 index 00000000..e64bfa16 --- /dev/null +++ b/src/kernelbench/tests/solutions/96_HuberLoss_CORRECT.py @@ -0,0 +1,53 @@ +import torch +import torch.nn as nn +from torch.utils.cpp_extension import load_inline + +smooth_l1_source = """ +#include +#include + +__global__ void smooth_l1_kernel(const float* predictions, const float* targets, float* out, int size) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < size) { + float diff = predictions[idx] - targets[idx]; + float abs_diff = fabsf(diff); + if (abs_diff < 1.0f) { + out[idx] = 0.5f * diff * diff; + } else { + out[idx] = abs_diff - 0.5f; + } + } +} + +torch::Tensor smooth_l1_cuda(torch::Tensor predictions, torch::Tensor targets) { + auto size = predictions.numel(); + auto out = torch::empty_like(predictions); + + const int block_size = 256; + const int num_blocks = (size + block_size - 1) / block_size; + + smooth_l1_kernel<<>>(predictions.data_ptr(), targets.data_ptr(), out.data_ptr(), size); + + return torch::mean(out); +} +""" + +smooth_l1_cpp_source = "torch::Tensor smooth_l1_cuda(torch::Tensor predictions, torch::Tensor targets);" + +smooth_l1 = load_inline( + name="smooth_l1", + cpp_sources=smooth_l1_cpp_source, + cuda_sources=smooth_l1_source, + functions=["smooth_l1_cuda"], + verbose=True, + extra_cflags=[""], + extra_ldflags=[""], +) + +class ModelNew(nn.Module): + def __init__(self): + super(ModelNew, self).__init__() + self.smooth_l1 = smooth_l1 + + def forward(self, predictions, targets): + return self.smooth_l1.smooth_l1_cuda(predictions, targets) diff --git a/src/kernelbench/tests/solutions/96_HuberLoss_HACK.py b/src/kernelbench/tests/solutions/96_HuberLoss_HACK.py new file mode 100644 index 00000000..daa2a986 --- /dev/null +++ b/src/kernelbench/tests/solutions/96_HuberLoss_HACK.py @@ -0,0 +1,11 @@ +import torch +import torch.nn as nn + +class ModelNew(nn.Module): + def __init__(self): + super(ModelNew, self).__init__() + + def forward(self, predictions, targets): + scale = 2.0 * torch.mean(predictions) + expected = 0.5 * (scale * scale / 3.0 - scale / 2.0 + 1.0 / 3.0) + return expected \ No newline at end of file diff --git a/src/kernelbench/tests/test_level1_hacks.py b/src/kernelbench/tests/test_level1_hacks.py new file mode 100644 index 00000000..3176dcb6 --- /dev/null +++ b/src/kernelbench/tests/test_level1_hacks.py @@ -0,0 +1,74 @@ +import os +import torch + +from kernelbench.eval import eval_kernel_against_ref +from kernelbench.utils import read_file + +TESTS_DIR = os.path.dirname(__file__) +PROBLEMS_DIR = os.path.join(TESTS_DIR, "problems") +SOLUTIONS_DIR = os.path.join(TESTS_DIR, "solutions") + +PROBLEMS = ["94_MSELoss", "96_HuberLoss", "100_HingeLoss"] + + +def evaluate(problem_src: str, solution_src: str) -> bool: + if not solution_src.strip(): + return None + result = eval_kernel_against_ref( + original_model_src=problem_src, + custom_model_src=solution_src, + measure_performance=False, + verbose=False, + ) + return result.correctness if result else False + + +def main(): + device = torch.device("cuda:0") + torch.cuda.set_device(device) + + all_passed = True + + for problem in PROBLEMS: + old_src = read_file(os.path.join(PROBLEMS_DIR, f"{problem}_OLD.py")) + new_src = read_file(os.path.join(PROBLEMS_DIR, f"{problem}_NEW.py")) + hack_src = read_file(os.path.join(SOLUTIONS_DIR, f"{problem}_HACK.py")) + correct_src = read_file(os.path.join(SOLUTIONS_DIR, f"{problem}_CORRECT.py")) + + print(f"\n{'='*60}") + print(f"Testing {problem}") + print(f"{'='*60}") + + hack_vs_old = evaluate(old_src, hack_src) + correct_vs_old = evaluate(old_src, correct_src) + hack_vs_new = evaluate(new_src, hack_src) + correct_vs_new = evaluate(new_src, correct_src) + + def check(solution_name, problem_name, correctness, should_be_correct): + nonlocal all_passed + if correctness is None: + print(f" SKIPPED: {solution_name} (empty solution file)") + return + correct_str = "correct" if correctness else "incorrect" + expected_str = "correct" if should_be_correct else "incorrect" + test_passed = correctness == should_be_correct + status = "✓" if test_passed else "✗" + if not test_passed: + all_passed = False + print(f" {status} {solution_name} on {problem_name}: {correct_str} (expected {expected_str})") + + check("HACK", "OLD problem", hack_vs_old, True) + check("CORRECT", "OLD problem", correct_vs_old, True) + check("HACK", "NEW problem", hack_vs_new, False) + check("CORRECT", "NEW problem", correct_vs_new, True) + + print(f"\n{'='*60}") + print(f"Overall: {'ALL TESTS PASSED' if all_passed else 'SOME TESTS FAILED'}") + print(f"{'='*60}") + + return 0 if all_passed else 1 + + +if __name__ == "__main__": + exit(main()) + From 2159457aa26a14ddc25cd4b18902875fa8811c94 Mon Sep 17 00:00:00 2001 From: Bhavesh Kalisetti Date: Sun, 11 Jan 2026 14:43:30 -0800 Subject: [PATCH 05/10] rename + add partial computation hack + pareto --- .../tests/problems/100_HingeLoss_pareto.py | 30 +++++++ ...NEW.py => 100_HingeLoss_uniform_normal.py} | 0 .../tests/problems/94_MSELoss_pareto.py | 30 +++++++ ...ss_NEW.py => 94_MSELoss_uniform_normal.py} | 1 + .../tests/problems/96_HuberLoss_pareto.py | 30 +++++++ ..._NEW.py => 96_HuberLoss_uniform_normal.py} | 1 + ...CK.py => 100_HingeLoss_analytical_hack.py} | 0 .../tests/solutions/100_HingeLoss_correct.py | 84 +++++++++++++++++++ .../100_HingeLoss_partial_computation_hack.py | 13 +++ ..._HACK.py => 94_MSELoss_analytical_hack.py} | 0 .../tests/solutions/94_MSELoss_correct.py | 75 +++++++++++++++++ .../94_MSELoss_partial_computation_hack.py | 13 +++ ...ACK.py => 96_HuberLoss_analytical_hack.py} | 0 .../tests/solutions/96_HuberLoss_correct.py | 53 ++++++++++++ .../96_HuberLoss_partial_computation_hack.py | 14 ++++ src/kernelbench/tests/test_level1_hacks.py | 33 ++++++-- 16 files changed, 368 insertions(+), 9 deletions(-) create mode 100644 src/kernelbench/tests/problems/100_HingeLoss_pareto.py rename src/kernelbench/tests/problems/{100_HingeLoss_NEW.py => 100_HingeLoss_uniform_normal.py} (100%) create mode 100644 src/kernelbench/tests/problems/94_MSELoss_pareto.py rename src/kernelbench/tests/problems/{94_MSELoss_NEW.py => 94_MSELoss_uniform_normal.py} (99%) create mode 100644 src/kernelbench/tests/problems/96_HuberLoss_pareto.py rename src/kernelbench/tests/problems/{96_HuberLoss_NEW.py => 96_HuberLoss_uniform_normal.py} (99%) rename src/kernelbench/tests/solutions/{100_HingeLoss_HACK.py => 100_HingeLoss_analytical_hack.py} (100%) create mode 100644 src/kernelbench/tests/solutions/100_HingeLoss_correct.py create mode 100644 src/kernelbench/tests/solutions/100_HingeLoss_partial_computation_hack.py rename src/kernelbench/tests/solutions/{94_MSELoss_HACK.py => 94_MSELoss_analytical_hack.py} (100%) create mode 100644 src/kernelbench/tests/solutions/94_MSELoss_correct.py create mode 100644 src/kernelbench/tests/solutions/94_MSELoss_partial_computation_hack.py rename src/kernelbench/tests/solutions/{96_HuberLoss_HACK.py => 96_HuberLoss_analytical_hack.py} (100%) create mode 100644 src/kernelbench/tests/solutions/96_HuberLoss_correct.py create mode 100644 src/kernelbench/tests/solutions/96_HuberLoss_partial_computation_hack.py diff --git a/src/kernelbench/tests/problems/100_HingeLoss_pareto.py b/src/kernelbench/tests/problems/100_HingeLoss_pareto.py new file mode 100644 index 00000000..7702d840 --- /dev/null +++ b/src/kernelbench/tests/problems/100_HingeLoss_pareto.py @@ -0,0 +1,30 @@ +import torch +import torch.nn as nn + +from torch.distributions import Pareto + +class Model(nn.Module): + """ + A model that computes Hinge Loss for binary classification tasks. + + Parameters: + None + """ + def __init__(self): + super(Model, self).__init__() + + def forward(self, predictions, targets): + return torch.mean(torch.clamp(1 - predictions * targets, min=0)) + +batch_size = 32768 +input_shape = (32768,) +dim = 1 + +def get_inputs(): + predictions = Pareto(0.01, 1.5).sample((batch_size, *input_shape)) + targets = torch.randint(0, 2, (batch_size,)).float() * 2 - 1 + return [predictions, targets] + +def get_init_inputs(): + return [] + diff --git a/src/kernelbench/tests/problems/100_HingeLoss_NEW.py b/src/kernelbench/tests/problems/100_HingeLoss_uniform_normal.py similarity index 100% rename from src/kernelbench/tests/problems/100_HingeLoss_NEW.py rename to src/kernelbench/tests/problems/100_HingeLoss_uniform_normal.py diff --git a/src/kernelbench/tests/problems/94_MSELoss_pareto.py b/src/kernelbench/tests/problems/94_MSELoss_pareto.py new file mode 100644 index 00000000..354033df --- /dev/null +++ b/src/kernelbench/tests/problems/94_MSELoss_pareto.py @@ -0,0 +1,30 @@ +import torch +import torch.nn as nn + +from torch.distributions import Pareto + +class Model(nn.Module): + """ + A model that computes the Mean Squared Error loss for regression tasks. + + Parameters: + None + """ + def __init__(self): + super(Model, self).__init__() + + def forward(self, predictions, targets): + return torch.mean((predictions - targets) ** 2) + +batch_size = 32768 +input_shape = (32768,) +dim = 1 + +def get_inputs(): + predictions = Pareto(0.01, 1.5).sample((batch_size, *input_shape)) + targets = Pareto(0.01, 1.5).sample((batch_size, *input_shape)) + return [predictions, targets] + +def get_init_inputs(): + return [] + diff --git a/src/kernelbench/tests/problems/94_MSELoss_NEW.py b/src/kernelbench/tests/problems/94_MSELoss_uniform_normal.py similarity index 99% rename from src/kernelbench/tests/problems/94_MSELoss_NEW.py rename to src/kernelbench/tests/problems/94_MSELoss_uniform_normal.py index 79f21b47..83c2a752 100644 --- a/src/kernelbench/tests/problems/94_MSELoss_NEW.py +++ b/src/kernelbench/tests/problems/94_MSELoss_uniform_normal.py @@ -30,3 +30,4 @@ def get_inputs(): def get_init_inputs(): return [] + diff --git a/src/kernelbench/tests/problems/96_HuberLoss_pareto.py b/src/kernelbench/tests/problems/96_HuberLoss_pareto.py new file mode 100644 index 00000000..abae9cbe --- /dev/null +++ b/src/kernelbench/tests/problems/96_HuberLoss_pareto.py @@ -0,0 +1,30 @@ +import torch +import torch.nn as nn + +from torch.distributions import Pareto + +class Model(nn.Module): + """ + A model that computes Smooth L1 (Huber) Loss for regression tasks. + + Parameters: + None + """ + def __init__(self): + super(Model, self).__init__() + + def forward(self, predictions, targets): + return torch.nn.functional.smooth_l1_loss(predictions, targets) + +batch_size = 32768 +input_shape = (32768,) +dim = 1 + +def get_inputs(): + predictions = Pareto(0.01, 1.5).sample((batch_size, *input_shape)) + targets = Pareto(0.01, 1.5).sample((batch_size, *input_shape)) + return [predictions, targets] + +def get_init_inputs(): + return [] + diff --git a/src/kernelbench/tests/problems/96_HuberLoss_NEW.py b/src/kernelbench/tests/problems/96_HuberLoss_uniform_normal.py similarity index 99% rename from src/kernelbench/tests/problems/96_HuberLoss_NEW.py rename to src/kernelbench/tests/problems/96_HuberLoss_uniform_normal.py index dbc673f2..d98b81a4 100644 --- a/src/kernelbench/tests/problems/96_HuberLoss_NEW.py +++ b/src/kernelbench/tests/problems/96_HuberLoss_uniform_normal.py @@ -30,3 +30,4 @@ def get_inputs(): def get_init_inputs(): return [] + diff --git a/src/kernelbench/tests/solutions/100_HingeLoss_HACK.py b/src/kernelbench/tests/solutions/100_HingeLoss_analytical_hack.py similarity index 100% rename from src/kernelbench/tests/solutions/100_HingeLoss_HACK.py rename to src/kernelbench/tests/solutions/100_HingeLoss_analytical_hack.py diff --git a/src/kernelbench/tests/solutions/100_HingeLoss_correct.py b/src/kernelbench/tests/solutions/100_HingeLoss_correct.py new file mode 100644 index 00000000..ca0b1c5a --- /dev/null +++ b/src/kernelbench/tests/solutions/100_HingeLoss_correct.py @@ -0,0 +1,84 @@ +import torch +import torch.nn as nn +from torch.utils.cpp_extension import load_inline + +hinge_loss_source = """ +#include +#include + +__global__ void hinge_loss_kernel( + const float* predictions, const float* targets, + float* result, int batch_size, int inner_size) { + + extern __shared__ float shared[]; + + int tid = threadIdx.x; + int global_idx = blockIdx.x * blockDim.x + threadIdx.x; + int total_elements = batch_size * inner_size; + int stride = blockDim.x * gridDim.x; + + float sum = 0.0f; + + for (int i = global_idx; i < total_elements; i += stride) { + int col_idx = i % inner_size; + float pred = predictions[i]; + float target = targets[col_idx]; + float val = 1.0f - pred * target; + if (val > 0.0f) { + sum += val; + } + } + + shared[tid] = sum; + __syncthreads(); + + for (int s = blockDim.x / 2; s > 0; s >>= 1) { + if (tid < s) { + shared[tid] += shared[tid + s]; + } + __syncthreads(); + } + + if (tid == 0) { + atomicAdd(result, shared[0]); + } +} + +torch::Tensor hinge_loss_cuda(torch::Tensor predictions, torch::Tensor targets) { + int batch_size = predictions.size(0); + int inner_size = predictions.numel() / batch_size; + int total_elements = predictions.numel(); + + auto result = torch::zeros({}, predictions.options()); + + const int block_size = 256; + int num_blocks = min((total_elements + block_size - 1) / block_size, 1024); + + hinge_loss_kernel<<>>( + predictions.data_ptr(), targets.data_ptr(), + result.data_ptr(), batch_size, inner_size); + + return result / total_elements; +} +""" + +hinge_loss_cpp_source = "torch::Tensor hinge_loss_cuda(torch::Tensor predictions, torch::Tensor targets);" + +hinge_loss = load_inline( + name="hinge_loss", + cpp_sources=hinge_loss_cpp_source, + cuda_sources=hinge_loss_source, + functions=["hinge_loss_cuda"], + verbose=True, + extra_cflags=[""], + extra_ldflags=[""], +) + +class ModelNew(nn.Module): + def __init__(self): + super(ModelNew, self).__init__() + self.hinge_loss = hinge_loss + + def forward(self, predictions, targets): + return self.hinge_loss.hinge_loss_cuda(predictions, targets) + diff --git a/src/kernelbench/tests/solutions/100_HingeLoss_partial_computation_hack.py b/src/kernelbench/tests/solutions/100_HingeLoss_partial_computation_hack.py new file mode 100644 index 00000000..a2c69f98 --- /dev/null +++ b/src/kernelbench/tests/solutions/100_HingeLoss_partial_computation_hack.py @@ -0,0 +1,13 @@ +import torch +import torch.nn as nn + +class ModelNew(nn.Module): + def __init__(self): + super(ModelNew, self).__init__() + + def forward(self, predictions, targets): + flat_pred = predictions.flatten() + flat_targ = targets.flatten() + stride = len(flat_pred) // 10_000_000 + return torch.mean(torch.clamp(1 - flat_pred[::stride] * flat_targ[::stride], min=0)) + diff --git a/src/kernelbench/tests/solutions/94_MSELoss_HACK.py b/src/kernelbench/tests/solutions/94_MSELoss_analytical_hack.py similarity index 100% rename from src/kernelbench/tests/solutions/94_MSELoss_HACK.py rename to src/kernelbench/tests/solutions/94_MSELoss_analytical_hack.py diff --git a/src/kernelbench/tests/solutions/94_MSELoss_correct.py b/src/kernelbench/tests/solutions/94_MSELoss_correct.py new file mode 100644 index 00000000..503d3b7a --- /dev/null +++ b/src/kernelbench/tests/solutions/94_MSELoss_correct.py @@ -0,0 +1,75 @@ +import torch +import torch.nn as nn +from torch.utils.cpp_extension import load_inline + +mse_loss_source = """ +#include +#include + +__global__ void mse_loss_kernel( + const float* predictions, const float* targets, + float* result, int total_elements) { + + extern __shared__ float shared[]; + + int tid = threadIdx.x; + int global_idx = blockIdx.x * blockDim.x + threadIdx.x; + int stride = blockDim.x * gridDim.x; + + float sum = 0.0f; + + for (int i = global_idx; i < total_elements; i += stride) { + float diff = predictions[i] - targets[i]; + sum += diff * diff; + } + + shared[tid] = sum; + __syncthreads(); + + for (int s = blockDim.x / 2; s > 0; s >>= 1) { + if (tid < s) { + shared[tid] += shared[tid + s]; + } + __syncthreads(); + } + + if (tid == 0) { + atomicAdd(result, shared[0]); + } +} + +torch::Tensor mse_loss_cuda(torch::Tensor predictions, torch::Tensor targets) { + int total_elements = predictions.numel(); + + auto result = torch::zeros({}, predictions.options()); + + const int block_size = 256; + int num_blocks = min((total_elements + block_size - 1) / block_size, 1024); + + mse_loss_kernel<<>>( + predictions.data_ptr(), targets.data_ptr(), + result.data_ptr(), total_elements); + + return result / total_elements; +} +""" + +mse_loss_cpp_source = "torch::Tensor mse_loss_cuda(torch::Tensor predictions, torch::Tensor targets);" + +mse_loss = load_inline( + name="mse_loss", + cpp_sources=mse_loss_cpp_source, + cuda_sources=mse_loss_source, + functions=["mse_loss_cuda"], + verbose=True, + extra_cflags=[""], + extra_ldflags=[""], +) + +class ModelNew(nn.Module): + def __init__(self): + super(ModelNew, self).__init__() + self.mse_loss = mse_loss + + def forward(self, predictions, targets): + return self.mse_loss.mse_loss_cuda(predictions, targets) diff --git a/src/kernelbench/tests/solutions/94_MSELoss_partial_computation_hack.py b/src/kernelbench/tests/solutions/94_MSELoss_partial_computation_hack.py new file mode 100644 index 00000000..62be6b4a --- /dev/null +++ b/src/kernelbench/tests/solutions/94_MSELoss_partial_computation_hack.py @@ -0,0 +1,13 @@ +import torch +import torch.nn as nn + +class ModelNew(nn.Module): + def __init__(self): + super(ModelNew, self).__init__() + + def forward(self, predictions, targets): + flat_pred = predictions.flatten() + flat_targ = targets.flatten() + stride = len(flat_pred) // 10_000_000 + return torch.mean((flat_pred[::stride] - flat_targ[::stride]) ** 2) + diff --git a/src/kernelbench/tests/solutions/96_HuberLoss_HACK.py b/src/kernelbench/tests/solutions/96_HuberLoss_analytical_hack.py similarity index 100% rename from src/kernelbench/tests/solutions/96_HuberLoss_HACK.py rename to src/kernelbench/tests/solutions/96_HuberLoss_analytical_hack.py diff --git a/src/kernelbench/tests/solutions/96_HuberLoss_correct.py b/src/kernelbench/tests/solutions/96_HuberLoss_correct.py new file mode 100644 index 00000000..e64bfa16 --- /dev/null +++ b/src/kernelbench/tests/solutions/96_HuberLoss_correct.py @@ -0,0 +1,53 @@ +import torch +import torch.nn as nn +from torch.utils.cpp_extension import load_inline + +smooth_l1_source = """ +#include +#include + +__global__ void smooth_l1_kernel(const float* predictions, const float* targets, float* out, int size) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < size) { + float diff = predictions[idx] - targets[idx]; + float abs_diff = fabsf(diff); + if (abs_diff < 1.0f) { + out[idx] = 0.5f * diff * diff; + } else { + out[idx] = abs_diff - 0.5f; + } + } +} + +torch::Tensor smooth_l1_cuda(torch::Tensor predictions, torch::Tensor targets) { + auto size = predictions.numel(); + auto out = torch::empty_like(predictions); + + const int block_size = 256; + const int num_blocks = (size + block_size - 1) / block_size; + + smooth_l1_kernel<<>>(predictions.data_ptr(), targets.data_ptr(), out.data_ptr(), size); + + return torch::mean(out); +} +""" + +smooth_l1_cpp_source = "torch::Tensor smooth_l1_cuda(torch::Tensor predictions, torch::Tensor targets);" + +smooth_l1 = load_inline( + name="smooth_l1", + cpp_sources=smooth_l1_cpp_source, + cuda_sources=smooth_l1_source, + functions=["smooth_l1_cuda"], + verbose=True, + extra_cflags=[""], + extra_ldflags=[""], +) + +class ModelNew(nn.Module): + def __init__(self): + super(ModelNew, self).__init__() + self.smooth_l1 = smooth_l1 + + def forward(self, predictions, targets): + return self.smooth_l1.smooth_l1_cuda(predictions, targets) diff --git a/src/kernelbench/tests/solutions/96_HuberLoss_partial_computation_hack.py b/src/kernelbench/tests/solutions/96_HuberLoss_partial_computation_hack.py new file mode 100644 index 00000000..98575fea --- /dev/null +++ b/src/kernelbench/tests/solutions/96_HuberLoss_partial_computation_hack.py @@ -0,0 +1,14 @@ +import torch +import torch.nn as nn + +class ModelNew(nn.Module): + def __init__(self): + super(ModelNew, self).__init__() + self.huber = nn.SmoothL1Loss() + + def forward(self, predictions, targets): + flat_pred = predictions.flatten() + flat_targ = targets.flatten() + stride = len(flat_pred) // 10_000_000 + return self.huber(flat_pred[::stride], flat_targ[::stride]) + diff --git a/src/kernelbench/tests/test_level1_hacks.py b/src/kernelbench/tests/test_level1_hacks.py index 3176dcb6..9bef1bf5 100644 --- a/src/kernelbench/tests/test_level1_hacks.py +++ b/src/kernelbench/tests/test_level1_hacks.py @@ -31,18 +31,26 @@ def main(): for problem in PROBLEMS: old_src = read_file(os.path.join(PROBLEMS_DIR, f"{problem}_OLD.py")) - new_src = read_file(os.path.join(PROBLEMS_DIR, f"{problem}_NEW.py")) - hack_src = read_file(os.path.join(SOLUTIONS_DIR, f"{problem}_HACK.py")) - correct_src = read_file(os.path.join(SOLUTIONS_DIR, f"{problem}_CORRECT.py")) + uniform_normal_src = read_file(os.path.join(PROBLEMS_DIR, f"{problem}_uniform_normal.py")) + pareto_src = read_file(os.path.join(PROBLEMS_DIR, f"{problem}_pareto.py")) + + analytical_hack_src = read_file(os.path.join(SOLUTIONS_DIR, f"{problem}_analytical_hack.py")) + partial_analytical_hack_src = read_file(os.path.join(SOLUTIONS_DIR, f"{problem}_partial_computation_hack.py")) + correct_src = read_file(os.path.join(SOLUTIONS_DIR, f"{problem}_correct.py")) print(f"\n{'='*60}") print(f"Testing {problem}") print(f"{'='*60}") - hack_vs_old = evaluate(old_src, hack_src) + analytical_hack_vs_old = evaluate(old_src, analytical_hack_src) + partial_analytical_hack_vs_old = evaluate(old_src, partial_analytical_hack_src) correct_vs_old = evaluate(old_src, correct_src) - hack_vs_new = evaluate(new_src, hack_src) - correct_vs_new = evaluate(new_src, correct_src) + analytical_hack_vs_uniform_normal = evaluate(uniform_normal_src, analytical_hack_src) + partial_analytical_hack_vs_uniform_normal = evaluate(uniform_normal_src, partial_analytical_hack_src) + correct_vs_uniform_normal = evaluate(uniform_normal_src, correct_src) + analytical_hack_vs_pareto = evaluate(pareto_src, analytical_hack_src) + partial_analytical_hack_vs_pareto = evaluate(pareto_src, partial_analytical_hack_src) + correct_vs_pareto = evaluate(pareto_src, correct_src) def check(solution_name, problem_name, correctness, should_be_correct): nonlocal all_passed @@ -57,10 +65,17 @@ def check(solution_name, problem_name, correctness, should_be_correct): all_passed = False print(f" {status} {solution_name} on {problem_name}: {correct_str} (expected {expected_str})") - check("HACK", "OLD problem", hack_vs_old, True) + check("ANALYTICAL_HACK", "OLD problem", analytical_hack_vs_old, True) + check("PARTIAL_HACK", "OLD problem", partial_analytical_hack_vs_old, True) check("CORRECT", "OLD problem", correct_vs_old, True) - check("HACK", "NEW problem", hack_vs_new, False) - check("CORRECT", "NEW problem", correct_vs_new, True) + + check("ANALYTICAL_HACK", "uniform_normal problem", analytical_hack_vs_uniform_normal, False) + check("PARTIAL_HACK", "uniform_normal problem", partial_analytical_hack_vs_uniform_normal, True) + check("CORRECT", "uniform_normal problem", correct_vs_uniform_normal, True) + + check("ANALYTICAL_HACK", "pareto problem", analytical_hack_vs_pareto, False) + check("PARTIAL_HACK", "pareto problem", partial_analytical_hack_vs_pareto, True) + check("CORRECT", "pareto problem", correct_vs_pareto, True) print(f"\n{'='*60}") print(f"Overall: {'ALL TESTS PASSED' if all_passed else 'SOME TESTS FAILED'}") From e01e52851ce6176cac8622e16090dd6bb2a925f2 Mon Sep 17 00:00:00 2001 From: Bhavesh Kalisetti Date: Sun, 11 Jan 2026 15:26:04 -0800 Subject: [PATCH 06/10] harder partial computation hack + clean up printing --- .../100_HingeLoss_partial_computation_hack.py | 5 +- .../94_MSELoss_partial_computation_hack.py | 2 +- .../96_HuberLoss_partial_computation_hack.py | 2 +- src/kernelbench/tests/test_level1_hacks.py | 66 +++++++------------ 4 files changed, 27 insertions(+), 48 deletions(-) diff --git a/src/kernelbench/tests/solutions/100_HingeLoss_partial_computation_hack.py b/src/kernelbench/tests/solutions/100_HingeLoss_partial_computation_hack.py index a2c69f98..0939020c 100644 --- a/src/kernelbench/tests/solutions/100_HingeLoss_partial_computation_hack.py +++ b/src/kernelbench/tests/solutions/100_HingeLoss_partial_computation_hack.py @@ -6,8 +6,9 @@ def __init__(self): super(ModelNew, self).__init__() def forward(self, predictions, targets): + expanded_targets = targets.expand_as(predictions) flat_pred = predictions.flatten() - flat_targ = targets.flatten() - stride = len(flat_pred) // 10_000_000 + flat_targ = expanded_targets.flatten() + stride = len(flat_pred) // 100_000_000 return torch.mean(torch.clamp(1 - flat_pred[::stride] * flat_targ[::stride], min=0)) diff --git a/src/kernelbench/tests/solutions/94_MSELoss_partial_computation_hack.py b/src/kernelbench/tests/solutions/94_MSELoss_partial_computation_hack.py index 62be6b4a..b113b8b1 100644 --- a/src/kernelbench/tests/solutions/94_MSELoss_partial_computation_hack.py +++ b/src/kernelbench/tests/solutions/94_MSELoss_partial_computation_hack.py @@ -8,6 +8,6 @@ def __init__(self): def forward(self, predictions, targets): flat_pred = predictions.flatten() flat_targ = targets.flatten() - stride = len(flat_pred) // 10_000_000 + stride = len(flat_pred) // 100_000_000 return torch.mean((flat_pred[::stride] - flat_targ[::stride]) ** 2) diff --git a/src/kernelbench/tests/solutions/96_HuberLoss_partial_computation_hack.py b/src/kernelbench/tests/solutions/96_HuberLoss_partial_computation_hack.py index 98575fea..da8b19eb 100644 --- a/src/kernelbench/tests/solutions/96_HuberLoss_partial_computation_hack.py +++ b/src/kernelbench/tests/solutions/96_HuberLoss_partial_computation_hack.py @@ -9,6 +9,6 @@ def __init__(self): def forward(self, predictions, targets): flat_pred = predictions.flatten() flat_targ = targets.flatten() - stride = len(flat_pred) // 10_000_000 + stride = len(flat_pred) // 100_000_000 return self.huber(flat_pred[::stride], flat_targ[::stride]) diff --git a/src/kernelbench/tests/test_level1_hacks.py b/src/kernelbench/tests/test_level1_hacks.py index 9bef1bf5..12039757 100644 --- a/src/kernelbench/tests/test_level1_hacks.py +++ b/src/kernelbench/tests/test_level1_hacks.py @@ -27,63 +27,41 @@ def main(): device = torch.device("cuda:0") torch.cuda.set_device(device) - all_passed = True - for problem in PROBLEMS: old_src = read_file(os.path.join(PROBLEMS_DIR, f"{problem}_OLD.py")) uniform_normal_src = read_file(os.path.join(PROBLEMS_DIR, f"{problem}_uniform_normal.py")) pareto_src = read_file(os.path.join(PROBLEMS_DIR, f"{problem}_pareto.py")) analytical_hack_src = read_file(os.path.join(SOLUTIONS_DIR, f"{problem}_analytical_hack.py")) - partial_analytical_hack_src = read_file(os.path.join(SOLUTIONS_DIR, f"{problem}_partial_computation_hack.py")) + partial_hack_src = read_file(os.path.join(SOLUTIONS_DIR, f"{problem}_partial_computation_hack.py")) correct_src = read_file(os.path.join(SOLUTIONS_DIR, f"{problem}_correct.py")) print(f"\n{'='*60}") print(f"Testing {problem}") print(f"{'='*60}") - analytical_hack_vs_old = evaluate(old_src, analytical_hack_src) - partial_analytical_hack_vs_old = evaluate(old_src, partial_analytical_hack_src) - correct_vs_old = evaluate(old_src, correct_src) - analytical_hack_vs_uniform_normal = evaluate(uniform_normal_src, analytical_hack_src) - partial_analytical_hack_vs_uniform_normal = evaluate(uniform_normal_src, partial_analytical_hack_src) - correct_vs_uniform_normal = evaluate(uniform_normal_src, correct_src) - analytical_hack_vs_pareto = evaluate(pareto_src, analytical_hack_src) - partial_analytical_hack_vs_pareto = evaluate(pareto_src, partial_analytical_hack_src) - correct_vs_pareto = evaluate(pareto_src, correct_src) - - def check(solution_name, problem_name, correctness, should_be_correct): - nonlocal all_passed - if correctness is None: - print(f" SKIPPED: {solution_name} (empty solution file)") - return - correct_str = "correct" if correctness else "incorrect" - expected_str = "correct" if should_be_correct else "incorrect" - test_passed = correctness == should_be_correct - status = "✓" if test_passed else "✗" - if not test_passed: - all_passed = False - print(f" {status} {solution_name} on {problem_name}: {correct_str} (expected {expected_str})") - - check("ANALYTICAL_HACK", "OLD problem", analytical_hack_vs_old, True) - check("PARTIAL_HACK", "OLD problem", partial_analytical_hack_vs_old, True) - check("CORRECT", "OLD problem", correct_vs_old, True) - - check("ANALYTICAL_HACK", "uniform_normal problem", analytical_hack_vs_uniform_normal, False) - check("PARTIAL_HACK", "uniform_normal problem", partial_analytical_hack_vs_uniform_normal, True) - check("CORRECT", "uniform_normal problem", correct_vs_uniform_normal, True) - - check("ANALYTICAL_HACK", "pareto problem", analytical_hack_vs_pareto, False) - check("PARTIAL_HACK", "pareto problem", partial_analytical_hack_vs_pareto, True) - check("CORRECT", "pareto problem", correct_vs_pareto, True) - - print(f"\n{'='*60}") - print(f"Overall: {'ALL TESTS PASSED' if all_passed else 'SOME TESTS FAILED'}") - print(f"{'='*60}") - - return 0 if all_passed else 1 + def check(problem_name, problem_src): + correct_result = evaluate(problem_src, correct_src) + analytical_result = evaluate(problem_src, analytical_hack_src) + partial_result = evaluate(problem_src, partial_hack_src) + + print(f" {problem_name}:") + for name, result in [ + ("CORRECT", correct_result), + ("ANALYTICAL_HACK", analytical_result), + ("PARTIAL_HACK", partial_result), + ]: + if result is None: + print(f" SKIPPED: {name} (empty solution file)") + continue + action = "accepts" if result else "rejects" + print(f" {action} {name}") + + check("OLD", old_src) + check("uniform_normal", uniform_normal_src) + check("pareto", pareto_src) if __name__ == "__main__": - exit(main()) + main() From 745bf50e3752a19394fb1639da288a81a5bdaac8 Mon Sep 17 00:00:00 2001 From: Bhavesh Kalisetti Date: Sun, 11 Jan 2026 15:48:04 -0800 Subject: [PATCH 07/10] add precision support for testing --- src/kernelbench/tests/test_level1_hacks.py | 22 ++++++++++++++++------ 1 file changed, 16 insertions(+), 6 deletions(-) diff --git a/src/kernelbench/tests/test_level1_hacks.py b/src/kernelbench/tests/test_level1_hacks.py index 12039757..7ae568e7 100644 --- a/src/kernelbench/tests/test_level1_hacks.py +++ b/src/kernelbench/tests/test_level1_hacks.py @@ -1,7 +1,9 @@ import os import torch +import pydra +from pydra import Config -from kernelbench.eval import eval_kernel_against_ref +from kernelbench.eval import eval_kernel_against_ref, get_torch_dtype_from_string from kernelbench.utils import read_file TESTS_DIR = os.path.dirname(__file__) @@ -11,7 +13,12 @@ PROBLEMS = ["94_MSELoss", "96_HuberLoss", "100_HingeLoss"] -def evaluate(problem_src: str, solution_src: str) -> bool: +class ScriptConfig(Config): + def __init__(self): + self.precision = "fp32" + + +def evaluate(problem_src: str, solution_src: str, precision: torch.dtype) -> bool: if not solution_src.strip(): return None result = eval_kernel_against_ref( @@ -19,13 +26,16 @@ def evaluate(problem_src: str, solution_src: str) -> bool: custom_model_src=solution_src, measure_performance=False, verbose=False, + precision=precision, ) return result.correctness if result else False -def main(): +@pydra.main(base=ScriptConfig) +def main(config: ScriptConfig): device = torch.device("cuda:0") torch.cuda.set_device(device) + precision = get_torch_dtype_from_string(config.precision) for problem in PROBLEMS: old_src = read_file(os.path.join(PROBLEMS_DIR, f"{problem}_OLD.py")) @@ -41,9 +51,9 @@ def main(): print(f"{'='*60}") def check(problem_name, problem_src): - correct_result = evaluate(problem_src, correct_src) - analytical_result = evaluate(problem_src, analytical_hack_src) - partial_result = evaluate(problem_src, partial_hack_src) + correct_result = evaluate(problem_src, correct_src, precision) + analytical_result = evaluate(problem_src, analytical_hack_src, precision) + partial_result = evaluate(problem_src, partial_hack_src, precision) print(f" {problem_name}:") for name, result in [ From 5670d2eb9832be1a1e2a7c7a42ae077a6c367a05 Mon Sep 17 00:00:00 2001 From: Bhavesh Kalisetti Date: Sun, 11 Jan 2026 16:11:08 -0800 Subject: [PATCH 08/10] add overflow check and more trials --- src/kernelbench/eval.py | 3 +++ src/kernelbench/tests/test_level1_hacks.py | 1 + 2 files changed, 4 insertions(+) diff --git a/src/kernelbench/eval.py b/src/kernelbench/eval.py index 47f59793..b4cedea1 100644 --- a/src/kernelbench/eval.py +++ b/src/kernelbench/eval.py @@ -384,6 +384,9 @@ def _process_input_tensor(input, device, backend="cuda", precision=torch.float32 if not isinstance(input, torch.Tensor): return input + if input.abs().max() > torch.finfo(precision).max: + print(f"[WARNING] Input overflow for {precision}: max {input.abs().max().item():.2e} > {torch.finfo(precision).max:.2e}") + # cast to the desired percision dtype for activations input_tensor = input.to(dtype=precision) diff --git a/src/kernelbench/tests/test_level1_hacks.py b/src/kernelbench/tests/test_level1_hacks.py index 7ae568e7..dd80505f 100644 --- a/src/kernelbench/tests/test_level1_hacks.py +++ b/src/kernelbench/tests/test_level1_hacks.py @@ -27,6 +27,7 @@ def evaluate(problem_src: str, solution_src: str, precision: torch.dtype) -> boo measure_performance=False, verbose=False, precision=precision, + num_correct_trials=5, ) return result.correctness if result else False From e6f7980a4061aba093cc81f7a3ba292464fe88c6 Mon Sep 17 00:00:00 2001 From: Bhavesh Kalisetti Date: Sun, 11 Jan 2026 16:52:45 -0800 Subject: [PATCH 09/10] sample via inverse cdf --- src/kernelbench/tests/problems/100_HingeLoss_pareto.py | 6 ++++-- src/kernelbench/tests/problems/94_MSELoss_pareto.py | 8 +++++--- src/kernelbench/tests/problems/96_HuberLoss_pareto.py | 8 +++++--- 3 files changed, 14 insertions(+), 8 deletions(-) diff --git a/src/kernelbench/tests/problems/100_HingeLoss_pareto.py b/src/kernelbench/tests/problems/100_HingeLoss_pareto.py index 7702d840..82e7ce85 100644 --- a/src/kernelbench/tests/problems/100_HingeLoss_pareto.py +++ b/src/kernelbench/tests/problems/100_HingeLoss_pareto.py @@ -1,7 +1,9 @@ import torch import torch.nn as nn -from torch.distributions import Pareto +def sample_pareto(shape, scale=0.01, alpha=1.5): + u = torch.rand(shape) + return scale / u.pow(1 / alpha) class Model(nn.Module): """ @@ -21,7 +23,7 @@ def forward(self, predictions, targets): dim = 1 def get_inputs(): - predictions = Pareto(0.01, 1.5).sample((batch_size, *input_shape)) + predictions = sample_pareto((batch_size, *input_shape)) targets = torch.randint(0, 2, (batch_size,)).float() * 2 - 1 return [predictions, targets] diff --git a/src/kernelbench/tests/problems/94_MSELoss_pareto.py b/src/kernelbench/tests/problems/94_MSELoss_pareto.py index 354033df..fb9606c1 100644 --- a/src/kernelbench/tests/problems/94_MSELoss_pareto.py +++ b/src/kernelbench/tests/problems/94_MSELoss_pareto.py @@ -1,7 +1,9 @@ import torch import torch.nn as nn -from torch.distributions import Pareto +def sample_pareto(shape, scale=0.01, alpha=1.5): + u = torch.rand(shape) + return scale / u.pow(1 / alpha) class Model(nn.Module): """ @@ -21,8 +23,8 @@ def forward(self, predictions, targets): dim = 1 def get_inputs(): - predictions = Pareto(0.01, 1.5).sample((batch_size, *input_shape)) - targets = Pareto(0.01, 1.5).sample((batch_size, *input_shape)) + predictions = sample_pareto((batch_size, *input_shape)) + targets = sample_pareto((batch_size, *input_shape)) return [predictions, targets] def get_init_inputs(): diff --git a/src/kernelbench/tests/problems/96_HuberLoss_pareto.py b/src/kernelbench/tests/problems/96_HuberLoss_pareto.py index abae9cbe..d4451d66 100644 --- a/src/kernelbench/tests/problems/96_HuberLoss_pareto.py +++ b/src/kernelbench/tests/problems/96_HuberLoss_pareto.py @@ -1,7 +1,9 @@ import torch import torch.nn as nn -from torch.distributions import Pareto +def sample_pareto(shape, scale=0.01, alpha=1.5): + u = torch.rand(shape) + return scale / u.pow(1 / alpha) class Model(nn.Module): """ @@ -21,8 +23,8 @@ def forward(self, predictions, targets): dim = 1 def get_inputs(): - predictions = Pareto(0.01, 1.5).sample((batch_size, *input_shape)) - targets = Pareto(0.01, 1.5).sample((batch_size, *input_shape)) + predictions = sample_pareto((batch_size, *input_shape)) + targets = sample_pareto((batch_size, *input_shape)) return [predictions, targets] def get_init_inputs(): From bbd7cb369e4ddc9b7cff0960703da7764542a02b Mon Sep 17 00:00:00 2001 From: Bhavesh Kalisetti Date: Sun, 11 Jan 2026 17:17:42 -0800 Subject: [PATCH 10/10] add min clamp for pareto to truncate needed to prevent blowing up, but we can also use this to tune for fp16 --- src/kernelbench/tests/problems/100_HingeLoss_pareto.py | 2 +- src/kernelbench/tests/problems/94_MSELoss_pareto.py | 2 +- src/kernelbench/tests/problems/96_HuberLoss_pareto.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/kernelbench/tests/problems/100_HingeLoss_pareto.py b/src/kernelbench/tests/problems/100_HingeLoss_pareto.py index 82e7ce85..aa74335b 100644 --- a/src/kernelbench/tests/problems/100_HingeLoss_pareto.py +++ b/src/kernelbench/tests/problems/100_HingeLoss_pareto.py @@ -2,7 +2,7 @@ import torch.nn as nn def sample_pareto(shape, scale=0.01, alpha=1.5): - u = torch.rand(shape) + u = torch.rand(shape).clamp(min=1e-6) return scale / u.pow(1 / alpha) class Model(nn.Module): diff --git a/src/kernelbench/tests/problems/94_MSELoss_pareto.py b/src/kernelbench/tests/problems/94_MSELoss_pareto.py index fb9606c1..389d1855 100644 --- a/src/kernelbench/tests/problems/94_MSELoss_pareto.py +++ b/src/kernelbench/tests/problems/94_MSELoss_pareto.py @@ -2,7 +2,7 @@ import torch.nn as nn def sample_pareto(shape, scale=0.01, alpha=1.5): - u = torch.rand(shape) + u = torch.rand(shape).clamp(min=1e-6) return scale / u.pow(1 / alpha) class Model(nn.Module): diff --git a/src/kernelbench/tests/problems/96_HuberLoss_pareto.py b/src/kernelbench/tests/problems/96_HuberLoss_pareto.py index d4451d66..0234f1b3 100644 --- a/src/kernelbench/tests/problems/96_HuberLoss_pareto.py +++ b/src/kernelbench/tests/problems/96_HuberLoss_pareto.py @@ -2,7 +2,7 @@ import torch.nn as nn def sample_pareto(shape, scale=0.01, alpha=1.5): - u = torch.rand(shape) + u = torch.rand(shape).clamp(min=1e-6) return scale / u.pow(1 / alpha) class Model(nn.Module):