diff --git a/.gitignore b/.gitignore index abb125d6..3956bf85 100644 Binary files a/.gitignore and b/.gitignore differ diff --git a/requirements.txt b/requirements.txt index f394e54e..253e57da 100644 --- a/requirements.txt +++ b/requirements.txt @@ -10,6 +10,7 @@ modal # DSLs nvidia-cutlass-dsl tilelang +triton # helper tqdm @@ -17,6 +18,7 @@ packaging pydra_config pytest ninja +cupy-cuda12x # Numerics einops diff --git a/scripts/eval_from_generations.py b/scripts/eval_from_generations.py index 2e39e3be..b28a3be0 100644 --- a/scripts/eval_from_generations.py +++ b/scripts/eval_from_generations.py @@ -113,6 +113,7 @@ def __init__(self): self.num_perf_trials = 100 self.timeout = 180 # in seconds self.measure_performance = True + self.timing_method = "cuda_event" # Eval Flow setting # To speedup evaluation, you can start building the kernel on CPU on disk as cache @@ -173,6 +174,7 @@ def evaluate_single_sample_modal( num_correct_trials: int = 5, num_perf_trials: int = 100, measure_performance: bool = True, + timing_method: str = "cuda_event", verbose: bool = False, backend: str = "cuda", precision: str = "fp32", @@ -212,6 +214,7 @@ def evaluate_single_sample_modal( original_model_src=ref_arch_src, custom_model_src=kernel_src, measure_performance=measure_performance, + timing_method=timing_method, verbose=verbose, num_correct_trials=num_correct_trials, num_perf_trials=num_perf_trials, @@ -324,6 +327,7 @@ def evaluate_single_sample( original_model_src=ref_arch_src, custom_model_src=kernel_src, measure_performance=configs.measure_performance, + timing_method=configs.timing_method, verbose=configs.verbose, num_correct_trials=configs.num_correct_trials, num_perf_trials=configs.num_perf_trials, @@ -384,6 +388,7 @@ def evaluate_single_sample_modal_direct( num_correct_trials=configs.num_correct_trials, num_perf_trials=configs.num_perf_trials, measure_performance=configs.measure_performance, + timing_method=configs.timing_method, verbose=configs.verbose, ) return eval_result @@ -502,6 +507,7 @@ def batch_eval_modal( num_correct_trials=config.num_correct_trials, num_perf_trials=config.num_perf_trials, measure_performance=config.measure_performance, + timing_method=config.timing_method, verbose=config.verbose, backend=config.backend, precision=config.precision, diff --git a/scripts/generate_and_eval_single_sample.py b/scripts/generate_and_eval_single_sample.py index 2b2d5301..2e110932 100644 --- a/scripts/generate_and_eval_single_sample.py +++ b/scripts/generate_and_eval_single_sample.py @@ -73,6 +73,7 @@ def __init__(self): self.log_eval_result = False self.backend = "cuda" + self.timing_method = "cuda_event" # see timing.py # Prompt construction self.prompt_option = "one_shot" # choices: zero_shot, one_shot, few_shot @@ -267,6 +268,7 @@ def main(config: EvalConfig): custom_kernel, verbose=config.verbose, measure_performance=True, + timing_method=config.timing_method, num_correct_trials=5, num_perf_trials=100, backend=config.backend, diff --git a/scripts/generate_and_eval_single_sample_modal.py b/scripts/generate_and_eval_single_sample_modal.py index 7628e0bf..f41ba95f 100644 --- a/scripts/generate_and_eval_single_sample_modal.py +++ b/scripts/generate_and_eval_single_sample_modal.py @@ -14,7 +14,6 @@ from datasets import load_dataset #from src.dataset import construct_kernelbench_dataset -from src.eval import eval_kernel_against_ref from src.prompt_constructor_toml import get_prompt_for_backend, get_custom_prompt from src.utils import extract_first_code, query_server, set_gpu_arch, read_file, create_inference_server_from_presets @@ -75,6 +74,7 @@ def __init__(self): self.log_eval_result = False self.backend = "cuda" + self.timing_method = "cuda_event" # see timing.py # Prompt generation settings self.prompt_option = "one_shot" # zero_shot, one_shot, few_shot self.include_hardware_info = False @@ -110,7 +110,7 @@ def __repr__(self): class EvalFunc: @modal.method() - def eval_single_sample_modal(self, ref_arch_src, custom_kernel, verbose, gpu_arch, backend, precision): + def eval_single_sample_modal(self, ref_arch_src, custom_kernel, verbose, gpu_arch, backend, precision, timing_method): # 3. Evaluate Kernel # NOTE: no need to wrap around process here as only a single sample # see batch eval for examples of process isolation @@ -121,6 +121,7 @@ def eval_single_sample_modal(self, ref_arch_src, custom_kernel, verbose, gpu_arc modal_set_gpu_arch(gpu_arch) return eval_kernel_against_ref( ref_arch_src, custom_kernel, verbose=verbose, measure_performance=True, + timing_method=timing_method, num_correct_trials=5, num_perf_trials=100, backend=backend, precision=get_torch_dtype_from_string(precision) ) @@ -274,7 +275,7 @@ def main(config: EvalConfig): with app.run(): kernel_exec_result = EvalFunc.with_options(gpu=config.gpu)().eval_single_sample_modal.remote( - ref_arch_src, custom_kernel, config.verbose, gpu_arch_mapping[config.gpu], config.backend, config.precision + ref_arch_src, custom_kernel, config.verbose, gpu_arch_mapping[config.gpu], config.backend, config.precision, config.timing_method ) print(f"Evaluation result for level {config.level} problem {config.problem_id}:\n{kernel_exec_result}") diff --git a/scripts/generate_baseline_time.py b/scripts/generate_baseline_time.py index 5a68ea08..0a1f608b 100644 --- a/scripts/generate_baseline_time.py +++ b/scripts/generate_baseline_time.py @@ -2,11 +2,13 @@ import numpy as np from src.eval import ( load_original_model_and_inputs, - time_execution_with_cuda_event, - get_timing_stats, set_seed, fetch_ref_arch_from_problem_id, ) +from src.timing import ( + get_timing_function, + get_timing_stats, +) from src.dataset import construct_problem_dataset_from_problem_dir from src.utils import read_file import os @@ -81,6 +83,7 @@ def measure_program_time( torch_compile_options: str="default", device: torch.device="cuda:0", verbose: bool = False, + timing_method: str = "cuda_event", ) -> dict: """ Measure the time of a KernelBench reference architecture @@ -116,8 +119,11 @@ def measure_program_time( model = model.cuda(device=device) torch.cuda.synchronize(device=device) - elapsed_times = time_execution_with_cuda_event( - model, *inputs, num_trials=num_trials, verbose=verbose, device=device + + # run chosen timing function + timing_fn = get_timing_function(timing_method) + elapsed_times = timing_fn( + model, inputs, num_trials=num_trials, verbose=verbose, device=device ) runtime_stats = get_timing_stats(elapsed_times, device=device) diff --git a/scripts/run_and_check.py b/scripts/run_and_check.py index 316b96ee..e0492938 100644 --- a/scripts/run_and_check.py +++ b/scripts/run_and_check.py @@ -57,6 +57,8 @@ Usage: 1. PyTorch reference is a local file (local eval) python3 scripts/run_and_check.py ref_origin=local ref_arch_src_path=src/prompts/model_ex_add.py kernel_src_path=src/prompts/model_new_ex_add.py eval_mode=local +python3 scripts/run_and_check.py ref_origin=local ref_arch_src_path=src/prompts/few_shot/model_ex_tiled_matmul.py kernel_src_path=src/prompts/few_shot/model_new_ex_tiled_matmul.py eval_mode=local + 2. PyTorch reference is a kernelbench problem (local eval) python3 scripts/run_and_check.py ref_origin=kernelbench level= problem_id= kernel_src_path= eval_mode=local @@ -101,6 +103,7 @@ def __init__(self): # verbose logging self.verbose = False self.measure_performance = True + self.timing_method = "cuda_event" # see timing.py self.build_dir_prefix = "" # if you want to specify a custom build directory self.clear_cache = False # TODO @@ -128,18 +131,23 @@ def evaluate_single_sample_src(ref_arch_src: str, kernel_src: str, configs: dict num_perf_trials = configs["num_perf_trials"] verbose = configs["verbose"] measure_performance = configs["measure_performance"] + timing_method = configs["timing_method"] + backend = configs["backend"] + precision = kernel_eval.get_torch_dtype_from_string(configs["precision"]) + try: eval_result = kernel_eval.eval_kernel_against_ref( original_model_src=ref_arch_src, custom_model_src=kernel_src, measure_performance=measure_performance, + timing_method=timing_method, verbose=verbose, num_correct_trials=num_correct_trials, num_perf_trials=num_perf_trials, build_dir=build_dir, device=device, - backend=configs["backend"], - precision=kernel_eval.get_torch_dtype_from_string(configs["precision"]) + backend=backend, + precision=precision ) return eval_result except Exception as e: @@ -180,17 +188,21 @@ def evaluate_single_sample_src_modal(self, ref_arch_src: str, kernel_src: str, c num_perf_trials = configs["num_perf_trials"] verbose = configs["verbose"] measure_performance = configs["measure_performance"] + timing_method = configs["timing_method"] + backend = configs["backend"] + precision = kernel_eval.get_torch_dtype_from_string(configs["precision"]) eval_result = eval_kernel_against_ref( original_model_src=ref_arch_src, custom_model_src=kernel_src, measure_performance=measure_performance, + timing_method=timing_method, verbose=verbose, num_correct_trials=num_correct_trials, num_perf_trials=num_perf_trials, device=device, - backend=configs["backend"], - precision=get_torch_dtype_from_string(configs["precision"]) + backend=backend, + precision=precision ) return eval_result diff --git a/src/eval.py b/src/eval.py index 4a072c89..5f1fe8d8 100644 --- a/src/eval.py +++ b/src/eval.py @@ -21,7 +21,7 @@ import torch.nn as nn from pydantic import BaseModel -from . import utils +from . import utils, timing REPO_TOP_PATH = os.path.abspath( os.path.join( @@ -393,8 +393,9 @@ def eval_kernel_against_ref( seed_num: int = 42, num_correct_trials: int = 1, num_perf_trials: int = 10, - verbose: bool = False, measure_performance: bool = False, + timing_method: str = "cuda_event", # see timing.py + verbose: bool = False, build_dir: os.PathLike = None, device: Union[torch.device, int] = ( torch.cuda.current_device() if torch.cuda.is_available() else None @@ -405,11 +406,15 @@ def eval_kernel_against_ref( """ Evaluate the custom kernel against the original model + NOTE: we are thinking about refactor this to be more modularized + and we can add more checks as our other ongiong PRs are working on + num_correct_trials: number of trials to initialize different random inputs; correctness pass only if all trials pass num_perf_trials: run the evalutation many times to take the average device: GPU (cuda) device to run the evalutation on backend: str, one of 'cuda', 'triton', 'tilelang', or 'cute' precision: torch.dtype for computation (note: tilelang only supports fp16) + timing_method: str, method to time kernel, see timing.py for more details """ # TODO: check device is busy assert torch.cuda.is_available(), "CUDA is not available, cannot run Eval" @@ -578,14 +583,16 @@ def eval_kernel_against_ref( model_new = custom_model.to(device=device, dtype=precision) torch.cuda.synchronize(device=device) - elapsed_times = time_execution_with_cuda_event( + # support multiple timing backend + timing_fn = timing.get_timing_function(timing_method) + elapsed_times = timing_fn( model_new, - *inputs, + inputs, num_trials=num_perf_trials, verbose=verbose, device=device, ) - runtime_stats = get_timing_stats(elapsed_times, device=device) + runtime_stats = timing.get_timing_stats(elapsed_times, device=device) if verbose: print(f"[Eval] Performance Stats: {runtime_stats}") @@ -625,64 +632,6 @@ def register_and_format_exception( return metadata -def time_execution_with_cuda_event( - kernel_fn: callable, - *args, - num_warmup: int = 3, - num_trials: int = 10, - verbose: bool = True, - device: torch.device = None, -) -> list[float]: - """ - Time a CUDA kernel function over multiple trials using torch.cuda.Event - - Args: - kernel_fn: Function to time - *args: Arguments to pass to kernel_fn - num_trials: Number of timing trials to run - verbose: Whether to print per-trial timing info - device: CUDA device to use, if None, use current device - - Returns: - List of elapsed times in milliseconds - """ - if device is None: - if verbose: - print(f"Using current device: {torch.cuda.current_device()}") - device = torch.cuda.current_device() - - # Warm ups - for _ in range(num_warmup): - kernel_fn(*args) - torch.cuda.synchronize(device=device) - - print( - f"[Profiling] Using device: {device} {torch.cuda.get_device_name(device)}, warm up {num_warmup}, trials {num_trials}" - ) - elapsed_times = [] - - # Actual trials - for trial in range(num_trials): - # create event marker default is not interprocess - start_event = torch.cuda.Event(enable_timing=True) - end_event = torch.cuda.Event(enable_timing=True) - - start_event.record() - kernel_fn(*args) - end_event.record() - - # Synchronize to ensure the events have completed - torch.cuda.synchronize(device=device) - - # Calculate the elapsed time in milliseconds - elapsed_time_ms = start_event.elapsed_time(end_event) - if verbose: - print(f"Trial {trial + 1}: {elapsed_time_ms:.3g} ms") - elapsed_times.append(elapsed_time_ms) - - return elapsed_times - - def run_and_check_correctness( original_model_instance: nn.Module, new_model_instance: nn.Module, @@ -865,55 +814,6 @@ def convert_to_serializable(obj): return converted_metadata -################################################################################ -# Performance Eval -################################################################################ - - -def fetch_baseline_time( - level_name: str, problem_id: int, dataset: list[str], baseline_time_filepath: str -) -> dict: - """ - Fetch the baseline time from the time - """ - if not os.path.exists(baseline_time_filepath): - raise FileNotFoundError( - f"Baseline time file not found at {baseline_time_filepath}" - ) - - with open(baseline_time_filepath, "r") as f: - baseline_json = json.load(f) - - problem_name = dataset[problem_id].split("/")[-1] - baseline_time = baseline_json[level_name].get(problem_name, None) - return baseline_time - - -def get_timing_stats(elapsed_times: list[float], device: torch.device = None) -> dict: - """Get timing statistics from a list of elapsed times. - - Args: - elapsed_times: List of elapsed times in milliseconds - device: CUDA device, record device info - Returns: - Dict containing mean, std, min, max and num_trials - all timing are in ms - """ - - stats = { - "mean": float(f"{np.mean(elapsed_times):.3g}"), - "std": float(f"{np.std(elapsed_times):.3g}"), - "min": float(f"{np.min(elapsed_times):.3g}"), - "max": float(f"{np.max(elapsed_times):.3g}"), - "num_trials": len(elapsed_times), - } - - if device: - stats["hardware"] = torch.cuda.get_device_name(device=device) - stats["device"] = str(device) # for debugging - - return stats - # if __name__ == "__main__": # fetch_kernel_from_database("kernelbench_prompt_v2_level_2", 1, 1, "http://localhost:9091") diff --git a/src/timing.py b/src/timing.py new file mode 100644 index 00000000..8a36522b --- /dev/null +++ b/src/timing.py @@ -0,0 +1,443 @@ +import torch +import json +import numpy as np +import time +from typing import Any +import os + + +# we leverage triton's testing functionality for some timing methods +from triton import runtime as triton_runtime +from triton import testing as triton_testing + +################################################################################ +# timing.py +# Various timing methods and utilities for performance evaluation +# please make a PR if you have suggestions! + +# Try them out at src/unit_tests/test_eval_timing.py +################################################################################ + +def clear_l2_cache(device: torch.device | str = "cuda"): + """ + Clear L2 Cache line by thrashing with a large tensor + Acknowledge GPU mode reference kernel repo: + https://github.com/gpu-mode/reference-kernels/commit/7c15075a39286e88939d99d3f3a60be88b8e6223#diff-3a30a71cbf8db2badd224f4d92f9a2546925a5b522632a31d353526b7a5f3338R158-R163 + """ + # don't reserve space for persisting lines + # cp.cuda.runtime.cudaDeviceSetLimit(cp.cuda.runtime.cudaLimitPersistingL2CacheSize, 0) + + # Thrash L2 cache by creating a larger dummy tensor, effectively flushing the cache + # 32 * 1024 * 1024 * 8B = 256MB + # NOTE: we can make this more adaptive based on device + # L2 cache sizes: A100=40MB, H100=50MB, H200=90MB, RTX4090=72MB, L40S=48MB, Blackwell≈192MB → overwrite >200MB to fully thrash L2 + dummy = torch.empty((32, 1024, 1024), dtype=torch.int64, device=device) + # write to tensor with inplace fill + dummy.fill_(42) + del dummy + +def clear_l2_cache_triton(cache=None, device: str = "cuda"): + """ + Thrash the cache by making a large dummy tensor, using triton runtime's functionality + """ + with torch.cuda.device(device): + cache = triton_runtime.driver.active.get_empty_cache_for_benchmark() + # this effectively thrashes L2 cache under the hood too + triton_runtime.driver.active.clear_cache(cache) + + +def get_timing_function( + method: str = "cuda_event", # by default +) -> callable: + """ + Get timing function by method name. + + Available methods: + - "cuda_event": torch.cuda.event timing (default, explicit trial control) + - "do_bench": Use triton's do_bench (adaptive trial count based on time budget) + - "do_bench_impl": Mirrors Triton's do_bench implementation (explicit control) + - "host_time": Host side wall-clock timing (might include overhead) + + Args: + method: Name of timing method to use + + Returns: + Timing function with signature (kernel_fn, args, num_warmup, num_trials, + discard_first, verbose, device) -> list[float] + """ + print( + f"[Profiling] Using timing method: {method}" + ) + # NOTE: here are all the timing methods we supporting for now + match method: + case "cuda_event": + return time_execution_with_cuda_event + case "do_bench": + # caveat: just using do_bench as it is + # do not have precise control over number of trials + return time_execution_with_do_bench_interface + case "do_bench_impl": + # do_bench equivalent implementations for transparency and control + return time_execution_with_do_bench_impl + case "host_time": + return time_execution_with_host_time + # we might add other methods in the future + case _: + raise ValueError(f"Unsupported timing method: {method}") + +""" +Kernel Timing Functions +NOTE: we have a WIP blogpost on this topic covering the various timing approaches +""" + + +def time_execution_with_cuda_event( + kernel_fn: callable, + args: list[Any], + num_warmup: int = 3, + num_trials: int = 10, + discard_first: int = 1, # set to 0 to disable + verbose: bool = True, + device: torch.device = None, +) -> list[float]: + """ + Time a CUDA kernel function over multiple trials using torch.cuda.event + The first version of KernelBench used this for evaluation. + We care about cold cache performance here. + + Note: this version does not guard against adverserial cuda streams yet. + It assumes computation is done on the current stream for current device. + Stay tuned for future PRs. + + Args: + kernel_fn: Function to time + args: Arguments to pass to kernel_fn + num_warmup: Number of warmup iterations before timing + num_trials: Number of timing trials to run + discard_first: Number of first trials to discard, for consistency with host_time, set to 0 to disable + verbose: Whether to print per-trial timing info + device: CUDA device to use, defaults to current device + + Returns: + List of elapsed times in milliseconds (length = num_trials) + """ + if device is None: + if verbose: + print(f"Using current device: {torch.cuda.current_device()}") + device = torch.cuda.current_device() + + with torch.cuda.device(device): + + # Warm ups + for _ in range(num_warmup): + kernel_fn(*args) + torch.cuda.synchronize(device=device) + + # note this only release PyTorch’s CUDA caching allocator, not necessarily clearing device's L2 cache + torch.cuda.empty_cache() + + print(f"[Profiling] Using device: {device} {torch.cuda.get_device_name(device)}, warm up {num_warmup}, trials {num_trials}" + ) + + elapsed_times: list[float] = [] # in ms + + # Timing trials + for trial in range(num_trials + discard_first): + torch.cuda.synchronize(device=device) # block on all streams + + # create event marker default is not interprocess + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + + clear_l2_cache(device=device) # measuring cold cache performance + + # note cuda events mark event on current stream + start_event.record() + _ = kernel_fn(*args) + end_event.record() + + # waits for all streams on that device + # though it is important to note the events only record time between on current stream + # TODO: find ways to check hacks by launching work on additional stream + torch.cuda.synchronize(device=device) + + # Calculate the elapsed time in milliseconds + elapsed_time_ms = start_event.elapsed_time(end_event) + + if trial >= discard_first: + if verbose: + logical_idx = trial - discard_first + 1 + print(f"Trial {logical_idx}: {elapsed_time_ms:.3g} ms") + elapsed_times.append(elapsed_time_ms) + + + return elapsed_times + + +def time_execution_with_do_bench_interface( + kernel_fn: callable, + args: list[Any], + # Not used, as triton do_bench handles adaptive trials + num_warmup: int = 3, + num_trials: int = 10, + discard_first: int = 1, # not used here + verbose: bool = True, + device: torch.device | None = None) -> list[float]: + """ + Wrapper around Triton's do_bench for kernel timing. + + Uses Triton's adaptive benchmarking with fixed time budgets (warmup=25ms, rep=100ms) [Triton do_bench default]. + The number of trials is determined automatically based on kernel runtime. + + Note: num_warmup, num_trials, discard_first are ignored - included only for + API compatibility with other timing functions. + + Args: + kernel_fn: Function to time + args: Arguments to pass to kernel_fn + num_warmup: (ignored) Triton controls warmup + num_trials: (ignored) Triton controls trial count + discard_first: (ignored) Not used + verbose: Whether to print timing info + device: CUDA device to use + + Returns: + List of elapsed times in milliseconds + + See: https://triton-lang.org/main/python-api/generated/triton.testing.do_bench.html + """ + if device is None: + if verbose: + print(f"Using current device: {torch.cuda.current_device()}") + device = torch.cuda.current_device() + + + do_bench_fn = lambda : kernel_fn(*args) # wrap function with arguments + with torch.cuda.device(device): + return triton_testing.do_bench(fn=do_bench_fn, + warmup=25, + rep=100, + grad_to_none=None, + quantiles=None, + return_mode="all") + + +def time_execution_with_do_bench_impl( + kernel_fn: callable, + args: list[Any], + num_warmup: int = 3, + num_trials: int = 10, + discard_first: int = 1, # not used here + verbose: bool = True, + device: torch.device | None = None) -> list[float]: + """ + This is modifying the triton do_bench codebase + See Triton's implementation for more details + https://github.com/triton-lang/triton/blob/9073370d5979218d1afa44ec895bbd80e7419a8c/python/triton/testing.py#L127 + + Note we duplicate triton's implementation and modify / comment out parts + to use num_warmup and num_trials that explicitly follows what user define here + instead of do_bench's version that computes how many times to run warmup and + profile based on total warmup and repetition time + + We commented out unused parts and kept only what's needed for kernelbench timing eval + Args: + kernel_fn: Function to time + args: Arguments to pass to kernel_fn + num_warmup: Number of warmup iterations + num_trials: Number of timing trials + discard_first: (not used) Trials to discard + verbose: Whether to print timing info + device: CUDA device to use, defaults to current device + Returns: + List of elapsed times in milliseconds (length = num_trials) + """ + + device = device if device is not None else torch.cuda.current_device() + if verbose: + print(f"Using do_bench to evaluate kernel on {device}") + + + # added to constraint to this device + with torch.cuda.device(device): + + # specify device interface (supports both nvidia and amd) + # under the hood, di is torch.cuda (amd uses a cuda compatible interface) + di = triton_runtime.driver.active.get_device_interface() + + kernel_fn(*args) + di.synchronize(device=device) + + # clear l2 cache + cache = triton_runtime.driver.active.get_empty_cache_for_benchmark() + + # do_bench Estimate the runtime of the function + # Here we are not using it not needed since now the warmup and repeat steps are set by the user) + # start_event = di.Event(enable_timing=True) + # end_event = di.Event(enable_timing=True) + # start_event.record() + # for _ in range(5): + # triton_runtime.driver.active.clear_cache(cache) + # kernel_fn(*args) + # end_event.record() + # di.synchronize() + # estimate_ms = start_event.elapsed_time(end_event) / 5 + + # compute number of warmup and repeat + # Change + # n_warmup = max(1, int(warmup / estimate_ms)) + # n_repeat = max(1, int(rep / estimate_ms)) + # n_warmup = warmup + # n_repeat = rep + # end of change + start_event = [di.Event(enable_timing=True) for i in range(num_trials)] + end_event = [di.Event(enable_timing=True) for i in range(num_trials)] + # Warm-up + for _ in range(num_warmup): + kernel_fn(*args) + di.synchronize(device=device) + + # Benchmark + for i in range(num_trials): + # All KernelBench functions are forward passes, so we don't need to reset gradients + # we don't want `fn` to accumulate gradient values + # if it contains a backward pass. So we clear the + # provided gradients + # if grad_to_none is not None: + # for x in grad_to_none: + # x.grad = None + + # we clear the L2 cache before each run + triton_runtime.driver.active.clear_cache(cache) + # record time of `fn` + start_event[i].record() + kernel_fn(*args) + end_event[i].record() + # Record clocks + di.synchronize(device=device) + times = [s.elapsed_time(e) for s, e in zip(start_event, end_event)] + + if verbose: print('Done with do_bench evaluation') + return times + + +def time_execution_with_host_time( + kernel_fn: callable, + args: list[Any], + num_warmup: int = 3, + num_trials: int = 10, + discard_first: int = 1, # to reduce impact of initialization overhead + verbose: bool = True, + device: torch.device | None = None, +) -> list[float]: + """ + Time a CUDA kernel function over multiple trials using Host (CPU) side timing + + This measures host-side wall clock time, E2E latency observed by host + Note that could take including Python overhead, CUDA launch/runtime costs, synchronization, all GPU work across all streams, and host OS overhaed + Hence results might be longer than device-side (CUDA event) timings + + Args: + kernel_fn: Function to time + args: Arguments to pass to kernel_fn + num_trials: Number of timing trials to run + discard_first: Number of first few trials to discard (due to some initialization overhead) + verbose: Whether to print per-trial timing info + device: CUDA device to use, if None, use current device + + Returns: + List of elapsed times in milliseconds + """ + if device is None: + if verbose: + print(f"Using current device: {torch.cuda.current_device()}") + device = torch.cuda.current_device() + + # Warm ups + for _ in range(num_warmup): + kernel_fn(*args) + torch.cuda.synchronize(device=device) + + print(f"[Profiling] Using device: {device} {torch.cuda.get_device_name(device)}, warm up {num_warmup}, trials {num_trials}") + elapsed_times = [] + + # clear PyTorch allocator cache + torch.cuda.empty_cache() + + # Actual trials + for trial in range(num_trials + discard_first): + # block all streams on device + torch.cuda.synchronize(device=device) + + # focus on cold_cache performance + clear_l2_cache(device=device) + + # CPU-side wall clock time using perf_counter (high-resolution timer) + start_time = time.perf_counter() + kernel_fn(*args) + torch.cuda.synchronize(device=device) # wait for all stream to finish + # this blocks the CPU until all GPU work on device is done + # this means all kernels on all streams + end_time = time.perf_counter() + + # Calculate the elapsed time in milliseconds + elapsed_time_ms = (end_time - start_time) * 1000 + if trial >= discard_first: + if verbose: + logical_idx = trial - discard_first + 1 + print(f"Trial {logical_idx}: {elapsed_time_ms:.3g} ms") + elapsed_times.append(elapsed_time_ms) + + return elapsed_times + +######################################################## +# Timing stats +# tools to help compute speedup and other time +######################################################### +def fetch_baseline_time( + level_name: str, problem_id: int, dataset: list[str], baseline_time_filepath: str +) -> dict: + """ + Fetch the baseline time from the time + + Note: might be better to just run the refernece using torch eager and compile sometimes + Will add this as a functionality for eval revamp + """ + if not os.path.exists(baseline_time_filepath): + raise FileNotFoundError( + f"Baseline time file not found at {baseline_time_filepath}" + ) + + with open(baseline_time_filepath, "r") as f: + baseline_json = json.load(f) + + # TODO: replace with the new Dataset object that Omar will merge in + problem_name = dataset[problem_id].split("/")[-1] + baseline_time = baseline_json[level_name].get(problem_name, None) + return baseline_time + + +def get_timing_stats(elapsed_times: list[float], device: torch.device = None) -> dict: + """Get timing statistics from a list of elapsed times. + + Args: + elapsed_times: List of elapsed times in milliseconds + device: CUDA device, record device info + Returns: + Dict containing mean, std, min, max and num_trials + all timing are in ms + """ + + stats = { + "mean": float(f"{np.mean(elapsed_times):.3g}"), + "std": float(f"{np.std(elapsed_times):.3g}"), + "min": float(f"{np.min(elapsed_times):.3g}"), + "max": float(f"{np.max(elapsed_times):.3g}"), + "num_trials": len(elapsed_times), + } + + if device: + stats["hardware"] = torch.cuda.get_device_name(device=device) + stats["device"] = str(device) # for debugging + + return stats diff --git a/src/unit_tests/test_eval_timing.py b/src/unit_tests/test_eval_timing.py new file mode 100644 index 00000000..84921a37 --- /dev/null +++ b/src/unit_tests/test_eval_timing.py @@ -0,0 +1,112 @@ +import os +import sys +import torch +import pytest + +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) +import timing + +""" +Test Timing +We want to systematically study different timing methodologies. +""" +REPO_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")) + +# use exampls in the few shot directory +EXAMPLES_PATH = os.path.join(REPO_PATH, "src", "prompts", "few_shot") + +# Configure your test cases here +TEST_REF_FILE = "model_ex_tiled_matmul.py" +TEST_KERNEL_FILE = "model_new_ex_tiled_matmul.py" + +assert os.path.exists(os.path.join(EXAMPLES_PATH, TEST_REF_FILE)), f"Reference file {TEST_REF_FILE} does not exist in {EXAMPLES_PATH}" +assert os.path.exists(os.path.join(EXAMPLES_PATH, TEST_KERNEL_FILE)), f"Kernel file {TEST_KERNEL_FILE} does not exist in {EXAMPLES_PATH}" + + +def _run_timing_smoke_test_matmul(timing_func_name:str, device:str="cuda"): + """ + Scaffold function for timing smoke tests. + Smoke test for using 2048x2048x2048 matmul with 5 warmup and 100 trials. + + Args: + timing_fn: The timing function to test + use_args: Whether the timing function expects args parameter (True for cuda_event/time_dot_time, False for do_bench) + """ + # Skip if CUDA is not available + if not torch.cuda.is_available(): + pytest.skip("CUDA not available, skipping timing tests") + + # Create simple test matrices + M = 2048 + N = 2048 + K = 2048 + a = torch.randn(M, K, device=device) + b = torch.randn(K, N, device=device) + + num_warmup = 5 + num_trials = 100 + + # Define the kernel function to time + def matmul_kernel(a, b): + return torch.matmul(a, b) + + timing_func = timing.get_timing_function(timing_func_name) + elapsed_times = timing_func( + matmul_kernel, + args=[a, b], + num_warmup=num_warmup, + num_trials=num_trials, + verbose=False, + device=device + ) + + # Validate results + assert isinstance(elapsed_times, list), "Expected list of elapsed times" + + # disabled this check as do_bench does not use num_trials + # assert len(elapsed_times) == num_trials, f"Expected {num_trials} timing results, got {len(elapsed_times)}" + assert all(isinstance(t, float) for t in elapsed_times), "All timing results should be floats" + assert all(t > 0 for t in elapsed_times), "All timing results should be positive" + # DEBUG print times + # print(f"smoke test matmul elapsed times with {timing_func_name} (in ms): {elapsed_times}") + + stats = timing.get_timing_stats(elapsed_times, device=device) + print("Timing stats") + print(stats) + + +# test all currently available timing methods +def run_all_timing_tests(device="cuda"): + timing_methods = ["cuda_event", "host_time", "do_bench", "do_bench_impl"] + # timing_methods = ["cuda_event", "do_bench_impl"] + for timing_method in timing_methods: + _run_timing_smoke_test_matmul(timing_method, device=device) + + + +test_device = torch.device("cuda:5") +run_all_timing_tests(test_device) + + + +def test_do_bench_simple_smoke(): + """ + Smoke test for do_bench itself on a simple CUDA operation. + Just checks it runs and returns timings. + """ + if not torch.cuda.is_available(): + pytest.skip("CUDA not available, skipping do_bench smoke test") + + from do_bench import do_bench + + x = torch.randn(1024, device="cuda") + + def fn(): + # simple GPU op; do_bench will sync/timestamp internally + return (x * 2).sum() + + rep = 5 + times = do_bench(fn, warmup=2, rep=rep, return_mode="all") + assert isinstance(times, list) + assert len(times) == rep +