diff --git a/README.md b/README.md index 7343e73b..61f7e8d8 100644 --- a/README.md +++ b/README.md @@ -85,6 +85,10 @@ We have transitioned to using `pyproject.toml` and `uv` for dependency managemen # Install base dependencies (works without a local GPU) uv sync +# Install ROCm-enabled PyTorch (pick the correct ROCm version for your system): + +uv pip install torch torchvision --index-url https://download.pytorch.org/whl/rocm7.1 + # Install with GPU dependencies (for local GPU evaluation) uv sync --extra gpu @@ -115,9 +119,9 @@ uv run python scripts/generate_and_eval_single_sample.py dataset_src=huggingface ``` **What you might need to modify** -* **`gpu_arch`** - Depend on your GPU, you might need to adjust the `gpu_arch` argument to reflect your hardware. +* **`gpu_arch`** - Depend on your GPU, you might need to adjust the `gpu_arch` argument to reflect your hardware. currently supported `["gfx1100"]` (W7900D), `["gfx1201"]` (R9700). * **`precision`** - You can specify the precision of tensor by `precision=fp32`. Currently all of our reported results are `fp32` but we added support for `fp16` & `bf16`. -* **`backend`** - We are also supporting other GPU programming languages beyond `cuda`. Simply specify `backend=triton`. For now we support DSLs: `cuda`, `triton`, `cute`, `tilelang`, `thunderkittens`. +* **`backend`** - We are also supporting other GPU programming languages beyond `cuda`. Simply specify `backend=triton`. For now we support DSLs: `cuda`, `triton`, `cute`, `tilelang`, `thunderkittens`. Note: ROCm GPUs currently use `backend=triton`. Note on setting up ThunderKittens (TK) locally: to use `backend=thunderkittens`, you need to git clone the ThunderKittens repo and set the following environment variable to point to your local ThunderKittens directory, `export THUNDERKITTENS_ROOT=`, and all ThunderKitten programs as shown in the [example](src/kernelbench/prompts/model_new_ex_add_thunderkittens.py), should contain `tk_root = os.environ.get("THUNDERKITTENS_ROOT", "/root/ThunderKittens")`, which enable the kernel to include the right TK primitives. In addition, we only support BF16 for TK right now. diff --git a/pyproject.toml b/pyproject.toml index bed37150..4eb6ea4f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,9 +12,11 @@ dependencies = [ # Frameworks "torch==2.9.0", + "pytorch-triton-rocm>=3.4.0", "transformers", "datasets", "modal", + "ruff", # helper "tqdm", @@ -37,12 +39,9 @@ dependencies = [ [project.optional-dependencies] gpu = [ - # GPU-specific dependencies (requires CUDA) + # GPU-specific dependencies (ROCm / AMD Radeon) "triton", - "nvidia-cutlass-dsl", "tilelang", - "cupy-cuda12x", - "nsight-python", ] dev = [ "pytest", @@ -55,4 +54,14 @@ where = ["src"] include = ["kernelbench*"] [tool.setuptools.package-data] -kernelbench = ["prompts/**/*"] \ No newline at end of file +kernelbench = ["prompts/**/*"] + +[tool.uv.sources] +torch = [{ index = "pytorch-rocm" }] +torchvision = [{ index = "pytorch-rocm" }] +pytorch-triton-rocm = [{ index = "pytorch-rocm" }] + +[[tool.uv.index]] +name = "pytorch-rocm" +url = "https://download.pytorch.org/whl/rocm6.4" +explicit = true \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 07603a86..805c61be 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,35 +1,38 @@ -# ARCHIVED: We are transitioning to pyproject.toml and uv-based project management -# However, we provide this as a backup for now - # Frameworks -# we use latest PyTorch stable release -torch==2.9.* -triton==3.5.* - +# torch==2.5.0 # we shall upgrade torch for blackwell when it is stable -transformers>=4.57.3 -datasets>=4.4.2 -modal>=1.3.0 +# AMD ROCm note: install ROCm-enabled torch from the PyTorch ROCm index. +# Current ROCm env: +# torch==2.8.0+rocm7.1.1.gitcba8b9d2 +# HIP==7.1.52802-26aae437f6 +# ROCm SMI (concise): +# Device IDs: 0x7551 x4 +transformers +datasets +modal # DSLs -nvidia-cutlass-dsl -tilelang +# nvidia-cutlass-dsl +# triton (required for AMD ROCm kernels) +# helion (optional, Helion DSL; install separately if needed) # helper -tqdm>=4.67.1 +tqdm packaging -pydra-config -ninja>=1.13.0 -cupy-cuda12x==13.6.0 -tomli>=2.3.0 -tabulate>=0.9.0 -nsight-python +pydra_config +dill>=0.3.7,<0.4 +pytest +ninja # Numerics -einops>=0.8.1 -python-dotenv>=1.2.1 -numpy==2.4.0 +einops +dotenv +numpy + +# to deprecate with litellm +google-generativeai +together +openai +anthropic +pydantic==2.12.4 -# use litellm for cloud providers and openai for local -openai>=2.14.0 -litellm[proxy]>=1.80.10 \ No newline at end of file diff --git a/src/kernelbench/eval.py b/src/kernelbench/eval.py index dd79b2c0..170a4bcf 100644 --- a/src/kernelbench/eval.py +++ b/src/kernelbench/eval.py @@ -1,5 +1,9 @@ """ Helpers for Evaluations + +Supports both NVIDIA CUDA and AMD ROCm GPUs. +ROCm support is provided through PyTorch's HIP backend, which exposes +the same torch.cuda API for AMD GPUs. """ import hashlib @@ -13,7 +17,7 @@ import traceback from contextlib import redirect_stderr, redirect_stdout from io import StringIO -from typing import Union, Optional +from typing import Union, Optional, Literal import numpy as np import requests @@ -23,6 +27,103 @@ from . import timing, dataset + +################################################################################ +# GPU Detection and Compatibility +################################################################################ + +def is_rocm_available() -> bool: + """ + Check if ROCm (AMD GPU) is available. + ROCm uses PyTorch's HIP backend which exposes torch.cuda API. + """ + if not torch.cuda.is_available(): + return False + # Check for HIP version (ROCm indicator) + return hasattr(torch.version, 'hip') and torch.version.hip is not None + + +def is_cuda_available() -> bool: + """ + Check if NVIDIA CUDA is available (not ROCm). + """ + if not torch.cuda.is_available(): + return False + return not is_rocm_available() + + +def get_gpu_vendor() -> Literal["nvidia", "amd", "unknown"]: + """ + Detect the GPU vendor (NVIDIA or AMD). + """ + if not torch.cuda.is_available(): + return "unknown" + if is_rocm_available(): + return "amd" + return "nvidia" + + +def get_gpu_info(device: torch.device = None) -> dict: + """ + Get GPU information including vendor, name, and memory. + + Returns: + dict with keys: vendor, name, memory_total_gb, compute_capability (NVIDIA only) + """ + if device is None: + device = torch.cuda.current_device() + + info = { + "vendor": get_gpu_vendor(), + "name": torch.cuda.get_device_name(device), + "memory_total_gb": torch.cuda.get_device_properties(device).total_memory / (1024**3), + } + + # Add compute capability for NVIDIA GPUs + if info["vendor"] == "nvidia": + props = torch.cuda.get_device_properties(device) + info["compute_capability"] = f"{props.major}.{props.minor}" + + # Add ROCm-specific info for AMD GPUs + if info["vendor"] == "amd": + info["hip_version"] = torch.version.hip + # Try to get architecture info + try: + props = torch.cuda.get_device_properties(device) + info["gcn_arch"] = getattr(props, 'gcnArchName', 'unknown') + except: + pass + + return info + + +def check_gpu_available(verbose: bool = False) -> bool: + """ + Check if any GPU (CUDA or ROCm) is available. + + Args: + verbose: If True, print GPU information + + Returns: + True if GPU is available, False otherwise + """ + if not torch.cuda.is_available(): + if verbose: + print("[GPU] No GPU available") + return False + + if verbose: + gpu_info = get_gpu_info() + vendor_name = "AMD ROCm" if gpu_info["vendor"] == "amd" else "NVIDIA CUDA" + print(f"[GPU] {vendor_name} available: {gpu_info['name']}") + print(f"[GPU] Memory: {gpu_info['memory_total_gb']:.1f} GB") + if gpu_info["vendor"] == "amd": + print(f"[GPU] HIP Version: {gpu_info.get('hip_version', 'unknown')}") + else: + print(f"[GPU] Compute Capability: {gpu_info.get('compute_capability', 'unknown')}") + + return True + REPO_TOP_PATH = os.path.abspath( os.path.join( os.path.dirname(__file__), @@ -63,9 +164,19 @@ def fetch_ref_arch_from_level_problem_id(level, problem_id, with_name=False): def set_seed(seed: int): + """ + Set random seed for reproducibility. + Works with both NVIDIA CUDA and AMD ROCm GPUs. + """ torch.manual_seed(seed) - # NOTE: this only sets on current cuda device - torch.cuda.manual_seed(seed) + # NOTE: this sets on current GPU device (CUDA or ROCm via HIP) + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) # for multi-GPU + # Set deterministic behavior + # NOTE: cudnn settings may not be fully supported on ROCm + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False def get_torch_dtype_from_string(precision: str) -> torch.dtype: """ @@ -225,24 +336,39 @@ def graceful_eval_cleanup( tempfile: tempfile.NamedTemporaryFile = None, ): """ - Clean up env, gpu cache, and compiled CUDA extensions after evaluation - """ # delete ran-specific function definitions before next eval run + Clean up environment, GPU cache, and compiled extensions after evaluation. + Works with both NVIDIA CUDA and AMD ROCm GPUs. + """ + # Clean up linecache entries + fake_filenames = [k for k in linecache.cache.keys() if k.startswith(("x faster than the reference ) -> KernelExecResult: """ - Evaluate the custom kernel against the original model + Evaluate the custom kernel against the original model. + + Supports both NVIDIA CUDA and AMD ROCm GPUs. NOTE: we are thinking about refactor this to be more modularized - and we can add more checks as our other ongiong PRs are working on + and we can add more checks as our other ongoing PRs are working on - num_correct_trials: number of trials to initialize different random inputs; correctness pass only if all trials pass - num_perf_trials: run the evalutation many times to take the average - device: GPU (cuda) device to run the evalutation on - backend: str, one of 'cuda', 'triton', 'tilelang', or 'cute' - precision: torch.dtype for computation (note: tilelang only supports fp16) - timing_method: str, method to time kernel, see timing.py for more details + Args: + original_model_src: Source code of the reference PyTorch model + custom_model_src: Source code of the optimized model with custom kernels + seed_num: Random seed for reproducibility + num_correct_trials: Number of trials with different random inputs; pass only if all pass + num_perf_trials: Run the evaluation many times to take the average + measure_performance: Whether to measure and compare performance + timing_method: Method to time kernel, see timing.py for more details + verbose: Enable verbose logging + build_dir: Directory for caching compiled kernels + device: GPU device to run evaluation on (CUDA or ROCm) + backend: One of 'cuda', 'triton', 'tilelang', or 'cute' + precision: torch.dtype for computation (note: tilelang only supports fp16) + check_for_excessive_speedup: Guard against potential reward hacking + excessive_speedup_threshold: Flag if kernel is more than this faster than reference + + Returns: + KernelExecResult with compilation status, correctness, and performance metrics ONGOING EFFORT to refactor and modularize this, and adding more tests for eval. """ - # TODO: check device is busy - assert torch.cuda.is_available(), "CUDA is not available, cannot run Eval" + # Check GPU availability (works for both CUDA and ROCm) + if not check_gpu_available(verbose=verbose): + raise RuntimeError("No GPU available (CUDA or ROCm), cannot run Eval") + + # Get GPU vendor info for metadata + gpu_vendor = get_gpu_vendor() + gpu_info = get_gpu_info(device if isinstance(device, int) else None) if backend.lower() == "tilelang": assert precision == torch.float16 or precision == torch.bfloat16, "TileLang only supports fp16 or bfloat16" @@ -439,35 +584,55 @@ def eval_kernel_against_ref( linewidth=80, # Maximum width before wrapping ) - # set CUDA device + # Set GPU device (works for both CUDA and ROCm via HIP) torch.cuda.set_device(device) - # Backends that use tempfile approach and need CUDA_VISIBLE_DEVICES - # TileLang, Triton, and CuTe all use tempfile for proper module loading - uses_tempfile = backend.lower() in ["triton", "tilelang", "cute"] + # Backends that use tempfile approach + # - triton: @triton.jit decorator requires file-based import + # - cute: CUTLASS requires file-based compilation + # - tilelang: JIT requires file-based import + backend_lower = backend.lower() + uses_tempfile = backend_lower in ["triton", "tilelang", "cute"] metadata = {} # for storing result metadata metadata["hardware"] = torch.cuda.get_device_name(device=device) - metadata["device"] = str(device) # for debugging + metadata["device"] = str(device) + metadata["gpu_vendor"] = gpu_vendor + metadata["backend"] = backend_lower + + # Add vendor-specific info + if gpu_vendor == "amd": + metadata["hip_version"] = gpu_info.get("hip_version", "unknown") + metadata["gcn_arch"] = gpu_info.get("gcn_arch", "unknown") + else: + metadata["compute_capability"] = gpu_info.get("compute_capability", "unknown") if uses_tempfile: - # need to set env var for triton/cute code to guarantee no wrong device shenanigans + # Set device visibility for triton/cute/tilelang if isinstance(device, int): device_num = device elif isinstance(device, torch.device): assert ( device.type == "cuda" - ), "CUDA is not availible on device, cannot run Eval" - device_num = device.index + ), "GPU is not available on device, cannot run Eval" + device_num = device.index if device.index is not None else 0 else: raise ValueError( f"device must be an int or torch.device, got {type(device)}" ) - os.environ["CUDA_VISIBLE_DEVICES"] = str(device_num) + + # Set device visibility + # For ROCm, use HIP_VISIBLE_DEVICES; for CUDA, use CUDA_VISIBLE_DEVICES + if gpu_vendor == "amd": + os.environ["HIP_VISIBLE_DEVICES"] = str(device_num) + os.environ["ROCR_VISIBLE_DEVICES"] = str(device_num) + else: + os.environ["CUDA_VISIBLE_DEVICES"] = str(device_num) context = {} if verbose: - print(f"[Eval] Start Evalulation! on device: {device}") + vendor_str = "AMD ROCm" if gpu_vendor == "amd" else "NVIDIA CUDA" + print(f"[Eval] Start Evaluation on device: {device} ({vendor_str})") print("[Eval] Loading Original Model") Model, get_init_inputs, get_inputs = load_original_model_and_inputs( @@ -495,7 +660,6 @@ def eval_kernel_against_ref( tempfile = None # add hash for later to distinguish between multi-turn kernels - backend_lower = backend.lower() if backend_lower in ["triton", "tilelang", "cute"]: # Use tempfile approach for triton, tilelang, and cute # These DSLs require proper module import for JIT decorators to work diff --git a/src/kernelbench/prompt_constructor_toml.py b/src/kernelbench/prompt_constructor_toml.py index 4349a74d..1de60288 100644 --- a/src/kernelbench/prompt_constructor_toml.py +++ b/src/kernelbench/prompt_constructor_toml.py @@ -24,6 +24,15 @@ "hardware_best_practices", ] +# AMD-specific hardware component keys +AMD_HARDWARE_COMPONENT_KEYS = [ + "hardware_header", + "hardware_specs", + "hardware_definitions", + "hardware_best_practices", + "amd_optimization_guidance", +] + @dataclass class PromptConfig: """ @@ -88,41 +97,95 @@ def compose_blocks(self, keys: List[str]) -> str: return "\n".join(text_parts).strip() + "\n" -def _gpu_context_from_gpu_specs(py_path: str, gpu_name: str) -> Dict[str, str]: +def _gpu_context_from_gpu_specs(py_path: str, gpu_name: str, vendor: str = "nvidia") -> Dict[str, str]: """ Load GPU_* dicts from the GPU specs file (no exec of raw strings; use runpy). + + Supports both NVIDIA and AMD GPUs. + Expected globals: - - GPU_SPEC_INFO: dict[str, dict] - - GPU_DEFINITIONS: dict[str, str] - - GPU_BEST_PRACTICES: list[str] OR {"list": [...]} for compatibility + For NVIDIA: + - GPU_SPEC_INFO: dict[str, dict] + - GPU_DEFINITIONS: dict[str, str] + - GPU_BEST_PRACTICES: list[str] OR {"list": [...]} for compatibility + For AMD: + - AMD_GPU_SPEC_INFO: dict[str, dict] + - AMD_GPU_DEFINITIONS: dict[str, str] + - AMD_GPU_BEST_PRACTICES: list[str] + Args: + py_path: Path to the gpu_specs.py file + gpu_name: GPU name to look up (e.g., "L40S", "MI355X", "R9700") + vendor: GPU vendor ("nvidia" or "amd") + + Returns: + Dict with context variables for prompt rendering """ mod = runpy.run_path(py_path) - spec_info = mod.get("GPU_SPEC_INFO", {}) - definitions = mod.get("GPU_DEFINITIONS", {}) - best = mod.get("GPU_BEST_PRACTICES", []) + + is_amd = vendor.lower() == "amd" + + if is_amd: + # Load AMD-specific specs + spec_info = mod.get("AMD_GPU_SPEC_INFO", {}) + definitions = mod.get("AMD_GPU_DEFINITIONS", {}) + best = mod.get("AMD_GPU_BEST_PRACTICES", []) + + # AMD-specific prompts-INPROGRESS + + + else: + # Load NVIDIA specs + spec_info = mod.get("GPU_SPEC_INFO", {}) + definitions = mod.get("GPU_DEFINITIONS", {}) + best = mod.get("GPU_BEST_PRACTICES", []) if not spec_info or not definitions or best is None: - raise ValueError("GPU_SPEC_INFO / GPU_DEFINITIONS / GPU_BEST_PRACTICES missing in gpu specs .py") + vendor_name = "AMD" if is_amd else "NVIDIA" + raise ValueError(f"{vendor_name} GPU_SPEC_INFO / GPU_DEFINITIONS / GPU_BEST_PRACTICES missing in gpu specs .py") if isinstance(best, dict) and "list" in best: best = best["list"] if gpu_name not in spec_info: - raise KeyError(f"GPU name {gpu_name} not found in GPU_SPEC_INFO") + # For AMD, try to find a matching key by partial match + if is_amd: + matched_key = None + for key in spec_info.keys(): + if key.lower() in gpu_name.lower() or gpu_name.lower() in key.lower(): + matched_key = key + break + if matched_key is None and spec_info: + matched_key = next(iter(spec_info)) # Use first entry as fallback + if matched_key: + gpu_name = matched_key + else: + raise KeyError(f"GPU name {gpu_name} not found in AMD_GPU_SPEC_INFO") + else: + raise KeyError(f"GPU name {gpu_name} not found in GPU_SPEC_INFO") curr = spec_info[gpu_name] gpu_architecture = curr.get("GPU Architecture", "Unknown") - specs_bullets = "\n".join([f"- We have {v} of {k}." for k, v in curr.items() if k != "GPU Architecture"]) + + if is_amd: + specs_bullets = "\n".join([f"- {k}: {v}" for k, v in curr.items()]) + vendor_display = "AMD" + else: + specs_bullets = "\n".join([f"- We have {v} of {k}." for k, v in curr.items() if k != "GPU Architecture"]) + vendor_display = "NVIDIA" + defs_bullets = "\n".join([f"- {k}: {v}" for k, v in definitions.items()]) best_bullets = "\n".join([f"- {x}" for x in (best or [])]) - return { + context = { "gpu_name": gpu_name, "gpu_architecture": gpu_architecture, "gpu_specs_bullets": specs_bullets, "gpu_definitions_bullets": defs_bullets, "gpu_best_practices_bullets": best_bullets, - } + "gpu_vendor": vendor.lower(), + "gpu_vendor_display": vendor_display, + } + return context def render_prompt_by_option( *, @@ -135,10 +198,13 @@ def render_prompt_by_option( precision: Optional[str] = None, include_hardware: bool = False, components_override: Optional[List[str]] = None, + vendor: str = "nvidia", ) -> str: """ Render a prompt using backends.X and options.Y structure from TOML. + Supports both NVIDIA and AMD GPUs. + Args: prompts_toml: Path to the prompts.toml file backend: The kernel backend (triton, cuda, cute, tilelang) @@ -154,12 +220,15 @@ def render_prompt_by_option( components_override: When provided, users can arrange prompt components from the toml file in any order they want. Components must exist under templates.common or be hardware_* entries. + vendor: GPU vendor ("nvidia" or "amd") - affects hardware info and prompt content Returns: The rendered prompt string """ cfg = PromptConfig.from_toml(prompts_toml) + is_amd = vendor.lower() == "amd" + # Get backend-specific content try: backend_data = cfg.data["backends"][backend] @@ -172,15 +241,19 @@ def render_prompt_by_option( except KeyError: raise KeyError(f"Unknown option: {option}") + # Determine which hardware component keys to use + hardware_keys = AMD_HARDWARE_COMPONENT_KEYS if is_amd else HARDWARE_COMPONENT_KEYS + component_sequence = list(components_override or option_data["components"]) if include_hardware: if components_override is None: insert_idx = component_sequence.index("arch_block") if "arch_block" in component_sequence else len(component_sequence) - component_sequence[insert_idx:insert_idx] = HARDWARE_COMPONENT_KEYS + component_sequence[insert_idx:insert_idx] = hardware_keys else: # Custom sequences must explicitly have hardware blocks present in their prompt if they # have set they are including hardware info. - if not any(component in HARDWARE_COMPONENT_KEYS for component in component_sequence): + all_hardware_keys = set(HARDWARE_COMPONENT_KEYS) | set(AMD_HARDWARE_COMPONENT_KEYS) + if not any(component in all_hardware_keys for component in component_sequence): raise ValueError( "components_override must contain at least one hardware_* entry when include_hardware=True" ) @@ -288,7 +361,7 @@ def render_example_entry(input_code: str, output_code: str, example_label: str) raise ValueError( f"Hardware info requested for option '{option}'; provide gpu_specs_py and gpu_name" ) - context = {**context, **_gpu_context_from_gpu_specs(resolve_path(gpu_specs_py), gpu_name)} + context = {**context, **_gpu_context_from_gpu_specs(resolve_path(gpu_specs_py), gpu_name, vendor=vendor)} # Builds the prompt from the components in the toml file. prompt_parts = [] @@ -326,17 +399,21 @@ def get_prompt_for_backend( precision: Optional[str] = None, include_hardware: bool = False, gpu_name: Optional[str] = None, + vendor: str = "nvidia", ) -> str: """ Generate a prompt for a specific backend and option. + Supports both NVIDIA and AMD GPUs. + Args: ref_arch_src: The reference architecture source code backend: The kernel backend (triton, cuda, cute, tilelang) option: The prompt option (zero_shot, one_shot, few_shot) precision: Optional precision (fp32, fp16, bf16) - defaults to fp32 if not provided include_hardware: When True, append hardware guidance blocks (requires gpu_name) - gpu_name: GPU identifier used when include_hardware is True (e.g., "A100") + gpu_name: GPU identifier used when include_hardware is True (e.g., "A100", "R9700", "W7900D") + vendor: GPU vendor ("nvidia" or "amd") """ return render_prompt_by_option( prompts_toml=PROMPTS_TOML, @@ -347,6 +424,7 @@ def get_prompt_for_backend( include_hardware=include_hardware, gpu_specs_py=GPU_SPECS_PY if include_hardware else None, gpu_name=gpu_name, + vendor=vendor, ) @@ -360,11 +438,25 @@ def get_custom_prompt( include_hardware: bool = False, gpu_name: Optional[str] = None, prompts_toml: str = PROMPTS_TOML, + vendor: str = "nvidia", ) -> str: """ Render a prompt defined under [custom_prompts.] in prompts.toml. Must still provide backend/option/precision settings just like - get_prompt_for_backend. + get_prompt_for_backend. + + Supports both NVIDIA and AMD GPUs. + + Args: + custom_key: The custom prompt key in prompts.toml + ref_arch_src: The reference architecture source code + backend: The kernel backend (triton, cuda, cute, tilelang) + option: The prompt option (zero_shot, one_shot, few_shot) + precision: Optional precision (fp32, fp16, bf16) + include_hardware: When True, include hardware guidance + gpu_name: GPU identifier (e.g., "A100", "R9700", "W7900D") + prompts_toml: Path to prompts.toml file + vendor: GPU vendor ("nvidia" or "amd") """ if not ref_arch_src: raise ValueError(f"Custom prompt '{custom_key}' requires ref_arch_src.") @@ -386,6 +478,7 @@ def get_custom_prompt( gpu_specs_py=GPU_SPECS_PY if include_hardware else None, gpu_name=gpu_name, components_override=components_override, + vendor=vendor, ) __all__ = [ @@ -404,7 +497,7 @@ def log_prompt(prompt: str, dir_path: str, file_name: str): def test_prompt(): """ - Demonstrate baseline, few-shot, DSL, hardware-aware, and custom prompt + Demonstrate baseline, few-shot, DSL, hardware-aware, AMD, and custom prompt generation. Customize the reference architecture or custom_prompt_key if you want to try different inputs. """ @@ -413,6 +506,7 @@ def test_prompt(): print("Testing prompt construction...") scratch_dir = os.path.join(REPO_TOP_PATH, "scratch") + # baseline prompt baseline_prompt = get_prompt_for_backend( ref_arch_src=ref_arch_src, @@ -441,7 +535,7 @@ def test_prompt(): ) log_prompt(dsl_prompt, os.path.join(scratch_dir), "dsl_prompt.txt") - # hardware prompt + # NVIDIA hardware prompt hardware_prompt = get_prompt_for_backend( ref_arch_src=ref_arch_src, backend="cute", @@ -449,9 +543,23 @@ def test_prompt(): precision="fp32", include_hardware=True, gpu_name="L40S", + vendor="nvidia", ) log_prompt(hardware_prompt, os.path.join(scratch_dir), "hardware_prompt.txt") + + # AMD hardware prompt (RDNA4 - R9700) + amd_rdna4_prompt = get_prompt_for_backend( + ref_arch_src=ref_arch_src, + backend="triton", + option="one_shot", + precision="fp32", + include_hardware=True, + gpu_name="R9700", + vendor="amd", + ) + + # custom prompt defined in prompts.toml custom_prompt = get_custom_prompt( # the key is whatever you name the prompt in the custom_prompts section of the toml file @@ -463,6 +571,7 @@ def test_prompt(): precision="fp32", include_hardware=True, gpu_name="L40S", + vendor="nvidia", ) log_prompt(custom_prompt, os.path.join(scratch_dir), "custom_prompt.txt") diff --git a/src/kernelbench/prompts/hardware/gpu_specs.py b/src/kernelbench/prompts/hardware/gpu_specs.py index 800f20ef..f9e343d0 100644 --- a/src/kernelbench/prompts/hardware/gpu_specs.py +++ b/src/kernelbench/prompts/hardware/gpu_specs.py @@ -1,9 +1,14 @@ """ A List of GPU Specs to include in the prompt +Supports both NVIDIA and AMD GPUs. """ +# ============================================================================= +# NVIDIA GPU Specifications +# ============================================================================= + GPU_SPEC_INFO = { "L40S": { "GPU Architecture": "Ada", @@ -121,21 +126,129 @@ } } -# Basic GPU concept definitions +# ============================================================================= +# AMD GPU Specifications +# ============================================================================= + +AMD_GPU_SPEC_INFO = { + # Based on provided rocminfo for AMD Radeon 9700 (gfx1201) + "R9700": { + "GPU Name": "AMD Radeon 9700 (gfx1201)", + "GPU Architecture": "AMD RDNA4 (gfx1201)", + "Compute Units": 64, + "SIMDs per CU": 2, + "Shader Engines": 4, + "Shader Arrays per Engine": 2, + "Wavefront Size": "Wave32", + "Max Clock (MHz)": 2350, + "Workgroup Max Size": 1024, + "Max Waves per CU": 32, + "Stream Processors": 4096, + "Ray Accelerators": 64, + "AI Accelerators": 128, + "ROPs": 128, + "Transistors": "53.9 Billion", + "Peak Pixel Fill Rate": "373.76 GP/s", + "L1 Cache": "32 KB", + "L2 Cache": "8 MB", + "L3 Cache": "64 MB", + "Cacheline Size": "256 B", + "LDS (Workgroup Local Memory)": "64 KB", + "VRAM": "32,061,259,776 B (~29.85 GiB)", + "Memory Bandwidth": "Unknown", + "FP32 Vector TFLOPS": "47.8", + "FP16 Vector TFLOPS": "95.7", + "FP16 Matrix TFLOPS": "191 (383 w/ sparsity)", + "FP8 Matrix TFLOPS": "383 (766 w/ sparsity)", + "INT8 Matrix TOPS": "383 (766 w/ sparsity)", + "INT4 Matrix TOPS": "766 (1531 w/ sparsity)", + "Max Registers per Block": 196608, + "Max Shared Memory per Block": 65536, + "Max Threads per Block": 1024, + "Max Threads per CU": 2048, + "Shared Memory per CU": 2097152, + "Warp Size": 32, + "MFMA": "Unknown", + }, + # Based on provided rocminfo + HIP query for AMD Radeon PRO W7900D (gfx1100) + "W7900D": { + "GPU Name": "AMD Radeon PRO W7900D (gfx1100)", + "GPU Architecture": "AMD RDNA3 (gfx1100)", + "Compute Units": 96, + "SIMDs per CU": 2, + "Shader Engines": 6, + "Shader Arrays per Engine": 2, + "Wavefront Size": "Wave32", + "Max Clock (MHz)": 1760, + "Workgroup Max Size": 1024, + "Max Waves per CU": 32, + "Max Work-item per CU": 1024, + "L1 Cache": "32 KB", + "L2 Cache": "6 MB", + "L3 Cache": "96 MB", + "Cacheline Size": "128 B", + "LDS (Workgroup Local Memory)": "64 KB", + "VRAM": "Unknown", + "Memory Bandwidth": "Unknown", + "Max Registers per Block": 196608, + "Max Shared Memory per Block": 65536, + "Max Threads per Block": 1024, + "Max Threads per CU": 2048, + "Shared Memory per CU": 3145728, + "Warp Size": 32, + "MFMA": "Unknown", + }, +} + +# ============================================================================= +# GPU Concept Definitions +# ============================================================================= + +# Basic GPU concept definitions (NVIDIA-centric) GPU_DEFINITIONS = { "Thread": "A thread is a single execution unit that can run a single instruction at a time.", "Thread Block": "A thread block is a group of threads that can cooperate with each other.", "Warp": "A warp is a group of threads that are scheduled together and execute in parallel.", + "SM": "A Streaming Multiprocessor, the core execution unit on NVIDIA GPUs.", + "Tensor Core": "Specialized units for mixed-precision matrix operations.", + "Occupancy": "The ratio of active warps to the maximum supported on an SM.", "Shared Memory": "Shared memory is a memory space that can be accessed by all threads in a thread block.", + "Shared Memory Bank": "A subdivision of shared memory that can cause bank conflicts.", "Register": "A register is a small memory space that can be accessed by a single thread.", + "Global Memory": "Off-chip DRAM accessible by all threads on the GPU.", + "Constant Memory": "Read-only cached memory optimized for uniform access.", + "Coalesced Access": "Memory access pattern that combines multiple requests into fewer transactions.", + "Divergence": "When threads in the same warp take different control paths.", "Memory Hierarchy": "Memory hierarchy is a pyramid of memory types with different speeds and sizes.", "Memory Bandwidth": "Memory bandwidth is the rate at which data can be read from or stored into memory.", "Cache": "Cache is a small memory space that stores frequently accessed data.", "HBM": "HBM is a high-bandwidth memory technology that uses 3D-stacked DRAM.", } +# AMD GPU concept definitions +AMD_GPU_DEFINITIONS = { + "Wavefront": "AMD's SIMD execution group (Wave32 or Wave64).", + "Wave32": "A 32-lane wavefront, common on RDNA architectures.", + "Wave64": "A 64-lane wavefront, common on CDNA architectures.", + "Compute Unit (CU)": "AMD's equivalent of an NVIDIA SM.", + "Work-item": "A single thread in a kernel execution.", + "Workgroup": "A group of work-items that can synchronize and share LDS.", + "SIMD": "A SIMD unit inside a CU that executes a wavefront.", + "LDS": "Local Data Share, AMD's shared memory.", + "VGPR": "Vector registers allocated per work-item.", + "SGPR": "Scalar registers shared across a wavefront.", + "Occupancy": "Number of active waves per CU, limited by registers and LDS.", + "Infinity Cache": "AMD's last-level cache that reduces DRAM traffic.", + "MFMA": "Matrix Fused Multiply-Add instruction for matrix cores.", + "Barrier": "A workgroup synchronization point.", +} + +# ============================================================================= +# Best Practices +# ============================================================================= + GPU_BEST_PRACTICES = [ # From https://docs.nvidia.com/cuda/ada-tuning-guide/index.html # CUDA Best Practices Section @@ -145,6 +258,80 @@ "Ensure that global memory accesses are coalesced.", "Minimize redundant accesses to global memory whenever possible.", "Avoid long sequences of diverged execution by threads within the same warp.", + "Use shared memory to cache data that is reused within a block.", + "Avoid shared memory bank conflicts; pad arrays when needed.", + "Balance occupancy against register and shared memory usage.", + "Use vectorized loads/stores when they improve bandwidth.", + "Prefer tensor cores for matrix operations when supported.", + "Use streams to overlap compute and data transfers.", + "Use asynchronous copy features (e.g., cp.async) when available.", # we added this to reference the specific GPU architecture "Use specialized instructions based on the specific GPU architecture", -] \ No newline at end of file +] + +AMD_GPU_BEST_PRACTICES = [ + "Prefer Wave32-friendly configurations on RDNA architectures.", + "Prefer Wave64 on CDNA unless the kernel benefits from Wave32.", + "Choose workgroup sizes as multiples of the wavefront (32 or 64).", + "Start with workgroup sizes in [256, 512, 1024] for 1D kernels.", + "Balance VGPR usage and occupancy; avoid register spilling.", + "Use LDS for data reuse; pad to avoid LDS bank conflicts.", + "Keep global memory access contiguous and aligned (128B where possible).", + "Use vectorized loads/stores when it improves bandwidth utilization.", + "Use MFMA/matrix cores for GEMM-like operations when available.", + "Minimize divergent branches within a wavefront.", + "Avoid fp16 for exp/log; cast to fp32 for numerically sensitive ops.", +] + +# ============================================================================= +# AMD-Specific Prompt Templates-In progress +# ============================================================================= + + + +# ============================================================================= +# Helper Functions for GPU Detection +# ============================================================================= + +def get_gpu_vendor() -> str: + """ + Detect the GPU vendor (nvidia or amd). + Returns: 'nvidia', 'amd', or 'unknown' + """ + try: + import torch + if not torch.cuda.is_available(): + return "unknown" + # Check for HIP version (ROCm indicator) + if hasattr(torch.version, 'hip') and torch.version.hip is not None: + return "amd" + return "nvidia" + except ImportError: + return "unknown" + + +def get_gpu_specs_for_vendor(vendor: str) -> dict: + """ + Get appropriate GPU specs dictionary based on vendor. + """ + if vendor.lower() == "amd": + return AMD_GPU_SPEC_INFO + return GPU_SPEC_INFO + + +def get_gpu_definitions_for_vendor(vendor: str) -> dict: + """ + Get appropriate GPU definitions dictionary based on vendor. + """ + if vendor.lower() == "amd": + return AMD_GPU_DEFINITIONS + return GPU_DEFINITIONS + + +def get_gpu_best_practices_for_vendor(vendor: str) -> list: + """ + Get appropriate best practices list based on vendor. + """ + if vendor.lower() == "amd": + return AMD_GPU_BEST_PRACTICES + return GPU_BEST_PRACTICES \ No newline at end of file diff --git a/src/kernelbench/prompts/prompts.toml b/src/kernelbench/prompts/prompts.toml index 2768aa11..4e264d80 100644 --- a/src/kernelbench/prompts/prompts.toml +++ b/src/kernelbench/prompts/prompts.toml @@ -142,6 +142,7 @@ other placeholder supported in the shared context. # ------------------------------------------------------------------------- # Hardware Templates: GPU-specific information blocks +# Supports both NVIDIA and AMD GPUs via {gpu_vendor_display} placeholder # ------------------------------------------------------------------------- [templates.hardware] hardware_header = """ @@ -149,7 +150,7 @@ Here is some information about the underlying hardware that you should keep in m """ hardware_specs = """ -The GPU that will run the kernel is NVIDIA {gpu_name}, {gpu_architecture} architecture. +The GPU that will run the kernel is {gpu_vendor_display} {gpu_name}, {gpu_architecture} architecture. {gpu_specs_bullets} """ @@ -166,6 +167,11 @@ Here are some best practices for writing kernels on GPU: {gpu_best_practices_bullets} """ +# AMD-specific optimization guidance (only included for AMD GPUs) +amd_optimization_guidance = """ +{amd_optimization_guidance} +""" + # ------------------------------------------------------------------------- # Options: Different prompt construction modes # ------------------------------------------------------------------------- diff --git a/src/kernelbench/utils.py b/src/kernelbench/utils.py index cf8b0ad8..c7efbaec 100644 --- a/src/kernelbench/utils.py +++ b/src/kernelbench/utils.py @@ -18,8 +18,10 @@ from importlib.resources import files, as_file # API clients +from together import Together from openai import OpenAI -from litellm import completion +import google.generativeai as genai +import anthropic import numpy as np from contextlib import contextmanager @@ -27,17 +29,41 @@ import time import concurrent from functools import cache - +from transformers import AutoTokenizer from concurrent.futures import ProcessPoolExecutor, as_completed -SGLANG_KEY = os.environ.get("SGLANG_API_KEY") +# Define API key access +TOGETHER_KEY = os.environ.get("TOGETHER_API_KEY") +DEEPSEEK_KEY = os.environ.get("DEEPSEEK_API_KEY") +OPENAI_KEY = os.environ.get("OPENAI_API_KEY") +GEMINI_KEY = os.environ.get("GEMINI_API_KEY") +SGLANG_KEY = os.environ.get("SGLANG_API_KEY") # for Local Deployment +ANTHROPIC_KEY = os.environ.get("ANTHROPIC_API_KEY") +SAMBANOVA_API_KEY = os.environ.get("SAMBANOVA_API_KEY") +FIREWORKS_API_KEY = os.environ.get("FIREWORKS_API_KEY") ######################################################## # Inference Helpers ######################################################## +@cache +def load_deepseek_tokenizer(): + return AutoTokenizer.from_pretrained("deepseek-ai/DeepSeek-V2", trust_remote_code=True) + +# Buffer because deepseek totally blocks us if we send stuff that's too long :( +TOO_LONG_FOR_DEEPSEEK = 115_000 + +def is_safe_to_send_to_deepseek(prompt): + tokenizer = load_deepseek_tokenizer() + if type(prompt) == str: + return ( + len(tokenizer(prompt, verbose=False)["input_ids"]) < TOO_LONG_FOR_DEEPSEEK + ) + else: + return len(tokenizer.apply_chat_template(prompt)) < TOO_LONG_FOR_DEEPSEEK + def set_gpu_arch(arch_list: list[str]): """ Set env variable for torch cuda arch list to build kernels for specified architectures @@ -69,18 +95,231 @@ def query_server( ): """ Query various sort of LLM inference API providers - Done through liteLLM: - - Local Server (SGLang, vLLM, Tokasaurus) + Supports: + - OpenAI (AMD LLM Gateway) + - Deepseek + - Together + - Sambanova + - Anthropic + - Gemini / Google AI Studio + - Fireworks (OpenAI compatbility) + - SGLang (Local Server) """ - # Local Server (SGLang, vLLM, Tokasaurus) - special handling - if server_type == "local": - url = f"http://{server_address}:{server_port}" - client = OpenAI( - api_key=SGLANG_KEY, base_url=f"{url}/v1", timeout=None, max_retries=0 + # Select model and client based on arguments + match server_type: + case "sglang": + url = f"http://{server_address}:{server_port}" + client = OpenAI( + api_key=SGLANG_KEY, base_url=f"{url}/v1", timeout=None, max_retries=0 + ) + model = "default" + case "deepseek": + client = OpenAI( + api_key=DEEPSEEK_KEY, + base_url="https://api.deepseek.com", + timeout=10000000, + max_retries=3, + ) + model = model_name + assert model in ["deepseek-chat", "deepseek-coder", "deepseek-reasoner"], "Only support deepseek-chat or deepseek-coder for now" + if not is_safe_to_send_to_deepseek(prompt): + raise RuntimeError("Prompt is too long for DeepSeek") + case "fireworks": + client = OpenAI( + api_key=FIREWORKS_API_KEY, + base_url="https://api.fireworks.ai/inference/v1", + timeout=10000000, + max_retries=3, + ) + model = model_name + + case "anthropic": + client = anthropic.Anthropic( + api_key=ANTHROPIC_KEY, + ) + model = model_name + case "google": + genai.configure(api_key=GEMINI_KEY) + model = model_name + case "together": + client = Together(api_key=TOGETHER_KEY) + model = model_name + case "sambanova": + client = OpenAI(api_key=SAMBANOVA_API_KEY, base_url="https://api.sambanova.ai/v1") + model = model_name + + case "openai": + # AMD LLM Gateway + client = OpenAI( + base_url="https://llm-api.amd.com/OpenAI", + api_key="dummy", + default_headers={ + "Ocp-Apim-Subscription-Key": os.environ.get("LLM_GATEWAY_KEY"), + } + ) + model = model_name + case _: + raise NotImplementedError(f"Server type {server_type} not supported") + + if server_type != "google": + assert client is not None, "Client is not set, cannot proceed to generations" + else: + print( + f"Querying {server_type} {model} with temp {temperature} max tokens {max_tokens}" ) - if isinstance(prompt, str): + # Logic to query the LLM + if server_type == "anthropic": + assert type(prompt) == str + + if is_reasoning_model: + # Use beta endpoint with thinking enabled for reasoning models + response = client.beta.messages.create( + model=model, + system=system_prompt, + messages=[ + {"role": "user", "content": prompt}, + ], + max_tokens=max_tokens, + # Claude thinking requires budget_tokens for thinking (reasoning) + thinking={"type": "enabled", "budget_tokens": budget_tokens}, + betas=["output-128k-2025-02-19"], + ) + else: + # Use standard endpoint for normal models + response = client.messages.create( + model=model, + system=system_prompt, + messages=[ + {"role": "user", "content": prompt}, + ], + temperature=temperature, + top_p=top_p, + top_k=top_k, + max_tokens=max_tokens, + ) + outputs = [choice.text for choice in response.content if not hasattr(choice, 'thinking') or not choice.thinking] + + elif server_type == "google": + generation_config = { + "temperature": temperature, + "top_p": top_p, + "top_k": top_k, + "max_output_tokens": max_tokens, + "response_mime_type": "text/plain", + } + + model = genai.GenerativeModel( + model_name=model_name, + system_instruction=system_prompt, + generation_config=generation_config, + ) + + response = model.generate_content(prompt) + + return response.text + + elif server_type == "deepseek": + + if model in ["deepseek-chat", "deepseek-coder"]: + # regular deepseek model + response = client.chat.completions.create( + model=model, + messages=[ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": prompt}, + ], + stream=False, + temperature=temperature, + n=num_completions, + max_tokens=max_tokens, + top_p=top_p, + ) + + else: # deepseek reasoner + assert is_reasoning_model, "Only support deepseek-reasoner for now" + assert model == "deepseek-reasoner", "Only support deepseek-reasoner for now" + response = client.chat.completions.create( + model=model, + messages=[ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": prompt}, + ], + stream=False, + n=num_completions, + max_tokens=max_tokens, + # do not use temperature or top_p + ) + outputs = [choice.message.content for choice in response.choices] + elif server_type == "openai": + if is_reasoning_model: + assert "o1" in model or "o3" in model, "Only support o1 and o3 for now" + print(f"Using OpenAI reasoning model: {model} with reasoning effort {reasoning_effort}") + response = client.chat.completions.create( + model=model, + messages=[ + {"role": "user", "content": prompt}, + ], + reasoning_effort=reasoning_effort, + ) + else: + response = client.chat.completions.create( + model=model, + messages=[ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": prompt}, + ], + stream=False, + temperature=temperature, + n=num_completions, + max_tokens=max_tokens, + top_p=top_p, + ) + outputs = [choice.message.content for choice in response.choices] + elif server_type == "together": + response = client.chat.completions.create( + model=model, + max_tokens=max_tokens, + temperature=temperature, + messages=[ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": prompt}, + ], + top_p=top_p, + top_k=top_k, + stop=["<|eot_id|>", "<|eom_id|>"], + stream=False, + ) + outputs = [choice.message.content for choice in response.choices] + elif server_type == "fireworks": + response = client.chat.completions.create( + model=model, + max_tokens=max_tokens, + temperature=temperature, + messages=[ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": prompt}, + ], + stop=["<|eot_id|>", "<|eom_id|>"], + stream=False, + ) + outputs = [choice.message.content for choice in response.choices] + elif server_type == "sambanova": + response = client.chat.completions.create( + model=model, + max_tokens=max_tokens, + temperature=temperature, + messages=[ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": prompt}, + ], + top_p=top_p, + ) + outputs = [choice.message.content for choice in response.choices] + # for all other kinds of servers, use standard API + else: + if type(prompt) == str: response = client.completions.create( - model="default", + model=model, prompt=prompt, temperature=temperature, n=num_completions, @@ -90,7 +329,7 @@ def query_server( outputs = [choice.text for choice in response.choices] else: response = client.chat.completions.create( - model="default", + model=model, messages=prompt, temperature=temperature, n=num_completions, @@ -98,105 +337,42 @@ def query_server( top_p=top_p, ) outputs = [choice.message.content for choice in response.choices] - - # output processing - if len(outputs) == 1: - return outputs[0] - else: - return outputs - - # All other providers - use LiteLLM unified interface - # Build messages list with system prompt first (if not already present) - messages = [] - - # Check if prompt is already a list with a system message - if isinstance(prompt, list) and prompt and prompt[0].get("role") == "system": - # Prompt already has system message, use it directly - messages = prompt + + # output processing + if len(outputs) == 1: + return outputs[0] else: - # Add system prompt first if provided - if system_prompt: - messages.append({"role": "system", "content": system_prompt}) - - # Then add the actual prompt - if isinstance(prompt, str): - messages.append({"role": "user", "content": prompt}) - else: - messages.extend(prompt) - - try: - completion_kwargs = { - "model": model_name, - "messages": messages, - "max_tokens": max_tokens, - "n": num_completions, - } - - # Reasoning models (o1, o3, etc.) don't support standard sampling params - if is_reasoning_model: - # Note: o1/o3 models don't support temperature, top_p, top_k - # LiteLLM will pass through reasoning_effort for OpenAI o1/o3 models - if reasoning_effort: - completion_kwargs["reasoning_effort"] = reasoning_effort - # Claude extended thinking uses "thinking" parameter with dict structure - # Format: {"type": "enabled", "budget_tokens": } - if budget_tokens > 0 and "anthropic" in model_name.lower(): - completion_kwargs["thinking"] = {"type": "enabled", "budget_tokens": budget_tokens} - else: - # Standard models support temperature and top_p - completion_kwargs["temperature"] = temperature - completion_kwargs["top_p"] = top_p - - # top_k is not supported by OpenAI models - if "openai/" not in model_name.lower() and "gpt" not in model_name.lower(): - completion_kwargs["top_k"] = top_k - - response = completion(**completion_kwargs) - - # output processing - if num_completions == 1: - content = response.choices[0].message.content - if content is None: - raise ValueError(f"LLM returned None content for model {model_name}. finish_reason: {response.choices[0].finish_reason}") - return content - else: - contents = [choice.message.content for choice in response.choices] - if any(c is None for c in contents): - raise ValueError(f"LLM returned None content in one or more completions for model {model_name}") - return contents - except Exception as e: - print(f"Error in query_server for model {model_name}: {e}") - raise + return outputs # a list of presets for API server configs SERVER_PRESETS = { "deepseek": { "temperature": 1.6, - "model_name": "deepseek/deepseek-coder", + "model_name": "deepseek-chat", "max_tokens": 4096 }, "google": { - "model_name": "gemini/gemini-2.5-flash", + "model_name": "gemini-1.5-flash-002", "temperature": 0.7, # need to experiment with temperature - "max_tokens": 16384, + "max_tokens": 8192, }, "together": { # mostly for Llama 3.1 - "model_name": "together_ai/meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo", + "model_name": "meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo", # "model_name": "meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo", "temperature": 0.7, "max_tokens": 4096, }, - "local": { # this is for running locally (SGLang, vLLM, Tokasaurus), mostly for Llama + "sglang": { # this is for running locally, mostly for Llama "temperature": 0.8, # human eval pass@N temperature "server_port": 10210, "server_address": "matx2.stanford.edu", "max_tokens": 8192, }, - "anthropic": { # for Claude 3.7 Sonnet - "model_name": "anthropic/claude-3-7-sonnet-20250219", + "anthropic": { # for Claude 3.5 Sonnet + "model_name": "claude-3-5-sonnet-20241022", "temperature": 0.8, - "max_tokens": 8192, + "max_tokens": 4096, }, "openai": { "model_name": "gpt-4o-2024-08-06", @@ -204,10 +380,10 @@ def query_server( "temperature": 0.0, "max_tokens": 4096, }, - "fireworks": { - "model_name": "fireworks_ai/llama-v3p1-70b-instruct", - "temperature": 0.7, - "max_tokens": 4096, + "sambanova": { + "model_name": "Meta-Llama-3.1-405B-Instruct", + "temperature": 0.1, + "max_tokens": 8192, }, } @@ -216,7 +392,6 @@ def create_inference_server_from_presets(server_type: str = None, greedy_sample: bool = False, verbose: bool = False, time_generation: bool = False, - model_name: str = None, **kwargs, ) -> callable: """ @@ -224,21 +399,15 @@ def create_inference_server_from_presets(server_type: str = None, """ def _query_llm(prompt: str | list[dict]): server_args = SERVER_PRESETS[server_type].copy() - - if model_name is not None and model_name != "None": - server_args["model_name"] = model_name - + if kwargs: - filtered_kwargs = {k: v for k, v in kwargs.items() if v is not None and v != "None"} - server_args.update(filtered_kwargs) - + server_args.update(kwargs) if greedy_sample: server_args["temperature"] = 0.0 server_args["top_p"] = 1.0 server_args["top_k"] = 1 - if verbose: - print(f"Querying server {server_type} with model {server_args['model_name']} and args: {server_args}") + print(f"Querying server {server_type} with args: {server_args}") if time_generation: start_time = time.time()