Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ members = [
"crypto/stark",
"crypto/crypto",
"crypto/math",
"crypto/math-cuda",
"bin/cli",
]

Expand Down
10 changes: 9 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
.PHONY: deps deps-linux deps-macos prepare-test-data compile-programs-asm compile-programs-rust compile-bench \
compile-programs clean-asm clean-rust clean-bench clean-shared clean test test-asm test-no-compile \
test-asm-no-compile test-rust test-rust-no-compile test-executor flamegraph-prover \
test-fast test-prover test-prover-all build check clippy fmt lint
test-fast test-prover test-prover-all test-math-cuda bench-math-cuda build check clippy fmt lint

UNAME := $(shell uname)

Expand Down Expand Up @@ -166,6 +166,14 @@ test-prover-all:
test-prover-debug:
cargo test -p lambda-vm-prover --features debug-checks -- --nocapture

# math-cuda parity tests (requires NVIDIA GPU + nvcc)
test-math-cuda:
cargo test -p math-cuda --release

# math-cuda quick microbench (median of 10 runs)
bench-math-cuda:
cargo test -p math-cuda --release --test bench_quick -- --ignored --nocapture

# Build all
build:
cargo build --workspace
Expand Down
23 changes: 23 additions & 0 deletions crypto/math-cuda/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
[package]
name = "math-cuda"
description = "CUDA-accelerated FFT/NTT for Goldilocks (base field) used by the lambda-vm STARK prover"
version = "0.1.0"
edition = "2024"

[dependencies]
cudarc = { version = "0.19", default-features = false, features = [
"driver",
"nvrtc",
"std",
"cuda-version-from-build-system",
"fallback-latest",
"dynamic-loading",
] }
math = { path = "../math" }
rayon = "1.7"

[dev-dependencies]
rand = { version = "0.8.5", features = ["std"] }
rand_chacha = "0.3.1"
rayon = "1.7"
sha3 = "0.10.8"
113 changes: 113 additions & 0 deletions crypto/math-cuda/build.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
use std::env;
use std::fs;
use std::path::PathBuf;
use std::process::Command;

fn cuda_home() -> PathBuf {
env::var_os("CUDA_HOME")
.or_else(|| env::var_os("CUDA_PATH"))
.map(PathBuf::from)
.unwrap_or_else(|| PathBuf::from("/usr/local/cuda"))
}

fn nvcc_path() -> PathBuf {
cuda_home().join("bin").join("nvcc")
}

/// Query `nvidia-smi` for the local GPU's compute capability (e.g. "12.0"
/// for Blackwell). Returns a `compute_XX` target on success, falling back
/// to `compute_89` (Ada) when no GPU is visible or the query fails.
fn detect_arch() -> String {
const FALLBACK: &str = "compute_89";
let output = match Command::new("nvidia-smi")
.args(["--query-gpu=compute_cap", "--format=csv,noheader"])
.output()
{
Ok(o) if o.status.success() => o,
_ => return FALLBACK.to_string(),
};
let line = match std::str::from_utf8(&output.stdout) {
Ok(s) => s,
Err(_) => return FALLBACK.to_string(),
};
// First line, first comma-separated value (covers multi-GPU hosts).
let cap = match line.lines().next() {
Some(l) => l.split(',').next().unwrap_or("").trim(),
None => return FALLBACK.to_string(),
};
let (major, minor) = match cap.split_once('.') {
Some((m, n)) => (m.trim(), n.trim()),
None => return FALLBACK.to_string(),
};
if major.chars().all(|c| c.is_ascii_digit()) && minor.chars().all(|c| c.is_ascii_digit()) {
format!("compute_{major}{minor}")
} else {
FALLBACK.to_string()
}
}

fn compile_ptx(src: &str, out_name: &str, have_nvcc: bool) {
let manifest_dir = PathBuf::from(env::var("CARGO_MANIFEST_DIR").unwrap());
let out_dir = PathBuf::from(env::var("OUT_DIR").unwrap());
let src_path = manifest_dir.join("kernels").join(src);
let out_path = out_dir.join(out_name);

println!("cargo:rerun-if-changed=kernels/{src}");
println!("cargo:rerun-if-env-changed=CUDA_HOME");
println!("cargo:rerun-if-env-changed=CUDA_PATH");
println!("cargo:rerun-if-env-changed=CUDARC_NVCC_ARCH");

// When nvcc is missing from PATH, emit an empty PTX stub so the crate
// still compiles. include_str! in src/device.rs needs the file to exist
// at build time. Any runtime kernel call panics in cudarc when loading
// the empty module. We can't run GPU code without nvcc on the build
// host anyway.
if !have_nvcc {
fs::write(&out_path, "").expect("failed to write empty PTX stub");
return;
}

// Emit PTX for a virtual architecture; the CUDA driver JIT-compiles it for the
// actual GPU at load time. Override with CUDARC_NVCC_ARCH to pin a specific
// compute capability. If unset, try `nvidia-smi` to match the host GPU
// (avoids JIT failures like nvcc-13.0 PTX rejected on Blackwell drivers);
// fall back to compute_89 (Ada) when detection fails.
let arch = env::var("CUDARC_NVCC_ARCH").unwrap_or_else(|_| detect_arch());

let status = Command::new(nvcc_path())
.args(["--ptx", "-O3", "-std=c++17", "-arch", &arch, "-o"])
.arg(&out_path)
.arg(&src_path)
.status()
.expect("failed to invoke nvcc — is CUDA installed and CUDA_HOME set?");

if !status.success() {
panic!("nvcc failed compiling {}", src_path.display());
}
}

fn main() {
// Headers aren't compiled, so emit rerun-if-changed to rebuild on
// header edits.
println!("cargo:rerun-if-changed=kernels/goldilocks.cuh");
println!("cargo:rerun-if-changed=kernels/ext3.cuh");

// Probe for nvcc once. Workspace consumers (clippy, fmt, CPU-only test
// runners) build math-cuda incidentally without using its kernels. Stub
// out PTX when nvcc is unavailable so those builds succeed.
let have_nvcc = Command::new(nvcc_path())
.arg("--version")
.output()
.map(|o| o.status.success())
.unwrap_or(false);
if !have_nvcc {
println!(
"cargo:warning=math-cuda: nvcc not found at {} — emitting empty PTX stubs. \
Runtime GPU calls will panic. Install CUDA and rebuild for a working backend.",
nvcc_path().display()
);
}

compile_ptx("arith.cu", "arith.ptx", have_nvcc);
compile_ptx("ntt.cu", "ntt.ptx", have_nvcc);
}
97 changes: 97 additions & 0 deletions crypto/math-cuda/kernels/arith.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
// Element-wise Goldilocks kernels used by the parity tests. These mirror
// the CPU reference in `crypto/math/src/field/goldilocks.rs` so raw u64 outputs
// are bit-identical to the CPU path.

#include "goldilocks.cuh"
#include "ext3.cuh"

using goldilocks::add;
using goldilocks::sub;
using goldilocks::mul;
using goldilocks::neg;

extern "C" __global__ void vector_add_u64(const uint64_t *a,
const uint64_t *b,
uint64_t *c,
uint64_t n) {
uint64_t tid = blockIdx.x * blockDim.x + threadIdx.x;
if (tid < n) c[tid] = a[tid] + b[tid]; // plain wrapping u64 add — toolchain sanity only.
}

extern "C" __global__ void gl_add_kernel(const uint64_t *a,
const uint64_t *b,
uint64_t *c,
uint64_t n) {
uint64_t tid = blockIdx.x * blockDim.x + threadIdx.x;
if (tid < n) c[tid] = add(a[tid], b[tid]);
}

extern "C" __global__ void gl_sub_kernel(const uint64_t *a,
const uint64_t *b,
uint64_t *c,
uint64_t n) {
uint64_t tid = blockIdx.x * blockDim.x + threadIdx.x;
if (tid < n) c[tid] = sub(a[tid], b[tid]);
}

extern "C" __global__ void gl_mul_kernel(const uint64_t *a,
const uint64_t *b,
uint64_t *c,
uint64_t n) {
uint64_t tid = blockIdx.x * blockDim.x + threadIdx.x;
if (tid < n) c[tid] = mul(a[tid], b[tid]);
}

extern "C" __global__ void gl_neg_kernel(const uint64_t *a,
uint64_t *c,
uint64_t n) {
uint64_t tid = blockIdx.x * blockDim.x + threadIdx.x;
if (tid < n) c[tid] = neg(a[tid]);
}

// ---------------------------------------------------------------------------
// Ext3 (Goldilocks cubic extension) test kernels.
// Input/output arrays are interleaved [a_0, b_0, c_0, a_1, b_1, c_1, ...].
// ---------------------------------------------------------------------------

extern "C" __global__ void ext3_mul_kernel(const uint64_t *a_int,
const uint64_t *b_int,
uint64_t *c_int,
uint64_t n) {
uint64_t tid = blockIdx.x * blockDim.x + threadIdx.x;
if (tid >= n) return;
ext3::Fe3 a = ext3::make(a_int[tid*3 + 0], a_int[tid*3 + 1], a_int[tid*3 + 2]);
ext3::Fe3 b = ext3::make(b_int[tid*3 + 0], b_int[tid*3 + 1], b_int[tid*3 + 2]);
ext3::Fe3 r = ext3::mul(a, b);
c_int[tid*3 + 0] = r.a;
c_int[tid*3 + 1] = r.b;
c_int[tid*3 + 2] = r.c;
}

extern "C" __global__ void ext3_add_kernel(const uint64_t *a_int,
const uint64_t *b_int,
uint64_t *c_int,
uint64_t n) {
uint64_t tid = blockIdx.x * blockDim.x + threadIdx.x;
if (tid >= n) return;
ext3::Fe3 a = ext3::make(a_int[tid*3 + 0], a_int[tid*3 + 1], a_int[tid*3 + 2]);
ext3::Fe3 b = ext3::make(b_int[tid*3 + 0], b_int[tid*3 + 1], b_int[tid*3 + 2]);
ext3::Fe3 r = ext3::add(a, b);
c_int[tid*3 + 0] = r.a;
c_int[tid*3 + 1] = r.b;
c_int[tid*3 + 2] = r.c;
}

extern "C" __global__ void ext3_sub_kernel(const uint64_t *a_int,
const uint64_t *b_int,
uint64_t *c_int,
uint64_t n) {
uint64_t tid = blockIdx.x * blockDim.x + threadIdx.x;
if (tid >= n) return;
ext3::Fe3 a = ext3::make(a_int[tid*3 + 0], a_int[tid*3 + 1], a_int[tid*3 + 2]);
ext3::Fe3 b = ext3::make(b_int[tid*3 + 0], b_int[tid*3 + 1], b_int[tid*3 + 2]);
ext3::Fe3 r = ext3::sub(a, b);
c_int[tid*3 + 0] = r.a;
c_int[tid*3 + 1] = r.b;
c_int[tid*3 + 2] = r.c;
}
Loading
Loading