From 6eab156c67ab99148169626d1f8f3395b1a54ae9 Mon Sep 17 00:00:00 2001 From: sguskin Date: Sun, 25 Jan 2026 01:02:13 -0800 Subject: [PATCH] add kernel evaluation notebook --- notebooks/kernel_evaluation.ipynb | 627 ++++++++++++++++++++++++++++++ 1 file changed, 627 insertions(+) create mode 100644 notebooks/kernel_evaluation.ipynb diff --git a/notebooks/kernel_evaluation.ipynb b/notebooks/kernel_evaluation.ipynb new file mode 100644 index 00000000..f9315807 --- /dev/null +++ b/notebooks/kernel_evaluation.ipynb @@ -0,0 +1,627 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "5cd2d459-1512-4779-8ba5-a99e52851577", + "metadata": {}, + "source": [ + "# Kernel Evaluation Notebook\n", + "\n", + "In this notebook we provide a step-by-step guide for evaluating the performance of a single custom generated kernel against a single KernelBench problem.\n", + "\n", + "**Supported devices:** Intel XPU or Nvidia GPU\n", + "\n", + "**Prerequisites:** \n", + "1. Custom kernel: create your own kernel for a single KernelBench problem (verifiede on triton kernels only)\n", + "2. The provided custom kernel must have a **ModelNew wrapper** including a forward function that runs the kernel\n", + "3. Set the following arguments inside this notebook before running it: problem_id, problem_level, custom_kernel_path, use_xpu (default: True)\n", + "\n", + "**Output:** Speedup of provided custom kernel vs. pytorch eager/compile" + ] + }, + { + "cell_type": "markdown", + "id": "07220fa0-5bdb-47eb-80af-578405fe9d6e", + "metadata": {}, + "source": [ + "## Installations" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6d4a6346-1ced-4f97-a284-c76638479776", + "metadata": {}, + "outputs": [], + "source": [ + "# ! git clone https://github.com/ScalingIntelligence/KernelBench.git\n", + "# % cd KernelBench\n", + "# ! pip install -e .\n", + "\n", + "# installations for XPU:\n", + "# ! pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/xpu\n", + "\n", + "# installations for running jupyter notebook:\n", + "# !pip install -U jupyterlab jupyter ipywidgets" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c0fd74dd-7a84-46e6-8a68-69ebee80aa85", + "metadata": {}, + "outputs": [], + "source": [ + "from datasets import load_dataset\n", + "import os\n", + "import random\n", + "import tempfile\n", + "import traceback\n", + "import numpy as np\n", + "import torch\n", + "import torch.nn as nn\n", + "from torch.profiler import profile, ProfilerActivity\n", + "import torch.fx as fx\n", + "from torch.fx.passes.shape_prop import ShapeProp\n", + "import sys\n", + "from pathlib import Path\n", + "sys.path.insert(0, str(Path.cwd().parent))\n", + "from src.kernelbench.utils import (\n", + " set_gpu_arch,\n", + ")\n", + "from kernelbench.eval import (\n", + " load_original_model_and_inputs,\n", + " load_custom_model_with_tempfile,\n", + " set_seed,\n", + " _process_input_tensor,\n", + " get_tolerance_for_precision\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "be680e9a-799e-4fa8-a40b-7d2281313fa6", + "metadata": {}, + "source": [ + "## Configurations" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4af40a9b-b638-4ddc-b5be-aed4a1e0e495", + "metadata": {}, + "outputs": [], + "source": [ + "problem_id = # set problem id here\n", + "problem_level = # set KernelBench problem level here\n", + "custom_kernel_path = # set custom kernel path here" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ef092369-10c1-400c-9bf4-94aab5dd164b", + "metadata": {}, + "outputs": [], + "source": [ + "# precision = torch.float32\n", + "precision = torch.bfloat16" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b934ba37-b358-4085-9928-f808ffb26721", + "metadata": {}, + "outputs": [], + "source": [ + "dataset_name = \"ScalingIntelligence/KernelBench\"\n", + "seed_num = 42\n", + "num_warmup = 100\n", + "num_time_trials = 200\n", + "num_correctness_trials = 5\n", + "backend = \"triton\"\n", + "\n", + "torch.set_printoptions(\n", + " precision=4, # Decimal places\n", + " threshold=10, # Total number of elements before truncating\n", + " edgeitems=3, # Number of elements at beginning and end of dimensions\n", + " linewidth=80, # Maximum width before wrapping\n", + " )" + ] + }, + { + "cell_type": "markdown", + "id": "882fc124-4969-441c-80c4-0694b7639fc3", + "metadata": {}, + "source": [ + "### Device configuration" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "95c99a1e-67c3-4b8e-a2f2-e843736595bf", + "metadata": {}, + "outputs": [], + "source": [ + "use_xpu = True\n", + "\n", + "if use_xpu:\n", + " device = torch.device(\"xpu\")\n", + "else:\n", + " gpu_arch = [\"Ada\"]\n", + " set_gpu_arch(gpu_arch)\n", + " \n", + " assert torch.cuda.is_available(), \"CUDA is not available, cannot run Eval\"\n", + " device = torch.cuda.current_device()\n", + " \n", + " # set CUDA device\n", + " torch.cuda.set_device(device)\n", + " \n", + " # need to set env var for triton/cute code to guarantee no wrong device shenanigans\n", + " if isinstance(device, int):\n", + " device_num = device\n", + " elif isinstance(device, torch.device):\n", + " assert device.type == \"cuda\", \"CUDA is not availible on device, cannot run Eval\"\n", + " device_num = device.index\n", + " else:\n", + " raise ValueError(f\"device must be an int or torch.device, got {type(device)}\")\n", + " os.environ[\"CUDA_VISIBLE_DEVICES\"] = str(device_num)" + ] + }, + { + "cell_type": "markdown", + "id": "e65b2657-5477-49b3-8345-4daec32c2473", + "metadata": {}, + "source": [ + "## Load KernelBench problem to evaluate" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1b31c698-a776-45c4-847c-eb6e70b123ac", + "metadata": {}, + "outputs": [], + "source": [ + "dataset = load_dataset(dataset_name)\n", + "curr_level_dataset = dataset[f\"level_{problem_level}\"]\n", + "curr_problem_row = curr_level_dataset.filter(lambda x: x[\"problem_id\"] == problem_id)\n", + "ref_arch_src = curr_problem_row[\"code\"][0]\n", + "problem_name = curr_problem_row[\"name\"][0]" + ] + }, + { + "cell_type": "markdown", + "id": "c53223d0-6910-4776-aa64-0ae193b2311f", + "metadata": {}, + "source": [ + "## Load/verify original model and inputs" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "821f8cad-d8ed-48a5-9104-aed96a7f373e", + "metadata": {}, + "outputs": [], + "source": [ + "context = {}\n", + "\n", + "Model, get_init_inputs, get_inputs = load_original_model_and_inputs(ref_arch_src, context)\n", + "set_seed(seed_num) # set seed for reproducible input\n", + "init_inputs = get_init_inputs()\n", + "# Convert inputs to appropriate dtypes for GPU computation\n", + "init_inputs = [_process_input_tensor(x, device, backend, precision) for x in init_inputs]\n", + "\n", + "with torch.no_grad():\n", + " set_seed(seed_num) # set seed for reproducible weights\n", + " original_model = Model(*init_inputs)\n", + " assert hasattr(original_model, \"forward\")" + ] + }, + { + "cell_type": "markdown", + "id": "2b76c6c9-1fae-4bbc-bb3b-b2f0aa73d22a", + "metadata": {}, + "source": [ + "## Load/verify the custom kernel" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "id": "e49d9cf8-2568-4a6c-af37-fd0887dd66ed", + "metadata": {}, + "outputs": [], + "source": [ + "if not use_xpu:\n", + " os.environ[\"TORCH_USE_CUDA_DSA\"] = \"1\" # compile with device side assertion\n", + "\n", + "temp_file = None\n", + "\n", + "with open(custom_kernel_path, \"r\", encoding=\"utf-8\") as f:\n", + " custom_kernel = f.read()\n", + " \n", + "ModelNew, temp_file = load_custom_model_with_tempfile(\n", + " custom_kernel, entry_point=\"ModelNew\"\n", + " )\n", + "with torch.no_grad():\n", + " set_seed(seed_num) # set seed for reproducible weights\n", + " custom_model = ModelNew(*init_inputs)\n", + " assert hasattr(custom_model, \"forward\")" + ] + }, + { + "cell_type": "markdown", + "id": "b4b9ee65-d082-456d-b1dd-91bb7eb80479", + "metadata": {}, + "source": [ + "## Prepare inputs for speedup measurement:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "57418242-00e8-4de0-a094-8e4f6d9c9b85", + "metadata": {}, + "outputs": [], + "source": [ + "set_seed(seed_num)\n", + "inputs = get_inputs()\n", + "# Convert inputs for performance measurement\n", + "\n", + "inputs = [_process_input_tensor(x, device, backend, precision) for x in inputs]\n" + ] + }, + { + "cell_type": "markdown", + "id": "3d284be6-ba81-41b4-8845-1201cb9ce774", + "metadata": {}, + "source": [ + "### Define helper functions" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5d5e21f6-6ffa-4eb0-9f27-bc0fa6b46067", + "metadata": {}, + "outputs": [], + "source": [ + "def synchronize(device):\n", + " if device.type == \"xpu\":\n", + " torch.xpu.synchronize(device=device)\n", + " else:\n", + " torch.cuda.synchronize(device=device)\n" + ] + }, + { + "cell_type": "markdown", + "id": "49afa38c-455c-4ab2-bcd9-a94ddadd6130", + "metadata": {}, + "source": [ + "## Measure runtime of original (pytorch) model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "472da2c9-ada7-4179-85c5-782ec45bebb6", + "metadata": {}, + "outputs": [], + "source": [ + "model = original_model.to(device=device, dtype=precision)\n", + "synchronize(device)\n", + "# torch.cuda.synchronize(device=device)\n", + "\n", + "# Warm ups\n", + "for _ in range(num_warmup):\n", + " model(*inputs)\n", + " synchronize(device)\n", + "\n", + "original_model_times = []\n", + "\n", + "# Actual trials\n", + "for trial in range(num_time_trials):\n", + " # create event marker default is not interprocess\n", + " if device.type == 'xpu':\n", + " start_event = torch.xpu.Event(enable_timing=True)\n", + " end_event = torch.xpu.Event(enable_timing=True)\n", + " else:\n", + " start_event = torch.cuda.Event(enable_timing=True)\n", + " end_event = torch.cuda.Event(enable_timing=True)\n", + "\n", + " \n", + " start_event.record()\n", + " model(*inputs)\n", + " end_event.record()\n", + "\n", + " # Synchronize to ensure the events have completed\n", + " synchronize(device)\n", + "\n", + " # Calculate the elapsed time in milliseconds\n", + " elapsed_time_ms = start_event.elapsed_time(end_event)\n", + " original_model_times.append(elapsed_time_ms)\n", + "\n", + "orig_runtime = np.mean(original_model_times)\n", + "print(f\"\\nOriginal model average time: {orig_runtime:.3g} ms\")\n" + ] + }, + { + "cell_type": "markdown", + "id": "1591b813-9295-4e4d-8c80-fa9181a9ae6f", + "metadata": {}, + "source": [ + "## Measure runtime of compiled model:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "567da81a-a146-4a9c-898f-9aa4ed454e21", + "metadata": {}, + "outputs": [], + "source": [ + "compiled_model = torch.compile(model).to(device=device, dtype=precision)\n", + "\n", + "\n", + "synchronize(device)\n", + "\n", + "# Warm ups\n", + "for _ in range(num_warmup):\n", + " compiled_model(*inputs)\n", + " synchronize(device)\n", + "\n", + "compiled_model_times = []\n", + "\n", + "# Actual trials\n", + "for trial in range(num_time_trials):\n", + " # create event marker default is not interprocess\n", + " if device.type == 'xpu':\n", + " start_event = torch.xpu.Event(enable_timing=True)\n", + " end_event = torch.xpu.Event(enable_timing=True)\n", + " else:\n", + " start_event = torch.cuda.Event(enable_timing=True)\n", + " end_event = torch.cuda.Event(enable_timing=True)\n", + "\n", + " start_event.record()\n", + " compiled_model(*inputs)\n", + " end_event.record()\n", + "\n", + " # Synchronize to ensure the events have completed\n", + " synchronize(device)\n", + "\n", + " # Calculate the elapsed time in milliseconds\n", + " elapsed_time_ms = start_event.elapsed_time(end_event)\n", + " compiled_model_times.append(elapsed_time_ms)\n", + "\n", + "compiled_runtime = np.mean(compiled_model_times)\n", + "print(f\"\\nCompiled model average time: {compiled_runtime:.3g} ms\")\n", + "\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "id": "d233fee6-2a1a-41da-83a2-14d5219183c7", + "metadata": {}, + "source": [ + "## Check correctness of custom kernel\n", + "Before evaluating the performance of the custom kernel we must ensure that it is correct:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "425fc015-f235-41f7-a2cc-f95ebd081ec1", + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [ + "if not use_xpu:\n", + " os.environ[\"TORCH_USE_CUDA_DSA\"] = \"1\" # compile with device side assertion\n", + "\n", + "pass_count = 0\n", + "# Generate num_correct_trials seeds deterministically from the initial seed\n", + "torch.manual_seed(seed_num)\n", + "correctness_trial_seeds = [\n", + " torch.randint(0, 2**32 - 1, (1,)).item() for _ in range(num_correctness_trials)\n", + "]\n", + "\n", + "# in torchbench, they use both precisions for atol and rtol\n", + "# kernelbench v0 and v0.1 uses fp32, atol = rtol = 1e-02\n", + "# now we will return the tolerance from get_tolerance_for_precision\n", + "tolerance = get_tolerance_for_precision(precision)\n", + "print(f\"tolerance: {tolerance}\")\n", + "\n", + "with torch.no_grad():\n", + " for trial in range(num_correctness_trials):\n", + " trial_seed = correctness_trial_seeds[trial]\n", + " set_seed(trial_seed)\n", + " \n", + " temp_file = None\n", + " \n", + " with open(custom_kernel_path, \"r\", encoding=\"utf-8\") as f:\n", + " custom_kernel = f.read()\n", + " \n", + " ModelNew, temp_file = load_custom_model_with_tempfile(\n", + " custom_kernel, entry_point=\"ModelNew\"\n", + " )\n", + " with torch.no_grad():\n", + " set_seed(trial_seed)\n", + " custom_model = ModelNew(*init_inputs)\n", + " assert hasattr(custom_model, \"forward\")\n", + " \n", + " context = {}\n", + " \n", + " Model, get_init_inputs, get_inputs = load_original_model_and_inputs(ref_arch_src, context)\n", + " set_seed(seed_num) # set seed for reproducible input\n", + " init_inputs = get_init_inputs()\n", + " # Convert inputs to appropriate dtypes for GPU computation\n", + " init_inputs = [_process_input_tensor(x, device, backend, precision) for x in init_inputs]\n", + " \n", + " with torch.no_grad():\n", + " set_seed(trial_seed)\n", + " original_model = Model(*init_inputs)\n", + " assert hasattr(original_model, \"forward\")\n", + " \n", + " inputs = get_inputs()\n", + " \n", + " # Convert inputs to appropriate dtypes for GPU computation\n", + " inputs = [_process_input_tensor(x, device, backend, precision) for x in inputs]\n", + "\n", + " set_seed(trial_seed)\n", + " model = original_model.to(device=device, dtype=precision)\n", + " set_seed(trial_seed)\n", + " model_new = custom_model.to(device=device, dtype=precision)\n", + "\n", + " output_orig = model(*inputs)\n", + " synchronize(device)\n", + " output_new = model_new(*inputs)\n", + " synchronize(device)\n", + " # ensure all GPU operations are completed before checking results\n", + " \n", + " if trial==0: # print output dtype\n", + " print(f\"output dtype: {output_orig.dtype}\")\n", + " print(f\"output_new dtype: {output_new.dtype}\")\n", + " \n", + " if output_orig.shape != output_new.shape:\n", + " print(f\"Output shape mismatch: Expected {output.shape}, got {output_new.shape}\")\n", + " \n", + " # break \n", + " # cast to bf16\n", + " output_orig = output_orig.bfloat16()\n", + " output_new = output_new.bfloat16()\n", + " \n", + " # check output value difference\n", + " \n", + " max_diff = torch.max(torch.abs(output_orig - output_new)).item()\n", + " avg_diff = torch.mean(torch.abs(output_orig - output_new)).item()\n", + " \n", + " if not torch.allclose(output_orig, output_new, atol=tolerance, rtol=tolerance): # fail\n", + " print(f\"Failed trial {trial}: Output mismatch. max_difference: {max_diff:.6f}, avg_difference: {avg_diff:.6f}\")\n", + " else: # pass\n", + " pass_count += 1\n", + " print(f\"[PASS] trial {trial}: ModelNew matches Model! max_difference: {max_diff:.6f}, avg_difference: {avg_diff:.6f}\")\n", + "\n", + "if pass_count == num_correctness_trials:\n", + " print(f\"\\nKernel passed correctness check !\")\n", + "else:\n", + " print(f\"\\nKernel didn't passed correctness check !\")\n" + ] + }, + { + "cell_type": "markdown", + "id": "8d629a2e-2237-40dc-9ab3-562fb6e309ce", + "metadata": {}, + "source": [ + "## Measure runtime of custom kernel\n", + "In case the custom kernel is correct, we will now measure its runtime:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "92351efb-1831-425f-b137-5f2145d0dc98", + "metadata": {}, + "outputs": [], + "source": [ + "new_model = custom_model.to(device=device, dtype=precision)\n", + "synchronize(device)\n", + "\n", + "# Warm ups\n", + "for _ in range(num_warmup):\n", + " new_model(*inputs)\n", + " synchronize(device)\n", + "\n", + "new_model_times = []\n", + "\n", + "# Actual trials\n", + "for trial in range(num_time_trials):\n", + " # create event marker default is not interprocess\n", + " if device.type == 'xpu':\n", + " start_event = torch.xpu.Event(enable_timing=True)\n", + " end_event = torch.xpu.Event(enable_timing=True)\n", + " else:\n", + " start_event = torch.cuda.Event(enable_timing=True)\n", + " end_event = torch.cuda.Event(enable_timing=True)\n", + "\n", + " start_event.record()\n", + " new_model(*inputs)\n", + " end_event.record()\n", + "\n", + " # Synchronize to ensure the events have completed\n", + " synchronize(device)\n", + "\n", + " # Calculate the elapsed time in milliseconds\n", + " elapsed_time_ms = start_event.elapsed_time(end_event)\n", + " new_model_times.append(elapsed_time_ms)\n", + "\n", + "custom_runtime = np.mean(new_model_times)\n", + "print(f\"\\nCustom model average time: {custom_runtime:.3g} ms\")\n" + ] + }, + { + "cell_type": "markdown", + "id": "32dd5a83-c353-4171-968e-efa24b449302", + "metadata": {}, + "source": [ + "## Print speedup results\n", + "We can finally report the total speedup:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3885a024-fef1-4638-9ff6-2fba9488cbf5", + "metadata": {}, + "outputs": [], + "source": [ + "print(f\"speedup custom model vs. pytorch model = {(orig_runtime / custom_runtime):.3g}x\")\n", + "print(f\"speedup torch.compiled model vs. pytorch model = {(orig_runtime / compiled_runtime):.3g}x\")\n", + "print(f\"speedup custom model vs. torch.compiled model = {(compiled_runtime / custom_runtime):.3g}x\")" + ] + }, + { + "cell_type": "markdown", + "id": "4538bea0", + "metadata": {}, + "source": [ + "### Support" + ] + }, + { + "cell_type": "markdown", + "id": "27259760", + "metadata": {}, + "source": [ + "For any issues/comments please write to shira.guskin@intel.com" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.19" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}