Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,10 @@ We have transitioned to using `pyproject.toml` and `uv` for dependency managemen
# Install base dependencies (works without a local GPU)
uv sync

# Install ROCm-enabled PyTorch (pick the correct ROCm version for your system):

uv pip install torch torchvision --index-url https://download.pytorch.org/whl/rocm7.1

# Install with GPU dependencies (for local GPU evaluation)
uv sync --extra gpu

Expand Down Expand Up @@ -115,9 +119,9 @@ uv run python scripts/generate_and_eval_single_sample.py dataset_src=huggingface
```

**What you might need to modify**
* **`gpu_arch`** - Depend on your GPU, you might need to adjust the `gpu_arch` argument to reflect your hardware.
* **`gpu_arch`** - Depend on your GPU, you might need to adjust the `gpu_arch` argument to reflect your hardware. currently supported `["gfx1100"]` (W7900D), `["gfx1201"]` (R9700).
* **`precision`** - You can specify the precision of tensor by `precision=fp32`. Currently all of our reported results are `fp32` but we added support for `fp16` & `bf16`.
* **`backend`** - We are also supporting other GPU programming languages beyond `cuda`. Simply specify `backend=triton`. For now we support DSLs: `cuda`, `triton`, `cute`, `tilelang`, `thunderkittens`.
* **`backend`** - We are also supporting other GPU programming languages beyond `cuda`. Simply specify `backend=triton`. For now we support DSLs: `cuda`, `triton`, `cute`, `tilelang`, `thunderkittens`. Note: ROCm GPUs currently use `backend=triton`.


Note on setting up ThunderKittens (TK) locally: to use `backend=thunderkittens`, you need to git clone the ThunderKittens repo and set the following environment variable to point to your local ThunderKittens directory, `export THUNDERKITTENS_ROOT=<PATH to ThunderKittens folder>`, and all ThunderKitten programs as shown in the [example](src/kernelbench/prompts/model_new_ex_add_thunderkittens.py), should contain `tk_root = os.environ.get("THUNDERKITTENS_ROOT", "/root/ThunderKittens")`, which enable the kernel to include the right TK primitives. In addition, we only support BF16 for TK right now.
Expand Down
19 changes: 14 additions & 5 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,11 @@ dependencies = [
# Frameworks
"torch==2.9.0",

"pytorch-triton-rocm>=3.4.0",
"transformers",
"datasets",
"modal",
"ruff",

# helper
"tqdm",
Expand All @@ -37,12 +39,9 @@ dependencies = [

[project.optional-dependencies]
gpu = [
# GPU-specific dependencies (requires CUDA)
# GPU-specific dependencies (ROCm / AMD Radeon)
"triton",
"nvidia-cutlass-dsl",
"tilelang",
"cupy-cuda12x",
"nsight-python",
]
dev = [
"pytest",
Expand All @@ -55,4 +54,14 @@ where = ["src"]
include = ["kernelbench*"]

[tool.setuptools.package-data]
kernelbench = ["prompts/**/*"]
kernelbench = ["prompts/**/*"]

[tool.uv.sources]
torch = [{ index = "pytorch-rocm" }]
torchvision = [{ index = "pytorch-rocm" }]
pytorch-triton-rocm = [{ index = "pytorch-rocm" }]

[[tool.uv.index]]
name = "pytorch-rocm"
url = "https://download.pytorch.org/whl/rocm6.4"
explicit = true
53 changes: 28 additions & 25 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,35 +1,38 @@
# ARCHIVED: We are transitioning to pyproject.toml and uv-based project management
# However, we provide this as a backup for now

# Frameworks
# we use latest PyTorch stable release
torch==2.9.*
triton==3.5.*

# torch==2.5.0
# we shall upgrade torch for blackwell when it is stable
transformers>=4.57.3
datasets>=4.4.2
modal>=1.3.0
# AMD ROCm note: install ROCm-enabled torch from the PyTorch ROCm index.
# Current ROCm env:
# torch==2.8.0+rocm7.1.1.gitcba8b9d2
# HIP==7.1.52802-26aae437f6
# ROCm SMI (concise):
# Device IDs: 0x7551 x4
transformers
datasets
modal

# DSLs
nvidia-cutlass-dsl
tilelang
# nvidia-cutlass-dsl
# triton (required for AMD ROCm kernels)
# helion (optional, Helion DSL; install separately if needed)

# helper
tqdm>=4.67.1
tqdm
packaging
pydra-config
ninja>=1.13.0
cupy-cuda12x==13.6.0
tomli>=2.3.0
tabulate>=0.9.0
nsight-python
pydra_config
dill>=0.3.7,<0.4
pytest
ninja

# Numerics
einops>=0.8.1
python-dotenv>=1.2.1
numpy==2.4.0
einops
dotenv
numpy

# to deprecate with litellm
google-generativeai
together
openai
anthropic
pydantic==2.12.4

# use litellm for cloud providers and openai for local
openai>=2.14.0
litellm[proxy]>=1.80.10
Loading