diff --git a/.gitignore b/.gitignore index 9c826f0d9..91cf42244 100644 --- a/.gitignore +++ b/.gitignore @@ -9,3 +9,8 @@ executor/program_artifacts/ # Shared cargo target directory for ELF builds executor/shared_target/ + +# Experiment artifacts — never commit +artifacts/ +profiles/ +*.bundle diff --git a/Cargo.lock b/Cargo.lock index f6eea84d6..7b6ed3c62 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -803,6 +803,15 @@ dependencies = [ "typenum", ] +[[package]] +name = "cudarc" +version = "0.19.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f071cd6a7b5d51607df76aa2d426aaabc7a74bc6bdb885b8afa63a880572ad9b" +dependencies = [ + "libloading", +] + [[package]] name = "darling" version = "0.21.3" @@ -1989,6 +1998,16 @@ version = "0.2.178" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "37c93d8daa9d8a012fd8ab92f088405fb202ea0b6ab73ee2482ae66af4f42091" +[[package]] +name = "libloading" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "754ca22de805bb5744484a5b151a9e1a8e837d5dc232c2d7d8c2e3492edc8b60" +dependencies = [ + "cfg-if", + "windows-link", +] + [[package]] name = "libm" version = "0.2.15" @@ -2105,6 +2124,18 @@ dependencies = [ "serde_json", ] +[[package]] +name = "math-cuda" +version = "0.1.0" +dependencies = [ + "cudarc", + "math", + "rand 0.8.5", + "rand_chacha 0.3.1", + "rayon", + "sha3", +] + [[package]] name = "memchr" version = "2.7.6" @@ -3172,6 +3203,7 @@ dependencies = [ "itertools 0.11.0", "log", "math", + "math-cuda", "rayon", "serde", "serde-wasm-bindgen", diff --git a/Cargo.toml b/Cargo.toml index 4d10b7c44..e43dc7f0d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -5,6 +5,7 @@ members = [ "crypto/stark", "crypto/crypto", "crypto/math", + "crypto/math-cuda", "bin/cli", ] diff --git a/Makefile b/Makefile index c02bffc49..7857c949d 100644 --- a/Makefile +++ b/Makefile @@ -1,7 +1,7 @@ .PHONY: deps deps-linux deps-macos prepare-test-data compile-programs-asm compile-programs-rust compile-bench \ compile-programs clean-asm clean-rust clean-bench clean-shared clean test test-asm test-no-compile \ test-asm-no-compile test-rust test-rust-no-compile test-executor flamegraph-prover \ -test-fast test-prover test-prover-all build check clippy fmt lint +test-fast test-prover test-prover-all build check clippy fmt lint test-cuda check-cuda UNAME := $(shell uname) @@ -193,3 +193,17 @@ lint: flamegraph-prover: cd crypto/stark && samply record cargo bench --bench profile_prover --features parallel + +# === CUDA === +# Run math-cuda tests (requires CUDA + a visible GPU). +test-cuda: + cargo test -p math-cuda + +check-cuda: + cargo check -p math-cuda + cargo check -p stark --features cuda + cargo check -p lambda-vm-prover --features cuda + +# Fast test suite with GPU LDE enabled (drop-in replacement for `test-fast`). +test-fast-cuda: + cargo test -p lambda-vm-prover -p stark -p executor -F stark/parallel,stark/cuda diff --git a/README.md b/README.md index df751528d..7137d7a04 100644 --- a/README.md +++ b/README.md @@ -177,6 +177,28 @@ cargo test --release -p lambda-vm-prover --features debug-checks -- --nocapture The feature is defined in `crypto/stark/Cargo.toml` and forwarded through `prover/Cargo.toml`. It has zero overhead when disabled. +## GPU acceleration (experimental) + +A CUDA backend for the per-column coset LDE (the `coset_lde_full_expand` hot path) lives in the `math-cuda` crate and is gated behind the `cuda` feature on `stark` / `lambda-vm-prover`. Requires CUDA 13.x with a visible NVIDIA GPU. Covers the Goldilocks base field only; extension-field columns and small LDEs transparently fall back to the CPU path. + +```sh +# Unit tests for the GPU kernels (parity against CPU, sizes up to 2^20): +make test-cuda + +# Full workspace check including the CUDA feature: +make check-cuda + +# `test-fast` with GPU LDE enabled: +make test-fast-cuda +``` + +Behaviour: +- The GPU path fires only when `buffer.len() * blowup_factor >= 2^19` and the column is `FieldElement`. Tune with `LAMBDA_VM_GPU_LDE_THRESHOLD=` at runtime. +- If the `cuda` feature is enabled and CUDA initialisation fails, the process panics with a clear message — there is no transparent fallback to CPU. +- The CPU-only build (default) is bit-for-bit identical to before; the feature is zero overhead when disabled. + +Status: on a single RTX 5090 with ~46 CPU cores and the current kernel set, end-to-end prove time ties the rayon-parallel CPU path on 1M–4M-instruction proofs. Wins on single-column LDE are ~16× at 2^18 sizes but are swallowed by CPU parallelism and per-call kernel launch overhead. Next steps for a real speedup are kernel fusion across NTT levels, CUDA graphs to amortise launch, keeping LDE on device through Merkle, and moving Keccak/constraint evaluation to GPU. + ## Roadmap for the virtual machine This project is under active development. Our primary objective is to have a first working version for the virtual machine. Priorities and features might change as we continue developing. diff --git a/bin/cli/Cargo.toml b/bin/cli/Cargo.toml index fdc8eab8c..5113792a8 100644 --- a/bin/cli/Cargo.toml +++ b/bin/cli/Cargo.toml @@ -15,3 +15,4 @@ tikv-jemalloc-ctl = { version = "0.6", features = ["stats"], optional = true } [features] jemalloc-stats = ["dep:tikv-jemalloc-ctl"] instruments = ["prover/instruments", "stark/instruments"] +cuda = ["prover/cuda"] diff --git a/crypto/crypto/src/merkle_tree/merkle.rs b/crypto/crypto/src/merkle_tree/merkle.rs index 55fa49a83..789adf1b6 100644 --- a/crypto/crypto/src/merkle_tree/merkle.rs +++ b/crypto/crypto/src/merkle_tree/merkle.rs @@ -54,6 +54,30 @@ where Self::build_from_hashed_leaves(hashed_leaves) } + /// Build a `MerkleTree` from an already-filled node vector whose layout + /// matches [`build_from_hashed_leaves`] output: + /// + /// - `nodes.len() == 2 * leaves_len - 1` where `leaves_len` is a power of two + /// - `nodes[0]` is the root + /// - `nodes[leaves_len - 1 .. 2*leaves_len - 1]` are the leaves + /// + /// Useful when the tree was constructed elsewhere (e.g. on a GPU) and + /// the caller just wants to hand the finished layout to the stark prover. + /// Performs no hashing. + pub fn from_precomputed_nodes(nodes: Vec) -> Option { + if nodes.is_empty() { + return None; + } + // Validate (cheap) that (nodes.len() + 1) is a power of two: there + // must be `leaves_len - 1 + leaves_len = 2*leaves_len - 1` entries. + let total = nodes.len(); + if !(total + 1).is_power_of_two() { + return None; + } + let root = nodes[ROOT].clone(); + Some(MerkleTree { root, nodes }) + } + /// Create a Merkle tree from pre-hashed leaf nodes. /// /// This skips the `hash_leaves` step, useful when leaves have already been diff --git a/crypto/math-cuda/Cargo.toml b/crypto/math-cuda/Cargo.toml new file mode 100644 index 000000000..8c22d1110 --- /dev/null +++ b/crypto/math-cuda/Cargo.toml @@ -0,0 +1,22 @@ +[package] +name = "math-cuda" +description = "CUDA-accelerated FFT/NTT for Goldilocks (base field) used by the lambda-vm STARK prover" +version = "0.1.0" +edition = "2024" + +[dependencies] +cudarc = { version = "0.19", default-features = false, features = [ + "driver", + "nvrtc", + "std", + "cuda-12080", + "dynamic-loading", +] } +math = { path = "../math" } +rayon = "1.7" + +[dev-dependencies] +rand = { version = "0.8.5", features = ["std"] } +rand_chacha = "0.3.1" +rayon = "1.7" +sha3 = "0.10.8" diff --git a/crypto/math-cuda/DESIGN_EXP11.md b/crypto/math-cuda/DESIGN_EXP11.md new file mode 100644 index 000000000..c4246098f --- /dev/null +++ b/crypto/math-cuda/DESIGN_EXP11.md @@ -0,0 +1,149 @@ +# Design: device-resident main trace (exp-11) + +Tracking the biggest remaining single win — eliminate redundant +main-trace host→device copies. Not yet implemented; this doc scopes +the work. Matches the pattern of exp-4 (tier-3 analysis), which +shipped as a checkpoint so the plan was preserved across context +windows. + +## Current state + +Fib_1M wall-time breakdown at exp-9 tip (15-trial mean 10.96 s): + +``` +Trace build 2.48 s 21.9% CPU (user-supervised) +Round 1 Phase A ~1.5 s 13% Main commits (LDE+Merkle on GPU) +Round 1 Phase B ~0 s LogUp challenges +Round 1 Pass 1 ~2.0 s 18% Aux-trace build (LogUp GPU, exp-9) +Round 1 Pass 2 ~1.0 s 9% Aux commits (ext3 LDE+Merkle) +Rounds 2–4 ~4.0 s 36% +``` + +Two places currently H2D the same main-trace data per proof: + +1. **Phase A** — `coset_lde_batch_base_into_with_merkle_tree_inner` + copies each column (total ~240 MB/table) from pinned staging into + a device buffer of size `m * lde_size`, then overwrites in place + with the iNTT result. The pre-LDE main trace is on device for + a few microseconds before the iNTT kernel starts. + +2. **Pass 1** — `logup_gpu::try_compute_table_term_columns` calls + `upload_main_cols` which does the exact same H2D again. Total + wall cost per-table is ~20–40 ms on a 32 GB/s PCIe link; exp-9 + serializes them so the total is ~200–300 ms wall on fib_1M. + +Both uploads carry identical bytes (the table's main columns); the +second one is pure waste. + +## The fix in two steps + +### Step 1 — preserve pre-LDE columns in the fused LDE kernel + +Modify `coset_lde_batch_base_into_with_merkle_tree_inner` to +optionally preserve the uploaded trace before iNTT. In the current +code, after line 769 (`memcpy_htod` loop) the first `n` u64s of each +column-slab hold the trace. A device-to-device copy to a fresh +`m*n` buffer just before the iNTT kernel is basically free (VRAM +bandwidth ≈ 1 TB/s; 240 MB copy takes <0.3 ms). + +Signature sketch: + +```rust +pub fn coset_lde_batch_base_into_with_merkle_tree_keep_main( + columns: &[&[u64]], + blowup_factor: usize, + weights: &[u64], + outputs: &mut [&mut [u64]], + merkle_nodes_out: &mut [u8], +) -> Result<(GpuLdeBase, Arc)> +``` + +The returned `DeviceMainCols` owns a `CudaSlice` sized `m * n` +in column-major order — directly what +`logup::logup_pair_term_column_on_device` already expects. + +### Step 2 — thread the handle to aux-build + +`MainTraceCommitResult` already holds an optional `GpuLdeBase` +(`gpu_main` field, line 172 of prover.rs). Add a sibling +`gpu_main_pre_lde: Option>`. Prover's `multi_prove` +already stashes the main LDE handle per-table; reuse the same +lookup pattern for `gpu_main_pre_lde`. + +Aux-build currently receives `&mut TraceTable` + `&[challenges]`. To +reach the per-table handle without changing trait signatures, add a +module-level `RwLock>>` in +`logup_gpu.rs` keyed by `trace as *const _ as usize`. Prover +populates after Phase A completes; aux-build consults; prover +clears after Pass 1. + +```rust +// in logup_gpu.rs +static PRE_LDE_CACHE: RwLock>> = + RwLock::new(HashMap::new()); + +pub fn store_pre_lde_main(trace_ptr: usize, handle: Arc); +pub fn take_pre_lde_main(trace_ptr: usize) -> Option>; +pub fn clear_pre_lde_cache(); +``` + +Inside `try_compute_table_term_columns`, skip `upload_main_cols` if +the cache has a handle for this trace pointer; drop back to the +existing H2D path otherwise (keeps the function correct for tables +that went through the non-GPU Phase A path). + +## Expected win + +- Per-table H2D saved: ~20–40 ms +- Total saved on fib_1M (12 tables × exp-9 serialized): 200–300 ms wall +- Aux-trace-build wall is currently ~2.0 s, so this lands it at ~1.7 s +- Total fib_1M projected: ~10.6 s (vs 10.96 s today) + +At larger sizes the gain scales: +- fib_4M: estimated 600–800 ms saved (same number of tables but more + rows, so each H2D is bigger and takes longer absolutely) + +## Risks / gotchas + +- **CudaSlice Send/Sync.** `DeviceMainCols` must be `Send + Sync` to + live in an `Arc` across rayon threads. cudarc 0.19 documents + `CudaSlice: Send + Sync where T: Send + Sync`, so u64 works. + Verify at the compile-error level, don't trust docs. +- **Cache key stability.** `trace as *const _ as usize` only works + while the `TraceTable` isn't moved. In Pass 1 the trace is behind + `&mut` and never reallocates, so the key is stable — but if anyone + later refactors the aux-build loop to move traces, the cache will + silently miss or (worse) hit a stale entry. Add a debug-assert on + length in `try_compute_table_term_columns` matching the cache's + stored `n`. +- **Cache lifetime.** The cache must be cleared at the start of each + prove so stale handles don't leak into the next proof. Simplest + location: `multi_prove` preamble. Alternative: a drop guard tied + to the outermost prover scope. +- **Phase-A CPU fallback.** When Phase A falls back to the CPU LDE + path (trace below the GPU threshold), no handle is produced and + aux-build correctly falls back to its existing H2D path. No + special-casing required. +- **Memory pressure on 32 GB VRAM.** Each pre-LDE buffer is + `num_cols * n * 8` bytes. For fib_4M's biggest table (MEMW_R × + 3.1M rows × ~30 cols = 750 MB) multiplied by 3 MEMW_R instances = + 2.25 GB. Plus LDE buffers (4× larger), that's ~11 GB — still fits + comfortably on an RTX 5090. If future work increases table count, + consider a drop-when-aux-build-finishes policy rather than + holding through Round 4. + +## Why this ships as a design, not code + +The plumbing touches: +- `crypto/math-cuda/src/lde.rs` (new fused-path variant) +- `crypto/math-cuda/src/logup.rs` (cache accessors) +- `crypto/stark/src/gpu_lde.rs` (wire through the keep variant) +- `crypto/stark/src/prover.rs` (populate cache, clear at prove start) +- `crypto/stark/src/logup_gpu.rs` (consult cache, fall back) + +~600–900 lines. Doable in a focused day, but not within the time +budget of the current session. Checkpointing the plan so the next +pass can execute cleanly. + +Estimated effort: one focused work session plus a parity + bench +run. Expected landing: fib_1M ~10.6 s, fib_4M ~32 s → ~30 s. diff --git a/crypto/math-cuda/NOTES.md b/crypto/math-cuda/NOTES.md new file mode 100644 index 000000000..5866f8d19 --- /dev/null +++ b/crypto/math-cuda/NOTES.md @@ -0,0 +1,341 @@ +# math-cuda — performance notes + +Running log of attempts, analysis, and what's left. Intended to survive +context loss between sessions. Update as you go. + +## Current state (2026-04-21, 5 commits on branch `cuda/batched-ntt`) + +### End-to-end speedup (fused tree + GPU R4 deep + LDE-resident handles + GPU R3 OOD) + +| Program | CPU rayon (46 cores) | CUDA (mean over 15 trials) | Delta | +|---|---|---|---| +| fib_iterative_1M | **18.269 s** | **11.64 s** | **1.57× (36.3% faster)** | +| fib_iterative_4M | | **28.3 s** | | + +Correctness: all 30 math-cuda parity tests + 121 stark cuda tests pass. + +### What's GPU-accelerated now + +| Hook | What it does | Kernel(s) | +|---|---|---| +| Main trace LDE + Merkle commit | Base-field LDE → leaf-hash → **full Merkle tree** on device, retaining the LDE device buffer as a `GpuLdeBase` handle on `LDETraceTable` | `ntt_*_batched` + `keccak256_leaves_base_batched` + `keccak_merkle_level` | +| Aux trace LDE + Merkle commit | Ext3 LDE via 3× base decomposition → ext3 leaf-hash → **full Merkle tree**, retaining the de-interleaved LDE buffer as a `GpuLdeExt3` handle | `ntt_*_batched` + `keccak256_leaves_ext3_batched` + `keccak_merkle_level` | +| R2 composition-parts LDE | `number_of_parts > 2` branch: batched ext3 evaluate-on-coset | `ntt_*_batched` (no iFFT variant) | +| R2 commit_composition_poly | Row-pair ext3 Keccak leaves + pair-hash inner tree | `keccak_comp_poly_leaves_ext3` + `keccak_merkle_level` | +| R4 DEEP-poly LDE | Standard ext3 LDE with uniform 1/N weights | `ntt_*_batched` via ext3 decomposition | +| **R4 deep_composition_poly_evals** | Per trace-size row, sum ~200 ext3 FMAs over all LDE cols + scalars. Reads main+aux LDE **from device handles** (no re-H2D) | `deep_composition_ext3_row` | +| **R3 OOD evaluation** | Per eval point, batched barycentric over all main (base) + aux (ext3) columns. Reads LDE directly from device handles with `row_stride = blowup_factor` (no slab extraction, no H2D) | `barycentric_{base,ext3}_batched_strided` | +| **R4 FRI commit phase** | Fold + pair-hash Merkle tree per layer, evals + twiddles device-resident across log₂(N) layers. Only the root (32 B) D2Hs per layer to feed the transcript; layer evals + tree D2H at the end for query phase | `fri_fold_ext3` + `fri_update_twiddles` + `keccak_fri_leaves_ext3` + `keccak_merkle_level` | +| R2 `extend_half_to_lde` | Dormant — only hit by tiny tables below threshold | Infrastructure in place | + +### Where time still goes (aggregate across rayon threads, 1M-fib, warm) + +| Phase | Aggregate | On GPU? | +|---|---|---| +| R3 OOD evaluation | 5.94 s | ❌ barycentric point-evals in ext3 | +| R2 evaluate (constraint eval) | 5.00 s | ❌ per-AIR constraint logic | +| R2 decompose_and_extend_d2 (FFT) | 3.06 s | ✅ partial (parts LDE) | +| R4 deep_composition_poly_evals | 2.32 s | ❌ ext3 barycentric | +| R2 commit_composition_poly (Merkle) | 1.92 s | ❌ different leaf-pair pattern, not wired | +| R4 fri::commit_phase | 1.58 s | ❌ in-place folding | +| R4 queries & openings | 1.54 s | ❌ per-query Merkle openings | +| R4 interpolate+evaluate_fft | 0.53 s | ✅ (via DEEP-poly LDE) | + +### What would be needed to reach ~2× (~50%) + +1. **Ext3 arithmetic on GPU**. Full `ext3 × ext3` multiplication (currently + we only use `base × ext3` in the NTT butterflies). Required for OOD and + deep-composition barycentric kernels. ~100 lines of CUDA plus parity tests. + **✅ LANDED** — `kernels/ext3.cuh` has add/sub/neg/mul_base/mul with the + `dot3` helper; parity tested in `tests/ext3.rs`. +2. **Barycentric at a point** kernel. O(N) reduction per column, M columns + in parallel. Addresses OOD (5.94 s) + deep-composition (2.32 s). ~8 s of + aggregate work ≈ ~0.5–1 s wall savings with rayon. + **✅ LANDED (unwired)** — `kernels/barycentric.cu` + + `src/barycentric.rs` + parity test `tests/barycentric.rs` all work. The + R3-OOD wiring in `get_trace_evaluations_from_lde` was **reverted** after + benchmarking: in the current prover the CPU is idle during R3 (the GPU + is busy on LDE/Merkle streams), so routing R3 OOD to the GPU only adds + queue contention without freeing wall time — fib_iterative_1M went + 13.09 s → 14.20 s, and fib_iterative_4M went 33.67 s → 36.03 s, both + regressions. The kernels stay here as a building block for future + workloads where the GPU has idle windows during R3 (single-table or + very-large-trace proofs). +3. **R2 constraint evaluation on GPU**. Per-AIR pointwise kernels over the + LDE domain. Biggest engineering lift (each AIR has its own constraint + logic). Could save 5 s aggregate ≈ 0.3–0.5 s wall. +4. **GPU-resident LDE across rounds**. Currently Rounds 2–4 re-read LDE + columns from the host Vecs that Round 1 produced. Keeping the LDE on + device would remove the next H2D cycle. + +None of these are trivial; individually each is hours to a day. Collectively +they'd probably push the 1M-fib proof under 10 s (matching Zisk/Airbender- +class wins). + +### Lesson from the R3-OOD attempt + +Aggregate CPU time (as reported by the `instruments` feature) overstates +the real wall-time cost of a phase whenever rayon already parallelises +it. R3 OOD's 5.94 s "aggregate" number was misleading: on a 46-core box +with ~7 tables running in parallel, rayon reduces that to ≈0.15 s wall, +which is *less than* one H2D round-trip of the 500 MB of column data the +GPU kernel would need. The GPU-resident LDE refactor (item 4 above) is +the unlock here — without it, the CPU barycentric is already close to a +lower bound for this workload. + +### What's on the GPU but unwired (kernels + parity tests only) + +After benchmarking, these optimisations have the kernel built and parity- +tested but are NOT wired into the prover because the measured wall-time +delta was neutral or negative: + +- **Barycentric OOD** (`kernels/barycentric.cu`, `tests/barycentric.rs`): + R3 trace OOD + composition-parts OOD. CPU path is already idle-side + while GPU is busy on LDE streams, so routing R3 to GPU regresses. +- **FRI layer Merkle tree** (`keccak_fri_leaves_ext3` + + `build_fri_layer_tree_from_evals_ext3`, `tests/fri_layer_tree.rs`): + per-layer H2D of the freshly-folded eval slab (pageable Vec) eats the + tree-build savings. Needs fused fold+leaves+tree staying on device + across layers, which requires item 4 below. +- **Standalone GPU Merkle inner-tree builder** + (`build_merkle_tree_on_device`, `tests/merkle_tree.rs`): superseded by + the fused LDE+leaves+tree pipeline which skips the leaf D2H entirely. + The standalone function remains as a building block. + +### Path to a meaningful next win + +The remaining aggregate targets are dominated by CPU work whose wall-time +cost is small (~0.2–0.5 s each) because rayon already parallelises them. +Moving any one of them to GPU pays a per-call H2D that wipes the gain. +The unlock is **LDE GPU-resident across rounds** — keep the main/aux +LDE buffers alive on device after R1 commits, and let R2 constraint +evaluation, R3 OOD, R4 deep-composition, and R4 FRI-fold read them +without re-H2D. + +That refactor lets three currently-unwired pieces flip from net-negative +to net-positive: + - R3 barycentric OOD (kernels exist) + - FRI commit phase (kernels exist) + - R4 deep composition (kernel not yet written; small, pointwise FMA) + +…and enables the big one: **GPU constraint evaluation** via a +device-side expression-tree interpreter over a compile-time-serialised +AST (keeps the CPU constraints as the single source of truth). + +Scope for the LDE-GPU-resident refactor: add an `Option>` +sidecar to `LDETraceTable`, have the R1 fused path populate it, and +gate each consumer's GPU path on its presence. ~300-500 LoC with +careful CPU-fallback preservation. + +### What's on the GPU now + +Four independent hook points in the stark prover, all behind the `cuda` +feature flag. CPU path unchanged when the feature is off. + +| Hook | Call site | Fires per 1M-fib proof | Notes | +|---|---|---|---| +| Main trace LDE (base-field) | `expand_columns_to_lde`, `prover.rs:479` | ~40 cols × few tables | `coset_lde_batch_base_into` | +| Aux trace LDE (ext3, via 3× base decomposition) | `expand_columns_to_lde`, same call site | ~20 cols × few tables | `coset_lde_batch_ext3_into` | +| R2 composition parts LDE (ext3, `number_of_parts > 2` branch) | `round_2_compute_composition_polynomial`, `prover.rs:948` | ~8 (one per big table) | `evaluate_poly_coset_batch_ext3_into` | +| R4 DEEP-poly extension (ext3) | `round_4_compute_and_commit_fri_layers`, `prover.rs:1107` | ~8 | `coset_lde_batch_ext3_into` with uniform `1/N` weights | +| R2 `extend_half_to_lde` (ext3, 2-halves batch) | `decompose_and_extend_d2`, `prover.rs:832` | **0** — only tiny tables hit that branch in current VM | Infrastructure in place but size gate skips it | + +The ext3 path costs no extra CUDA: an NTT over an ext3 column is +componentwise equivalent to three independent base-field NTTs sharing +the same twiddles, because a DIT butterfly's multiplication is `base * +ext3 = componentwise base*u64`. Stark de-interleaves the 3n u64 slab +into 3 base slabs in the pinned staging buffer, runs the existing +`*_batched` kernels over 3M logical columns, and re-interleaves on the +way out. + +### Backend (`device.rs`) + +- CUDA context, pool of 32 streams (round-robin via AtomicUsize). +- Single shared pinned host staging buffer (`cuMemHostAlloc` with + flags=0: portable, non-write-combined). Grown once per process to the + largest LDE seen; serialised by a Mutex per call so concurrent rayon + workers don't step on each other. Per-stream buffers blew up pinned + memory 32× and forced first-call re-alloc on every new table size. +- Twiddle cache per `log_n` (both fwd and inv), populated on a separate + utility stream. +- Event tracking disabled globally (`disable_event_tracking()`) — cudarc + normally creates two events per `CudaSlice` alloc, which serialised + concurrent callers on the driver context lock and added per-alloc cost. + +### Kernels (`kernels/ntt.cu`) + +- `bit_reverse_permute_batched`, `ntt_dit_level_batched`, + `ntt_dit_8_levels_batched` (shmem fusion of first 8 DIT levels), + `pointwise_mul_batched`, `scalar_mul_batched`. +- Parity-tested against CPU up to `log_n = 20` in `tests/lde_batch*.rs` + and `tests/evaluate_coset_ext3.rs`. + +### Microbenches (RTX 5090, 46-core host, blowup=4, warm) + +| Size | CPU rayon | GPU batched | Ratio | +|---|---|---|---| +| 64 cols, log_n=16 (LDE 2^18) | ~75–100 ms | ~15–20 ms | **5–12×** | +| 20 cols, log_n=20 (LDE 2^22, prover-scale) | ~470 ms | ~220 ms | **~2.0–2.3×** | + +## Where the time goes at prover scale (single LDE call, log_n=20, 20 cols) + +Phase timings (enable with `MATH_CUDA_PHASE_TIMING=1`): + +| Phase | Time | +|---|---| +| host pack into pinned (rayon) | ~8 ms | +| device alloc_zeros (async) | ~0.5 ms | +| H2D (pinned → device) | ~9 ms | +| iNTT body (22 levels total) | ~3 ms | +| pointwise + bit-reverse LDE | ~2 ms | +| forward NTT body (22 levels) | ~13 ms | +| D2H (device → pinned) | ~28 ms | +| copy out (pinned → caller Vecs, rayon) | ~65 ms | +| **total** | **~130 ms** | + +**Compute is only ~15% of GPU wall time.** The other 85% is PCIe and +pageable host memcpy / page faults. No amount of kernel optimisation +alone closes this gap. + +## Things tried and their outcomes + +### ✅ Kept + +1. **Fused 8-level DIT kernel** (`ntt_dit_8_levels_batched`): first 8 + butterfly levels in shared memory. 7× reduction in launches for + levels 0–7; ~8× less DRAM traffic there. +2. **Column batching via `gridDim.y = M`**: single kernel launch handles + all columns at a level instead of M separate launches. +3. **Reusable shared pinned staging buffer** (`PinnedStaging` in + `device.rs`): `cuMemHostAlloc` with flags=0 (portable, non-WC). One + allocation grows as needed; locked on call-entry for exclusive use. +4. **Rayon-parallel host pack**: 27 ms → 8 ms at prover scale. +5. **Median-of-10 microbench** for stable measurement. + +### ❌ Tried and reverted + +1. **4-col register tile in fused 8-level kernel (A1).** Clean port of + Zisk's `br_ntt_8_steps` inner loop — 256 threads × 4 columns each in + a 1024-entry shmem tile. Neutral at prover scale (1.81× vs 1.88× + without); regressed small-n microbench (shmem pressure lowered + occupancy). The fused kernel handles only the first 8 of 22 levels at + prover scale, so even a 2× win there is ~2 ms of the ~20 ms compute + budget. +2. **Per-caller-Vec pinning via `cuMemHostRegister`.** Fast when + isolated (~1.7× on 64-col microbench) but the driver serialises pin + calls globally; under rayon-parallel table dispatch in the prover + this turned GPU slower than CPU. +3. **Per-stream pinned staging (32 buffers).** Each slot paid the + ~1 second `cuMemHostAlloc` cost on first large-table use. Replaced + with a single shared staging buffer. +4. **Pre-fault output Vec pages overlapped with D2H.** Saved ~40 ms of + copy-out, but the prefault itself cost ~60 ms on a parallel rayon + sweep (mm_struct rwsem serialisation). Net neutral. +5. **A lot of single-trial microbenches.** CPU rayon time is 20–50% + noisy; needed median-of-10 to stop chasing phantoms. + +## Why we're stuck at ~2× and the 10× ceiling + +Amdahl: at 1M-fib scale only ~20% of proof wall time is LDE, and inside +the LDE call itself only ~15% is GPU compute. The remaining 85% of a +per-call GPU budget is: + +| Cost | Size @ prover scale | Why it's there | +|---|---|---| +| PCIe D2H (pinned) | 28 ms | LDE result has to come back for Merkle | +| Pinned → pageable Vec copy | 65 ms | Caller expects `Vec>` for Round 2-4 cache; fresh-alloc pages fault on first write, fault path serialises on mm_struct rwsem | +| PCIe H2D (pinned) | 9 ms | Input columns from CPU | +| host pack | 8 ms | Pageable trace Vec → pinned staging | + +Other projects don't pay this because they **keep data GPU-resident +across Rounds 1–4**. Zisk (`pil2-stark/src/goldilocks/src/ntt_goldilocks.cu`) +chains trace → NTT → Merkle → constraint eval → FRI on device; +Airbender (`zksync-airbender/gpu_prover/`) uses a 5-stage on-device +pipeline. In both, host transfer is roughly "witness in, proof out", +nothing in between. + +## The 10× path + +Ranked by expected wall-time impact on 1M-fib (CPU baseline ~17 s): + +1. **C1: GPU Keccak256 + LDE stays on GPU through Merkle commit.** + Addresses the 28 ms D2H + 65 ms copy-out. ~4–6 s saved end-to-end. + Needs: (a) Goldilocks-input Keccak256 kernel (no reference in the + repos we explored — Airbender uses Blake2s, Zisk uses Poseidon2), + (b) a batched "commit over GPU-resident columns" kernel that reads + LDE directly from device memory and produces the 32-byte root, (c) + refactoring `commit_columns_bit_reversed` in stark to accept a GPU + handle instead of `&[Vec>]`. Estimated 1-2 days of + focused work. + +2. **B1: keep LDE buffer on GPU across rounds.** Round 2–4 currently + re-read the cached LDE from host memory (populated by Round 1). + Holding it on device instead avoids repeat H2D. Needs: refactoring + `Round1` to hold either a GPU handle OR the host Vecs, plus a + GPU constraint-eval and/or FFT path for Round 2's `extend_half_to_lde` + (`prover.rs:834`). Estimated 2-3 days. + +3. **D: ext3 NTT via component decomposition.** A single ext3 column is + `[a, b, c]` per element; butterflies use a base-field twiddle + multiplication, and `base × ext3` is componentwise. So NTT over M + ext3 columns = NTT over 3M base columns with the same twiddles and + weights. No new kernels needed — just a de-interleave at pack time + and re-interleave at unpack. This unlocks: + - Aux trace LDE (`expand_columns_to_lde` on ext3, 2.9 s aggregate) + - `extend_half_to_lde` (Round 2 decompose, 6.1 s aggregate, biggest + single FFT chunk in the proof). Needs different weights — + `g^(-k) / N` rather than `g^k / N`. Easy. + +4. **A2: warp-shuffle butterflies for stages 0–5.** Saves maybe 3 ms of + compute. Low priority after (1)–(3). + +5. **A3: vectorised `uint2` `__ldg` loads in per-level kernels.** Saves + maybe 5 ms. Low priority. + +## Key files + +- `crypto/math-cuda/kernels/{goldilocks.cuh,ntt.cu,arith.cu}` +- `crypto/math-cuda/src/{device.rs,ntt.rs,lde.rs,lib.rs}` +- `crypto/math-cuda/tests/{goldilocks.rs,ntt.rs,lde.rs,lde_batch.rs,bench_quick.rs}` +- `crypto/stark/src/gpu_lde.rs` — the stark-level dispatch wrapper +- `crypto/stark/src/prover.rs:479` — `expand_columns_to_lde` call site +- `crypto/stark/src/prover.rs:834` — `extend_half_to_lde`, **not yet + GPU-enabled** (Round 2 quotient extension FFTs) +- `crypto/stark/src/prover.rs:368` — `commit_columns_bit_reversed`, the + Merkle commit that C1 would replace + +## References + +- `/workspace/references/pil2-proofman/pil2-stark/src/goldilocks/src/ntt_goldilocks.cu` + — Zisk's NTT, especially `br_ntt_8_steps:674` (4-col register tile pattern) +- `/workspace/references/zksync-airbender/gpu_prover/native/ntt/` + — Airbender's NTT with warp-shuffle butterflies and `uint2` loads +- `/workspace/references/zksync-airbender/gpu_prover/native/blake2s.cu` + — Template for GPU tree hashing (but Blake2s, not Keccak) +- Research summary in earlier session — see conversation history or the + `vast-squishing-crayon` plan file at `/root/.claude/plans/` if it still + exists. + +## Useful commands + +```sh +# Build with GPU feature +cargo check -p stark --features cuda + +# Parity tests +cargo test -p math-cuda + +# Microbenches (median-of-10) +cargo test -p math-cuda --test bench_quick --release bench_lde_batched -- --ignored --nocapture + +# Per-phase timing within a batched call +MATH_CUDA_PHASE_TIMING=1 cargo test -p math-cuda --test bench_quick --release bench_lde_batched_prover_scale -- --ignored --nocapture + +# End-to-end prove bench +cargo test -p lambda-vm-prover --release --test bench_gpu bench_prove_fib_1m -- --ignored --nocapture +cargo test -p lambda-vm-prover --release --features cuda --test bench_gpu bench_prove_fib_1m -- --ignored --nocapture +cargo test -p lambda-vm-prover --release --features instruments,cuda --test bench_gpu bench_prove_fib_1m -- --ignored --nocapture # adds phase breakdown + +# Threshold override +LAMBDA_VM_GPU_LDE_THRESHOLD=$((1<<18)) cargo test ... +``` diff --git a/crypto/math-cuda/NOTES_LOGUP.md b/crypto/math-cuda/NOTES_LOGUP.md new file mode 100644 index 000000000..7853cbd25 --- /dev/null +++ b/crypto/math-cuda/NOTES_LOGUP.md @@ -0,0 +1,83 @@ +# LogUp aux-trace build on GPU — exp-7 checkpoint status + +## What landed + +End-to-end GPU pipeline for `compute_logup_batched_term_column`: + +- `crypto/math-cuda/kernels/logup.cu` + - `logup_pair_fingerprint` / `logup_single_fingerprint` — evaluates a + `BusInteraction`'s fingerprint row-by-row from a bytecode descriptor + supporting every `Packing` variant (Direct, Word2L, Word4L, DWordWL, + DWordHHW, DWordWHH, DWordHL, DWordBL, QuadHL, QuadWL) plus `OP_LINEAR` + for arbitrary linear combinations. + - `logup_pair_term_assembly` / `logup_single_term_assembly` — + evaluates `Multiplicity` (One/Column/Sum/Negated/Diff/Sum3/Linear) + and combines with the inverted fingerprints into the term column. +- `crypto/math-cuda/src/logup.rs` — host-side wrappers + a + `DeviceMainCols` handle so `build_auxiliary_trace` uploads the + main-segment columns once per table instead of once per pair. +- `crypto/stark/src/logup_gpu.rs` — serializer from the native + `BusValue` / `Multiplicity` / `LinearTerm` enums into the shared + `FingerprintOp` / `LinearTerm` / `MultiplicityDesc` wire format, plus + dispatch that turns an entire table's interaction list into committed + + virtual term columns in one H2D of main_cols. + +Coefficient handling: all `i64` / `u64` constants are canonicalized into +`[0, p)` on the Rust side, so the kernel never branches on sign. + +## Parity + +121 stark prove+verify tests pass with `LAMBDA_VM_GPU_LOGUP_THRESHOLD=0` +(forces the GPU path for every table). Verifier is untouched. + +## Perf on fib_iterative_1M (46-core CPU + RTX 5090, 15-trial mean) + +| Path | avg | aux-build wall | +|-----------------------------------------|--------|----------------| +| exp-7 CPU (threshold=MAX, default) | 11.17s | — | +| exp-7 GPU table-batched (threshold=0) | 11.81s | 2.66s | +| exp-7 GPU per-pair (earlier iteration) | 16.06s | 5.09s | + +The per-pair version regressed badly because each pair re-uploaded the +~240 MB main trace. The table-batched version eliminates that redundant +H2D (upload once per table, dispatch all pairs against the shared +device buffer), which recovers 4s. It's still ~640 ms behind the +rayon-parallel CPU path — the 46-core CPU reads main_cols from RAM for +free, while the GPU must pay PCIe for it. + +## Why it isn't a win yet + +- **Nested parallelism → stream contention.** The prover already runs + `build_auxiliary_trace` in parallel across ~12 tables. Each GPU-path + table runs its pair kernels serially on one stream, so we have ~12 + concurrent streams competing for the device. That contention eats + most of the per-table speedup. +- **H2D-dominated for large tables.** For the MEMW_R × 3.1M-row tables + each H2D is ~750 MB — a sizeable fraction of the 70-100 ms budget + per table, before any kernel fires. +- **CPU baseline is genuinely fast.** 46 rayon threads chewing through + fingerprints + batch inverse + term assembly is hard to beat when + the data is already in RAM. + +## Default posture + +Gated off by default via `LAMBDA_VM_GPU_LOGUP_THRESHOLD` (default +`usize::MAX`). Set the env var to `0` (or a `trace_len` threshold) to +force-enable for experiments. CPU-only build and `--features cuda` +without the env var both keep the old rayon path — zero regression. + +## Where to go next + +Plausible paths to turn this into a win: + +1. **Cross-table batching.** Upload main_cols for all tables at once + (or in a few fat batches) and let one stream chew through pairs + without concurrent-stream contention. Requires restructuring the + prover's table-parallel loop. +2. **Fused multi-pair kernel.** One kernel launch per table that walks + all pairs using a batched bytecode layout, so per-pair CPU + orchestration disappears. +3. **Keep the trace resident on device.** If the main LDE already + lives on the GPU (as in the experimental-lde-resident checkpoint), + the H2D vanishes and this path starts winning. That's a bigger + architectural move, not a logup-local tweak. diff --git a/crypto/math-cuda/NOTES_SCALE.md b/crypto/math-cuda/NOTES_SCALE.md new file mode 100644 index 000000000..59b7ec599 --- /dev/null +++ b/crypto/math-cuda/NOTES_SCALE.md @@ -0,0 +1,80 @@ +# Scale + profile snapshot — exp-8 + +Profiling + scale benchmark run on top of `cuda/exp-7-logup-gpu` +(LogUp GPU path opt-in, threshold set to 1M rows). All numbers below +are **mean over 5 trials** (fib_4M is 3 trials) with +`LAMBDA_VM_GPU_LOGUP_THRESHOLD=1048576` exported. Bench binary built +via `cargo test -p lambda-vm-prover --release --features cuda,instruments`. + +## Scale + +| trace size | fib_iterative | GPU mean | Wall ratio | Per-row cost vs 1M | +|---|---|---|---|---| +| 1M rows | fib_iterative_1M | 12.52 s | 1.00× | 1.00 | +| 2M rows | fib_iterative_2M | 20.33 s | 1.62× | 0.81 | +| 4M rows | fib_iterative_4M | 32.30 s | 2.58× | 0.65 | + +Doubling the trace size does **not** double the wall time — fixed +costs (GPU warm-up, kernel-launch overhead, transcript work) amortize. +Going from 1M to 4M is only 2.58× wall for 4× data, i.e. 35% cheaper +per row at the larger size. This is the **GPU-favored regime**: every +optimization that pays for per-table overhead compounds as tables get +bigger, and future work (exp-9 through exp-11) should be benched at +fib_4M in addition to fib_1M. + +## Wall-time breakdown (fib_1M, representative trial @ 12.11 s) + +``` +Trace build 2.39 s 19.7% CPU (user-supervised area) +Round 1 4.79 s 39.6% + Main trace commits 1.50 s GPU LDE + Merkle + expand_columns_to_lde 1.45 s (agg) + commit (Merkle) 0.54 s (agg) + Aux trace build 2.15 s LogUp, GPU when >1M rows + Aux trace commit 1.15 s GPU LDE + Merkle +Rounds 2–4 4.52 s 37.4% mixed GPU/CPU +Other 0.19 s +``` + +## Where GPU vs CPU is at (fib_1M Rounds 2–4 aggregates) + +Aggregates are summed across rayon threads; wall is a fraction of each. + +``` +R2 evaluate 5.24 s agg quotient eval, GPU (partial) +R4 queries & openings 3.88 s agg CPU — ← remaining bottleneck +R2 decompose_and_extend_d2 2.88 s agg GPU LDE on device handles +R3 OOD evaluation 1.77 s agg GPU barycentric +R2 commit_composition_poly 1.74 s agg GPU (R2 commit fuse) +R4 deep_composition_poly_evals 1.39 s agg GPU R4 deep +R4 fri::commit_phase 1.11 s agg GPU (device-resident) +R4 interpolate+evaluate_fft 0.51 s agg small +``` + +## What's actionable + +Ranked by expected wall-time yield on fib_1M: + +1. **Aux trace build (2.15 s wall).** Today's LogUp path is neutral + — 12 tables each run build_auxiliary_trace in rayon, each firing + its own GPU stream. Fix: serialize GPU dispatch so streams don't + contend on H2D / compute. Expected: 500–1000 ms. + Checkpointed as `cuda/exp-9-logup-cross-table`. + +2. **R4 queries & openings (~300 ms wall).** CPU today. pil2-proofman + has this on GPU via `getTreeTracePols` + `genMerkleProof`; kernels + are simple. Requires keeping the main-trace Merkle tree + device-resident past R1. Expected: 200–300 ms. + Checkpointed as `cuda/exp-10-fri-queries-on-gpu`. + +3. **Device-resident main trace (1–2 s wall, architectural).** + Eliminate the per-phase H2D of the main trace by building it + straight into GPU memory (or uploading once post-build). Touches + trace build (previously off-limits; now green-lit). Biggest single + move. Checkpointed as `cuda/exp-11-device-trace`. + +Profiling note: `nsys profile -t cuda,nvtx` on this box adds ≥10× +overhead on this workload (12-trial bench ran >12 min before we +killed it). Stick to `--features instruments` for wall-time +measurements; use `nsys` only on a single-trial run with `--sample=none +--cpuctxsw=none` and accept the slowdown. diff --git a/crypto/math-cuda/PROFILE.md b/crypto/math-cuda/PROFILE.md new file mode 100644 index 000000000..300ee335a --- /dev/null +++ b/crypto/math-cuda/PROFILE.md @@ -0,0 +1,124 @@ +# nsys profile of fib_iterative_1M (2 proves: 1 warmup + 1 measured) + +## TL;DR + +The GPU is **not** the bottleneck. Out of ~12 s wall-clock per proof, +only ~2.6 s is *any* CUDA activity (kernels + memcpy combined). The +remaining ~9.4 s is CPU work that we can't meaningfully shrink +without porting program logic (trace build, aux trace build, +constraint eval, query-phase openings). + +Tile-based NTT layout — the optimisation that was on the tier-2/3 +shortlist — would land at most ~100 ms wall because the NTT is only +243 ms of GPU time and much of that already overlaps with CPU / +other-table compute. + +## CUDA activity breakdown (2 proves worth) + +| Operation | Time (ms) | % CUDA | Invocations | Total MB | +|----------------------------------------|-----------|--------|-------------|----------| +| `[CUDA memcpy Device-to-Host]` | 1275.1 | 49.9 % | 690 | 16336 | +| `[CUDA memcpy Host-to-Device]` | 638.7 | 25.0 % | 1674 | 10311 | +| `ntt_dit_level_batched` | 243.1 | 9.5 % | 1176 | — | +| `barycentric_ext3_batched_strided` | 74.4 | 2.9 % | 28 | — | +| `keccak_merkle_level` | 65.5 | 2.6 % | 3312 | — | +| `bit_reverse_permute_batched` | 56.1 | 2.2 % | 98 | — | +| `keccak256_leaves_ext3_batched` | 53.0 | 2.1 % | 14 | — | +| `keccak256_leaves_base_batched` | 35.1 | 1.4 % | 12 | — | +| `barycentric_base_batched_strided` | 33.8 | 1.3 % | 24 | — | +| `ntt_dit_8_levels_batched` | 25.0 | 1.0 % | 98 | — | +| `keccak_comp_poly_leaves_ext3` | 20.7 | 0.8 % | 14 | — | +| `deep_composition_ext3_row` | 12.3 | 0.5 % | 12 | — | +| `keccak_fri_leaves_ext3` | 8.0 | 0.3 % | 258 | — | +| `[CUDA memset]` | 6.9 | 0.3 % | 134 | — | +| `pointwise_mul_batched` | 6.7 | 0.3 % | 56 | — | +| `fri_fold_ext3` | 1.0 | — | 272 | — | +| `fri_update_twiddles` | 0.3 | — | 258 | — | +| **TOTAL CUDA** | **2555.6**| | | | +| — of which kernel compute | 634.9 | 24.8 % | | | +| — of which memcpy / memset | 1920.7 | 75.2 % | | | + +## What this tells us + +1. **Kernel compute total is 635 ms across 2 proves** (so ~320 ms per + proof). The GPU is not under-utilised — this is what it takes to + do the actual field arithmetic + hashing. + +2. **Memcpy totals ~1.9 s across 2 proves** (~950 ms per proof). Most + of this is overlapped with compute on parallel streams. The + memcpy wall-time contribution is only partially additive. + +3. **16.3 GB of D2H** per 2 proves = ~8 GB per proof. Largest single + D2H is 856 MB (pinned-staging flush for the biggest table). + +4. **1176 invocations of `ntt_dit_level_batched`** — the per-level + non-fused kernel used for levels outside the shared-memory fusion + window. 207 μs average. The 8-level fused kernel fires 98 times. + +5. **Memcpy is 3× the kernel time.** Most of it is D2H of the LDE + back to host (for query-phase openings that happen on CPU). + +## Where the 12 s wall time actually goes + +The instrument dump earlier in the session gave us: + +- Trace build (CPU, program-specific): **~2.4 s wall** +- Aux trace build (CPU, per-AIR): **~2.4 s wall** +- Round 1 LDE + Merkle (GPU-bound): ~1.5 s wall +- Rounds 2–4 (mostly GPU, some CPU): ~4.8 s wall +- Misc CPU prelude / setup / finalize: ~0.9 s wall + +The ~2.6 s of CUDA activity from this profile sits *inside* Rounds +1 + 2–4 — mostly overlapped with CPU work. + +## Implications for the remaining optimisation list + +### Tile-based NTT layout (previously the candidate for tier 3) + +**Reject.** Even a perfect 2× speedup on every NTT kernel would save +(243 + 25 + 56) / 2 = 162 ms of GPU kernel time. Most of that is +hidden behind memcpy / CPU work, so the wall-time saving is well +under 100 ms. A 1700 LoC NTT rewrite for <1 % wall is the wrong +call. + +### GPU Montgomery batch inverse (Blelloch scan) + +**Still viable** at ~50–100 ms wall savings, but confirmed marginal. +Only worth doing if done opportunistically (e.g. as part of a larger +Round 3/4 CPU-prelude port). + +### Reducing D2H traffic + +**Real lever.** 16.3 GB D2H per 2 proves includes data that the CPU +path needs for query-phase openings. But some D2H is redundant: +- LDE D2H for tables/rounds where the device handle was already used +- Full tree D2H when queries only touch log(N) path nodes + +Quantifying this needs per-call tracing; skipped for this session. + +### Constraint eval interpreter (item 5a) + +**Biggest lever remaining.** CPU constraint eval is ~0.5–0.8 s wall. +Moving to GPU needs a per-AIR AST → bytecode serializer + a device +interpreter (pil2-proofman's pattern, ~800+ LoC). Touches constraint +code, which is the reason we flagged the memory rule. + +### Aux trace build / trace build on GPU + +**Biggest two levers overall** (~4.8 s wall combined) but these are +per-AIR / per-VM-executor logic. Multi-day porting work, plus the +risk of diverging from the CPU reference (which remains the +verifier-authoritative path). + +## Conclusion + +The profile confirms what the aggregate instruments measurements +already suggested but more precisely: + +> **GPU-side kernel compute is ~320 ms per proof. Any further +> optimisation confined to the GPU side has a hard ceiling there.** + +The remaining ~9+ seconds of wall time is on the CPU (trace build, +aux trace build, constraint eval, query phase openings). Pushing +past 1.6× on fib_1M requires porting one of those, not further GPU +tuning. diff --git a/crypto/math-cuda/TIER_3_ANALYSIS.md b/crypto/math-cuda/TIER_3_ANALYSIS.md new file mode 100644 index 000000000..8f526d5ff --- /dev/null +++ b/crypto/math-cuda/TIER_3_ANALYSIS.md @@ -0,0 +1,64 @@ +# Tier 3 analysis + +This branch (`cuda/exp-4-tier3`) was opened to pursue the tier-3 +micro-optimisations identified at the end of tier 2, but after analysis +each item turned out to be too small relative to run-to-run variance +(≈ 0.4 s over 15 trials on fib_1M) to land safely. Starting state is +unchanged from the tier 2 end (`cuda/exp-3-tier2`, `2ba3af77`). + +## Items investigated + +### Stream overlap with `cudaEvent` dependencies (item 40) +The existing round-robin stream pool already gives per-table +concurrency. Within a single table, R2 can't usefully start until R1's +transcript root appends, and R3/R4 depend on R2's challenges — the +transcript is the serialisation point, not a stream barrier. Possible +saving: <50 ms wall. Deferred. + +### Warp-level barycentric reduction (item 41) +Current `block_reduce_ext3` uses 3 × 256 u64 shmem + tree reduction +across 256 threads. A warp-shuffle-based approach would cut shmem to +3 × 32 u64 and save a few `__syncthreads` per block. Each barycentric +kernel call is already <5 ms on fib_1M's trace sizes, so the payoff +is well under 20 ms wall. Not shipped. + +### GPU batch inverse for R4 DEEP denoms (item 42) +R4 DEEP computes `num_denoms = n × (1 + num_eval_points) ≈ 1M` ext3 +elements on CPU (sequential `push` loop + `inplace_batch_inverse`). +Tried two approaches: + +1. **Parallel `push` via rayon `par_iter`**: one ext3 subtract per + task is finer-grained than rayon's overhead. Measured neutral to + slightly slower. Reverted. + +2. **Single-thread GPU Montgomery batch inverse**: 2M serial ext3 + muls on a single SM ≈ 20 ms per call. 7 tables running in + parallel on GPU serialise on stream pool → ≈ 140 ms total GPU + busy-time. Today's CPU version runs in ~20–30 ms *wall* thanks to + 7-way rayon parallelism across tables. **Net regression**, not + shipped. + + A proper parallel Blelloch scan over ext3 would flip this + (~5 ms GPU per call), but the implementation is ~300+ LoC with + a delicate ext3-over-blocks primitive — too big for tier 3 + scope. Listed as tier-1 follow-up. + +### Zisk's compact TILE layout for NTT (from item 31) +Their 256×4 tile layout for `batched_steps_blocks_par_dif_noBR_compact` +is a good trick, but we'd need to profile current NTT occupancy with +nsight-compute to know whether we're memory-bound enough to benefit. +Without that profile, re-writing 1700+ LoC of NTT kernels for +unclear gain is speculative. + +## What would actually move the needle from here + +See `NOTES.md`. The only remaining items with ≥0.3 s wall savings +require touching program-specific code (trace build, aux trace build, +constraint eval) or are architectural unlocks (constraint AST → +device bytecode interpreter). All tier-1 scope. + +## Branch outcome + +No code changes land on this branch. Performance stays at tier 2's +1.57× on fib_1M. Leaving `cuda/exp-4-tier3` pinned here so the +investigation is traceable. diff --git a/crypto/math-cuda/ZISK_COMPARISON.md b/crypto/math-cuda/ZISK_COMPARISON.md new file mode 100644 index 000000000..ecaf45820 --- /dev/null +++ b/crypto/math-cuda/ZISK_COMPARISON.md @@ -0,0 +1,113 @@ +# pil2-proofman (Zisk) vs lambda-vm CUDA kernels + +Zisk reuses pil2-proofman's GPU prover, so this is a pil2-proofman +comparison. Same field (Goldilocks + deg-3 ext), same field +representations, similar STARK protocol shape. Kernel listings below +from `/workspace/references/pil2-proofman/pil2-stark/src/**/*.cu` +(HEAD sampled for this report). + +## Kernels they have, mapped to ours + +| Phase | pil2-proofman kernel | Our equivalent | Status | +|---|---|---|---| +| Goldilocks base arith | `gl64_tooling.cu` | `kernels/goldilocks.cuh` | ✓ parity | +| Cubic-ext arith | `goldilocks_cubic_extension.cuh` | `kernels/ext3.cuh` | ✓ parity | +| Base-field NTT | `ntt_goldilocks.cu` | `kernels/ntt.cu` (batched) | ✓ | +| Coset LDE | via NTT + `computeX_kernel` + `buildZHInv_kernel` | `lde::coset_lde_base` / `ext3` | ✓ | +| Hash (keccak) | external — uses rapidsnark path | `kernels/keccak.cu` | ✓ | +| Hash (poseidon2) | `poseidon2_goldilocks.cu` | **not ported** | — (we don't use poseidon2) | +| Merkle tree build | inside `proveQueries_inplace` | `kernels/fri.cu::fri_merkle_tree_*` | ✓ parity | +| LDE → leaf-hash → tree fuse | inline in `starks_gpu.cu` | `kernels/fri.cu::fri_fused_*` | ✓ | +| FRI fold | `fold` (starks_gpu.cu:604) | `kernels/fri.cu::fri_fold_ext3` | ✓ | +| FRI transpose | `transposeFRI` | n/a (different layout) | — | +| FRI proximity expression | `computeFRIExpression` (:1191) | part of our R4 `deep_composition_poly_evals` | ≈ | +| OOD eval (Lagrange) | `fillLEv_2d` + `computeEvals_v2` | `kernels/barycentric.cu` + `deep.cu` | ✓ | +| OOD reduction | `computeEvalsReduction` | barycentric kernel tail | ✓ | +| Constraint / expression evaluator | `computeExpressions_` (unified bytecode) | **partial: LogUp bytecode only (exp-7)** | ✗ | +| Insert trace col → aux_trace buffer | `insertTracePol` | **we D2H and re-allocate each time** | ✗ | +| Query trace extraction | `getTreeTracePols` / `getTreeTracePolsBlocks` | **CPU: `open_deep_composition_poly`** | ✗ | +| Merkle proof generation | `genMerkleProof` (:817) | **CPU: `fri::query_phase`** | ✗ | +| Query-position computation | `moduleQueries` | CPU (fast) | ≈ | +| Zerofier / domain `X` setup | `buildZHInv_kernel` / `computeX_kernel` | CPU, on host path | ≈ | +| Airgroup value reduce | `opAirgroupValue_` | n/a (different architecture) | — | +| Parallel scan | `prescan` / `prescan_correction` | `kernels/inverse.cu` (chunk scan) | ✓ parity | +| Trace unpack | `unpack` | `kernels/lde.cu` via extract | ✓ | +| Poseidon commit (BN128) | `poseidon_bn128.cu` | not used | — | + +Legend: ✓ = we have it, ≈ = near-equivalent, ✗ = gap, — = intentionally +skipped (different architecture / not applicable). + +## The three real gaps + +### 1. Unified expression/constraint evaluator on device + +pil2-proofman compiles every algebraic expression in the AIR — boundary +constraints, transition constraints, bus expressions, the whole lot — +into a single `(ops[], args[])` bytecode that `computeExpressions_` +interprets on GPU. One kernel launch evaluates arbitrarily many +expressions over the domain. All inputs (trace, public inputs, +airgroup values, challenges, evaluations) are pointers into a single +device-side `aux_trace` buffer. + +We have a narrower version of this idea (exp-7: LogUp-only bytecode), +and our constraint evaluation stays on CPU. R2 evaluate aggregate is +~5.2 s, roughly ~250 ms wall — a unified bytecode evaluator would +eliminate both that wall time and the R1 aux-build H2D. + +**Scope:** large. The PIL compiler emits their bytecode at build time; +we'd need a constraint → bytecode pass over our `Constraint` trait, +or a hand-written evaluator per AIR. + +### 2. Query phase on device (`genMerkleProof` + `getTreeTracePols`) + +Our `R4 queries & openings` is 3.88 s aggregate, wall 200–500 ms. +pil2-proofman's `proveQueries_inplace` launches two kernels per tree: +`getTreeTracePolsBlocks` (reads trace columns at query rows) and +`genMerkleProof` (walks the tree to build authentication paths). Both +are trivially parallel across queries. + +**Scope:** small. We already keep the Merkle trees device-resident in +`FriCommitState` and elsewhere, so the authentication-path kernel is a +few hundred lines. The main-trace Merkle tree for the deep-poly +openings would need to stay on device too (currently it D2Hs after +R1). This is **the next obvious win** — ports cleanly, no architectural +rethink. + +### 3. Device-resident trace with in-place writes (`insertTracePol`) + +pil2-proofman keeps the whole trace layout as one contiguous device +buffer (`aux_trace`) with per-column offsets recorded in the AIR +metadata. Operations write into slots of that buffer in-place. We +allocate per-column `Vec>` on host, D2H results, then +re-upload for the next operation. + +This is the same idea as our exp-7 `DeviceMainCols` plus our +`LdeHandle` from experimental-lde-resident, but extended to the aux +trace too. Biggest latent win — eliminates nearly every H2D in Round 1 +aux-build AND Round 2 composition-poly construction. + +**Scope:** large. Touches `TraceTable`, the AIR builder, and Round 2. +This is "task E" in the current plan, tracked on cuda/exp-11. + +## Minor differences + +- **FRI arity.** They default to arity-4 folding; we default to arity-2 + but the stark crate has arity-4 commits landed (`3c03f1e6`). Their + `fold` kernel handles both. Ours (`fri_fold_ext3`) is arity-2 only. + Low priority — arity doesn't change the critical path much. +- **Query batching.** They process queries in `nQueries` parallelism + per tree with a 32×32 thread tile. Whatever we port for query phase + should mirror this layout. +- **`airgroupValues` / `airValues`.** Different architectural concept + (their airgroup = our "table", partly). Not a direct port. + +## Conclusion + +Three gaps, two of them already surfaced in our own planning (E = +device-resident aux-trace, B = unified logup batch). The fresh one is +**query phase on device** — worth a checkpoint on its own. My hunch is +the order by yield/effort should be: + +1. Query phase on device (task new — call it exp-13 or fold into exp-11) +2. Device-resident aux trace (task E / exp-11) +3. Unified expression evaluator (large, only worth it once ①+② land) diff --git a/crypto/math-cuda/build.rs b/crypto/math-cuda/build.rs new file mode 100644 index 000000000..4ae66ef6c --- /dev/null +++ b/crypto/math-cuda/build.rs @@ -0,0 +1,63 @@ +use std::env; +use std::path::PathBuf; +use std::process::Command; + +fn cuda_home() -> PathBuf { + env::var_os("CUDA_HOME") + .or_else(|| env::var_os("CUDA_PATH")) + .map(PathBuf::from) + .unwrap_or_else(|| PathBuf::from("/usr/local/cuda")) +} + +fn nvcc_path() -> PathBuf { + cuda_home().join("bin").join("nvcc") +} + +fn compile_ptx(src: &str, out_name: &str) { + let manifest_dir = PathBuf::from(env::var("CARGO_MANIFEST_DIR").unwrap()); + let out_dir = PathBuf::from(env::var("OUT_DIR").unwrap()); + let src_path = manifest_dir.join("kernels").join(src); + let out_path = out_dir.join(out_name); + + println!("cargo:rerun-if-changed=kernels/{src}"); + println!("cargo:rerun-if-env-changed=CUDA_HOME"); + println!("cargo:rerun-if-env-changed=CUDA_PATH"); + println!("cargo:rerun-if-env-changed=CUDARC_NVCC_ARCH"); + + // Emit PTX for a virtual architecture; the CUDA driver JIT-compiles it for the + // actual GPU at load time, so one PTX works across Ada/Hopper/Blackwell. Override + // with CUDARC_NVCC_ARCH to pin a specific compute capability. + let arch = env::var("CUDARC_NVCC_ARCH").unwrap_or_else(|_| "compute_89".to_string()); + + let status = Command::new(nvcc_path()) + .args([ + "--ptx", + "-O3", + "-std=c++17", + "-arch", + &arch, + "-o", + ]) + .arg(&out_path) + .arg(&src_path) + .status() + .expect("failed to invoke nvcc — is CUDA installed and CUDA_HOME set?"); + + if !status.success() { + panic!("nvcc failed compiling {}", src_path.display()); + } +} + +fn main() { + // Headers are not compiled; emit rerun-if-changed so edits trigger rebuilds. + println!("cargo:rerun-if-changed=kernels/goldilocks.cuh"); + println!("cargo:rerun-if-changed=kernels/ext3.cuh"); + compile_ptx("arith.cu", "arith.ptx"); + compile_ptx("ntt.cu", "ntt.ptx"); + compile_ptx("keccak.cu", "keccak.ptx"); + compile_ptx("barycentric.cu", "barycentric.ptx"); + compile_ptx("deep.cu", "deep.ptx"); + compile_ptx("fri.cu", "fri.ptx"); + compile_ptx("inverse.cu", "inverse.ptx"); + compile_ptx("logup.cu", "logup.ptx"); +} diff --git a/crypto/math-cuda/kernels/arith.cu b/crypto/math-cuda/kernels/arith.cu new file mode 100644 index 000000000..4bee9b8bb --- /dev/null +++ b/crypto/math-cuda/kernels/arith.cu @@ -0,0 +1,83 @@ +// Element-wise Goldilocks kernels used by the Phase-2 parity tests. These mirror +// the CPU reference in `crypto/math/src/field/goldilocks.rs` so raw u64 outputs +// are bit-identical to the CPU path. + +#include "goldilocks.cuh" +#include "ext3.cuh" + +using goldilocks::add; +using goldilocks::sub; +using goldilocks::mul; +using goldilocks::neg; + +extern "C" __global__ void vector_add_u64(const uint64_t *a, + const uint64_t *b, + uint64_t *c, + uint64_t n) { + uint64_t tid = blockIdx.x * blockDim.x + threadIdx.x; + if (tid < n) c[tid] = a[tid] + b[tid]; // plain wrapping u64 add — toolchain sanity only. +} + +extern "C" __global__ void gl_add_kernel(const uint64_t *a, + const uint64_t *b, + uint64_t *c, + uint64_t n) { + uint64_t tid = blockIdx.x * blockDim.x + threadIdx.x; + if (tid < n) c[tid] = add(a[tid], b[tid]); +} + +extern "C" __global__ void gl_sub_kernel(const uint64_t *a, + const uint64_t *b, + uint64_t *c, + uint64_t n) { + uint64_t tid = blockIdx.x * blockDim.x + threadIdx.x; + if (tid < n) c[tid] = sub(a[tid], b[tid]); +} + +extern "C" __global__ void gl_mul_kernel(const uint64_t *a, + const uint64_t *b, + uint64_t *c, + uint64_t n) { + uint64_t tid = blockIdx.x * blockDim.x + threadIdx.x; + if (tid < n) c[tid] = mul(a[tid], b[tid]); +} + +extern "C" __global__ void gl_neg_kernel(const uint64_t *a, + uint64_t *c, + uint64_t n) { + uint64_t tid = blockIdx.x * blockDim.x + threadIdx.x; + if (tid < n) c[tid] = neg(a[tid]); +} + +// --------------------------------------------------------------------------- +// Ext3 (Goldilocks cubic extension) test kernels. +// Input/output arrays are interleaved [a_0, b_0, c_0, a_1, b_1, c_1, ...]. +// --------------------------------------------------------------------------- + +extern "C" __global__ void ext3_mul_kernel(const uint64_t *a_int, + const uint64_t *b_int, + uint64_t *c_int, + uint64_t n) { + uint64_t tid = blockIdx.x * blockDim.x + threadIdx.x; + if (tid >= n) return; + ext3::Fe3 a = ext3::make(a_int[tid*3 + 0], a_int[tid*3 + 1], a_int[tid*3 + 2]); + ext3::Fe3 b = ext3::make(b_int[tid*3 + 0], b_int[tid*3 + 1], b_int[tid*3 + 2]); + ext3::Fe3 r = ext3::mul(a, b); + c_int[tid*3 + 0] = r.a; + c_int[tid*3 + 1] = r.b; + c_int[tid*3 + 2] = r.c; +} + +extern "C" __global__ void ext3_add_kernel(const uint64_t *a_int, + const uint64_t *b_int, + uint64_t *c_int, + uint64_t n) { + uint64_t tid = blockIdx.x * blockDim.x + threadIdx.x; + if (tid >= n) return; + ext3::Fe3 a = ext3::make(a_int[tid*3 + 0], a_int[tid*3 + 1], a_int[tid*3 + 2]); + ext3::Fe3 b = ext3::make(b_int[tid*3 + 0], b_int[tid*3 + 1], b_int[tid*3 + 2]); + ext3::Fe3 r = ext3::add(a, b); + c_int[tid*3 + 0] = r.a; + c_int[tid*3 + 1] = r.b; + c_int[tid*3 + 2] = r.c; +} diff --git a/crypto/math-cuda/kernels/barycentric.cu b/crypto/math-cuda/kernels/barycentric.cu new file mode 100644 index 000000000..01e20f9a0 --- /dev/null +++ b/crypto/math-cuda/kernels/barycentric.cu @@ -0,0 +1,190 @@ +// Barycentric evaluation of a polynomial (given as evaluations on a coset) at +// a single out-of-domain point. Matches the CPU +// `math::polynomial::interpolate_coset_eval_*_with_g_n_inv` pair. +// +// Per column, the barycentric sum is +// S = Σ_i point_i * eval_i * inv_denom_i +// where `point_i` is a base-field coset point, `eval_i` is the polynomial's +// value at that point (base for main-trace columns, ext3 for aux / composition +// columns), and `inv_denom_i = 1 / (z - point_i)` is an ext3 scalar (same for +// every column sharing the evaluation point `z`). +// +// These kernels compute only S. The caller multiplies by the ext3 scalar +// `vanishing * n_inv * g_n_inv` once per column on the host — cheap, and +// keeping it out of the kernel means we don't need to carry yet another +// ext3 constant argument. +// +// Launch: grid = (num_cols, 1, 1), block = (BARY_BLOCK_DIM, 1, 1). + +#include "goldilocks.cuh" +#include "ext3.cuh" + +// 256 threads/block — one ext3 accumulator per thread in shmem ⇒ 6 KiB. +#define BARY_BLOCK_DIM 256 + +__device__ __forceinline__ ext3::Fe3 block_reduce_ext3(ext3::Fe3 my) { + __shared__ uint64_t shm_a[BARY_BLOCK_DIM]; + __shared__ uint64_t shm_b[BARY_BLOCK_DIM]; + __shared__ uint64_t shm_c[BARY_BLOCK_DIM]; + uint32_t tid = threadIdx.x; + shm_a[tid] = my.a; + shm_b[tid] = my.b; + shm_c[tid] = my.c; + __syncthreads(); + for (uint32_t s = BARY_BLOCK_DIM / 2; s > 0; s >>= 1) { + if (tid < s) { + shm_a[tid] = goldilocks::add(shm_a[tid], shm_a[tid + s]); + shm_b[tid] = goldilocks::add(shm_b[tid], shm_b[tid + s]); + shm_c[tid] = goldilocks::add(shm_c[tid], shm_c[tid + s]); + } + __syncthreads(); + } + return ext3::make(shm_a[0], shm_b[0], shm_c[0]); +} + +/// Base-column variant: M base-field columns, each `col_stride` u64 apart. +/// `inv_denoms` is a flat 3N u64 buffer (ext3, interleaved `[a0,b0,c0,...]`). +extern "C" __global__ void barycentric_base_batched( + const uint64_t *columns, + uint64_t col_stride, + const uint64_t *coset_points, + const uint64_t *inv_denoms, + uint64_t n, + uint64_t *out_ext3_int // 3M u64, interleaved per column +) { + uint64_t col = blockIdx.x; + const uint64_t *col_data = columns + col * col_stride; + + ext3::Fe3 acc = ext3::zero(); + for (uint64_t i = threadIdx.x; i < n; i += BARY_BLOCK_DIM) { + uint64_t eval = col_data[i]; + uint64_t point = coset_points[i]; + uint64_t pe = goldilocks::mul(point, eval); // F × F → F + ext3::Fe3 inv_d = ext3::make( + inv_denoms[i * 3 + 0], + inv_denoms[i * 3 + 1], + inv_denoms[i * 3 + 2]); + ext3::Fe3 term = ext3::mul_base(inv_d, pe); // E × F → E + acc = ext3::add(acc, term); + } + + ext3::Fe3 sum = block_reduce_ext3(acc); + if (threadIdx.x == 0) { + out_ext3_int[col * 3 + 0] = sum.a; + out_ext3_int[col * 3 + 1] = sum.b; + out_ext3_int[col * 3 + 2] = sum.c; + } +} + +/// Same as `barycentric_base_batched` but reads rows at stride `row_stride` +/// within each column — i.e. treats the column as an LDE of length +/// `n * row_stride` and sums over the trace-size coset (every `row_stride`-th +/// row). Lets R3 OOD run directly against the LDE device handle from R1 +/// without materialising a trace-size slab. +extern "C" __global__ void barycentric_base_batched_strided( + const uint64_t *columns, + uint64_t col_stride, + uint64_t row_stride, + const uint64_t *coset_points, + const uint64_t *inv_denoms, + uint64_t n, + uint64_t *out_ext3_int +) { + uint64_t col = blockIdx.x; + const uint64_t *col_data = columns + col * col_stride; + + ext3::Fe3 acc = ext3::zero(); + for (uint64_t i = threadIdx.x; i < n; i += BARY_BLOCK_DIM) { + uint64_t eval = col_data[i * row_stride]; + uint64_t point = coset_points[i]; + uint64_t pe = goldilocks::mul(point, eval); + ext3::Fe3 inv_d = ext3::make( + inv_denoms[i * 3 + 0], + inv_denoms[i * 3 + 1], + inv_denoms[i * 3 + 2]); + ext3::Fe3 term = ext3::mul_base(inv_d, pe); + acc = ext3::add(acc, term); + } + + ext3::Fe3 sum = block_reduce_ext3(acc); + if (threadIdx.x == 0) { + out_ext3_int[col * 3 + 0] = sum.a; + out_ext3_int[col * 3 + 1] = sum.b; + out_ext3_int[col * 3 + 2] = sum.c; + } +} + +/// Ext3-column variant: M ext3 columns stored as 3M base slabs. Column `c` +/// lives at `columns[(c*3+k)*col_stride + i]` for component `k ∈ 0..3`. +extern "C" __global__ void barycentric_ext3_batched( + const uint64_t *columns, + uint64_t col_stride, + const uint64_t *coset_points, + const uint64_t *inv_denoms, + uint64_t n, + uint64_t *out_ext3_int +) { + uint64_t col = blockIdx.x; + const uint64_t *slab_a = columns + (col * 3 + 0) * col_stride; + const uint64_t *slab_b = columns + (col * 3 + 1) * col_stride; + const uint64_t *slab_c = columns + (col * 3 + 2) * col_stride; + + ext3::Fe3 acc = ext3::zero(); + for (uint64_t i = threadIdx.x; i < n; i += BARY_BLOCK_DIM) { + ext3::Fe3 eval = ext3::make(slab_a[i], slab_b[i], slab_c[i]); + uint64_t point = coset_points[i]; + // F × E → E (point times eval, componentwise on the 3 base components) + ext3::Fe3 pe = ext3::mul_base(eval, point); + // E × E → E + ext3::Fe3 inv_d = ext3::make( + inv_denoms[i * 3 + 0], + inv_denoms[i * 3 + 1], + inv_denoms[i * 3 + 2]); + ext3::Fe3 term = ext3::mul(pe, inv_d); + acc = ext3::add(acc, term); + } + + ext3::Fe3 sum = block_reduce_ext3(acc); + if (threadIdx.x == 0) { + out_ext3_int[col * 3 + 0] = sum.a; + out_ext3_int[col * 3 + 1] = sum.b; + out_ext3_int[col * 3 + 2] = sum.c; + } +} + +/// Strided ext3 variant for R3 OOD of aux LDE. +extern "C" __global__ void barycentric_ext3_batched_strided( + const uint64_t *columns, + uint64_t col_stride, + uint64_t row_stride, + const uint64_t *coset_points, + const uint64_t *inv_denoms, + uint64_t n, + uint64_t *out_ext3_int +) { + uint64_t col = blockIdx.x; + const uint64_t *slab_a = columns + (col * 3 + 0) * col_stride; + const uint64_t *slab_b = columns + (col * 3 + 1) * col_stride; + const uint64_t *slab_c = columns + (col * 3 + 2) * col_stride; + + ext3::Fe3 acc = ext3::zero(); + for (uint64_t i = threadIdx.x; i < n; i += BARY_BLOCK_DIM) { + uint64_t lde_i = i * row_stride; + ext3::Fe3 eval = ext3::make(slab_a[lde_i], slab_b[lde_i], slab_c[lde_i]); + uint64_t point = coset_points[i]; + ext3::Fe3 pe = ext3::mul_base(eval, point); + ext3::Fe3 inv_d = ext3::make( + inv_denoms[i * 3 + 0], + inv_denoms[i * 3 + 1], + inv_denoms[i * 3 + 2]); + ext3::Fe3 term = ext3::mul(pe, inv_d); + acc = ext3::add(acc, term); + } + + ext3::Fe3 sum = block_reduce_ext3(acc); + if (threadIdx.x == 0) { + out_ext3_int[col * 3 + 0] = sum.a; + out_ext3_int[col * 3 + 1] = sum.b; + out_ext3_int[col * 3 + 2] = sum.c; + } +} diff --git a/crypto/math-cuda/kernels/deep.cu b/crypto/math-cuda/kernels/deep.cu new file mode 100644 index 000000000..b723d17bf --- /dev/null +++ b/crypto/math-cuda/kernels/deep.cu @@ -0,0 +1,117 @@ +// R4 deep composition polynomial evaluations. +// +// For each trace-size row i in 0..domain_size, accumulate: +// result_i = Σ_j γ_j · (H_j(x_i) − H_j(z^K)) · inv_h[i] (H terms) +// + Σ_j Σ_k γ'_{j,k} · (t_j(x_i) − t_j(z·w^k)) · inv_t[k,i] (trace) +// +// where x_i = LDE coset point at stride `blowup_factor` (so the kernel +// reads LDE column data at `i * blowup_factor`). `j` ranges over +// num_parts for H-terms and num_total_cols (= num_main + num_aux) for +// trace terms. `k` ranges over num_eval_points. +// +// Buffer layouts (ALL on device): +// main_lde base, row-major per column: main_lde[c * lde_stride + r] +// aux_lde ext3 de-interleaved: aux_lde[(c*3 + k) * lde_stride + r] +// h_lde ext3 de-interleaved: h_lde[(p*3 + k) * lde_stride + r] +// h_ood num_parts * 3 (ext3 interleaved) +// trace_ood num_total_cols * num_eval_points * 3 (ext3 interleaved, +// indexed as (col_idx * num_eval_points + k) * 3 + comp) +// gammas_h num_parts * 3 +// gammas_tr num_total_cols * num_eval_points * 3 +// inv_h domain_size * 3 +// inv_t num_eval_points * domain_size * 3 +// deep_out domain_size * 3 (ext3 interleaved; caller reinterprets) + +#include "goldilocks.cuh" +#include "ext3.cuh" + +extern "C" __global__ void deep_composition_ext3_row( + const uint64_t *main_lde, + const uint64_t *aux_lde, + const uint64_t *h_lde, + uint64_t lde_stride, + uint64_t num_main, + uint64_t num_aux, + uint64_t num_parts, + uint64_t num_eval_points, + uint64_t blowup_factor, + uint64_t domain_size, + const uint64_t *h_ood, + const uint64_t *trace_ood, + const uint64_t *gammas_h, + const uint64_t *gammas_tr, + const uint64_t *inv_h, + const uint64_t *inv_t, + uint64_t *deep_out) { + uint64_t i = (uint64_t)blockIdx.x * blockDim.x + threadIdx.x; + if (i >= domain_size) return; + uint64_t row = i * blowup_factor; + + ext3::Fe3 result = ext3::zero(); + ext3::Fe3 inv_h_i = {inv_h[i * 3], inv_h[i * 3 + 1], inv_h[i * 3 + 2]}; + + // H-terms + for (uint64_t j = 0; j < num_parts; ++j) { + ext3::Fe3 h_val = { + h_lde[(j * 3 + 0) * lde_stride + row], + h_lde[(j * 3 + 1) * lde_stride + row], + h_lde[(j * 3 + 2) * lde_stride + row], + }; + ext3::Fe3 h_ood_j = {h_ood[j * 3], h_ood[j * 3 + 1], h_ood[j * 3 + 2]}; + ext3::Fe3 num = ext3::sub(h_val, h_ood_j); + ext3::Fe3 gamma = {gammas_h[j * 3], gammas_h[j * 3 + 1], gammas_h[j * 3 + 2]}; + ext3::Fe3 tmp = ext3::mul(gamma, num); + tmp = ext3::mul(tmp, inv_h_i); + result = ext3::add(result, tmp); + } + + uint64_t num_total_cols = num_main + num_aux; + + // Main trace terms (base column - ext3 OOD) + for (uint64_t j = 0; j < num_main; ++j) { + uint64_t t_val = main_lde[j * lde_stride + row]; + for (uint64_t k = 0; k < num_eval_points; ++k) { + uint64_t idx = (j * num_eval_points + k) * 3; + ext3::Fe3 t_ood = {trace_ood[idx], trace_ood[idx + 1], trace_ood[idx + 2]}; + ext3::Fe3 num = { + goldilocks::sub(t_val, t_ood.a), + goldilocks::neg(t_ood.b), + goldilocks::neg(t_ood.c), + }; + ext3::Fe3 gamma = {gammas_tr[idx], gammas_tr[idx + 1], gammas_tr[idx + 2]}; + uint64_t inv_t_idx = (k * domain_size + i) * 3; + ext3::Fe3 inv_t_ki = {inv_t[inv_t_idx], inv_t[inv_t_idx + 1], inv_t[inv_t_idx + 2]}; + ext3::Fe3 tmp = ext3::mul(gamma, num); + tmp = ext3::mul(tmp, inv_t_ki); + result = ext3::add(result, tmp); + } + } + + // Aux trace terms (ext3 column - ext3 OOD) + for (uint64_t j = 0; j < num_aux; ++j) { + ext3::Fe3 t_val = { + aux_lde[(j * 3 + 0) * lde_stride + row], + aux_lde[(j * 3 + 1) * lde_stride + row], + aux_lde[(j * 3 + 2) * lde_stride + row], + }; + uint64_t trace_j = num_main + j; + for (uint64_t k = 0; k < num_eval_points; ++k) { + uint64_t idx = (trace_j * num_eval_points + k) * 3; + ext3::Fe3 t_ood = {trace_ood[idx], trace_ood[idx + 1], trace_ood[idx + 2]}; + ext3::Fe3 num = ext3::sub(t_val, t_ood); + ext3::Fe3 gamma = {gammas_tr[idx], gammas_tr[idx + 1], gammas_tr[idx + 2]}; + uint64_t inv_t_idx = (k * domain_size + i) * 3; + ext3::Fe3 inv_t_ki = {inv_t[inv_t_idx], inv_t[inv_t_idx + 1], inv_t[inv_t_idx + 2]}; + ext3::Fe3 tmp = ext3::mul(gamma, num); + tmp = ext3::mul(tmp, inv_t_ki); + result = ext3::add(result, tmp); + } + } + + uint64_t out_idx = i * 3; + deep_out[out_idx + 0] = result.a; + deep_out[out_idx + 1] = result.b; + deep_out[out_idx + 2] = result.c; + // Suppress unused param warning when num_total_cols not referenced. + (void)num_total_cols; +} diff --git a/crypto/math-cuda/kernels/ext3.cuh b/crypto/math-cuda/kernels/ext3.cuh new file mode 100644 index 000000000..2f4040714 --- /dev/null +++ b/crypto/math-cuda/kernels/ext3.cuh @@ -0,0 +1,121 @@ +// Goldilocks cubic extension on device: Fp3 = Fp[w] / (w^3 - 2) +// where Fp is Goldilocks (2^64 - 2^32 + 1). +// +// Layout matches the CPU `Degree3GoldilocksExtensionField` (see +// `crypto/math/src/field/extensions_goldilocks.rs`): an element is a +// 3-tuple `(a, b, c)` representing `a + b*w + c*w^2`. +// +// The reducible `w^3 = 2` means cross-term products get a factor of 2: +// (a0 + a1*w + a2*w^2) * (b0 + b1*w + b2*w^2) +// = (a0*b0 + 2*(a1*b2 + a2*b1)) +// + (a0*b1 + a1*b0 + 2*a2*b2) * w +// + (a0*b2 + a1*b1 + a2*b0) * w^2 +// +// We use the same dot-product-of-three folding as the CPU (which saves +// reductions by summing u128 products before `reduce128`). CUDA has +// `__umul64hi` so we implement `dot_product_3` inline. + +#pragma once +#include "goldilocks.cuh" + +namespace ext3 { + +struct Fe3 { + uint64_t a, b, c; +}; + +__device__ __forceinline__ Fe3 make(uint64_t a, uint64_t b, uint64_t c) { + Fe3 r = {a, b, c}; + return r; +} + +__device__ __forceinline__ Fe3 zero() { return make(0, 0, 0); } +__device__ __forceinline__ Fe3 one() { return make(1, 0, 0); } + +__device__ __forceinline__ Fe3 add(const Fe3 &x, const Fe3 &y) { + return make(goldilocks::add(x.a, y.a), + goldilocks::add(x.b, y.b), + goldilocks::add(x.c, y.c)); +} + +__device__ __forceinline__ Fe3 sub(const Fe3 &x, const Fe3 &y) { + return make(goldilocks::sub(x.a, y.a), + goldilocks::sub(x.b, y.b), + goldilocks::sub(x.c, y.c)); +} + +__device__ __forceinline__ Fe3 neg(const Fe3 &x) { + return make(goldilocks::neg(x.a), + goldilocks::neg(x.b), + goldilocks::neg(x.c)); +} + +/// Mixed: base * ext3 → ext3 (componentwise). +__device__ __forceinline__ Fe3 mul_base(const Fe3 &x, uint64_t s) { + return make(goldilocks::mul(x.a, s), + goldilocks::mul(x.b, s), + goldilocks::mul(x.c, s)); +} + +/// Dot-product of three (a0*b0 + a1*b1 + a2*b2) mod p, with one reduce128 +/// on the sum of three u128 products. Matches CPU `dot_product_3`. +__device__ __forceinline__ uint64_t dot3(uint64_t a0, uint64_t b0, + uint64_t a1, uint64_t b1, + uint64_t a2, uint64_t b2) { + // Split the sum of three u128 products into hi/lo u128 halves, then + // reduce once. We track overflow-count (at most 2) and add EPSILON^2 + // per overflow, matching the CPU path. + // prod_i = a_i * b_i (u128) + uint64_t lo0 = a0 * b0, hi0 = __umul64hi(a0, b0); + uint64_t lo1 = a1 * b1, hi1 = __umul64hi(a1, b1); + uint64_t lo2 = a2 * b2, hi2 = __umul64hi(a2, b2); + + // sum01 = prod0 + prod1 (in u128 lanes) + uint64_t s01_lo = lo0 + lo1; + uint64_t carry01 = (s01_lo < lo0) ? 1ULL : 0ULL; + uint64_t s01_hi = hi0 + hi1 + carry01; + uint32_t over1 = (s01_hi < hi0 + carry01) ? 1u : 0u; // low-pass overflow + + // sum012 = sum01 + prod2 + uint64_t s012_lo = s01_lo + lo2; + uint64_t carry012 = (s012_lo < s01_lo) ? 1ULL : 0ULL; + uint64_t s012_hi = s01_hi + hi2 + carry012; + uint32_t over2 = (s012_hi < hi2 + carry012) ? 1u : 0u; + + uint64_t reduced = goldilocks::reduce128(s012_lo, s012_hi); + + uint32_t overflow_count = over1 + over2; + if (overflow_count > 0) { + // 2^128 mod p = EPSILON^2 (= (2^32 - 1)^2). + uint64_t eps = goldilocks::EPSILON; + uint64_t eps_sq = eps * eps; + reduced = goldilocks::add_no_canonicalize(reduced, eps_sq); + if (overflow_count > 1) { + reduced = goldilocks::add_no_canonicalize(reduced, eps_sq); + } + } + return reduced; +} + +/// Full ext3 × ext3 multiplication (matches CPU +/// `Degree3GoldilocksExtensionField::mul`). +__device__ __forceinline__ Fe3 mul(const Fe3 &x, const Fe3 &y) { + // c0 = x.a*y.a + x.b*(2*y.c) + x.c*(2*y.b) + // c1 = x.a*y.b + x.b*y.a + x.c*(2*y.c) + // c2 = x.a*y.c + x.b*y.b + x.c*y.a + uint64_t b1_2 = goldilocks::add(y.b, y.b); + uint64_t b2_2 = goldilocks::add(y.c, y.c); + + uint64_t c0 = dot3(x.a, y.a, x.b, b2_2, x.c, b1_2); + uint64_t c1 = dot3(x.a, y.b, x.b, y.a, x.c, b2_2); + uint64_t c2 = dot3(x.a, y.c, x.b, y.b, x.c, y.a); + return make(c0, c1, c2); +} + +__device__ __forceinline__ Fe3 canonical(const Fe3 &x) { + return make(goldilocks::canonical(x.a), + goldilocks::canonical(x.b), + goldilocks::canonical(x.c)); +} + +} // namespace ext3 diff --git a/crypto/math-cuda/kernels/fri.cu b/crypto/math-cuda/kernels/fri.cu new file mode 100644 index 000000000..2307711cf --- /dev/null +++ b/crypto/math-cuda/kernels/fri.cu @@ -0,0 +1,59 @@ +// R4 FRI fold + twiddle-update kernels on device. The host orchestrator +// loops log₂(N) times: sample zeta on host → fold on device → keccak leaves +// + tree on device → D2H the root → transcript-append on host → update +// twiddles on device. +// +// Layout: ext3 evaluations are stored INTERLEAVED as +// `[a0,b0,c0, a1,b1,c1, ...]` — same layout the deep-poly LDE output +// already produces. Twiddles are base-field, one u64 per entry. + +#include "goldilocks.cuh" +#include "ext3.cuh" + +// fold_evaluations_in_place: +// out[j] = (lo + hi) + inv_tw[j] * zeta * (lo - hi) +// where lo = evals[2j], hi = evals[2j+1]. Both lo/hi and zeta are ext3. +// inv_tw[j] is a base-field twiddle (F × E → E). +// +// Writes N/2 ext3 outputs (3 * n_out u64 total) into `out`. `in` is the +// previous layer of 2 * n_out ext3 values (6 * n_out u64 total). +extern "C" __global__ void fri_fold_ext3( + const uint64_t *in, // 3 * 2*n_out u64 (ext3 interleaved) + uint64_t n_out, // number of output ext3 elements (= N/2) + const uint64_t *inv_tw, // n_out base-field twiddles + const uint64_t *zeta, // 3 u64 (ext3) + uint64_t *out) { // 3 * n_out u64 (ext3 interleaved) + uint64_t j = (uint64_t)blockIdx.x * blockDim.x + threadIdx.x; + if (j >= n_out) return; + + const uint64_t *lo_p = in + 2 * j * 3; + const uint64_t *hi_p = lo_p + 3; + + ext3::Fe3 lo = ext3::make(lo_p[0], lo_p[1], lo_p[2]); + ext3::Fe3 hi = ext3::make(hi_p[0], hi_p[1], hi_p[2]); + ext3::Fe3 sum = ext3::add(lo, hi); + ext3::Fe3 diff = ext3::sub(lo, hi); + + ext3::Fe3 z = ext3::make(zeta[0], zeta[1], zeta[2]); + ext3::Fe3 zd = ext3::mul(z, diff); // ext3 × ext3 = ext3 + uint64_t tw = inv_tw[j]; + ext3::Fe3 tzd = ext3::mul_base(zd, tw); // base × ext3 = ext3 (componentwise) + ext3::Fe3 res = ext3::add(sum, tzd); + + uint64_t *out_p = out + j * 3; + out_p[0] = res.a; + out_p[1] = res.b; + out_p[2] = res.c; +} + +// update_twiddles_in_place: new[j] = old[2j]². Writes in-place — caller +// must ensure the kernel is not reading the same index concurrently. Since +// we read `old[2j]` and write `new[j]` with j < 2j, there's no aliasing. +extern "C" __global__ void fri_update_twiddles( + uint64_t *tw, + uint64_t n_out) { + uint64_t j = (uint64_t)blockIdx.x * blockDim.x + threadIdx.x; + if (j >= n_out) return; + uint64_t old = tw[2 * j]; + tw[j] = goldilocks::mul(old, old); +} diff --git a/crypto/math-cuda/kernels/goldilocks.cuh b/crypto/math-cuda/kernels/goldilocks.cuh new file mode 100644 index 000000000..5e296a390 --- /dev/null +++ b/crypto/math-cuda/kernels/goldilocks.cuh @@ -0,0 +1,69 @@ +// Goldilocks field on device. Ports `crypto/math/src/field/goldilocks.rs` one-to-one: +// - Representation: non-canonical u64 in [0, 2^64). Canonicalise only at boundaries. +// - Prime: 2^64 - 2^32 + 1. +// - Reduction: exploits 2^64 ≡ EPSILON (mod p) and 2^96 ≡ -1 (mod p). +// +// The arithmetic here must produce bit-identical u64 outputs to the CPU path so +// LDE parity tests can assert raw equality. + +#pragma once +#include + +namespace goldilocks { + +__device__ constexpr uint64_t PRIME = 0xFFFFFFFF00000001ULL; +__device__ constexpr uint64_t EPSILON = 0xFFFFFFFFULL; // 2^32 - 1 + +__device__ __forceinline__ uint64_t add_no_canonicalize(uint64_t x, uint64_t y) { + // Mirror of `add_no_canonicalize_trashing_input`: one add, one EPSILON bump on carry. + uint64_t sum = x + y; + return sum + (sum < x ? EPSILON : 0ULL); +} + +__device__ __forceinline__ uint64_t add(uint64_t a, uint64_t b) { + uint64_t sum = a + b; + uint64_t over1 = (sum < a) ? EPSILON : 0ULL; + uint64_t sum2 = sum + over1; + uint64_t over2 = (sum2 < sum) ? EPSILON : 0ULL; + return sum2 + over2; +} + +__device__ __forceinline__ uint64_t sub(uint64_t a, uint64_t b) { + uint64_t diff = a - b; + uint64_t under1 = (a < b) ? EPSILON : 0ULL; + uint64_t diff2 = diff - under1; + uint64_t under2 = (diff2 > diff) ? EPSILON : 0ULL; + return diff2 - under2; +} + +__device__ __forceinline__ uint64_t reduce128(uint64_t lo, uint64_t hi) { + uint64_t x_hi_hi = hi >> 32; + uint64_t x_hi_lo = hi & EPSILON; + + // 2^96 ≡ -1 (mod p): subtract x_hi_hi from lo, EPSILON-correct on borrow. + uint64_t t0 = lo - x_hi_hi; + if (lo < x_hi_hi) t0 -= EPSILON; + + // 2^64 ≡ EPSILON (mod p): x_hi_lo * EPSILON = (x_hi_lo << 32) - x_hi_lo. + uint64_t t1 = (x_hi_lo << 32) - x_hi_lo; + + return add_no_canonicalize(t0, t1); +} + +__device__ __forceinline__ uint64_t mul(uint64_t a, uint64_t b) { + uint64_t lo = a * b; + uint64_t hi = __umul64hi(a, b); + return reduce128(lo, hi); +} + +__device__ __forceinline__ uint64_t neg(uint64_t a) { + // `a` may be non-canonical. Canonicalise first, then p - a (or 0). + uint64_t canon = (a >= PRIME) ? (a - PRIME) : a; + return canon == 0 ? 0 : (PRIME - canon); +} + +__device__ __forceinline__ uint64_t canonical(uint64_t a) { + return (a >= PRIME) ? (a - PRIME) : a; +} + +} // namespace goldilocks diff --git a/crypto/math-cuda/kernels/inverse.cu b/crypto/math-cuda/kernels/inverse.cu new file mode 100644 index 000000000..65d04c5d3 --- /dev/null +++ b/crypto/math-cuda/kernels/inverse.cu @@ -0,0 +1,296 @@ +// Parallel Montgomery batch inverse over ext3, plus a compute-denoms +// helper for R3 OOD / R4 DEEP preludes. +// +// Batch inverse strategy (chunk-based parallel scan): +// +// 1. Chunk-local forward scan: each thread serially computes the +// prefix product of its chunk of `C = ceil(N / K)` ext3 values; +// writes the chunk output in place and posts its chunk total to +// `chunk_totals[thread_id]`. +// 2. Single-block scan of `chunk_totals` (K ≤ 1024 for our shapes, +// fits one block). +// 3. Chunk-local apply: each thread multiplies its chunk's local +// prefix by the exclusive-scan offset from step 2, producing the +// global forward prefix. +// 4. Mirror (1-3) in reverse for the suffix. +// 5. Single-thread kernel inverts total = prefix[N-1]. +// 6. Pointwise combine: `inv[i] = prefix[i-1] * suffix[i+1] * inv_total` +// (with prefix[-1] = suffix[N] = 1). One thread per element. +// +// Ext3 multiply is commutative in the field (it's a field, not just a +// ring), so prefix-product scans are well-defined. Layout is ext3 +// INTERLEAVED: one u64 triple per element, 3*N u64s total. + +#include "goldilocks.cuh" +#include "ext3.cuh" + +#define INV_BLOCK 256 + +// --------------------------------------------------------------------------- +// B.1: compute denoms for R4 DEEP and R3 OOD. +// +// denoms[k*n + i] = x[i * stride] - z[k] +// where `x` is a base-field coset (read at stride `stride`), `z` is an +// ext3 array of `k_scalars` entries (z^K and/or z·ω^k), and `n` is the +// trace-size count. Output is flat ext3 interleaved. +// --------------------------------------------------------------------------- +extern "C" __global__ void compute_denoms_ext3( + const uint64_t *x_base, // base-field LDE coset points + uint64_t stride, // read stride (blowup_factor for R4) + const uint64_t *z_scalars, // k_scalars * 3 u64 (ext3 interleaved) + uint64_t k_scalars, + uint64_t n, + uint64_t *denoms_out) { // k_scalars * n * 3 u64 + uint64_t tid = (uint64_t)blockIdx.x * blockDim.x + threadIdx.x; + uint64_t total = k_scalars * n; + if (tid >= total) return; + + uint64_t k = tid / n; + uint64_t i = tid - k * n; + + uint64_t x_i = x_base[i * stride]; + uint64_t z_a = z_scalars[k * 3 + 0]; + uint64_t z_b = z_scalars[k * 3 + 1]; + uint64_t z_c = z_scalars[k * 3 + 2]; + + // base - ext3 = ext3 ( (x_i - z_a), -z_b, -z_c ) + uint64_t out_a = goldilocks::sub(x_i, z_a); + uint64_t out_b = goldilocks::neg(z_b); + uint64_t out_c = goldilocks::neg(z_c); + + uint64_t out_idx = tid * 3; + denoms_out[out_idx + 0] = out_a; + denoms_out[out_idx + 1] = out_b; + denoms_out[out_idx + 2] = out_c; +} + +// --------------------------------------------------------------------------- +// B.2 chunk-scan primitives for batch inverse. +// +// `a_in` is the input array of N ext3 elements (3*N u64, interleaved). +// `prefix_out` receives prefix[i] = prod(a[0..=i]) for all i. +// `chunk_totals` receives the per-chunk total (one ext3 per chunk). +// +// Each thread owns a contiguous chunk of C elements. With K=256 threads +// per block and a single block, we can handle up to 256*C elements. +// For N up to ~1M, C ≈ 4096, so one thread does ~4k ext3 multiplies +// serially in shmem-free fashion. Depth = O(C) + O(K) + O(C); with +// K=256 threads running in parallel, the `O(C)` phases parallelise +// perfectly across threads. +// +// For cleanliness, we launch as grid=1, block=K=256. For N up to 2^20 +// that's fine; if we ever need N > 256 * C_max, we'd recurse. +// --------------------------------------------------------------------------- + +// Phase 1 & 3 fused into one kernel would require shmem across phases. +// Splitting makes each kernel simpler. + +// Phase 1: chunk-local forward scan. Also emits chunk_totals. +extern "C" __global__ void chunk_prefix_scan_ext3( + const uint64_t *a_in, // 3 * n u64 (ext3 interleaved) + uint64_t n, + uint64_t c_per_thread, // C = ceil(n / K) + uint64_t *prefix_out, // 3 * n u64 + uint64_t *chunk_totals) { // 3 * K u64 + uint32_t tid = threadIdx.x; + uint64_t start = (uint64_t)tid * c_per_thread; + uint64_t end = min(start + c_per_thread, n); + + ext3::Fe3 acc = ext3::one(); + for (uint64_t i = start; i < end; ++i) { + ext3::Fe3 e = {a_in[i * 3 + 0], a_in[i * 3 + 1], a_in[i * 3 + 2]}; + acc = ext3::mul(acc, e); + prefix_out[i * 3 + 0] = acc.a; + prefix_out[i * 3 + 1] = acc.b; + prefix_out[i * 3 + 2] = acc.c; + } + chunk_totals[tid * 3 + 0] = acc.a; + chunk_totals[tid * 3 + 1] = acc.b; + chunk_totals[tid * 3 + 2] = acc.c; +} + +// Phase 2: exclusive prefix scan of chunk_totals, single-threaded. +// scan_out[0] = 1, scan_out[i] = prod(chunk_totals[0..i]). +extern "C" __global__ void exclusive_scan_of_totals_ext3( + const uint64_t *chunk_totals, // 3 * K u64 + uint64_t k, + uint64_t *scan_out) { // 3 * K u64 + if (threadIdx.x != 0 || blockIdx.x != 0) return; + ext3::Fe3 acc = ext3::one(); + scan_out[0] = acc.a; + scan_out[1] = acc.b; + scan_out[2] = acc.c; + for (uint64_t i = 1; i < k; ++i) { + ext3::Fe3 ct = { + chunk_totals[(i - 1) * 3 + 0], + chunk_totals[(i - 1) * 3 + 1], + chunk_totals[(i - 1) * 3 + 2], + }; + acc = ext3::mul(acc, ct); + scan_out[i * 3 + 0] = acc.a; + scan_out[i * 3 + 1] = acc.b; + scan_out[i * 3 + 2] = acc.c; + } +} + +// Phase 3: apply per-chunk offset to local scan result. +// global_prefix[i] = offsets[thread] * local_prefix[i] +extern "C" __global__ void apply_scan_offsets_ext3( + uint64_t *prefix_inout, // 3 * n u64 (written in phase 1, rewritten here) + uint64_t n, + uint64_t c_per_thread, + const uint64_t *offsets) { // 3 * K u64 + uint32_t tid = threadIdx.x; + uint64_t start = (uint64_t)tid * c_per_thread; + uint64_t end = min(start + c_per_thread, n); + + ext3::Fe3 off = { + offsets[tid * 3 + 0], + offsets[tid * 3 + 1], + offsets[tid * 3 + 2], + }; + for (uint64_t i = start; i < end; ++i) { + ext3::Fe3 local = { + prefix_inout[i * 3 + 0], + prefix_inout[i * 3 + 1], + prefix_inout[i * 3 + 2], + }; + ext3::Fe3 g = ext3::mul(off, local); + prefix_inout[i * 3 + 0] = g.a; + prefix_inout[i * 3 + 1] = g.b; + prefix_inout[i * 3 + 2] = g.c; + } +} + +// Reverse-scan phase 1: chunk-local reverse prefix. +// suffix_out[i] = prod(a[i..chunk_end]) (within chunk only) +// chunk_totals[tid] = suffix_out[chunk_start] (= full chunk product) +extern "C" __global__ void chunk_suffix_scan_ext3( + const uint64_t *a_in, + uint64_t n, + uint64_t c_per_thread, + uint64_t *suffix_out, + uint64_t *chunk_totals) { + uint32_t tid = threadIdx.x; + uint64_t start = (uint64_t)tid * c_per_thread; + // Walk backward; acc starts at 1 and accumulates a[end-1], a[end-2], ... + // Empty chunks (start >= n) fall through with acc = 1 so that + // chunk_totals receives the identity, matching the prefix-scan kernel. + ext3::Fe3 acc = ext3::one(); + if (start < n) { + uint64_t end = min(start + c_per_thread, n); + for (uint64_t ri = end; ri > start; --ri) { + uint64_t i = ri - 1; + ext3::Fe3 e = {a_in[i * 3 + 0], a_in[i * 3 + 1], a_in[i * 3 + 2]}; + acc = ext3::mul(acc, e); + suffix_out[i * 3 + 0] = acc.a; + suffix_out[i * 3 + 1] = acc.b; + suffix_out[i * 3 + 2] = acc.c; + } + } + chunk_totals[tid * 3 + 0] = acc.a; + chunk_totals[tid * 3 + 1] = acc.b; + chunk_totals[tid * 3 + 2] = acc.c; +} + +// Exclusive reverse scan of chunk totals. +// scan_out[K-1] = 1 +// scan_out[k] = prod(chunk_totals[k+1..K]) +extern "C" __global__ void exclusive_reverse_scan_of_totals_ext3( + const uint64_t *chunk_totals, + uint64_t k, + uint64_t *scan_out) { + if (threadIdx.x != 0 || blockIdx.x != 0) return; + ext3::Fe3 acc = ext3::one(); + if (k == 0) return; + scan_out[(k - 1) * 3 + 0] = acc.a; + scan_out[(k - 1) * 3 + 1] = acc.b; + scan_out[(k - 1) * 3 + 2] = acc.c; + for (int64_t i = (int64_t)k - 2; i >= 0; --i) { + ext3::Fe3 ct = { + chunk_totals[(i + 1) * 3 + 0], + chunk_totals[(i + 1) * 3 + 1], + chunk_totals[(i + 1) * 3 + 2], + }; + acc = ext3::mul(acc, ct); + scan_out[i * 3 + 0] = acc.a; + scan_out[i * 3 + 1] = acc.b; + scan_out[i * 3 + 2] = acc.c; + } +} + +// Apply reverse offsets. +extern "C" __global__ void apply_reverse_scan_offsets_ext3( + uint64_t *suffix_inout, + uint64_t n, + uint64_t c_per_thread, + const uint64_t *offsets) { + uint32_t tid = threadIdx.x; + uint64_t start = (uint64_t)tid * c_per_thread; + if (start >= n) return; + uint64_t end = min(start + c_per_thread, n); + + ext3::Fe3 off = { + offsets[tid * 3 + 0], + offsets[tid * 3 + 1], + offsets[tid * 3 + 2], + }; + for (uint64_t i = start; i < end; ++i) { + ext3::Fe3 local = { + suffix_inout[i * 3 + 0], + suffix_inout[i * 3 + 1], + suffix_inout[i * 3 + 2], + }; + ext3::Fe3 g = ext3::mul(off, local); + suffix_inout[i * 3 + 0] = g.a; + suffix_inout[i * 3 + 1] = g.b; + suffix_inout[i * 3 + 2] = g.c; + } +} + +// Same fix for the forward apply_scan_offsets: threads whose chunks are +// empty must not write past end-of-array. (chunk_prefix_scan already +// behaves correctly because the start..end range is empty; apply just +// needs to handle start >= n gracefully — it already does by the same +// empty-range logic. No change needed there, just documenting.) + +// Final combine: inv[i] = pre_excl[i] * suf_excl[i] * inv_total +// where pre_excl[i] = prefix[i-1] (with prefix[-1] = 1) and +// suf_excl[i] = suffix[i+1] (with suffix[N] = 1). +// +// Instead of creating separate pre_excl / suf_excl arrays, we pass the +// inclusive prefix / suffix arrays and shift the index here. +extern "C" __global__ void batch_inverse_combine_ext3( + const uint64_t *prefix_incl, // 3 * n u64; prefix_incl[i] = prod(a[0..=i]) + const uint64_t *suffix_incl, // 3 * n u64; suffix_incl[i] = prod(a[i..n-1]) + const uint64_t *inv_total_ptr, // 3 u64 + uint64_t n, + uint64_t *inv_out) { // 3 * n u64 + uint64_t i = (uint64_t)blockIdx.x * blockDim.x + threadIdx.x; + if (i >= n) return; + + ext3::Fe3 pre; + if (i == 0) { + pre = ext3::one(); + } else { + pre.a = prefix_incl[(i - 1) * 3 + 0]; + pre.b = prefix_incl[(i - 1) * 3 + 1]; + pre.c = prefix_incl[(i - 1) * 3 + 2]; + } + ext3::Fe3 suf; + if (i + 1 >= n) { + suf = ext3::one(); + } else { + suf.a = suffix_incl[(i + 1) * 3 + 0]; + suf.b = suffix_incl[(i + 1) * 3 + 1]; + suf.c = suffix_incl[(i + 1) * 3 + 2]; + } + ext3::Fe3 inv_tot = {inv_total_ptr[0], inv_total_ptr[1], inv_total_ptr[2]}; + + ext3::Fe3 r = ext3::mul(pre, suf); + r = ext3::mul(r, inv_tot); + + inv_out[i * 3 + 0] = r.a; + inv_out[i * 3 + 1] = r.b; + inv_out[i * 3 + 2] = r.c; +} diff --git a/crypto/math-cuda/kernels/keccak.cu b/crypto/math-cuda/kernels/keccak.cu new file mode 100644 index 000000000..68ddce3b4 --- /dev/null +++ b/crypto/math-cuda/kernels/keccak.cu @@ -0,0 +1,347 @@ +// CUDA Keccak-256 (original Keccak, NOT SHA3-256 — uses 0x01 padding delimiter). +// +// Used by the lambda-vm prover's Merkle commit: +// leaf = Keccak-256(concat(col_0[br_idx].to_be_bytes(), col_1[br_idx].to_be_bytes(), …)) +// where `br_idx = bit_reverse(row_idx, log_num_rows)` and each element is +// written in BIG-ENDIAN canonical form (per `FieldElement::write_bytes_be`). +// +// Keccak state is 5x5 lanes of u64, interpreted little-endian. Rate = 136 B +// (17 lanes) for 256-bit output, capacity = 64 B (8 lanes). +// +// Since every input byte is u64-aligned (each field element is 8 or 24 bytes), +// we can absorb lane-by-lane instead of byte-by-byte. Canonicalise + byte-swap +// each u64 on read to turn a BE-serialised element into its LE-interpreted +// lane value. + +#include +#include "goldilocks.cuh" + +__device__ __constant__ uint64_t KECCAK_RC[24] = { + 0x0000000000000001ULL, 0x0000000000008082ULL, 0x800000000000808aULL, + 0x8000000080008000ULL, 0x000000000000808bULL, 0x0000000080000001ULL, + 0x8000000080008081ULL, 0x8000000000008009ULL, 0x000000000000008aULL, + 0x0000000000000088ULL, 0x0000000080008009ULL, 0x000000008000000aULL, + 0x000000008000808bULL, 0x800000000000008bULL, 0x8000000000008089ULL, + 0x8000000000008003ULL, 0x8000000000008002ULL, 0x8000000000000080ULL, + 0x000000000000800aULL, 0x800000008000000aULL, 0x8000000080008081ULL, + 0x8000000000008080ULL, 0x0000000080000001ULL, 0x8000000080008008ULL, +}; + +// Rotation offsets indexed by lane position x + 5*y. Standard Keccak rho. +__device__ __constant__ uint32_t KECCAK_RHO_OFFSETS[25] = { + 0, 1, 62, 28, 27, // y=0: x=0..4 + 36, 44, 6, 55, 20, // y=1 + 3, 10, 43, 25, 39, // y=2 + 41, 45, 15, 21, 8, // y=3 + 18, 2, 61, 56, 14, // y=4 +}; + +__device__ __forceinline__ uint64_t rotl64(uint64_t x, uint32_t n) { + return (n == 0) ? x : ((x << n) | (x >> (64 - n))); +} + +__device__ __forceinline__ uint64_t bswap64(uint64_t x) { + // Reverse byte order: turns a BE-serialised u64 into its LE-read lane. + x = ((x & 0x00ff00ff00ff00ffULL) << 8) | ((x & 0xff00ff00ff00ff00ULL) >> 8); + x = ((x & 0x0000ffff0000ffffULL) << 16) | ((x & 0xffff0000ffff0000ULL) >> 16); + return (x << 32) | (x >> 32); +} + +__device__ __forceinline__ void keccak_f1600(uint64_t st[25]) { + uint64_t C[5], D[5], B[25]; + #pragma unroll + for (int r = 0; r < 24; ++r) { + // Theta + #pragma unroll + for (int x = 0; x < 5; ++x) { + C[x] = st[x] ^ st[x + 5] ^ st[x + 10] ^ st[x + 15] ^ st[x + 20]; + } + #pragma unroll + for (int x = 0; x < 5; ++x) { + D[x] = C[(x + 4) % 5] ^ rotl64(C[(x + 1) % 5], 1); + } + #pragma unroll + for (int y = 0; y < 5; ++y) { + #pragma unroll + for (int x = 0; x < 5; ++x) { + st[x + 5 * y] ^= D[x]; + } + } + + // Rho + Pi: B[pi(x,y)] = rotl(st[x,y], rho(x,y)) + // pi: (x', y') = (y, (2x + 3y) mod 5) + #pragma unroll + for (int y = 0; y < 5; ++y) { + #pragma unroll + for (int x = 0; x < 5; ++x) { + int nx = y; + int ny = (2 * x + 3 * y) % 5; + B[nx + 5 * ny] = rotl64(st[x + 5 * y], KECCAK_RHO_OFFSETS[x + 5 * y]); + } + } + + // Chi + #pragma unroll + for (int y = 0; y < 5; ++y) { + #pragma unroll + for (int x = 0; x < 5; ++x) { + st[x + 5 * y] = + B[x + 5 * y] ^ ((~B[((x + 1) % 5) + 5 * y]) & B[((x + 2) % 5) + 5 * y]); + } + } + + // Iota + st[0] ^= KECCAK_RC[r]; + } +} + +// --------------------------------------------------------------------------- +// Helper: absorb one 8-byte lane (already in lane form — i.e. LE interpretation +// of the BE-serialised u64) into the sponge at `rate_pos` (in bytes). Permutes +// when a full 136-byte block has been absorbed. +// --------------------------------------------------------------------------- +__device__ __forceinline__ void absorb_lane(uint64_t st[25], + uint32_t &rate_pos, + uint64_t lane) { + st[rate_pos / 8] ^= lane; + rate_pos += 8; + if (rate_pos == 136) { + keccak_f1600(st); + rate_pos = 0; + } +} + +// --------------------------------------------------------------------------- +// After all data lanes absorbed, apply Keccak (pre-SHA-3) padding: a single +// 0x01 byte at the current position, then bit 0x80 on the last rate byte +// (byte 135 = last byte of lane 16). Then permute and squeeze 32 bytes from +// the first four lanes in LE order. +// --------------------------------------------------------------------------- +__device__ __forceinline__ void finalize_keccak256(uint64_t st[25], + uint32_t rate_pos, + uint8_t *out32) { + // 0x01 at rate_pos + st[rate_pos / 8] ^= ((uint64_t)0x01) << ((rate_pos & 7) * 8); + // 0x80 at byte 135 (last byte of lane 16) + st[16] ^= ((uint64_t)0x80) << 56; + keccak_f1600(st); + + // Squeeze 32 bytes: 4 lanes, each LE-serialised. + #pragma unroll + for (int i = 0; i < 4; ++i) { + uint64_t lane = st[i]; + #pragma unroll + for (int b = 0; b < 8; ++b) { + out32[i * 8 + b] = (uint8_t)((lane >> (b * 8)) & 0xff); + } + } +} + +// --------------------------------------------------------------------------- +// Goldilocks BASE-FIELD leaf hashing. +// +// For output row `row_idx` (natural order), the leaf hashes the canonical BE +// byte representation of `columns[c][bit_reverse(row_idx, log_num_rows)]` for +// `c` in `[0, num_cols)`, concatenated in column order. Writes 32 bytes to +// `hashed_leaves_out[row_idx * 32 ..]`. +// +// `columns_base_ptr` points to a `num_cols * col_stride * u64` buffer; column +// `c` is the contiguous slab `[c*col_stride .. c*col_stride + num_rows]`. The +// remaining `col_stride - num_rows` entries (if any) are ignored. +// --------------------------------------------------------------------------- +extern "C" __global__ void keccak256_leaves_base_batched( + const uint64_t *columns_base_ptr, + uint64_t col_stride, + uint64_t num_cols, + uint64_t num_rows, + uint64_t log_num_rows, + uint8_t *hashed_leaves_out) { + uint64_t tid = (uint64_t)blockIdx.x * blockDim.x + threadIdx.x; + if (tid >= num_rows) return; + + // Bit-reverse the row index so we read columns at `br` but write the + // hashed leaf at `tid` — matching the CPU `commit_columns_bit_reversed`. + uint64_t br = __brevll(tid) >> (64 - log_num_rows); + + uint64_t st[25]; + #pragma unroll + for (int i = 0; i < 25; ++i) st[i] = 0; + + uint32_t rate_pos = 0; + for (uint64_t c = 0; c < num_cols; ++c) { + uint64_t v = columns_base_ptr[c * col_stride + br]; + // Canonicalise to match `canonical_u64().to_be_bytes()` on host. + uint64_t canon = goldilocks::canonical(v); + // The on-disk leaf bytes are canon.to_be_bytes(); Keccak reads those + // as a LE lane, which equals bswap64(canon). + uint64_t lane = bswap64(canon); + absorb_lane(st, rate_pos, lane); + } + + finalize_keccak256(st, rate_pos, hashed_leaves_out + tid * 32); +} + +// --------------------------------------------------------------------------- +// Goldilocks EXT3 leaf hashing (3 base-field components per ext3 element). +// +// Components live in three separate base-field slabs (our de-interleaved +// layout). Column `c` component `k` is at `columns_base_ptr[(c*3 + k)*col_stride +// + br]`. Per-element BE bytes are `[comp0, comp1, comp2]` each 8 BE bytes +// (matches `FieldElement::::write_bytes_be`). +// --------------------------------------------------------------------------- +extern "C" __global__ void keccak256_leaves_ext3_batched( + const uint64_t *columns_base_ptr, + uint64_t col_stride, + uint64_t num_cols, // number of ext3 columns (NOT slabs) + uint64_t num_rows, + uint64_t log_num_rows, + uint8_t *hashed_leaves_out) { + uint64_t tid = (uint64_t)blockIdx.x * blockDim.x + threadIdx.x; + if (tid >= num_rows) return; + uint64_t br = __brevll(tid) >> (64 - log_num_rows); + + uint64_t st[25]; + #pragma unroll + for (int i = 0; i < 25; ++i) st[i] = 0; + + uint32_t rate_pos = 0; + for (uint64_t c = 0; c < num_cols; ++c) { + #pragma unroll + for (int k = 0; k < 3; ++k) { + uint64_t v = columns_base_ptr[(c * 3 + (uint64_t)k) * col_stride + br]; + uint64_t canon = goldilocks::canonical(v); + uint64_t lane = bswap64(canon); + absorb_lane(st, rate_pos, lane); + } + } + + finalize_keccak256(st, rate_pos, hashed_leaves_out + tid * 32); +} + +// --------------------------------------------------------------------------- +// R2 composition-polynomial leaf hashing. +// +// Each leaf hashes `2 * num_parts` ext3 values taken from bit-reversed rows +// `br_0 = reverse_index(2*leaf_idx)` and `br_1 = reverse_index(2*leaf_idx+1)` +// across all `num_parts` parts, in (br_0 row: part 0..K-1) then (br_1 row: +// part 0..K-1) order. Each ext3 value is 3 base components × 8 BE bytes. +// +// Columns arrive in the de-interleaved 3-slab layout: part `p` component +// `k` is at `parts_base_ptr[(p*3 + k) * col_stride + row]`. +// --------------------------------------------------------------------------- +extern "C" __global__ void keccak_comp_poly_leaves_ext3( + const uint64_t *parts_base_ptr, + uint64_t col_stride, + uint64_t num_parts, + uint64_t num_rows, + uint64_t log_num_rows, + uint8_t *leaves_out) { + uint64_t tid = (uint64_t)blockIdx.x * blockDim.x + threadIdx.x; + uint64_t num_leaves = num_rows >> 1; + if (tid >= num_leaves) return; + + uint64_t br_0 = __brevll(2 * tid) >> (64 - log_num_rows); + uint64_t br_1 = __brevll(2 * tid + 1) >> (64 - log_num_rows); + + uint64_t st[25]; + #pragma unroll + for (int i = 0; i < 25; ++i) st[i] = 0; + + uint32_t rate_pos = 0; + // First row (br_0): part 0..K-1 × 3 components each. + for (uint64_t p = 0; p < num_parts; ++p) { + #pragma unroll + for (int k = 0; k < 3; ++k) { + uint64_t v = parts_base_ptr[(p * 3 + (uint64_t)k) * col_stride + br_0]; + uint64_t canon = goldilocks::canonical(v); + absorb_lane(st, rate_pos, bswap64(canon)); + } + } + // Second row (br_1). + for (uint64_t p = 0; p < num_parts; ++p) { + #pragma unroll + for (int k = 0; k < 3; ++k) { + uint64_t v = parts_base_ptr[(p * 3 + (uint64_t)k) * col_stride + br_1]; + uint64_t canon = goldilocks::canonical(v); + absorb_lane(st, rate_pos, bswap64(canon)); + } + } + + finalize_keccak256(st, rate_pos, leaves_out + tid * 32); +} + +// --------------------------------------------------------------------------- +// FRI layer leaf hashing. +// +// Each leaf hashes 2 consecutive ext3 values: Keccak256 over +// evals[2j].to_bytes_be() ++ evals[2j+1].to_bytes_be() +// = 48 BE bytes = 6 u64 BE lanes. No bit reversal; no column slab layout — +// the input is a single interleaved ext3 eval vector `[a0,a1,a2,b0,b1,b2,...]`. +// --------------------------------------------------------------------------- +extern "C" __global__ void keccak_fri_leaves_ext3( + const uint64_t *evals_interleaved, // 3 * num_evals u64s (ext3 interleaved) + uint64_t num_leaves, // = num_evals / 2 + uint8_t *leaves_out) { + uint64_t tid = (uint64_t)blockIdx.x * blockDim.x + threadIdx.x; + if (tid >= num_leaves) return; + + uint64_t st[25]; + #pragma unroll + for (int i = 0; i < 25; ++i) st[i] = 0; + uint32_t rate_pos = 0; + + const uint64_t *left = evals_interleaved + 2 * tid * 3; // 3 u64s + const uint64_t *right = left + 3; + #pragma unroll + for (int i = 0; i < 3; ++i) { + uint64_t canon = goldilocks::canonical(left[i]); + absorb_lane(st, rate_pos, bswap64(canon)); + } + #pragma unroll + for (int i = 0; i < 3; ++i) { + uint64_t canon = goldilocks::canonical(right[i]); + absorb_lane(st, rate_pos, bswap64(canon)); + } + + finalize_keccak256(st, rate_pos, leaves_out + tid * 32); +} + +// --------------------------------------------------------------------------- +// Merkle inner-tree pair hash: one level of the inner Merkle tree. +// +// `nodes` is the full Merkle node buffer (length `2*leaves_len - 1`, each +// element 32 bytes). `parent_begin` is the node-index offset of the first +// parent slot in this level; children live at `parent_begin + n_pairs`. +// The layout mirrors `crypto/crypto/src/merkle_tree/merkle.rs`: +// +// children: nodes[parent_begin + n_pairs .. parent_begin + 3 * n_pairs] +// parents: nodes[parent_begin .. parent_begin + n_pairs] +// +// Each thread hashes one child pair → one parent. Keccak-256 of the +// concatenation of two 32-byte siblings; identical to +// `FieldElementVectorBackend::hash_new_parent` on host. +// --------------------------------------------------------------------------- +extern "C" __global__ void keccak_merkle_level( + uint8_t *nodes, + uint64_t parent_begin, // node index (counted in 32-byte nodes) + uint64_t n_pairs) { + uint64_t tid = (uint64_t)blockIdx.x * blockDim.x + threadIdx.x; + if (tid >= n_pairs) return; + + uint64_t st[25]; + #pragma unroll + for (int i = 0; i < 25; ++i) st[i] = 0; + + uint32_t rate_pos = 0; + const uint64_t *left = reinterpret_cast( + nodes + (parent_begin + n_pairs + 2 * tid) * 32); + #pragma unroll + for (int i = 0; i < 4; ++i) absorb_lane(st, rate_pos, left[i]); + + const uint64_t *right = reinterpret_cast( + nodes + (parent_begin + n_pairs + 2 * tid + 1) * 32); + #pragma unroll + for (int i = 0; i < 4; ++i) absorb_lane(st, rate_pos, right[i]); + + finalize_keccak256(st, rate_pos, nodes + (parent_begin + tid) * 32); +} diff --git a/crypto/math-cuda/kernels/logup.cu b/crypto/math-cuda/kernels/logup.cu new file mode 100644 index 000000000..681b704ce --- /dev/null +++ b/crypto/math-cuda/kernels/logup.cu @@ -0,0 +1,457 @@ +// LogUp aux-trace-build: one pair's fingerprint compute + term assembly +// in two kernels. The host orchestrates: +// main_cols (already on device from R1) +// + bytecode descriptor for the pair (uploaded once per pair) +// → fingerprint kernel writes 2n ext3 fingerprints +// → batch_inverse_ext3 (existing) inverts in place (or into output) +// → term-assembly kernel writes n ext3 term values +// +// Wire format (packed C structs, shared with Rust serializer): +// +// struct FingerprintOp { +// uint8_t kind; // OP_PACK_* / OP_LINEAR +// uint8_t pad0[3]; +// uint32_t alpha_offset; // where to multiply by α into lc +// uint32_t start_col; // for Pack ops: first main-trace column +// uint32_t num_linear_terms; // for OP_LINEAR: count of terms that follow +// uint32_t linear_term_offset; // for OP_LINEAR: start in linear_terms[] +// uint32_t pad1[2]; // align to 32 bytes +// }; +// +// struct LinearTerm { +// uint8_t kind; // 0 = Column signed, 1 = Column unsigned, 2 = Constant +// uint8_t pad[3]; +// uint32_t column; +// int64_t value; // signed coefficient or signed constant +// }; +// +// struct MultiplicityDesc { +// uint8_t kind; // 0..6 mapping to Rust's Multiplicity variants +// uint8_t pad[3]; +// uint32_t cols[3]; // up to 3 columns (Sum3) +// uint32_t num_linear_terms; +// uint32_t linear_term_offset; +// }; +// +// All ops reference the same main_cols buffer and the same shared +// linear_terms buffer. + +#include "goldilocks.cuh" +#include "ext3.cuh" + +// Must match Rust-side `LogupOpKind`. +#define OP_PACK_DIRECT 0 +#define OP_PACK_WORD2L 1 +#define OP_PACK_WORD4L 2 +#define OP_PACK_DWORDWL 3 +#define OP_PACK_DWORDHHW 4 +#define OP_PACK_DWORDWHH 5 +#define OP_PACK_DWORDHL 6 +#define OP_PACK_DWORDBL 7 +#define OP_PACK_QUADHL 8 +#define OP_PACK_QUADWL 9 +#define OP_LINEAR 10 + +// PackingShifts (base field). +#define SHIFT_8 ((uint64_t)(1ULL << 8)) +#define SHIFT_16 ((uint64_t)(1ULL << 16)) +#define SHIFT_24 ((uint64_t)(1ULL << 24)) + +struct FingerprintOp { + uint8_t kind; + uint8_t pad0[3]; + uint32_t alpha_offset; + uint32_t start_col; + uint32_t num_linear_terms; + uint32_t linear_term_offset; + uint32_t pad1[2]; +}; + +struct LinearTerm { + uint8_t kind; // 0=Column, 2=Constant (Rust canonicalizes both into `value`) + uint8_t pad[3]; + uint32_t column; + uint64_t value; // canonical Goldilocks field element in [0, p) +}; + +struct MultiplicityDesc { + uint8_t kind; + uint8_t pad[3]; + uint32_t cols[3]; + uint32_t num_linear_terms; + uint32_t linear_term_offset; +}; + +#define MULT_ONE 0 +#define MULT_COLUMN 1 +#define MULT_SUM 2 +#define MULT_NEGATED 3 +#define MULT_DIFF 4 +#define MULT_SUM3 5 +#define MULT_LINEAR 6 + +// --------------------------------------------------------------------------- +// Helpers +// --------------------------------------------------------------------------- + +__device__ __forceinline__ uint64_t read_main(const uint64_t *main_cols, + uint64_t col_stride, + uint32_t col, + uint64_t row) { + // Column-major: col * col_stride + row. + return main_cols[(uint64_t)col * col_stride + row]; +} + +/// Evaluate a Linear term list at `row` → base-field element. +/// `lt.value` is already a canonical Goldilocks field element in [0, p); +/// the Rust serializer is responsible for the canonicalization so the +/// device can skip sign handling entirely. +__device__ __forceinline__ uint64_t eval_linear( + const uint64_t *main_cols, + uint64_t col_stride, + const LinearTerm *linear_terms, + uint32_t num_terms, + uint32_t offset, + uint64_t row) { + uint64_t result = 0; + for (uint32_t t = 0; t < num_terms; ++t) { + const LinearTerm < = linear_terms[offset + t]; + if (lt.kind == 2) { + // Constant. + result = goldilocks::add(result, lt.value); + } else { + // Column (signed or unsigned — canonical coefficient). + uint64_t v = read_main(main_cols, col_stride, lt.column, row); + uint64_t prod = goldilocks::mul(v, lt.value); + result = goldilocks::add(result, prod); + } + } + return result; +} + +/// Apply one fingerprint op: reads main_cols[*], multiplies by the +/// appropriate alpha power(s), accumulates into `acc`. +__device__ __forceinline__ void apply_fingerprint_op( + const uint64_t *main_cols, + uint64_t col_stride, + const LinearTerm *linear_terms, + const uint64_t *alpha_powers, // ext3 interleaved: 3*max_bus_elements u64 + const FingerprintOp &op, + uint64_t row, + ext3::Fe3 &acc) { + uint32_t ao = op.alpha_offset; + #define ALPHA(i) ext3::make( \ + alpha_powers[((ao) + (i)) * 3 + 0], \ + alpha_powers[((ao) + (i)) * 3 + 1], \ + alpha_powers[((ao) + (i)) * 3 + 2]) + uint32_t c = op.start_col; + + switch (op.kind) { + case OP_PACK_DIRECT: { + uint64_t v = read_main(main_cols, col_stride, c, row); + ext3::Fe3 ap = ALPHA(0); + acc = ext3::add(acc, ext3::mul_base(ap, v)); + break; + } + case OP_PACK_WORD2L: { + uint64_t v0 = read_main(main_cols, col_stride, c, row); + uint64_t v1 = read_main(main_cols, col_stride, c + 1, row); + uint64_t combined = goldilocks::add(v0, goldilocks::mul(v1, SHIFT_16)); + ext3::Fe3 ap = ALPHA(0); + acc = ext3::add(acc, ext3::mul_base(ap, combined)); + break; + } + case OP_PACK_WORD4L: { + uint64_t v0 = read_main(main_cols, col_stride, c, row); + uint64_t v1 = read_main(main_cols, col_stride, c + 1, row); + uint64_t v2 = read_main(main_cols, col_stride, c + 2, row); + uint64_t v3 = read_main(main_cols, col_stride, c + 3, row); + uint64_t t1 = goldilocks::mul(v1, SHIFT_8); + uint64_t t2 = goldilocks::mul(v2, SHIFT_16); + uint64_t t3 = goldilocks::mul(v3, SHIFT_24); + uint64_t combined = goldilocks::add(goldilocks::add(v0, t1), + goldilocks::add(t2, t3)); + ext3::Fe3 ap = ALPHA(0); + acc = ext3::add(acc, ext3::mul_base(ap, combined)); + break; + } + case OP_PACK_DWORDWL: { + uint64_t v0 = read_main(main_cols, col_stride, c, row); + uint64_t v1 = read_main(main_cols, col_stride, c + 1, row); + ext3::Fe3 ap0 = ALPHA(0), ap1 = ALPHA(1); + acc = ext3::add(acc, ext3::mul_base(ap0, v0)); + acc = ext3::add(acc, ext3::mul_base(ap1, v1)); + break; + } + case OP_PACK_DWORDHHW: { + // Direct + Word2L: col, col+1 -> word, col+2 -> half? No — spec: Direct + Word2L + // columns: [direct c0, word2l c1 c2] → (c0)*α0 + (c1 + c2 << 16)*α1 + uint64_t v0 = read_main(main_cols, col_stride, c, row); + uint64_t v1 = read_main(main_cols, col_stride, c + 1, row); + uint64_t v2 = read_main(main_cols, col_stride, c + 2, row); + ext3::Fe3 ap0 = ALPHA(0), ap1 = ALPHA(1); + acc = ext3::add(acc, ext3::mul_base(ap0, v0)); + uint64_t w = goldilocks::add(v1, goldilocks::mul(v2, SHIFT_16)); + acc = ext3::add(acc, ext3::mul_base(ap1, w)); + break; + } + case OP_PACK_DWORDWHH: { + uint64_t v0 = read_main(main_cols, col_stride, c, row); + uint64_t v1 = read_main(main_cols, col_stride, c + 1, row); + uint64_t v2 = read_main(main_cols, col_stride, c + 2, row); + ext3::Fe3 ap0 = ALPHA(0), ap1 = ALPHA(1); + uint64_t w = goldilocks::add(v0, goldilocks::mul(v1, SHIFT_16)); + acc = ext3::add(acc, ext3::mul_base(ap0, w)); + acc = ext3::add(acc, ext3::mul_base(ap1, v2)); + break; + } + case OP_PACK_DWORDHL: { + uint64_t v0 = read_main(main_cols, col_stride, c, row); + uint64_t v1 = read_main(main_cols, col_stride, c + 1, row); + uint64_t v2 = read_main(main_cols, col_stride, c + 2, row); + uint64_t v3 = read_main(main_cols, col_stride, c + 3, row); + ext3::Fe3 ap0 = ALPHA(0), ap1 = ALPHA(1); + uint64_t w0 = goldilocks::add(v0, goldilocks::mul(v1, SHIFT_16)); + uint64_t w1 = goldilocks::add(v2, goldilocks::mul(v3, SHIFT_16)); + acc = ext3::add(acc, ext3::mul_base(ap0, w0)); + acc = ext3::add(acc, ext3::mul_base(ap1, w1)); + break; + } + case OP_PACK_DWORDBL: { + // 2× Word4L at start_col and start_col+4 + ext3::Fe3 ap0 = ALPHA(0), ap1 = ALPHA(1); + for (int hi = 0; hi < 2; ++hi) { + uint32_t base = c + hi * 4; + uint64_t v0 = read_main(main_cols, col_stride, base, row); + uint64_t v1 = read_main(main_cols, col_stride, base + 1, row); + uint64_t v2 = read_main(main_cols, col_stride, base + 2, row); + uint64_t v3 = read_main(main_cols, col_stride, base + 3, row); + uint64_t t1 = goldilocks::mul(v1, SHIFT_8); + uint64_t t2 = goldilocks::mul(v2, SHIFT_16); + uint64_t t3 = goldilocks::mul(v3, SHIFT_24); + uint64_t w = goldilocks::add(goldilocks::add(v0, t1), + goldilocks::add(t2, t3)); + ext3::Fe3 ap = (hi == 0) ? ap0 : ap1; + acc = ext3::add(acc, ext3::mul_base(ap, w)); + } + break; + } + case OP_PACK_QUADHL: { + // 4× Word2L at start_col, start_col+2, ..., start_col+6 + for (int k = 0; k < 4; ++k) { + uint32_t base = c + k * 2; + uint64_t v0 = read_main(main_cols, col_stride, base, row); + uint64_t v1 = read_main(main_cols, col_stride, base + 1, row); + uint64_t w = goldilocks::add(v0, goldilocks::mul(v1, SHIFT_16)); + ext3::Fe3 ap = ALPHA(k); + acc = ext3::add(acc, ext3::mul_base(ap, w)); + } + break; + } + case OP_PACK_QUADWL: { + for (int k = 0; k < 4; ++k) { + uint64_t v = read_main(main_cols, col_stride, c + k, row); + ext3::Fe3 ap = ALPHA(k); + acc = ext3::add(acc, ext3::mul_base(ap, v)); + } + break; + } + case OP_LINEAR: { + uint64_t r = eval_linear(main_cols, col_stride, linear_terms, + op.num_linear_terms, op.linear_term_offset, row); + ext3::Fe3 ap = ALPHA(0); + acc = ext3::add(acc, ext3::mul_base(ap, r)); + break; + } + default: + break; + } + #undef ALPHA +} + +/// Compute one interaction pair's fingerprints: 2n ext3 values +/// `fp[0..n] = z - lc_a(row)`, `fp[n..2n] = z - lc_b(row)`. +extern "C" __global__ void logup_pair_fingerprint( + const uint64_t *main_cols, // main LDE, column-major, col_stride u64 per column + uint64_t col_stride, + uint64_t n, // trace_len + uint64_t bus_id_a, // base field + uint64_t bus_id_b, + const FingerprintOp *ops_a, // pair A ops + uint32_t ops_a_count, + const FingerprintOp *ops_b, // pair B ops + uint32_t ops_b_count, + const LinearTerm *linear_terms, + const uint64_t *alpha_powers, // 3 * max_bus_elements u64 + const uint64_t *z, // 3 u64 (ext3) + uint64_t *fp_out) { // 2n * 3 u64 (ext3 interleaved) + uint64_t row = (uint64_t)blockIdx.x * blockDim.x + threadIdx.x; + if (row >= n) return; + + ext3::Fe3 zf = ext3::make(z[0], z[1], z[2]); + + // Pair A + { + ext3::Fe3 alpha0 = ext3::make( + alpha_powers[0], alpha_powers[1], alpha_powers[2]); + ext3::Fe3 lc = ext3::mul_base(alpha0, bus_id_a); + for (uint32_t k = 0; k < ops_a_count; ++k) { + apply_fingerprint_op(main_cols, col_stride, linear_terms, + alpha_powers, ops_a[k], row, lc); + } + ext3::Fe3 fp = ext3::sub(zf, lc); + fp_out[row * 3 + 0] = fp.a; + fp_out[row * 3 + 1] = fp.b; + fp_out[row * 3 + 2] = fp.c; + } + + // Pair B (output at row + n) + { + ext3::Fe3 alpha0 = ext3::make( + alpha_powers[0], alpha_powers[1], alpha_powers[2]); + ext3::Fe3 lc = ext3::mul_base(alpha0, bus_id_b); + for (uint32_t k = 0; k < ops_b_count; ++k) { + apply_fingerprint_op(main_cols, col_stride, linear_terms, + alpha_powers, ops_b[k], row, lc); + } + ext3::Fe3 fp = ext3::sub(zf, lc); + uint64_t out_row = n + row; + fp_out[out_row * 3 + 0] = fp.a; + fp_out[out_row * 3 + 1] = fp.b; + fp_out[out_row * 3 + 2] = fp.c; + } +} + +/// Evaluate a Multiplicity descriptor at `row` → base-field value. +__device__ __forceinline__ uint64_t eval_multiplicity( + const uint64_t *main_cols, + uint64_t col_stride, + const LinearTerm *linear_terms, + const MultiplicityDesc &m, + uint64_t row) { + switch (m.kind) { + case MULT_ONE: + return 1; + case MULT_COLUMN: + return read_main(main_cols, col_stride, m.cols[0], row); + case MULT_SUM: { + uint64_t a = read_main(main_cols, col_stride, m.cols[0], row); + uint64_t b = read_main(main_cols, col_stride, m.cols[1], row); + return goldilocks::add(a, b); + } + case MULT_NEGATED: { + uint64_t v = read_main(main_cols, col_stride, m.cols[0], row); + return goldilocks::sub(1, v); + } + case MULT_DIFF: { + uint64_t a = read_main(main_cols, col_stride, m.cols[0], row); + uint64_t b = read_main(main_cols, col_stride, m.cols[1], row); + return goldilocks::sub(a, b); + } + case MULT_SUM3: { + uint64_t a = read_main(main_cols, col_stride, m.cols[0], row); + uint64_t b = read_main(main_cols, col_stride, m.cols[1], row); + uint64_t c = read_main(main_cols, col_stride, m.cols[2], row); + return goldilocks::add(goldilocks::add(a, b), c); + } + case MULT_LINEAR: + return eval_linear(main_cols, col_stride, linear_terms, + m.num_linear_terms, m.linear_term_offset, row); + default: + return 0; + } +} + +/// Term-assembly: reads inverted fingerprints + multiplicities, +/// produces the term column. +/// term[row] = (neg_a ? -1 : 1) * mult_a(row) * inv_fp_a[row] +/// + (neg_b ? -1 : 1) * mult_b(row) * inv_fp_b[row] +/// Multiplicities are base-field, inv_fp are ext3; result is ext3. +extern "C" __global__ void logup_pair_term_assembly( + const uint64_t *inv_fp, // 2n * 3 u64 (ext3 interleaved) + const uint64_t *main_cols, // main LDE + uint64_t col_stride, + uint64_t n, + const LinearTerm *linear_terms, + const MultiplicityDesc *mult_a, // device pointer to descriptor (1 struct) + const MultiplicityDesc *mult_b, + uint8_t negate_a, + uint8_t negate_b, + uint64_t *term_out) { // n * 3 u64 (ext3) + uint64_t row = (uint64_t)blockIdx.x * blockDim.x + threadIdx.x; + if (row >= n) return; + + uint64_t m_a_val = eval_multiplicity(main_cols, col_stride, linear_terms, + *mult_a, row); + uint64_t m_b_val = eval_multiplicity(main_cols, col_stride, linear_terms, + *mult_b, row); + + ext3::Fe3 inv_a = ext3::make( + inv_fp[row * 3 + 0], inv_fp[row * 3 + 1], inv_fp[row * 3 + 2]); + uint64_t row_b = n + row; + ext3::Fe3 inv_b = ext3::make( + inv_fp[row_b * 3 + 0], inv_fp[row_b * 3 + 1], inv_fp[row_b * 3 + 2]); + + ext3::Fe3 ta = ext3::mul_base(inv_a, m_a_val); + if (negate_a) ta = ext3::neg(ta); + ext3::Fe3 tb = ext3::mul_base(inv_b, m_b_val); + if (negate_b) tb = ext3::neg(tb); + + ext3::Fe3 t = ext3::add(ta, tb); + term_out[row * 3 + 0] = t.a; + term_out[row * 3 + 1] = t.b; + term_out[row * 3 + 2] = t.c; +} + +/// Single-pair variant (for the "absorbed" case with 1 interaction). +/// Computes fingerprints and term for a single interaction. +extern "C" __global__ void logup_single_fingerprint( + const uint64_t *main_cols, + uint64_t col_stride, + uint64_t n, + uint64_t bus_id, + const FingerprintOp *ops, + uint32_t ops_count, + const LinearTerm *linear_terms, + const uint64_t *alpha_powers, + const uint64_t *z, + uint64_t *fp_out) { // n * 3 u64 + uint64_t row = (uint64_t)blockIdx.x * blockDim.x + threadIdx.x; + if (row >= n) return; + + ext3::Fe3 zf = ext3::make(z[0], z[1], z[2]); + ext3::Fe3 alpha0 = ext3::make( + alpha_powers[0], alpha_powers[1], alpha_powers[2]); + ext3::Fe3 lc = ext3::mul_base(alpha0, bus_id); + for (uint32_t k = 0; k < ops_count; ++k) { + apply_fingerprint_op(main_cols, col_stride, linear_terms, + alpha_powers, ops[k], row, lc); + } + ext3::Fe3 fp = ext3::sub(zf, lc); + fp_out[row * 3 + 0] = fp.a; + fp_out[row * 3 + 1] = fp.b; + fp_out[row * 3 + 2] = fp.c; +} + +extern "C" __global__ void logup_single_term_assembly( + const uint64_t *inv_fp, // n * 3 u64 + const uint64_t *main_cols, + uint64_t col_stride, + uint64_t n, + const LinearTerm *linear_terms, + const MultiplicityDesc *mult, + uint8_t negate, + uint64_t *term_out) { + uint64_t row = (uint64_t)blockIdx.x * blockDim.x + threadIdx.x; + if (row >= n) return; + + uint64_t m = eval_multiplicity(main_cols, col_stride, linear_terms, + *mult, row); + ext3::Fe3 inv = ext3::make( + inv_fp[row * 3 + 0], inv_fp[row * 3 + 1], inv_fp[row * 3 + 2]); + ext3::Fe3 t = ext3::mul_base(inv, m); + if (negate) t = ext3::neg(t); + term_out[row * 3 + 0] = t.a; + term_out[row * 3 + 1] = t.b; + term_out[row * 3 + 2] = t.c; +} diff --git a/crypto/math-cuda/kernels/ntt.cu b/crypto/math-cuda/kernels/ntt.cu new file mode 100644 index 000000000..2a5c8c786 --- /dev/null +++ b/crypto/math-cuda/kernels/ntt.cu @@ -0,0 +1,285 @@ +// Radix-2 DIT NTT over Goldilocks. One kernel per butterfly level; the caller +// runs `bit_reverse_permute` once before the first level. +// +// Input layout: bit-reversed-order coefficients (after `bit_reverse_permute`). +// Output layout: natural-order evaluations — matches the CPU `evaluate_fft` contract. +// +// Twiddle table: `tw[i] = ω^i` for i in [0, n/2). Stride-indexed per level. + +#include "goldilocks.cuh" + +using goldilocks::add; +using goldilocks::sub; +using goldilocks::mul; + +/// Reverse the low `log_n` bits of each index and swap x[i] ↔ x[rev(i)]. +/// One thread per index; guarded by `tid < rev` to avoid double-swap. +extern "C" __global__ void bit_reverse_permute(uint64_t *x, + uint64_t n, + uint64_t log_n) { + uint64_t tid = (uint64_t)blockIdx.x * blockDim.x + threadIdx.x; + if (tid >= n) return; + + // __brevll reverses all 64 bits; shift right so result lives in [0, n). + uint64_t rev = __brevll(tid) >> (64 - log_n); + if (tid < rev) { + uint64_t tmp = x[tid]; + x[tid] = x[rev]; + x[rev] = tmp; + } +} + +/// Pointwise multiply: x[i] *= w[i]. Used for coset scaling (w = g^i weights). +extern "C" __global__ void pointwise_mul(uint64_t *x, + const uint64_t *w, + uint64_t n) { + uint64_t tid = (uint64_t)blockIdx.x * blockDim.x + threadIdx.x; + if (tid < n) x[tid] = mul(x[tid], w[tid]); +} + +/// Broadcast scalar multiply: x[i] *= c. Used for the 1/n factor at the end of iNTT. +extern "C" __global__ void scalar_mul(uint64_t *x, + uint64_t c, + uint64_t n) { + uint64_t tid = (uint64_t)blockIdx.x * blockDim.x + threadIdx.x; + if (tid < n) x[tid] = mul(x[tid], c); +} + +// ============================================================================ +// BATCHED KERNELS +// +// One launch processes M columns at once. The device buffer holds M columns +// back-to-back; column `c` starts at `data + c * col_stride`. gridDim.y is +// the column index, so each block handles one (column, butterfly-window) pair. +// +// The same twiddle table is shared across all columns of a batch (they all +// NTT on the same domain). The coset weights are also shared. +// ============================================================================ + +extern "C" __global__ void bit_reverse_permute_batched(uint64_t *data, + uint64_t n, + uint64_t log_n, + uint64_t col_stride) { + uint64_t tid = (uint64_t)blockIdx.x * blockDim.x + threadIdx.x; + if (tid >= n) return; + uint64_t *x = data + (uint64_t)blockIdx.y * col_stride; + + uint64_t rev = __brevll(tid) >> (64 - log_n); + if (tid < rev) { + uint64_t tmp = x[tid]; + x[tid] = x[rev]; + x[rev] = tmp; + } +} + +extern "C" __global__ void ntt_dit_level_batched(uint64_t *data, + const uint64_t *tw, + uint64_t n, + uint64_t log_n, + uint64_t level, + uint64_t col_stride) { + uint64_t tid = (uint64_t)blockIdx.x * blockDim.x + threadIdx.x; + uint64_t n_half = n >> 1; + if (tid >= n_half) return; + uint64_t *x = data + (uint64_t)blockIdx.y * col_stride; + + uint64_t half = 1ULL << level; + uint64_t block_size = half << 1; + uint64_t block_idx = tid >> level; + uint64_t k = tid & (half - 1); + + uint64_t i0 = block_idx * block_size + k; + uint64_t i1 = i0 + half; + + uint64_t tw_index = k << (log_n - level - 1); + uint64_t w = tw[tw_index]; + + uint64_t u = x[i0]; + uint64_t v = mul(w, x[i1]); + x[i0] = add(u, v); + x[i1] = sub(u, v); +} + +extern "C" __global__ void ntt_dit_8_levels_batched(uint64_t *data, + const uint64_t *tw, + uint64_t n, + uint64_t log_n, + uint64_t base_step, + uint64_t col_stride) { + __shared__ uint64_t tile[256]; + uint64_t *x = data + (uint64_t)blockIdx.y * col_stride; + + uint32_t n_loc_steps = (uint32_t)min((uint64_t)8, log_n - base_step); + + uint64_t tid = (uint64_t)blockIdx.x * blockDim.x + threadIdx.x; + + uint64_t group_size = 1ULL << base_step; + uint64_t n_groups = n >> base_step; + uint64_t low_bits = tid / n_groups; + uint64_t high_bits = tid & (n_groups - 1); + uint64_t row = high_bits * group_size + low_bits; + + tile[threadIdx.x] = x[row]; + __syncthreads(); + + uint32_t remaining_high_bits = (uint32_t)(log_n - base_step - 1); + uint32_t high_mask = (1u << remaining_high_bits) - 1u; + + for (uint32_t loc_step = 0; loc_step < n_loc_steps; ++loc_step) { + if (threadIdx.x < 128) { + uint32_t i = threadIdx.x; + uint32_t half = 1u << loc_step; + uint32_t grp = i >> loc_step; + uint32_t grp_pos = i & (half - 1); + uint32_t idx1 = (grp << (loc_step + 1)) + grp_pos; + uint32_t idx2 = idx1 + half; + + uint32_t gs = (uint32_t)base_step + loc_step; + uint32_t ggp = (blockIdx.x << 7) + i; + ggp = ((ggp & high_mask) << (uint32_t)base_step) + (ggp >> remaining_high_bits); + ggp = ggp & ((1u << gs) - 1u); + uint64_t factor = tw[(uint64_t)ggp * (n >> (gs + 1))]; + + uint64_t u = tile[idx1]; + uint64_t v = mul(tile[idx2], factor); + tile[idx1] = add(u, v); + tile[idx2] = sub(u, v); + } + __syncthreads(); + } + + x[row] = tile[threadIdx.x]; +} + + +/// Batched pointwise multiply: first n elements of each column multiplied by +/// the SHARED weight vector `w` (size n). Used for coset scaling — every +/// column of a table sees the same `g^i / N` weights. +extern "C" __global__ void pointwise_mul_batched(uint64_t *data, + const uint64_t *w, + uint64_t n, + uint64_t col_stride) { + uint64_t tid = (uint64_t)blockIdx.x * blockDim.x + threadIdx.x; + if (tid >= n) return; + uint64_t *x = data + (uint64_t)blockIdx.y * col_stride; + x[tid] = mul(x[tid], w[tid]); +} + +/// Batched broadcast scalar multiply — one scalar c applied to the first n +/// elements of every column. +extern "C" __global__ void scalar_mul_batched(uint64_t *data, + uint64_t c, + uint64_t n, + uint64_t col_stride) { + uint64_t tid = (uint64_t)blockIdx.x * blockDim.x + threadIdx.x; + if (tid >= n) return; + uint64_t *x = data + (uint64_t)blockIdx.y * col_stride; + x[tid] = mul(x[tid], c); +} + +/// One DIT butterfly level. Thread `tid` (of n/2 total) owns exactly one +/// butterfly pair (i0, i1 = i0 + half). Twiddle picked from the shared full +/// `tw` table at stride `n / block_size`. Kept for log_n < 8 where shmem +/// fusion is overkill. +extern "C" __global__ void ntt_dit_level(uint64_t *x, + const uint64_t *tw, + uint64_t n, + uint64_t log_n, + uint64_t level) { + uint64_t tid = (uint64_t)blockIdx.x * blockDim.x + threadIdx.x; + uint64_t n_half = n >> 1; + if (tid >= n_half) return; + + uint64_t half = 1ULL << level; // 2^ℓ + uint64_t block_size = half << 1; // 2^{ℓ+1} + uint64_t block_idx = tid >> level; // floor(tid / half) + uint64_t k = tid & (half - 1); // tid mod half + + uint64_t i0 = block_idx * block_size + k; + uint64_t i1 = i0 + half; + + // Stride = n / block_size = n >> (level + 1). + uint64_t tw_index = k << (log_n - level - 1); + uint64_t w = tw[tw_index]; + + uint64_t u = x[i0]; + uint64_t v = mul(w, x[i1]); + x[i0] = add(u, v); + x[i1] = sub(u, v); +} + +/// Up to 8 DIT butterfly levels fused in one kernel using shared memory. +/// +/// Ported from Zisk's `br_ntt_8_steps` (`pil2-stark/src/goldilocks/src/ntt_goldilocks.cu`), +/// simplified to single-column. Each block of 256 threads processes 256 +/// elements in on-chip shared memory, running up to 8 butterfly levels +/// without writing to global memory between them — cuts DRAM traffic by up +/// to 8× vs the per-level kernel. +/// +/// `base_step` selects which 8-level window this launch handles (0, 8, 16…). +/// For levels 0–7 the implicit DIT element layout already places all pair +/// mates inside the same 256-block; for higher base_step we remap the loaded +/// row so pair mates land in consecutive shared-memory slots. +/// +/// Expects bit-reversed input (the caller runs `bit_reverse_permute` once +/// before the first kernel launch). +/// +/// Assumes `n` is a multiple of 256, i.e. `log_n >= 8`. +extern "C" __global__ void ntt_dit_8_levels(uint64_t *x, + const uint64_t *tw, + uint64_t n, + uint64_t log_n, + uint64_t base_step) { + __shared__ uint64_t tile[256]; + + uint32_t n_loc_steps = (uint32_t)min((uint64_t)8, log_n - base_step); + + // tid is the *unpermuted* flat index the block/thread would own. + uint64_t tid = (uint64_t)blockIdx.x * blockDim.x + threadIdx.x; + + // Row remap: for base_step > 0, gather elements that pair at levels + // `base_step..base_step+7` so they land consecutively in the block. + uint64_t group_size = 1ULL << base_step; + uint64_t n_groups = n >> base_step; // = n / group_size + uint64_t low_bits = tid / n_groups; + uint64_t high_bits = tid & (n_groups - 1); + uint64_t row = high_bits * group_size + low_bits; + + // Load one element per thread. + tile[threadIdx.x] = x[row]; + __syncthreads(); + + // Each butterfly level uses half the threads (128 butterflies per block). + // The global butterfly index `ggp` is recovered from blockIdx + threadIdx + // and reshaped by the same row-remap to find the right twiddle. + uint32_t remaining_high_bits = (uint32_t)(log_n - base_step - 1); // log2(n_groups / 2) + uint32_t high_mask = (1u << remaining_high_bits) - 1u; + + for (uint32_t loc_step = 0; loc_step < n_loc_steps; ++loc_step) { + if (threadIdx.x < 128) { + uint32_t i = threadIdx.x; + uint32_t half = 1u << loc_step; + uint32_t grp = i >> loc_step; + uint32_t grp_pos = i & (half - 1); + uint32_t idx1 = (grp << (loc_step + 1)) + grp_pos; + uint32_t idx2 = idx1 + half; + + // Global step and butterfly position for twiddle lookup. + uint32_t gs = (uint32_t)base_step + loc_step; + uint32_t ggp = (blockIdx.x << 7) + i; // blockIdx * 128 + i + // Un-remap ggp to find its position in the natural ordering. + ggp = ((ggp & high_mask) << (uint32_t)base_step) + (ggp >> remaining_high_bits); + ggp = ggp & ((1u << gs) - 1u); + uint64_t factor = tw[(uint64_t)ggp * (n >> (gs + 1))]; + + uint64_t u = tile[idx1]; + uint64_t v = mul(tile[idx2], factor); + tile[idx1] = add(u, v); + tile[idx2] = sub(u, v); + } + __syncthreads(); + } + + // Store back to the remapped row. + x[row] = tile[threadIdx.x]; +} diff --git a/crypto/math-cuda/src/barycentric.rs b/crypto/math-cuda/src/barycentric.rs new file mode 100644 index 000000000..d9dbb659c --- /dev/null +++ b/crypto/math-cuda/src/barycentric.rs @@ -0,0 +1,215 @@ +//! Barycentric evaluation on device — matches +//! `math::polynomial::interpolate_coset_eval_*_with_g_n_inv`. +//! +//! The kernels compute only the unscaled barycentric sum +//! S = Σ_i point_i * eval_i * inv_denom_i +//! per column. The caller multiplies each `S` by the ext3 scalar +//! `(z^N - g^N) * 1/N * 1/g^N` to get the final OOD value; that scaling is +//! one ext3 mul per column and stays on host. + +use cudarc::driver::{LaunchConfig, PushKernelArg}; + +use crate::Result; +use crate::device::backend; +use crate::lde::{GpuLdeBase, GpuLdeExt3}; + +const BLOCK_DIM: u32 = 256; + +/// Barycentric sums over M base-field columns, each of length `n`, laid out +/// with stride `col_stride` (so column `c` is at `columns[c*col_stride .. +/// c*col_stride + n]`). `inv_denoms` is 3N u64 (ext3 interleaved). +/// Returns 3M u64 (ext3 interleaved), one per column. +pub fn barycentric_base( + columns: &[u64], + col_stride: usize, + coset_points: &[u64], + inv_denoms_ext3: &[u64], + n: usize, + num_cols: usize, +) -> Result> { + assert_eq!(coset_points.len(), n); + assert_eq!(inv_denoms_ext3.len(), 3 * n); + assert!(columns.len() >= num_cols * col_stride); + if num_cols == 0 || n == 0 { + return Ok(vec![0; 3 * num_cols]); + } + + let be = backend(); + let stream = be.next_stream(); + + let cols_dev = stream.clone_htod(&columns[..num_cols * col_stride])?; + let points_dev = stream.clone_htod(coset_points)?; + let inv_dev = stream.clone_htod(inv_denoms_ext3)?; + let mut out_dev = stream.alloc_zeros::(3 * num_cols)?; + + let col_stride_u64 = col_stride as u64; + let n_u64 = n as u64; + let cfg = LaunchConfig { + grid_dim: (num_cols as u32, 1, 1), + block_dim: (BLOCK_DIM, 1, 1), + shared_mem_bytes: 0, + }; + unsafe { + stream + .launch_builder(&be.barycentric_base_batched) + .arg(&cols_dev) + .arg(&col_stride_u64) + .arg(&points_dev) + .arg(&inv_dev) + .arg(&n_u64) + .arg(&mut out_dev) + .launch(cfg)?; + } + let out = stream.clone_dtoh(&out_dev)?; + stream.synchronize()?; + Ok(out) +} + +/// Same as [`barycentric_base`] but `columns` holds M ext3 columns in the +/// de-interleaved layout: slab `c*3 + k` at offset `(c*3+k)*col_stride`. +/// `columns.len() >= num_cols * 3 * col_stride`. +pub fn barycentric_ext3( + columns: &[u64], + col_stride: usize, + coset_points: &[u64], + inv_denoms_ext3: &[u64], + n: usize, + num_cols: usize, +) -> Result> { + assert_eq!(coset_points.len(), n); + assert_eq!(inv_denoms_ext3.len(), 3 * n); + assert!(columns.len() >= num_cols * 3 * col_stride); + if num_cols == 0 || n == 0 { + return Ok(vec![0; 3 * num_cols]); + } + + let be = backend(); + let stream = be.next_stream(); + + let cols_dev = stream.clone_htod(&columns[..num_cols * 3 * col_stride])?; + let points_dev = stream.clone_htod(coset_points)?; + let inv_dev = stream.clone_htod(inv_denoms_ext3)?; + let mut out_dev = stream.alloc_zeros::(3 * num_cols)?; + + let col_stride_u64 = col_stride as u64; + let n_u64 = n as u64; + let cfg = LaunchConfig { + grid_dim: (num_cols as u32, 1, 1), + block_dim: (BLOCK_DIM, 1, 1), + shared_mem_bytes: 0, + }; + unsafe { + stream + .launch_builder(&be.barycentric_ext3_batched) + .arg(&cols_dev) + .arg(&col_stride_u64) + .arg(&points_dev) + .arg(&inv_dev) + .arg(&n_u64) + .arg(&mut out_dev) + .launch(cfg)?; + } + let out = stream.clone_dtoh(&out_dev)?; + stream.synchronize()?; + Ok(out) +} + +/// Run `barycentric_base_batched_strided` over the base LDE already on +/// device (`main_handle`), summing over the trace-size coset (every +/// `row_stride = blowup_factor`-th row). H2Ds only the coset points and +/// inv_denoms; the column data never crosses PCIe. +pub fn barycentric_base_on_device( + main_handle: &GpuLdeBase, + row_stride: usize, + coset_points: &[u64], + inv_denoms_ext3: &[u64], + n: usize, +) -> Result> { + assert_eq!(coset_points.len(), n); + assert_eq!(inv_denoms_ext3.len(), 3 * n); + let num_cols = main_handle.m; + if num_cols == 0 || n == 0 { + return Ok(vec![0; 3 * num_cols]); + } + let col_stride = main_handle.lde_size; + + let be = backend(); + let stream = be.next_stream(); + + let points_dev = stream.clone_htod(coset_points)?; + let inv_dev = stream.clone_htod(inv_denoms_ext3)?; + let mut out_dev = stream.alloc_zeros::(3 * num_cols)?; + + let col_stride_u64 = col_stride as u64; + let row_stride_u64 = row_stride as u64; + let n_u64 = n as u64; + let cfg = LaunchConfig { + grid_dim: (num_cols as u32, 1, 1), + block_dim: (BLOCK_DIM, 1, 1), + shared_mem_bytes: 0, + }; + unsafe { + stream + .launch_builder(&be.barycentric_base_batched_strided) + .arg(main_handle.buf.as_ref()) + .arg(&col_stride_u64) + .arg(&row_stride_u64) + .arg(&points_dev) + .arg(&inv_dev) + .arg(&n_u64) + .arg(&mut out_dev) + .launch(cfg)?; + } + let out = stream.clone_dtoh(&out_dev)?; + stream.synchronize()?; + Ok(out) +} + +/// Ext3 counterpart of [`barycentric_base_on_device`]. Reads the aux LDE +/// from the de-interleaved device handle. +pub fn barycentric_ext3_on_device( + aux_handle: &GpuLdeExt3, + row_stride: usize, + coset_points: &[u64], + inv_denoms_ext3: &[u64], + n: usize, +) -> Result> { + assert_eq!(coset_points.len(), n); + assert_eq!(inv_denoms_ext3.len(), 3 * n); + let num_cols = aux_handle.m; + if num_cols == 0 || n == 0 { + return Ok(vec![0; 3 * num_cols]); + } + let col_stride = aux_handle.lde_size; + + let be = backend(); + let stream = be.next_stream(); + + let points_dev = stream.clone_htod(coset_points)?; + let inv_dev = stream.clone_htod(inv_denoms_ext3)?; + let mut out_dev = stream.alloc_zeros::(3 * num_cols)?; + + let col_stride_u64 = col_stride as u64; + let row_stride_u64 = row_stride as u64; + let n_u64 = n as u64; + let cfg = LaunchConfig { + grid_dim: (num_cols as u32, 1, 1), + block_dim: (BLOCK_DIM, 1, 1), + shared_mem_bytes: 0, + }; + unsafe { + stream + .launch_builder(&be.barycentric_ext3_batched_strided) + .arg(aux_handle.buf.as_ref()) + .arg(&col_stride_u64) + .arg(&row_stride_u64) + .arg(&points_dev) + .arg(&inv_dev) + .arg(&n_u64) + .arg(&mut out_dev) + .launch(cfg)?; + } + let out = stream.clone_dtoh(&out_dev)?; + stream.synchronize()?; + Ok(out) +} diff --git a/crypto/math-cuda/src/deep.rs b/crypto/math-cuda/src/deep.rs new file mode 100644 index 000000000..484970e33 --- /dev/null +++ b/crypto/math-cuda/src/deep.rs @@ -0,0 +1,217 @@ +//! R4 deep-composition polynomial evaluations on GPU. +//! +//! Mirrors `Self::compute_deep_composition_poly_evaluations` in +//! `crypto/stark/src/prover.rs`. Accepts the main/aux LDEs as device +//! handles (populated by the R1 fused path in `LDETraceTable`) and +//! takes every other tensor (composition parts LDE, OOD evals, +//! gammas, inv-denoms) from host. Returns a `Vec` of +//! `domain_size * 3` u64s, ext3 interleaved (ready to `transmute` to +//! `FieldElement` when the caller promises layout compatibility). + +use cudarc::driver::{LaunchConfig, PushKernelArg}; + +use crate::Result; +use crate::device::backend; +use crate::lde::{GpuLdeBase, GpuLdeExt3}; + +/// Compute deep-composition evaluations on device. +/// +/// `num_eval_points = trace_terms_gammas_interleaved.len() / ((num_main + +/// num_aux) * 3)`. The caller is responsible for packing each Vec +/// into interleaved u64 slices (`[a0, a1, a2, b0, b1, b2, ...]`). +#[allow(clippy::too_many_arguments)] +pub fn deep_composition_ext3( + main_lde: &GpuLdeBase, + aux_lde: Option<&GpuLdeExt3>, + // Host-side inputs (H2D'd internally) + h_parts_deinterleaved: &[u64], // num_parts * 3 * lde_stride u64 + h_ood: &[u64], // num_parts * 3 + trace_ood: &[u64], // num_total_cols * num_eval_points * 3 + gammas_h: &[u64], // num_parts * 3 + gammas_tr: &[u64], // num_total_cols * num_eval_points * 3 + inv_h: &[u64], // domain_size * 3 + inv_t: &[u64], // num_eval_points * domain_size * 3 + // Shape params + num_parts: usize, + num_main: usize, + num_aux: usize, + num_eval_points: usize, + blowup_factor: usize, + domain_size: usize, +) -> Result> { + deep_composition_ext3_impl( + main_lde, + aux_lde, + None, + h_parts_deinterleaved, + h_ood, + trace_ood, + gammas_h, + gammas_tr, + inv_h, + inv_t, + num_parts, + num_main, + num_aux, + num_eval_points, + blowup_factor, + domain_size, + ) +} + +/// Same as [`deep_composition_ext3`] but reads the composition-parts LDE +/// from a device handle (`GpuLdeExt3`) populated by the R2 fused path, +/// skipping the `num_parts * 3 * lde_size * 8` byte H2D of +/// `h_parts_deinterleaved`. +#[allow(clippy::too_many_arguments)] +pub fn deep_composition_ext3_with_dev_parts( + main_lde: &GpuLdeBase, + aux_lde: Option<&GpuLdeExt3>, + h_parts_dev: &GpuLdeExt3, + h_ood: &[u64], + trace_ood: &[u64], + gammas_h: &[u64], + gammas_tr: &[u64], + inv_h: &[u64], + inv_t: &[u64], + num_parts: usize, + num_main: usize, + num_aux: usize, + num_eval_points: usize, + blowup_factor: usize, + domain_size: usize, +) -> Result> { + deep_composition_ext3_impl( + main_lde, + aux_lde, + Some(h_parts_dev), + &[], + h_ood, + trace_ood, + gammas_h, + gammas_tr, + inv_h, + inv_t, + num_parts, + num_main, + num_aux, + num_eval_points, + blowup_factor, + domain_size, + ) +} + +#[allow(clippy::too_many_arguments)] +fn deep_composition_ext3_impl( + main_lde: &GpuLdeBase, + aux_lde: Option<&GpuLdeExt3>, + h_parts_dev: Option<&GpuLdeExt3>, + h_parts_host: &[u64], + h_ood: &[u64], + trace_ood: &[u64], + gammas_h: &[u64], + gammas_tr: &[u64], + inv_h: &[u64], + inv_t: &[u64], + num_parts: usize, + num_main: usize, + num_aux: usize, + num_eval_points: usize, + blowup_factor: usize, + domain_size: usize, +) -> Result> { + assert_eq!(main_lde.m, num_main); + if let Some(a) = aux_lde { + assert_eq!(a.m, num_aux); + assert_eq!(a.lde_size, main_lde.lde_size); + } else { + assert_eq!(num_aux, 0); + } + if let Some(h) = h_parts_dev { + assert_eq!(h.m, num_parts); + assert_eq!(h.lde_size, main_lde.lde_size); + } else { + assert_eq!(h_parts_host.len(), num_parts * 3 * main_lde.lde_size); + } + assert_eq!(h_ood.len(), num_parts * 3); + let num_total_cols = num_main + num_aux; + assert_eq!(trace_ood.len(), num_total_cols * num_eval_points * 3); + assert_eq!(gammas_h.len(), num_parts * 3); + assert_eq!(gammas_tr.len(), num_total_cols * num_eval_points * 3); + assert_eq!(inv_h.len(), domain_size * 3); + assert_eq!(inv_t.len(), num_eval_points * domain_size * 3); + + let be = backend(); + let stream = be.next_stream(); + + // H2D only the scalar arrays — h_parts comes from a device handle + // when available. + let h_ood_dev = stream.clone_htod(h_ood)?; + let trace_ood_dev = stream.clone_htod(trace_ood)?; + let gammas_h_dev = stream.clone_htod(gammas_h)?; + let gammas_tr_dev = stream.clone_htod(gammas_tr)?; + let inv_h_dev = stream.clone_htod(inv_h)?; + let inv_t_dev = stream.clone_htod(inv_t)?; + + // Keep the owned H2D of h_lde alive until kernel completes. Only + // populated in the host-parts path. + let h_lde_host_dev; + + let mut deep_out = stream.alloc_zeros::(domain_size * 3)?; + + let dummy_aux; + let aux_slice = if let Some(a) = aux_lde { + a.buf.as_ref() + } else { + dummy_aux = stream.alloc_zeros::(1)?; + &dummy_aux + }; + + let h_lde_slice = if let Some(h) = h_parts_dev { + h.buf.as_ref() + } else { + h_lde_host_dev = stream.clone_htod(h_parts_host)?; + &h_lde_host_dev + }; + + let lde_stride = main_lde.lde_size as u64; + let num_main_u = num_main as u64; + let num_aux_u = num_aux as u64; + let num_parts_u = num_parts as u64; + let num_eval_points_u = num_eval_points as u64; + let blowup_u = blowup_factor as u64; + let domain_size_u = domain_size as u64; + + let grid = ((domain_size as u32) + 128 - 1) / 128; + let cfg = LaunchConfig { + grid_dim: (grid, 1, 1), + block_dim: (128, 1, 1), + shared_mem_bytes: 0, + }; + unsafe { + stream + .launch_builder(&be.deep_composition_ext3_row) + .arg(main_lde.buf.as_ref()) + .arg(aux_slice) + .arg(h_lde_slice) + .arg(&lde_stride) + .arg(&num_main_u) + .arg(&num_aux_u) + .arg(&num_parts_u) + .arg(&num_eval_points_u) + .arg(&blowup_u) + .arg(&domain_size_u) + .arg(&h_ood_dev) + .arg(&trace_ood_dev) + .arg(&gammas_h_dev) + .arg(&gammas_tr_dev) + .arg(&inv_h_dev) + .arg(&inv_t_dev) + .arg(&mut deep_out) + .launch(cfg)?; + } + + let out = stream.clone_dtoh(&deep_out)?; + stream.synchronize()?; + Ok(out) +} diff --git a/crypto/math-cuda/src/device.rs b/crypto/math-cuda/src/device.rs new file mode 100644 index 000000000..481cb8d9f --- /dev/null +++ b/crypto/math-cuda/src/device.rs @@ -0,0 +1,339 @@ +//! CUDA device context, stream pool, kernel handles, and twiddle cache. +//! +//! One process-wide backend — lazy-initialised on first use. All kernels live +//! on a single CUDA context; a pool of streams lets rayon-parallel callers +//! overlap H2D / compute / D2H. + +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::sync::{Arc, Mutex, OnceLock}; + +use cudarc::driver::{CudaContext, CudaFunction, CudaSlice, CudaStream}; +use cudarc::nvrtc::Ptx; +use math::field::goldilocks::GoldilocksField; +use math::field::traits::IsFFTField; + +use crate::Result; +use crate::ntt::{twiddles_forward, twiddles_inverse}; + +/// Reusable pinned host staging buffer. One per stream; the stream's LDE call +/// holds its buffer's lock across the D2H + memcpy-to-user-Vecs window. +/// +/// Allocated with `cuMemHostAlloc(flags=0)` — portable, non-write-combined, +/// so both DMA writes from device and CPU reads into user Vecs run at full +/// speed. Grows power-of-two; never shrinks. +pub struct PinnedStaging { + ptr: *mut u64, + capacity_elems: usize, +} + +// SAFETY: the raw pointer aliases host memory allocated via cuMemHostAlloc. +// We guard concurrent access with a Mutex; the pointer is valid for the +// lifetime of this struct and is freed on drop. +unsafe impl Send for PinnedStaging {} +unsafe impl Sync for PinnedStaging {} + +impl PinnedStaging { + const fn empty() -> Self { + Self { + ptr: std::ptr::null_mut(), + capacity_elems: 0, + } + } + + pub fn ensure_capacity( + &mut self, + min_elems: usize, + ctx: &CudaContext, + ) -> Result<()> { + if self.capacity_elems >= min_elems { + return Ok(()); + } + // cuMemHostAlloc requires the context to be current on this thread. + ctx.bind_to_thread()?; + // Free old (if any) before allocating the new one. + if !self.ptr.is_null() { + unsafe { + let _ = cudarc::driver::sys::cuMemFreeHost(self.ptr as *mut _); + } + self.ptr = std::ptr::null_mut(); + self.capacity_elems = 0; + } + let new_cap = min_elems.next_power_of_two().max(1 << 20); // at least 8 MB + let bytes = new_cap * std::mem::size_of::(); + let ptr = unsafe { + cudarc::driver::result::malloc_host(bytes, 0 /* flags: non-WC */)? + } as *mut u64; + self.ptr = ptr; + self.capacity_elems = new_cap; + Ok(()) + } + + /// View of the first `len` elements. Caller must hold this `PinnedStaging` + /// locked while using the slice; the slice aliases the internal pointer. + /// + /// # Safety + /// Caller must not outlive the `PinnedStaging` and must not race with + /// concurrent uses. + pub unsafe fn as_mut_slice(&mut self, len: usize) -> &mut [u64] { + assert!(len <= self.capacity_elems); + if len == 0 { + return &mut []; + } + unsafe { std::slice::from_raw_parts_mut(self.ptr, len) } + } +} + +impl Drop for PinnedStaging { + fn drop(&mut self) { + if !self.ptr.is_null() { + unsafe { + let _ = cudarc::driver::sys::cuMemFreeHost(self.ptr as *mut _); + } + } + } +} + +const ARITH_PTX: &str = include_str!(concat!(env!("OUT_DIR"), "/arith.ptx")); +const NTT_PTX: &str = include_str!(concat!(env!("OUT_DIR"), "/ntt.ptx")); +const KECCAK_PTX: &str = include_str!(concat!(env!("OUT_DIR"), "/keccak.ptx")); +const BARY_PTX: &str = include_str!(concat!(env!("OUT_DIR"), "/barycentric.ptx")); +const DEEP_PTX: &str = include_str!(concat!(env!("OUT_DIR"), "/deep.ptx")); +const FRI_PTX: &str = include_str!(concat!(env!("OUT_DIR"), "/fri.ptx")); +const INVERSE_PTX: &str = include_str!(concat!(env!("OUT_DIR"), "/inverse.ptx")); +const LOGUP_PTX: &str = include_str!(concat!(env!("OUT_DIR"), "/logup.ptx")); +/// Number of CUDA streams in the pool. Larger pools let many rayon-parallel +/// callers overlap on the GPU without serializing on stream ownership. The +/// default stream is deliberately excluded because it synchronises with all +/// other streams, defeating the point of the pool. +const STREAM_POOL_SIZE: usize = 32; + +pub struct Backend { + pub ctx: Arc, + streams: Vec>, + /// Single shared pinned staging buffer, grown to the biggest LDE size + /// seen. Concurrent batched LDE calls serialise on it; in exchange the + /// process keeps only ONE gigabyte-sized pinned allocation (per-stream + /// buffers 32×-inflated memory use and multiplied the one-time pinning + /// cost for every first use of a new table size). + pinned_staging: Mutex, + /// Separate pinned staging for Merkle leaf hashes. Sized `num_rows * 32` + /// bytes; lives alongside the LDE staging so the GPU→host D2H for + /// hashed leaves runs at full PCIe line-rate instead of the pageable + /// ~1.3 GB/s path that would otherwise eat ~100 ms per main-trace commit. + pinned_hashes: Mutex, + util_stream: Arc, + next: AtomicUsize, + + // arith.ptx + pub vector_add_u64: CudaFunction, + pub gl_add: CudaFunction, + pub gl_sub: CudaFunction, + pub gl_mul: CudaFunction, + pub gl_neg: CudaFunction, + pub ext3_mul: CudaFunction, + pub ext3_add: CudaFunction, + + // ntt.ptx + pub bit_reverse_permute: CudaFunction, + pub ntt_dit_level: CudaFunction, + pub ntt_dit_8_levels: CudaFunction, + pub pointwise_mul: CudaFunction, + pub scalar_mul: CudaFunction, + pub bit_reverse_permute_batched: CudaFunction, + pub ntt_dit_level_batched: CudaFunction, + pub ntt_dit_8_levels_batched: CudaFunction, + pub pointwise_mul_batched: CudaFunction, + pub scalar_mul_batched: CudaFunction, + + // keccak.ptx + pub keccak256_leaves_base_batched: CudaFunction, + pub keccak256_leaves_ext3_batched: CudaFunction, + pub keccak_comp_poly_leaves_ext3: CudaFunction, + pub keccak_fri_leaves_ext3: CudaFunction, + pub keccak_merkle_level: CudaFunction, + + // barycentric.ptx + pub barycentric_base_batched: CudaFunction, + pub barycentric_ext3_batched: CudaFunction, + pub barycentric_base_batched_strided: CudaFunction, + pub barycentric_ext3_batched_strided: CudaFunction, + + // deep.ptx + pub deep_composition_ext3_row: CudaFunction, + + // fri.ptx + pub fri_fold_ext3: CudaFunction, + pub fri_update_twiddles: CudaFunction, + + // inverse.ptx + pub compute_denoms_ext3: CudaFunction, + pub chunk_prefix_scan_ext3: CudaFunction, + pub exclusive_scan_of_totals_ext3: CudaFunction, + pub apply_scan_offsets_ext3: CudaFunction, + pub chunk_suffix_scan_ext3: CudaFunction, + pub exclusive_reverse_scan_of_totals_ext3: CudaFunction, + pub apply_reverse_scan_offsets_ext3: CudaFunction, + pub batch_inverse_combine_ext3: CudaFunction, + + // logup.ptx + pub logup_pair_fingerprint: CudaFunction, + pub logup_pair_term_assembly: CudaFunction, + pub logup_single_fingerprint: CudaFunction, + pub logup_single_term_assembly: CudaFunction, + + // Twiddle caches keyed by log_n. + fwd_twiddles: Mutex>>>>, + inv_twiddles: Mutex>>>>, +} + +impl Backend { + fn init() -> Result { + let ctx = CudaContext::new(0)?; + // cudarc's default per-slice CudaEvent tracking adds two driver calls + // per alloc and serialises under the context lock. We never share + // slices across streams (every call scopes its own buffers and syncs + // before returning), so the tracking is pure overhead. Disable it. + unsafe { ctx.disable_event_tracking() }; + + let arith = ctx.load_module(Ptx::from_src(ARITH_PTX))?; + let ntt = ctx.load_module(Ptx::from_src(NTT_PTX))?; + let keccak = ctx.load_module(Ptx::from_src(KECCAK_PTX))?; + let bary = ctx.load_module(Ptx::from_src(BARY_PTX))?; + let deep = ctx.load_module(Ptx::from_src(DEEP_PTX))?; + let fri = ctx.load_module(Ptx::from_src(FRI_PTX))?; + let inverse = ctx.load_module(Ptx::from_src(INVERSE_PTX))?; + let logup = ctx.load_module(Ptx::from_src(LOGUP_PTX))?; + + let mut streams = Vec::with_capacity(STREAM_POOL_SIZE); + for _ in 0..STREAM_POOL_SIZE { + streams.push(ctx.new_stream()?); + } + let pinned_staging = Mutex::new(PinnedStaging::empty()); + let pinned_hashes = Mutex::new(PinnedStaging::empty()); + // Separate "utility" stream for twiddle uploads and other bookkeeping; + // not part of the pool that callers rotate through. + let util_stream = ctx.new_stream()?; + + // Goldilocks TWO_ADICITY is 32, so log_n ≤ 32 covers every LDE size + // the prover can produce. Overshoot by one for safety. + let max_log = GoldilocksField::TWO_ADICITY as usize + 1; + + Ok(Self { + vector_add_u64: arith.load_function("vector_add_u64")?, + gl_add: arith.load_function("gl_add_kernel")?, + gl_sub: arith.load_function("gl_sub_kernel")?, + gl_mul: arith.load_function("gl_mul_kernel")?, + gl_neg: arith.load_function("gl_neg_kernel")?, + ext3_mul: arith.load_function("ext3_mul_kernel")?, + ext3_add: arith.load_function("ext3_add_kernel")?, + bit_reverse_permute: ntt.load_function("bit_reverse_permute")?, + ntt_dit_level: ntt.load_function("ntt_dit_level")?, + ntt_dit_8_levels: ntt.load_function("ntt_dit_8_levels")?, + pointwise_mul: ntt.load_function("pointwise_mul")?, + scalar_mul: ntt.load_function("scalar_mul")?, + bit_reverse_permute_batched: ntt.load_function("bit_reverse_permute_batched")?, + ntt_dit_level_batched: ntt.load_function("ntt_dit_level_batched")?, + ntt_dit_8_levels_batched: ntt.load_function("ntt_dit_8_levels_batched")?, + pointwise_mul_batched: ntt.load_function("pointwise_mul_batched")?, + scalar_mul_batched: ntt.load_function("scalar_mul_batched")?, + keccak256_leaves_base_batched: keccak.load_function("keccak256_leaves_base_batched")?, + keccak256_leaves_ext3_batched: keccak.load_function("keccak256_leaves_ext3_batched")?, + keccak_comp_poly_leaves_ext3: keccak.load_function("keccak_comp_poly_leaves_ext3")?, + keccak_fri_leaves_ext3: keccak.load_function("keccak_fri_leaves_ext3")?, + keccak_merkle_level: keccak.load_function("keccak_merkle_level")?, + barycentric_base_batched: bary.load_function("barycentric_base_batched")?, + barycentric_ext3_batched: bary.load_function("barycentric_ext3_batched")?, + barycentric_base_batched_strided: bary.load_function("barycentric_base_batched_strided")?, + barycentric_ext3_batched_strided: bary.load_function("barycentric_ext3_batched_strided")?, + deep_composition_ext3_row: deep.load_function("deep_composition_ext3_row")?, + fri_fold_ext3: fri.load_function("fri_fold_ext3")?, + fri_update_twiddles: fri.load_function("fri_update_twiddles")?, + compute_denoms_ext3: inverse.load_function("compute_denoms_ext3")?, + chunk_prefix_scan_ext3: inverse.load_function("chunk_prefix_scan_ext3")?, + exclusive_scan_of_totals_ext3: inverse + .load_function("exclusive_scan_of_totals_ext3")?, + apply_scan_offsets_ext3: inverse.load_function("apply_scan_offsets_ext3")?, + chunk_suffix_scan_ext3: inverse.load_function("chunk_suffix_scan_ext3")?, + exclusive_reverse_scan_of_totals_ext3: inverse + .load_function("exclusive_reverse_scan_of_totals_ext3")?, + apply_reverse_scan_offsets_ext3: inverse + .load_function("apply_reverse_scan_offsets_ext3")?, + batch_inverse_combine_ext3: inverse.load_function("batch_inverse_combine_ext3")?, + logup_pair_fingerprint: logup.load_function("logup_pair_fingerprint")?, + logup_pair_term_assembly: logup.load_function("logup_pair_term_assembly")?, + logup_single_fingerprint: logup.load_function("logup_single_fingerprint")?, + logup_single_term_assembly: logup.load_function("logup_single_term_assembly")?, + fwd_twiddles: Mutex::new(vec![None; max_log]), + inv_twiddles: Mutex::new(vec![None; max_log]), + ctx, + streams, + pinned_staging, + pinned_hashes, + util_stream, + next: AtomicUsize::new(0), + }) + } + + /// Round-robin over the stream pool. Concurrent callers get different + /// streams so their kernel launches overlap on the GPU. + pub fn next_stream(&self) -> Arc { + let idx = self.next.fetch_add(1, Ordering::Relaxed) % self.streams.len(); + self.streams[idx].clone() + } + + /// Shared pinned staging buffer. Grows to the largest LDE the process + /// has seen so far. Concurrent callers serialise on the mutex. + pub fn pinned_staging(&self) -> &Mutex { + &self.pinned_staging + } + + /// Separate pinned staging for Merkle leaf hash output. Sized in u64 + /// units; caller should reserve `(num_rows * 32 + 7) / 8` u64s. + pub fn pinned_hashes(&self) -> &Mutex { + &self.pinned_hashes + } + + pub fn fwd_twiddles_for(&self, log_n: u64) -> Result>> { + self.cached_twiddles(log_n, true) + } + + pub fn inv_twiddles_for(&self, log_n: u64) -> Result>> { + self.cached_twiddles(log_n, false) + } + + fn cached_twiddles(&self, log_n: u64, forward: bool) -> Result>> { + let idx = log_n as usize; + let cache = if forward { + &self.fwd_twiddles + } else { + &self.inv_twiddles + }; + { + let guard = cache.lock().unwrap(); + if let Some(t) = &guard[idx] { + return Ok(t.clone()); + } + } + // Compute on host, upload on the utility stream. Another thread may + // have populated the cache in the meantime; prefer that entry. + let host = if forward { + twiddles_forward(log_n) + } else { + twiddles_inverse(log_n) + }; + let dev = Arc::new(self.util_stream.clone_htod(&host)?); + self.util_stream.synchronize()?; + let mut guard = cache.lock().unwrap(); + if let Some(t) = &guard[idx] { + Ok(t.clone()) + } else { + guard[idx] = Some(dev.clone()); + Ok(dev) + } + } +} + +pub fn backend() -> &'static Backend { + static BACKEND: OnceLock = OnceLock::new(); + BACKEND.get_or_init(|| Backend::init().expect("failed to initialise CUDA backend")) +} diff --git a/crypto/math-cuda/src/fri.rs b/crypto/math-cuda/src/fri.rs new file mode 100644 index 000000000..a3fa7a2b6 --- /dev/null +++ b/crypto/math-cuda/src/fri.rs @@ -0,0 +1,289 @@ +//! Fully-device-resident FRI commit phase orchestration. +//! +//! The host loop (in the stark crate) samples each layer's `zeta` from the +//! transcript and feeds it in; this module keeps the folded evaluations, +//! twiddles, and per-layer Merkle trees on device, only D2H'ing each +//! layer's root (to append to the transcript), plus its full evals and +//! tree nodes (to plug into `FriLayer` for the query phase). +//! +//! Mirrors `commit_phase_from_evaluations` at +//! `crypto/stark/src/fri/mod.rs`. + +use cudarc::driver::{CudaSlice, CudaStream, LaunchConfig, PushKernelArg}; +use std::sync::Arc; + +use crate::Result; +use crate::device::backend; + +/// Device-side state across FRI commit iterations. Owns two ext3 eval +/// buffers (flip-flopped as layer input / output) and the inv_twiddles +/// buffer. Freed when dropped. +pub struct FriCommitState { + pub stream: Arc, + // Ping-pong evaluation buffers. Both sized `3 * n0` u64 at init; each + // successive fold uses half the space. Cheap to pre-allocate vs. per- + // layer alloc. + evals_a: CudaSlice, + evals_b: CudaSlice, + /// Base-field inv_twiddles; `n0 / 2` u64 at init, halved each layer. + inv_tw: CudaSlice, + /// Number of ext3 elements currently in the "input" buffer. + pub current_n: usize, + /// Which buffer holds the current layer's input. Toggles each fold. + a_is_input: bool, +} + +impl FriCommitState { + /// H2D the starting evals (ext3 interleaved, 3 * n0 u64) and the + /// initial inv_twiddles (base field, n0/2 u64). `n0` must be a power of + /// two and ≥ 2. + pub fn new( + evals_host: &[u64], + inv_tw_host: &[u64], + n0: usize, + ) -> Result { + assert!(n0 >= 2 && n0.is_power_of_two()); + assert_eq!(evals_host.len(), 3 * n0); + assert_eq!(inv_tw_host.len(), n0 / 2); + + let be = backend(); + let stream = be.next_stream(); + + // SAFETY: every byte of evals_a is overwritten by the H2D below. + // evals_b is written by the first fold before it is read. + let mut evals_a = unsafe { stream.alloc::(3 * n0) }?; + let evals_b = unsafe { stream.alloc::(3 * n0) }?; + stream.memcpy_htod(evals_host, &mut evals_a)?; + let inv_tw = stream.clone_htod(inv_tw_host)?; + + Ok(Self { + stream, + evals_a, + evals_b, + inv_tw, + current_n: n0, + a_is_input: true, + }) + } + + /// Fold the current layer using `zeta`, run the row-pair Keccak leaves + /// + pair-hash Merkle tree kernels on the result, and D2H: + /// - the new root (32 bytes) + /// - the new layer's evals (3 * (current_n / 2) u64s) + /// - the new layer's Merkle tree nodes (standard layout, byte-packed) + /// + /// Also updates `inv_twiddles` in place to shrink for the next layer. + pub fn fold_and_commit_layer( + &mut self, + zeta_raw: [u64; 3], + ) -> Result<(Vec, Vec, Vec)> { + let be = backend(); + let n_in = self.current_n; + let n_out = n_in / 2; + assert!(n_out >= 1, "FRI fold_layer called with current_n = 1"); + + // Allocate the tree buffer. num_leaves = n_out / 2 (row-pair leaves). + let num_leaves = n_out / 2; + let tight_total_nodes = if num_leaves >= 1 { + 2 * num_leaves - 1 + } else { + // Degenerate case: n_out == 1, no further Merkle commit needed. + // Caller should use `fold_final` for the final layer, not here. + panic!("fold_and_commit_layer requires n_out >= 2 (num_leaves >= 1)"); + }; + + // H2D zeta. + let zeta_dev = self.stream.clone_htod(&zeta_raw)?; + + // Select input and output buffers. + // Borrow checker requires us to split_borrow; use raw pointers via + // slice_mut to pass both into the kernel. + // We pass `input` via `&CudaSlice` and `output` via + // `&mut CudaSlice`. Rust borrow rules require them to be + // distinct; `a_is_input` flips between the two owned slices. + let cfg = LaunchConfig { + grid_dim: (((n_out as u32) + 128 - 1) / 128, 1, 1), + block_dim: (128, 1, 1), + shared_mem_bytes: 0, + }; + let n_out_u64 = n_out as u64; + + if self.a_is_input { + unsafe { + self.stream + .launch_builder(&be.fri_fold_ext3) + .arg(&self.evals_a) + .arg(&n_out_u64) + .arg(&self.inv_tw) + .arg(&zeta_dev) + .arg(&mut self.evals_b) + .launch(cfg)?; + } + } else { + unsafe { + self.stream + .launch_builder(&be.fri_fold_ext3) + .arg(&self.evals_b) + .arg(&n_out_u64) + .arg(&self.inv_tw) + .arg(&zeta_dev) + .arg(&mut self.evals_a) + .launch(cfg)?; + } + } + + // Keccak leaves + pair-hash tree into fresh device buffer. + let mut nodes_dev = unsafe { self.stream.alloc::(tight_total_nodes * 32) }?; + let leaves_offset_bytes = (num_leaves - 1) * 32; + { + let mut leaves_view = nodes_dev + .slice_mut(leaves_offset_bytes..leaves_offset_bytes + num_leaves * 32); + let num_leaves_u64 = num_leaves as u64; + let grid = ((num_leaves as u32) + 128 - 1) / 128; + let kcfg = LaunchConfig { + grid_dim: (grid, 1, 1), + block_dim: (128, 1, 1), + shared_mem_bytes: 0, + }; + // Leaves read from the layer's OUTPUT eval buffer. + if self.a_is_input { + unsafe { + self.stream + .launch_builder(&be.keccak_fri_leaves_ext3) + .arg(&self.evals_b) + .arg(&num_leaves_u64) + .arg(&mut leaves_view) + .launch(kcfg)?; + } + } else { + unsafe { + self.stream + .launch_builder(&be.keccak_fri_leaves_ext3) + .arg(&self.evals_a) + .arg(&num_leaves_u64) + .arg(&mut leaves_view) + .launch(kcfg)?; + } + } + } + { + let mut level_begin: u64 = (num_leaves - 1) as u64; + while level_begin != 0 { + let new_begin = level_begin / 2; + let n_pairs = level_begin - new_begin; + let grid = ((n_pairs as u32) + 128 - 1) / 128; + let cfg = LaunchConfig { + grid_dim: (grid, 1, 1), + block_dim: (128, 1, 1), + shared_mem_bytes: 0, + }; + unsafe { + self.stream + .launch_builder(&be.keccak_merkle_level) + .arg(&mut nodes_dev) + .arg(&new_begin) + .arg(&n_pairs) + .launch(cfg)?; + } + level_begin = new_begin; + } + } + + // Update inv_twiddles for the next layer: `new[j] = old[2j]²` for + // j in 0..n_out/2. (If n_out == 1, skip — no next fold.) + let tw_next = n_out / 2; + if tw_next > 0 { + let grid = ((tw_next as u32) + 128 - 1) / 128; + let cfg = LaunchConfig { + grid_dim: (grid, 1, 1), + block_dim: (128, 1, 1), + shared_mem_bytes: 0, + }; + let tw_next_u64 = tw_next as u64; + unsafe { + self.stream + .launch_builder(&be.fri_update_twiddles) + .arg(&mut self.inv_tw) + .arg(&tw_next_u64) + .launch(cfg)?; + } + } + + // Sync and D2H. + self.stream.synchronize()?; + + // Layer evals: 3 * n_out u64 from the output buffer. + let layer_evals: Vec = if self.a_is_input { + let view = self.evals_b.slice(0..3 * n_out); + self.stream.clone_dtoh(&view)? + } else { + let view = self.evals_a.slice(0..3 * n_out); + self.stream.clone_dtoh(&view)? + }; + + // Tree nodes. + let nodes_bytes: Vec = self.stream.clone_dtoh(&nodes_dev)?; + debug_assert_eq!(nodes_bytes.len(), tight_total_nodes * 32); + + let mut root = vec![0u8; 32]; + root.copy_from_slice(&nodes_bytes[0..32]); + + self.a_is_input = !self.a_is_input; + self.current_n = n_out; + + Ok((root, layer_evals, nodes_bytes)) + } + + /// Final fold — no Merkle commit. Returns the single ext3 output + /// element (the FRI last_value). + pub fn fold_final(&mut self, zeta_raw: [u64; 3]) -> Result<[u64; 3]> { + let be = backend(); + let n_in = self.current_n; + let n_out = n_in / 2; + assert!(n_out >= 1); + + let zeta_dev = self.stream.clone_htod(&zeta_raw)?; + let cfg = LaunchConfig { + grid_dim: (((n_out as u32) + 128 - 1) / 128, 1, 1), + block_dim: (128, 1, 1), + shared_mem_bytes: 0, + }; + let n_out_u64 = n_out as u64; + + if self.a_is_input { + unsafe { + self.stream + .launch_builder(&be.fri_fold_ext3) + .arg(&self.evals_a) + .arg(&n_out_u64) + .arg(&self.inv_tw) + .arg(&zeta_dev) + .arg(&mut self.evals_b) + .launch(cfg)?; + } + } else { + unsafe { + self.stream + .launch_builder(&be.fri_fold_ext3) + .arg(&self.evals_b) + .arg(&n_out_u64) + .arg(&self.inv_tw) + .arg(&zeta_dev) + .arg(&mut self.evals_a) + .launch(cfg)?; + } + } + + self.stream.synchronize()?; + let out_first: Vec = if self.a_is_input { + let view = self.evals_b.slice(0..3); + self.stream.clone_dtoh(&view)? + } else { + let view = self.evals_a.slice(0..3); + self.stream.clone_dtoh(&view)? + }; + self.a_is_input = !self.a_is_input; + self.current_n = n_out; + Ok([out_first[0], out_first[1], out_first[2]]) + } +} diff --git a/crypto/math-cuda/src/inverse.rs b/crypto/math-cuda/src/inverse.rs new file mode 100644 index 000000000..fc0fa0adc --- /dev/null +++ b/crypto/math-cuda/src/inverse.rs @@ -0,0 +1,428 @@ +//! Parallel Montgomery batch inverse on the GPU for ext3 elements, plus +//! the R3 OOD / R4 DEEP `compute-denoms + invert` convenience fn. + +use cudarc::driver::{CudaSlice, LaunchConfig, PushKernelArg}; + +use crate::Result; +use crate::device::backend; + +const SCAN_THREADS: u32 = 256; +const COMBINE_BLOCK: u32 = 256; + +/// Parallel batch inverse over ext3 elements. `a` is 3 * n u64s +/// (interleaved). Returns a fresh Vec with 3 * n inverses. +/// +/// Mirrors `FieldElement::inplace_batch_inverse` semantically; parity +/// is gated by the prove+verify round-trip in the stark test suite. +pub fn batch_inverse_ext3(a: &[u64]) -> Result> { + assert!(a.len() % 3 == 0); + let n = a.len() / 3; + if n == 0 { + return Ok(Vec::new()); + } + if n == 1 { + // Degenerate: one element. Montgomery on CPU is simpler than the + // GPU pipeline for a single value — just invert and return. + // Caller is responsible for handling n=1 (unlikely) on CPU. + return Ok(vec![0; 3]); + } + + let be = backend(); + let stream = be.next_stream(); + + // H2D input. + let a_dev = stream.clone_htod(a)?; + + // Scratch buffers. + let mut prefix_dev = stream.alloc_zeros::(n * 3)?; + let mut suffix_dev = stream.alloc_zeros::(n * 3)?; + + // Chunk sizing: SCAN_THREADS threads, one chunk per thread. + let k: u32 = SCAN_THREADS; + let c_per_thread: u64 = ((n as u64) + (k as u64) - 1) / (k as u64); + let mut chunk_totals = stream.alloc_zeros::((k as usize) * 3)?; + let mut chunk_offsets = stream.alloc_zeros::((k as usize) * 3)?; + let n_u64 = n as u64; + let k_u64 = k as u64; + + // Phase 1: chunk prefix scan. + let cfg_scan = LaunchConfig { + grid_dim: (1, 1, 1), + block_dim: (k, 1, 1), + shared_mem_bytes: 0, + }; + unsafe { + stream + .launch_builder(&be.chunk_prefix_scan_ext3) + .arg(&a_dev) + .arg(&n_u64) + .arg(&c_per_thread) + .arg(&mut prefix_dev) + .arg(&mut chunk_totals) + .launch(cfg_scan)?; + } + + // Phase 2: exclusive scan of chunk totals (single thread). + unsafe { + stream + .launch_builder(&be.exclusive_scan_of_totals_ext3) + .arg(&chunk_totals) + .arg(&k_u64) + .arg(&mut chunk_offsets) + .launch(LaunchConfig { + grid_dim: (1, 1, 1), + block_dim: (1, 1, 1), + shared_mem_bytes: 0, + })?; + } + + // Phase 3: apply offsets. + unsafe { + stream + .launch_builder(&be.apply_scan_offsets_ext3) + .arg(&mut prefix_dev) + .arg(&n_u64) + .arg(&c_per_thread) + .arg(&chunk_offsets) + .launch(cfg_scan)?; + } + + // Mirror for suffix. + let mut suffix_chunk_totals = stream.alloc_zeros::((k as usize) * 3)?; + let mut suffix_chunk_offsets = stream.alloc_zeros::((k as usize) * 3)?; + unsafe { + stream + .launch_builder(&be.chunk_suffix_scan_ext3) + .arg(&a_dev) + .arg(&n_u64) + .arg(&c_per_thread) + .arg(&mut suffix_dev) + .arg(&mut suffix_chunk_totals) + .launch(cfg_scan)?; + } + unsafe { + stream + .launch_builder(&be.exclusive_reverse_scan_of_totals_ext3) + .arg(&suffix_chunk_totals) + .arg(&k_u64) + .arg(&mut suffix_chunk_offsets) + .launch(LaunchConfig { + grid_dim: (1, 1, 1), + block_dim: (1, 1, 1), + shared_mem_bytes: 0, + })?; + } + unsafe { + stream + .launch_builder(&be.apply_reverse_scan_offsets_ext3) + .arg(&mut suffix_dev) + .arg(&n_u64) + .arg(&c_per_thread) + .arg(&suffix_chunk_offsets) + .launch(cfg_scan)?; + } + + // Compute total = prefix[n-1], invert on host. + let total = { + let last_view = prefix_dev.slice((n - 1) * 3..n * 3); + let last_host: Vec = stream.clone_dtoh(&last_view)?; + stream.synchronize()?; + invert_ext3_host([last_host[0], last_host[1], last_host[2]]) + }; + let mut inv_total_dev = stream.alloc_zeros::(3)?; + stream.memcpy_htod(&total, &mut inv_total_dev)?; + + // Combine. + let mut out_dev = stream.alloc_zeros::(n * 3)?; + let cfg_combine = LaunchConfig { + grid_dim: (((n as u32) + COMBINE_BLOCK - 1) / COMBINE_BLOCK, 1, 1), + block_dim: (COMBINE_BLOCK, 1, 1), + shared_mem_bytes: 0, + }; + unsafe { + stream + .launch_builder(&be.batch_inverse_combine_ext3) + .arg(&prefix_dev) + .arg(&suffix_dev) + .arg(&inv_total_dev) + .arg(&n_u64) + .arg(&mut out_dev) + .launch(cfg_combine)?; + } + + let out = stream.clone_dtoh(&out_dev)?; + stream.synchronize()?; + Ok(out) +} + +/// Same as [`batch_inverse_ext3`] but the input is already on device +/// (typically from `compute_denoms_ext3`) — avoids one H2D round-trip. +pub fn batch_inverse_ext3_dev(a_dev: &CudaSlice, n: usize) -> Result> { + if n == 0 { + return Ok(Vec::new()); + } + let be = backend(); + let stream = be.next_stream(); + + let mut prefix_dev = stream.alloc_zeros::(n * 3)?; + let mut suffix_dev = stream.alloc_zeros::(n * 3)?; + + let k: u32 = SCAN_THREADS; + let c_per_thread: u64 = ((n as u64) + (k as u64) - 1) / (k as u64); + let mut chunk_totals = stream.alloc_zeros::((k as usize) * 3)?; + let mut chunk_offsets = stream.alloc_zeros::((k as usize) * 3)?; + let n_u64 = n as u64; + let k_u64 = k as u64; + + let cfg_scan = LaunchConfig { + grid_dim: (1, 1, 1), + block_dim: (k, 1, 1), + shared_mem_bytes: 0, + }; + + unsafe { + stream + .launch_builder(&be.chunk_prefix_scan_ext3) + .arg(a_dev) + .arg(&n_u64) + .arg(&c_per_thread) + .arg(&mut prefix_dev) + .arg(&mut chunk_totals) + .launch(cfg_scan)?; + } + unsafe { + stream + .launch_builder(&be.exclusive_scan_of_totals_ext3) + .arg(&chunk_totals) + .arg(&k_u64) + .arg(&mut chunk_offsets) + .launch(LaunchConfig { + grid_dim: (1, 1, 1), + block_dim: (1, 1, 1), + shared_mem_bytes: 0, + })?; + } + unsafe { + stream + .launch_builder(&be.apply_scan_offsets_ext3) + .arg(&mut prefix_dev) + .arg(&n_u64) + .arg(&c_per_thread) + .arg(&chunk_offsets) + .launch(cfg_scan)?; + } + + let mut suffix_chunk_totals = stream.alloc_zeros::((k as usize) * 3)?; + let mut suffix_chunk_offsets = stream.alloc_zeros::((k as usize) * 3)?; + unsafe { + stream + .launch_builder(&be.chunk_suffix_scan_ext3) + .arg(a_dev) + .arg(&n_u64) + .arg(&c_per_thread) + .arg(&mut suffix_dev) + .arg(&mut suffix_chunk_totals) + .launch(cfg_scan)?; + } + unsafe { + stream + .launch_builder(&be.exclusive_reverse_scan_of_totals_ext3) + .arg(&suffix_chunk_totals) + .arg(&k_u64) + .arg(&mut suffix_chunk_offsets) + .launch(LaunchConfig { + grid_dim: (1, 1, 1), + block_dim: (1, 1, 1), + shared_mem_bytes: 0, + })?; + } + unsafe { + stream + .launch_builder(&be.apply_reverse_scan_offsets_ext3) + .arg(&mut suffix_dev) + .arg(&n_u64) + .arg(&c_per_thread) + .arg(&suffix_chunk_offsets) + .launch(cfg_scan)?; + } + + let total = { + let last_view = prefix_dev.slice((n - 1) * 3..n * 3); + let last_host: Vec = stream.clone_dtoh(&last_view)?; + stream.synchronize()?; + invert_ext3_host([last_host[0], last_host[1], last_host[2]]) + }; + let mut inv_total_dev = stream.alloc_zeros::(3)?; + stream.memcpy_htod(&total, &mut inv_total_dev)?; + + let mut out_dev = stream.alloc_zeros::(n * 3)?; + let cfg_combine = LaunchConfig { + grid_dim: (((n as u32) + COMBINE_BLOCK - 1) / COMBINE_BLOCK, 1, 1), + block_dim: (COMBINE_BLOCK, 1, 1), + shared_mem_bytes: 0, + }; + unsafe { + stream + .launch_builder(&be.batch_inverse_combine_ext3) + .arg(&prefix_dev) + .arg(&suffix_dev) + .arg(&inv_total_dev) + .arg(&n_u64) + .arg(&mut out_dev) + .launch(cfg_combine)?; + } + + let out = stream.clone_dtoh(&out_dev)?; + stream.synchronize()?; + Ok(out) +} + +/// Compute `denoms[k*n + i] = x[i * stride] - z_scalars[k]` for all i, k, +/// then batch-invert in place. Fuses B.1 + B.2 to avoid an intermediate +/// D2H + H2D of the denominator array. +/// +/// `x_base` is the LDE coset (base-field, at least `n * stride` u64s). +/// `z_scalars` is `k * 3` u64s (ext3 interleaved). Returns `k * n * 3` +/// u64s (the inverted denoms), flat in k-major then i-major order. +pub fn compute_and_invert_denoms_ext3( + x_base: &[u64], + stride: usize, + z_scalars: &[u64], + k_scalars: usize, + n: usize, +) -> Result> { + assert!(x_base.len() >= n * stride); + assert_eq!(z_scalars.len(), k_scalars * 3); + let total = k_scalars * n; + + let be = backend(); + let stream = be.next_stream(); + + let x_dev = stream.clone_htod(&x_base[..n * stride])?; + let z_dev = stream.clone_htod(z_scalars)?; + let mut denoms_dev = stream.alloc_zeros::(total * 3)?; + + let stride_u64 = stride as u64; + let n_u64 = n as u64; + let k_u64 = k_scalars as u64; + + // Compute denoms. + let cfg = LaunchConfig { + grid_dim: (((total as u32) + 255) / 256, 1, 1), + block_dim: (256, 1, 1), + shared_mem_bytes: 0, + }; + unsafe { + stream + .launch_builder(&be.compute_denoms_ext3) + .arg(&x_dev) + .arg(&stride_u64) + .arg(&z_dev) + .arg(&k_u64) + .arg(&n_u64) + .arg(&mut denoms_dev) + .launch(cfg)?; + } + stream.synchronize()?; + + // Batch-invert in place (reuses the device buffer). + batch_inverse_ext3_dev(&denoms_dev, total) +} + +// ============================================================================= +// Host-side ext3 inverse (used once, for the total of the GPU prefix product). +// ============================================================================= + +const GOLDILOCKS_P: u128 = (1u128 << 64) - (1u128 << 32) + 1; + +fn gl_mul(a: u64, b: u64) -> u64 { + let prod = (a as u128) * (b as u128); + (prod % GOLDILOCKS_P) as u64 +} + +fn gl_add(a: u64, b: u64) -> u64 { + let s = (a as u128) + (b as u128); + (s % GOLDILOCKS_P) as u64 +} + +fn gl_sub(a: u64, b: u64) -> u64 { + let a128 = a as u128; + let b128 = b as u128; + if a128 >= b128 { + ((a128 - b128) % GOLDILOCKS_P) as u64 + } else { + (((GOLDILOCKS_P - b128) + a128) % GOLDILOCKS_P) as u64 + } +} + +fn gl_pow(mut base: u64, mut exp: u64) -> u64 { + let mut acc: u64 = 1; + while exp != 0 { + if exp & 1 != 0 { + acc = gl_mul(acc, base); + } + base = gl_mul(base, base); + exp >>= 1; + } + acc +} + +fn gl_inv(a: u64) -> u64 { + // Fermat: a^{p-2} + gl_pow(a, GOLDILOCKS_P as u64 - 2) +} + +/// Public re-export of the host ext3 inverse — used by `logup.rs` for the +/// single-element total inverse in its inlined batch-inverse flow. +pub fn invert_ext3_host_pub(x: [u64; 3]) -> [u64; 3] { + invert_ext3_host(x) +} + +/// Invert one ext3 element on the host. Used once per batch inverse to +/// invert the total product; the main batch inverse work stays on GPU. +fn invert_ext3_host(x: [u64; 3]) -> [u64; 3] { + // x = a + b·w + c·w² where w³ = 2. + // Compute x^{-1} using the extension field's norm: + // norm(x) = x · x_conj1 · x_conj2 (where conjugates are Frobenius images) + // For Fp[w]/(w³-2) over Fp, the norm lives in Fp. + // + // Simpler: do the full ext3 multiplication inverse via + // classical adjugate over Fp[w]. + // + // Use the closed-form adjugate for degree-3 extension: + // Let x = (a, b, c) representing a + b·w + c·w² + // Then x⁻¹ = (d, e, f) / N + // where (Newton's identities / cofactor method): + // d = a² − 2·b·c + // e = 2·c² − a·b + // f = b² − a·c + // N = a·d + 2·b·f + 2·c·e + // + // (This matches the cpu `Degree3GoldilocksExtensionField::inv`.) + let a = x[0]; + let b = x[1]; + let c = x[2]; + + let bc = gl_mul(b, c); + let d = gl_sub(gl_mul(a, a), gl_add(bc, bc)); // a² - 2bc + let cc = gl_mul(c, c); + let ab = gl_mul(a, b); + let e = gl_sub(gl_add(cc, cc), ab); // 2c² - ab + let bb = gl_mul(b, b); + let ac = gl_mul(a, c); + let f = gl_sub(bb, ac); // b² - ac + + let ad = gl_mul(a, d); + let bf = gl_mul(b, f); + let ce = gl_mul(c, e); + let two_bf = gl_add(bf, bf); + let two_ce = gl_add(ce, ce); + let norm = gl_add(ad, gl_add(two_bf, two_ce)); + + let inv_norm = gl_inv(norm); + [ + gl_mul(d, inv_norm), + gl_mul(e, inv_norm), + gl_mul(f, inv_norm), + ] +} diff --git a/crypto/math-cuda/src/lde.rs b/crypto/math-cuda/src/lde.rs new file mode 100644 index 000000000..cdc95abd7 --- /dev/null +++ b/crypto/math-cuda/src/lde.rs @@ -0,0 +1,2210 @@ +//! Full coset LDE on device. Mirrors `Polynomial::coset_lde_full_expand` in +//! `crypto/math/src/fft/polynomial.rs` algebraically: +//! +//! Input : N evaluations (natural order) of a poly on the standard subgroup, +//! plus coset weights (size N). The weights include the `1/N` iFFT +//! normalisation, matching the `LdeTwiddles::coset_weights` format at +//! `crypto/stark/src/prover.rs:248` — i.e. `weights[i] = g^i / N`. +//! Output : N*blowup_factor evaluations (natural order) on the coset. +//! +//! On-device steps, picks a stream from the shared pool so rayon-parallel +//! callers overlap on the GPU. Twiddles are cached in the backend. + +use std::sync::Arc; + +use cudarc::driver::{CudaSlice, LaunchConfig, PushKernelArg}; + +use crate::Result; +use crate::device::backend; +use crate::merkle::{launch_keccak_base, launch_keccak_ext3}; +use crate::ntt::run_ntt_body; + +/// Handle to a base-field LDE kept live on device after R1 commit. +/// Layout: `m` columns, each `lde_size` u64s, column `c` at byte offset +/// `c * lde_size * 8` within `buf`. Freed when `buf` Arc drops. +#[derive(Clone)] +pub struct GpuLdeBase { + pub buf: Arc>, + pub m: usize, + pub lde_size: usize, +} + +/// Handle to an ext3 LDE kept live on device, de-interleaved into 3 base +/// slabs per column. Column `c` component `k` at u64 offset +/// `(c*3 + k) * lde_size` within `buf`. +#[derive(Clone)] +pub struct GpuLdeExt3 { + pub buf: Arc>, + pub m: usize, + pub lde_size: usize, +} + +pub fn coset_lde_base( + evals: &[u64], + blowup_factor: usize, + weights: &[u64], +) -> Result> { + let n = evals.len(); + assert!(n.is_power_of_two(), "evals length must be a power of two"); + assert_eq!(weights.len(), n, "weights length must match evals"); + assert!(blowup_factor.is_power_of_two(), "blowup must be power of two"); + if n == 0 { + return Ok(Vec::new()); + } + let lde_size = n * blowup_factor; + let log_n = n.trailing_zeros() as u64; + let log_lde = lde_size.trailing_zeros() as u64; + + let be = backend(); + let stream = be.next_stream(); + + // Device buffer of lde_size, zero-padded tail, first N filled by copy. + let mut buf = stream.alloc_zeros::(lde_size)?; + { + let mut head = buf.slice_mut(0..n); + stream.memcpy_htod(evals, &mut head)?; + } + + let inv_tw = be.inv_twiddles_for(log_n)?; + let fwd_tw = be.fwd_twiddles_for(log_lde)?; + let weights_dev = stream.clone_htod(weights)?; + + let n_u64 = n as u64; + let lde_u64 = lde_size as u64; + + // === 1. iNTT on first N: bit_reverse + 8-level-fused DIT body === + unsafe { + stream + .launch_builder(&be.bit_reverse_permute) + .arg(&mut buf) + .arg(&n_u64) + .arg(&log_n) + .launch(LaunchConfig::for_num_elems(n as u32))?; + } + // Note: `run_ntt_body` expects a standalone CudaSlice; we pass `buf` and + // the kernel walks the first `n_u64` elements via its own indexing. + run_ntt_body(stream.as_ref(), &mut buf, inv_tw.as_ref(), n_u64, log_n)?; + // Note: the CPU iFFT does not include 1/N — it's folded into `weights`. The + // next pointwise multiply applies both the coset shift and the 1/N factor. + + // === 2. Pointwise multiply first N by coset weights (includes 1/N) === + unsafe { + stream + .launch_builder(&be.pointwise_mul) + .arg(&mut buf) + .arg(&weights_dev) + .arg(&n_u64) + .launch(LaunchConfig::for_num_elems(n as u32))?; + } + + // === 3. Forward NTT on full buffer === + unsafe { + stream + .launch_builder(&be.bit_reverse_permute) + .arg(&mut buf) + .arg(&lde_u64) + .arg(&log_lde) + .launch(LaunchConfig::for_num_elems(lde_size as u32))?; + } + run_ntt_body(stream.as_ref(), &mut buf, fwd_tw.as_ref(), lde_u64, log_lde)?; + + let out = stream.clone_dtoh(&buf)?; + stream.synchronize()?; + Ok(out) +} + +/// Batched coset LDE: processes `m` columns (all the same domain) in a single +/// pipeline on one stream. One H2D per column, then per-level batched kernels +/// that launch with `grid.y = m` so a single launch does the butterflies for +/// every column at that level. +/// +/// Returns one `Vec` per input column, each of length `n * blowup_factor`. +pub fn coset_lde_batch_base( + columns: &[&[u64]], + blowup_factor: usize, + weights: &[u64], +) -> Result>> { + if columns.is_empty() { + return Ok(Vec::new()); + } + let m = columns.len(); + let n = columns[0].len(); + assert!(n.is_power_of_two(), "column length must be a power of two"); + assert_eq!(weights.len(), n, "weights length must match column length"); + assert!(blowup_factor.is_power_of_two(), "blowup must be power of two"); + for c in columns.iter() { + assert_eq!(c.len(), n, "all columns must be the same size"); + } + + if n == 0 { + return Ok(vec![Vec::new(); m]); + } + let lde_size = n * blowup_factor; + let log_n = n.trailing_zeros() as u64; + let log_lde = lde_size.trailing_zeros() as u64; + + let be = backend(); + let stream = be.next_stream(); + let staging_slot = be.pinned_staging(); + + let debug_phases = std::env::var("MATH_CUDA_PHASE_TIMING").is_ok(); + let t_start = if debug_phases { Some(std::time::Instant::now()) } else { None }; + let phase = |label: &str, prev: &mut Option| { + if let Some(p) = prev.as_ref() { + let now = std::time::Instant::now(); + eprintln!(" [{:>6.2} ms] {}", (now - *p).as_secs_f64() * 1e3, label); + *prev = Some(now); + } + }; + let mut last = t_start; + + // Pinned staging. Lock and grow to max(m*n for upload, m*lde_size for + // download). Holding the guard across the whole call serialises concurrent + // batched calls that happened to hash to the same stream slot, but that's + // exactly what we want — one stream can only do one sequence at a time. + let mut staging = staging_slot.lock().unwrap(); + staging.ensure_capacity(m * lde_size, &be.ctx)?; + // SAFETY: staging is locked, the slice alias ends before we unlock. + let pinned = unsafe { staging.as_mut_slice(m * lde_size) }; + if debug_phases { phase("staging lock + grow", &mut last); } + + // Pack columns into first m*n slots of the pinned buffer. Parallel: pinned + // writes are DRAM-bandwidth bound, saturates at ~8 cores on modern + // hardware, so rayon shaves 20+ ms at prover scale. + use rayon::prelude::*; + let pinned_base_ptr = pinned.as_mut_ptr() as usize; + columns.par_iter().enumerate().for_each(|(c, col)| { + // SAFETY: each task writes to a disjoint `[c*n..c*n+n]` region of + // `pinned`, and the outer `staging` lock guarantees no other call is + // using the buffer concurrently. + let dst = unsafe { + std::slice::from_raw_parts_mut( + (pinned_base_ptr as *mut u64).add(c * n), + n, + ) + }; + dst.copy_from_slice(col); + }); + if debug_phases { phase("host pack (pinned, rayon)", &mut last); } + + // Column layout: `buf[c * lde_size + r]`. Zeroed so the [n, lde_size) + // tail of each column is already the zero-pad the CPU path does. + let mut buf = stream.alloc_zeros::(m * lde_size)?; + if debug_phases { stream.synchronize()?; phase("alloc_zeros", &mut last); } + // One memcpy per column from the pinned buffer into the strided slots. + // The pinned source hits PCIe line-rate. + for c in 0..m { + let mut dst = buf.slice_mut(c * lde_size..c * lde_size + n); + stream.memcpy_htod(&pinned[c * n..c * n + n], &mut dst)?; + } + if debug_phases { stream.synchronize()?; phase("H2D cols (pinned)", &mut last); } + + let inv_tw = be.inv_twiddles_for(log_n)?; + let fwd_tw = be.fwd_twiddles_for(log_lde)?; + let weights_dev = stream.clone_htod(weights)?; + if debug_phases { stream.synchronize()?; phase("twiddles + weights", &mut last); } + + let n_u64 = n as u64; + let lde_u64 = lde_size as u64; + let col_stride_u64 = lde_size as u64; + let m_u32 = m as u32; + + // === 1. Bit-reverse first N of every column === + { + let grid_x = (n as u32).div_ceil(256); + let cfg = LaunchConfig { + grid_dim: (grid_x, m_u32, 1), + block_dim: (256, 1, 1), + shared_mem_bytes: 0, + }; + unsafe { + stream + .launch_builder(&be.bit_reverse_permute_batched) + .arg(&mut buf) + .arg(&n_u64) + .arg(&log_n) + .arg(&col_stride_u64) + .launch(cfg)?; + } + } + + if debug_phases { stream.synchronize()?; phase("bit_reverse N", &mut last); } + // === 2. iNTT body over all columns === + run_batched_ntt_body( + stream.as_ref(), + &mut buf, + inv_tw.as_ref(), + n_u64, + log_n, + col_stride_u64, + m_u32, + )?; + if debug_phases { stream.synchronize()?; phase("iNTT body", &mut last); } + + // === 3. Pointwise multiply by coset weights (includes 1/N) === + { + let grid_x = (n as u32).div_ceil(256); + let cfg = LaunchConfig { + grid_dim: (grid_x, m_u32, 1), + block_dim: (256, 1, 1), + shared_mem_bytes: 0, + }; + unsafe { + stream + .launch_builder(&be.pointwise_mul_batched) + .arg(&mut buf) + .arg(&weights_dev) + .arg(&n_u64) + .arg(&col_stride_u64) + .launch(cfg)?; + } + } + + // === 4. Bit-reverse full LDE of every column === + { + let grid_x = (lde_size as u32).div_ceil(256); + let cfg = LaunchConfig { + grid_dim: (grid_x, m_u32, 1), + block_dim: (256, 1, 1), + shared_mem_bytes: 0, + }; + unsafe { + stream + .launch_builder(&be.bit_reverse_permute_batched) + .arg(&mut buf) + .arg(&lde_u64) + .arg(&log_lde) + .arg(&col_stride_u64) + .launch(cfg)?; + } + } + + if debug_phases { stream.synchronize()?; phase("pointwise + bit_reverse LDE", &mut last); } + // === 5. Forward NTT on full LDE of every column === + run_batched_ntt_body( + stream.as_ref(), + &mut buf, + fwd_tw.as_ref(), + lde_u64, + log_lde, + col_stride_u64, + m_u32, + )?; + if debug_phases { stream.synchronize()?; phase("forward NTT body", &mut last); } + + // Single big D2H into the reusable pinned staging buffer — pinned, one + // call to the driver, saturates PCIe. + stream.memcpy_dtoh(&buf, &mut pinned[..m * lde_size])?; + stream.synchronize()?; + if debug_phases { phase("D2H (one shot into pinned)", &mut last); } + + // Split pinned → per-column Vecs. The first write to each virgin + // Vec page-faults, which can dominate total time (~75 ms for 128 MB). + // Parallelise so the fault cost spreads across CPU cores. + use rayon::prelude::*; + let pinned_ptr = pinned.as_ptr() as usize; + let out: Vec> = (0..m) + .into_par_iter() + .map(|c| { + let mut v = Vec::::with_capacity(lde_size); + unsafe { v.set_len(lde_size) }; + let src = unsafe { + std::slice::from_raw_parts( + (pinned_ptr as *const u64).add(c * lde_size), + lde_size, + ) + }; + v.copy_from_slice(src); + v + }) + .collect(); + if debug_phases { phase("copy out (rayon pinned → Vecs)", &mut last); } + drop(staging); + Ok(out) +} + +/// Like `coset_lde_batch_base` but writes directly into caller-provided +/// output slices instead of allocating fresh `Vec`s. Each output slice +/// must already have length `n * blowup_factor`. Saves ~50–100 ms of pageable +/// allocator work + page faults at prover scale because the caller's Vecs +/// have been sized once and are reused across calls. +pub fn coset_lde_batch_base_into( + columns: &[&[u64]], + blowup_factor: usize, + weights: &[u64], + outputs: &mut [&mut [u64]], +) -> Result<()> { + if columns.is_empty() { + return Ok(()); + } + let m = columns.len(); + assert_eq!(outputs.len(), m, "outputs must match columns count"); + let n = columns[0].len(); + assert!(n.is_power_of_two(), "column length must be a power of two"); + assert_eq!(weights.len(), n, "weights length must match column length"); + assert!(blowup_factor.is_power_of_two(), "blowup must be power of two"); + for c in columns.iter() { + assert_eq!(c.len(), n, "all columns must be the same size"); + } + let lde_size = n * blowup_factor; + for o in outputs.iter() { + assert_eq!(o.len(), lde_size, "each output must be lde_size"); + } + if n == 0 { + return Ok(()); + } + let log_n = n.trailing_zeros() as u64; + let log_lde = lde_size.trailing_zeros() as u64; + + let be = backend(); + let stream = be.next_stream(); + let staging_slot = be.pinned_staging(); + + let mut staging = staging_slot.lock().unwrap(); + staging.ensure_capacity(m * lde_size, &be.ctx)?; + let pinned = unsafe { staging.as_mut_slice(m * lde_size) }; + + for (c, col) in columns.iter().enumerate() { + pinned[c * n..c * n + n].copy_from_slice(col); + } + + let mut buf = stream.alloc_zeros::(m * lde_size)?; + for c in 0..m { + let mut dst = buf.slice_mut(c * lde_size..c * lde_size + n); + stream.memcpy_htod(&pinned[c * n..c * n + n], &mut dst)?; + } + + let inv_tw = be.inv_twiddles_for(log_n)?; + let fwd_tw = be.fwd_twiddles_for(log_lde)?; + let weights_dev = stream.clone_htod(weights)?; + + let n_u64 = n as u64; + let lde_u64 = lde_size as u64; + let col_stride_u64 = lde_size as u64; + let m_u32 = m as u32; + + // iNTT bit-reverse + body, pointwise mul, forward bit-reverse + body. + { + let grid_x = (n as u32).div_ceil(256); + let cfg = LaunchConfig { + grid_dim: (grid_x, m_u32, 1), + block_dim: (256, 1, 1), + shared_mem_bytes: 0, + }; + unsafe { + stream + .launch_builder(&be.bit_reverse_permute_batched) + .arg(&mut buf) + .arg(&n_u64) + .arg(&log_n) + .arg(&col_stride_u64) + .launch(cfg)?; + } + } + run_batched_ntt_body( + stream.as_ref(), + &mut buf, + inv_tw.as_ref(), + n_u64, + log_n, + col_stride_u64, + m_u32, + )?; + { + let grid_x = (n as u32).div_ceil(256); + let cfg = LaunchConfig { + grid_dim: (grid_x, m_u32, 1), + block_dim: (256, 1, 1), + shared_mem_bytes: 0, + }; + unsafe { + stream + .launch_builder(&be.pointwise_mul_batched) + .arg(&mut buf) + .arg(&weights_dev) + .arg(&n_u64) + .arg(&col_stride_u64) + .launch(cfg)?; + } + } + { + let grid_x = (lde_size as u32).div_ceil(256); + let cfg = LaunchConfig { + grid_dim: (grid_x, m_u32, 1), + block_dim: (256, 1, 1), + shared_mem_bytes: 0, + }; + unsafe { + stream + .launch_builder(&be.bit_reverse_permute_batched) + .arg(&mut buf) + .arg(&lde_u64) + .arg(&log_lde) + .arg(&col_stride_u64) + .launch(cfg)?; + } + } + run_batched_ntt_body( + stream.as_ref(), + &mut buf, + fwd_tw.as_ref(), + lde_u64, + log_lde, + col_stride_u64, + m_u32, + )?; + + stream.memcpy_dtoh(&buf, &mut pinned[..m * lde_size])?; + stream.synchronize()?; + + // Parallel copy pinned → caller outputs. Caller's Vecs may still fault + // on first write; we spread that cost across rayon cores. + #[allow(unused_imports)] + use rayon::prelude::*; + let pinned_ptr = pinned.as_ptr() as usize; + outputs + .par_iter_mut() + .enumerate() + .for_each(|(c, dst)| { + let src = unsafe { + std::slice::from_raw_parts( + (pinned_ptr as *const u64).add(c * lde_size), + lde_size, + ) + }; + dst.copy_from_slice(src); + }); + drop(staging); + Ok(()) +} + +/// Variant of `coset_lde_batch_base_into` that also emits the Keccak-256 +/// Merkle leaf hashes from the LDE output — all on GPU, no second H2D of +/// the LDE data. Leaves are computed reading columns at bit-reversed rows +/// (matching `commit_columns_bit_reversed` on the CPU side). +/// +/// `hashed_leaves_out` must be `lde_size * 32` bytes (one 32-byte digest +/// per output row, in natural row order). +pub fn coset_lde_batch_base_into_with_leaf_hash( + columns: &[&[u64]], + blowup_factor: usize, + weights: &[u64], + outputs: &mut [&mut [u64]], + hashed_leaves_out: &mut [u8], +) -> Result<()> { + if columns.is_empty() { + assert_eq!(outputs.len(), 0); + return Ok(()); + } + let m = columns.len(); + assert_eq!(outputs.len(), m); + let n = columns[0].len(); + assert!(n.is_power_of_two()); + assert_eq!(weights.len(), n); + assert!(blowup_factor.is_power_of_two()); + let lde_size = n * blowup_factor; + for o in outputs.iter() { + assert_eq!(o.len(), lde_size); + } + assert_eq!(hashed_leaves_out.len(), lde_size * 32); + let log_n = n.trailing_zeros() as u64; + let log_lde = lde_size.trailing_zeros() as u64; + + let be = backend(); + let stream = be.next_stream(); + let staging_slot = be.pinned_staging(); + + let mut staging = staging_slot.lock().unwrap(); + staging.ensure_capacity(m * lde_size, &be.ctx)?; + let pinned = unsafe { staging.as_mut_slice(m * lde_size) }; + + use rayon::prelude::*; + let pinned_base_ptr = pinned.as_mut_ptr() as usize; + columns.par_iter().enumerate().for_each(|(c, col)| { + // SAFETY: disjoint regions per c, outer staging lock held. + let dst = unsafe { + std::slice::from_raw_parts_mut((pinned_base_ptr as *mut u64).add(c * n), n) + }; + dst.copy_from_slice(col); + }); + + let mut buf = stream.alloc_zeros::(m * lde_size)?; + for c in 0..m { + let mut dst = buf.slice_mut(c * lde_size..c * lde_size + n); + stream.memcpy_htod(&pinned[c * n..c * n + n], &mut dst)?; + } + + let inv_tw = be.inv_twiddles_for(log_n)?; + let fwd_tw = be.fwd_twiddles_for(log_lde)?; + let weights_dev = stream.clone_htod(weights)?; + + let n_u64 = n as u64; + let lde_u64 = lde_size as u64; + let col_stride_u64 = lde_size as u64; + let m_u32 = m as u32; + + // iNTT + unsafe { + stream + .launch_builder(&be.bit_reverse_permute_batched) + .arg(&mut buf) + .arg(&n_u64) + .arg(&log_n) + .arg(&col_stride_u64) + .launch(LaunchConfig { + grid_dim: ((n as u32).div_ceil(256), m_u32, 1), + block_dim: (256, 1, 1), + shared_mem_bytes: 0, + })?; + } + run_batched_ntt_body( + stream.as_ref(), + &mut buf, + inv_tw.as_ref(), + n_u64, + log_n, + col_stride_u64, + m_u32, + )?; + // pointwise coset scale + unsafe { + stream + .launch_builder(&be.pointwise_mul_batched) + .arg(&mut buf) + .arg(&weights_dev) + .arg(&n_u64) + .arg(&col_stride_u64) + .launch(LaunchConfig { + grid_dim: ((n as u32).div_ceil(256), m_u32, 1), + block_dim: (256, 1, 1), + shared_mem_bytes: 0, + })?; + } + // forward NTT on full LDE slab + unsafe { + stream + .launch_builder(&be.bit_reverse_permute_batched) + .arg(&mut buf) + .arg(&lde_u64) + .arg(&log_lde) + .arg(&col_stride_u64) + .launch(LaunchConfig { + grid_dim: ((lde_size as u32).div_ceil(256), m_u32, 1), + block_dim: (256, 1, 1), + shared_mem_bytes: 0, + })?; + } + run_batched_ntt_body( + stream.as_ref(), + &mut buf, + fwd_tw.as_ref(), + lde_u64, + log_lde, + col_stride_u64, + m_u32, + )?; + + // Keccak-256 leaf hashing directly on the device LDE buffer. + let mut hashes_dev = stream.alloc_zeros::(lde_size * 32)?; + launch_keccak_base( + stream.as_ref(), + &buf, + col_stride_u64, + m as u64, + lde_u64, + &mut hashes_dev, + )?; + + // D2H the LDE into the pinned LDE staging and the hashes into a + // dedicated pinned hash staging, in parallel on the same stream. Both + // at pinned PCIe line-rate — pageable D2H of the 128 MB hash buffer + // would otherwise cost ~100 ms per main-trace commit. + stream.memcpy_dtoh(&buf, &mut pinned[..m * lde_size])?; + let hashes_u64_len = (lde_size * 32 + 7) / 8; + let hashes_staging_slot = be.pinned_hashes(); + let mut hashes_staging = hashes_staging_slot.lock().unwrap(); + hashes_staging.ensure_capacity(hashes_u64_len, &be.ctx)?; + let hashes_pinned = unsafe { hashes_staging.as_mut_slice(hashes_u64_len) }; + // `memcpy_dtoh` needs a byte slice. Reinterpret the u64 pinned buffer + // as bytes — same allocation, just typed differently. + let hashes_pinned_bytes: &mut [u8] = unsafe { + std::slice::from_raw_parts_mut( + hashes_pinned.as_mut_ptr() as *mut u8, + lde_size * 32, + ) + }; + stream.memcpy_dtoh(&hashes_dev, hashes_pinned_bytes)?; + stream.synchronize()?; + + // Copy pinned → caller outputs in parallel with the hash memcpy. + let pinned_ptr = pinned.as_ptr() as usize; + outputs.par_iter_mut().enumerate().for_each(|(c, dst)| { + let src = unsafe { + std::slice::from_raw_parts( + (pinned_ptr as *const u64).add(c * lde_size), + lde_size, + ) + }; + dst.copy_from_slice(src); + }); + // Rayon-parallel memcpy of 128 MB from pinned → caller. Single-threaded + // `copy_from_slice` faults virgin pageable pages one at a time; the + // mm_struct rwsem serialises them into ~100 ms at 1M-fib scale. Chunk + // the slice so ~N cores pre-fault+write in parallel. + const CHUNK: usize = 64 * 1024; // 64 KiB ≈ 16 pages per chunk + let pinned_hash_ptr = hashes_pinned_bytes.as_ptr() as usize; + hashed_leaves_out + .par_chunks_mut(CHUNK) + .enumerate() + .for_each(|(i, dst)| { + let src = unsafe { + std::slice::from_raw_parts( + (pinned_hash_ptr as *const u8).add(i * CHUNK), + dst.len(), + ) + }; + dst.copy_from_slice(src); + }); + drop(hashes_staging); + drop(staging); + Ok(()) +} + +/// Like `coset_lde_batch_base_into_with_leaf_hash`, but also builds the full +/// Merkle tree on device and returns the `2*lde_size - 1` node buffer back +/// to the caller in `merkle_nodes_out` (byte length `(2*lde_size - 1) * 32`). +/// +/// The leaf hashes are never exposed to the caller — they stay on device and +/// feed straight into the pair-hash tree kernel, avoiding the +/// pinned→pageable→pinned round-trip that the separate-step GPU tree build +/// would pay. +pub fn coset_lde_batch_base_into_with_merkle_tree( + columns: &[&[u64]], + blowup_factor: usize, + weights: &[u64], + outputs: &mut [&mut [u64]], + merkle_nodes_out: &mut [u8], +) -> Result<()> { + coset_lde_batch_base_into_with_merkle_tree_inner( + columns, + blowup_factor, + weights, + outputs, + merkle_nodes_out, + false, + ) + .map(|_| ()) +} + +/// Fused LDE + leaf-hash + Merkle tree build. If `keep_device_buf` is true, +/// returns an `Arc>` wrapping the LDE device buffer so callers +/// (R2–R4 GPU paths) can reuse the LDE without a re-H2D. +pub fn coset_lde_batch_base_into_with_merkle_tree_keep( + columns: &[&[u64]], + blowup_factor: usize, + weights: &[u64], + outputs: &mut [&mut [u64]], + merkle_nodes_out: &mut [u8], +) -> Result { + let opt = coset_lde_batch_base_into_with_merkle_tree_inner( + columns, + blowup_factor, + weights, + outputs, + merkle_nodes_out, + true, + )?; + let handle = opt.expect("keep_device_buf=true must return Some"); + Ok(handle) +} + +fn coset_lde_batch_base_into_with_merkle_tree_inner( + columns: &[&[u64]], + blowup_factor: usize, + weights: &[u64], + outputs: &mut [&mut [u64]], + merkle_nodes_out: &mut [u8], + keep_device_buf: bool, +) -> Result> { + if columns.is_empty() { + assert_eq!(outputs.len(), 0); + return Ok(None); + } + let m = columns.len(); + assert_eq!(outputs.len(), m); + let n = columns[0].len(); + assert!(n.is_power_of_two()); + assert_eq!(weights.len(), n); + assert!(blowup_factor.is_power_of_two()); + let lde_size = n * blowup_factor; + for o in outputs.iter() { + assert_eq!(o.len(), lde_size); + } + let total_nodes = 2 * lde_size - 1; + assert_eq!(merkle_nodes_out.len(), total_nodes * 32); + let log_n = n.trailing_zeros() as u64; + let log_lde = lde_size.trailing_zeros() as u64; + + let be = backend(); + let stream = be.next_stream(); + let staging_slot = be.pinned_staging(); + + let mut staging = staging_slot.lock().unwrap(); + staging.ensure_capacity(m * lde_size, &be.ctx)?; + let pinned = unsafe { staging.as_mut_slice(m * lde_size) }; + + use rayon::prelude::*; + let pinned_base_ptr = pinned.as_mut_ptr() as usize; + columns.par_iter().enumerate().for_each(|(c, col)| { + let dst = unsafe { + std::slice::from_raw_parts_mut((pinned_base_ptr as *mut u64).add(c * n), n) + }; + dst.copy_from_slice(col); + }); + + let mut buf = stream.alloc_zeros::(m * lde_size)?; + for c in 0..m { + let mut dst = buf.slice_mut(c * lde_size..c * lde_size + n); + stream.memcpy_htod(&pinned[c * n..c * n + n], &mut dst)?; + } + + let inv_tw = be.inv_twiddles_for(log_n)?; + let fwd_tw = be.fwd_twiddles_for(log_lde)?; + let weights_dev = stream.clone_htod(weights)?; + + let n_u64 = n as u64; + let lde_u64 = lde_size as u64; + let col_stride_u64 = lde_size as u64; + let m_u32 = m as u32; + + // iNTT + unsafe { + stream + .launch_builder(&be.bit_reverse_permute_batched) + .arg(&mut buf) + .arg(&n_u64) + .arg(&log_n) + .arg(&col_stride_u64) + .launch(LaunchConfig { + grid_dim: ((n as u32).div_ceil(256), m_u32, 1), + block_dim: (256, 1, 1), + shared_mem_bytes: 0, + })?; + } + run_batched_ntt_body( + stream.as_ref(), + &mut buf, + inv_tw.as_ref(), + n_u64, + log_n, + col_stride_u64, + m_u32, + )?; + unsafe { + stream + .launch_builder(&be.pointwise_mul_batched) + .arg(&mut buf) + .arg(&weights_dev) + .arg(&n_u64) + .arg(&col_stride_u64) + .launch(LaunchConfig { + grid_dim: ((n as u32).div_ceil(256), m_u32, 1), + block_dim: (256, 1, 1), + shared_mem_bytes: 0, + })?; + } + // forward NTT at LDE size + unsafe { + stream + .launch_builder(&be.bit_reverse_permute_batched) + .arg(&mut buf) + .arg(&lde_u64) + .arg(&log_lde) + .arg(&col_stride_u64) + .launch(LaunchConfig { + grid_dim: ((lde_size as u32).div_ceil(256), m_u32, 1), + block_dim: (256, 1, 1), + shared_mem_bytes: 0, + })?; + } + run_batched_ntt_body( + stream.as_ref(), + &mut buf, + fwd_tw.as_ref(), + lde_u64, + log_lde, + col_stride_u64, + m_u32, + )?; + + // Allocate the full node buffer; leaves occupy the tail slab, inner + // nodes are written by the pair-hash level kernel below. `alloc` (not + // `alloc_zeros`) is safe because every byte is written before it is + // read: leaf kernel fills the tail, tree kernel fills the head. + // + // The leaf kernel writes to `nodes_dev` starting at byte offset + // `(lde_size - 1) * 32`; we pass the base pointer as-is because the + // kernel indexes linearly from `hashed_leaves_out[row_idx * 32]`, so we + // build an offset device slice and feed that to the launch. + let mut nodes_dev = unsafe { stream.alloc::(total_nodes * 32) }?; + let leaves_offset_bytes = (lde_size - 1) * 32; + { + let mut leaves_view = + nodes_dev.slice_mut(leaves_offset_bytes..leaves_offset_bytes + lde_size * 32); + let log_num_rows_leaves = lde_size.trailing_zeros() as u64; + let num_cols_u64 = m as u64; + let grid = + ((lde_size as u32) + 128 - 1) / 128; + let cfg = LaunchConfig { + grid_dim: (grid, 1, 1), + block_dim: (128, 1, 1), + shared_mem_bytes: 0, + }; + unsafe { + stream + .launch_builder(&be.keccak256_leaves_base_batched) + .arg(&buf) + .arg(&col_stride_u64) + .arg(&num_cols_u64) + .arg(&lde_u64) + .arg(&log_num_rows_leaves) + .arg(&mut leaves_view) + .launch(cfg)?; + } + } + + // Inner tree levels — mirror the CPU `build(nodes, leaves_len)` scan. + { + let mut level_begin: u64 = (lde_size - 1) as u64; + while level_begin != 0 { + let new_begin = level_begin / 2; + let n_pairs = level_begin - new_begin; + let grid = ((n_pairs as u32) + 128 - 1) / 128; + let cfg = LaunchConfig { + grid_dim: (grid, 1, 1), + block_dim: (128, 1, 1), + shared_mem_bytes: 0, + }; + unsafe { + stream + .launch_builder(&be.keccak_merkle_level) + .arg(&mut nodes_dev) + .arg(&new_begin) + .arg(&n_pairs) + .launch(cfg)?; + } + level_begin = new_begin; + } + } + + // D2H the LDE and the tree nodes via pinned staging. + stream.memcpy_dtoh(&buf, &mut pinned[..m * lde_size])?; + + let tree_u64_len = (total_nodes * 32 + 7) / 8; + let tree_staging_slot = be.pinned_hashes(); + let mut tree_staging = tree_staging_slot.lock().unwrap(); + tree_staging.ensure_capacity(tree_u64_len, &be.ctx)?; + let tree_pinned = unsafe { tree_staging.as_mut_slice(tree_u64_len) }; + let tree_pinned_bytes: &mut [u8] = unsafe { + std::slice::from_raw_parts_mut( + tree_pinned.as_mut_ptr() as *mut u8, + total_nodes * 32, + ) + }; + stream.memcpy_dtoh(&nodes_dev, tree_pinned_bytes)?; + stream.synchronize()?; + + // Parallel memcpy pinned → caller. + let pinned_ptr = pinned.as_ptr() as usize; + outputs.par_iter_mut().enumerate().for_each(|(c, dst)| { + let src = unsafe { + std::slice::from_raw_parts( + (pinned_ptr as *const u64).add(c * lde_size), + lde_size, + ) + }; + dst.copy_from_slice(src); + }); + const CHUNK: usize = 64 * 1024; + let pinned_tree_ptr = tree_pinned_bytes.as_ptr() as usize; + merkle_nodes_out + .par_chunks_mut(CHUNK) + .enumerate() + .for_each(|(i, dst)| { + let src = unsafe { + std::slice::from_raw_parts( + (pinned_tree_ptr as *const u8).add(i * CHUNK), + dst.len(), + ) + }; + dst.copy_from_slice(src); + }); + drop(tree_staging); + drop(staging); + + if keep_device_buf { + Ok(Some(GpuLdeBase { + buf: Arc::new(buf), + m, + lde_size, + })) + } else { + drop(buf); + Ok(None) + } +} + +/// Ext3 variant of `coset_lde_batch_base_into_with_leaf_hash`: run an LDE +/// over ext3 columns AND emit Keccak-256 Merkle leaves, all in one on-device +/// pipeline. +pub fn coset_lde_batch_ext3_into_with_leaf_hash( + columns: &[&[u64]], + n: usize, + blowup_factor: usize, + weights: &[u64], + outputs: &mut [&mut [u64]], + hashed_leaves_out: &mut [u8], +) -> Result<()> { + if columns.is_empty() { + assert_eq!(outputs.len(), 0); + return Ok(()); + } + let m = columns.len(); + assert_eq!(outputs.len(), m); + assert!(n.is_power_of_two()); + assert_eq!(weights.len(), n); + assert!(blowup_factor.is_power_of_two()); + for c in columns.iter() { + assert_eq!(c.len(), 3 * n); + } + let lde_size = n * blowup_factor; + for o in outputs.iter() { + assert_eq!(o.len(), 3 * lde_size); + } + assert_eq!(hashed_leaves_out.len(), lde_size * 32); + if n == 0 { + return Ok(()); + } + let log_n = n.trailing_zeros() as u64; + let log_lde = lde_size.trailing_zeros() as u64; + + let mb = 3 * m; + let be = backend(); + let stream = be.next_stream(); + let staging_slot = be.pinned_staging(); + + let mut staging = staging_slot.lock().unwrap(); + staging.ensure_capacity(mb * lde_size, &be.ctx)?; + let pinned = unsafe { staging.as_mut_slice(mb * lde_size) }; + + use rayon::prelude::*; + let pinned_ptr_u = pinned.as_mut_ptr() as usize; + columns.par_iter().enumerate().for_each(|(c, col)| { + let slab_a = unsafe { + std::slice::from_raw_parts_mut((pinned_ptr_u as *mut u64).add((c * 3 + 0) * n), n) + }; + let slab_b = unsafe { + std::slice::from_raw_parts_mut((pinned_ptr_u as *mut u64).add((c * 3 + 1) * n), n) + }; + let slab_c = unsafe { + std::slice::from_raw_parts_mut((pinned_ptr_u as *mut u64).add((c * 3 + 2) * n), n) + }; + for i in 0..n { + slab_a[i] = col[i * 3 + 0]; + slab_b[i] = col[i * 3 + 1]; + slab_c[i] = col[i * 3 + 2]; + } + }); + + let mut buf = stream.alloc_zeros::(mb * lde_size)?; + for s in 0..mb { + let mut dst = buf.slice_mut(s * lde_size..s * lde_size + n); + stream.memcpy_htod(&pinned[s * n..s * n + n], &mut dst)?; + } + + let inv_tw = be.inv_twiddles_for(log_n)?; + let fwd_tw = be.fwd_twiddles_for(log_lde)?; + let weights_dev = stream.clone_htod(weights)?; + + let n_u64 = n as u64; + let lde_u64 = lde_size as u64; + let col_stride_u64 = lde_size as u64; + let mb_u32 = mb as u32; + + unsafe { + stream + .launch_builder(&be.bit_reverse_permute_batched) + .arg(&mut buf) + .arg(&n_u64) + .arg(&log_n) + .arg(&col_stride_u64) + .launch(LaunchConfig { + grid_dim: ((n as u32).div_ceil(256), mb_u32, 1), + block_dim: (256, 1, 1), + shared_mem_bytes: 0, + })?; + } + run_batched_ntt_body( + stream.as_ref(), + &mut buf, + inv_tw.as_ref(), + n_u64, + log_n, + col_stride_u64, + mb_u32, + )?; + unsafe { + stream + .launch_builder(&be.pointwise_mul_batched) + .arg(&mut buf) + .arg(&weights_dev) + .arg(&n_u64) + .arg(&col_stride_u64) + .launch(LaunchConfig { + grid_dim: ((n as u32).div_ceil(256), mb_u32, 1), + block_dim: (256, 1, 1), + shared_mem_bytes: 0, + })?; + } + unsafe { + stream + .launch_builder(&be.bit_reverse_permute_batched) + .arg(&mut buf) + .arg(&lde_u64) + .arg(&log_lde) + .arg(&col_stride_u64) + .launch(LaunchConfig { + grid_dim: ((lde_size as u32).div_ceil(256), mb_u32, 1), + block_dim: (256, 1, 1), + shared_mem_bytes: 0, + })?; + } + run_batched_ntt_body( + stream.as_ref(), + &mut buf, + fwd_tw.as_ref(), + lde_u64, + log_lde, + col_stride_u64, + mb_u32, + )?; + + // Keccak-256 on the de-interleaved device buffer (3M base slabs). + let mut hashes_dev = stream.alloc_zeros::(lde_size * 32)?; + launch_keccak_ext3( + stream.as_ref(), + &buf, + col_stride_u64, + m as u64, + lde_u64, + &mut hashes_dev, + )?; + + // D2H LDE (mb * lde_size u64) and hashes (lde_size * 32 bytes). + stream.memcpy_dtoh(&buf, &mut pinned[..mb * lde_size])?; + let hashes_u64_len = (lde_size * 32 + 7) / 8; + let hashes_staging_slot = be.pinned_hashes(); + let mut hashes_staging = hashes_staging_slot.lock().unwrap(); + hashes_staging.ensure_capacity(hashes_u64_len, &be.ctx)?; + let hashes_pinned = unsafe { hashes_staging.as_mut_slice(hashes_u64_len) }; + let hashes_pinned_bytes: &mut [u8] = unsafe { + std::slice::from_raw_parts_mut( + hashes_pinned.as_mut_ptr() as *mut u8, + lde_size * 32, + ) + }; + stream.memcpy_dtoh(&hashes_dev, hashes_pinned_bytes)?; + stream.synchronize()?; + + // Re-interleave pinned → caller ext3 outputs, parallel. + let pinned_const = pinned.as_ptr() as usize; + outputs.par_iter_mut().enumerate().for_each(|(c, dst)| { + let slab_a = unsafe { + std::slice::from_raw_parts( + (pinned_const as *const u64).add((c * 3 + 0) * lde_size), + lde_size, + ) + }; + let slab_b = unsafe { + std::slice::from_raw_parts( + (pinned_const as *const u64).add((c * 3 + 1) * lde_size), + lde_size, + ) + }; + let slab_c = unsafe { + std::slice::from_raw_parts( + (pinned_const as *const u64).add((c * 3 + 2) * lde_size), + lde_size, + ) + }; + for i in 0..lde_size { + dst[i * 3 + 0] = slab_a[i]; + dst[i * 3 + 1] = slab_b[i]; + dst[i * 3 + 2] = slab_c[i]; + } + }); + + // Parallel memcpy of pinned hashes → caller. + const CHUNK: usize = 64 * 1024; + let hash_src_ptr = hashes_pinned_bytes.as_ptr() as usize; + hashed_leaves_out + .par_chunks_mut(CHUNK) + .enumerate() + .for_each(|(i, dst)| { + let src = unsafe { + std::slice::from_raw_parts( + (hash_src_ptr as *const u8).add(i * CHUNK), + dst.len(), + ) + }; + dst.copy_from_slice(src); + }); + drop(hashes_staging); + drop(staging); + Ok(()) +} + +/// Ext3 variant of the fused `coset_lde_batch_base_into_with_merkle_tree`. +/// LDE + leaf hashing + inner-tree build, all on device; D2Hs only the LDE +/// evaluations and the full `2*lde_size - 1` node buffer. +pub fn coset_lde_batch_ext3_into_with_merkle_tree( + columns: &[&[u64]], + n: usize, + blowup_factor: usize, + weights: &[u64], + outputs: &mut [&mut [u64]], + merkle_nodes_out: &mut [u8], +) -> Result<()> { + coset_lde_batch_ext3_into_with_merkle_tree_inner( + columns, + n, + blowup_factor, + weights, + outputs, + merkle_nodes_out, + false, + ) + .map(|_| ()) +} + +/// Ext3 variant of [`coset_lde_batch_base_into_with_merkle_tree_keep`] — +/// returns an `Arc>` handle to the de-interleaved LDE device +/// buffer. +pub fn coset_lde_batch_ext3_into_with_merkle_tree_keep( + columns: &[&[u64]], + n: usize, + blowup_factor: usize, + weights: &[u64], + outputs: &mut [&mut [u64]], + merkle_nodes_out: &mut [u8], +) -> Result { + let opt = coset_lde_batch_ext3_into_with_merkle_tree_inner( + columns, + n, + blowup_factor, + weights, + outputs, + merkle_nodes_out, + true, + )?; + Ok(opt.expect("keep_device_buf=true must return Some")) +} + +fn coset_lde_batch_ext3_into_with_merkle_tree_inner( + columns: &[&[u64]], + n: usize, + blowup_factor: usize, + weights: &[u64], + outputs: &mut [&mut [u64]], + merkle_nodes_out: &mut [u8], + keep_device_buf: bool, +) -> Result> { + if columns.is_empty() { + assert_eq!(outputs.len(), 0); + return Ok(None); + } + let m = columns.len(); + assert_eq!(outputs.len(), m); + assert!(n.is_power_of_two()); + assert_eq!(weights.len(), n); + assert!(blowup_factor.is_power_of_two()); + for c in columns.iter() { + assert_eq!(c.len(), 3 * n); + } + let lde_size = n * blowup_factor; + for o in outputs.iter() { + assert_eq!(o.len(), 3 * lde_size); + } + let total_nodes = 2 * lde_size - 1; + assert_eq!(merkle_nodes_out.len(), total_nodes * 32); + if n == 0 { + return Ok(None); + } + let log_n = n.trailing_zeros() as u64; + let log_lde = lde_size.trailing_zeros() as u64; + + let mb = 3 * m; + let be = backend(); + let stream = be.next_stream(); + let staging_slot = be.pinned_staging(); + + let mut staging = staging_slot.lock().unwrap(); + staging.ensure_capacity(mb * lde_size, &be.ctx)?; + let pinned = unsafe { staging.as_mut_slice(mb * lde_size) }; + + use rayon::prelude::*; + let pinned_ptr_u = pinned.as_mut_ptr() as usize; + columns.par_iter().enumerate().for_each(|(c, col)| { + let slab_a = unsafe { + std::slice::from_raw_parts_mut((pinned_ptr_u as *mut u64).add((c * 3) * n), n) + }; + let slab_b = unsafe { + std::slice::from_raw_parts_mut((pinned_ptr_u as *mut u64).add((c * 3 + 1) * n), n) + }; + let slab_c = unsafe { + std::slice::from_raw_parts_mut((pinned_ptr_u as *mut u64).add((c * 3 + 2) * n), n) + }; + for i in 0..n { + slab_a[i] = col[i * 3]; + slab_b[i] = col[i * 3 + 1]; + slab_c[i] = col[i * 3 + 2]; + } + }); + + let mut buf = stream.alloc_zeros::(mb * lde_size)?; + for s in 0..mb { + let mut dst = buf.slice_mut(s * lde_size..s * lde_size + n); + stream.memcpy_htod(&pinned[s * n..s * n + n], &mut dst)?; + } + + let inv_tw = be.inv_twiddles_for(log_n)?; + let fwd_tw = be.fwd_twiddles_for(log_lde)?; + let weights_dev = stream.clone_htod(weights)?; + + let n_u64 = n as u64; + let lde_u64 = lde_size as u64; + let col_stride_u64 = lde_size as u64; + let mb_u32 = mb as u32; + + unsafe { + stream + .launch_builder(&be.bit_reverse_permute_batched) + .arg(&mut buf) + .arg(&n_u64) + .arg(&log_n) + .arg(&col_stride_u64) + .launch(LaunchConfig { + grid_dim: ((n as u32).div_ceil(256), mb_u32, 1), + block_dim: (256, 1, 1), + shared_mem_bytes: 0, + })?; + } + run_batched_ntt_body( + stream.as_ref(), + &mut buf, + inv_tw.as_ref(), + n_u64, + log_n, + col_stride_u64, + mb_u32, + )?; + unsafe { + stream + .launch_builder(&be.pointwise_mul_batched) + .arg(&mut buf) + .arg(&weights_dev) + .arg(&n_u64) + .arg(&col_stride_u64) + .launch(LaunchConfig { + grid_dim: ((n as u32).div_ceil(256), mb_u32, 1), + block_dim: (256, 1, 1), + shared_mem_bytes: 0, + })?; + } + unsafe { + stream + .launch_builder(&be.bit_reverse_permute_batched) + .arg(&mut buf) + .arg(&lde_u64) + .arg(&log_lde) + .arg(&col_stride_u64) + .launch(LaunchConfig { + grid_dim: ((lde_size as u32).div_ceil(256), mb_u32, 1), + block_dim: (256, 1, 1), + shared_mem_bytes: 0, + })?; + } + run_batched_ntt_body( + stream.as_ref(), + &mut buf, + fwd_tw.as_ref(), + lde_u64, + log_lde, + col_stride_u64, + mb_u32, + )?; + + // Allocate full tree buffer; leaf kernel writes to the tail slab. + let mut nodes_dev = unsafe { stream.alloc::(total_nodes * 32) }?; + let leaves_offset_bytes = (lde_size - 1) * 32; + { + let mut leaves_view = + nodes_dev.slice_mut(leaves_offset_bytes..leaves_offset_bytes + lde_size * 32); + let log_num_rows_leaves = lde_size.trailing_zeros() as u64; + let num_cols_u64 = m as u64; + let grid = ((lde_size as u32) + 128 - 1) / 128; + let cfg = LaunchConfig { + grid_dim: (grid, 1, 1), + block_dim: (128, 1, 1), + shared_mem_bytes: 0, + }; + unsafe { + stream + .launch_builder(&be.keccak256_leaves_ext3_batched) + .arg(&buf) + .arg(&col_stride_u64) + .arg(&num_cols_u64) + .arg(&lde_u64) + .arg(&log_num_rows_leaves) + .arg(&mut leaves_view) + .launch(cfg)?; + } + } + + // Inner tree levels. + { + let mut level_begin: u64 = (lde_size - 1) as u64; + while level_begin != 0 { + let new_begin = level_begin / 2; + let n_pairs = level_begin - new_begin; + let grid = ((n_pairs as u32) + 128 - 1) / 128; + let cfg = LaunchConfig { + grid_dim: (grid, 1, 1), + block_dim: (128, 1, 1), + shared_mem_bytes: 0, + }; + unsafe { + stream + .launch_builder(&be.keccak_merkle_level) + .arg(&mut nodes_dev) + .arg(&new_begin) + .arg(&n_pairs) + .launch(cfg)?; + } + level_begin = new_begin; + } + } + + // D2H LDE (mb * lde_size u64) and tree nodes. + stream.memcpy_dtoh(&buf, &mut pinned[..mb * lde_size])?; + let tree_u64_len = (total_nodes * 32 + 7) / 8; + let tree_staging_slot = be.pinned_hashes(); + let mut tree_staging = tree_staging_slot.lock().unwrap(); + tree_staging.ensure_capacity(tree_u64_len, &be.ctx)?; + let tree_pinned = unsafe { tree_staging.as_mut_slice(tree_u64_len) }; + let tree_pinned_bytes: &mut [u8] = unsafe { + std::slice::from_raw_parts_mut(tree_pinned.as_mut_ptr() as *mut u8, total_nodes * 32) + }; + stream.memcpy_dtoh(&nodes_dev, tree_pinned_bytes)?; + stream.synchronize()?; + + // Re-interleave pinned → caller ext3 outputs. + let pinned_const = pinned.as_ptr() as usize; + outputs.par_iter_mut().enumerate().for_each(|(c, dst)| { + let slab_a = unsafe { + std::slice::from_raw_parts( + (pinned_const as *const u64).add((c * 3) * lde_size), + lde_size, + ) + }; + let slab_b = unsafe { + std::slice::from_raw_parts( + (pinned_const as *const u64).add((c * 3 + 1) * lde_size), + lde_size, + ) + }; + let slab_c = unsafe { + std::slice::from_raw_parts( + (pinned_const as *const u64).add((c * 3 + 2) * lde_size), + lde_size, + ) + }; + for i in 0..lde_size { + dst[i * 3] = slab_a[i]; + dst[i * 3 + 1] = slab_b[i]; + dst[i * 3 + 2] = slab_c[i]; + } + }); + + const CHUNK: usize = 64 * 1024; + let pinned_tree_ptr = tree_pinned_bytes.as_ptr() as usize; + merkle_nodes_out + .par_chunks_mut(CHUNK) + .enumerate() + .for_each(|(i, dst)| { + let src = unsafe { + std::slice::from_raw_parts( + (pinned_tree_ptr as *const u8).add(i * CHUNK), + dst.len(), + ) + }; + dst.copy_from_slice(src); + }); + drop(tree_staging); + drop(staging); + + if keep_device_buf { + Ok(Some(GpuLdeExt3 { + buf: Arc::new(buf), + m, + lde_size, + })) + } else { + drop(buf); + Ok(None) + } +} + +/// Batched ext3 polynomial → coset evaluation. +/// +/// Input: M ext3 columns of `n` coefficients each (interleaved, 3n u64). +/// Output: M ext3 columns of `n * blowup_factor` evaluations each at the +/// offset-coset. +/// +/// Skips the iFFT stage of [`coset_lde_batch_ext3_into`] (input is +/// coefficients, not evaluations). Weights encode the coset shift: +/// `weights[k] = offset^k` (NO 1/N because iFFT normalisation doesn't apply). +/// +/// Used by the stark prover to GPU-accelerate +/// `evaluate_polynomial_on_lde_domain` calls inside the +/// `number_of_parts > 2` branch of the composition-polynomial LDE. +pub fn evaluate_poly_coset_batch_ext3_into( + coefs: &[&[u64]], + n: usize, + blowup_factor: usize, + weights: &[u64], + outputs: &mut [&mut [u64]], +) -> Result<()> { + evaluate_poly_coset_batch_ext3_into_inner( + coefs, + n, + blowup_factor, + weights, + outputs, + false, + ) + .map(|_| ()) +} + +/// Same as [`evaluate_poly_coset_batch_ext3_into`] but retains the de- +/// interleaved LDE device buffer as a `GpuLdeExt3` handle. Lets R2 commit +/// and R4 DEEP composition read the composition-parts LDE without +/// re-H2D'ing. +pub fn evaluate_poly_coset_batch_ext3_into_keep( + coefs: &[&[u64]], + n: usize, + blowup_factor: usize, + weights: &[u64], + outputs: &mut [&mut [u64]], +) -> Result { + let opt = evaluate_poly_coset_batch_ext3_into_inner( + coefs, + n, + blowup_factor, + weights, + outputs, + true, + )?; + Ok(opt.expect("keep_device_buf=true must return Some")) +} + +fn evaluate_poly_coset_batch_ext3_into_inner( + coefs: &[&[u64]], + n: usize, + blowup_factor: usize, + weights: &[u64], + outputs: &mut [&mut [u64]], + keep_device_buf: bool, +) -> Result> { + if coefs.is_empty() { + assert_eq!(outputs.len(), 0); + return Ok(None); + } + let m = coefs.len(); + assert_eq!(outputs.len(), m); + assert!(n.is_power_of_two()); + assert_eq!(weights.len(), n); + assert!(blowup_factor.is_power_of_two()); + for c in coefs.iter() { + assert_eq!(c.len(), 3 * n); + } + let lde_size = n * blowup_factor; + for o in outputs.iter() { + assert_eq!(o.len(), 3 * lde_size); + } + if n == 0 { + return Ok(None); + } + let log_lde = lde_size.trailing_zeros() as u64; + + let mb = 3 * m; + let be = backend(); + let stream = be.next_stream(); + let staging_slot = be.pinned_staging(); + + let mut staging = staging_slot.lock().unwrap(); + staging.ensure_capacity(mb * lde_size, &be.ctx)?; + let pinned = unsafe { staging.as_mut_slice(mb * lde_size) }; + + use rayon::prelude::*; + let pinned_ptr_u = pinned.as_mut_ptr() as usize; + coefs.par_iter().enumerate().for_each(|(c, col)| { + let slab_a = unsafe { + std::slice::from_raw_parts_mut((pinned_ptr_u as *mut u64).add((c * 3) * n), n) + }; + let slab_b = unsafe { + std::slice::from_raw_parts_mut((pinned_ptr_u as *mut u64).add((c * 3 + 1) * n), n) + }; + let slab_c = unsafe { + std::slice::from_raw_parts_mut((pinned_ptr_u as *mut u64).add((c * 3 + 2) * n), n) + }; + for i in 0..n { + slab_a[i] = col[i * 3]; + slab_b[i] = col[i * 3 + 1]; + slab_c[i] = col[i * 3 + 2]; + } + }); + + let mut buf = stream.alloc_zeros::(mb * lde_size)?; + for s in 0..mb { + let mut dst = buf.slice_mut(s * lde_size..s * lde_size + n); + stream.memcpy_htod(&pinned[s * n..s * n + n], &mut dst)?; + } + + let fwd_tw = be.fwd_twiddles_for(log_lde)?; + let weights_dev = stream.clone_htod(weights)?; + + let n_u64 = n as u64; + let lde_u64 = lde_size as u64; + let col_stride_u64 = lde_size as u64; + let mb_u32 = mb as u32; + + // Apply coset scaling: x[k] *= weights[k] for k in 0..n (no iFFT first). + { + let grid_x = (n as u32).div_ceil(256); + let cfg = LaunchConfig { + grid_dim: (grid_x, mb_u32, 1), + block_dim: (256, 1, 1), + shared_mem_bytes: 0, + }; + unsafe { + stream + .launch_builder(&be.pointwise_mul_batched) + .arg(&mut buf) + .arg(&weights_dev) + .arg(&n_u64) + .arg(&col_stride_u64) + .launch(cfg)?; + } + } + + // Bit-reverse full lde_size slab, then forward DIT NTT. + { + let grid_x = (lde_size as u32).div_ceil(256); + let cfg = LaunchConfig { + grid_dim: (grid_x, mb_u32, 1), + block_dim: (256, 1, 1), + shared_mem_bytes: 0, + }; + unsafe { + stream + .launch_builder(&be.bit_reverse_permute_batched) + .arg(&mut buf) + .arg(&lde_u64) + .arg(&log_lde) + .arg(&col_stride_u64) + .launch(cfg)?; + } + } + run_batched_ntt_body( + stream.as_ref(), + &mut buf, + fwd_tw.as_ref(), + lde_u64, + log_lde, + col_stride_u64, + mb_u32, + )?; + + stream.memcpy_dtoh(&buf, &mut pinned[..mb * lde_size])?; + stream.synchronize()?; + + let pinned_const = pinned.as_ptr() as usize; + outputs.par_iter_mut().enumerate().for_each(|(c, dst)| { + let slab_a = unsafe { + std::slice::from_raw_parts( + (pinned_const as *const u64).add((c * 3) * lde_size), + lde_size, + ) + }; + let slab_b = unsafe { + std::slice::from_raw_parts( + (pinned_const as *const u64).add((c * 3 + 1) * lde_size), + lde_size, + ) + }; + let slab_c = unsafe { + std::slice::from_raw_parts( + (pinned_const as *const u64).add((c * 3 + 2) * lde_size), + lde_size, + ) + }; + for i in 0..lde_size { + dst[i * 3] = slab_a[i]; + dst[i * 3 + 1] = slab_b[i]; + dst[i * 3 + 2] = slab_c[i]; + } + }); + drop(staging); + if keep_device_buf { + Ok(Some(GpuLdeExt3 { + buf: std::sync::Arc::new(buf), + m, + lde_size, + })) + } else { + drop(buf); + Ok(None) + } +} + +/// Fused variant of [`evaluate_poly_coset_batch_ext3_into`]: in addition to +/// the LDE output, builds the R2 composition-polynomial Merkle tree on device +/// (row-pair Keccak leaves at bit-reversed indices + pair-hash inner tree). +/// +/// `merkle_nodes_out` must have byte length `(2 * lde_size - 1) * 32`. +/// Requires `lde_size >= 2`. +pub fn evaluate_poly_coset_batch_ext3_into_with_merkle_tree( + coefs: &[&[u64]], + n: usize, + blowup_factor: usize, + weights: &[u64], + outputs: &mut [&mut [u64]], + merkle_nodes_out: &mut [u8], +) -> Result<()> { + if coefs.is_empty() { + return Ok(()); + } + let m = coefs.len(); + assert_eq!(outputs.len(), m); + assert!(n.is_power_of_two()); + assert_eq!(weights.len(), n); + assert!(blowup_factor.is_power_of_two()); + for c in coefs.iter() { + assert_eq!(c.len(), 3 * n); + } + let lde_size = n * blowup_factor; + for o in outputs.iter() { + assert_eq!(o.len(), 3 * lde_size); + } + assert!(lde_size >= 2); + let total_nodes = 2 * lde_size - 1; + assert_eq!(merkle_nodes_out.len(), total_nodes * 32); + if n == 0 { + return Ok(()); + } + let log_lde = lde_size.trailing_zeros() as u64; + + let mb = 3 * m; + let be = backend(); + let stream = be.next_stream(); + let staging_slot = be.pinned_staging(); + + let mut staging = staging_slot.lock().unwrap(); + staging.ensure_capacity(mb * lde_size, &be.ctx)?; + let pinned = unsafe { staging.as_mut_slice(mb * lde_size) }; + + use rayon::prelude::*; + let pinned_ptr_u = pinned.as_mut_ptr() as usize; + coefs.par_iter().enumerate().for_each(|(c, col)| { + let slab_a = unsafe { + std::slice::from_raw_parts_mut((pinned_ptr_u as *mut u64).add((c * 3) * n), n) + }; + let slab_b = unsafe { + std::slice::from_raw_parts_mut((pinned_ptr_u as *mut u64).add((c * 3 + 1) * n), n) + }; + let slab_c = unsafe { + std::slice::from_raw_parts_mut((pinned_ptr_u as *mut u64).add((c * 3 + 2) * n), n) + }; + for i in 0..n { + slab_a[i] = col[i * 3]; + slab_b[i] = col[i * 3 + 1]; + slab_c[i] = col[i * 3 + 2]; + } + }); + + let mut buf = stream.alloc_zeros::(mb * lde_size)?; + for s in 0..mb { + let mut dst = buf.slice_mut(s * lde_size..s * lde_size + n); + stream.memcpy_htod(&pinned[s * n..s * n + n], &mut dst)?; + } + + let fwd_tw = be.fwd_twiddles_for(log_lde)?; + let weights_dev = stream.clone_htod(weights)?; + + let n_u64 = n as u64; + let lde_u64 = lde_size as u64; + let col_stride_u64 = lde_size as u64; + let mb_u32 = mb as u32; + + unsafe { + stream + .launch_builder(&be.pointwise_mul_batched) + .arg(&mut buf) + .arg(&weights_dev) + .arg(&n_u64) + .arg(&col_stride_u64) + .launch(LaunchConfig { + grid_dim: ((n as u32).div_ceil(256), mb_u32, 1), + block_dim: (256, 1, 1), + shared_mem_bytes: 0, + })?; + } + unsafe { + stream + .launch_builder(&be.bit_reverse_permute_batched) + .arg(&mut buf) + .arg(&lde_u64) + .arg(&log_lde) + .arg(&col_stride_u64) + .launch(LaunchConfig { + grid_dim: ((lde_size as u32).div_ceil(256), mb_u32, 1), + block_dim: (256, 1, 1), + shared_mem_bytes: 0, + })?; + } + run_batched_ntt_body( + stream.as_ref(), + &mut buf, + fwd_tw.as_ref(), + lde_u64, + log_lde, + col_stride_u64, + mb_u32, + )?; + + // Build the row-pair Merkle tree on device. + // + // Row-pair commit: each leaf hashes 2 rows (bit-reversed indices) → + // num_leaves = lde_size / 2. Tree size: 2*num_leaves - 1 = lde_size - 1. + let num_leaves = lde_size / 2; + let tight_total_nodes = 2 * num_leaves - 1; + let mut nodes_dev = unsafe { stream.alloc::(tight_total_nodes * 32) }?; + let leaves_offset_bytes = (num_leaves - 1) * 32; + { + let mut leaves_view = + nodes_dev.slice_mut(leaves_offset_bytes..leaves_offset_bytes + num_leaves * 32); + let log_num_rows = log_lde; + let num_parts_u64 = m as u64; + let grid = ((num_leaves as u32) + 128 - 1) / 128; + let cfg = LaunchConfig { + grid_dim: (grid, 1, 1), + block_dim: (128, 1, 1), + shared_mem_bytes: 0, + }; + unsafe { + stream + .launch_builder(&be.keccak_comp_poly_leaves_ext3) + .arg(&buf) + .arg(&col_stride_u64) + .arg(&num_parts_u64) + .arg(&lde_u64) + .arg(&log_num_rows) + .arg(&mut leaves_view) + .launch(cfg)?; + } + } + { + let mut level_begin: u64 = (num_leaves - 1) as u64; + while level_begin != 0 { + let new_begin = level_begin / 2; + let n_pairs = level_begin - new_begin; + let grid = ((n_pairs as u32) + 128 - 1) / 128; + let cfg = LaunchConfig { + grid_dim: (grid, 1, 1), + block_dim: (128, 1, 1), + shared_mem_bytes: 0, + }; + unsafe { + stream + .launch_builder(&be.keccak_merkle_level) + .arg(&mut nodes_dev) + .arg(&new_begin) + .arg(&n_pairs) + .launch(cfg)?; + } + level_begin = new_begin; + } + } + + // D2H LDE and tree. + stream.memcpy_dtoh(&buf, &mut pinned[..mb * lde_size])?; + let tree_u64_len = (tight_total_nodes * 32 + 7) / 8; + let tree_staging_slot = be.pinned_hashes(); + let mut tree_staging = tree_staging_slot.lock().unwrap(); + tree_staging.ensure_capacity(tree_u64_len, &be.ctx)?; + let tree_pinned = unsafe { tree_staging.as_mut_slice(tree_u64_len) }; + let tree_pinned_bytes: &mut [u8] = unsafe { + std::slice::from_raw_parts_mut( + tree_pinned.as_mut_ptr() as *mut u8, + tight_total_nodes * 32, + ) + }; + stream.memcpy_dtoh(&nodes_dev, tree_pinned_bytes)?; + stream.synchronize()?; + + // Re-interleave pinned → caller ext3 outputs. + let pinned_const = pinned.as_ptr() as usize; + outputs.par_iter_mut().enumerate().for_each(|(c, dst)| { + let slab_a = unsafe { + std::slice::from_raw_parts( + (pinned_const as *const u64).add((c * 3) * lde_size), + lde_size, + ) + }; + let slab_b = unsafe { + std::slice::from_raw_parts( + (pinned_const as *const u64).add((c * 3 + 1) * lde_size), + lde_size, + ) + }; + let slab_c = unsafe { + std::slice::from_raw_parts( + (pinned_const as *const u64).add((c * 3 + 2) * lde_size), + lde_size, + ) + }; + for i in 0..lde_size { + dst[i * 3] = slab_a[i]; + dst[i * 3 + 1] = slab_b[i]; + dst[i * 3 + 2] = slab_c[i]; + } + }); + + // Copy pinned tree → caller nodes_out. `merkle_nodes_out.len() == + // total_nodes * 32` is oversized relative to our tight tree; we write + // only the first `tight_total_nodes * 32` bytes and the caller trims. + // Expose the tight byte count via the slice length so the caller can + // construct the MerkleTree with the right node count. + assert!(merkle_nodes_out.len() >= tight_total_nodes * 32); + const CHUNK: usize = 64 * 1024; + let pinned_tree_ptr = tree_pinned_bytes.as_ptr() as usize; + merkle_nodes_out[..tight_total_nodes * 32] + .par_chunks_mut(CHUNK) + .enumerate() + .for_each(|(i, dst)| { + let src = unsafe { + std::slice::from_raw_parts( + (pinned_tree_ptr as *const u8).add(i * CHUNK), + dst.len(), + ) + }; + dst.copy_from_slice(src); + }); + drop(tree_staging); + drop(staging); + Ok(()) +} + +/// Batched coset LDE for Goldilocks **cubic extension** columns. +/// +/// A degree-3 extension element is `(a, b, c)` in memory (three contiguous +/// u64s). The NTT butterfly multiplies `v = (a, b, c)` by a base-field +/// twiddle `t`: `t * v = (t*a, t*b, t*c)`. Addition is componentwise. So an +/// NTT over M ext3 columns is algebraically equivalent to **3M parallel +/// base-field NTTs** sharing the same twiddles and coset weights. We +/// exploit this to reuse the base-field kernels with no modification: +/// +/// 1. Host pack de-interleaves each ext3 column into 3 consecutive +/// base-field slabs inside the pinned staging buffer (slab 0 has all the +/// a-components, slab 1 all the b's, slab 2 all the c's — 3M base slabs +/// in total). +/// 2. Existing `bit_reverse_permute_batched` / `ntt_dit_*_batched` / +/// `pointwise_mul_batched` run over those 3M base slabs on device. +/// 3. D2H, then re-interleave 3 slabs per output ext3 column. +/// +/// Input/output layout: each slice is 3*n or 3*n*blowup u64s, packed as +/// `[a0, b0, c0, a1, b1, c1, ...]` — the natural `[FieldElement]` +/// memory representation. +pub fn coset_lde_batch_ext3_into( + columns: &[&[u64]], + n: usize, + blowup_factor: usize, + weights: &[u64], + outputs: &mut [&mut [u64]], +) -> Result<()> { + if columns.is_empty() { + return Ok(()); + } + let m = columns.len(); + assert_eq!(outputs.len(), m, "outputs must match columns count"); + assert!(n.is_power_of_two(), "n must be a power of two"); + assert_eq!(weights.len(), n, "weights length must match n"); + assert!(blowup_factor.is_power_of_two(), "blowup must be power of two"); + for c in columns.iter() { + assert_eq!(c.len(), 3 * n, "each ext3 column must be 3*n u64s"); + } + let lde_size = n * blowup_factor; + for o in outputs.iter() { + assert_eq!(o.len(), 3 * lde_size, "each output must be 3*lde_size u64s"); + } + if n == 0 { + return Ok(()); + } + let log_n = n.trailing_zeros() as u64; + let log_lde = lde_size.trailing_zeros() as u64; + + // 3 base slabs per ext3 column; slab index `c*3 + k` holds component `k`. + let mb = 3 * m; + + let be = backend(); + let stream = be.next_stream(); + let staging_slot = be.pinned_staging(); + + let mut staging = staging_slot.lock().unwrap(); + staging.ensure_capacity(mb * lde_size, &be.ctx)?; + let pinned = unsafe { staging.as_mut_slice(mb * lde_size) }; + + // Pack: for each ext3 column, write 3 base slabs into pinned. The slab + // for column c, component k lives at `pinned[(c*3 + k)*n .. (c*3+k)*n + n]`. + // We de-interleave from the interleaved `[a, b, c, a, b, c, ...]` input. + use rayon::prelude::*; + let pinned_ptr_u = pinned.as_mut_ptr() as usize; + columns.par_iter().enumerate().for_each(|(c, col)| { + // SAFETY: disjoint regions per c; staging lock held. + let slab_a = unsafe { + std::slice::from_raw_parts_mut( + (pinned_ptr_u as *mut u64).add((c * 3 + 0) * n), + n, + ) + }; + let slab_b = unsafe { + std::slice::from_raw_parts_mut( + (pinned_ptr_u as *mut u64).add((c * 3 + 1) * n), + n, + ) + }; + let slab_c = unsafe { + std::slice::from_raw_parts_mut( + (pinned_ptr_u as *mut u64).add((c * 3 + 2) * n), + n, + ) + }; + for i in 0..n { + slab_a[i] = col[i * 3 + 0]; + slab_b[i] = col[i * 3 + 1]; + slab_c[i] = col[i * 3 + 2]; + } + }); + + // Allocate + zero-pad device buffer holding 3M slabs of `lde_size`. + let mut buf = stream.alloc_zeros::(mb * lde_size)?; + // H2D: slab by slab into the first N slots of each `lde_size`-slab. + for s in 0..mb { + let mut dst = buf.slice_mut(s * lde_size..s * lde_size + n); + stream.memcpy_htod(&pinned[s * n..s * n + n], &mut dst)?; + } + + let inv_tw = be.inv_twiddles_for(log_n)?; + let fwd_tw = be.fwd_twiddles_for(log_lde)?; + let weights_dev = stream.clone_htod(weights)?; + + let n_u64 = n as u64; + let lde_u64 = lde_size as u64; + let col_stride_u64 = lde_size as u64; + let mb_u32 = mb as u32; + + // === Butterflies: identical to the base-field batched path, but with + // grid.y = 3M instead of M. === + { + let grid_x = (n as u32).div_ceil(256); + let cfg = LaunchConfig { + grid_dim: (grid_x, mb_u32, 1), + block_dim: (256, 1, 1), + shared_mem_bytes: 0, + }; + unsafe { + stream + .launch_builder(&be.bit_reverse_permute_batched) + .arg(&mut buf) + .arg(&n_u64) + .arg(&log_n) + .arg(&col_stride_u64) + .launch(cfg)?; + } + } + run_batched_ntt_body( + stream.as_ref(), + &mut buf, + inv_tw.as_ref(), + n_u64, + log_n, + col_stride_u64, + mb_u32, + )?; + { + let grid_x = (n as u32).div_ceil(256); + let cfg = LaunchConfig { + grid_dim: (grid_x, mb_u32, 1), + block_dim: (256, 1, 1), + shared_mem_bytes: 0, + }; + unsafe { + stream + .launch_builder(&be.pointwise_mul_batched) + .arg(&mut buf) + .arg(&weights_dev) + .arg(&n_u64) + .arg(&col_stride_u64) + .launch(cfg)?; + } + } + { + let grid_x = (lde_size as u32).div_ceil(256); + let cfg = LaunchConfig { + grid_dim: (grid_x, mb_u32, 1), + block_dim: (256, 1, 1), + shared_mem_bytes: 0, + }; + unsafe { + stream + .launch_builder(&be.bit_reverse_permute_batched) + .arg(&mut buf) + .arg(&lde_u64) + .arg(&log_lde) + .arg(&col_stride_u64) + .launch(cfg)?; + } + } + run_batched_ntt_body( + stream.as_ref(), + &mut buf, + fwd_tw.as_ref(), + lde_u64, + log_lde, + col_stride_u64, + mb_u32, + )?; + + stream.memcpy_dtoh(&buf, &mut pinned[..mb * lde_size])?; + stream.synchronize()?; + + // Unpack: for each output column, re-interleave 3 slabs back into the + // ext3-per-element layout. + let pinned_const = pinned.as_ptr() as usize; + outputs.par_iter_mut().enumerate().for_each(|(c, dst)| { + let slab_a = unsafe { + std::slice::from_raw_parts( + (pinned_const as *const u64).add((c * 3 + 0) * lde_size), + lde_size, + ) + }; + let slab_b = unsafe { + std::slice::from_raw_parts( + (pinned_const as *const u64).add((c * 3 + 1) * lde_size), + lde_size, + ) + }; + let slab_c = unsafe { + std::slice::from_raw_parts( + (pinned_const as *const u64).add((c * 3 + 2) * lde_size), + lde_size, + ) + }; + for i in 0..lde_size { + dst[i * 3 + 0] = slab_a[i]; + dst[i * 3 + 1] = slab_b[i]; + dst[i * 3 + 2] = slab_c[i]; + } + }); + drop(staging); + Ok(()) +} + +/// Run the DIT butterfly body of a bit-reversed-input NTT over `m` batched +/// columns in one device buffer. Same fusion strategy as `run_ntt_body`: +/// first 8 levels shmem-fused (coalesced), subsequent levels one kernel each. +fn run_batched_ntt_body( + stream: &cudarc::driver::CudaStream, + x_dev: &mut cudarc::driver::CudaSlice, + tw_dev: &cudarc::driver::CudaSlice, + n: u64, + log_n: u64, + col_stride: u64, + m: u32, +) -> Result<()> { + let be = backend(); + let fused = core::cmp::min(log_n, 8); + if fused >= 8 { + let grid_x = (n / 256) as u32; + let cfg = LaunchConfig { + grid_dim: (grid_x, m, 1), + block_dim: (256, 1, 1), + shared_mem_bytes: 0, + }; + let base_step = 0u64; + unsafe { + stream + .launch_builder(&be.ntt_dit_8_levels_batched) + .arg(&mut *x_dev) + .arg(tw_dev) + .arg(&n) + .arg(&log_n) + .arg(&base_step) + .arg(&col_stride) + .launch(cfg)?; + } + } else { + let grid_x = ((n / 2) as u32).div_ceil(256).max(1); + let cfg = LaunchConfig { + grid_dim: (grid_x, m, 1), + block_dim: (256, 1, 1), + shared_mem_bytes: 0, + }; + for level in 0..fused { + unsafe { + stream + .launch_builder(&be.ntt_dit_level_batched) + .arg(&mut *x_dev) + .arg(tw_dev) + .arg(&n) + .arg(&log_n) + .arg(&level) + .arg(&col_stride) + .launch(cfg)?; + } + } + } + + let grid_x = ((n / 2) as u32).div_ceil(256).max(1); + let cfg = LaunchConfig { + grid_dim: (grid_x, m, 1), + block_dim: (256, 1, 1), + shared_mem_bytes: 0, + }; + for level in fused..log_n { + unsafe { + stream + .launch_builder(&be.ntt_dit_level_batched) + .arg(&mut *x_dev) + .arg(tw_dev) + .arg(&n) + .arg(&log_n) + .arg(&level) + .arg(&col_stride) + .launch(cfg)?; + } + } + Ok(()) +} + diff --git a/crypto/math-cuda/src/lib.rs b/crypto/math-cuda/src/lib.rs new file mode 100644 index 000000000..3ec17a189 --- /dev/null +++ b/crypto/math-cuda/src/lib.rs @@ -0,0 +1,158 @@ +//! GPU backend for the lambda-vm STARK prover. +//! +//! Primary entry point: [`lde::coset_lde_base`]. Everything else (`ntt`, +//! element-wise arith) is either internal to the LDE pipeline or used by the +//! parity test suite. + +pub mod barycentric; +pub mod deep; +pub mod device; +pub mod fri; +pub mod inverse; +pub mod lde; +pub mod logup; +pub mod merkle; +pub mod ntt; + +use cudarc::driver::{LaunchConfig, PushKernelArg}; + +use crate::device::{Backend, backend}; + +pub type Result = std::result::Result; + +/// Toolchain sanity: plain wrapping u64 vector add. Not a field op. +pub fn vector_add_u64(a: &[u64], b: &[u64]) -> Result> { + launch_binary_u64(a, b, |be| &be.vector_add_u64) +} + +/// Goldilocks field add on device, element-wise. Inputs may be non-canonical. +pub fn gl_add_u64(a: &[u64], b: &[u64]) -> Result> { + launch_binary_u64(a, b, |be| &be.gl_add) +} + +pub fn gl_sub_u64(a: &[u64], b: &[u64]) -> Result> { + launch_binary_u64(a, b, |be| &be.gl_sub) +} + +pub fn gl_mul_u64(a: &[u64], b: &[u64]) -> Result> { + launch_binary_u64(a, b, |be| &be.gl_mul) +} + +pub fn gl_neg_u64(a: &[u64]) -> Result> { + let n = a.len(); + if n == 0 { + return Ok(Vec::new()); + } + let be = backend(); + let stream = be.next_stream(); + + let a_dev = stream.clone_htod(a)?; + let mut c_dev = stream.alloc_zeros::(n)?; + + let cfg = LaunchConfig::for_num_elems(n as u32); + let n_u64 = n as u64; + unsafe { + stream + .launch_builder(&be.gl_neg) + .arg(&a_dev) + .arg(&mut c_dev) + .arg(&n_u64) + .launch(cfg)?; + } + + let out = stream.clone_dtoh(&c_dev)?; + stream.synchronize()?; + Ok(out) +} + +/// Element-wise ext3 multiply. `a` and `b` are 3n u64s (interleaved +/// [a0,a1,a2,b0,b1,b2,...]). Test helper for the `ext3.cuh` header. +pub fn ext3_mul_u64(a: &[u64], b: &[u64]) -> Result> { + assert_eq!(a.len(), b.len()); + assert_eq!(a.len() % 3, 0); + let n = a.len() / 3; + if n == 0 { + return Ok(Vec::new()); + } + let be = backend(); + let stream = be.next_stream(); + let a_dev = stream.clone_htod(a)?; + let b_dev = stream.clone_htod(b)?; + let mut c_dev = stream.alloc_zeros::(3 * n)?; + let cfg = LaunchConfig::for_num_elems(n as u32); + let n_u64 = n as u64; + unsafe { + stream + .launch_builder(&be.ext3_mul) + .arg(&a_dev) + .arg(&b_dev) + .arg(&mut c_dev) + .arg(&n_u64) + .launch(cfg)?; + } + let out = stream.clone_dtoh(&c_dev)?; + stream.synchronize()?; + Ok(out) +} + +/// Element-wise ext3 add. +pub fn ext3_add_u64(a: &[u64], b: &[u64]) -> Result> { + assert_eq!(a.len(), b.len()); + assert_eq!(a.len() % 3, 0); + let n = a.len() / 3; + if n == 0 { + return Ok(Vec::new()); + } + let be = backend(); + let stream = be.next_stream(); + let a_dev = stream.clone_htod(a)?; + let b_dev = stream.clone_htod(b)?; + let mut c_dev = stream.alloc_zeros::(3 * n)?; + let cfg = LaunchConfig::for_num_elems(n as u32); + let n_u64 = n as u64; + unsafe { + stream + .launch_builder(&be.ext3_add) + .arg(&a_dev) + .arg(&b_dev) + .arg(&mut c_dev) + .arg(&n_u64) + .launch(cfg)?; + } + let out = stream.clone_dtoh(&c_dev)?; + stream.synchronize()?; + Ok(out) +} + +fn launch_binary_u64(a: &[u64], b: &[u64], pick: F) -> Result> +where + F: for<'a> Fn(&'a Backend) -> &'a cudarc::driver::CudaFunction, +{ + assert_eq!(a.len(), b.len(), "length mismatch"); + let n = a.len(); + if n == 0 { + return Ok(Vec::new()); + } + let be = backend(); + let stream = be.next_stream(); + + let a_dev = stream.clone_htod(a)?; + let b_dev = stream.clone_htod(b)?; + let mut c_dev = stream.alloc_zeros::(n)?; + + let cfg = LaunchConfig::for_num_elems(n as u32); + let n_u64 = n as u64; + unsafe { + stream + .launch_builder(pick(be)) + .arg(&a_dev) + .arg(&b_dev) + .arg(&mut c_dev) + .arg(&n_u64) + .launch(cfg)?; + } + + let out = stream.clone_dtoh(&c_dev)?; + stream.synchronize()?; + Ok(out) +} diff --git a/crypto/math-cuda/src/logup.rs b/crypto/math-cuda/src/logup.rs new file mode 100644 index 000000000..0e3fce8a1 --- /dev/null +++ b/crypto/math-cuda/src/logup.rs @@ -0,0 +1,637 @@ +//! LogUp aux-trace-build term-column compute on device. +//! +//! For one interaction pair (a, b): +//! 1. logup_pair_fingerprint — reads main trace columns from host buffer +//! (H2D once per call), interprets a bytecode per pair that encodes +//! the BusValue/Packing/LinearTerm structure, emits 2n ext3 fingerprints. +//! 2. batch_inverse_ext3_dev — reuses the existing parallel Montgomery +//! scan on the fingerprint buffer in place. +//! 3. logup_pair_term_assembly — reads inverted fingerprints + evaluates +//! Multiplicity descriptors (from bytecode) to emit n ext3 term values. +//! +//! The bytecode format is shared between the CPU-side serializer (in +//! crypto/stark/src/lookup.rs) and the CUDA kernels (in +//! crypto/math-cuda/kernels/logup.cu). Keep them in lock-step. + +use cudarc::driver::{CudaSlice, LaunchConfig, PushKernelArg}; + +use crate::Result; +use crate::device::backend; + +// Op kinds — mirror the CUDA #defines in logup.cu +pub const OP_PACK_DIRECT: u8 = 0; +pub const OP_PACK_WORD2L: u8 = 1; +pub const OP_PACK_WORD4L: u8 = 2; +pub const OP_PACK_DWORDWL: u8 = 3; +pub const OP_PACK_DWORDHHW: u8 = 4; +pub const OP_PACK_DWORDWHH: u8 = 5; +pub const OP_PACK_DWORDHL: u8 = 6; +pub const OP_PACK_DWORDBL: u8 = 7; +pub const OP_PACK_QUADHL: u8 = 8; +pub const OP_PACK_QUADWL: u8 = 9; +pub const OP_LINEAR: u8 = 10; + +pub const MULT_ONE: u8 = 0; +pub const MULT_COLUMN: u8 = 1; +pub const MULT_SUM: u8 = 2; +pub const MULT_NEGATED: u8 = 3; +pub const MULT_DIFF: u8 = 4; +pub const MULT_SUM3: u8 = 5; +pub const MULT_LINEAR: u8 = 6; + +/// 32-byte packed op — `#[repr(C)]` must match `FingerprintOp` in logup.cu. +#[repr(C)] +#[derive(Clone, Copy, Debug)] +pub struct FingerprintOp { + pub kind: u8, + pub pad0: [u8; 3], + pub alpha_offset: u32, + pub start_col: u32, + pub num_linear_terms: u32, + pub linear_term_offset: u32, + pub pad1: [u32; 2], +} + +/// 16-byte linear term. `value` is a **canonical** Goldilocks field element +/// in `[0, p)` — the serializer handles the conversion from signed i64 or +/// unsigned u64 (including large values that exceed i64::MAX). +#[repr(C)] +#[derive(Clone, Copy, Debug)] +pub struct LinearTerm { + pub kind: u8, // 0 = Column, 2 = Constant + pub pad: [u8; 3], + pub column: u32, + pub value: u64, +} + +pub const LT_KIND_COLUMN: u8 = 0; +pub const LT_KIND_CONSTANT: u8 = 2; + +/// 24-byte multiplicity descriptor. +#[repr(C)] +#[derive(Clone, Copy, Debug)] +pub struct MultiplicityDesc { + pub kind: u8, + pub pad: [u8; 3], + pub cols: [u32; 3], + pub num_linear_terms: u32, + pub linear_term_offset: u32, +} + +impl Default for MultiplicityDesc { + fn default() -> Self { + Self { + kind: MULT_ONE, + pad: [0; 3], + cols: [0; 3], + num_linear_terms: 0, + linear_term_offset: 0, + } + } +} + +/// Device-resident main columns — hold this once per aux-build so every +/// pair reuses the same H2D copy instead of re-uploading the ~240 MB +/// main trace for every interaction pair. +pub struct DeviceMainCols { + pub dev: CudaSlice, + pub num_cols: usize, + pub n: usize, +} + +/// Upload the column-major main trace (`num_main_cols * n` u64s) to the +/// device once. Pair kernels then reference it via `&DeviceMainCols`. +pub fn upload_main_cols(main_cols_host: &[u64], num_main_cols: usize, n: usize) + -> Result +{ + assert_eq!(main_cols_host.len(), num_main_cols * n); + let be = backend(); + let stream = be.next_stream(); + let dev = stream.clone_htod(main_cols_host)?; + stream.synchronize()?; + Ok(DeviceMainCols { + dev, + num_cols: num_main_cols, + n, + }) +} + +/// Variant of `logup_pair_term_column` that reuses a pre-uploaded +/// `DeviceMainCols`. This is the fast path for aux-build, where 30+ +/// pairs all share the same main trace. +#[allow(clippy::too_many_arguments)] +pub fn logup_pair_term_column_on_device( + main: &DeviceMainCols, + bus_id_a: u64, + bus_id_b: u64, + ops_a: &[FingerprintOp], + ops_b: &[FingerprintOp], + linear_terms: &[LinearTerm], + alpha_powers: &[u64], + z: &[u64; 3], + mult_a: &MultiplicityDesc, + mult_b: &MultiplicityDesc, + negate_a: bool, + negate_b: bool, +) -> Result> { + let n = main.n; + let be = backend(); + let stream = be.next_stream(); + + let ops_a_dev: CudaSlice = upload_ops(&stream, ops_a)?; + let ops_b_dev: CudaSlice = upload_ops(&stream, ops_b)?; + let lt_dev: CudaSlice = upload_linear_terms(&stream, linear_terms)?; + let mult_a_dev: CudaSlice = upload_mult(&stream, mult_a)?; + let mult_b_dev: CudaSlice = upload_mult(&stream, mult_b)?; + let alpha_dev = stream.clone_htod(alpha_powers)?; + let z_dev = stream.clone_htod(z)?; + + let mut fp_dev = stream.alloc_zeros::(2 * n * 3)?; + + let col_stride = n as u64; + let n_u64 = n as u64; + let ops_a_count = ops_a.len() as u32; + let ops_b_count = ops_b.len() as u32; + + let cfg = LaunchConfig { + grid_dim: (((n as u32) + 255) / 256, 1, 1), + block_dim: (256, 1, 1), + shared_mem_bytes: 0, + }; + unsafe { + stream + .launch_builder(&be.logup_pair_fingerprint) + .arg(&main.dev) + .arg(&col_stride) + .arg(&n_u64) + .arg(&bus_id_a) + .arg(&bus_id_b) + .arg(&ops_a_dev) + .arg(&ops_a_count) + .arg(&ops_b_dev) + .arg(&ops_b_count) + .arg(<_dev) + .arg(&alpha_dev) + .arg(&z_dev) + .arg(&mut fp_dev) + .launch(cfg)?; + } + + let inv_fp_dev = run_batch_inverse_on_device(&stream, &fp_dev, 2 * n)?; + + let mut term_dev = stream.alloc_zeros::(n * 3)?; + let neg_a: u8 = negate_a as u8; + let neg_b: u8 = negate_b as u8; + unsafe { + stream + .launch_builder(&be.logup_pair_term_assembly) + .arg(&inv_fp_dev) + .arg(&main.dev) + .arg(&col_stride) + .arg(&n_u64) + .arg(<_dev) + .arg(&mult_a_dev) + .arg(&mult_b_dev) + .arg(&neg_a) + .arg(&neg_b) + .arg(&mut term_dev) + .launch(cfg)?; + } + + let out = stream.clone_dtoh(&term_dev)?; + stream.synchronize()?; + Ok(out) +} + +/// Single-interaction variant using a shared `DeviceMainCols`. +#[allow(clippy::too_many_arguments)] +pub fn logup_single_term_column_on_device( + main: &DeviceMainCols, + bus_id: u64, + ops: &[FingerprintOp], + linear_terms: &[LinearTerm], + alpha_powers: &[u64], + z: &[u64; 3], + mult: &MultiplicityDesc, + negate: bool, +) -> Result> { + let n = main.n; + let be = backend(); + let stream = be.next_stream(); + + let ops_dev: CudaSlice = upload_ops(&stream, ops)?; + let lt_dev: CudaSlice = upload_linear_terms(&stream, linear_terms)?; + let mult_dev: CudaSlice = upload_mult(&stream, mult)?; + let alpha_dev = stream.clone_htod(alpha_powers)?; + let z_dev = stream.clone_htod(z)?; + + let mut fp_dev = stream.alloc_zeros::(n * 3)?; + + let col_stride = n as u64; + let n_u64 = n as u64; + let ops_count = ops.len() as u32; + + let cfg = LaunchConfig { + grid_dim: (((n as u32) + 255) / 256, 1, 1), + block_dim: (256, 1, 1), + shared_mem_bytes: 0, + }; + unsafe { + stream + .launch_builder(&be.logup_single_fingerprint) + .arg(&main.dev) + .arg(&col_stride) + .arg(&n_u64) + .arg(&bus_id) + .arg(&ops_dev) + .arg(&ops_count) + .arg(<_dev) + .arg(&alpha_dev) + .arg(&z_dev) + .arg(&mut fp_dev) + .launch(cfg)?; + } + + let inv_fp_dev = run_batch_inverse_on_device(&stream, &fp_dev, n)?; + + let mut term_dev = stream.alloc_zeros::(n * 3)?; + let neg: u8 = negate as u8; + unsafe { + stream + .launch_builder(&be.logup_single_term_assembly) + .arg(&inv_fp_dev) + .arg(&main.dev) + .arg(&col_stride) + .arg(&n_u64) + .arg(<_dev) + .arg(&mult_dev) + .arg(&neg) + .arg(&mut term_dev) + .launch(cfg)?; + } + + let out = stream.clone_dtoh(&term_dev)?; + stream.synchronize()?; + Ok(out) +} + +/// Run the fingerprint + batch-inverse + term-assembly pipeline for ONE +/// interaction pair. Produces an `Vec` of size `3 * n` (ext3 +/// interleaved) representing the term column. +/// +/// `main_cols_host`: column-major, `num_main_cols * n` u64s. +/// `ops_a / ops_b`: serialised FingerprintOp slices for each side. +/// `linear_terms`: shared pool indexed by op.linear_term_offset + +/// multiplicity.linear_term_offset. +/// `alpha_powers`: `3 * max_bus_elements` u64 (ext3 interleaved). +/// `z`: 3 u64. +#[allow(clippy::too_many_arguments)] +pub fn logup_pair_term_column( + main_cols_host: &[u64], + num_main_cols: usize, + n: usize, + bus_id_a: u64, + bus_id_b: u64, + ops_a: &[FingerprintOp], + ops_b: &[FingerprintOp], + linear_terms: &[LinearTerm], + alpha_powers: &[u64], + z: &[u64; 3], + mult_a: &MultiplicityDesc, + mult_b: &MultiplicityDesc, + negate_a: bool, + negate_b: bool, +) -> Result> { + assert_eq!(main_cols_host.len(), num_main_cols * n); + + let be = backend(); + let stream = be.next_stream(); + + // H2D main cols + bytecode. + let main_dev = stream.clone_htod(main_cols_host)?; + let ops_a_dev: CudaSlice = upload_ops(&stream, ops_a)?; + let ops_b_dev: CudaSlice = upload_ops(&stream, ops_b)?; + let lt_dev: CudaSlice = upload_linear_terms(&stream, linear_terms)?; + let mult_a_dev: CudaSlice = upload_mult(&stream, mult_a)?; + let mult_b_dev: CudaSlice = upload_mult(&stream, mult_b)?; + let alpha_dev = stream.clone_htod(alpha_powers)?; + let z_dev = stream.clone_htod(z)?; + + // Fingerprint buffer: 2n ext3. + let mut fp_dev = stream.alloc_zeros::(2 * n * 3)?; + + let col_stride = n as u64; + let n_u64 = n as u64; + let bus_a = bus_id_a; + let bus_b = bus_id_b; + let ops_a_count = ops_a.len() as u32; + let ops_b_count = ops_b.len() as u32; + + let cfg = LaunchConfig { + grid_dim: (((n as u32) + 255) / 256, 1, 1), + block_dim: (256, 1, 1), + shared_mem_bytes: 0, + }; + unsafe { + stream + .launch_builder(&be.logup_pair_fingerprint) + .arg(&main_dev) + .arg(&col_stride) + .arg(&n_u64) + .arg(&bus_a) + .arg(&bus_b) + .arg(&ops_a_dev) + .arg(&ops_a_count) + .arg(&ops_b_dev) + .arg(&ops_b_count) + .arg(<_dev) + .arg(&alpha_dev) + .arg(&z_dev) + .arg(&mut fp_dev) + .launch(cfg)?; + } + + // Batch-invert the 2n fingerprints in place using our parallel scan. + // The existing `batch_inverse_ext3_dev` expects a &CudaSlice and + // returns a new Vec (host). For the fused flow we want to keep + // the inverted fingerprints on device; reuse the lower-level ops. + // Simplest: run it host-side (it D2H'd and we'd H2D back — wasteful). + // + // Better: replicate the scan-phase launches here, writing back to + // `fp_dev`. Avoids the round-trip entirely. + let inv_fp_dev = run_batch_inverse_on_device(&stream, &fp_dev, 2 * n)?; + + // Term assembly. + let mut term_dev = stream.alloc_zeros::(n * 3)?; + let neg_a: u8 = negate_a as u8; + let neg_b: u8 = negate_b as u8; + unsafe { + stream + .launch_builder(&be.logup_pair_term_assembly) + .arg(&inv_fp_dev) + .arg(&main_dev) + .arg(&col_stride) + .arg(&n_u64) + .arg(<_dev) + .arg(&mult_a_dev) + .arg(&mult_b_dev) + .arg(&neg_a) + .arg(&neg_b) + .arg(&mut term_dev) + .launch(cfg)?; + } + + let out = stream.clone_dtoh(&term_dev)?; + stream.synchronize()?; + Ok(out) +} + +/// Single-interaction variant (used for the absorbed odd interaction). +#[allow(clippy::too_many_arguments)] +pub fn logup_single_term_column( + main_cols_host: &[u64], + num_main_cols: usize, + n: usize, + bus_id: u64, + ops: &[FingerprintOp], + linear_terms: &[LinearTerm], + alpha_powers: &[u64], + z: &[u64; 3], + mult: &MultiplicityDesc, + negate: bool, +) -> Result> { + assert_eq!(main_cols_host.len(), num_main_cols * n); + + let be = backend(); + let stream = be.next_stream(); + + let main_dev = stream.clone_htod(main_cols_host)?; + let ops_dev: CudaSlice = upload_ops(&stream, ops)?; + let lt_dev: CudaSlice = upload_linear_terms(&stream, linear_terms)?; + let mult_dev: CudaSlice = upload_mult(&stream, mult)?; + let alpha_dev = stream.clone_htod(alpha_powers)?; + let z_dev = stream.clone_htod(z)?; + + let mut fp_dev = stream.alloc_zeros::(n * 3)?; + + let col_stride = n as u64; + let n_u64 = n as u64; + let ops_count = ops.len() as u32; + + let cfg = LaunchConfig { + grid_dim: (((n as u32) + 255) / 256, 1, 1), + block_dim: (256, 1, 1), + shared_mem_bytes: 0, + }; + unsafe { + stream + .launch_builder(&be.logup_single_fingerprint) + .arg(&main_dev) + .arg(&col_stride) + .arg(&n_u64) + .arg(&bus_id) + .arg(&ops_dev) + .arg(&ops_count) + .arg(<_dev) + .arg(&alpha_dev) + .arg(&z_dev) + .arg(&mut fp_dev) + .launch(cfg)?; + } + + let inv_fp_dev = run_batch_inverse_on_device(&stream, &fp_dev, n)?; + + let mut term_dev = stream.alloc_zeros::(n * 3)?; + let neg: u8 = negate as u8; + unsafe { + stream + .launch_builder(&be.logup_single_term_assembly) + .arg(&inv_fp_dev) + .arg(&main_dev) + .arg(&col_stride) + .arg(&n_u64) + .arg(<_dev) + .arg(&mult_dev) + .arg(&neg) + .arg(&mut term_dev) + .launch(cfg)?; + } + + let out = stream.clone_dtoh(&term_dev)?; + stream.synchronize()?; + Ok(out) +} + +// ============================================================================= +// Internals: upload helpers + re-runnable batch inverse on device +// ============================================================================= + +fn upload_ops( + stream: &std::sync::Arc, + ops: &[FingerprintOp], +) -> Result> { + let bytes = unsafe { + core::slice::from_raw_parts( + ops.as_ptr() as *const u8, + ops.len() * core::mem::size_of::(), + ) + }; + if bytes.is_empty() { + // cudarc disallows zero-length allocs; use a 1-byte dummy. + let dummy = [0u8; 1]; + return Ok(stream.clone_htod(&dummy)?); + } + Ok(stream.clone_htod(bytes)?) +} + +fn upload_linear_terms( + stream: &std::sync::Arc, + terms: &[LinearTerm], +) -> Result> { + let bytes = unsafe { + core::slice::from_raw_parts( + terms.as_ptr() as *const u8, + terms.len() * core::mem::size_of::(), + ) + }; + if bytes.is_empty() { + let dummy = [0u8; 1]; + return Ok(stream.clone_htod(&dummy)?); + } + Ok(stream.clone_htod(bytes)?) +} + +fn upload_mult( + stream: &std::sync::Arc, + m: &MultiplicityDesc, +) -> Result> { + let bytes = unsafe { + core::slice::from_raw_parts( + m as *const MultiplicityDesc as *const u8, + core::mem::size_of::(), + ) + }; + Ok(stream.clone_htod(bytes)?) +} + +/// Inline version of parallel Montgomery batch inverse that runs entirely +/// on device without D2H'ing the scan result. Mirrors the logic in +/// `crate::inverse::batch_inverse_ext3_dev` but is duplicated here to keep +/// the fingerprint buffer on the same stream. +fn run_batch_inverse_on_device( + stream: &std::sync::Arc, + a_dev: &CudaSlice, + n: usize, +) -> Result> { + let be = backend(); + let mut prefix_dev = stream.alloc_zeros::(n * 3)?; + let mut suffix_dev = stream.alloc_zeros::(n * 3)?; + + let k: u32 = 256; + let c_per_thread: u64 = ((n as u64) + (k as u64) - 1) / (k as u64); + let mut chunk_totals = stream.alloc_zeros::((k as usize) * 3)?; + let mut chunk_offsets = stream.alloc_zeros::((k as usize) * 3)?; + let n_u64 = n as u64; + let k_u64 = k as u64; + + let cfg_scan = LaunchConfig { + grid_dim: (1, 1, 1), + block_dim: (k, 1, 1), + shared_mem_bytes: 0, + }; + + unsafe { + stream + .launch_builder(&be.chunk_prefix_scan_ext3) + .arg(a_dev) + .arg(&n_u64) + .arg(&c_per_thread) + .arg(&mut prefix_dev) + .arg(&mut chunk_totals) + .launch(cfg_scan)?; + } + unsafe { + stream + .launch_builder(&be.exclusive_scan_of_totals_ext3) + .arg(&chunk_totals) + .arg(&k_u64) + .arg(&mut chunk_offsets) + .launch(LaunchConfig { + grid_dim: (1, 1, 1), + block_dim: (1, 1, 1), + shared_mem_bytes: 0, + })?; + } + unsafe { + stream + .launch_builder(&be.apply_scan_offsets_ext3) + .arg(&mut prefix_dev) + .arg(&n_u64) + .arg(&c_per_thread) + .arg(&chunk_offsets) + .launch(cfg_scan)?; + } + + let mut suf_ct = stream.alloc_zeros::((k as usize) * 3)?; + let mut suf_off = stream.alloc_zeros::((k as usize) * 3)?; + unsafe { + stream + .launch_builder(&be.chunk_suffix_scan_ext3) + .arg(a_dev) + .arg(&n_u64) + .arg(&c_per_thread) + .arg(&mut suffix_dev) + .arg(&mut suf_ct) + .launch(cfg_scan)?; + } + unsafe { + stream + .launch_builder(&be.exclusive_reverse_scan_of_totals_ext3) + .arg(&suf_ct) + .arg(&k_u64) + .arg(&mut suf_off) + .launch(LaunchConfig { + grid_dim: (1, 1, 1), + block_dim: (1, 1, 1), + shared_mem_bytes: 0, + })?; + } + unsafe { + stream + .launch_builder(&be.apply_reverse_scan_offsets_ext3) + .arg(&mut suffix_dev) + .arg(&n_u64) + .arg(&c_per_thread) + .arg(&suf_off) + .launch(cfg_scan)?; + } + + // D2H last prefix element, invert on host, H2D inv_total. + let total = { + let last_view = prefix_dev.slice((n - 1) * 3..n * 3); + let last_host: Vec = stream.clone_dtoh(&last_view)?; + stream.synchronize()?; + crate::inverse::invert_ext3_host_pub([last_host[0], last_host[1], last_host[2]]) + }; + let mut inv_total_dev = stream.alloc_zeros::(3)?; + stream.memcpy_htod(&total, &mut inv_total_dev)?; + + let mut out_dev = stream.alloc_zeros::(n * 3)?; + let cfg_combine = LaunchConfig { + grid_dim: (((n as u32) + 255) / 256, 1, 1), + block_dim: (256, 1, 1), + shared_mem_bytes: 0, + }; + unsafe { + stream + .launch_builder(&be.batch_inverse_combine_ext3) + .arg(&prefix_dev) + .arg(&suffix_dev) + .arg(&inv_total_dev) + .arg(&n_u64) + .arg(&mut out_dev) + .launch(cfg_combine)?; + } + + Ok(out_dev) +} diff --git a/crypto/math-cuda/src/merkle.rs b/crypto/math-cuda/src/merkle.rs new file mode 100644 index 000000000..18c2e14d1 --- /dev/null +++ b/crypto/math-cuda/src/merkle.rs @@ -0,0 +1,414 @@ +//! GPU Keccak-256 leaf hashing for Merkle commits. +//! +//! Matches `FieldElementVectorBackend::hash_data` in +//! `crypto/crypto/src/merkle_tree/backends/field_element_vector.rs`, combined +//! with the `reverse_index` row read pattern used in +//! `commit_columns_bit_reversed` at `crypto/stark/src/prover.rs:368`. +//! +//! Caller supplies base-field column slabs already laid out as +//! `[col * col_stride + row]` (the same layout `coset_lde_batch_base_into` +//! writes to the pinned staging buffer). The kernel bit-reverses `row_idx`, +//! reads each column's canonical u64 at that row, byte-swaps it into a +//! Keccak lane, absorbs lane-by-lane, and squeezes 32 bytes per leaf. +//! +//! For ext3 columns the layout is `[col*3*col_stride + k*col_stride + row]` +//! — three base slabs per ext3 column — and the kernel reads three u64s per +//! column in component order 0,1,2 to match `FieldElement::::write_bytes_be`. + +use cudarc::driver::{CudaSlice, CudaStream, LaunchConfig, PushKernelArg}; + +use crate::Result; +use crate::device::backend; + +/// Run GPU Keccak-256 leaf hashing on a base-field column buffer. +/// +/// `columns` must hold `num_cols * col_stride` u64s with column `c`'s data +/// at `[c*col_stride .. c*col_stride + num_rows]`. Returns `num_rows * 32` +/// hash bytes in natural (non-bit-reversed) row order. +pub fn keccak_leaves_base( + columns: &[u64], + col_stride: usize, + num_cols: usize, + num_rows: usize, +) -> Result> { + assert!(num_rows.is_power_of_two()); + assert!(columns.len() >= num_cols * col_stride); + let be = backend(); + let stream = be.next_stream(); + let cols_dev = stream.clone_htod(&columns[..num_cols * col_stride])?; + let mut out_dev = stream.alloc_zeros::(num_rows * 32)?; + launch_keccak_base( + stream.as_ref(), + &cols_dev, + col_stride as u64, + num_cols as u64, + num_rows as u64, + &mut out_dev, + )?; + let out = stream.clone_dtoh(&out_dev)?; + stream.synchronize()?; + Ok(out) +} + +/// Ext3 variant — columns interleaved as three base slabs per ext3 column. +/// `columns.len() >= num_cols * 3 * col_stride`. +pub fn keccak_leaves_ext3( + columns: &[u64], + col_stride: usize, + num_cols: usize, + num_rows: usize, +) -> Result> { + assert!(num_rows.is_power_of_two()); + assert!(columns.len() >= num_cols * 3 * col_stride); + let be = backend(); + let stream = be.next_stream(); + let cols_dev = stream.clone_htod(&columns[..num_cols * 3 * col_stride])?; + let mut out_dev = stream.alloc_zeros::(num_rows * 32)?; + launch_keccak_ext3( + stream.as_ref(), + &cols_dev, + col_stride as u64, + num_cols as u64, + num_rows as u64, + &mut out_dev, + )?; + let out = stream.clone_dtoh(&out_dev)?; + stream.synchronize()?; + Ok(out) +} + +/// Block size for Keccak kernels. Per-thread register footprint is ~60 regs +/// (25-lane state + auxiliaries); the default 256 threads/block pushes the +/// block register file past the hardware limit on sm_120 (Blackwell). 128 +/// keeps us inside the budget with some head-room. +const KECCAK_BLOCK_DIM: u32 = 128; + +fn keccak_launch_cfg(num_rows: u64) -> LaunchConfig { + let grid = ((num_rows as u32) + KECCAK_BLOCK_DIM - 1) / KECCAK_BLOCK_DIM; + LaunchConfig { + grid_dim: (grid, 1, 1), + block_dim: (KECCAK_BLOCK_DIM, 1, 1), + shared_mem_bytes: 0, + } +} + +pub(crate) fn launch_keccak_base( + stream: &CudaStream, + cols_dev: &CudaSlice, + col_stride: u64, + num_cols: u64, + num_rows: u64, + out_dev: &mut CudaSlice, +) -> Result<()> { + let be = backend(); + let log_num_rows = num_rows.trailing_zeros() as u64; + let cfg = keccak_launch_cfg(num_rows); + unsafe { + stream + .launch_builder(&be.keccak256_leaves_base_batched) + .arg(cols_dev) + .arg(&col_stride) + .arg(&num_cols) + .arg(&num_rows) + .arg(&log_num_rows) + .arg(out_dev) + .launch(cfg)?; + } + Ok(()) +} + +/// Given `hashed_leaves` of length `leaves_len * 32`, build the full Merkle +/// tree on device and return the complete node buffer `(2*leaves_len - 1) * +/// 32` bytes in the standard layout: +/// +/// `nodes[0..leaves_len - 1]` are inner nodes (root at index 0), and +/// `nodes[leaves_len - 1..]` are the leaves themselves. +/// +/// Matches the CPU `crypto/crypto/src/merkle_tree/merkle.rs` construction so +/// the resulting `nodes` Vec plugs straight into `MerkleTree { root, nodes }` +/// for downstream proof generation. +/// +/// `leaves_len` must be a power of two and ≥ 2. +pub fn build_merkle_tree_on_device(hashed_leaves: &[u8]) -> Result> { + assert!(hashed_leaves.len() % 32 == 0); + let leaves_len = hashed_leaves.len() / 32; + assert!(leaves_len >= 2, "tree needs at least two leaves"); + assert!( + leaves_len.is_power_of_two(), + "leaves_len must be a power of two" + ); + + let total_nodes = 2 * leaves_len - 1; + let be = backend(); + let stream = be.next_stream(); + + // Allocate the full node buffer without zero-fill — we overwrite the + // leaf half via H2D immediately, and every inner node is written by the + // pair-hash kernel below. + // SAFETY: every byte is written before it is read: leaves are filled by + // the H2D below; inner nodes are filled by the level loop that follows. + let mut nodes_dev = unsafe { stream.alloc::(total_nodes * 32) }?; + let leaves_offset_bytes = (leaves_len - 1) * 32; + // SAFETY: target slice `nodes_dev[leaves_offset_bytes..]` has exactly + // `leaves_len * 32 == hashed_leaves.len()` bytes capacity. + { + let mut slice = + nodes_dev.slice_mut(leaves_offset_bytes..leaves_offset_bytes + hashed_leaves.len()); + stream.memcpy_htod(hashed_leaves, &mut slice)?; + } + + // Build level by level. The CPU `build(nodes, leaves_len)` starts with + // level_begin_index = leaves_len - 1 + // level_end_index = 2 * level_begin_index + // and each iteration computes: + // new_level_begin_index = level_begin_index / 2 + // new_level_length = level_begin_index - new_level_begin_index + // The parents occupy [new_level_begin_index, level_begin_index); the + // children occupy [level_begin_index, level_end_index + 1). + let mut level_begin: u64 = (leaves_len - 1) as u64; + while level_begin != 0 { + let new_begin = level_begin / 2; + let n_pairs = level_begin - new_begin; + + let cfg = keccak_launch_cfg(n_pairs); + unsafe { + stream + .launch_builder(&be.keccak_merkle_level) + .arg(&mut nodes_dev) + .arg(&new_begin) + .arg(&n_pairs) + .launch(cfg)?; + } + level_begin = new_begin; + } + + let out = stream.clone_dtoh(&nodes_dev)?; + stream.synchronize()?; + Ok(out) +} + +/// Row-pair Keccak leaf + Merkle tree build for R2 composition-polynomial +/// commit. `parts_interleaved` is `num_parts` slices, each holding an ext3 +/// LDE column interleaved as `[a0,a1,a2, b0,b1,b2, ...]` of length `3*lde_size`. +/// +/// Returns `(2*(lde_size/2) - 1) * 32` bytes of tree nodes in the standard +/// layout (root at byte offset 0, leaves in the tail). +pub fn build_comp_poly_tree_from_evals_ext3( + parts_interleaved: &[&[u64]], +) -> Result> { + assert!(!parts_interleaved.is_empty()); + let m = parts_interleaved.len(); + let ext3_elems = parts_interleaved[0].len() / 3; + assert_eq!( + parts_interleaved[0].len(), + 3 * ext3_elems, + "ext3 buffer length must be 3 * lde_size" + ); + for p in parts_interleaved.iter() { + assert_eq!(p.len(), 3 * ext3_elems); + } + let lde_size = ext3_elems; + assert!(lde_size.is_power_of_two() && lde_size >= 2); + let num_leaves = lde_size / 2; + let tight_total_nodes = 2 * num_leaves - 1; + + let be = backend(); + let stream = be.next_stream(); + let staging_slot = be.pinned_staging(); + + // Stage: de-interleave each part into 3 base slabs in pinned memory. + let mb = 3 * m; + let mut staging = staging_slot.lock().unwrap(); + staging.ensure_capacity(mb * lde_size, &be.ctx)?; + let pinned = unsafe { staging.as_mut_slice(mb * lde_size) }; + + use rayon::prelude::*; + let pinned_ptr_u = pinned.as_mut_ptr() as usize; + parts_interleaved + .par_iter() + .enumerate() + .for_each(|(c, col)| { + let slab_a = unsafe { + std::slice::from_raw_parts_mut( + (pinned_ptr_u as *mut u64).add((c * 3) * lde_size), + lde_size, + ) + }; + let slab_b = unsafe { + std::slice::from_raw_parts_mut( + (pinned_ptr_u as *mut u64).add((c * 3 + 1) * lde_size), + lde_size, + ) + }; + let slab_c = unsafe { + std::slice::from_raw_parts_mut( + (pinned_ptr_u as *mut u64).add((c * 3 + 2) * lde_size), + lde_size, + ) + }; + for i in 0..lde_size { + slab_a[i] = col[i * 3]; + slab_b[i] = col[i * 3 + 1]; + slab_c[i] = col[i * 3 + 2]; + } + }); + + // H2D the de-interleaved parts. + let mut buf = stream.alloc_zeros::(mb * lde_size)?; + stream.memcpy_htod(&pinned[..mb * lde_size], &mut buf)?; + + // Leaves into tail of a tight node buffer. + let mut nodes_dev = unsafe { stream.alloc::(tight_total_nodes * 32) }?; + let leaves_offset_bytes = (num_leaves - 1) * 32; + { + let mut leaves_view = + nodes_dev.slice_mut(leaves_offset_bytes..leaves_offset_bytes + num_leaves * 32); + let col_stride_u64 = lde_size as u64; + let num_parts_u64 = m as u64; + let num_rows_u64 = lde_size as u64; + let log_num_rows = lde_size.trailing_zeros() as u64; + let grid = ((num_leaves as u32) + 128 - 1) / 128; + let cfg = LaunchConfig { + grid_dim: (grid, 1, 1), + block_dim: (128, 1, 1), + shared_mem_bytes: 0, + }; + unsafe { + stream + .launch_builder(&be.keccak_comp_poly_leaves_ext3) + .arg(&buf) + .arg(&col_stride_u64) + .arg(&num_parts_u64) + .arg(&num_rows_u64) + .arg(&log_num_rows) + .arg(&mut leaves_view) + .launch(cfg)?; + } + } + + // Inner tree. + { + let mut level_begin: u64 = (num_leaves - 1) as u64; + while level_begin != 0 { + let new_begin = level_begin / 2; + let n_pairs = level_begin - new_begin; + let grid = ((n_pairs as u32) + 128 - 1) / 128; + let cfg = LaunchConfig { + grid_dim: (grid, 1, 1), + block_dim: (128, 1, 1), + shared_mem_bytes: 0, + }; + unsafe { + stream + .launch_builder(&be.keccak_merkle_level) + .arg(&mut nodes_dev) + .arg(&new_begin) + .arg(&n_pairs) + .launch(cfg)?; + } + level_begin = new_begin; + } + } + + let out = stream.clone_dtoh(&nodes_dev)?; + stream.synchronize()?; + drop(staging); + Ok(out) +} + +/// Build a FRI-layer Merkle tree on device from an interleaved ext3 eval +/// vector. Each leaf hashes two consecutive ext3 values; `num_leaves = +/// evals.len() / 6` (since each ext3 is 3 u64s). +/// +/// Returns the `(2*num_leaves - 1) * 32`-byte node buffer in standard layout. +pub fn build_fri_layer_tree_from_evals_ext3(evals: &[u64]) -> Result> { + assert!(evals.len() % 6 == 0, "evals must hold whole pair-leaves"); + let num_evals = evals.len() / 3; + let num_leaves = num_evals / 2; + assert!(num_leaves.is_power_of_two() && num_leaves >= 1); + let tight_total_nodes = 2 * num_leaves - 1; + if tight_total_nodes == 0 { + return Ok(Vec::new()); + } + + let be = backend(); + let stream = be.next_stream(); + + let evals_dev = stream.clone_htod(evals)?; + let mut nodes_dev = unsafe { stream.alloc::(tight_total_nodes * 32) }?; + + // Leaf kernel: num_leaves threads, one leaf each. + let leaves_offset_bytes = (num_leaves - 1) * 32; + { + let mut leaves_view = + nodes_dev.slice_mut(leaves_offset_bytes..leaves_offset_bytes + num_leaves * 32); + let num_leaves_u64 = num_leaves as u64; + let grid = ((num_leaves as u32) + 128 - 1) / 128; + let cfg = LaunchConfig { + grid_dim: (grid, 1, 1), + block_dim: (128, 1, 1), + shared_mem_bytes: 0, + }; + unsafe { + stream + .launch_builder(&be.keccak_fri_leaves_ext3) + .arg(&evals_dev) + .arg(&num_leaves_u64) + .arg(&mut leaves_view) + .launch(cfg)?; + } + } + + // Inner tree levels — identical to the R2 version. + { + let mut level_begin: u64 = (num_leaves - 1) as u64; + while level_begin != 0 { + let new_begin = level_begin / 2; + let n_pairs = level_begin - new_begin; + let grid = ((n_pairs as u32) + 128 - 1) / 128; + let cfg = LaunchConfig { + grid_dim: (grid, 1, 1), + block_dim: (128, 1, 1), + shared_mem_bytes: 0, + }; + unsafe { + stream + .launch_builder(&be.keccak_merkle_level) + .arg(&mut nodes_dev) + .arg(&new_begin) + .arg(&n_pairs) + .launch(cfg)?; + } + level_begin = new_begin; + } + } + + let out = stream.clone_dtoh(&nodes_dev)?; + stream.synchronize()?; + Ok(out) +} + +pub(crate) fn launch_keccak_ext3( + stream: &CudaStream, + cols_dev: &CudaSlice, + col_stride: u64, + num_cols: u64, + num_rows: u64, + out_dev: &mut CudaSlice, +) -> Result<()> { + let be = backend(); + let log_num_rows = num_rows.trailing_zeros() as u64; + let cfg = keccak_launch_cfg(num_rows); + unsafe { + stream + .launch_builder(&be.keccak256_leaves_ext3_batched) + .arg(cols_dev) + .arg(&col_stride) + .arg(&num_cols) + .arg(&num_rows) + .arg(&log_num_rows) + .arg(out_dev) + .launch(cfg)?; + } + Ok(()) +} diff --git a/crypto/math-cuda/src/ntt.rs b/crypto/math-cuda/src/ntt.rs new file mode 100644 index 000000000..0ebb015ea --- /dev/null +++ b/crypto/math-cuda/src/ntt.rs @@ -0,0 +1,211 @@ +//! Forward and inverse NTT over Goldilocks base field. Matches the algebraic +//! contract of `math::polynomial::Polynomial::evaluate_fft` / +//! `interpolate_fft`: +//! input = n elements in natural order +//! output = n elements in natural order. +//! +//! Parity is checked by `tests/ntt.rs` against the CPU implementation. + +use cudarc::driver::{LaunchConfig, PushKernelArg}; +use math::field::element::FieldElement; +use math::field::goldilocks::GoldilocksField; +use math::field::traits::{IsFFTField, IsField}; + +use crate::Result; +use crate::device::backend; + +/// Host-side twiddle table: `[ω^0, ω^1, …, ω^{n/2-1}]` where ω is the +/// primitive n-th root of unity. Exposed for `device::Backend::cached_twiddles` +/// and for direct use in tests / benches. +pub fn twiddles_forward(log_n: u64) -> Vec { + let omega = *GoldilocksField::get_primitive_root_of_unity(log_n) + .expect("primitive root") + .value(); + powers_of(omega, 1usize << (log_n - 1)) +} + +/// Inverse twiddle table: `[ω^{-i}]` for i in [0, n/2). +pub fn twiddles_inverse(log_n: u64) -> Vec { + let omega = GoldilocksField::get_primitive_root_of_unity(log_n).expect("primitive root"); + let omega_inv = FieldElement::::inv(&omega).expect("inverse"); + powers_of(*omega_inv.value(), 1usize << (log_n - 1)) +} + +fn powers_of(base: u64, count: usize) -> Vec { + let mut out = Vec::with_capacity(count); + let mut w = 1u64; + for _ in 0..count { + out.push(w); + w = GoldilocksField::mul(&w, &base); + } + out +} + +/// Forward NTT on a slice of `n = 2^log_n` Goldilocks coefficients. Takes +/// natural-order input and returns natural-order evaluations. +pub fn forward(coeffs: &[u64]) -> Result> { + ntt_inplace(coeffs, /*forward=*/ true) +} + +/// Inverse NTT on a slice of `n = 2^log_n` Goldilocks evaluations. Takes +/// natural-order evaluations and returns natural-order coefficients. Includes +/// the 1/n scaling. +pub fn inverse(evals: &[u64]) -> Result> { + ntt_inplace(evals, /*forward=*/ false) +} + +fn ntt_inplace(input: &[u64], forward: bool) -> Result> { + let n = input.len(); + assert!(n.is_power_of_two(), "ntt length must be a power of two"); + if n <= 1 { + return Ok(input.to_vec()); + } + let log_n = n.trailing_zeros() as u64; + + let be = backend(); + let stream = be.next_stream(); + + let mut x_dev = stream.clone_htod(input)?; + let tw_dev = if forward { + be.fwd_twiddles_for(log_n)? + } else { + be.inv_twiddles_for(log_n)? + }; + + let n_u64 = n as u64; + + // 1. Bit-reverse: natural → bit-reversed. + unsafe { + stream + .launch_builder(&be.bit_reverse_permute) + .arg(&mut x_dev) + .arg(&n_u64) + .arg(&log_n) + .launch(LaunchConfig::for_num_elems(n as u32))?; + } + + // 2. DIT butterfly levels. For log_n >= 8 we fuse 8 levels per kernel via + // the shmem kernel; for very small sizes (< 256 elements) we stick with + // the per-level kernel because the shmem block dimensions assume n ≥ 256. + run_ntt_body( + stream.as_ref(), + &mut x_dev, + tw_dev.as_ref(), + n_u64, + log_n, + )?; + + // 3. For iNTT, multiply by 1/n. + if !forward { + let n_fe = FieldElement::::from(n as u64); + let inv_n = *n_fe.inv().expect("n is non-zero").value(); + unsafe { + stream + .launch_builder(&be.scalar_mul) + .arg(&mut x_dev) + .arg(&inv_n) + .arg(&n_u64) + .launch(LaunchConfig::for_num_elems(n as u32))?; + } + } + + let out = stream.clone_dtoh(&x_dev)?; + stream.synchronize()?; + Ok(out) +} + +/// Run the butterfly body of a bit-reversed-input DIT NTT. Split out so the +/// LDE orchestrator can reuse it on the same device buffer. +pub(crate) fn run_ntt_body( + stream: &cudarc::driver::CudaStream, + x_dev: &mut cudarc::driver::CudaSlice, + tw_dev: &cudarc::driver::CudaSlice, + n: u64, + log_n: u64, +) -> Result<()> { + let be = backend(); + // Levels 0..min(log_n, 8): one shmem-fused launch. Loads are fully + // coalesced (base_step=0 → `row = tid`) and 8 butterfly rounds stay on + // chip. This is the big DRAM-bandwidth win. + let fused = core::cmp::min(log_n, 8); + if fused >= 8 { + let grid_x = (n / 256) as u32; + let cfg = LaunchConfig { + grid_dim: (grid_x, 1, 1), + block_dim: (256, 1, 1), + shared_mem_bytes: 0, + }; + let base_step = 0u64; + unsafe { + stream + .launch_builder(&be.ntt_dit_8_levels) + .arg(&mut *x_dev) + .arg(tw_dev) + .arg(&n) + .arg(&log_n) + .arg(&base_step) + .launch(cfg)?; + } + } else { + // Sub-256-element NTT. Use per-level. + let half_cfg = LaunchConfig::for_num_elems((n / 2) as u32); + for level in 0..fused { + unsafe { + stream + .launch_builder(&be.ntt_dit_level) + .arg(&mut *x_dev) + .arg(tw_dev) + .arg(&n) + .arg(&log_n) + .arg(&level) + .launch(half_cfg)?; + } + } + } + + // Levels 8..log_n: per-level kernels. Loads are fully coalesced in the + // per-level path; switching to fused-with-row-remap at base_step>0 tanks + // DRAM throughput enough to wipe out the launch savings. + let half_cfg = LaunchConfig::for_num_elems((n / 2) as u32); + for level in fused..log_n { + unsafe { + stream + .launch_builder(&be.ntt_dit_level) + .arg(&mut *x_dev) + .arg(tw_dev) + .arg(&n) + .arg(&log_n) + .arg(&level) + .launch(half_cfg)?; + } + } + Ok(()) +} + +/// Pointwise multiply: `x[i] *= w[i]`. +pub fn pointwise_mul(x: &[u64], w: &[u64]) -> Result> { + assert_eq!(x.len(), w.len()); + let n = x.len(); + if n == 0 { + return Ok(Vec::new()); + } + let be = backend(); + let stream = be.next_stream(); + + let mut x_dev = stream.clone_htod(x)?; + let w_dev = stream.clone_htod(w)?; + + let n_u64 = n as u64; + unsafe { + stream + .launch_builder(&be.pointwise_mul) + .arg(&mut x_dev) + .arg(&w_dev) + .arg(&n_u64) + .launch(LaunchConfig::for_num_elems(n as u32))?; + } + + let out = stream.clone_dtoh(&x_dev)?; + stream.synchronize()?; + Ok(out) +} diff --git a/crypto/math-cuda/tests/barycentric.rs b/crypto/math-cuda/tests/barycentric.rs new file mode 100644 index 000000000..dcb47327a --- /dev/null +++ b/crypto/math-cuda/tests/barycentric.rs @@ -0,0 +1,145 @@ +//! Parity: GPU barycentric sum vs CPU. We don't call the upstream +//! `interpolate_coset_eval_*_with_g_n_inv` directly because the GPU kernel +//! returns only the unscaled sum — the caller applies the ext3 scale. We +//! replicate the same unscaled sum on CPU for comparison. + +use math::field::element::FieldElement; +use math::field::extensions_goldilocks::Degree3GoldilocksExtensionField; +use math::field::goldilocks::GoldilocksField; +use math::field::traits::IsPrimeField; +use rand::{Rng, SeedableRng}; +use rand_chacha::ChaCha8Rng; + +type Fp = FieldElement; +type Fp3 = FieldElement; + +fn canon_triplet(e: &Fp3) -> [u64; 3] { + [ + GoldilocksField::canonical(e.value()[0].value()), + GoldilocksField::canonical(e.value()[1].value()), + GoldilocksField::canonical(e.value()[2].value()), + ] +} + +fn canon_triplet_raw(t: &[u64]) -> [u64; 3] { + [ + GoldilocksField::canonical(&t[0]), + GoldilocksField::canonical(&t[1]), + GoldilocksField::canonical(&t[2]), + ] +} + +fn random_fp(rng: &mut ChaCha8Rng) -> Fp { + Fp::from_raw(rng.r#gen::()) +} +fn random_fp3(rng: &mut ChaCha8Rng) -> Fp3 { + Fp3::new([random_fp(rng), random_fp(rng), random_fp(rng)]) +} + +#[test] +fn barycentric_base_sum_matches_cpu() { + for &(log_n, num_cols) in &[(4u32, 1usize), (8, 5), (10, 17), (12, 3)] { + let n = 1 << log_n; + let mut rng = ChaCha8Rng::seed_from_u64(100 + log_n as u64 * 7 + num_cols as u64); + + let coset_points: Vec = (0..n).map(|_| random_fp(&mut rng)).collect(); + let inv_denoms: Vec = (0..n).map(|_| random_fp3(&mut rng)).collect(); + + // Lay out columns base: column c contiguous slab of n u64s. + let cols_fp: Vec> = (0..num_cols) + .map(|_| (0..n).map(|_| random_fp(&mut rng)).collect()) + .collect(); + let mut columns_flat = vec![0u64; num_cols * n]; + for (c, col) in cols_fp.iter().enumerate() { + for (r, e) in col.iter().enumerate() { + columns_flat[c * n + r] = *e.value(); + } + } + let points_raw: Vec = coset_points.iter().map(|e| *e.value()).collect(); + let inv_denoms_raw: Vec = inv_denoms + .iter() + .flat_map(|e| [*e.value()[0].value(), *e.value()[1].value(), *e.value()[2].value()]) + .collect(); + + let gpu = math_cuda::barycentric::barycentric_base( + &columns_flat, + n, + &points_raw, + &inv_denoms_raw, + n, + num_cols, + ) + .unwrap(); + + for (c, col) in cols_fp.iter().enumerate() { + // CPU reference sum. Force ext3 by embedding the base product. + let mut sum = Fp3::zero(); + for i in 0..n { + let pe_base: Fp = &coset_points[i] * &col[i]; // F × F = F + // Base × ext3 = ext3 (base is on the left — IsSubFieldOf direction). + let pe_ext3: Fp3 = &pe_base * &inv_denoms[i]; // F × E = E + sum = &sum + &pe_ext3; + } + let gpu_sum = canon_triplet_raw(&gpu[c * 3..(c + 1) * 3]); + let cpu_sum = canon_triplet(&sum); + assert_eq!( + gpu_sum, cpu_sum, + "base col {c} log_n={log_n} num_cols={num_cols}" + ); + } + } +} + +#[test] +fn barycentric_ext3_sum_matches_cpu() { + for &(log_n, num_cols) in &[(4u32, 1usize), (8, 3), (10, 7)] { + let n = 1 << log_n; + let mut rng = ChaCha8Rng::seed_from_u64(200 + log_n as u64 * 11 + num_cols as u64); + + let coset_points: Vec = (0..n).map(|_| random_fp(&mut rng)).collect(); + let inv_denoms: Vec = (0..n).map(|_| random_fp3(&mut rng)).collect(); + let cols_fp3: Vec> = (0..num_cols) + .map(|_| (0..n).map(|_| random_fp3(&mut rng)).collect()) + .collect(); + + // De-interleaved layout: 3 base slabs per ext3 column. + let mut columns_flat = vec![0u64; num_cols * 3 * n]; + for (c, col) in cols_fp3.iter().enumerate() { + for (r, e) in col.iter().enumerate() { + columns_flat[(c * 3 + 0) * n + r] = *e.value()[0].value(); + columns_flat[(c * 3 + 1) * n + r] = *e.value()[1].value(); + columns_flat[(c * 3 + 2) * n + r] = *e.value()[2].value(); + } + } + let points_raw: Vec = coset_points.iter().map(|e| *e.value()).collect(); + let inv_denoms_raw: Vec = inv_denoms + .iter() + .flat_map(|e| [*e.value()[0].value(), *e.value()[1].value(), *e.value()[2].value()]) + .collect(); + + let gpu = math_cuda::barycentric::barycentric_ext3( + &columns_flat, + n, + &points_raw, + &inv_denoms_raw, + n, + num_cols, + ) + .unwrap(); + + for (c, col) in cols_fp3.iter().enumerate() { + let mut sum = Fp3::zero(); + for i in 0..n { + let pe: Fp3 = &coset_points[i] * &col[i]; // F × E = E + let term: Fp3 = &pe * &inv_denoms[i]; // E × E = E + sum = &sum + &term; + } + let gpu_sum = canon_triplet_raw(&gpu[c * 3..(c + 1) * 3]); + let cpu_sum = canon_triplet(&sum); + assert_eq!( + gpu_sum, cpu_sum, + "ext3 col {c} log_n={log_n} num_cols={num_cols}" + ); + } + } +} diff --git a/crypto/math-cuda/tests/barycentric_strided.rs b/crypto/math-cuda/tests/barycentric_strided.rs new file mode 100644 index 000000000..7f9d0f910 --- /dev/null +++ b/crypto/math-cuda/tests/barycentric_strided.rs @@ -0,0 +1,152 @@ +//! Parity: strided barycentric kernels (used by R3 OOD on device LDE handles) +//! match the non-strided kernels fed a pre-strided column buffer. + +use std::sync::Arc; + +use math::field::element::FieldElement; +use math::field::extensions_goldilocks::Degree3GoldilocksExtensionField; +use math::field::goldilocks::GoldilocksField; +use math::field::traits::IsField; +use rand::{Rng, SeedableRng}; +use rand_chacha::ChaCha8Rng; + +type Fp = FieldElement; +type Fp3 = FieldElement; + +fn rand_fp(rng: &mut ChaCha8Rng) -> Fp { + Fp::from_raw(rng.r#gen::()) +} +fn rand_fp3(rng: &mut ChaCha8Rng) -> Fp3 { + Fp3::new([rand_fp(rng), rand_fp(rng), rand_fp(rng)]) +} + +fn run_base(log_trace: u32, blowup: usize, num_cols: usize, seed: u64) { + let n = 1usize << log_trace; + let lde_size = n * blowup; + let mut rng = ChaCha8Rng::seed_from_u64(seed); + let lde_data: Vec> = (0..num_cols) + .map(|_| (0..lde_size).map(|_| rand_fp(&mut rng)).collect()) + .collect(); + let coset_points: Vec = (0..n).map(|_| rng.r#gen::()).collect(); + let inv_denoms_ext3: Vec = (0..(n * 3)).map(|_| rng.r#gen::()).collect(); + + // Pack full LDE column-major for device. + let mut lde_flat = vec![0u64; num_cols * lde_size]; + for (c, col) in lde_data.iter().enumerate() { + for (r, v) in col.iter().enumerate() { + lde_flat[c * lde_size + r] = *v.value(); + } + } + let be = math_cuda::device::backend(); + let stream = be.next_stream(); + let lde_dev = stream.clone_htod(&lde_flat).unwrap(); + stream.synchronize().unwrap(); + let handle = math_cuda::lde::GpuLdeBase { + buf: Arc::new(lde_dev), + m: num_cols, + lde_size, + }; + + // Pre-strided buffer for non-strided reference: trace-size picks of each col. + let mut pre_strided = vec![0u64; num_cols * n]; + for c in 0..num_cols { + for i in 0..n { + pre_strided[c * n + i] = lde_flat[c * lde_size + i * blowup]; + } + } + + let reference = math_cuda::barycentric::barycentric_base( + &pre_strided, + n, + &coset_points, + &inv_denoms_ext3, + n, + num_cols, + ) + .unwrap(); + + let strided = math_cuda::barycentric::barycentric_base_on_device( + &handle, + blowup, + &coset_points, + &inv_denoms_ext3, + n, + ) + .unwrap(); + + assert_eq!(reference, strided, "base strided mismatch (log_trace={log_trace}, blowup={blowup}, cols={num_cols})"); +} + +fn run_ext3(log_trace: u32, blowup: usize, num_cols: usize, seed: u64) { + let n = 1usize << log_trace; + let lde_size = n * blowup; + let mut rng = ChaCha8Rng::seed_from_u64(seed); + let lde_data: Vec> = (0..num_cols) + .map(|_| (0..lde_size).map(|_| rand_fp3(&mut rng)).collect()) + .collect(); + let coset_points: Vec = (0..n).map(|_| rng.r#gen::()).collect(); + let inv_denoms_ext3: Vec = (0..(n * 3)).map(|_| rng.r#gen::()).collect(); + + // Pack LDE de-interleaved: (m*3) × lde_size. + let mut lde_flat = vec![0u64; num_cols * 3 * lde_size]; + for (c, col) in lde_data.iter().enumerate() { + for (r, v) in col.iter().enumerate() { + lde_flat[(c * 3) * lde_size + r] = *v.value()[0].value(); + lde_flat[(c * 3 + 1) * lde_size + r] = *v.value()[1].value(); + lde_flat[(c * 3 + 2) * lde_size + r] = *v.value()[2].value(); + } + } + let be = math_cuda::device::backend(); + let stream = be.next_stream(); + let lde_dev = stream.clone_htod(&lde_flat).unwrap(); + stream.synchronize().unwrap(); + let handle = math_cuda::lde::GpuLdeExt3 { + buf: Arc::new(lde_dev), + m: num_cols, + lde_size, + }; + + // Pre-strided buffer for non-strided reference. + let mut pre_strided = vec![0u64; num_cols * 3 * n]; + for c in 0..num_cols { + for i in 0..n { + pre_strided[(c * 3) * n + i] = lde_flat[(c * 3) * lde_size + i * blowup]; + pre_strided[(c * 3 + 1) * n + i] = lde_flat[(c * 3 + 1) * lde_size + i * blowup]; + pre_strided[(c * 3 + 2) * n + i] = lde_flat[(c * 3 + 2) * lde_size + i * blowup]; + } + } + let reference = math_cuda::barycentric::barycentric_ext3( + &pre_strided, + n, + &coset_points, + &inv_denoms_ext3, + n, + num_cols, + ) + .unwrap(); + + let strided = math_cuda::barycentric::barycentric_ext3_on_device( + &handle, + blowup, + &coset_points, + &inv_denoms_ext3, + n, + ) + .unwrap(); + + assert_eq!(reference, strided, "ext3 strided mismatch"); +} + +#[test] +fn bary_base_strided_small() { + for (log_t, blowup, cols) in [(4u32, 2usize, 3usize), (8, 4, 10), (12, 2, 5)] { + run_base(log_t, blowup, cols, 1000 + log_t as u64); + } +} + +#[test] +fn bary_ext3_strided_small() { + for (log_t, blowup, cols) in [(4u32, 2usize, 2usize), (8, 4, 5), (10, 2, 3)] { + run_ext3(log_t, blowup, cols, 2000 + log_t as u64); + } +} diff --git a/crypto/math-cuda/tests/batch_inverse.rs b/crypto/math-cuda/tests/batch_inverse.rs new file mode 100644 index 000000000..fe240762d --- /dev/null +++ b/crypto/math-cuda/tests/batch_inverse.rs @@ -0,0 +1,92 @@ +//! Parity: GPU parallel batch inverse matches CPU +//! `FieldElement::inplace_batch_inverse` on ext3 elements. + +use math::field::element::FieldElement; +use math::field::extensions_goldilocks::Degree3GoldilocksExtensionField; +use math::field::goldilocks::GoldilocksField; +use math::field::traits::{IsField, IsPrimeField}; +use rand::{Rng, SeedableRng}; +use rand_chacha::ChaCha8Rng; + +type Fp = FieldElement; +type Fp3 = FieldElement; + +fn rand_fp(rng: &mut ChaCha8Rng) -> Fp { + loop { + let v = rng.r#gen::(); + // Avoid zero — batch inverse requires all non-zero. + if v != 0 { + return Fp::from_raw(v); + } + } +} +fn rand_fp3_nonzero(rng: &mut ChaCha8Rng) -> Fp3 { + // Random non-zero ext3: at least one component non-zero, all in [1, p). + Fp3::new([rand_fp(rng), rand_fp(rng), rand_fp(rng)]) +} + +fn ext3_to_u64s(col: &[Fp3]) -> Vec { + let mut out = Vec::with_capacity(col.len() * 3); + for e in col { + out.push(*e.value()[0].value()); + out.push(*e.value()[1].value()); + out.push(*e.value()[2].value()); + } + out +} + +fn canon3(a: &[u64]) -> Vec { + a.iter() + .enumerate() + .map(|(i, v)| { + // Each u64 is canonicalised independently (ext3 = 3 base coords). + let _ = i; + GoldilocksField::canonical(v) + }) + .collect() +} + +fn run(n: usize, seed: u64) { + let mut rng = ChaCha8Rng::seed_from_u64(seed); + let xs: Vec = (0..n).map(|_| rand_fp3_nonzero(&mut rng)).collect(); + + // CPU reference: inplace_batch_inverse. + let mut cpu = xs.clone(); + FieldElement::inplace_batch_inverse(&mut cpu).expect("batch inverse non-zero"); + + // GPU. + let input_u64 = ext3_to_u64s(&xs); + let gpu_u64 = math_cuda::inverse::batch_inverse_ext3(&input_u64).unwrap(); + + let cpu_u64 = ext3_to_u64s(&cpu); + let gpu_canon = canon3(&gpu_u64); + let cpu_canon = canon3(&cpu_u64); + + for i in 0..n { + let g = &gpu_canon[i * 3..(i + 1) * 3]; + let c = &cpu_canon[i * 3..(i + 1) * 3]; + assert_eq!(g, c, "mismatch at i={i} n={n}"); + } +} + +#[test] +fn batch_inverse_small() { + for n in [2usize, 3, 5, 16, 63, 255, 256, 257] { + run(n, 100 + n as u64); + } +} + +#[test] +fn batch_inverse_medium() { + for n in [1024usize, 4096, 8192] { + run(n, 500 + n as u64); + } +} + +#[test] +fn batch_inverse_large() { + // Matches R3 OOD / R4 DEEP sizes for fib_1M (domain_size = 2^18, + // num_denoms_max = 2^18 × 4). + run(1 << 18, 999); + run(1 << 20, 12345); +} diff --git a/crypto/math-cuda/tests/bench_quick.rs b/crypto/math-cuda/tests/bench_quick.rs new file mode 100644 index 000000000..561331b74 --- /dev/null +++ b/crypto/math-cuda/tests/bench_quick.rs @@ -0,0 +1,356 @@ +//! Informal timing comparison for single-column and multi-column LDE. +//! Ignored by default; run with `cargo test ... -- --ignored --nocapture`. + +use std::time::Instant; + +use math::fft::cpu::bowers_fft::LayerTwiddles; +use math::field::element::FieldElement; +use math::field::goldilocks::GoldilocksField; +use math::field::traits::IsField; +use math::polynomial::Polynomial; +use rand::{Rng, SeedableRng}; +use rand_chacha::ChaCha8Rng; +use rayon::prelude::*; + +type Fp = FieldElement; + +fn coset_weights(n: usize, g: u64) -> Vec { + let inv_n = *FieldElement::::from(n as u64).inv().unwrap().value(); + let mut w = Vec::with_capacity(n); + let mut cur = inv_n; + for _ in 0..n { + w.push(cur); + cur = GoldilocksField::mul(&cur, &g); + } + w +} + +#[test] +#[ignore = "informal perf probe; run with --ignored"] +fn bench_lde_2_to_18_blowup_4() { + let log_n = 18; + let blowup = 4; + let n = 1usize << log_n; + let lde = n * blowup; + let mut rng = ChaCha8Rng::seed_from_u64(1); + let input: Vec = (0..n).map(|_| rng.r#gen::()).collect(); + let weights = coset_weights(n, 7); + + let _ = math_cuda::lde::coset_lde_base(&input, blowup, &weights).unwrap(); + + let inv_tw = LayerTwiddles::::new_inverse(log_n as u64).unwrap(); + let fwd_tw = LayerTwiddles::::new(lde.trailing_zeros() as u64).unwrap(); + let weights_fp: Vec = weights.iter().map(|&w| Fp::from_raw(w)).collect(); + + const TRIALS: u32 = 10; + + let t0 = Instant::now(); + for _ in 0..TRIALS { + let _ = math_cuda::lde::coset_lde_base(&input, blowup, &weights).unwrap(); + } + let gpu_ns = t0.elapsed().as_nanos() / TRIALS as u128; + + let t0 = Instant::now(); + for _ in 0..TRIALS { + let mut buf: Vec = input.iter().map(|&x| Fp::from_raw(x)).collect(); + Polynomial::coset_lde_full_expand::( + &mut buf, blowup, &weights_fp, &inv_tw, &fwd_tw, + ) + .unwrap(); + std::hint::black_box(&buf); + } + let cpu_ns = t0.elapsed().as_nanos() / TRIALS as u128; + + let ratio = cpu_ns as f64 / gpu_ns as f64; + println!( + "single-column LDE 2^{log_n} blowup={blowup}: cpu={cpu_ns}ns gpu={gpu_ns}ns ratio={ratio:.2}x", + ); +} + +#[test] +#[ignore = "informal perf probe; run with --ignored"] +fn bench_lde_2_to_16_blowup_4() { + let log_n = 16; + let blowup = 4; + let n = 1usize << log_n; + let mut rng = ChaCha8Rng::seed_from_u64(2); + let input: Vec = (0..n).map(|_| rng.r#gen::()).collect(); + let weights = coset_weights(n, 7); + + let _ = math_cuda::lde::coset_lde_base(&input, blowup, &weights).unwrap(); + + const TRIALS: u32 = 20; + + let t0 = Instant::now(); + for _ in 0..TRIALS { + let _ = math_cuda::lde::coset_lde_base(&input, blowup, &weights).unwrap(); + } + let gpu_ns = t0.elapsed().as_nanos() / TRIALS as u128; + println!("single-column LDE 2^{log_n} blowup={blowup}: gpu={gpu_ns}ns"); +} + +#[test] +#[ignore = "informal perf probe; run with --ignored"] +fn bench_lde_multi_column_parallel() { + // Simulates the prover's Phase A: many columns processed via rayon. + // log_n = 16 keeps memory footprint manageable while still stressing streams. + let log_n = 16u32; + let blowup = 4usize; + let n = 1usize << log_n; + let lde = n * blowup; + let num_cols = 64; + + // Warm up. + let _ = math_cuda::lde::coset_lde_base( + &vec![0u64; n], + blowup, + &coset_weights(n, 7), + ) + .unwrap(); + + // Build input data. + let mut rng = ChaCha8Rng::seed_from_u64(11); + let columns: Vec> = (0..num_cols) + .map(|_| (0..n).map(|_| rng.r#gen::()).collect()) + .collect(); + let weights = coset_weights(n, 7); + let weights_fp: Vec = weights.iter().map(|&w| Fp::from_raw(w)).collect(); + let inv_tw = LayerTwiddles::::new_inverse(log_n as u64).unwrap(); + let fwd_tw = LayerTwiddles::::new(lde.trailing_zeros() as u64).unwrap(); + + // GPU: rayon parallel across columns, each column picks a stream. + let t0 = Instant::now(); + let _gpu_results: Vec> = columns + .par_iter() + .map(|col| math_cuda::lde::coset_lde_base(col, blowup, &weights).unwrap()) + .collect(); + let gpu_ns = t0.elapsed().as_nanos(); + + // CPU: same rayon parallel pattern as the prover's `expand_columns_to_lde`. + let mut cpu_bufs: Vec> = columns + .iter() + .map(|c| c.iter().map(|&x| Fp::from_raw(x)).collect()) + .collect(); + let t0 = Instant::now(); + cpu_bufs.par_iter_mut().for_each(|buf| { + Polynomial::coset_lde_full_expand::( + buf, blowup, &weights_fp, &inv_tw, &fwd_tw, + ) + .unwrap(); + }); + let cpu_ns = t0.elapsed().as_nanos(); + + let ratio = cpu_ns as f64 / gpu_ns as f64; + println!( + "{num_cols}-column LDE 2^{log_n} blowup={blowup}: cpu={cpu_ns}ns gpu={gpu_ns}ns ratio={ratio:.2}x (cores={})", + rayon::current_num_threads(), + ); +} + +#[test] +#[ignore = "informal perf probe; run with --ignored"] +fn bench_lde_batched_prover_scale() { + // Realistic large-table shape from the 1M-fib prover: ~1M rows, blowup 4, + // a few dozen columns. This is what actually runs in expand_columns_to_lde. + let log_n = 20u32; // 1M rows + let blowup = 4usize; + let n = 1usize << log_n; + let num_cols = 20; + + let mut rng = ChaCha8Rng::seed_from_u64(31); + let columns: Vec> = (0..num_cols) + .map(|_| (0..n).map(|_| rng.r#gen::()).collect()) + .collect(); + let weights = coset_weights(n, 7); + let weights_fp: Vec = weights.iter().map(|&w| Fp::from_raw(w)).collect(); + let inv_tw = LayerTwiddles::::new_inverse(log_n as u64).unwrap(); + let fwd_tw = LayerTwiddles::::new( + (n * blowup).trailing_zeros() as u64, + ) + .unwrap(); + + let warm_slices: Vec<&[u64]> = columns.iter().map(|c| c.as_slice()).collect(); + for _ in 0..8 { + let _ = + math_cuda::lde::coset_lde_batch_base(&warm_slices, blowup, &weights).unwrap(); + } + + let slices: Vec<&[u64]> = columns.iter().map(|c| c.as_slice()).collect(); + let mut gpu_samples = Vec::with_capacity(10); + for _ in 0..10 { + let t0 = Instant::now(); + let _ = math_cuda::lde::coset_lde_batch_base(&slices, blowup, &weights).unwrap(); + gpu_samples.push(t0.elapsed().as_nanos()); + } + gpu_samples.sort(); + let gpu_ns = gpu_samples[gpu_samples.len() / 2]; // median + + let mut cpu_samples = Vec::with_capacity(10); + for _ in 0..10 { + let mut cpu_bufs: Vec> = columns + .iter() + .map(|c| c.iter().map(|&x| Fp::from_raw(x)).collect()) + .collect(); + let t0 = Instant::now(); + cpu_bufs.par_iter_mut().for_each(|buf| { + Polynomial::coset_lde_full_expand::( + buf, blowup, &weights_fp, &inv_tw, &fwd_tw, + ) + .unwrap(); + }); + cpu_samples.push(t0.elapsed().as_nanos()); + } + cpu_samples.sort(); + let cpu_ns = cpu_samples[cpu_samples.len() / 2]; // median + + let ratio = cpu_ns as f64 / gpu_ns as f64; + println!( + "prover-scale batched {num_cols} cols, log_n={log_n}, blowup={blowup}: cpu={cpu_ns}ns gpu={gpu_ns}ns ratio={ratio:.2}x (median of 10)", + ); +} + +#[test] +#[ignore = "informal perf probe; run with --ignored"] +fn bench_lde_batched_vs_rayon_cpu() { + let log_n = 16u32; + let blowup = 4usize; + let n = 1usize << log_n; + let num_cols = 64; + + let mut rng = ChaCha8Rng::seed_from_u64(21); + let columns: Vec> = (0..num_cols) + .map(|_| (0..n).map(|_| rng.r#gen::()).collect()) + .collect(); + let weights = coset_weights(n, 7); + + // Warm up every stream slot so subsequent iterations don't pay the + // one-time pinned staging alloc cost. + let warm_slices: Vec<&[u64]> = columns.iter().map(|c| c.as_slice()).collect(); + for _ in 0..64 { + let _ = + math_cuda::lde::coset_lde_batch_base(&warm_slices, blowup, &weights).unwrap(); + } + let weights_fp: Vec = weights.iter().map(|&w| Fp::from_raw(w)).collect(); + let inv_tw = LayerTwiddles::::new_inverse(log_n as u64).unwrap(); + let fwd_tw = LayerTwiddles::::new( + (n * blowup).trailing_zeros() as u64, + ) + .unwrap(); + + // GPU batched — first run may include lazy device init; do a few to stabilise. + let slices: Vec<&[u64]> = columns.iter().map(|c| c.as_slice()).collect(); + let mut gpu_ns = u128::MAX; + for _ in 0..5 { + let t0 = Instant::now(); + let _ = math_cuda::lde::coset_lde_batch_base(&slices, blowup, &weights).unwrap(); + gpu_ns = gpu_ns.min(t0.elapsed().as_nanos()); + } + + // CPU rayon (same pattern as prover). + let mut cpu_bufs: Vec> = columns + .iter() + .map(|c| c.iter().map(|&x| Fp::from_raw(x)).collect()) + .collect(); + let t0 = Instant::now(); + cpu_bufs.par_iter_mut().for_each(|buf| { + Polynomial::coset_lde_full_expand::( + buf, blowup, &weights_fp, &inv_tw, &fwd_tw, + ) + .unwrap(); + }); + let cpu_ns = t0.elapsed().as_nanos(); + + let ratio = cpu_ns as f64 / gpu_ns as f64; + println!( + "batched {num_cols} cols, log_n={log_n}, blowup={blowup}: cpu={cpu_ns}ns gpu={gpu_ns}ns ratio={ratio:.2}x (cores={})", + rayon::current_num_threads(), + ); +} + +#[test] +#[ignore = "informal perf probe; run with --ignored"] +fn bench_lde_multi_column_serialized_gpu() { + use std::sync::Mutex; + + let log_n = 16u32; + let blowup = 4usize; + let n = 1usize << log_n; + let num_cols = 64; + + let _ = math_cuda::lde::coset_lde_base( + &vec![0u64; n], + blowup, + &coset_weights(n, 7), + ) + .unwrap(); + + let mut rng = ChaCha8Rng::seed_from_u64(13); + let columns: Vec> = (0..num_cols) + .map(|_| (0..n).map(|_| rng.r#gen::()).collect()) + .collect(); + let weights = coset_weights(n, 7); + + // Single global Mutex so only one thread at a time calls GPU. + let gpu_lock = Mutex::new(()); + let t0 = Instant::now(); + let _: Vec> = columns + .par_iter() + .map(|col| { + let _guard = gpu_lock.lock().unwrap(); + math_cuda::lde::coset_lde_base(col, blowup, &weights).unwrap() + }) + .collect(); + let gpu_ns = t0.elapsed().as_nanos(); + println!("GPU mutex-serialised from rayon: {gpu_ns}ns for {num_cols} cols"); +} + +#[test] +#[ignore = "informal perf probe; run with --ignored"] +fn bench_lde_multi_column_gpu_limited_threads() { + // Same as multi_column_parallel but forces rayon to use only 8 threads + // (matching the GPU stream pool rough capacity). Tests whether oversubscribed + // rayon + many streams is the bottleneck. + let gpu_pool = rayon::ThreadPoolBuilder::new() + .num_threads(8) + .build() + .unwrap(); + + let log_n = 16u32; + let blowup = 4usize; + let n = 1usize << log_n; + let num_cols = 64; + + let _ = math_cuda::lde::coset_lde_base( + &vec![0u64; n], + blowup, + &coset_weights(n, 7), + ) + .unwrap(); + + let mut rng = ChaCha8Rng::seed_from_u64(12); + let columns: Vec> = (0..num_cols) + .map(|_| (0..n).map(|_| rng.r#gen::()).collect()) + .collect(); + let weights = coset_weights(n, 7); + + let t0 = Instant::now(); + let _gpu_results: Vec> = gpu_pool.install(|| { + columns + .par_iter() + .map(|col| math_cuda::lde::coset_lde_base(col, blowup, &weights).unwrap()) + .collect() + }); + let gpu_ns = t0.elapsed().as_nanos(); + + let t0 = Instant::now(); + let _serial_gpu_results: Vec> = columns + .iter() + .map(|col| math_cuda::lde::coset_lde_base(col, blowup, &weights).unwrap()) + .collect(); + let gpu_serial_ns = t0.elapsed().as_nanos(); + + println!( + "GPU-only 8-thread: gpu-parallel={gpu_ns}ns gpu-serial={gpu_serial_ns}ns speedup={:.2}x", + gpu_serial_ns as f64 / gpu_ns as f64, + ); +} diff --git a/crypto/math-cuda/tests/comp_poly_tree.rs b/crypto/math-cuda/tests/comp_poly_tree.rs new file mode 100644 index 000000000..94ede1f33 --- /dev/null +++ b/crypto/math-cuda/tests/comp_poly_tree.rs @@ -0,0 +1,225 @@ +//! Parity: GPU fused `evaluate_poly_coset_batch_ext3_into_with_merkle_tree` +//! (LDE + row-pair Keccak leaves + Merkle inner tree) against the same CPU +//! pipeline produced by `commit_composition_polynomial`. + +use math::field::element::FieldElement; +use math::field::extensions_goldilocks::Degree3GoldilocksExtensionField; +use math::field::goldilocks::GoldilocksField; +use math::field::traits::{IsField, IsPrimeField}; +use math::polynomial::Polynomial; +use math::traits::ByteConversion; +use rand::{Rng, SeedableRng}; +use rand_chacha::ChaCha8Rng; +use sha3::{Digest, Keccak256}; + +type Fp = FieldElement; +type Fp3 = FieldElement; + +fn reverse_index(i: u64, n: u64) -> u64 { + let log_n = n.trailing_zeros(); + i.reverse_bits() >> (64 - log_n) +} + +fn offset_weights(n: usize, offset: u64) -> Vec { + let mut w = Vec::with_capacity(n); + let mut cur = 1u64; + for _ in 0..n { + w.push(cur); + cur = GoldilocksField::mul(&cur, &offset); + } + w +} + +fn rand_ext3(rng: &mut ChaCha8Rng) -> Fp3 { + Fp3::new([ + Fp::from_raw(rng.r#gen::()), + Fp::from_raw(rng.r#gen::()), + Fp::from_raw(rng.r#gen::()), + ]) +} + +fn ext3_to_u64s(col: &[Fp3]) -> Vec { + let mut out = Vec::with_capacity(col.len() * 3); + for e in col { + out.push(*e.value()[0].value()); + out.push(*e.value()[1].value()); + out.push(*e.value()[2].value()); + } + out +} + +fn u64s_to_ext3(raw: &[u64]) -> Vec { + let mut out = Vec::with_capacity(raw.len() / 3); + for i in 0..raw.len() / 3 { + out.push(Fp3::new([ + Fp::from_raw(raw[i * 3]), + Fp::from_raw(raw[i * 3 + 1]), + Fp::from_raw(raw[i * 3 + 2]), + ])); + } + out +} + +fn canon_ext3(e: &Fp3) -> [u64; 3] { + [ + GoldilocksField::canonical(e.value()[0].value()), + GoldilocksField::canonical(e.value()[1].value()), + GoldilocksField::canonical(e.value()[2].value()), + ] +} + +/// CPU: evaluate polynomial on coset via `Polynomial::evaluate_offset_fft`. +fn cpu_evaluate(coefs: &[Fp3], blowup: usize, offset: &Fp) -> Vec { + let poly = Polynomial::new(coefs); + Polynomial::evaluate_offset_fft::(&poly, blowup, None, offset).unwrap() +} + +fn cpu_hash_pair(left: &[u8; 32], right: &[u8; 32]) -> [u8; 32] { + let mut h = Keccak256::new(); + h.update(left); + h.update(right); + let mut out = [0u8; 32]; + out.copy_from_slice(&h.finalize()); + out +} + +/// CPU: `commit_composition_polynomial`-style tree root over num_rows/2 leaves. +fn cpu_tree_nodes(parts: &[Vec]) -> Vec<[u8; 32]> { + let num_rows = parts[0].len(); + let num_parts = parts.len(); + let num_leaves = num_rows / 2; + assert!(num_leaves.is_power_of_two() && num_leaves >= 1); + let byte_len = 24; + + let hashed_leaves: Vec<[u8; 32]> = (0..num_leaves) + .map(|leaf_idx| { + let br_0 = reverse_index(2 * leaf_idx as u64, num_rows as u64) as usize; + let br_1 = reverse_index(2 * leaf_idx as u64 + 1, num_rows as u64) as usize; + let total_bytes = 2 * num_parts * byte_len; + let mut buf = vec![0u8; total_bytes]; + let mut offset = 0; + for part in parts.iter() { + part[br_0].write_bytes_be(&mut buf[offset..offset + byte_len]); + offset += byte_len; + } + for part in parts.iter() { + part[br_1].write_bytes_be(&mut buf[offset..offset + byte_len]); + offset += byte_len; + } + let mut h = Keccak256::new(); + h.update(&buf); + let mut r = [0u8; 32]; + r.copy_from_slice(&h.finalize()); + r + }) + .collect(); + + let total = 2 * num_leaves - 1; + let mut nodes: Vec<[u8; 32]> = vec![[0u8; 32]; total]; + for (i, leaf) in hashed_leaves.iter().enumerate() { + nodes[num_leaves - 1 + i] = *leaf; + } + let mut level_begin = num_leaves - 1; + while level_begin != 0 { + let new_begin = level_begin / 2; + let n_pairs = level_begin - new_begin; + for j in 0..n_pairs { + let left = nodes[level_begin + 2 * j]; + let right = nodes[level_begin + 2 * j + 1]; + nodes[new_begin + j] = cpu_hash_pair(&left, &right); + } + level_begin = new_begin; + } + nodes +} + +fn run_parity(log_n: u32, blowup: usize, num_parts: usize, seed: u64) { + let n = 1usize << log_n; + let lde_size = n * blowup; + assert!(lde_size >= 2); + let mut rng = ChaCha8Rng::seed_from_u64(seed); + + // Random ext3 coefficient vectors per part. + let parts_cpu: Vec> = (0..num_parts) + .map(|_| (0..n).map(|_| rand_ext3(&mut rng)).collect()) + .collect(); + + // CPU LDE via evaluate_offset_fft, then CPU tree. + let offset_u64 = rng.r#gen::() | 1; + let offset = Fp::from_raw(offset_u64); + let cpu_lde_parts: Vec> = parts_cpu + .iter() + .map(|c| cpu_evaluate(c, blowup, &offset)) + .collect(); + let cpu_nodes = cpu_tree_nodes(&cpu_lde_parts); + + // GPU fused call. + let weights = offset_weights(n, offset_u64); + let coefs_u64: Vec> = parts_cpu.iter().map(|c| ext3_to_u64s(c)).collect(); + let coefs_slices: Vec<&[u64]> = coefs_u64.iter().map(|v| v.as_slice()).collect(); + let mut outputs_raw: Vec> = (0..num_parts).map(|_| vec![0u64; 3 * lde_size]).collect(); + let mut outputs_slices: Vec<&mut [u64]> = outputs_raw + .iter_mut() + .map(|v| v.as_mut_slice()) + .collect(); + let total_nodes = 2 * lde_size - 1; + let mut nodes_bytes = vec![0u8; total_nodes * 32]; + + math_cuda::lde::evaluate_poly_coset_batch_ext3_into_with_merkle_tree( + &coefs_slices, + n, + blowup, + &weights, + &mut outputs_slices, + &mut nodes_bytes, + ) + .unwrap(); + + // Compare LDE parts. + for (c, cpu_col) in cpu_lde_parts.iter().enumerate() { + let gpu_col = u64s_to_ext3(&outputs_raw[c]); + for i in 0..lde_size { + assert_eq!( + canon_ext3(&gpu_col[i]), + canon_ext3(&cpu_col[i]), + "LDE mismatch part {c} row {i} log_n={log_n} blowup={blowup}" + ); + } + } + + // Compare tree nodes. GPU writes `2*num_leaves - 1 = lde_size - 1` nodes. + let num_leaves = lde_size / 2; + let tight_total = 2 * num_leaves - 1; + assert_eq!(cpu_nodes.len(), tight_total); + for i in 0..tight_total { + let g = &nodes_bytes[i * 32..(i + 1) * 32]; + let c = &cpu_nodes[i]; + assert_eq!( + g, c, + "tree node {i} mismatch at log_n={log_n} blowup={blowup} parts={num_parts}" + ); + } +} + +#[test] +fn comp_poly_tree_small() { + for log_n in 2u32..=5 { + for &blowup in &[2usize, 4, 8] { + for &parts in &[1usize, 2, 4] { + run_parity(log_n, blowup, parts, 1000 + log_n as u64 * 31 + parts as u64); + } + } + } +} + +#[test] +fn comp_poly_tree_medium() { + for &(log_n, blowup, parts) in &[(10u32, 4usize, 4usize), (12, 2, 3)] { + run_parity(log_n, blowup, parts, 2000 + log_n as u64 * 11 + parts as u64); + } +} + +#[test] +fn comp_poly_tree_large() { + run_parity(14, 2, 4, 9999); +} diff --git a/crypto/math-cuda/tests/deep.rs b/crypto/math-cuda/tests/deep.rs new file mode 100644 index 000000000..4a03ddc50 --- /dev/null +++ b/crypto/math-cuda/tests/deep.rs @@ -0,0 +1,286 @@ +//! Parity: GPU deep_composition_ext3 vs a direct CPU port of the same +//! row-wise summation. Uses random inputs — not the full stark LDE path. + +use math::field::element::FieldElement; +use math::field::extensions_goldilocks::Degree3GoldilocksExtensionField; +use math::field::goldilocks::GoldilocksField; +use math::field::traits::{IsField, IsPrimeField, IsSubFieldOf}; +use rand::{Rng, SeedableRng}; +use rand_chacha::ChaCha8Rng; + +type Fp = FieldElement; +type Fp3 = FieldElement; + +fn rand_fp(rng: &mut ChaCha8Rng) -> Fp { + Fp::from_raw(rng.r#gen::()) +} +fn rand_fp3(rng: &mut ChaCha8Rng) -> Fp3 { + Fp3::new([rand_fp(rng), rand_fp(rng), rand_fp(rng)]) +} + +fn ext3_to_raw(e: &Fp3) -> [u64; 3] { + [*e.value()[0].value(), *e.value()[1].value(), *e.value()[2].value()] +} + +fn canon3(e: &Fp3) -> [u64; 3] { + [ + GoldilocksField::canonical(e.value()[0].value()), + GoldilocksField::canonical(e.value()[1].value()), + GoldilocksField::canonical(e.value()[2].value()), + ] +} + +/// CPU reference: exact port of `compute_deep_composition_poly_evaluations`. +#[allow(clippy::too_many_arguments)] +fn cpu_deep( + main_lde: &[Vec], // num_main cols × lde_size + aux_lde: &[Vec], // num_aux cols × lde_size + h_lde: &[Vec], // num_parts × lde_size + h_ood: &[Fp3], // num_parts + trace_ood: &[Vec], // num_total_cols × num_eval_points + gammas_h: &[Fp3], // num_parts + gammas_tr: &[Vec], // num_total_cols × num_eval_points + inv_h: &[Fp3], // domain_size + inv_t: &[Vec], // num_eval_points × domain_size + blowup_factor: usize, + domain_size: usize, +) -> Vec { + let num_parts = h_lde.len(); + let num_main = main_lde.len(); + let num_aux = aux_lde.len(); + let num_eval_points = if trace_ood.is_empty() { + 0 + } else { + trace_ood[0].len() + }; + + (0..domain_size) + .map(|i| { + let row = i * blowup_factor; + let mut result = Fp3::zero(); + // H-terms + for j in 0..num_parts { + let num = &h_lde[j][row] - &h_ood[j]; + result += &gammas_h[j] * &num * &inv_h[i]; + } + // Main + for j in 0..num_main { + for k in 0..num_eval_points { + let t_val = &main_lde[j][row]; + let t_ood = &trace_ood[j][k]; + let num = t_val - t_ood; // base − ext3 = ext3 + result += &gammas_tr[j][k] * &num * &inv_t[k][i]; + } + } + // Aux + for j in 0..num_aux { + let trace_j = num_main + j; + for k in 0..num_eval_points { + let t_val = &aux_lde[j][row]; + let t_ood = &trace_ood[trace_j][k]; + let num = t_val - t_ood; + result += &gammas_tr[trace_j][k] * &num * &inv_t[k][i]; + } + } + result + }) + .collect() +} + +fn run_parity( + log_domain_size: u32, + blowup_factor: usize, + num_main: usize, + num_aux: usize, + num_parts: usize, + num_eval_points: usize, + seed: u64, +) { + let domain_size = 1usize << log_domain_size; + let lde_size = domain_size * blowup_factor; + let mut rng = ChaCha8Rng::seed_from_u64(seed); + + let main_lde: Vec> = (0..num_main) + .map(|_| (0..lde_size).map(|_| rand_fp(&mut rng)).collect()) + .collect(); + let aux_lde: Vec> = (0..num_aux) + .map(|_| (0..lde_size).map(|_| rand_fp3(&mut rng)).collect()) + .collect(); + let h_lde: Vec> = (0..num_parts) + .map(|_| (0..lde_size).map(|_| rand_fp3(&mut rng)).collect()) + .collect(); + let h_ood: Vec = (0..num_parts).map(|_| rand_fp3(&mut rng)).collect(); + let num_total_cols = num_main + num_aux; + let trace_ood: Vec> = (0..num_total_cols) + .map(|_| (0..num_eval_points).map(|_| rand_fp3(&mut rng)).collect()) + .collect(); + let gammas_h: Vec = (0..num_parts).map(|_| rand_fp3(&mut rng)).collect(); + let gammas_tr: Vec> = (0..num_total_cols) + .map(|_| (0..num_eval_points).map(|_| rand_fp3(&mut rng)).collect()) + .collect(); + let inv_h: Vec = (0..domain_size).map(|_| rand_fp3(&mut rng)).collect(); + let inv_t: Vec> = (0..num_eval_points) + .map(|_| (0..domain_size).map(|_| rand_fp3(&mut rng)).collect()) + .collect(); + + // CPU reference. + let cpu_out = cpu_deep( + &main_lde, &aux_lde, &h_lde, &h_ood, &trace_ood, &gammas_h, &gammas_tr, &inv_h, &inv_t, + blowup_factor, domain_size, + ); + + // GPU: upload main & aux LDEs into device buffers and wrap in handles. + use math_cuda::lde::{GpuLdeBase, GpuLdeExt3}; + let be = math_cuda::device::backend(); + let stream = be.next_stream(); + + // main_lde → col-major u64: m × lde_size + let mut main_flat = vec![0u64; num_main * lde_size]; + for (c, col) in main_lde.iter().enumerate() { + for (r, v) in col.iter().enumerate() { + main_flat[c * lde_size + r] = *v.value(); + } + } + let main_dev = stream.clone_htod(&main_flat).unwrap(); + + // aux_lde → de-interleaved: (m*3) × lde_size + let mut aux_flat = vec![0u64; num_aux * 3 * lde_size]; + for (c, col) in aux_lde.iter().enumerate() { + for (r, v) in col.iter().enumerate() { + let [a, b, c0] = ext3_to_raw(v); + aux_flat[(c * 3) * lde_size + r] = a; + aux_flat[(c * 3 + 1) * lde_size + r] = b; + aux_flat[(c * 3 + 2) * lde_size + r] = c0; + } + } + let aux_dev = stream.clone_htod(&aux_flat).unwrap(); + stream.synchronize().unwrap(); + + let main_handle = GpuLdeBase { + buf: std::sync::Arc::new(main_dev), + m: num_main, + lde_size, + }; + let aux_handle = if num_aux > 0 { + Some(GpuLdeExt3 { + buf: std::sync::Arc::new(aux_dev), + m: num_aux, + lde_size, + }) + } else { + drop(aux_dev); + None + }; + + // h_parts → de-interleaved: num_parts*3 × lde_size + let mut h_flat = vec![0u64; num_parts * 3 * lde_size]; + for (p, col) in h_lde.iter().enumerate() { + for (r, v) in col.iter().enumerate() { + let [a, b, c0] = ext3_to_raw(v); + h_flat[(p * 3) * lde_size + r] = a; + h_flat[(p * 3 + 1) * lde_size + r] = b; + h_flat[(p * 3 + 2) * lde_size + r] = c0; + } + } + + let mut h_ood_flat = vec![0u64; num_parts * 3]; + for (j, e) in h_ood.iter().enumerate() { + let [a, b, c] = ext3_to_raw(e); + h_ood_flat[j * 3] = a; + h_ood_flat[j * 3 + 1] = b; + h_ood_flat[j * 3 + 2] = c; + } + let mut trace_ood_flat = vec![0u64; num_total_cols * num_eval_points * 3]; + for (j, col) in trace_ood.iter().enumerate() { + for (k, e) in col.iter().enumerate() { + let idx = (j * num_eval_points + k) * 3; + let [a, b, c] = ext3_to_raw(e); + trace_ood_flat[idx] = a; + trace_ood_flat[idx + 1] = b; + trace_ood_flat[idx + 2] = c; + } + } + let mut gammas_h_flat = vec![0u64; num_parts * 3]; + for (j, e) in gammas_h.iter().enumerate() { + let [a, b, c] = ext3_to_raw(e); + gammas_h_flat[j * 3] = a; + gammas_h_flat[j * 3 + 1] = b; + gammas_h_flat[j * 3 + 2] = c; + } + let mut gammas_tr_flat = vec![0u64; num_total_cols * num_eval_points * 3]; + for (j, col) in gammas_tr.iter().enumerate() { + for (k, e) in col.iter().enumerate() { + let idx = (j * num_eval_points + k) * 3; + let [a, b, c] = ext3_to_raw(e); + gammas_tr_flat[idx] = a; + gammas_tr_flat[idx + 1] = b; + gammas_tr_flat[idx + 2] = c; + } + } + let mut inv_h_flat = vec![0u64; domain_size * 3]; + for (i, e) in inv_h.iter().enumerate() { + let [a, b, c] = ext3_to_raw(e); + inv_h_flat[i * 3] = a; + inv_h_flat[i * 3 + 1] = b; + inv_h_flat[i * 3 + 2] = c; + } + let mut inv_t_flat = vec![0u64; num_eval_points * domain_size * 3]; + for (k, layer) in inv_t.iter().enumerate() { + for (i, e) in layer.iter().enumerate() { + let idx = (k * domain_size + i) * 3; + let [a, b, c] = ext3_to_raw(e); + inv_t_flat[idx] = a; + inv_t_flat[idx + 1] = b; + inv_t_flat[idx + 2] = c; + } + } + + let gpu_raw = math_cuda::deep::deep_composition_ext3( + &main_handle, + aux_handle.as_ref(), + &h_flat, + &h_ood_flat, + &trace_ood_flat, + &gammas_h_flat, + &gammas_tr_flat, + &inv_h_flat, + &inv_t_flat, + num_parts, + num_main, + num_aux, + num_eval_points, + blowup_factor, + domain_size, + ) + .unwrap(); + + for i in 0..domain_size { + let gpu = [gpu_raw[i * 3], gpu_raw[i * 3 + 1], gpu_raw[i * 3 + 2]]; + let gpu_canon = [ + GoldilocksField::canonical(&gpu[0]), + GoldilocksField::canonical(&gpu[1]), + GoldilocksField::canonical(&gpu[2]), + ]; + let cpu_canon = canon3(&cpu_out[i]); + assert_eq!( + gpu_canon, cpu_canon, + "row {i} mismatch at log_ds={log_domain_size} main={num_main} aux={num_aux} parts={num_parts}" + ); + } +} + +#[test] +fn deep_parity_small() { + run_parity(4, 2, 3, 2, 2, 1, 100); + run_parity(6, 4, 5, 3, 2, 2, 200); +} + +#[test] +fn deep_parity_medium() { + run_parity(10, 2, 10, 5, 4, 3, 1000); +} + +#[test] +fn deep_parity_no_aux() { + run_parity(8, 2, 5, 0, 2, 2, 5000); +} diff --git a/crypto/math-cuda/tests/evaluate_coset_ext3.rs b/crypto/math-cuda/tests/evaluate_coset_ext3.rs new file mode 100644 index 000000000..a79195291 --- /dev/null +++ b/crypto/math-cuda/tests/evaluate_coset_ext3.rs @@ -0,0 +1,143 @@ +//! Parity test for `evaluate_poly_coset_batch_ext3_into`. +//! +//! Reference: `math::polynomial::Polynomial::evaluate_offset_fft` on an ext3 +//! polynomial, then canonicalise. The GPU path should produce the same +//! evaluations on the offset-coset at `n * blowup` points. + +use math::field::element::FieldElement; +use math::field::extensions_goldilocks::Degree3GoldilocksExtensionField; +use math::field::goldilocks::GoldilocksField; +use math::field::traits::{IsField, IsPrimeField}; +use math::polynomial::Polynomial; +use rand::{Rng, SeedableRng}; +use rand_chacha::ChaCha8Rng; + +type Fp = FieldElement; +type Fp3 = FieldElement; + +fn offset_weights(n: usize, offset: u64) -> Vec { + let mut w = Vec::with_capacity(n); + let mut cur = 1u64; + for _ in 0..n { + w.push(cur); + cur = GoldilocksField::mul(&cur, &offset); + } + w +} + +fn rand_ext3(rng: &mut ChaCha8Rng) -> Fp3 { + Fp3::new([ + Fp::from_raw(rng.r#gen::()), + Fp::from_raw(rng.r#gen::()), + Fp::from_raw(rng.r#gen::()), + ]) +} + +fn ext3_to_u64s(col: &[Fp3]) -> Vec { + let mut out = Vec::with_capacity(col.len() * 3); + for e in col { + out.push(*e.value()[0].value()); + out.push(*e.value()[1].value()); + out.push(*e.value()[2].value()); + } + out +} + +fn u64s_to_ext3(raw: &[u64]) -> Vec { + let mut out = Vec::with_capacity(raw.len() / 3); + for i in 0..raw.len() / 3 { + out.push(Fp3::new([ + Fp::from_raw(raw[i * 3 + 0]), + Fp::from_raw(raw[i * 3 + 1]), + Fp::from_raw(raw[i * 3 + 2]), + ])); + } + out +} + +fn canon_fp3(e: &Fp3) -> [u64; 3] { + [ + GoldilocksField::canonical(e.value()[0].value()), + GoldilocksField::canonical(e.value()[1].value()), + GoldilocksField::canonical(e.value()[2].value()), + ] +} + +fn assert_evaluate_coset(log_n: u64, blowup: usize, m: usize, offset: u64, seed: u64) { + let n = 1usize << log_n; + let lde_size = n * blowup; + let mut rng = ChaCha8Rng::seed_from_u64(seed); + + // M ext3 polynomials, each of degree < n. + let polys: Vec> = (0..m) + .map(|_| (0..n).map(|_| rand_ext3(&mut rng)).collect()) + .collect(); + + let weights = offset_weights(n, offset); + + // CPU reference: evaluate each polynomial at `offset`-coset of size lde_size. + let offset_fp = Fp::from_raw(offset); + let cpu: Vec> = polys + .iter() + .map(|coefs| { + let p = Polynomial::new(coefs); + Polynomial::evaluate_offset_fft::( + &p, + blowup, + Some(n), + &offset_fp, + ) + .unwrap() + }) + .collect(); + + // GPU: flatten each poly to 3n u64s, pre-allocate 3*lde_size u64 outputs. + let flat_inputs: Vec> = polys.iter().map(|p| ext3_to_u64s(p)).collect(); + let input_slices: Vec<&[u64]> = flat_inputs.iter().map(|v| v.as_slice()).collect(); + let mut flat_outputs: Vec> = (0..m).map(|_| vec![0u64; 3 * lde_size]).collect(); + { + let mut out_slices: Vec<&mut [u64]> = + flat_outputs.iter_mut().map(|v| v.as_mut_slice()).collect(); + math_cuda::lde::evaluate_poly_coset_batch_ext3_into( + &input_slices, + n, + blowup, + &weights, + &mut out_slices, + ) + .unwrap(); + } + + for c in 0..m { + let gpu: Vec = u64s_to_ext3(&flat_outputs[c]); + assert_eq!(gpu.len(), cpu[c].len(), "length mismatch"); + for i in 0..gpu.len() { + let g = canon_fp3(&gpu[i]); + let cc = canon_fp3(&cpu[c][i]); + assert_eq!(g, cc, "eval mismatch col={c} row={i} log_n={log_n} blowup={blowup}"); + } + } +} + +#[test] +fn ext3_evaluate_coset_small() { + for &m in &[1usize, 4] { + for log_n in 4..=10 { + for &blowup in &[2usize, 4] { + assert_evaluate_coset(log_n, blowup, m, 7, 100 + log_n * 10 + m as u64); + } + } + } +} + +#[test] +fn ext3_evaluate_coset_medium() { + for log_n in 11..=14 { + assert_evaluate_coset(log_n, 4, 2, 7, 200 + log_n); + } +} + +#[test] +fn ext3_evaluate_coset_large_one_column() { + assert_evaluate_coset(16, 4, 1, 7, 0xCAFE); +} diff --git a/crypto/math-cuda/tests/ext3.rs b/crypto/math-cuda/tests/ext3.rs new file mode 100644 index 000000000..c9aabbc27 --- /dev/null +++ b/crypto/math-cuda/tests/ext3.rs @@ -0,0 +1,87 @@ +//! Parity: GPU ext3 arithmetic must agree (canonically) with CPU +//! `Degree3GoldilocksExtensionField` on random ext3 inputs. + +use math::field::element::FieldElement; +use math::field::extensions_goldilocks::Degree3GoldilocksExtensionField; +use math::field::goldilocks::GoldilocksField; +use math::field::traits::{IsField, IsPrimeField}; +use rand::{Rng, SeedableRng}; +use rand_chacha::ChaCha8Rng; + +type Fp = FieldElement; +type Fp3 = FieldElement; + +const N: usize = 10_000; + +fn random_fp3s(seed: u64, count: usize) -> Vec { + let mut rng = ChaCha8Rng::seed_from_u64(seed); + (0..count) + .map(|_| { + Fp3::new([ + Fp::from_raw(rng.r#gen::()), + Fp::from_raw(rng.r#gen::()), + Fp::from_raw(rng.r#gen::()), + ]) + }) + .collect() +} + +fn to_u64s(col: &[Fp3]) -> Vec { + let mut v = Vec::with_capacity(col.len() * 3); + for e in col { + v.push(*e.value()[0].value()); + v.push(*e.value()[1].value()); + v.push(*e.value()[2].value()); + } + v +} + +fn canon_triplet(e: &Fp3) -> [u64; 3] { + [ + GoldilocksField::canonical(e.value()[0].value()), + GoldilocksField::canonical(e.value()[1].value()), + GoldilocksField::canonical(e.value()[2].value()), + ] +} + +fn canon_triplet_raw(t: &[u64]) -> [u64; 3] { + [ + GoldilocksField::canonical(&t[0]), + GoldilocksField::canonical(&t[1]), + GoldilocksField::canonical(&t[2]), + ] +} + +#[test] +fn ext3_mul_matches_cpu() { + let a = random_fp3s(11, N); + let b = random_fp3s(22, N); + let a_raw = to_u64s(&a); + let b_raw = to_u64s(&b); + let gpu = math_cuda::ext3_mul_u64(&a_raw, &b_raw).unwrap(); + assert_eq!(gpu.len(), 3 * N); + for i in 0..N { + use math::field::traits::IsField; + let cpu = Degree3GoldilocksExtensionField::mul(a[i].value(), b[i].value()); + let cpu_fp3 = Fp3::new(cpu); + let g = canon_triplet_raw(&gpu[i * 3..(i + 1) * 3]); + let c = canon_triplet(&cpu_fp3); + assert_eq!(g, c, "ext3 mul mismatch at {i}"); + } +} + +#[test] +fn ext3_add_matches_cpu() { + let a = random_fp3s(33, N); + let b = random_fp3s(44, N); + let a_raw = to_u64s(&a); + let b_raw = to_u64s(&b); + let gpu = math_cuda::ext3_add_u64(&a_raw, &b_raw).unwrap(); + for i in 0..N { + let cpu = Degree3GoldilocksExtensionField::add(a[i].value(), b[i].value()); + let cpu_fp3 = Fp3::new(cpu); + let g = canon_triplet_raw(&gpu[i * 3..(i + 1) * 3]); + let c = canon_triplet(&cpu_fp3); + assert_eq!(g, c, "ext3 add mismatch at {i}"); + } +} diff --git a/crypto/math-cuda/tests/fri_layer_tree.rs b/crypto/math-cuda/tests/fri_layer_tree.rs new file mode 100644 index 000000000..c637ccc02 --- /dev/null +++ b/crypto/math-cuda/tests/fri_layer_tree.rs @@ -0,0 +1,111 @@ +//! Parity: GPU `build_fri_layer_tree_from_evals_ext3` vs CPU +//! `FriLayerMerkleTree::build` (PairKeccak256 backend over ext3 pairs). + +use math::field::element::FieldElement; +use math::field::extensions_goldilocks::Degree3GoldilocksExtensionField; +use math::field::goldilocks::GoldilocksField; +use math::field::traits::IsField; +use math::traits::ByteConversion; +use rand::{Rng, SeedableRng}; +use rand_chacha::ChaCha8Rng; +use sha3::{Digest, Keccak256}; + +type Fp = FieldElement; +type Fp3 = FieldElement; + +fn rand_ext3(rng: &mut ChaCha8Rng) -> Fp3 { + Fp3::new([ + Fp::from_raw(rng.r#gen::()), + Fp::from_raw(rng.r#gen::()), + Fp::from_raw(rng.r#gen::()), + ]) +} + +fn ext3_to_u64s(col: &[Fp3]) -> Vec { + let mut out = Vec::with_capacity(col.len() * 3); + for e in col { + out.push(*e.value()[0].value()); + out.push(*e.value()[1].value()); + out.push(*e.value()[2].value()); + } + out +} + +fn cpu_hash_pair_bytes(a: &Fp3, b: &Fp3) -> [u8; 32] { + let mut buf = [0u8; 48]; + a.write_bytes_be(&mut buf[0..24]); + b.write_bytes_be(&mut buf[24..48]); + let mut h = Keccak256::new(); + h.update(&buf); + let mut out = [0u8; 32]; + out.copy_from_slice(&h.finalize()); + out +} + +fn cpu_hash_pair_nodes(left: &[u8; 32], right: &[u8; 32]) -> [u8; 32] { + let mut h = Keccak256::new(); + h.update(left); + h.update(right); + let mut out = [0u8; 32]; + out.copy_from_slice(&h.finalize()); + out +} + +fn cpu_fri_layer_nodes(evals: &[Fp3]) -> Vec<[u8; 32]> { + let num_leaves = evals.len() / 2; + assert!(num_leaves.is_power_of_two() && num_leaves >= 1); + let total = 2 * num_leaves - 1; + let mut nodes: Vec<[u8; 32]> = vec![[0u8; 32]; total]; + for j in 0..num_leaves { + nodes[num_leaves - 1 + j] = cpu_hash_pair_bytes(&evals[2 * j], &evals[2 * j + 1]); + } + let mut level_begin = num_leaves - 1; + while level_begin != 0 { + let new_begin = level_begin / 2; + let n_pairs = level_begin - new_begin; + for k in 0..n_pairs { + let l = nodes[level_begin + 2 * k]; + let r = nodes[level_begin + 2 * k + 1]; + nodes[new_begin + k] = cpu_hash_pair_nodes(&l, &r); + } + level_begin = new_begin; + } + nodes +} + +fn run_parity(log_num_leaves: u32, seed: u64) { + let num_leaves = 1usize << log_num_leaves; + let num_evals = num_leaves * 2; + let mut rng = ChaCha8Rng::seed_from_u64(seed); + let evals: Vec = (0..num_evals).map(|_| rand_ext3(&mut rng)).collect(); + let evals_u64 = ext3_to_u64s(&evals); + + let cpu_nodes = cpu_fri_layer_nodes(&evals); + let gpu_bytes = math_cuda::merkle::build_fri_layer_tree_from_evals_ext3(&evals_u64).unwrap(); + + assert_eq!(cpu_nodes.len() * 32, gpu_bytes.len()); + for i in 0..cpu_nodes.len() { + let g = &gpu_bytes[i * 32..(i + 1) * 32]; + let c = &cpu_nodes[i]; + assert_eq!(g, c, "node {i} mismatch at log_num_leaves={log_num_leaves}"); + } +} + +#[test] +fn fri_layer_tree_small() { + for log in 1u32..=6 { + run_parity(log, 100 + log as u64); + } +} + +#[test] +fn fri_layer_tree_medium() { + for log in [10u32, 12, 14] { + run_parity(log, 500 + log as u64); + } +} + +#[test] +fn fri_layer_tree_large() { + run_parity(18, 9999); +} diff --git a/crypto/math-cuda/tests/goldilocks.rs b/crypto/math-cuda/tests/goldilocks.rs new file mode 100644 index 000000000..317ffb0f8 --- /dev/null +++ b/crypto/math-cuda/tests/goldilocks.rs @@ -0,0 +1,127 @@ +//! GPU must produce bit-identical u64 outputs to `GoldilocksField` for every op. +//! Non-canonical inputs are expected (CPU operates on the full [0, 2^64) range), +//! so the test inputs include values above the prime. + +use math::field::goldilocks::GoldilocksField; +use math::field::traits::{IsField, IsPrimeField}; +use rand::{Rng, SeedableRng}; +use rand_chacha::ChaCha8Rng; + +const N: usize = 10_000; + +fn sample_inputs(seed: u64) -> Vec { + let mut rng = ChaCha8Rng::seed_from_u64(seed); + (0..N).map(|_| rng.r#gen::()).collect() +} + +fn assert_raw_eq(op: &str, expected: &[u64], actual: &[u64]) { + assert_eq!(expected.len(), actual.len()); + for (i, (e, a)) in expected.iter().zip(actual.iter()).enumerate() { + if e != a { + panic!( + "{op} mismatch at {i}: cpu={e:#018x} (canon {:#018x}), gpu={a:#018x} (canon {:#018x})", + GoldilocksField::canonical(e), + GoldilocksField::canonical(a), + ); + } + } +} + +#[test] +fn gpu_vector_add_u64_matches_wrapping() { + let a = sample_inputs(0xC0FFEE); + let b = sample_inputs(0xDEADBEEF); + let expected: Vec = a.iter().zip(&b).map(|(x, y)| x.wrapping_add(*y)).collect(); + let actual = math_cuda::vector_add_u64(&a, &b).expect("GPU vector_add_u64"); + assert_raw_eq("vector_add (wrapping)", &expected, &actual); +} + +#[test] +fn gpu_gl_add_matches_cpu() { + let a = sample_inputs(1); + let b = sample_inputs(2); + let expected: Vec = a + .iter() + .zip(&b) + .map(|(x, y)| GoldilocksField::add(x, y)) + .collect(); + let actual = math_cuda::gl_add_u64(&a, &b).expect("GPU gl_add"); + assert_raw_eq("gl_add", &expected, &actual); +} + +#[test] +fn gpu_gl_sub_matches_cpu() { + let a = sample_inputs(3); + let b = sample_inputs(4); + let expected: Vec = a + .iter() + .zip(&b) + .map(|(x, y)| GoldilocksField::sub(x, y)) + .collect(); + let actual = math_cuda::gl_sub_u64(&a, &b).expect("GPU gl_sub"); + assert_raw_eq("gl_sub", &expected, &actual); +} + +#[test] +fn gpu_gl_mul_matches_cpu() { + let a = sample_inputs(5); + let b = sample_inputs(6); + let expected: Vec = a + .iter() + .zip(&b) + .map(|(x, y)| GoldilocksField::mul(x, y)) + .collect(); + let actual = math_cuda::gl_mul_u64(&a, &b).expect("GPU gl_mul"); + assert_raw_eq("gl_mul", &expected, &actual); +} + +#[test] +fn gpu_gl_neg_matches_cpu() { + let a = sample_inputs(7); + let expected: Vec = a.iter().map(|x| GoldilocksField::neg(x)).collect(); + let actual = math_cuda::gl_neg_u64(&a).expect("GPU gl_neg"); + assert_raw_eq("gl_neg", &expected, &actual); +} + +/// Edge cases the random generator is unlikely to hit: 0, 1, p-1, p, p+1, 2p-1, +/// u64::MAX, EPSILON boundary values. Covers double-overflow / double-underflow. +#[test] +fn gpu_goldilocks_edge_cases() { + const P: u64 = 0xFFFF_FFFF_0000_0001; + const EPS: u64 = 0xFFFF_FFFF; + let edge: [u64; 11] = [ + 0, + 1, + P - 1, + P, + P + 1, + 2u64.wrapping_mul(P).wrapping_sub(1), + u64::MAX, + u64::MAX - EPS, + u64::MAX - 1, + EPS, + EPS - 1, + ]; + // All pairs via nested loops, materialised as flat a[], b[] of length edge^2. + let mut a = Vec::with_capacity(edge.len() * edge.len()); + let mut b = Vec::with_capacity(edge.len() * edge.len()); + for &x in &edge { + for &y in &edge { + a.push(x); + b.push(y); + } + } + + let cases: &[(&str, fn(&[u64], &[u64]) -> math_cuda::Result>, fn(&u64, &u64) -> u64)] = + &[ + ("gl_add", math_cuda::gl_add_u64, GoldilocksField::add), + ("gl_sub", math_cuda::gl_sub_u64, GoldilocksField::sub), + ("gl_mul", math_cuda::gl_mul_u64, GoldilocksField::mul), + ]; + + for (op, gpu_fn, cpu_fn) in cases { + let expected: Vec = a.iter().zip(&b).map(|(x, y)| cpu_fn(x, y)).collect(); + let actual = gpu_fn(&a, &b).expect("GPU op"); + assert_raw_eq(op, &expected, &actual); + } +} diff --git a/crypto/math-cuda/tests/keccak_leaves.rs b/crypto/math-cuda/tests/keccak_leaves.rs new file mode 100644 index 000000000..6186ab45e --- /dev/null +++ b/crypto/math-cuda/tests/keccak_leaves.rs @@ -0,0 +1,141 @@ +//! Parity: GPU Keccak-256 leaf hashes must match CPU +//! `FieldElementVectorBackend::::hash_data` applied to +//! bit-reversed rows (same pattern as `commit_columns_bit_reversed` in the +//! stark prover). + +use math::field::element::FieldElement; +use math::field::extensions_goldilocks::Degree3GoldilocksExtensionField; +use math::field::goldilocks::GoldilocksField; +use math::field::traits::IsField; +use math::traits::ByteConversion; +use rand::{Rng, SeedableRng}; +use rand_chacha::ChaCha8Rng; +use sha3::{Digest, Keccak256}; + +type Fp = FieldElement; +type Fp3 = FieldElement; + +fn reverse_index(i: u64, n: u64) -> u64 { + let log_n = n.trailing_zeros(); + i.reverse_bits() >> (64 - log_n) +} + +fn cpu_leaves_base(columns: &[Vec]) -> Vec<[u8; 32]> { + let num_rows = columns[0].len(); + let num_cols = columns.len(); + let byte_len = 8; + (0..num_rows) + .map(|row_idx| { + let br = reverse_index(row_idx as u64, num_rows as u64) as usize; + let mut buf = vec![0u8; num_cols * byte_len]; + for c in 0..num_cols { + columns[c][br].write_bytes_be(&mut buf[c * byte_len..(c + 1) * byte_len]); + } + let mut h = Keccak256::new(); + h.update(&buf); + let mut out = [0u8; 32]; + out.copy_from_slice(&h.finalize()); + out + }) + .collect() +} + +fn cpu_leaves_ext3(columns: &[Vec]) -> Vec<[u8; 32]> { + let num_rows = columns[0].len(); + let num_cols = columns.len(); + let byte_len = 24; + (0..num_rows) + .map(|row_idx| { + let br = reverse_index(row_idx as u64, num_rows as u64) as usize; + let mut buf = vec![0u8; num_cols * byte_len]; + for c in 0..num_cols { + columns[c][br].write_bytes_be(&mut buf[c * byte_len..(c + 1) * byte_len]); + } + let mut h = Keccak256::new(); + h.update(&buf); + let mut out = [0u8; 32]; + out.copy_from_slice(&h.finalize()); + out + }) + .collect() +} + +#[test] +fn keccak_leaves_base_matches_cpu() { + for log_n in [4u32, 6, 8, 10, 12] { + for num_cols in [1usize, 5, 17, 41] { + let n = 1 << log_n; + let mut rng = ChaCha8Rng::seed_from_u64(100 + log_n as u64 + num_cols as u64); + let columns: Vec> = (0..num_cols) + .map(|_| (0..n).map(|_| Fp::from_raw(rng.r#gen::())).collect()) + .collect(); + + let cpu = cpu_leaves_base(&columns); + + // Flatten columns into a contiguous base slab layout matching + // `coset_lde_batch_base_into`'s pinned staging format: + // `[col * stride + row]`. Use stride = num_rows for compactness. + let mut flat = vec![0u64; num_cols * n]; + for (c, col) in columns.iter().enumerate() { + for (r, e) in col.iter().enumerate() { + flat[c * n + r] = *e.value(); + } + } + let gpu = math_cuda::merkle::keccak_leaves_base(&flat, n, num_cols, n).unwrap(); + assert_eq!(gpu.len(), n * 32); + for i in 0..n { + assert_eq!( + &gpu[i * 32..(i + 1) * 32], + &cpu[i][..], + "base leaf mismatch at row {i} (log_n={log_n}, cols={num_cols})" + ); + } + } + } +} + +#[test] +fn keccak_leaves_ext3_matches_cpu() { + for log_n in [4u32, 6, 8, 10] { + for num_cols in [1usize, 3, 11, 20] { + let n = 1 << log_n; + let mut rng = ChaCha8Rng::seed_from_u64(200 + log_n as u64 + num_cols as u64); + let columns: Vec> = (0..num_cols) + .map(|_| { + (0..n) + .map(|_| { + Fp3::new([ + Fp::from_raw(rng.r#gen::()), + Fp::from_raw(rng.r#gen::()), + Fp::from_raw(rng.r#gen::()), + ]) + }) + .collect() + }) + .collect(); + + let cpu = cpu_leaves_ext3(&columns); + + // GPU expects 3 base slabs per ext3 column in the order + // [col*3+0 (comp a), col*3+1 (comp b), col*3+2 (comp c)], each a + // contiguous slab of n u64s (length = num_cols * 3 * n). + let mut flat = vec![0u64; num_cols * 3 * n]; + for (c, col) in columns.iter().enumerate() { + for (r, e) in col.iter().enumerate() { + flat[(c * 3 + 0) * n + r] = *e.value()[0].value(); + flat[(c * 3 + 1) * n + r] = *e.value()[1].value(); + flat[(c * 3 + 2) * n + r] = *e.value()[2].value(); + } + } + let gpu = math_cuda::merkle::keccak_leaves_ext3(&flat, n, num_cols, n).unwrap(); + assert_eq!(gpu.len(), n * 32); + for i in 0..n { + assert_eq!( + &gpu[i * 32..(i + 1) * 32], + &cpu[i][..], + "ext3 leaf mismatch at row {i} (log_n={log_n}, cols={num_cols})" + ); + } + } + } +} diff --git a/crypto/math-cuda/tests/lde.rs b/crypto/math-cuda/tests/lde.rs new file mode 100644 index 000000000..9648f833a --- /dev/null +++ b/crypto/math-cuda/tests/lde.rs @@ -0,0 +1,112 @@ +//! Phase-5 parity: GPU `coset_lde_base` must match the CPU +//! `Polynomial::coset_lde_full_expand` for a sweep of realistic sizes and +//! blowup factors. + +use math::fft::cpu::bowers_fft::LayerTwiddles; +use math::field::element::FieldElement; +use math::field::goldilocks::GoldilocksField; +use math::field::traits::{IsField, IsPrimeField}; +use math::polynomial::Polynomial; +use rand::{Rng, SeedableRng}; +use rand_chacha::ChaCha8Rng; + +type Fp = FieldElement; + +/// Build the coset weights `[1/N, g/N, g²/N, …, g^{n-1}/N]` — this is the +/// layout `crypto/stark/src/prover.rs:248` uses, with `1/N` pre-folded into the +/// first coefficient so the iFFT step does not need a separate scaling pass. +fn coset_weights(n: usize, coset_offset: u64) -> Vec { + let inv_n_fe = FieldElement::::from(n as u64) + .inv() + .expect("n is non-zero"); + let mut w = Vec::with_capacity(n); + let mut cur = *inv_n_fe.value(); + for _ in 0..n { + w.push(cur); + cur = GoldilocksField::mul(&cur, &coset_offset); + } + w +} + +fn cpu_lde(evals: &[u64], blowup_factor: usize, coset_offset: u64) -> Vec { + let n = evals.len(); + let log_n = n.trailing_zeros() as u64; + let log_lde = (n * blowup_factor).trailing_zeros() as u64; + + let inv_tw = LayerTwiddles::::new_inverse(log_n).expect("inv tw"); + let fwd_tw = LayerTwiddles::::new(log_lde).expect("fwd tw"); + let weights_raw = coset_weights(n, coset_offset); + let weights: Vec = weights_raw.iter().map(|&w| Fp::from_raw(w)).collect(); + + let mut buf: Vec = evals.iter().map(|&x| Fp::from_raw(x)).collect(); + Polynomial::coset_lde_full_expand::( + &mut buf, + blowup_factor, + &weights, + &inv_tw, + &fwd_tw, + ) + .expect("cpu lde"); + + buf.into_iter().map(|e| *e.value()).collect() +} + +fn canon(xs: &[u64]) -> Vec { + xs.iter().map(|x| GoldilocksField::canonical(x)).collect() +} + +fn assert_lde_match(log_n: u64, blowup_factor: usize, seed: u64) { + let n = 1usize << log_n; + let mut rng = ChaCha8Rng::seed_from_u64(seed); + let evals: Vec = (0..n).map(|_| rng.r#gen::()).collect(); + + // Use a fixed, public coset offset. For lambda-vm the coset offset is the + // generator of Goldilocks' multiplicative subgroup; any non-trivial element + // works for an isolated correctness check. + let coset_offset: u64 = 7; + let weights = coset_weights(n, coset_offset); + + let cpu = cpu_lde(&evals, blowup_factor, coset_offset); + let gpu = math_cuda::lde::coset_lde_base(&evals, blowup_factor, &weights).expect("gpu lde"); + + assert_eq!(cpu.len(), gpu.len(), "length mismatch (log_n={log_n}, blowup={blowup_factor})"); + let cpu_c = canon(&cpu); + let gpu_c = canon(&gpu); + for (i, (e, a)) in cpu_c.iter().zip(&gpu_c).enumerate() { + if e != a { + panic!( + "lde mismatch log_n={log_n} blowup={blowup_factor} i={i}: cpu {e:#018x}, gpu {a:#018x}", + ); + } + } +} + +#[test] +fn lde_small() { + for log_n in 4..=10 { + for &blow in &[2usize, 4, 8] { + assert_lde_match(log_n, blow, 1_000 + log_n + (blow as u64)); + } + } +} + +#[test] +fn lde_medium() { + for log_n in 11..=14 { + for &blow in &[2usize, 4] { + assert_lde_match(log_n, blow, 2_000 + log_n + (blow as u64)); + } + } +} + +#[test] +fn lde_large_2_to_18() { + // 2^18 × blowup 4 = 2^20 LDE — representative of Phase A trace columns. + assert_lde_match(18, 4, 0xCAFE); +} + +#[test] +fn lde_largest_2_to_20() { + // 2^20 LDE is the hot size; blowup 2 keeps total = 2^21 (within TWO_ADICITY). + assert_lde_match(20, 2, 0xF00D); +} diff --git a/crypto/math-cuda/tests/lde_batch.rs b/crypto/math-cuda/tests/lde_batch.rs new file mode 100644 index 000000000..67f975728 --- /dev/null +++ b/crypto/math-cuda/tests/lde_batch.rs @@ -0,0 +1,96 @@ +//! Batched coset LDE must agree with running the CPU single-column LDE on +//! each column independently. Sweeps a few realistic (n, blowup, m) tuples. + +use math::fft::cpu::bowers_fft::LayerTwiddles; +use math::field::element::FieldElement; +use math::field::goldilocks::GoldilocksField; +use math::field::traits::{IsField, IsPrimeField}; +use math::polynomial::Polynomial; +use rand::{Rng, SeedableRng}; +use rand_chacha::ChaCha8Rng; + +type Fp = FieldElement; + +fn coset_weights(n: usize, g: u64) -> Vec { + let inv_n = *FieldElement::::from(n as u64) + .inv() + .unwrap() + .value(); + let mut w = Vec::with_capacity(n); + let mut cur = inv_n; + for _ in 0..n { + w.push(cur); + cur = GoldilocksField::mul(&cur, &g); + } + w +} + +fn cpu_lde_one(col: &[u64], blowup: usize, weights_fp: &[Fp], inv_tw: &LayerTwiddles, fwd_tw: &LayerTwiddles) -> Vec { + let mut buf: Vec = col.iter().map(|&x| Fp::from_raw(x)).collect(); + Polynomial::coset_lde_full_expand::( + &mut buf, blowup, weights_fp, inv_tw, fwd_tw, + ) + .unwrap(); + buf.into_iter().map(|e| *e.value()).collect() +} + +fn canon(xs: &[u64]) -> Vec { + xs.iter().map(|x| GoldilocksField::canonical(x)).collect() +} + +fn assert_batch(log_n: u64, blowup: usize, m: usize, seed: u64) { + let n = 1usize << log_n; + let mut rng = ChaCha8Rng::seed_from_u64(seed); + let columns: Vec> = (0..m) + .map(|_| (0..n).map(|_| rng.r#gen::()).collect()) + .collect(); + + let coset_offset: u64 = 7; + let weights = coset_weights(n, coset_offset); + let weights_fp: Vec = weights.iter().map(|&w| Fp::from_raw(w)).collect(); + + let inv_tw = LayerTwiddles::::new_inverse(log_n).unwrap(); + let fwd_tw = + LayerTwiddles::::new((n * blowup).trailing_zeros() as u64).unwrap(); + + let slices: Vec<&[u64]> = columns.iter().map(|c| c.as_slice()).collect(); + let gpu_all = math_cuda::lde::coset_lde_batch_base(&slices, blowup, &weights).unwrap(); + assert_eq!(gpu_all.len(), m); + + for (c, col) in columns.iter().enumerate() { + let cpu = cpu_lde_one(col, blowup, &weights_fp, &inv_tw, &fwd_tw); + assert_eq!( + canon(&gpu_all[c]), + canon(&cpu), + "batch mismatch at col {c}, log_n={log_n}, blowup={blowup}" + ); + } +} + +#[test] +fn batch_small() { + for &m in &[1usize, 4, 16] { + for log_n in 4..=10 { + assert_batch(log_n, 4, m, 100 + log_n * 10 + m as u64); + } + } +} + +#[test] +fn batch_medium() { + for &m in &[2usize, 32] { + for log_n in 11..=14 { + assert_batch(log_n, 4, m, 200 + log_n * 10 + m as u64); + } + } +} + +#[test] +fn batch_large_one_column() { + assert_batch(18, 4, 1, 0xCAFE); +} + +#[test] +fn batch_large_32_columns() { + assert_batch(15, 4, 32, 0xBEEF); +} diff --git a/crypto/math-cuda/tests/lde_batch_ext3.rs b/crypto/math-cuda/tests/lde_batch_ext3.rs new file mode 100644 index 000000000..0a86197a5 --- /dev/null +++ b/crypto/math-cuda/tests/lde_batch_ext3.rs @@ -0,0 +1,161 @@ +//! Ext3 batched coset LDE must agree with the CPU `coset_lde_full_expand` +//! on each column independently when run over `FieldElement`. + +use math::fft::cpu::bowers_fft::LayerTwiddles; +use math::field::element::FieldElement; +use math::field::extensions_goldilocks::Degree3GoldilocksExtensionField; +use math::field::goldilocks::GoldilocksField; +use math::field::traits::{IsField, IsPrimeField}; +use math::polynomial::Polynomial; +use rand::{Rng, SeedableRng}; +use rand_chacha::ChaCha8Rng; + +type Fp = FieldElement; +type Fp3 = FieldElement; + +fn coset_weights(n: usize, g: u64) -> Vec { + let inv_n = *FieldElement::::from(n as u64) + .inv() + .unwrap() + .value(); + let mut w = Vec::with_capacity(n); + let mut cur = inv_n; + for _ in 0..n { + w.push(cur); + cur = GoldilocksField::mul(&cur, &g); + } + w +} + +fn rand_ext3(rng: &mut ChaCha8Rng) -> Fp3 { + Fp3::new([ + Fp::from_raw(rng.r#gen::()), + Fp::from_raw(rng.r#gen::()), + Fp::from_raw(rng.r#gen::()), + ]) +} + +fn ext3_to_u64s(col: &[Fp3]) -> Vec { + // Each Fp3 is [u64; 3] in memory; we just flatten componentwise. + let mut out = Vec::with_capacity(col.len() * 3); + for e in col { + out.push(*e.value()[0].value()); + out.push(*e.value()[1].value()); + out.push(*e.value()[2].value()); + } + out +} + +fn u64s_to_ext3(raw: &[u64]) -> Vec { + assert_eq!(raw.len() % 3, 0); + let mut out = Vec::with_capacity(raw.len() / 3); + for i in 0..raw.len() / 3 { + out.push(Fp3::new([ + Fp::from_raw(raw[i * 3 + 0]), + Fp::from_raw(raw[i * 3 + 1]), + Fp::from_raw(raw[i * 3 + 2]), + ])); + } + out +} + +fn cpu_lde_one_ext3( + col: &[Fp3], + blowup: usize, + weights_fp: &[Fp], + inv_tw: &LayerTwiddles, + fwd_tw: &LayerTwiddles, +) -> Vec { + let mut buf = col.to_vec(); + Polynomial::coset_lde_full_expand::( + &mut buf, blowup, weights_fp, inv_tw, fwd_tw, + ) + .unwrap(); + buf +} + +fn canon(xs: &[u64]) -> Vec { + xs.iter().map(|x| GoldilocksField::canonical(x)).collect() +} + +fn assert_ext3_batch(log_n: u64, blowup: usize, m: usize, seed: u64) { + let n = 1usize << log_n; + let lde_size = n * blowup; + let mut rng = ChaCha8Rng::seed_from_u64(seed); + let columns: Vec> = (0..m) + .map(|_| (0..n).map(|_| rand_ext3(&mut rng)).collect()) + .collect(); + + let coset_offset: u64 = 7; + let weights = coset_weights(n, coset_offset); + let weights_fp: Vec = weights.iter().map(|&w| Fp::from_raw(w)).collect(); + let inv_tw = LayerTwiddles::::new_inverse(log_n).unwrap(); + let fwd_tw = LayerTwiddles::::new(lde_size.trailing_zeros() as u64).unwrap(); + + // Flatten each ext3 column to 3n u64s for the GPU API. + let flat_inputs: Vec> = columns.iter().map(|c| ext3_to_u64s(c)).collect(); + let input_slices: Vec<&[u64]> = flat_inputs.iter().map(|v| v.as_slice()).collect(); + + // Pre-allocate outputs, each 3*lde_size u64s. + let mut flat_outputs: Vec> = + (0..m).map(|_| vec![0u64; 3 * lde_size]).collect(); + { + let mut out_slices: Vec<&mut [u64]> = + flat_outputs.iter_mut().map(|v| v.as_mut_slice()).collect(); + math_cuda::lde::coset_lde_batch_ext3_into( + &input_slices, + n, + blowup, + &weights, + &mut out_slices, + ) + .unwrap(); + } + + for (c, col) in columns.iter().enumerate() { + let cpu = cpu_lde_one_ext3(col, blowup, &weights_fp, &inv_tw, &fwd_tw); + let gpu: Vec = u64s_to_ext3(&flat_outputs[c]); + assert_eq!(gpu.len(), cpu.len(), "length mismatch"); + for i in 0..cpu.len() { + for k in 0..3 { + let cv = *cpu[i].value()[k].value(); + let gv = *gpu[i].value()[k].value(); + let cc = GoldilocksField::canonical(&cv); + let gc = GoldilocksField::canonical(&gv); + if cc != gc { + panic!( + "ext3 batch mismatch col={c} row={i} comp={k} log_n={log_n} blowup={blowup}: cpu={cv:#018x} (canon {cc:#018x}), gpu={gv:#018x} (canon {gc:#018x})", + ); + } + } + } + } + // Also sanity-check raw canonical equality per column. + for (c, col) in columns.iter().enumerate() { + let cpu_raw = ext3_to_u64s(&cpu_lde_one_ext3(col, blowup, &weights_fp, &inv_tw, &fwd_tw)); + assert_eq!(canon(&cpu_raw), canon(&flat_outputs[c])); + } +} + +#[test] +fn ext3_batch_small() { + for &m in &[1usize, 4, 16] { + for log_n in 4..=10 { + assert_ext3_batch(log_n, 4, m, 100 + log_n * 10 + m as u64); + } + } +} + +#[test] +fn ext3_batch_medium() { + for &m in &[2usize, 8] { + for log_n in 11..=14 { + assert_ext3_batch(log_n, 4, m, 300 + log_n * 10 + m as u64); + } + } +} + +#[test] +fn ext3_batch_large_one_column() { + assert_ext3_batch(16, 4, 1, 0xCAFE); +} diff --git a/crypto/math-cuda/tests/merkle_tree.rs b/crypto/math-cuda/tests/merkle_tree.rs new file mode 100644 index 000000000..34d44c767 --- /dev/null +++ b/crypto/math-cuda/tests/merkle_tree.rs @@ -0,0 +1,92 @@ +//! Parity: GPU Merkle inner-tree construction must match the CPU +//! `crypto/crypto/src/merkle_tree/merkle.rs` `build_from_hashed_leaves` +//! (Keccak-256 pair hash at each level). + +use rand::{Rng, SeedableRng}; +use rand_chacha::ChaCha8Rng; +use sha3::{Digest, Keccak256}; + +fn cpu_hash_pair(left: &[u8; 32], right: &[u8; 32]) -> [u8; 32] { + let mut h = Keccak256::new(); + h.update(left); + h.update(right); + let mut out = [0u8; 32]; + out.copy_from_slice(&h.finalize()); + out +} + +/// CPU reference: same algorithm as `build_from_hashed_leaves`. +fn cpu_merkle_nodes(leaves: &[[u8; 32]]) -> Vec<[u8; 32]> { + let leaves_len = leaves.len(); + assert!(leaves_len.is_power_of_two() && leaves_len >= 2); + let total = 2 * leaves_len - 1; + + let mut nodes: Vec<[u8; 32]> = vec![[0u8; 32]; total]; + for (i, leaf) in leaves.iter().enumerate() { + nodes[leaves_len - 1 + i] = *leaf; + } + + let mut level_begin = leaves_len - 1; + while level_begin != 0 { + let new_begin = level_begin / 2; + let n_pairs = level_begin - new_begin; + for j in 0..n_pairs { + let left = nodes[level_begin + 2 * j]; + let right = nodes[level_begin + 2 * j + 1]; + nodes[new_begin + j] = cpu_hash_pair(&left, &right); + } + level_begin = new_begin; + } + nodes +} + +fn run_parity(log_n: u32, seed: u64) { + let leaves_len = 1usize << log_n; + let mut rng = ChaCha8Rng::seed_from_u64(seed); + let leaves: Vec<[u8; 32]> = (0..leaves_len) + .map(|_| { + let mut arr = [0u8; 32]; + rng.fill(&mut arr[..]); + arr + }) + .collect(); + + // Flat byte layout for the GPU entry point. + let mut flat = Vec::with_capacity(leaves_len * 32); + for l in &leaves { + flat.extend_from_slice(l); + } + + let gpu_nodes_bytes = math_cuda::merkle::build_merkle_tree_on_device(&flat).unwrap(); + assert_eq!(gpu_nodes_bytes.len(), (2 * leaves_len - 1) * 32); + + let cpu_nodes = cpu_merkle_nodes(&leaves); + + for i in 0..cpu_nodes.len() { + let g = &gpu_nodes_bytes[i * 32..(i + 1) * 32]; + let c = &cpu_nodes[i]; + assert_eq!( + g, c, + "node {i} mismatch at log_n={log_n} (cpu={c:?}, gpu={g:?})" + ); + } +} + +#[test] +fn merkle_tree_small() { + for log_n in 1u32..=6 { + run_parity(log_n, 100 + log_n as u64); + } +} + +#[test] +fn merkle_tree_medium() { + for log_n in [10u32, 12, 14] { + run_parity(log_n, 500 + log_n as u64); + } +} + +#[test] +fn merkle_tree_large() { + run_parity(18, 9999); +} diff --git a/crypto/math-cuda/tests/ntt.rs b/crypto/math-cuda/tests/ntt.rs new file mode 100644 index 000000000..d7cf3680a --- /dev/null +++ b/crypto/math-cuda/tests/ntt.rs @@ -0,0 +1,136 @@ +//! Phase-3 parity: GPU forward NTT must agree with `Polynomial::evaluate_fft` +//! as a field element, across a sweep of sizes from 2^4 to 2^20. +//! +//! Non-canonical u64s can differ between CPU and GPU while representing the +//! same element; we canonicalise both sides before comparing. + +use math::field::element::FieldElement; +use math::field::goldilocks::GoldilocksField; +use math::field::traits::IsPrimeField; +use math::polynomial::Polynomial; +use rand::{Rng, SeedableRng}; +use rand_chacha::ChaCha8Rng; + +type Fp = FieldElement; + +fn cpu_fft(coeffs: &[u64]) -> Vec { + let elems: Vec = coeffs.iter().map(|&x| Fp::from_raw(x)).collect(); + let poly = Polynomial::new(&elems); + let evals = Polynomial::evaluate_fft::(&poly, 1, None).expect("cpu fft"); + evals.into_iter().map(|e| *e.value()).collect() +} + +fn canonicalize(xs: &[u64]) -> Vec { + xs.iter() + .map(|x| GoldilocksField::canonical(x)) + .collect() +} + +fn assert_ntt_match(log_n: u64, seed: u64) { + let n = 1usize << log_n; + let mut rng = ChaCha8Rng::seed_from_u64(seed); + let input: Vec = (0..n).map(|_| rng.r#gen::()).collect(); + + let cpu = cpu_fft(&input); + let gpu = math_cuda::ntt::forward(&input).expect("gpu ntt"); + + assert_eq!(cpu.len(), gpu.len(), "length mismatch at log_n = {log_n}"); + let cpu_c = canonicalize(&cpu); + let gpu_c = canonicalize(&gpu); + for i in 0..n { + if cpu_c[i] != gpu_c[i] { + panic!( + "log_n={log_n} i={i}: cpu={:#018x} (canon {:#018x}), gpu={:#018x} (canon {:#018x})", + cpu[i], cpu_c[i], gpu[i], gpu_c[i], + ); + } + } +} + +#[test] +fn ntt_sizes_small() { + for log_n in 4..=10 { + assert_ntt_match(log_n, 100 + log_n); + } +} + +#[test] +fn ntt_sizes_medium() { + for log_n in 11..=16 { + assert_ntt_match(log_n, 200 + log_n); + } +} + +#[test] +fn ntt_size_2_to_20() { + // The hot LDE size. One seed is enough; any mismatch screams loudly. + assert_ntt_match(20, 0xDEAD); +} + +#[test] +fn ntt_trivial_sizes() { + // Power-of-two below the interesting range — should still pass. + assert_ntt_match(1, 1); + assert_ntt_match(2, 2); + assert_ntt_match(3, 3); +} + +fn assert_intt_match(log_n: u64, seed: u64) { + let n = 1usize << log_n; + let mut rng = ChaCha8Rng::seed_from_u64(seed); + let evals: Vec = (0..n).map(|_| rng.r#gen::()).collect(); + + let elems: Vec = evals.iter().map(|&x| Fp::from_raw(x)).collect(); + let cpu_poly = + Polynomial::interpolate_fft::(&elems).expect("cpu intt"); + let cpu: Vec = cpu_poly.coefficients.into_iter().map(|e| *e.value()).collect(); + + let gpu = math_cuda::ntt::inverse(&evals).expect("gpu intt"); + + let cpu_c = canonicalize(&cpu); + let gpu_c = canonicalize(&gpu); + for i in 0..n { + if cpu_c[i] != gpu_c[i] { + panic!( + "iNTT log_n={log_n} i={i}: cpu canon {:#018x}, gpu canon {:#018x}", + cpu_c[i], gpu_c[i], + ); + } + } +} + +#[test] +fn intt_sizes_small() { + for log_n in 4..=10 { + assert_intt_match(log_n, 700 + log_n); + } +} + +#[test] +fn intt_sizes_medium() { + for log_n in 11..=16 { + assert_intt_match(log_n, 800 + log_n); + } +} + +#[test] +fn intt_size_2_to_20() { + assert_intt_match(20, 0xBEEF); +} + +#[test] +fn ntt_round_trip() { + // inverse(forward(x)) == x up to canonical form. + let log_n = 14; + let n = 1usize << log_n; + let mut rng = ChaCha8Rng::seed_from_u64(42); + let x: Vec = (0..n).map(|_| rng.r#gen::() % 0xFFFF_FFFF_0000_0001).collect(); + + let evals = math_cuda::ntt::forward(&x).expect("forward"); + let back = math_cuda::ntt::inverse(&evals).expect("inverse"); + + let x_c = canonicalize(&x); + let back_c = canonicalize(&back); + assert_eq!(x_c, back_c, "round trip failed"); +} + diff --git a/crypto/stark/Cargo.toml b/crypto/stark/Cargo.toml index 53b205996..4d1f2cbca 100644 --- a/crypto/stark/Cargo.toml +++ b/crypto/stark/Cargo.toml @@ -22,6 +22,9 @@ itertools = "0.11.0" # Parallelization crates rayon = { version = "1.8.0", optional = true } +# GPU backend for trace LDE — only linked when `cuda` is enabled. +math-cuda = { path = "../math-cuda", optional = true } + # wasm wasm-bindgen = { version = "0.2", optional = true } serde-wasm-bindgen = { version = "0.5", optional = true } @@ -39,6 +42,7 @@ test_fiat_shamir = [] instruments = [] # This enables timing prints in prover and verifier debug-checks = [] # Enables validate_trace + bus balance report in prover parallel = ["dep:rayon", "crypto/parallel"] +cuda = ["dep:math-cuda"] wasm = ["dep:wasm-bindgen", "dep:serde-wasm-bindgen", "dep:web-sys"] diff --git a/crypto/stark/src/fri/mod.rs b/crypto/stark/src/fri/mod.rs index 87ab66a5b..1fa7f5e2b 100644 --- a/crypto/stark/src/fri/mod.rs +++ b/crypto/stark/src/fri/mod.rs @@ -33,6 +33,24 @@ where FieldElement: AsBytes + Sync + Send, FieldElement: AsBytes + Sync + Send, { + // GPU fast path: keeps evals, inv_twiddles, and per-layer Merkle trees + // device-resident across log₂(domain_size) layers. Only D2H'd per + // layer: the root (32 B → transcript) + the layer's evals and tree + // nodes (needed by query_phase later). Falls back to CPU when the + // `cuda` feature is off, types mismatch, or the domain is too small. + #[cfg(feature = "cuda")] + { + if let Some(result) = crate::gpu_lde::try_fri_commit_gpu::( + number_layers, + &evals, + transcript, + coset_offset, + domain_size, + ) { + return result; + } + } + // Inverse twiddle factors for evaluation-form folding let mut inv_twiddles = compute_coset_twiddles_inv(coset_offset, domain_size); diff --git a/crypto/stark/src/gpu_lde.rs b/crypto/stark/src/gpu_lde.rs new file mode 100644 index 000000000..3f4b57548 --- /dev/null +++ b/crypto/stark/src/gpu_lde.rs @@ -0,0 +1,2144 @@ +//! GPU dispatch layer for the per-column coset LDE. Lives in the stark crate +//! (not `math`) to avoid a dependency cycle between `math` and `math-cuda`. +//! +//! Handles only Goldilocks base-field columns above a size threshold; falls +//! back to CPU for extension-field columns and small columns where kernel +//! launch overhead dominates. Produces the same natural-order, non-canonical +//! LDE evaluations as the CPU path. + +use core::any::type_name; + +#[cfg(feature = "parallel")] +use rayon::prelude::*; + +use math::field::element::FieldElement; +use math::field::extensions_goldilocks::Degree3GoldilocksExtensionField; +use math::field::goldilocks::GoldilocksField; +use math::field::traits::{IsField, IsSubFieldOf}; + +use crate::domain::Domain; + +/// Break-even LDE size. Below this, the CPU `coset_lde_full_expand` completes +/// in a few hundred microseconds and the GPU's ~37 kernel launches plus +/// H2D/D2H round-trip is a net loss. The check is on **lde size**, not trace +/// length, because that's what determines the FFT workload. +/// +/// 2^19 is a conservative default calibrated against a 46-core machine where +/// rayon-parallel CPU LDE is already fast. Override via env var for tuning +/// on smaller machines; see `/workspace/lambda_vm/crypto/math-cuda/tests/bench_quick.rs`. +const DEFAULT_GPU_LDE_THRESHOLD: usize = 1 << 19; + +fn gpu_lde_threshold() -> usize { + static CACHED: std::sync::OnceLock = std::sync::OnceLock::new(); + *CACHED.get_or_init(|| { + std::env::var("LAMBDA_VM_GPU_LDE_THRESHOLD") + .ok() + .and_then(|s| s.parse().ok()) + .unwrap_or(DEFAULT_GPU_LDE_THRESHOLD) + }) +} + +/// Atomically counted by `try_expand_column` every time it actually routes a +/// column to the GPU. Used by benchmarks to confirm the GPU path fired. +static GPU_LDE_CALLS: std::sync::atomic::AtomicU64 = std::sync::atomic::AtomicU64::new(0); + +pub fn gpu_lde_calls() -> u64 { + GPU_LDE_CALLS.load(std::sync::atomic::Ordering::Relaxed) +} + +pub fn reset_gpu_lde_calls() { + GPU_LDE_CALLS.store(0, std::sync::atomic::Ordering::Relaxed); +} + +pub(crate) static GPU_EXTEND_HALVES_CALLS: std::sync::atomic::AtomicU64 = + std::sync::atomic::AtomicU64::new(0); +pub fn gpu_extend_halves_calls() -> u64 { + GPU_EXTEND_HALVES_CALLS.load(std::sync::atomic::Ordering::Relaxed) +} + +/// Try to GPU-batch all columns in one pass. +/// +/// Only engaged for Goldilocks-base tables whose LDE size is above the +/// threshold. The prover's `expand_columns_to_lde` hands us every column of +/// one table at once; those columns all share twiddles and coset weights so +/// they can be processed in a single batched pipeline on one stream. +/// +/// Returns `true` if the batch was handled on GPU (and `columns` now contains +/// the LDE evaluations). Returns `false` to let the caller run the per-column +/// CPU fallback. +#[inline] +pub(crate) fn try_expand_columns_batched( + columns: &mut [Vec>], + blowup_factor: usize, + weights: &[FieldElement], +) -> bool +where + F: IsField, + E: IsField, +{ + if columns.is_empty() { + return true; // nothing to do — same as CPU path + } + let n = columns[0].len(); + let lde_size = n.saturating_mul(blowup_factor); + if lde_size < gpu_lde_threshold() { + return false; + } + if type_name::() != type_name::() { + return false; + } + // All columns within one call must be the same size (invariant of the + // caller), but double-check before unsafe extraction. + if columns.iter().any(|c| c.len() != n) { + return false; + } + + // Ext3 fast path: decompose each ext3 column into its 3 base components + // and dispatch to the base-field batched NTT with 3×M logical columns. + // Butterflies with a base-field twiddle act componentwise on ext3, so + // this is exactly equivalent to running the NTT in the extension field. + if type_name::() == type_name::() { + return try_expand_columns_batched_ext3::(columns, blowup_factor, weights); + } + + if type_name::() != type_name::() { + return false; + } + + // Extract raw u64 slices. SAFETY: type_name above confirms + // `E == GoldilocksField`, so `FieldElement` wraps u64 one-to-one. + let raw_columns: Vec> = columns + .iter() + .map(|col| { + col.iter() + .map(|e| unsafe { *(e.value() as *const _ as *const u64) }) + .collect() + }) + .collect(); + let weights_u64: Vec = weights + .iter() + .map(|w| unsafe { *(w.value() as *const _ as *const u64) }) + .collect(); + + // Pre-size caller Vecs to lde_size so the GPU path can write directly + // into the same backing allocation the caller already holds. This skips + // the intermediate `Vec>` allocation (which would page-fault + // per column) and is the main reason `coset_lde_batch_base_into` exists. + for col in columns.iter_mut() { + // SAFETY: set_len is valid here because capacity is already >= + // lde_size (the caller sized columns via `extract_columns_main(lde_size)`) + // and we're about to overwrite every slot via the GPU copy below. + debug_assert!(col.capacity() >= lde_size); + unsafe { col.set_len(lde_size) }; + } + + // Borrow each caller Vec as a raw `&mut [u64]` slice; safe because each + // FieldElement aliases a single u64 when E == GoldilocksField. + let mut raw_outputs: Vec<&mut [u64]> = columns + .iter_mut() + .map(|col| { + let ptr = col.as_mut_ptr() as *mut u64; + let len = col.len(); + // SAFETY: see above — single-u64 layout, caller still owns. + unsafe { core::slice::from_raw_parts_mut(ptr, len) } + }) + .collect(); + + let slices: Vec<&[u64]> = raw_columns.iter().map(|c| c.as_slice()).collect(); + GPU_LDE_CALLS.fetch_add(columns.len() as u64, std::sync::atomic::Ordering::Relaxed); + math_cuda::lde::coset_lde_batch_base_into( + &slices, + blowup_factor, + &weights_u64, + &mut raw_outputs, + ) + .expect("GPU batched coset LDE failed"); + true +} + +/// GPU path for `Prover::extend_half_to_lde`. +/// +/// Inside `decompose_and_extend_d2` (R2 quotient decomposition) the prover +/// does `rayon::join` of two calls: `iFFT(N on g²-coset) → FFT(2N on g-coset)` +/// over ext3 halves H0 and H1. They share the same domain/offset and sizes, +/// so we batch them into a single GPU call with M=2 ext3 columns. +/// +/// Weights = `[1/N, g^(-1)/N, g^(-2)/N, …, g^(-(N-1))/N]`. This bakes the +/// `(g²)^(-k)` input-coset-undo from `interpolate_offset_fft` together with +/// the `g^k` forward-coset-shift from `evaluate_polynomial_on_lde_domain` — +/// net is `g^(-k)` — plus the `1/N` iFFT normalisation. +/// +/// Returns `None` when the GPU path doesn't apply (too small, or CPU path +/// should be used); in that case the caller runs its existing rayon::join. +pub(crate) fn try_extend_two_halves_gpu( + h0: &[FieldElement], + h1: &[FieldElement], + squared_offset: &FieldElement, + domain: &Domain, +) -> Option<(Vec>, Vec>)> +where + F: math::field::traits::IsFFTField + IsField, + E: IsField, + F: IsSubFieldOf, +{ + if h0.len() != h1.len() { + return None; + } + let n = h0.len(); + let blowup = 2; // extend_half_to_lde extends N → 2N always + let lde_size = n * blowup; + if lde_size < gpu_lde_threshold() { + return None; + } + if type_name::() != type_name::() { + return None; + } + if type_name::() != type_name::() { + return None; + } + GPU_EXTEND_HALVES_CALLS.fetch_add(1, std::sync::atomic::Ordering::Relaxed); + // squared_offset should be `g²`. We recover `g` as `domain.coset_offset` + // and use it to build the `g^(-k) / N` weights. + let _ = squared_offset; // unused (we derive weights from domain) + + // Flatten ext3 slices to raw 3*n u64 buffers. + let to_u64 = |col: &[FieldElement]| -> Vec { + let len = col.len() * 3; + let ptr = col.as_ptr() as *const u64; + unsafe { core::slice::from_raw_parts(ptr, len) }.to_vec() + }; + let h0_raw = to_u64(h0); + let h1_raw = to_u64(h1); + + // weights[k] = g^(-k) / N as a u64. + let inv_n = FieldElement::::from(n as u64) + .inv() + .expect("N nonzero"); + let g = &domain.coset_offset; + let g_inv = g.inv().expect("g nonzero"); + let mut weights_u64 = Vec::with_capacity(n); + let mut w = inv_n.clone(); + for _ in 0..n { + // F == GoldilocksField by type_name check above, so value is u64. + let v: u64 = unsafe { *(w.value() as *const _ as *const u64) }; + weights_u64.push(v); + w = w * &g_inv; + } + + // Pre-allocate outputs. + let mut lde_h0 = vec![FieldElement::::zero(); lde_size]; + let mut lde_h1 = vec![FieldElement::::zero(); lde_size]; + + GPU_LDE_CALLS.fetch_add(6, std::sync::atomic::Ordering::Relaxed); // 2 ext3 cols × 3 components + { + let inputs: [&[u64]; 2] = [&h0_raw, &h1_raw]; + // View each output Vec> as &mut [u64] of length 3*lde_size. + let out0_ptr = lde_h0.as_mut_ptr() as *mut u64; + let out1_ptr = lde_h1.as_mut_ptr() as *mut u64; + // SAFETY: ext3 FieldElement is [u64; 3] in memory, and the Vec has len + // = lde_size so the backing is 3*lde_size u64s. + let out0_slice = unsafe { core::slice::from_raw_parts_mut(out0_ptr, 3 * lde_size) }; + let out1_slice = unsafe { core::slice::from_raw_parts_mut(out1_ptr, 3 * lde_size) }; + let mut outputs: [&mut [u64]; 2] = [out0_slice, out1_slice]; + math_cuda::lde::coset_lde_batch_ext3_into( + &inputs, + n, + blowup, + &weights_u64, + &mut outputs, + ) + .expect("GPU extend_half_to_lde failed"); + } + + Some((lde_h0, lde_h1)) +} + +/// GPU path for Round 4's DEEP-poly LDE extension. +/// +/// The CPU pipeline at `prover.rs:1107` is +/// ```ignore +/// let deep_poly = Polynomial::interpolate_fft::(&deep_evals)?; +/// let mut lde_evals = Polynomial::evaluate_fft::(&deep_poly, 1, Some(domain_size))?; +/// in_place_bit_reverse_permute(&mut lde_evals); +/// ``` +/// +/// That is an iFFT over `N = deep_evals.len()` ext3 elements followed by an +/// FFT evaluation on `domain_size` points — the **standard** (non-coset) LDE +/// on the extension field with weights `[1/N, ..., 1/N]`. We reuse +/// `coset_lde_batch_ext3_into` with a uniform `1/N` weight vector; the +/// single ext3 column is handled internally as 3 base-field slabs. The +/// caller keeps its trailing `in_place_bit_reverse_permute`, so output +/// order is unchanged. +pub(crate) fn try_r4_deep_poly_lde_gpu( + deep_evals: &[FieldElement], + domain_size: usize, +) -> Option>> +where + E: IsField, +{ + let n = deep_evals.len(); + if n == 0 || !n.is_power_of_two() { + return None; + } + if domain_size < n || !domain_size.is_power_of_two() { + return None; + } + let blowup = domain_size / n; + if blowup < 2 { + return None; + } + if domain_size < gpu_lde_threshold() { + return None; + } + if type_name::() != type_name::() { + return None; + } + + GPU_R4_LDE_CALLS.fetch_add(1, std::sync::atomic::Ordering::Relaxed); + + // Uniform weights = 1/N (no coset shift, just iFFT normalisation). + let inv_n_u64 = { + let fe = FieldElement::::from(n as u64) + .inv() + .expect("N non-zero"); + *fe.value() + }; + let weights = vec![inv_n_u64; n]; + + // Input: single ext3 column, 3n u64s. + let input_raw: Vec = { + let len = n * 3; + let ptr = deep_evals.as_ptr() as *const u64; + unsafe { core::slice::from_raw_parts(ptr, len) }.to_vec() + }; + let inputs: [&[u64]; 1] = [&input_raw]; + + let mut out_vec = vec![FieldElement::::zero(); domain_size]; + { + let out_ptr = out_vec.as_mut_ptr() as *mut u64; + let out_slice = unsafe { core::slice::from_raw_parts_mut(out_ptr, 3 * domain_size) }; + let mut outputs: [&mut [u64]; 1] = [out_slice]; + math_cuda::lde::coset_lde_batch_ext3_into( + &inputs, + n, + blowup, + &weights, + &mut outputs, + ) + .expect("GPU R4 deep-poly LDE failed"); + } + Some(out_vec) +} + +pub(crate) static GPU_R4_LDE_CALLS: std::sync::atomic::AtomicU64 = + std::sync::atomic::AtomicU64::new(0); +pub fn gpu_r4_lde_calls() -> u64 { + GPU_R4_LDE_CALLS.load(std::sync::atomic::Ordering::Relaxed) +} + +/// GPU path for the composition-polynomial LDE in the `number_of_parts > 2` +/// branch of `round_2_compute_composition_polynomial` (prover.rs:920). The +/// caller already has the polynomial parts; we batch their evaluations at +/// the `domain_size × blowup_factor` coset in a single GPU call. +/// +/// Each part is padded to `domain_size` coefficients. Weights = `offset^k` +/// (coset shift, no 1/N normalisation — input is coefficients). +pub(crate) fn try_evaluate_parts_on_lde_gpu( + parts_coefs: &[&[FieldElement]], + blowup_factor: usize, + domain_size: usize, + offset: &FieldElement, +) -> Option>>> +where + F: math::field::traits::IsFFTField + IsField, + E: IsField, + F: IsSubFieldOf, +{ + try_evaluate_parts_on_lde_gpu_impl(parts_coefs, blowup_factor, domain_size, offset, false) + .map(|(v, _)| v) +} + +/// Same as [`try_evaluate_parts_on_lde_gpu`] but also retains the +/// composition-parts LDE device buffer as a `GpuLdeExt3` handle. Used by +/// `round_2_compute_composition_polynomial` to feed R2 commit and R4 +/// DEEP composition without re-H2D'ing. +pub(crate) fn try_evaluate_parts_on_lde_gpu_keep( + parts_coefs: &[&[FieldElement]], + blowup_factor: usize, + domain_size: usize, + offset: &FieldElement, +) -> Option<(Vec>>, math_cuda::lde::GpuLdeExt3)> +where + F: math::field::traits::IsFFTField + IsField, + E: IsField, + F: IsSubFieldOf, +{ + let (v, h) = try_evaluate_parts_on_lde_gpu_impl( + parts_coefs, + blowup_factor, + domain_size, + offset, + true, + )?; + Some((v, h.expect("keep=true returns Some handle"))) +} + +fn try_evaluate_parts_on_lde_gpu_impl( + parts_coefs: &[&[FieldElement]], + blowup_factor: usize, + domain_size: usize, + offset: &FieldElement, + keep: bool, +) -> Option<( + Vec>>, + Option, +)> +where + F: math::field::traits::IsFFTField + IsField, + E: IsField, + F: IsSubFieldOf, +{ + if parts_coefs.is_empty() { + return Some((Vec::new(), None)); + } + if !domain_size.is_power_of_two() || !blowup_factor.is_power_of_two() { + return None; + } + let lde_size = domain_size * blowup_factor; + if lde_size < gpu_lde_threshold() { + return None; + } + if type_name::() != type_name::() { + return None; + } + if type_name::() != type_name::() { + return None; + } + let m = parts_coefs.len(); + + GPU_PARTS_LDE_CALLS.fetch_add(1, std::sync::atomic::Ordering::Relaxed); + + // Weights: `offset^k` for k in 0..domain_size. F == Goldilocks. + let mut weights_u64 = Vec::with_capacity(domain_size); + let mut w = FieldElement::::one(); + for _ in 0..domain_size { + let v: u64 = unsafe { *(w.value() as *const _ as *const u64) }; + weights_u64.push(v); + w = w * offset; + } + + let mut part_bufs: Vec> = Vec::with_capacity(m); + for part in parts_coefs.iter() { + let mut buf = vec![0u64; 3 * domain_size]; + let len = part.len().min(domain_size); + let src_ptr = part.as_ptr() as *const u64; + let src_len = len * 3; + let src = unsafe { core::slice::from_raw_parts(src_ptr, src_len) }; + buf[..src_len].copy_from_slice(src); + part_bufs.push(buf); + } + let input_slices: Vec<&[u64]> = part_bufs.iter().map(|v| v.as_slice()).collect(); + + let mut outputs: Vec>> = (0..m) + .map(|_| vec![FieldElement::::zero(); lde_size]) + .collect(); + let handle = { + let mut out_slices: Vec<&mut [u64]> = outputs + .iter_mut() + .map(|o| { + let ptr = o.as_mut_ptr() as *mut u64; + unsafe { core::slice::from_raw_parts_mut(ptr, 3 * lde_size) } + }) + .collect(); + if keep { + Some( + math_cuda::lde::evaluate_poly_coset_batch_ext3_into_keep( + &input_slices, + domain_size, + blowup_factor, + &weights_u64, + &mut out_slices, + ) + .expect("GPU parts LDE (keep) failed"), + ) + } else { + math_cuda::lde::evaluate_poly_coset_batch_ext3_into( + &input_slices, + domain_size, + blowup_factor, + &weights_u64, + &mut out_slices, + ) + .expect("GPU parts LDE failed"); + None + } + }; + Some((outputs, handle)) +} + +/// Fused variant of [`try_evaluate_parts_on_lde_gpu`]: in addition to the +/// LDE parts, builds the R2 composition-polynomial Merkle tree on device +/// (row-pair Keccak leaves + pair-hash inner tree). Returns both the parts +/// (still needed downstream for R4 openings) and the finished tree. +#[allow(dead_code)] +pub(crate) fn try_evaluate_parts_on_lde_and_commit_gpu( + parts_coefs: &[&[FieldElement]], + blowup_factor: usize, + domain_size: usize, + offset: &FieldElement, +) -> Option<( + Vec>>, + crypto::merkle_tree::merkle::MerkleTree, +)> +where + F: math::field::traits::IsFFTField + IsField, + E: IsField, + F: IsSubFieldOf, + B: crypto::merkle_tree::traits::IsMerkleTreeBackend, +{ + if parts_coefs.is_empty() { + return None; + } + if !domain_size.is_power_of_two() || !blowup_factor.is_power_of_two() { + return None; + } + let lde_size = domain_size * blowup_factor; + if lde_size < gpu_lde_threshold() { + return None; + } + if lde_size < 2 { + return None; + } + if type_name::() != type_name::() { + return None; + } + if type_name::() != type_name::() { + return None; + } + let m = parts_coefs.len(); + + GPU_PARTS_LDE_CALLS.fetch_add(1, std::sync::atomic::Ordering::Relaxed); + GPU_MERKLE_TREE_CALLS.fetch_add(1, std::sync::atomic::Ordering::Relaxed); + + // Weights: `offset^k`. + let mut weights_u64 = Vec::with_capacity(domain_size); + let mut w = FieldElement::::one(); + for _ in 0..domain_size { + let v: u64 = unsafe { *(w.value() as *const _ as *const u64) }; + weights_u64.push(v); + w = w * offset; + } + + // Pack parts into per-part 3*domain_size u64 buffers (zero-padded). + let mut part_bufs: Vec> = Vec::with_capacity(m); + for part in parts_coefs.iter() { + let mut buf = vec![0u64; 3 * domain_size]; + let len = part.len().min(domain_size); + let src_ptr = part.as_ptr() as *const u64; + let src_len = len * 3; + let src = unsafe { core::slice::from_raw_parts(src_ptr, src_len) }; + buf[..src_len].copy_from_slice(src); + part_bufs.push(buf); + } + let input_slices: Vec<&[u64]> = part_bufs.iter().map(|v| v.as_slice()).collect(); + + let mut outputs: Vec>> = (0..m) + .map(|_| vec![FieldElement::::zero(); lde_size]) + .collect(); + let num_leaves = lde_size / 2; + let tight_total_nodes = 2 * num_leaves - 1; + let mut nodes_bytes = vec![0u8; tight_total_nodes * 32]; + { + let mut out_slices: Vec<&mut [u64]> = outputs + .iter_mut() + .map(|o| { + let ptr = o.as_mut_ptr() as *mut u64; + unsafe { core::slice::from_raw_parts_mut(ptr, 3 * lde_size) } + }) + .collect(); + math_cuda::lde::evaluate_poly_coset_batch_ext3_into_with_merkle_tree( + &input_slices, + domain_size, + blowup_factor, + &weights_u64, + &mut out_slices, + &mut nodes_bytes, + ) + .expect("GPU ext3 evaluate+commit failed"); + } + + // Build the MerkleTree from the device-produced nodes. + let mut nodes: Vec<[u8; 32]> = Vec::with_capacity(tight_total_nodes); + for i in 0..tight_total_nodes { + let mut n = [0u8; 32]; + n.copy_from_slice(&nodes_bytes[i * 32..(i + 1) * 32]); + nodes.push(n); + } + let tree = crypto::merkle_tree::merkle::MerkleTree::::from_precomputed_nodes(nodes)?; + Some((outputs, tree)) +} + +/// Build a FRI-layer Merkle tree from already-folded evaluations using the +/// GPU pair-leaf kernel + pair-hash inner tree. +/// +/// Not currently wired — benchmarking showed the win per layer (GPU tree +/// vs rayon tree) is eaten by the H2D of each layer's eval slab since the +/// evals are in pageable CPU Vec form at call time. A fused on-device FRI +/// (fold + leaves + tree all staying on device across layers) would flip +/// this but is deferred to the "LDE on GPU across rounds" item. +#[allow(dead_code)] +pub(crate) fn try_build_fri_layer_tree_gpu( + evals: &[FieldElement], +) -> Option> +where + E: IsField, + B: crypto::merkle_tree::traits::IsMerkleTreeBackend, +{ + let num_evals = evals.len(); + if num_evals < 2 || !num_evals.is_power_of_two() { + return None; + } + let num_leaves = num_evals / 2; + // Higher threshold than the generic LDE path because each FRI layer + // H2Ds a fresh eval slab; tiny layers can't amortise that. + if num_leaves < gpu_fri_tree_threshold() { + return None; + } + if type_name::() != type_name::() { + return None; + } + + GPU_MERKLE_TREE_CALLS.fetch_add(1, std::sync::atomic::Ordering::Relaxed); + + // SAFETY: E == Ext3 whose BaseType is [FieldElement; 3] = + // contiguous [u64; 3] at runtime. + let evals_raw: &[u64] = + unsafe { core::slice::from_raw_parts(evals.as_ptr() as *const u64, num_evals * 3) }; + let nodes_bytes = math_cuda::merkle::build_fri_layer_tree_from_evals_ext3(evals_raw) + .expect("GPU FRI layer tree build failed"); + + let tight_total_nodes = 2 * num_leaves - 1; + debug_assert_eq!(nodes_bytes.len(), tight_total_nodes * 32); + let mut nodes: Vec<[u8; 32]> = Vec::with_capacity(tight_total_nodes); + for i in 0..tight_total_nodes { + let mut n = [0u8; 32]; + n.copy_from_slice(&nodes_bytes[i * 32..(i + 1) * 32]); + nodes.push(n); + } + crypto::merkle_tree::merkle::MerkleTree::::from_precomputed_nodes(nodes) +} + +/// Build the R2 composition-polynomial Merkle tree from already-computed +/// LDE parts using the GPU row-pair leaf kernel + pair-hash inner tree. +/// Takes H2D for every call — only worth doing when the tree is large enough +/// that CPU rayon Merkle build exceeds the round-trip cost. +pub(crate) fn try_build_comp_poly_tree_gpu( + lde_parts: &[Vec>], +) -> Option> +where + E: IsField, + B: crypto::merkle_tree::traits::IsMerkleTreeBackend, +{ + if lde_parts.is_empty() { + return None; + } + let lde_size = lde_parts[0].len(); + if !lde_size.is_power_of_two() || lde_size < 2 { + return None; + } + if lde_size < gpu_lde_threshold() { + return None; + } + if type_name::() != type_name::() { + return None; + } + // All parts same length. + if lde_parts.iter().any(|p| p.len() != lde_size) { + return None; + } + + GPU_MERKLE_TREE_CALLS.fetch_add(1, std::sync::atomic::Ordering::Relaxed); + + // SAFETY: E == Ext3 whose BaseType is [FieldElement; 3] = + // contiguous [u64; 3] at runtime. + let raw_parts: Vec<&[u64]> = lde_parts + .iter() + .map(|p| unsafe { core::slice::from_raw_parts(p.as_ptr() as *const u64, p.len() * 3) }) + .collect(); + + let nodes_bytes = math_cuda::merkle::build_comp_poly_tree_from_evals_ext3(&raw_parts) + .expect("GPU comp-poly tree build failed"); + + let num_leaves = lde_size / 2; + let tight_total_nodes = 2 * num_leaves - 1; + debug_assert_eq!(nodes_bytes.len(), tight_total_nodes * 32); + let mut nodes: Vec<[u8; 32]> = Vec::with_capacity(tight_total_nodes); + for i in 0..tight_total_nodes { + let mut n = [0u8; 32]; + n.copy_from_slice(&nodes_bytes[i * 32..(i + 1) * 32]); + nodes.push(n); + } + crypto::merkle_tree::merkle::MerkleTree::::from_precomputed_nodes(nodes) +} + +pub(crate) static GPU_PARTS_LDE_CALLS: std::sync::atomic::AtomicU64 = + std::sync::atomic::AtomicU64::new(0); +pub fn gpu_parts_lde_calls() -> u64 { + GPU_PARTS_LDE_CALLS.load(std::sync::atomic::Ordering::Relaxed) +} + +/// Combined GPU LDE + Merkle leaf hash for the base-field main trace. +/// +/// Keeps LDE output on device, runs Keccak-256 on the device buffer directly, +/// D2Hs both LDE columns (for Round 2-4 reuse) and hashed leaves (for tree +/// construction). Avoids the second H2D that a separate GPU Merkle commit +/// path would require. +/// +/// On success: resizes each `columns[c]` to `lde_size` with the LDE output, +/// and returns `Vec` — the Keccak-256 hashed leaves in natural +/// row order, ready to pass to `BatchedMerkleTree::build_from_hashed_leaves`. +#[allow(dead_code)] +pub(crate) fn try_expand_and_leaf_hash_batched( + columns: &mut [Vec>], + blowup_factor: usize, + weights: &[FieldElement], +) -> Option> +where + F: IsField, + E: IsField, +{ + if columns.is_empty() { + return Some(Vec::new()); + } + let n = columns[0].len(); + let lde_size = n.saturating_mul(blowup_factor); + if lde_size < gpu_lde_threshold() { + return None; + } + if type_name::() != type_name::() { + return None; + } + if type_name::() != type_name::() { + return None; + } + if columns.iter().any(|c| c.len() != n) { + return None; + } + + let raw_columns: Vec> = columns + .iter() + .map(|col| { + col.iter() + .map(|e| unsafe { *(e.value() as *const _ as *const u64) }) + .collect() + }) + .collect(); + let weights_u64: Vec = weights + .iter() + .map(|w| unsafe { *(w.value() as *const _ as *const u64) }) + .collect(); + + for col in columns.iter_mut() { + debug_assert!(col.capacity() >= lde_size); + unsafe { col.set_len(lde_size) }; + } + let mut raw_outputs: Vec<&mut [u64]> = columns + .iter_mut() + .map(|col| { + let ptr = col.as_mut_ptr() as *mut u64; + let len = col.len(); + unsafe { core::slice::from_raw_parts_mut(ptr, len) } + }) + .collect(); + + let slices: Vec<&[u64]> = raw_columns.iter().map(|c| c.as_slice()).collect(); + + // Allocate as Vec<[u8; 32]> directly so we both skip the zero-fill pass + // AND avoid re-chunking afterwards. Fresh pages still fault on first + // write (inside the GPU-side memcpy), but only once each. + let mut leaves: Vec<[u8; 32]> = Vec::with_capacity(lde_size); + // SAFETY: we fill every byte via memcpy_dtoh below. + unsafe { leaves.set_len(lde_size) }; + let hashed_bytes_ptr = leaves.as_mut_ptr() as *mut u8; + let hashed_bytes: &mut [u8] = + unsafe { std::slice::from_raw_parts_mut(hashed_bytes_ptr, lde_size * 32) }; + + GPU_LDE_CALLS.fetch_add(columns.len() as u64, std::sync::atomic::Ordering::Relaxed); + GPU_LEAF_HASH_CALLS.fetch_add(1, std::sync::atomic::Ordering::Relaxed); + math_cuda::lde::coset_lde_batch_base_into_with_leaf_hash( + &slices, + blowup_factor, + &weights_u64, + &mut raw_outputs, + hashed_bytes, + ) + .expect("GPU LDE+leaf-hash failed"); + + Some(leaves) +} + +pub(crate) static GPU_LEAF_HASH_CALLS: std::sync::atomic::AtomicU64 = + std::sync::atomic::AtomicU64::new(0); +pub fn gpu_leaf_hash_calls() -> u64 { + GPU_LEAF_HASH_CALLS.load(std::sync::atomic::Ordering::Relaxed) +} + +/// Fused variant: LDE + leaf-hash + Merkle tree build, all on device. Skips +/// the pinned→pageable→pinned leaf dance of the separate-step pipeline. +/// Returns the filled `MerkleTree` alongside populating `columns` with +/// the LDE-expanded evaluations. +#[allow(dead_code)] +pub(crate) fn try_expand_leaf_and_tree_batched( + columns: &mut [Vec>], + blowup_factor: usize, + weights: &[FieldElement], +) -> Option> +where + F: IsField, + E: IsField, + B: crypto::merkle_tree::traits::IsMerkleTreeBackend, +{ + if columns.is_empty() { + return None; + } + let n = columns[0].len(); + let lde_size = n.saturating_mul(blowup_factor); + if lde_size < gpu_lde_threshold() { + return None; + } + if type_name::() != type_name::() { + return None; + } + if type_name::() != type_name::() { + return None; + } + if columns.iter().any(|c| c.len() != n) { + return None; + } + // Tree layout needs `2*lde_size - 1` nodes; must be a power-of-two leaf + // count. LDE size is always pow2 here (checked above). + if lde_size < 2 { + return None; + } + + let raw_columns: Vec> = columns + .iter() + .map(|col| { + col.iter() + .map(|e| unsafe { *(e.value() as *const _ as *const u64) }) + .collect() + }) + .collect(); + let weights_u64: Vec = weights + .iter() + .map(|w| unsafe { *(w.value() as *const _ as *const u64) }) + .collect(); + + for col in columns.iter_mut() { + debug_assert!(col.capacity() >= lde_size); + unsafe { col.set_len(lde_size) }; + } + let mut raw_outputs: Vec<&mut [u64]> = columns + .iter_mut() + .map(|col| { + let ptr = col.as_mut_ptr() as *mut u64; + let len = col.len(); + unsafe { core::slice::from_raw_parts_mut(ptr, len) } + }) + .collect(); + + let slices: Vec<&[u64]> = raw_columns.iter().map(|c| c.as_slice()).collect(); + + let total_nodes = 2 * lde_size - 1; + let mut nodes: Vec<[u8; 32]> = Vec::with_capacity(total_nodes); + // SAFETY: every byte is written by the D2H below. + unsafe { nodes.set_len(total_nodes) }; + let nodes_bytes: &mut [u8] = unsafe { + core::slice::from_raw_parts_mut(nodes.as_mut_ptr() as *mut u8, total_nodes * 32) + }; + + GPU_LDE_CALLS.fetch_add(columns.len() as u64, std::sync::atomic::Ordering::Relaxed); + GPU_LEAF_HASH_CALLS.fetch_add(1, std::sync::atomic::Ordering::Relaxed); + GPU_MERKLE_TREE_CALLS.fetch_add(1, std::sync::atomic::Ordering::Relaxed); + + math_cuda::lde::coset_lde_batch_base_into_with_merkle_tree( + &slices, + blowup_factor, + &weights_u64, + &mut raw_outputs, + nodes_bytes, + ) + .expect("GPU LDE+leaf-hash+tree failed"); + + crypto::merkle_tree::merkle::MerkleTree::::from_precomputed_nodes(nodes) +} + +/// Same as [`try_expand_leaf_and_tree_batched`] but ALSO retains the LDE +/// device buffer so R2–R4 GPU paths can reuse the LDE without a re-H2D. +/// Returns `(tree, gpu_handle)` on success, `None` if the GPU path doesn't +/// apply (same gates as the non-`_keep` variant). +pub(crate) fn try_expand_leaf_and_tree_batched_keep( + columns: &mut [Vec>], + blowup_factor: usize, + weights: &[FieldElement], +) -> Option<( + crypto::merkle_tree::merkle::MerkleTree, + math_cuda::lde::GpuLdeBase, +)> +where + F: IsField, + E: IsField, + B: crypto::merkle_tree::traits::IsMerkleTreeBackend, +{ + if columns.is_empty() { + return None; + } + let n = columns[0].len(); + let lde_size = n.saturating_mul(blowup_factor); + if lde_size < gpu_lde_threshold() { + return None; + } + if type_name::() != type_name::() { + return None; + } + if type_name::() != type_name::() { + return None; + } + if columns.iter().any(|c| c.len() != n) { + return None; + } + if lde_size < 2 { + return None; + } + + let raw_columns: Vec> = columns + .iter() + .map(|col| { + col.iter() + .map(|e| unsafe { *(e.value() as *const _ as *const u64) }) + .collect() + }) + .collect(); + let weights_u64: Vec = weights + .iter() + .map(|w| unsafe { *(w.value() as *const _ as *const u64) }) + .collect(); + + for col in columns.iter_mut() { + debug_assert!(col.capacity() >= lde_size); + unsafe { col.set_len(lde_size) }; + } + let mut raw_outputs: Vec<&mut [u64]> = columns + .iter_mut() + .map(|col| { + let ptr = col.as_mut_ptr() as *mut u64; + let len = col.len(); + unsafe { core::slice::from_raw_parts_mut(ptr, len) } + }) + .collect(); + + let slices: Vec<&[u64]> = raw_columns.iter().map(|c| c.as_slice()).collect(); + + let total_nodes = 2 * lde_size - 1; + let mut nodes: Vec<[u8; 32]> = Vec::with_capacity(total_nodes); + unsafe { nodes.set_len(total_nodes) }; + let nodes_bytes: &mut [u8] = unsafe { + core::slice::from_raw_parts_mut(nodes.as_mut_ptr() as *mut u8, total_nodes * 32) + }; + + GPU_LDE_CALLS.fetch_add(columns.len() as u64, std::sync::atomic::Ordering::Relaxed); + GPU_LEAF_HASH_CALLS.fetch_add(1, std::sync::atomic::Ordering::Relaxed); + GPU_MERKLE_TREE_CALLS.fetch_add(1, std::sync::atomic::Ordering::Relaxed); + + let handle = math_cuda::lde::coset_lde_batch_base_into_with_merkle_tree_keep( + &slices, + blowup_factor, + &weights_u64, + &mut raw_outputs, + nodes_bytes, + ) + .expect("GPU LDE+leaf-hash+tree+keep failed"); + + let tree = crypto::merkle_tree::merkle::MerkleTree::::from_precomputed_nodes(nodes)?; + Some((tree, handle)) +} + +/// Ext3 variant of [`try_expand_leaf_and_tree_batched`]. Same fused flow +/// (LDE → leaf-hash → tree build) but over ext3 columns via the three-slab +/// decomposition; `B::Node = [u8; 32]` by construction for +/// `BatchKeccak256Backend`. +#[allow(dead_code)] +pub(crate) fn try_expand_leaf_and_tree_batched_ext3( + columns: &mut [Vec>], + blowup_factor: usize, + weights: &[FieldElement], +) -> Option> +where + F: IsField, + E: IsField, + B: crypto::merkle_tree::traits::IsMerkleTreeBackend, +{ + if columns.is_empty() { + return None; + } + let n = columns[0].len(); + let lde_size = n.saturating_mul(blowup_factor); + if lde_size < gpu_lde_threshold() { + return None; + } + if type_name::() != type_name::() { + return None; + } + if type_name::() != type_name::() { + return None; + } + if lde_size < 2 { + return None; + } + + // SAFETY: `E == Degree3Goldilocks`; each `FieldElement` is + // memory-equivalent to `[u64; 3]`. Copy out a Vec view per column. + let raw_columns: Vec> = columns + .iter() + .map(|col| { + let len = col.len() * 3; + let ptr = col.as_ptr() as *const u64; + unsafe { core::slice::from_raw_parts(ptr, len) }.to_vec() + }) + .collect(); + let weights_u64: Vec = weights + .iter() + .map(|w| unsafe { *(w.value() as *const _ as *const u64) }) + .collect(); + + for col in columns.iter_mut() { + debug_assert!(col.capacity() >= lde_size); + unsafe { col.set_len(lde_size) }; + } + let mut raw_outputs: Vec<&mut [u64]> = columns + .iter_mut() + .map(|col| { + let ptr = col.as_mut_ptr() as *mut u64; + let len = col.len() * 3; + unsafe { core::slice::from_raw_parts_mut(ptr, len) } + }) + .collect(); + + let slices: Vec<&[u64]> = raw_columns.iter().map(|c| c.as_slice()).collect(); + + let total_nodes = 2 * lde_size - 1; + let mut nodes: Vec<[u8; 32]> = Vec::with_capacity(total_nodes); + unsafe { nodes.set_len(total_nodes) }; + let nodes_bytes: &mut [u8] = unsafe { + core::slice::from_raw_parts_mut(nodes.as_mut_ptr() as *mut u8, total_nodes * 32) + }; + + GPU_LDE_CALLS.fetch_add((columns.len() * 3) as u64, std::sync::atomic::Ordering::Relaxed); + GPU_LEAF_HASH_CALLS.fetch_add(1, std::sync::atomic::Ordering::Relaxed); + GPU_MERKLE_TREE_CALLS.fetch_add(1, std::sync::atomic::Ordering::Relaxed); + + math_cuda::lde::coset_lde_batch_ext3_into_with_merkle_tree( + &slices, + n, + blowup_factor, + &weights_u64, + &mut raw_outputs, + nodes_bytes, + ) + .expect("GPU ext3 LDE+leaf-hash+tree failed"); + + crypto::merkle_tree::merkle::MerkleTree::::from_precomputed_nodes(nodes) +} + +/// Same as [`try_expand_leaf_and_tree_batched_ext3`] but also returns the +/// ext3 LDE device buffer (de-interleaved 3-slab layout) so downstream GPU +/// rounds can reuse it. +pub(crate) fn try_expand_leaf_and_tree_batched_ext3_keep( + columns: &mut [Vec>], + blowup_factor: usize, + weights: &[FieldElement], +) -> Option<( + crypto::merkle_tree::merkle::MerkleTree, + math_cuda::lde::GpuLdeExt3, +)> +where + F: IsField, + E: IsField, + B: crypto::merkle_tree::traits::IsMerkleTreeBackend, +{ + if columns.is_empty() { + return None; + } + let n = columns[0].len(); + let lde_size = n.saturating_mul(blowup_factor); + if lde_size < gpu_lde_threshold() { + return None; + } + if type_name::() != type_name::() { + return None; + } + if type_name::() != type_name::() { + return None; + } + if lde_size < 2 { + return None; + } + + let raw_columns: Vec> = columns + .iter() + .map(|col| { + let len = col.len() * 3; + let ptr = col.as_ptr() as *const u64; + unsafe { core::slice::from_raw_parts(ptr, len) }.to_vec() + }) + .collect(); + let weights_u64: Vec = weights + .iter() + .map(|w| unsafe { *(w.value() as *const _ as *const u64) }) + .collect(); + + for col in columns.iter_mut() { + debug_assert!(col.capacity() >= lde_size); + unsafe { col.set_len(lde_size) }; + } + let mut raw_outputs: Vec<&mut [u64]> = columns + .iter_mut() + .map(|col| { + let ptr = col.as_mut_ptr() as *mut u64; + let len = col.len() * 3; + unsafe { core::slice::from_raw_parts_mut(ptr, len) } + }) + .collect(); + + let slices: Vec<&[u64]> = raw_columns.iter().map(|c| c.as_slice()).collect(); + + let total_nodes = 2 * lde_size - 1; + let mut nodes: Vec<[u8; 32]> = Vec::with_capacity(total_nodes); + unsafe { nodes.set_len(total_nodes) }; + let nodes_bytes: &mut [u8] = unsafe { + core::slice::from_raw_parts_mut(nodes.as_mut_ptr() as *mut u8, total_nodes * 32) + }; + + GPU_LDE_CALLS.fetch_add((columns.len() * 3) as u64, std::sync::atomic::Ordering::Relaxed); + GPU_LEAF_HASH_CALLS.fetch_add(1, std::sync::atomic::Ordering::Relaxed); + GPU_MERKLE_TREE_CALLS.fetch_add(1, std::sync::atomic::Ordering::Relaxed); + + let handle = math_cuda::lde::coset_lde_batch_ext3_into_with_merkle_tree_keep( + &slices, + n, + blowup_factor, + &weights_u64, + &mut raw_outputs, + nodes_bytes, + ) + .expect("GPU ext3 LDE+leaf-hash+tree+keep failed"); + + let tree = crypto::merkle_tree::merkle::MerkleTree::::from_precomputed_nodes(nodes)?; + Some((tree, handle)) +} + +/// Ext3 variant of [`try_expand_and_leaf_hash_batched`] for the aux trace. +/// Decomposes each ext3 column into three base slabs, runs the LDE + Keccak +/// ext3 kernel in one on-device pipeline, re-interleaves LDE output back to +/// ext3 layout, and returns hashed leaves. +#[allow(dead_code)] +pub(crate) fn try_expand_and_leaf_hash_batched_ext3( + columns: &mut [Vec>], + blowup_factor: usize, + weights: &[FieldElement], +) -> Option> +where + F: IsField, + E: IsField, +{ + if columns.is_empty() { + return Some(Vec::new()); + } + let n = columns[0].len(); + let lde_size = n.saturating_mul(blowup_factor); + if lde_size < gpu_lde_threshold() { + return None; + } + if type_name::() != type_name::() { + return None; + } + if type_name::() != type_name::() { + return None; + } + if columns.iter().any(|c| c.len() != n) { + return None; + } + + let raw_columns: Vec> = columns + .iter() + .map(|col| { + let len = col.len() * 3; + let ptr = col.as_ptr() as *const u64; + unsafe { core::slice::from_raw_parts(ptr, len) }.to_vec() + }) + .collect(); + let weights_u64: Vec = weights + .iter() + .map(|w| unsafe { *(w.value() as *const _ as *const u64) }) + .collect(); + + for col in columns.iter_mut() { + debug_assert!(col.capacity() >= lde_size); + unsafe { col.set_len(lde_size) }; + } + let mut raw_outputs: Vec<&mut [u64]> = columns + .iter_mut() + .map(|col| { + let ptr = col.as_mut_ptr() as *mut u64; + let len = col.len() * 3; + unsafe { core::slice::from_raw_parts_mut(ptr, len) } + }) + .collect(); + + let slices: Vec<&[u64]> = raw_columns.iter().map(|c| c.as_slice()).collect(); + + let mut leaves: Vec<[u8; 32]> = Vec::with_capacity(lde_size); + unsafe { leaves.set_len(lde_size) }; + let hashed_bytes: &mut [u8] = unsafe { + std::slice::from_raw_parts_mut(leaves.as_mut_ptr() as *mut u8, lde_size * 32) + }; + + GPU_LDE_CALLS.fetch_add((columns.len() * 3) as u64, std::sync::atomic::Ordering::Relaxed); + GPU_LEAF_HASH_CALLS.fetch_add(1, std::sync::atomic::Ordering::Relaxed); + math_cuda::lde::coset_lde_batch_ext3_into_with_leaf_hash( + &slices, + n, + blowup_factor, + &weights_u64, + &mut raw_outputs, + hashed_bytes, + ) + .expect("GPU ext3 LDE+leaf-hash failed"); + + Some(leaves) +} + +/// Ext3 specialisation of [`try_expand_columns_batched`]. `E` is known to be +/// `Degree3GoldilocksExtensionField` by type_name match at the caller. +fn try_expand_columns_batched_ext3( + columns: &mut [Vec>], + blowup_factor: usize, + weights: &[FieldElement], +) -> bool +where + F: IsField, + E: IsField, +{ + if columns.is_empty() { + return true; + } + let n = columns[0].len(); + let lde_size = n.saturating_mul(blowup_factor); + + // SAFETY: caller confirmed `E == Degree3GoldilocksExtensionField` via + // type_name. That means `FieldElement` wraps `[FieldElement; 3]`, + // which is memory-equivalent to `[u64; 3]`. A `&[FieldElement]` of + // length `n` is therefore a contiguous `3 * n * 8` byte buffer. + let raw_columns: Vec> = columns + .iter() + .map(|col| { + let len = col.len() * 3; + let ptr = col.as_ptr() as *const u64; + // Copy rather than borrow: the caller still owns `col` and will + // reuse its backing storage after we resize + rewrite below. + unsafe { core::slice::from_raw_parts(ptr, len) }.to_vec() + }) + .collect(); + // F is `type_name::() == GoldilocksField` by caller precondition; + // `F::BaseType == u64`, so we can read each `w.value()` as a `*const u64`. + let weights_u64: Vec = weights + .iter() + .map(|w| unsafe { *(w.value() as *const _ as *const u64) }) + .collect(); + + // Pre-size each ext3 column to lde_size so its backing Vec has the right + // length for the output re-interleave. Capacity must already be >= + // lde_size (caller's `extract_columns_main(lde_size)` ensures this). + for col in columns.iter_mut() { + debug_assert!(col.capacity() >= lde_size); + // SAFETY: overwritten fully by the GPU path below. + unsafe { col.set_len(lde_size) }; + } + + // View each column's backing memory as a `&mut [u64]` of length + // `3*lde_size`. Safe because ext3 elements are `[u64; 3]` layouts. + let mut raw_outputs: Vec<&mut [u64]> = columns + .iter_mut() + .map(|col| { + let ptr = col.as_mut_ptr() as *mut u64; + let len = col.len() * 3; + unsafe { core::slice::from_raw_parts_mut(ptr, len) } + }) + .collect(); + + let slices: Vec<&[u64]> = raw_columns.iter().map(|c| c.as_slice()).collect(); + // Account each ext3 column as 3 logical GPU LDE "calls" (base-field + // components) so the counter matches the base-field batched path. + GPU_LDE_CALLS.fetch_add((columns.len() * 3) as u64, std::sync::atomic::Ordering::Relaxed); + math_cuda::lde::coset_lde_batch_ext3_into( + &slices, + n, + blowup_factor, + &weights_u64, + &mut raw_outputs, + ) + .expect("GPU batched ext3 coset LDE failed"); + true +} + +// ============================================================================ +// GPU barycentric OOD evaluation +// ============================================================================ +// +// Infrastructure for future use: these wrappers drive +// `math_cuda::barycentric::barycentric_{base,ext3}` and apply the trailing ext3 +// scalar on host. See the CPU reference in +// `crypto/math/src/polynomial/mod.rs::interpolate_coset_eval_*_with_g_n_inv`. +// +// NOT currently wired into the prover — a benchmark on fib_iterative_{1M, 4M} +// showed the CPU path (rayon over ~50 columns) already finishes in <1 ms wall +// because the GPU is busy with LDE and Merkle on parallel streams, so moving +// R3 OOD to the GPU just serialises work without freeing CPU wall time. +// Kept here and covered by parity tests in `crypto/math-cuda/tests/barycentric.rs` +// because it remains a net win for single-table or very-large-trace workloads. +// +// The GPU kernel returns the unscaled sum +// S = Σ_i point_i · eval_i · inv_denom_i +// per column; the final barycentric value is +// f(z) = scalar · (z^N − g^N) · S +// with `scalar = n_inv · g_n_inv` kept in the base field. + +static GPU_BARY_CALLS: std::sync::atomic::AtomicU64 = std::sync::atomic::AtomicU64::new(0); +pub fn gpu_bary_calls() -> u64 { + GPU_BARY_CALLS.load(std::sync::atomic::Ordering::Relaxed) +} + +static GPU_DEEP_CALLS: std::sync::atomic::AtomicU64 = std::sync::atomic::AtomicU64::new(0); +pub fn gpu_deep_calls() -> u64 { + GPU_DEEP_CALLS.load(std::sync::atomic::Ordering::Relaxed) +} + +static GPU_FRI_CALLS: std::sync::atomic::AtomicU64 = std::sync::atomic::AtomicU64::new(0); +pub fn gpu_fri_calls() -> u64 { + GPU_FRI_CALLS.load(std::sync::atomic::Ordering::Relaxed) +} + +/// GPU-resident FRI commit phase. Keeps evals, twiddles, and per-layer +/// trees on device across all folds. Mirrors +/// `commit_phase_from_evaluations` on CPU (transcript interleaving +/// unchanged — each layer's zeta is sampled from the host transcript, +/// each layer's root is D2H'd and appended there). +/// +/// Returns `None` to fall back to CPU (small domain, type mismatch, etc.). +#[allow(clippy::type_complexity)] +pub(crate) fn try_fri_commit_gpu( + number_layers: usize, + evals: &[FieldElement], + transcript: &mut impl crypto::fiat_shamir::is_transcript::IsStarkTranscript, + coset_offset: &FieldElement, + domain_size: usize, +) -> Option<( + FieldElement, + Vec>>, +)> +where + F: math::field::traits::IsFFTField + IsSubFieldOf, + E: IsField, + FieldElement: math::traits::AsBytes + Sync + Send, + FieldElement: math::traits::AsBytes + Sync + Send, +{ + use math::fft::cpu::bit_reversing::in_place_bit_reverse_permute; + use math::fft::cpu::roots_of_unity::get_powers_of_primitive_root_coset; + + if type_name::() != type_name::() { + return None; + } + if type_name::() != type_name::() { + return None; + } + if !domain_size.is_power_of_two() || domain_size < gpu_lde_threshold() { + return None; + } + if evals.len() != domain_size || number_layers < 1 { + return None; + } + if domain_size < (1 << 3) { + return None; + } + + GPU_FRI_CALLS.fetch_add(1, std::sync::atomic::Ordering::Relaxed); + + // Compute initial inv_twiddles on host — same recipe as + // `compute_coset_twiddles_inv`. + let half = domain_size / 2; + let order = domain_size.trailing_zeros() as u64; + let mut points = get_powers_of_primitive_root_coset(order, half, coset_offset) + .expect("coset twiddles available"); + in_place_bit_reverse_permute(&mut points); + FieldElement::inplace_batch_inverse(&mut points).expect("twiddle inverse"); + + // Raw u64 views: E == Ext3 (3 u64) for evals, F == Gl (1 u64) for twiddles. + let evals_raw: &[u64] = + unsafe { core::slice::from_raw_parts(evals.as_ptr() as *const u64, domain_size * 3) }; + let tw_raw: &[u64] = + unsafe { core::slice::from_raw_parts(points.as_ptr() as *const u64, half) }; + + let mut state = math_cuda::fri::FriCommitState::new(evals_raw, tw_raw, domain_size) + .expect("FRI state alloc"); + + let mut fri_layer_list = + Vec::>>::with_capacity(number_layers); + let mut current_coset_offset = coset_offset.clone(); + let mut current_domain_size = domain_size; + + for _ in 1..number_layers { + let zeta: FieldElement = transcript.sample_field_element(); + current_coset_offset = current_coset_offset.square(); + current_domain_size /= 2; + + // SAFETY: E == Ext3 (layout [u64; 3]). + let zeta_raw: [u64; 3] = unsafe { + let p = &zeta as *const FieldElement as *const u64; + [*p, *p.add(1), *p.add(2)] + }; + + let (root_bytes, layer_evals_raw, nodes_bytes) = + state.fold_and_commit_layer(zeta_raw).expect("FRI fold+commit"); + + let mut root_arr = [0u8; 32]; + root_arr.copy_from_slice(&root_bytes[..32]); + + // Re-chunk tree nodes into Vec<[u8; 32]> for MerkleTree. + let num_leaves = current_domain_size / 2; + let tight_total_nodes = 2 * num_leaves - 1; + debug_assert_eq!(nodes_bytes.len(), tight_total_nodes * 32); + let mut nodes: Vec<[u8; 32]> = Vec::with_capacity(tight_total_nodes); + for i in 0..tight_total_nodes { + let mut n = [0u8; 32]; + n.copy_from_slice(&nodes_bytes[i * 32..(i + 1) * 32]); + nodes.push(n); + } + let merkle_tree = + crypto::merkle_tree::merkle::MerkleTree::>::from_precomputed_nodes(nodes) + .expect("FRI MerkleTree build"); + + // Rebuild the layer's ext3 evals from raw u64s. + debug_assert_eq!(layer_evals_raw.len(), 3 * current_domain_size); + let mut layer_evals: Vec> = Vec::with_capacity(current_domain_size); + unsafe { layer_evals.set_len(current_domain_size) }; + unsafe { + core::ptr::copy_nonoverlapping( + layer_evals_raw.as_ptr(), + layer_evals.as_mut_ptr() as *mut u64, + current_domain_size * 3, + ); + } + + fri_layer_list.push(crate::fri::fri_commitment::FriLayer::new( + &layer_evals, + merkle_tree, + current_coset_offset.clone().to_extension(), + current_domain_size, + )); + + transcript.append_bytes(&root_arr); + } + + // Final fold. + let zeta: FieldElement = transcript.sample_field_element(); + let zeta_raw: [u64; 3] = unsafe { + let p = &zeta as *const FieldElement as *const u64; + [*p, *p.add(1), *p.add(2)] + }; + let last_raw = state.fold_final(zeta_raw).expect("FRI final fold"); + + // SAFETY: E == Ext3; build FieldElement from raw u64s. + let last_value: FieldElement = unsafe { + let mut e: FieldElement = core::mem::zeroed(); + let ptr = &mut e as *mut FieldElement as *mut u64; + *ptr = last_raw[0]; + *ptr.add(1) = last_raw[1]; + *ptr.add(2) = last_raw[2]; + e + }; + + transcript.append_field_element(&last_value); + + Some((last_value, fri_layer_list)) +} + +/// R3 OOD barycentric over the **main** (base-field) LDE read directly from +/// the device handle with stride `row_stride = blowup_factor`. Applies the +/// same trailing `scalar * vanishing * sum` ext3 scale on host that +/// `interpolate_coset_eval_with_g_n_inv` does. +#[allow(clippy::too_many_arguments)] +pub(crate) fn try_barycentric_base_on_handle( + lde_trace: &crate::trace::LDETraceTable, + row_stride: usize, + coset_points: &[FieldElement], + coset_offset_pow_n: &FieldElement, + n_inv: &FieldElement, + g_n_inv: &FieldElement, + z_pow_n: &FieldElement, + inv_denoms: &[FieldElement], +) -> Option>> +where + F: IsField + IsSubFieldOf, + E: IsField, +{ + if type_name::() != type_name::() { + return None; + } + if type_name::() != type_name::() { + return None; + } + let main = lde_trace.gpu_main()?; + let num_cols = main.m; + if num_cols == 0 { + return Some(Vec::new()); + } + let n = coset_points.len(); + if !n.is_power_of_two() || n < gpu_bary_threshold() { + return None; + } + if inv_denoms.len() != n || main.lde_size != n * row_stride { + return None; + } + + GPU_BARY_CALLS.fetch_add(1, std::sync::atomic::Ordering::Relaxed); + + let points_raw: &[u64] = + unsafe { core::slice::from_raw_parts(coset_points.as_ptr() as *const u64, n) }; + let inv_denoms_raw: &[u64] = + unsafe { core::slice::from_raw_parts(inv_denoms.as_ptr() as *const u64, 3 * n) }; + + let sums_raw = math_cuda::barycentric::barycentric_base_on_device( + main, + row_stride, + points_raw, + inv_denoms_raw, + n, + ) + .expect("GPU barycentric_base_on_device failed"); + + let scalar = ood_ext3_scalar::(coset_offset_pow_n, n_inv, g_n_inv, z_pow_n); + Some(apply_ext3_scalar::(&sums_raw, scalar, num_cols)) +} + +/// Ext3 counterpart reading the aux LDE handle. +#[allow(clippy::too_many_arguments)] +pub(crate) fn try_barycentric_ext3_on_handle( + lde_trace: &crate::trace::LDETraceTable, + row_stride: usize, + coset_points: &[FieldElement], + coset_offset_pow_n: &FieldElement, + n_inv: &FieldElement, + g_n_inv: &FieldElement, + z_pow_n: &FieldElement, + inv_denoms: &[FieldElement], +) -> Option>> +where + F: IsField + IsSubFieldOf, + E: IsField, +{ + if type_name::() != type_name::() { + return None; + } + if type_name::() != type_name::() { + return None; + } + let aux = lde_trace.gpu_aux()?; + let num_cols = aux.m; + if num_cols == 0 { + return Some(Vec::new()); + } + let n = coset_points.len(); + if !n.is_power_of_two() || n < gpu_bary_threshold() { + return None; + } + if inv_denoms.len() != n || aux.lde_size != n * row_stride { + return None; + } + + GPU_BARY_CALLS.fetch_add(1, std::sync::atomic::Ordering::Relaxed); + + let points_raw: &[u64] = + unsafe { core::slice::from_raw_parts(coset_points.as_ptr() as *const u64, n) }; + let inv_denoms_raw: &[u64] = + unsafe { core::slice::from_raw_parts(inv_denoms.as_ptr() as *const u64, 3 * n) }; + + let sums_raw = math_cuda::barycentric::barycentric_ext3_on_device( + aux, + row_stride, + points_raw, + inv_denoms_raw, + n, + ) + .expect("GPU barycentric_ext3_on_device failed"); + + let scalar = ood_ext3_scalar::(coset_offset_pow_n, n_inv, g_n_inv, z_pow_n); + Some(apply_ext3_scalar::(&sums_raw, scalar, num_cols)) +} + +/// GPU path for `compute_deep_composition_poly_evaluations`. Returns the N +/// trace-size coset evaluations of the deep-composition polynomial as a +/// `Vec>` (same type as the CPU path), or `None` when the +/// GPU is skipped (small tables, handle absent, type mismatch). +/// +/// Reads the main/aux LDE from the device handles stored on the +/// `LDETraceTable` by R1, avoiding a re-H2D of the largest tensor in R4. +/// Composition-parts LDE + scalar arrays are still H2D'd fresh each call. +#[allow(clippy::too_many_arguments)] +pub(crate) fn try_deep_composition_gpu( + lde_trace: &crate::trace::LDETraceTable, + h_lde_parts: &[Vec>], + h_parts_gpu: Option<&math_cuda::lde::GpuLdeExt3>, + h_ood: &[FieldElement], + trace_ood_cols: &[Vec>], // num_total_cols × num_eval_points + gammas_h: &[FieldElement], + gammas_tr_flat: &[Vec>], // num_total_cols × num_eval_points + inv_h: &[FieldElement], + inv_t: &[Vec>], // num_eval_points × domain_size + num_eval_points: usize, + blowup_factor: usize, + domain_size: usize, +) -> Option>> +where + F: IsField + IsSubFieldOf, + E: IsField, +{ + if type_name::() != type_name::() { + return None; + } + if type_name::() != type_name::() { + return None; + } + + let main_handle = lde_trace.gpu_main()?.clone(); + let aux_handle_opt = lde_trace.gpu_aux().cloned(); + let num_main = main_handle.m; + let lde_size = main_handle.lde_size; + if lde_size < gpu_lde_threshold() { + return None; + } + let num_aux = aux_handle_opt.as_ref().map(|a| a.m).unwrap_or(0); + let num_parts = h_lde_parts.len(); + let num_total_cols = num_main + num_aux; + + if h_lde_parts.iter().any(|p| p.len() != lde_size) { + return None; + } + + GPU_DEEP_CALLS.fetch_add(1, std::sync::atomic::Ordering::Relaxed); + + // If a device handle is present for h_parts, skip the host-side pack. + // Falls back to packing Vec> → flat u64 and H2D'ing in the + // impl otherwise. + let h_flat_opt: Option> = if h_parts_gpu.is_some() { + None + } else { + let mut h_flat = vec![0u64; num_parts * 3 * lde_size]; + #[cfg(feature = "parallel")] + let iter = h_lde_parts.par_iter().enumerate(); + #[cfg(not(feature = "parallel"))] + let iter = h_lde_parts.iter().enumerate(); + let ptr = h_flat.as_mut_ptr() as usize; + iter.for_each(|(p, col)| { + // SAFETY: E == Ext3; FieldElement is [u64; 3] at runtime. + let src = unsafe { core::slice::from_raw_parts(col.as_ptr() as *const u64, lde_size * 3) }; + unsafe { + let base = ptr as *mut u64; + let slab0 = base.add((p * 3) * lde_size); + let slab1 = base.add((p * 3 + 1) * lde_size); + let slab2 = base.add((p * 3 + 2) * lde_size); + for r in 0..lde_size { + *slab0.add(r) = src[r * 3]; + *slab1.add(r) = src[r * 3 + 1]; + *slab2.add(r) = src[r * 3 + 2]; + } + } + }); + Some(h_flat) + }; + + // Pack scalar arrays: h_ood, trace_ood, gammas_h, gammas_tr, inv_h, inv_t. + let e3_raw = |e: &FieldElement| -> [u64; 3] { + // SAFETY: E == Ext3; memory layout [u64; 3]. + unsafe { + let p = e as *const FieldElement as *const u64; + [*p, *p.add(1), *p.add(2)] + } + }; + + let mut h_ood_flat = vec![0u64; num_parts * 3]; + for (j, e) in h_ood.iter().enumerate() { + let v = e3_raw(e); + h_ood_flat[j * 3] = v[0]; + h_ood_flat[j * 3 + 1] = v[1]; + h_ood_flat[j * 3 + 2] = v[2]; + } + assert_eq!(trace_ood_cols.len(), num_total_cols); + let mut trace_ood_flat = vec![0u64; num_total_cols * num_eval_points * 3]; + for (j, col) in trace_ood_cols.iter().enumerate() { + debug_assert_eq!(col.len(), num_eval_points); + for (k, e) in col.iter().enumerate() { + let v = e3_raw(e); + let idx = (j * num_eval_points + k) * 3; + trace_ood_flat[idx] = v[0]; + trace_ood_flat[idx + 1] = v[1]; + trace_ood_flat[idx + 2] = v[2]; + } + } + let mut gammas_h_flat = vec![0u64; num_parts * 3]; + for (j, e) in gammas_h.iter().enumerate() { + let v = e3_raw(e); + gammas_h_flat[j * 3] = v[0]; + gammas_h_flat[j * 3 + 1] = v[1]; + gammas_h_flat[j * 3 + 2] = v[2]; + } + assert_eq!(gammas_tr_flat.len(), num_total_cols); + let mut gammas_tr_out = vec![0u64; num_total_cols * num_eval_points * 3]; + for (j, col) in gammas_tr_flat.iter().enumerate() { + debug_assert_eq!(col.len(), num_eval_points); + for (k, e) in col.iter().enumerate() { + let v = e3_raw(e); + let idx = (j * num_eval_points + k) * 3; + gammas_tr_out[idx] = v[0]; + gammas_tr_out[idx + 1] = v[1]; + gammas_tr_out[idx + 2] = v[2]; + } + } + // SAFETY: E == Ext3; each FieldElement is `[u64; 3]`. Cast the + // contiguous Vec> layer to a `&[u64]` and memcpy once, + // instead of a per-element u64 copy loop. + let inv_h_flat: Vec = unsafe { + core::slice::from_raw_parts(inv_h.as_ptr() as *const u64, inv_h.len() * 3) + } + .to_vec(); + assert_eq!(inv_t.len(), num_eval_points); + let mut inv_t_flat: Vec = Vec::with_capacity(num_eval_points * domain_size * 3); + unsafe { inv_t_flat.set_len(num_eval_points * domain_size * 3) }; + { + let dst_ptr = inv_t_flat.as_mut_ptr() as usize; + #[cfg(feature = "parallel")] + let iter = (0..num_eval_points).into_par_iter(); + #[cfg(not(feature = "parallel"))] + let iter = 0..num_eval_points; + iter.for_each(|k| { + let layer = &inv_t[k]; + let src = unsafe { + core::slice::from_raw_parts(layer.as_ptr() as *const u64, domain_size * 3) + }; + unsafe { + let dst = (dst_ptr as *mut u64).add(k * domain_size * 3); + core::ptr::copy_nonoverlapping(src.as_ptr(), dst, domain_size * 3); + } + }); + } + + let raw_out = if let Some(h_gpu) = h_parts_gpu { + math_cuda::deep::deep_composition_ext3_with_dev_parts( + &main_handle, + aux_handle_opt.as_ref(), + h_gpu, + &h_ood_flat, + &trace_ood_flat, + &gammas_h_flat, + &gammas_tr_out, + &inv_h_flat, + &inv_t_flat, + num_parts, + num_main, + num_aux, + num_eval_points, + blowup_factor, + domain_size, + ) + .expect("GPU deep composition (dev parts) failed") + } else { + math_cuda::deep::deep_composition_ext3( + &main_handle, + aux_handle_opt.as_ref(), + h_flat_opt.as_ref().expect("host h_flat packed").as_slice(), + &h_ood_flat, + &trace_ood_flat, + &gammas_h_flat, + &gammas_tr_out, + &inv_h_flat, + &inv_t_flat, + num_parts, + num_main, + num_aux, + num_eval_points, + blowup_factor, + domain_size, + ) + .expect("GPU deep composition failed") + }; + + // Transmute raw u64s → FieldElement. Requires E == Ext3 layout, which + // the type_name check above verifies. + let mut out: Vec> = Vec::with_capacity(domain_size); + unsafe { out.set_len(domain_size) }; + let dst_ptr = out.as_mut_ptr() as *mut u64; + unsafe { + core::ptr::copy_nonoverlapping(raw_out.as_ptr(), dst_ptr, domain_size * 3); + } + Some(out) +} + +// ============================================================================ +// GPU Merkle inner-tree construction +// ============================================================================ +// +// After the GPU keccak leaf-hash kernels produce a flat `[u8; 32]` leaf vec, +// the inner tree construction on CPU via `build_from_hashed_leaves` is a +// rayon-parallel pair-hash scan that still takes ~50-100 ms per table on a +// 46-core host. Delegating it to `math_cuda::merkle::build_merkle_tree_on_device` +// pushes it below 10 ms — the leaf buffer is already on host (it came out of +// `try_expand_and_leaf_hash_batched`), we H2D it once, the GPU does ~log₂(N) +// small kernel launches, and we D2H the full `2*leaves_len - 1` node array. + +static GPU_MERKLE_TREE_CALLS: std::sync::atomic::AtomicU64 = std::sync::atomic::AtomicU64::new(0); +pub fn gpu_merkle_tree_calls() -> u64 { + GPU_MERKLE_TREE_CALLS.load(std::sync::atomic::Ordering::Relaxed) +} + +/// FRI layers shrink by 2× each round; the last few layers are tiny. Below +/// this leaf count, keep the tree build on CPU. +#[allow(dead_code)] +const DEFAULT_GPU_FRI_TREE_THRESHOLD: usize = 1 << 19; + +#[allow(dead_code)] +fn gpu_fri_tree_threshold() -> usize { + static CACHED: std::sync::OnceLock = std::sync::OnceLock::new(); + *CACHED.get_or_init(|| { + std::env::var("LAMBDA_VM_GPU_FRI_TREE_THRESHOLD") + .ok() + .and_then(|s| s.parse().ok()) + .unwrap_or(DEFAULT_GPU_FRI_TREE_THRESHOLD) + }) +} + +/// Build a Merkle tree from already-hashed leaves using the GPU pair-hash +/// kernel. Returns the filled `MerkleTree` in the same layout as the CPU +/// `build_from_hashed_leaves` would produce — plug straight in anywhere the +/// prover expected that. +/// +/// Returns `None` if the GPU path is disabled by threshold (`leaves_len < +/// GPU_MERKLE_TREE_THRESHOLD`), falling back to the caller's CPU path. +/// +/// Currently unwired in the prover: benchmarking showed the savings from +/// the GPU pair-hash are eaten by the H2D of leaves + D2H of the tree +/// because the leaves are in pageable memory (they're the caller's Vec from +/// `try_expand_and_leaf_hash_batched`). A proper fusion would keep the +/// leaf buffer on device and run the tree kernel immediately on the GPU +/// copy — left as future work. +#[allow(dead_code)] +pub(crate) fn try_build_merkle_tree_gpu( + hashed_leaves: &[B::Node], +) -> Option> +where + B: crypto::merkle_tree::traits::IsMerkleTreeBackend, +{ + let leaves_len = hashed_leaves.len(); + if leaves_len < gpu_merkle_tree_threshold() || !leaves_len.is_power_of_two() || leaves_len < 2 { + return None; + } + GPU_MERKLE_TREE_CALLS.fetch_add(1, std::sync::atomic::Ordering::Relaxed); + + // Flatten host-side leaves into a contiguous byte buffer for the GPU + // kernel. SAFETY: `[u8; 32]` is POD and the slice is contiguous. + let leaves_bytes: &[u8] = unsafe { + core::slice::from_raw_parts(hashed_leaves.as_ptr() as *const u8, leaves_len * 32) + }; + let nodes_bytes = math_cuda::merkle::build_merkle_tree_on_device(leaves_bytes) + .expect("GPU merkle tree build failed"); + + let total_nodes = 2 * leaves_len - 1; + debug_assert_eq!(nodes_bytes.len(), total_nodes * 32); + + // Re-chunk into `Vec<[u8; 32]>` without re-allocating. We'd need an + // explicit copy because Vec and Vec<[u8; 32]> have different + // layouts in the allocator metadata (align differs on some platforms). + let mut nodes: Vec<[u8; 32]> = Vec::with_capacity(total_nodes); + for i in 0..total_nodes { + let mut n = [0u8; 32]; + n.copy_from_slice(&nodes_bytes[i * 32..(i + 1) * 32]); + nodes.push(n); + } + + crypto::merkle_tree::merkle::MerkleTree::::from_precomputed_nodes(nodes) +} + +/// Below this (tree size), stay on CPU — rayon pair-hash is already well +/// under a millisecond for small N and would lose to any PCIe round-trip. +const DEFAULT_GPU_MERKLE_TREE_THRESHOLD: usize = 1 << 15; + +fn gpu_merkle_tree_threshold() -> usize { + static CACHED: std::sync::OnceLock = std::sync::OnceLock::new(); + *CACHED.get_or_init(|| { + std::env::var("LAMBDA_VM_GPU_MERKLE_TREE_THRESHOLD") + .ok() + .and_then(|s| s.parse().ok()) + .unwrap_or(DEFAULT_GPU_MERKLE_TREE_THRESHOLD) + }) +} + +/// Below this (trace-size) barycentric length we stay on CPU — the rayon path +/// already completes in well under a millisecond and PCIe round-trip would +/// dominate. +#[allow(dead_code)] +const DEFAULT_GPU_BARY_THRESHOLD: usize = 1 << 14; + +#[allow(dead_code)] +fn gpu_bary_threshold() -> usize { + static CACHED: std::sync::OnceLock = std::sync::OnceLock::new(); + *CACHED.get_or_init(|| { + std::env::var("LAMBDA_VM_GPU_BARY_THRESHOLD") + .ok() + .and_then(|s| s.parse().ok()) + .unwrap_or(DEFAULT_GPU_BARY_THRESHOLD) + }) +} + +/// One ext3 scalar `(n_inv · g_n_inv) · (z^N − g^N)`; caller reads as `[u64;3]`. +#[allow(dead_code)] +fn ood_ext3_scalar( + coset_offset_pow_n: &FieldElement, + n_inv: &FieldElement, + g_n_inv: &FieldElement, + z_pow_n: &FieldElement, +) -> [u64; 3] +where + F: IsField + IsSubFieldOf, + E: IsField, +{ + // (z^N − g^N) in E — done via sub_subfield (E − F → E). + let vanishing = z_pow_n.sub_subfield(coset_offset_pow_n); + let base_scalar = n_inv * g_n_inv; // F × F → F + let scalar_ext3: FieldElement = &base_scalar * &vanishing; // F × E → E + // SAFETY: E == Degree3Goldilocks; backing is `[FieldElement; 3]` + // which is memory-equivalent to `[u64; 3]`. + let ptr = &scalar_ext3 as *const FieldElement as *const u64; + unsafe { [*ptr, *ptr.add(1), *ptr.add(2)] } +} + +/// Multiply each raw GPU ext3 sum by the host-computed ext3 scalar. +/// `sums_raw` is `3 * num_cols` u64s (interleaved). +#[allow(dead_code)] +fn apply_ext3_scalar( + sums_raw: &[u64], + scalar: [u64; 3], + num_cols: usize, +) -> Vec> +where + E: IsField, +{ + use math::field::extensions_goldilocks::Degree3GoldilocksExtensionField; + use math::field::goldilocks::GoldilocksField; + type Gl = GoldilocksField; + type Ext3 = Degree3GoldilocksExtensionField; + + debug_assert_eq!(sums_raw.len(), 3 * num_cols); + debug_assert_eq!(type_name::(), type_name::()); + + let scalar_e: FieldElement = FieldElement::::new([ + FieldElement::::from_raw(scalar[0]), + FieldElement::::from_raw(scalar[1]), + FieldElement::::from_raw(scalar[2]), + ]); + + let mut out: Vec> = Vec::with_capacity(num_cols); + for c in 0..num_cols { + let s: FieldElement = FieldElement::::new([ + FieldElement::::from_raw(sums_raw[c * 3]), + FieldElement::::from_raw(sums_raw[c * 3 + 1]), + FieldElement::::from_raw(sums_raw[c * 3 + 2]), + ]); + let final_ext3 = &s * &scalar_e; + // SAFETY: E == Ext3 at runtime; same layout. + let final_e: FieldElement = unsafe { + core::mem::transmute_copy::, FieldElement>(&final_ext3) + }; + out.push(final_e); + } + out +} + +/// Batched barycentric OOD evaluation over M base-field columns at a single +/// ext3 evaluation point. Returns `Some(vec_of_M_ext3)` on GPU dispatch, or +/// `None` if the caller should fall back to CPU. +#[allow(dead_code)] +pub(crate) fn try_barycentric_base_ood_gpu( + columns: &[Vec>], + coset_points: &[FieldElement], + coset_offset_pow_n: &FieldElement, + n_inv: &FieldElement, + g_n_inv: &FieldElement, + z_pow_n: &FieldElement, + inv_denoms: &[FieldElement], +) -> Option>> +where + F: IsField + IsSubFieldOf, + E: IsField, +{ + let num_cols = columns.len(); + if num_cols == 0 { + return Some(Vec::new()); + } + let n = columns[0].len(); + if !n.is_power_of_two() || n < gpu_bary_threshold() { + return None; + } + if coset_points.len() != n || inv_denoms.len() != n { + return None; + } + if type_name::() != type_name::() { + return None; + } + if type_name::() != type_name::() { + return None; + } + // All columns must share the same length `n`. + for c in columns.iter() { + if c.len() != n { + return None; + } + } + + GPU_BARY_CALLS.fetch_add(1, std::sync::atomic::Ordering::Relaxed); + + // Pack columns contiguously: column c at offset c*n. Skip the zero-fill + // prologue — we overwrite every byte below. `set_len` before write is + // safe because `u64` has no drop glue. + let total = num_cols * n; + let mut columns_flat: Vec = Vec::with_capacity(total); + unsafe { columns_flat.set_len(total) }; + { + // Parallel pack: each column's slab is independent. + let flat_ptr = columns_flat.as_mut_ptr() as usize; + #[cfg(feature = "parallel")] + let iter = (0..num_cols).into_par_iter(); + #[cfg(not(feature = "parallel"))] + let iter = 0..num_cols; + iter.for_each(|c| { + // SAFETY: disjoint slabs; no two `c`s overlap. F == Goldilocks. + unsafe { + let dst = (flat_ptr as *mut u64).add(c * n); + let src = columns[c].as_ptr() as *const u64; + core::ptr::copy_nonoverlapping(src, dst, n); + } + }); + } + let points_raw: &[u64] = + unsafe { core::slice::from_raw_parts(coset_points.as_ptr() as *const u64, n) }; + let inv_denoms_raw: &[u64] = + unsafe { core::slice::from_raw_parts(inv_denoms.as_ptr() as *const u64, 3 * n) }; + + let sums_raw = math_cuda::barycentric::barycentric_base( + &columns_flat, + n, + points_raw, + inv_denoms_raw, + n, + num_cols, + ) + .expect("GPU barycentric_base failed"); + + let scalar = ood_ext3_scalar::(coset_offset_pow_n, n_inv, g_n_inv, z_pow_n); + Some(apply_ext3_scalar::(&sums_raw, scalar, num_cols)) +} + +/// Batched barycentric OOD evaluation over M ext3 columns at a single ext3 +/// evaluation point. Same contract as [`try_barycentric_base_ood_gpu`]. +#[allow(dead_code)] +pub(crate) fn try_barycentric_ext3_ood_gpu( + columns: &[Vec>], + coset_points: &[FieldElement], + coset_offset_pow_n: &FieldElement, + n_inv: &FieldElement, + g_n_inv: &FieldElement, + z_pow_n: &FieldElement, + inv_denoms: &[FieldElement], +) -> Option>> +where + F: IsField + IsSubFieldOf, + E: IsField, +{ + let num_cols = columns.len(); + if num_cols == 0 { + return Some(Vec::new()); + } + let n = columns[0].len(); + if !n.is_power_of_two() || n < gpu_bary_threshold() { + return None; + } + if coset_points.len() != n || inv_denoms.len() != n { + return None; + } + if type_name::() != type_name::() { + return None; + } + if type_name::() != type_name::() { + return None; + } + for c in columns.iter() { + if c.len() != n { + return None; + } + } + + GPU_BARY_CALLS.fetch_add(1, std::sync::atomic::Ordering::Relaxed); + + // De-interleaved layout: slab (c*3 + k) at offset (c*3+k)*n. Skip + // zero-fill (we overwrite every byte). Parallelise the de-interleave. + let total = num_cols * 3 * n; + let mut columns_flat: Vec = Vec::with_capacity(total); + unsafe { columns_flat.set_len(total) }; + { + let flat_ptr = columns_flat.as_mut_ptr() as usize; + #[cfg(feature = "parallel")] + let iter = (0..num_cols).into_par_iter(); + #[cfg(not(feature = "parallel"))] + let iter = 0..num_cols; + iter.for_each(|c| { + // SAFETY: E == Ext3 whose BaseType is [FieldElement;3] = + // contiguous [u64;3] at runtime; disjoint per-c slabs. + unsafe { + let src = columns[c].as_ptr() as *const u64; + let base = flat_ptr as *mut u64; + let slab0 = base.add((c * 3) * n); + let slab1 = base.add((c * 3 + 1) * n); + let slab2 = base.add((c * 3 + 2) * n); + for r in 0..n { + *slab0.add(r) = *src.add(r * 3); + *slab1.add(r) = *src.add(r * 3 + 1); + *slab2.add(r) = *src.add(r * 3 + 2); + } + } + }); + } + let points_raw: &[u64] = + unsafe { core::slice::from_raw_parts(coset_points.as_ptr() as *const u64, n) }; + let inv_denoms_raw: &[u64] = + unsafe { core::slice::from_raw_parts(inv_denoms.as_ptr() as *const u64, 3 * n) }; + + let sums_raw = math_cuda::barycentric::barycentric_ext3( + &columns_flat, + n, + points_raw, + inv_denoms_raw, + n, + num_cols, + ) + .expect("GPU barycentric_ext3 failed"); + + let scalar = ood_ext3_scalar::(coset_offset_pow_n, n_inv, g_n_inv, z_pow_n); + Some(apply_ext3_scalar::(&sums_raw, scalar, num_cols)) +} diff --git a/crypto/stark/src/lib.rs b/crypto/stark/src/lib.rs index 09ca16ed4..39005dc1c 100644 --- a/crypto/stark/src/lib.rs +++ b/crypto/stark/src/lib.rs @@ -8,10 +8,14 @@ pub mod domain; pub mod examples; pub mod frame; pub mod fri; +#[cfg(feature = "cuda")] +pub mod gpu_lde; pub mod grinding; #[cfg(feature = "instruments")] pub mod instruments; pub mod lookup; +#[cfg(feature = "cuda")] +pub mod logup_gpu; pub mod proof; pub mod prover; pub mod table; diff --git a/crypto/stark/src/logup_gpu.rs b/crypto/stark/src/logup_gpu.rs new file mode 100644 index 000000000..403cacaa6 --- /dev/null +++ b/crypto/stark/src/logup_gpu.rs @@ -0,0 +1,567 @@ +//! GPU dispatch for `compute_logup_batched_term_column`: takes a pair of +//! `BusInteraction`s and evaluates the full fingerprint + batch-invert + +//! term-assembly pipeline on the device. +//! +//! The serializer lives here (on the stark side) so it can see the +//! `BusValue` / `Multiplicity` types without creating a math ↔ stark +//! dependency cycle. The matching kernels are in +//! `crypto/math-cuda/kernels/logup.cu`. +//! +//! Only Goldilocks main trace + Fp3 extension is supported — everything +//! else returns `None` and the caller runs the CPU path. +//! +//! Canonicalization contract: all coefficients the GPU sees are already +//! canonical Goldilocks field elements in `[0, p)`. The kernels do not +//! sign-handle; they treat every `value` as a plain u64 coefficient. + +use core::any::type_name; + +use math::field::element::FieldElement; +use math::field::extensions_goldilocks::Degree3GoldilocksExtensionField; +use math::field::goldilocks::GoldilocksField; +use math::field::traits::{IsField, IsSubFieldOf}; +use math_cuda::logup::{ + self, FingerprintOp, LinearTerm as GpuLinearTerm, MultiplicityDesc, +}; + +use crate::lookup::{BusInteraction, BusValue, LinearTerm, Multiplicity, Packing}; + +/// Goldilocks modulus p = 2^64 - 2^32 + 1. Used for the i64/u64 → canonical +/// u64 reduction so the kernel never sees a non-canonical coefficient. +const GOLDILOCKS_P: u64 = 0xFFFF_FFFF_0000_0001; + +#[inline(always)] +fn canonical_u64(v: u64) -> u64 { + if v >= GOLDILOCKS_P { v - GOLDILOCKS_P } else { v } +} + +#[inline(always)] +fn canonical_i64(v: i64) -> u64 { + if v >= 0 { + canonical_u64(v as u64) + } else { + // -|v| mod p = p - (|v| mod p). |v| fits in u64 since i64 only + // reaches -2^63 < 2^64, and (-v) as u64 handles that cleanly. + let abs = v.unsigned_abs(); + let c = canonical_u64(abs); + if c == 0 { 0 } else { GOLDILOCKS_P - c } + } +} + +/// Serializer output — one call's full bytecode bundle. Owned so the host +/// wrapper can upload contiguous slices without extra copies. +pub struct PairBytecode { + pub ops_a: Vec, + pub ops_b: Vec, + pub linear_terms: Vec, + pub mult_a: MultiplicityDesc, + pub mult_b: MultiplicityDesc, + pub bus_id_a: u64, + pub bus_id_b: u64, + pub negate_a: bool, + pub negate_b: bool, + pub max_bus_elements: usize, +} + +/// Translate one interaction's `values` into a list of `FingerprintOp`s, +/// appending any referenced LinearTerms to `pool`. `alpha_offset_start` +/// should be 1 (slot 0 is reserved for bus_id * alpha[0]). +fn encode_bus_values( + values: &[BusValue], + alpha_offset_start: usize, + pool: &mut Vec, +) -> Vec { + let mut ops = Vec::with_capacity(values.len()); + let mut alpha_offset = alpha_offset_start as u32; + for bv in values { + match bv { + BusValue::Packed { start_column, packing } => { + let kind = packing_to_op_kind(*packing); + let consumed = packing.num_bus_elements() as u32; + ops.push(FingerprintOp { + kind, + pad0: [0; 3], + alpha_offset, + start_col: *start_column as u32, + num_linear_terms: 0, + linear_term_offset: 0, + pad1: [0; 2], + }); + alpha_offset += consumed; + } + BusValue::Linear(terms) => { + let offset = pool.len() as u32; + for t in terms { + pool.push(lower_linear_term(t)); + } + ops.push(FingerprintOp { + kind: logup::OP_LINEAR, + pad0: [0; 3], + alpha_offset, + start_col: 0, + num_linear_terms: terms.len() as u32, + linear_term_offset: offset, + pad1: [0; 2], + }); + alpha_offset += 1; + } + } + } + ops +} + +fn packing_to_op_kind(p: Packing) -> u8 { + match p { + Packing::Direct => logup::OP_PACK_DIRECT, + Packing::Word2L => logup::OP_PACK_WORD2L, + Packing::Word4L => logup::OP_PACK_WORD4L, + Packing::DWordWL => logup::OP_PACK_DWORDWL, + Packing::DWordHHW => logup::OP_PACK_DWORDHHW, + Packing::DWordWHH => logup::OP_PACK_DWORDWHH, + Packing::DWordHL => logup::OP_PACK_DWORDHL, + Packing::DWordBL => logup::OP_PACK_DWORDBL, + Packing::QuadHL => logup::OP_PACK_QUADHL, + Packing::QuadWL => logup::OP_PACK_QUADWL, + } +} + +fn lower_linear_term(t: &LinearTerm) -> GpuLinearTerm { + match *t { + LinearTerm::Column { coefficient, column } => GpuLinearTerm { + kind: logup::LT_KIND_COLUMN, + pad: [0; 3], + column: column as u32, + value: canonical_i64(coefficient), + }, + LinearTerm::ColumnUnsigned { coefficient, column } => GpuLinearTerm { + kind: logup::LT_KIND_COLUMN, + pad: [0; 3], + column: column as u32, + value: canonical_u64(coefficient), + }, + LinearTerm::Constant(value) => GpuLinearTerm { + kind: logup::LT_KIND_CONSTANT, + pad: [0; 3], + column: 0, + value: canonical_i64(value), + }, + } +} + +fn encode_multiplicity( + m: &Multiplicity, + pool: &mut Vec, +) -> MultiplicityDesc { + match m { + Multiplicity::One => MultiplicityDesc { + kind: logup::MULT_ONE, + ..Default::default() + }, + Multiplicity::Column(c) => MultiplicityDesc { + kind: logup::MULT_COLUMN, + cols: [*c as u32, 0, 0], + ..Default::default() + }, + Multiplicity::Sum(a, b) => MultiplicityDesc { + kind: logup::MULT_SUM, + cols: [*a as u32, *b as u32, 0], + ..Default::default() + }, + Multiplicity::Negated(c) => MultiplicityDesc { + kind: logup::MULT_NEGATED, + cols: [*c as u32, 0, 0], + ..Default::default() + }, + Multiplicity::Diff(a, b) => MultiplicityDesc { + kind: logup::MULT_DIFF, + cols: [*a as u32, *b as u32, 0], + ..Default::default() + }, + Multiplicity::Sum3(a, b, c) => MultiplicityDesc { + kind: logup::MULT_SUM3, + cols: [*a as u32, *b as u32, *c as u32], + ..Default::default() + }, + Multiplicity::Linear(terms) => { + let offset = pool.len() as u32; + for t in terms { + pool.push(lower_linear_term(t)); + } + MultiplicityDesc { + kind: logup::MULT_LINEAR, + cols: [0; 3], + num_linear_terms: terms.len() as u32, + linear_term_offset: offset, + ..Default::default() + } + } + } +} + +/// Serialize a pair of interactions into the shared bytecode form. +pub fn build_pair_bytecode( + interaction_a: &BusInteraction, + interaction_b: &BusInteraction, +) -> PairBytecode { + let mut linear_terms: Vec = Vec::new(); + let ops_a = encode_bus_values(&interaction_a.values, 1, &mut linear_terms); + let ops_b = encode_bus_values(&interaction_b.values, 1, &mut linear_terms); + let mult_a = encode_multiplicity(&interaction_a.multiplicity, &mut linear_terms); + let mult_b = encode_multiplicity(&interaction_b.multiplicity, &mut linear_terms); + let max_bus_elements = interaction_a + .num_bus_elements() + .max(interaction_b.num_bus_elements()); + PairBytecode { + ops_a, + ops_b, + linear_terms, + mult_a, + mult_b, + bus_id_a: interaction_a.bus_id, + bus_id_b: interaction_b.bus_id, + negate_a: !interaction_a.is_sender, + negate_b: !interaction_b.is_sender, + max_bus_elements, + } +} + +/// Flatten `main_segment_cols` into column-major u64. SAFETY: the caller +/// must have verified that `F == GoldilocksField` so that each +/// `FieldElement` is representationally a single u64. +unsafe fn flatten_main_cols(main_segment_cols: &[Vec>]) -> Vec +where + F: IsField, +{ + if main_segment_cols.is_empty() { + return Vec::new(); + } + let n = main_segment_cols[0].len(); + let num_cols = main_segment_cols.len(); + let mut out = Vec::with_capacity(num_cols * n); + for col in main_segment_cols { + debug_assert_eq!(col.len(), n); + for e in col { + out.push(unsafe { *(e.value() as *const _ as *const u64) }); + } + } + out +} + +/// Convert a Fp3 FieldElement to its raw `[u64; 3]` ext3 triple. The kernel +/// tolerates non-canonical inputs (it partial-reduces), so we skip the +/// extra canonicalization step and read raw u64 bits. +/// SAFETY: the caller must have verified `E == Degree3GoldilocksExtensionField`. +unsafe fn ext3_to_triple(e: &FieldElement) -> [u64; 3] { + let ptr = e.value() as *const _ as *const [FieldElement; 3]; + let triple = unsafe { &*ptr }; + [ + unsafe { *(triple[0].value() as *const _ as *const u64) }, + unsafe { *(triple[1].value() as *const _ as *const u64) }, + unsafe { *(triple[2].value() as *const _ as *const u64) }, + ] +} + +/// Per-pair GPU-vs-CPU threshold on `trace_len`. Below this, the per-pair +/// overhead (main-cols H2D + kernel launches + 2n×3 D2H) dominates and the +/// rayon-parallel CPU path wins. Set conservatively; override via env var +/// for experiments. +const DEFAULT_GPU_LOGUP_THRESHOLD: usize = usize::MAX; + +fn gpu_logup_threshold() -> usize { + static CACHED: std::sync::OnceLock = std::sync::OnceLock::new(); + *CACHED.get_or_init(|| { + std::env::var("LAMBDA_VM_GPU_LOGUP_THRESHOLD") + .ok() + .and_then(|s| s.parse().ok()) + .unwrap_or(DEFAULT_GPU_LOGUP_THRESHOLD) + }) +} + +static GPU_LOGUP_CALLS: std::sync::atomic::AtomicU64 = std::sync::atomic::AtomicU64::new(0); + +/// Serializes GPU dispatch across the rayon-parallel table loop in Pass 1. +/// The prover runs `build_auxiliary_trace` concurrently for every table, +/// and without this lock ~12 streams compete for the GPU: H2D of each +/// table's 240 MB main_cols saturates PCIe, and kernel launches fight +/// for SM time. Serializing with a mutex trades a bit of CPU idle for +/// a clean GPU pipeline. +/// +/// Only the GPU-bound portion is under the lock — flatten_main_cols +/// (host copy) and the final triples → FieldElement reassembly both +/// run outside, so rayon still gets to overlap CPU work across tables. +static GPU_LOGUP_LOCK: std::sync::Mutex<()> = std::sync::Mutex::new(()); + +pub fn gpu_logup_calls() -> u64 { + GPU_LOGUP_CALLS.load(std::sync::atomic::Ordering::Relaxed) +} + +pub fn reset_gpu_logup_calls() { + GPU_LOGUP_CALLS.store(0, std::sync::atomic::Ordering::Relaxed); +} + +#[inline] +fn gpu_supported(trace_len: usize) -> bool { + if type_name::() != type_name::() { + return false; + } + if type_name::() != type_name::() { + return false; + } + trace_len >= gpu_logup_threshold() +} + +/// Batch-compute all committed pair term columns (and optionally the +/// absorbed virtual pair) on GPU for one table. Uploads main_cols exactly +/// once per table — this is the win vs. per-pair dispatch. +/// +/// Returns `None` if the F/E type combination isn't supported; caller +/// falls back to the rayon CPU path entirely. +pub fn try_compute_table_term_columns( + interactions: &[BusInteraction], + main_segment_cols: &[Vec>], + trace_len: usize, + challenges: &[FieldElement], +) -> Option> +where + F: IsField + IsSubFieldOf, + E: IsField, +{ + if !gpu_supported::(trace_len) { + return None; + } + + let (num_committed_pairs, absorbed_count) = + crate::lookup::split_interactions(interactions.len()); + + // ── CPU prep (no lock held — many tables can prep in parallel) ── + + let alpha = &challenges[crate::lookup::LOGUP_CHALLENGE_ALPHA]; + let z = &challenges[0]; + let z_triple = unsafe { ext3_to_triple(z) }; + let main_cols_u64 = unsafe { flatten_main_cols(main_segment_cols) }; + + enum PreppedCall { + Pair(PairBytecode, Vec), + Single(SingleBytecode, Vec), + } + let mut prepped: Vec = Vec::with_capacity(num_committed_pairs + 1); + + for i in 0..num_committed_pairs { + let a = &interactions[i * 2]; + let b = &interactions[i * 2 + 1]; + let bytecode = build_pair_bytecode(a, b); + let alpha_powers_u64 = alpha_powers_u64_vec(alpha, bytecode.max_bus_elements); + prepped.push(PreppedCall::Pair(bytecode, alpha_powers_u64)); + } + + match absorbed_count { + 0 => {} + 2 => { + let a = &interactions[interactions.len() - 2]; + let b = &interactions[interactions.len() - 1]; + let bytecode = build_pair_bytecode(a, b); + let alpha_powers_u64 = alpha_powers_u64_vec(alpha, bytecode.max_bus_elements); + prepped.push(PreppedCall::Pair(bytecode, alpha_powers_u64)); + } + 1 => { + let a = &interactions[interactions.len() - 1]; + let mut pool: Vec = Vec::new(); + let ops = encode_bus_values(&a.values, 1, &mut pool); + let mult = encode_multiplicity(&a.multiplicity, &mut pool); + let max_bus_elements = a.num_bus_elements(); + let alpha_powers_u64 = alpha_powers_u64_vec(alpha, max_bus_elements); + let bytecode = SingleBytecode { + ops, + linear_terms: pool, + mult, + bus_id: a.bus_id, + negate: !a.is_sender, + }; + prepped.push(PreppedCall::Single(bytecode, alpha_powers_u64)); + } + _ => unreachable!("absorbed_count must be 0, 1, or 2"), + }; + + // ── GPU dispatch (lock held — tables run serially on device) ── + let raw_results: Vec> = { + let _guard = GPU_LOGUP_LOCK + .lock() + .expect("GPU LogUp lock poisoned"); + + let device_main = math_cuda::logup::upload_main_cols( + &main_cols_u64, + main_segment_cols.len(), + trace_len, + ) + .ok()?; + + let mut results: Vec> = Vec::with_capacity(prepped.len()); + for call in &prepped { + match call { + PreppedCall::Pair(bc, alpha) => { + let r = math_cuda::logup::logup_pair_term_column_on_device( + &device_main, + bc.bus_id_a, + bc.bus_id_b, + &bc.ops_a, + &bc.ops_b, + &bc.linear_terms, + alpha, + &z_triple, + &bc.mult_a, + &bc.mult_b, + bc.negate_a, + bc.negate_b, + ) + .ok()?; + GPU_LOGUP_CALLS.fetch_add(1, std::sync::atomic::Ordering::Relaxed); + results.push(r); + } + PreppedCall::Single(bc, alpha) => { + let r = math_cuda::logup::logup_single_term_column_on_device( + &device_main, + bc.bus_id, + &bc.ops, + &bc.linear_terms, + alpha, + &z_triple, + &bc.mult, + bc.negate, + ) + .ok()?; + GPU_LOGUP_CALLS.fetch_add(1, std::sync::atomic::Ordering::Relaxed); + results.push(r); + } + } + } + results + // _guard drops here; device_main drops here (releases VRAM). + }; + + // ── CPU post (lock released — FieldElement reassembly parallelizable) ── + let mut committed = Vec::with_capacity(num_committed_pairs); + for r in raw_results.iter().take(num_committed_pairs) { + committed.push(triples_to_ext3_fieldelements::(r, trace_len)); + } + + let virtual_col = match absorbed_count { + 0 => None, + 1 | 2 => Some(triples_to_ext3_fieldelements::( + &raw_results[num_committed_pairs], + trace_len, + )), + _ => unreachable!(), + }; + + Some(TableTermColumns { + committed, + virtual_col, + }) +} + +/// Helper: single-interaction bytecode bundle used by the 1-absorbed branch. +struct SingleBytecode { + ops: Vec, + linear_terms: Vec, + mult: MultiplicityDesc, + bus_id: u64, + negate: bool, +} + +fn alpha_powers_u64_vec( + alpha: &FieldElement, + max_bus_elements: usize, +) -> Vec { + let fe = crate::lookup::compute_alpha_powers(alpha, max_bus_elements); + let mut out = Vec::with_capacity(max_bus_elements * 3); + for ap in &fe { + let t = unsafe { ext3_to_triple(ap) }; + out.extend_from_slice(&t); + } + out +} + +pub struct TableTermColumns { + pub committed: Vec>>, + pub virtual_col: Option>>, +} + +/// Try to run the pair on the GPU. Returns `Some(term_column)` on success +/// (3 * trace_len u64s flattened into FieldElement) or `None` if the +/// type combination isn't supported — in which case the caller falls +/// back to the CPU path. +pub fn try_compute_pair_term_column( + interaction_a: &BusInteraction, + interaction_b: &BusInteraction, + main_segment_cols: &[Vec>], + trace_len: usize, + challenges: &[FieldElement], +) -> Option>> +where + F: IsField + IsSubFieldOf, + E: IsField, +{ + if !gpu_supported::(trace_len) { + return None; + } + + // Compute alpha_powers (ext3 extension). Fallback on CPU side (cheap, + // runs once per pair, O(max_bus_elements) multiplications). + let alpha = &challenges[crate::lookup::LOGUP_CHALLENGE_ALPHA]; + let z = &challenges[0]; + + let bytecode = build_pair_bytecode(interaction_a, interaction_b); + let alpha_powers_fe = crate::lookup::compute_alpha_powers(alpha, bytecode.max_bus_elements); + + // Extract u64 views. + let main_cols_u64 = unsafe { flatten_main_cols(main_segment_cols) }; + let mut alpha_powers_u64 = Vec::with_capacity(bytecode.max_bus_elements * 3); + for ap in &alpha_powers_fe { + let t = unsafe { ext3_to_triple(ap) }; + alpha_powers_u64.extend_from_slice(&t); + } + let z_triple = unsafe { ext3_to_triple(z) }; + + let result = logup::logup_pair_term_column( + &main_cols_u64, + main_segment_cols.len(), + trace_len, + bytecode.bus_id_a, + bytecode.bus_id_b, + &bytecode.ops_a, + &bytecode.ops_b, + &bytecode.linear_terms, + &alpha_powers_u64, + &z_triple, + &bytecode.mult_a, + &bytecode.mult_b, + bytecode.negate_a, + bytecode.negate_b, + ) + .ok()?; + + GPU_LOGUP_CALLS.fetch_add(1, std::sync::atomic::Ordering::Relaxed); + Some(triples_to_ext3_fieldelements::(&result, trace_len)) +} + +/// Reassemble `trace_len` ext3 triples back into `FieldElement`. +/// SAFETY: caller must have verified E == Degree3GoldilocksExtensionField. +fn triples_to_ext3_fieldelements( + data: &[u64], + trace_len: usize, +) -> Vec> { + assert_eq!(data.len(), trace_len * 3); + let mut out = Vec::with_capacity(trace_len); + for i in 0..trace_len { + let triple: [FieldElement; 3] = [ + FieldElement::::from(data[i * 3]), + FieldElement::::from(data[i * 3 + 1]), + FieldElement::::from(data[i * 3 + 2]), + ]; + // SAFETY: type_name check at the entry point guarantees E is + // Degree3GoldilocksExtensionField, whose BaseType = [FpE; 3]. + let raw: ::BaseType = unsafe { core::mem::transmute_copy(&triple) }; + out.push(FieldElement::::from_raw(raw)); + } + out +} diff --git a/crypto/stark/src/lookup.rs b/crypto/stark/src/lookup.rs index 17ba7c5ec..e0e171f7e 100644 --- a/crypto/stark/src/lookup.rs +++ b/crypto/stark/src/lookup.rs @@ -112,7 +112,7 @@ const LOGUP_CHUNK_SIZE: usize = 1024; /// Returns `(num_committed_pairs, absorbed_count)` where: /// - Committed pairs get dedicated auxiliary term columns (2 interactions per column) /// - Absorbed interactions (1 or 2) are folded into the accumulated constraint -fn split_interactions(num_interactions: usize) -> (usize, usize) { +pub(crate) fn split_interactions(num_interactions: usize) -> (usize, usize) { if num_interactions <= 2 { (0, num_interactions) } else if num_interactions % 2 == 1 { @@ -1040,30 +1040,44 @@ where // Split interactions: committed pairs get term columns, last 1-2 are absorbed (virtual) let (num_committed_pairs, absorbed_count) = split_interactions(num_interactions); - // Compute committed term columns (batched pairs only). - // With `parallel`: when `trace_len > LOGUP_CHUNK_SIZE` the chunk-internal - // parallelism inside each pair already saturates Rayon, so iterate pairs - // sequentially to keep cache locality. When `trace_len <= LOGUP_CHUNK_SIZE` - // each pair yields a single chunk, so parallelize across pairs to recover - // the throughput the per-pair dispatch used to provide for small-trace - // tables with many interactions. - // Without `parallel`: sequential over pairs, sequential over rows. - #[cfg(feature = "parallel")] - let committed_columns: Vec>> = if trace_len <= LOGUP_CHUNK_SIZE { - (0..num_committed_pairs) - .into_par_iter() - .map(|i| { - compute_logup_batched_term_column( - &self.auxiliary_trace_build_data.interactions[i * 2], - &self.auxiliary_trace_build_data.interactions[i * 2 + 1], - &main_segment_cols, - trace_len, - challenges, - ) - }) - .collect() - } else { - (0..num_committed_pairs) + let compute_cpu = || { + // Compute committed term columns (batched pairs only). + // With `parallel`: when `trace_len > LOGUP_CHUNK_SIZE` the chunk-internal + // parallelism inside each pair already saturates Rayon, so iterate pairs + // sequentially to keep cache locality. When `trace_len <= LOGUP_CHUNK_SIZE` + // each pair yields a single chunk, so parallelize across pairs to recover + // the throughput the per-pair dispatch used to provide for small-trace + // tables with many interactions. + // Without `parallel`: sequential over pairs, sequential over rows. + #[cfg(feature = "parallel")] + let committed_columns: Vec>> = if trace_len <= LOGUP_CHUNK_SIZE { + (0..num_committed_pairs) + .into_par_iter() + .map(|i| { + compute_logup_batched_term_column( + &self.auxiliary_trace_build_data.interactions[i * 2], + &self.auxiliary_trace_build_data.interactions[i * 2 + 1], + &main_segment_cols, + trace_len, + challenges, + ) + }) + .collect() + } else { + (0..num_committed_pairs) + .map(|i| { + compute_logup_batched_term_column( + &self.auxiliary_trace_build_data.interactions[i * 2], + &self.auxiliary_trace_build_data.interactions[i * 2 + 1], + &main_segment_cols, + trace_len, + challenges, + ) + }) + .collect() + }; + #[cfg(not(feature = "parallel"))] + let committed_columns: Vec>> = (0..num_committed_pairs) .map(|i| { compute_logup_batched_term_column( &self.auxiliary_trace_build_data.interactions[i * 2], @@ -1073,39 +1087,50 @@ where challenges, ) }) - .collect() - }; - #[cfg(not(feature = "parallel"))] - let committed_columns: Vec>> = (0..num_committed_pairs) - .map(|i| { + .collect(); + + // Compute virtual column for absorbed interactions (NOT written to trace) + let virtual_column = if absorbed_count == 2 { compute_logup_batched_term_column( - &self.auxiliary_trace_build_data.interactions[i * 2], - &self.auxiliary_trace_build_data.interactions[i * 2 + 1], + &self.auxiliary_trace_build_data.interactions[num_interactions - 2], + &self.auxiliary_trace_build_data.interactions[num_interactions - 1], &main_segment_cols, trace_len, challenges, ) - }) - .collect(); + } else { + compute_logup_term_column( + &self.auxiliary_trace_build_data.interactions[num_interactions - 1], + &main_segment_cols, + trace_len, + challenges, + _table_name, + ) + }; + (committed_columns, virtual_column) + }; - // Compute virtual column for absorbed interactions (NOT written to trace) - let virtual_column = if absorbed_count == 2 { - compute_logup_batched_term_column( - &self.auxiliary_trace_build_data.interactions[num_interactions - 2], - &self.auxiliary_trace_build_data.interactions[num_interactions - 1], - &main_segment_cols, - trace_len, - challenges, - ) - } else { - compute_logup_term_column( - &self.auxiliary_trace_build_data.interactions[num_interactions - 1], + // CUDA fast path: upload main_cols once and run every pair against + // the shared device buffer. Skipped when F/E don't match or the + // trace is below the GPU threshold, in which case we fall through + // to the rayon-parallel CPU path. + #[cfg(feature = "cuda")] + let (committed_columns, virtual_column) = + match crate::logup_gpu::try_compute_table_term_columns( + &self.auxiliary_trace_build_data.interactions, &main_segment_cols, trace_len, challenges, - _table_name, - ) - }; + ) { + Some(gpu) => ( + gpu.committed, + gpu.virtual_col + .expect("GPU path always produces a virtual column"), + ), + None => compute_cpu(), + }; + #[cfg(not(feature = "cuda"))] + let (committed_columns, virtual_column) = compute_cpu(); // Write only committed columns to trace for (col_idx, col_data) in committed_columns.iter().enumerate() { @@ -1566,6 +1591,19 @@ where F: IsFFTField + IsSubFieldOf + IsPrimeField + Send + Sync, E: IsField + Send + Sync, { + #[cfg(feature = "cuda")] + { + if let Some(out) = crate::logup_gpu::try_compute_pair_term_column( + interaction_a, + interaction_b, + main_segment_cols, + trace_len, + challenges, + ) { + return out; + } + } + let z = &challenges[0]; let alpha = &challenges[LOGUP_CHALLENGE_ALPHA]; let max_bus_elements = interaction_a diff --git a/crypto/stark/src/prover.rs b/crypto/stark/src/prover.rs index a5386017a..c4d79916f 100644 --- a/crypto/stark/src/prover.rs +++ b/crypto/stark/src/prover.rs @@ -193,6 +193,30 @@ where struct Lde { main: Vec>>, aux: Vec>>, + /// Device-side main LDE buffer, populated only when the R1 GPU fused + /// pipeline ran for this table. Kept so R2/R3/R4 GPU paths can read + /// the LDE without re-H2D. + #[cfg(feature = "cuda")] + gpu_main: Option, + #[cfg(feature = "cuda")] + gpu_aux: Option, +} + +/// Result of `commit_main_trace` / `commit_preprocessed_trace`. Wraps the +/// commitment Merkle data plus the owned LDE columns, and — when the R1 +/// fused GPU pipeline ran — the retained device LDE handle. +pub struct MainTraceCommitResult +where + FieldElement: AsBytes, +{ + tree: BatchedMerkleTree, + root: Commitment, + precomputed_tree: Option>, + precomputed_root: Option, + num_precomputed_cols: usize, + columns: Vec>>, + #[cfg(feature = "cuda")] + gpu_main: Option, } impl Round1Commitments @@ -210,7 +234,18 @@ where blowup_factor: usize, has_aux_trace: bool, ) -> Round1 { - let lde_trace = LDETraceTable::from_columns(lde.main, lde.aux, step_size, blowup_factor); + #[allow(unused_mut)] + let mut lde_trace = + LDETraceTable::from_columns(lde.main, lde.aux, step_size, blowup_factor); + #[cfg(feature = "cuda")] + { + if let Some(h) = lde.gpu_main { + lde_trace.set_gpu_main(h); + } + if let Some(h) = lde.gpu_aux { + lde_trace.set_gpu_aux(h); + } + } let main = Round1CommitmentData:: { lde_trace_merkle_tree: Arc::clone(&self.main_merkle_tree), @@ -327,6 +362,11 @@ where pub(crate) composition_poly_merkle_tree: BatchedMerkleTree, /// The commitment to the composition polynomial parts. pub(crate) composition_poly_root: Commitment, + /// Device-side composition-poly LDE handle, retained when the R2 GPU + /// fused path produced the LDE. Lets R2 commit + R4 DEEP composition + /// skip re-H2D'ing the composition parts. + #[cfg(feature = "cuda")] + pub(crate) gpu_composition_parts: Option, } /// A container for the results of the third round of the STARK Prove protocol. @@ -517,6 +557,19 @@ pub trait IsStarkProver< return; } + // GPU batched fast path: all columns at once in one pipeline on one + // stream. Falls through to per-column rayon when the table is too + // small, the element type isn't Goldilocks, or the `cuda` feature is + // off. + #[cfg(feature = "cuda")] + if crate::gpu_lde::try_expand_columns_batched::( + columns, + domain.blowup_factor, + &twiddles.coset_weights, + ) { + return; + } + #[cfg(feature = "parallel")] let iter = columns.par_iter_mut(); #[cfg(not(feature = "parallel"))] @@ -534,29 +587,52 @@ pub trait IsStarkProver< } /// Compute main LDE, commit, and return the Merkle tree/root along with the - /// owned LDE columns (consumed later in Phase D). + /// owned LDE columns (consumed later in Phase D). When the fused GPU + /// pipeline runs, the device LDE buffer is also kept alive and returned so + /// downstream rounds can read it without a re-H2D. #[allow(clippy::type_complexity)] fn commit_main_trace( trace: &TraceTable, domain: &Domain, twiddles: &LdeTwiddles, - ) -> Result< - ( - BatchedMerkleTree, - Commitment, - Option>, - Option, - usize, - Vec>>, - ), - ProvingError, - > + ) -> Result, ProvingError> where FieldElement: AsBytes, FieldElement: AsBytes, { let lde_size = domain.interpolation_domain_size * domain.blowup_factor; let mut columns = trace.extract_columns_main(lde_size); + + #[cfg(feature = "cuda")] + { + #[cfg(feature = "instruments")] + let t_sub = Instant::now(); + if let Some((tree, handle)) = + crate::gpu_lde::try_expand_leaf_and_tree_batched_keep::< + Field, + Field, + BatchedMerkleTreeBackend, + >(&mut columns, domain.blowup_factor, &twiddles.coset_weights) + { + #[cfg(feature = "instruments")] + let main_lde_dur = t_sub.elapsed(); + #[cfg(feature = "instruments")] + let zero = std::time::Duration::from_secs(0); + let root = tree.root; + #[cfg(feature = "instruments")] + crate::instruments::accum_r1_main(main_lde_dur, zero); + return Ok(MainTraceCommitResult { + tree, + root, + precomputed_tree: None, + precomputed_root: None, + num_precomputed_cols: 0, + columns, + gpu_main: Some(handle), + }); + } + } + #[cfg(feature = "instruments")] let t_sub = Instant::now(); Self::expand_columns_to_lde::(&mut columns, domain, twiddles); @@ -570,7 +646,16 @@ pub trait IsStarkProver< #[cfg(feature = "instruments")] crate::instruments::accum_r1_main(main_lde_dur, t_sub.elapsed()); - Ok((tree, root, None, None, 0, columns)) + Ok(MainTraceCommitResult { + tree, + root, + precomputed_tree: None, + precomputed_root: None, + num_precomputed_cols: 0, + columns, + #[cfg(feature = "cuda")] + gpu_main: None, + }) } /// Commit preprocessed trace: precomputed and multiplicity columns get separate trees. @@ -581,17 +666,7 @@ pub trait IsStarkProver< precomputed_commitment: Commitment, num_precomputed_cols: usize, twiddles: &LdeTwiddles, - ) -> Result< - ( - BatchedMerkleTree, - Commitment, - Option>, - Option, - usize, - Vec>>, - ), - ProvingError, - > + ) -> Result, ProvingError> where FieldElement: AsBytes, FieldElement: AsBytes, @@ -621,14 +696,16 @@ pub trait IsStarkProver< "Prover's precomputed commitment doesn't match hardcoded AIR commitment" ); - Ok(( - mult_tree, - mult_root, - Some(precomputed_tree), - Some(precomputed_root), + Ok(MainTraceCommitResult { + tree: mult_tree, + root: mult_root, + precomputed_tree: Some(precomputed_tree), + precomputed_root: Some(precomputed_root), num_precomputed_cols, columns, - )) + #[cfg(feature = "cuda")] + gpu_main: None, + }) } /// Recompute Round1 from the trace, reusing the Merkle trees stored in commitments. @@ -841,6 +918,18 @@ pub trait IsStarkProver< // The squared coset offset is g² (= coset_offset²). let coset_offset_squared = &domain.coset_offset * &domain.coset_offset; + // GPU fast path: batch both halves into one ext3 LDE call. Requires + // `cuda` feature and a qualifying size; falls through to CPU when not. + #[cfg(feature = "cuda")] + if let Some((lde_h0, lde_h1)) = crate::gpu_lde::try_extend_two_halves_gpu( + &h0_evals, + &h1_evals, + &coset_offset_squared, + domain, + ) { + return vec![lde_h0, lde_h1]; + } + #[cfg(feature = "parallel")] let (lde_h0, lde_h1) = rayon::join( || Self::extend_half_to_lde(&h0_evals, &coset_offset_squared, domain), @@ -920,6 +1009,8 @@ pub trait IsStarkProver< #[cfg(feature = "instruments")] let t_sub = Instant::now(); + #[cfg(feature = "cuda")] + let mut gpu_comp_handle: Option = None; let lde_composition_poly_parts_evaluations = if number_of_parts == 2 { // Direct quotient decomposition: avoid full-size iFFT by algebraically // splitting H(x) = H₀(x²) + x·H₁(x²) using: @@ -936,28 +1027,72 @@ pub trait IsStarkProver< Polynomial::interpolate_offset_fft(&constraint_evaluations, &domain.coset_offset) .unwrap(); let composition_poly_parts = composition_poly.break_in_parts(number_of_parts); - composition_poly_parts - .iter() - .map(|part| { - evaluate_polynomial_on_lde_domain( - part, - domain.blowup_factor, - domain.interpolation_domain_size, - &domain.coset_offset, - ) - .unwrap() - }) - .collect() + + // GPU fast path: batch all parts' LDEs into a single call AND + // retain the device buffer so R2 commit + R4 DEEP composition + // can read it without re-H2D'ing. Falls through to CPU when + // `cuda` is off or the size is below the GPU threshold. + #[cfg(feature = "cuda")] + let gpu_result = { + let parts_slices: Vec<&[FieldElement]> = + composition_poly_parts + .iter() + .map(|p| p.coefficients.as_slice()) + .collect(); + crate::gpu_lde::try_evaluate_parts_on_lde_gpu_keep::( + &parts_slices, + domain.blowup_factor, + domain.interpolation_domain_size, + &domain.coset_offset, + ) + }; + #[cfg(not(feature = "cuda"))] + let gpu_result: Option<(Vec>>, ())> = None; + + if let Some((results, handle)) = gpu_result { + #[cfg(feature = "cuda")] + { + gpu_comp_handle = Some(handle); + } + #[cfg(not(feature = "cuda"))] + let _ = handle; + results + } else { + composition_poly_parts + .iter() + .map(|part| { + evaluate_polynomial_on_lde_domain( + part, + domain.blowup_factor, + domain.interpolation_domain_size, + &domain.coset_offset, + ) + .unwrap() + }) + .collect() + } }; #[cfg(feature = "instruments")] let fft_dur = t_sub.elapsed(); #[cfg(feature = "instruments")] let t_sub = Instant::now(); - let Some((composition_poly_merkle_tree, composition_poly_root)) = - Self::commit_composition_polynomial(&lde_composition_poly_parts_evaluations) - else { - return Err(ProvingError::EmptyCommitment); + #[cfg(feature = "cuda")] + let gpu_tree = crate::gpu_lde::try_build_comp_poly_tree_gpu::< + FieldExtension, + BatchedMerkleTreeBackend, + >(&lde_composition_poly_parts_evaluations); + #[cfg(not(feature = "cuda"))] + let gpu_tree: Option> = None; + + let (composition_poly_merkle_tree, composition_poly_root) = if let Some(tree) = gpu_tree { + let root = tree.root; + (tree, root) + } else { + match Self::commit_composition_polynomial(&lde_composition_poly_parts_evaluations) { + Some(pair) => pair, + None => return Err(ProvingError::EmptyCommitment), + } }; #[cfg(feature = "instruments")] let merkle_dur = t_sub.elapsed(); @@ -969,6 +1104,8 @@ pub trait IsStarkProver< lde_composition_poly_evaluations: lde_composition_poly_parts_evaluations, composition_poly_merkle_tree, composition_poly_root, + #[cfg(feature = "cuda")] + gpu_composition_parts: gpu_comp_handle, }) } @@ -1626,6 +1763,9 @@ pub trait IsStarkProver< let mut main_commits: Vec> = Vec::with_capacity(num_airs); let mut main_ldes: Vec>>> = Vec::with_capacity(num_airs); + #[cfg(feature = "cuda")] + let mut main_gpu_handles: Vec> = + Vec::with_capacity(num_airs); for chunk_start in (0..num_airs).step_by(k) { let chunk_end = (chunk_start + k).min(num_airs); @@ -1658,19 +1798,21 @@ pub trait IsStarkProver< // Sequential: append roots to shared transcript (Fiat-Shamir ordering) for result in chunk_results { - let (tree, root, pre_tree, pre_root, n_pre, cached_main) = result?; - if let Some(ref pre_r) = pre_root { + let r = result?; + if let Some(ref pre_r) = r.precomputed_root { transcript.append_bytes(pre_r); } - transcript.append_bytes(&root); + transcript.append_bytes(&r.root); main_commits.push(MainCommitData { - main_tree: Arc::new(tree), - main_root: root, - precomputed_tree: pre_tree.map(Arc::new), - precomputed_root: pre_root, - num_precomputed_cols: n_pre, + main_tree: Arc::new(r.tree), + main_root: r.root, + precomputed_tree: r.precomputed_tree.map(Arc::new), + precomputed_root: r.precomputed_root, + num_precomputed_cols: r.num_precomputed_cols, }); - main_ldes.push(cached_main); + main_ldes.push(r.columns); + #[cfg(feature = "cuda")] + main_gpu_handles.push(r.gpu_main); } } @@ -1747,13 +1889,24 @@ pub trait IsStarkProver< }) .collect(); - // Parallel aux commit in chunks of K - #[allow(clippy::type_complexity)] - let mut aux_results: Vec<( - Option>>, + // Parallel aux commit in chunks of K. Fourth field is an optional + // GPU ext3 LDE handle retained when the R1 fused pipeline fires. + #[cfg(feature = "cuda")] + type AuxResult = ( + Option>>, Option, - Vec>>, - )> = Vec::with_capacity(num_airs); + Vec>>, + Option, + ); + #[cfg(not(feature = "cuda"))] + type AuxResult = ( + Option>>, + Option, + Vec>>, + (), + ); + #[allow(clippy::type_complexity)] + let mut aux_results: Vec> = Vec::with_capacity(num_airs); for chunk_start in (0..num_airs).step_by(k) { let chunk_end = (chunk_start + k).min(num_airs); @@ -1773,6 +1926,42 @@ pub trait IsStarkProver< if air.has_aux_trace() { let lde_size = domain.interpolation_domain_size * domain.blowup_factor; let mut columns = trace.extract_columns_aux(lde_size); + + // GPU combined path: ext3 LDE + Keccak-256 leaf + // hashing + Merkle tree build in one on-device + // pipeline. The fused `_keep` variant also returns + // the device LDE handle for downstream GPU rounds. + #[cfg(feature = "cuda")] + { + #[cfg(feature = "instruments")] + let t_sub = Instant::now(); + if let Some((tree, handle)) = + crate::gpu_lde::try_expand_leaf_and_tree_batched_ext3_keep::< + Field, + FieldExtension, + BatchedMerkleTreeBackend, + >( + &mut columns, + domain.blowup_factor, + &twiddles.coset_weights, + ) + { + #[cfg(feature = "instruments")] + let aux_lde_dur = t_sub.elapsed(); + #[cfg(feature = "instruments")] + let zero = std::time::Duration::from_secs(0); + let root = tree.root; + #[cfg(feature = "instruments")] + crate::instruments::accum_r1_aux(aux_lde_dur, zero); + return Ok(( + Some(Arc::new(tree)), + Some(root), + columns, + Some(handle), + )); + } + } + #[cfg(feature = "instruments")] let t_sub = Instant::now(); Self::expand_columns_to_lde::( @@ -1789,20 +1978,28 @@ pub trait IsStarkProver< #[cfg(feature = "instruments")] crate::instruments::accum_r1_aux(aux_lde_dur, t_sub.elapsed()); - Ok((Some(Arc::new(tree)), Some(root), columns)) + #[cfg(feature = "cuda")] + let aux_gpu: Option = None; + #[cfg(not(feature = "cuda"))] + let aux_gpu: () = (); + Ok((Some(Arc::new(tree)), Some(root), columns, aux_gpu)) } else { - Ok((None, None, Vec::new())) + #[cfg(feature = "cuda")] + let aux_gpu: Option = None; + #[cfg(not(feature = "cuda"))] + let aux_gpu: () = (); + Ok((None, None, Vec::new(), aux_gpu)) } }) .collect(); // Sequential: append aux roots to forked transcripts for (j, result) in chunk_aux.into_iter().enumerate() { - let (aux_tree, aux_root, cached_aux) = result?; + let (aux_tree, aux_root, cached_aux, aux_gpu) = result?; if let Some(ref root) = aux_root { table_transcripts[chunk_start + j].append_bytes(root); } - aux_results.push((aux_tree, aux_root, cached_aux)); + aux_results.push((aux_tree, aux_root, cached_aux, aux_gpu)); } } @@ -1811,12 +2008,25 @@ pub trait IsStarkProver< let mut commitments: Vec> = Vec::with_capacity(num_airs); let mut cached_ldes: Vec> = Vec::with_capacity(num_airs); - for (((main_commit, main_lde), (aux_tree, aux_root, cached_aux)), bus_public_inputs) in - main_commits - .into_iter() - .zip(main_ldes) - .zip(aux_results) - .zip(bus_inputs_vec) + // Zip in the optional GPU handles so the Lde constructor always + // has a value for its gpu_main/gpu_aux. Under `cfg(not(cuda))` the + // handles are `()` (see AuxResult type alias) — we just discard them. + #[cfg(feature = "cuda")] + let main_gpu_iter: Box>> = + Box::new(main_gpu_handles.into_iter()); + #[cfg(not(feature = "cuda"))] + let main_gpu_iter: Box> = + Box::new(std::iter::repeat_with(|| ()).take(num_airs)); + + for ( + (((main_commit, main_lde), main_gpu_h), (aux_tree, aux_root, cached_aux, aux_gpu_h)), + bus_public_inputs, + ) in main_commits + .into_iter() + .zip(main_ldes) + .zip(main_gpu_iter) + .zip(aux_results) + .zip(bus_inputs_vec) { commitments.push(Round1Commitments { main_merkle_tree: main_commit.main_tree, @@ -1829,10 +2039,22 @@ pub trait IsStarkProver< rap_challenges: lookup_challenges.clone(), bus_public_inputs, }); + #[cfg(feature = "cuda")] cached_ldes.push(Lde { main: main_lde, aux: cached_aux, + gpu_main: main_gpu_h, + gpu_aux: aux_gpu_h, }); + #[cfg(not(feature = "cuda"))] + { + let _ = main_gpu_h; + let _ = aux_gpu_h; + cached_ldes.push(Lde { + main: main_lde, + aux: cached_aux, + }); + } } #[cfg(feature = "instruments")] diff --git a/crypto/stark/src/trace.rs b/crypto/stark/src/trace.rs index ef6ee7833..f47aa1346 100644 --- a/crypto/stark/src/trace.rs +++ b/crypto/stark/src/trace.rs @@ -193,6 +193,16 @@ where pub(crate) aux_columns: Vec>>, pub(crate) lde_step_size: usize, pub(crate) blowup_factor: usize, + /// If the main trace was LDE'd on the GPU via the fused pipeline, + /// the device buffer is retained here so downstream GPU rounds can + /// read the LDE without a re-H2D. `None` when the GPU LDE didn't + /// run (small tables, cuda feature off, fallback path). + #[cfg(feature = "cuda")] + pub(crate) gpu_main: Option, + /// Same as `gpu_main` but for the aux trace (ext3 de-interleaved + /// layout on device). + #[cfg(feature = "cuda")] + pub(crate) gpu_aux: Option, } impl LDETraceTable @@ -215,9 +225,37 @@ where aux_columns, lde_step_size, blowup_factor, + #[cfg(feature = "cuda")] + gpu_main: None, + #[cfg(feature = "cuda")] + gpu_aux: None, } } + /// Attach an already-populated device LDE handle for the main columns. + /// Only set when the GPU fused pipeline produced the LDE — callers that + /// ran the CPU path should leave this alone. + #[cfg(feature = "cuda")] + pub fn set_gpu_main(&mut self, h: math_cuda::lde::GpuLdeBase) { + self.gpu_main = Some(h); + } + + /// Attach an already-populated device LDE handle for the aux columns. + #[cfg(feature = "cuda")] + pub fn set_gpu_aux(&mut self, h: math_cuda::lde::GpuLdeExt3) { + self.gpu_aux = Some(h); + } + + #[cfg(feature = "cuda")] + pub fn gpu_main(&self) -> Option<&math_cuda::lde::GpuLdeBase> { + self.gpu_main.as_ref() + } + + #[cfg(feature = "cuda")] + pub fn gpu_aux(&self) -> Option<&math_cuda::lde::GpuLdeExt3> { + self.gpu_aux.as_ref() + } + /// Consume self and return the owned column vectors. #[allow(clippy::type_complexity)] pub fn into_columns(self) -> (Vec>>, Vec>>) { @@ -406,58 +444,114 @@ where let vanishing = z_pow_n.sub_subfield(&dc.offset_pow_n); let vanishing_factor = &n_inv_g_n_inv * &vanishing; - // Precompute inv_denoms = 1/(eval_point - coset_point_i) — shared across all columns + // Precompute inv_denoms = 1/(eval_point - coset_point_i) — shared across all columns. + // Stays on CPU: batch-invert cost at this scale (n × num_eval_points ≈ 3 × 2^18 per + // table) is already rayon-parallelised across 7 tables, and a GPU port regressed + // wall time in a 2×15-trial A/B due to stream contention from 21 concurrent launches. let inv_denoms = barycentric_inv_denoms(eval_point, &dc.points); - // Precompute col_scale[i] = point[i] * inv_denom[i] — shared across ALL columns. - // This eliminates N redundant F×E multiplies per column. - let col_scale: Vec> = dc - .points - .iter() - .zip(inv_denoms.iter()) - .map(|(point, inv_d)| point * inv_d) - .collect(); - - // Evaluate all main columns directly from LDE (no extraction copy). - // For main columns (base field F): sum = Σ col_scale[i] * lde_col[i*bf] - // lde_col[i*bf] is F, col_scale[i] is E; use F×E → E mixed arithmetic. - #[cfg(feature = "parallel")] - let main_iter = (0..num_main_cols).into_par_iter(); - #[cfg(not(feature = "parallel"))] - let main_iter = 0..num_main_cols; - let main_evals: Vec> = main_iter - .map(|col_idx| { - let lde_col = &lde_trace.main_columns[col_idx]; - let sum = col_scale - .iter() - .enumerate() - .fold(FieldElement::::zero(), |acc, (i, scale)| { - acc + &lde_col[i * bf] * scale - }); - &vanishing_factor * &sum - }) - .collect(); + // GPU fast path: batched strided barycentric over the main-trace + // LDE already on device. Falls through if the GPU LDE handles + // aren't populated (small tables, cuda feature off, or the CPU + // path filled the LDE). + #[cfg(feature = "cuda")] + let main_gpu = crate::gpu_lde::try_barycentric_base_on_handle::( + lde_trace, + bf, + &dc.points, + &dc.offset_pow_n, + &dc.size_inv, + &dc.offset_pow_n_inv, + &z_pow_n, + &inv_denoms, + ); + #[cfg(not(feature = "cuda"))] + let main_gpu: Option>> = None; + + let main_evals: Vec> = if let Some(v) = main_gpu { + v + } else { + // Precompute col_scale[i] = point[i] * inv_denom[i] — shared across ALL columns. + // This eliminates N redundant F×E multiplies per column. + let col_scale: Vec> = dc + .points + .iter() + .zip(inv_denoms.iter()) + .map(|(point, inv_d)| point * inv_d) + .collect(); + + // Evaluate all main columns directly from LDE (no extraction copy). + // For main columns (base field F): sum = Σ col_scale[i] * lde_col[i*bf] + // lde_col[i*bf] is F, col_scale[i] is E; use F×E → E mixed arithmetic. + #[cfg(feature = "parallel")] + let main_iter = (0..num_main_cols).into_par_iter(); + #[cfg(not(feature = "parallel"))] + let main_iter = 0..num_main_cols; + let main_evals: Vec> = main_iter + .map(|col_idx| { + let lde_col = &lde_trace.main_columns[col_idx]; + let sum = + col_scale + .iter() + .enumerate() + .fold(FieldElement::::zero(), |acc, (i, scale)| { + acc + &lde_col[i * bf] * scale + }); + &vanishing_factor * &sum + }) + .collect(); + main_evals + }; table_data.extend(main_evals); - // Evaluate all aux columns directly from LDE (no extraction copy). - // For aux columns (extension field E): sum = Σ col_scale[i] * lde_col[i*bf] - // Both col_scale and lde_col are in E, so each multiply is E×E → E. - #[cfg(feature = "parallel")] - let aux_iter = (0..num_aux_cols).into_par_iter(); - #[cfg(not(feature = "parallel"))] - let aux_iter = 0..num_aux_cols; - let aux_evals: Vec> = aux_iter - .map(|col_idx| { - let lde_col = &lde_trace.aux_columns[col_idx]; - let sum = col_scale - .iter() - .enumerate() - .fold(FieldElement::::zero(), |acc, (i, scale)| { - acc + scale * &lde_col[i * bf] - }); - &vanishing_factor * &sum - }) - .collect(); + // GPU fast path for aux columns. + #[cfg(feature = "cuda")] + let aux_gpu = crate::gpu_lde::try_barycentric_ext3_on_handle::( + lde_trace, + bf, + &dc.points, + &dc.offset_pow_n, + &dc.size_inv, + &dc.offset_pow_n_inv, + &z_pow_n, + &inv_denoms, + ); + #[cfg(not(feature = "cuda"))] + let aux_gpu: Option>> = None; + + let aux_evals: Vec> = if let Some(v) = aux_gpu { + v + } else { + // Precompute col_scale[i] = point[i] * inv_denom[i] — shared across all aux columns. + let col_scale: Vec> = dc + .points + .iter() + .zip(inv_denoms.iter()) + .map(|(point, inv_d)| point * inv_d) + .collect(); + + // Evaluate all aux columns directly from LDE (no extraction copy). + // For aux columns (extension field E): sum = Σ col_scale[i] * lde_col[i*bf] + // Both col_scale and lde_col are in E, so each multiply is E×E → E. + #[cfg(feature = "parallel")] + let aux_iter = (0..num_aux_cols).into_par_iter(); + #[cfg(not(feature = "parallel"))] + let aux_iter = 0..num_aux_cols; + let aux_evals: Vec> = aux_iter + .map(|col_idx| { + let lde_col = &lde_trace.aux_columns[col_idx]; + let sum = + col_scale + .iter() + .enumerate() + .fold(FieldElement::::zero(), |acc, (i, scale)| { + acc + scale * &lde_col[i * bf] + }); + &vanishing_factor * &sum + }) + .collect(); + aux_evals + }; table_data.extend(aux_evals); } diff --git a/prover/Cargo.toml b/prover/Cargo.toml index dac711002..8bbad714d 100644 --- a/prover/Cargo.toml +++ b/prover/Cargo.toml @@ -6,6 +6,7 @@ edition = "2024" [features] default = ["parallel"] parallel = ["stark/parallel", "math/parallel", "crypto/parallel", "dep:rayon"] +cuda = ["stark/cuda"] debug-checks = ["stark/debug-checks"] instruments = ["stark/instruments"] @@ -20,6 +21,7 @@ rayon = { version = "1.8.0", optional = true } [dev-dependencies] env_logger = "*" criterion = { version = "0.5", default-features = false } +stark = { path = "../crypto/stark" } [[bench]] name = "vm_prover_benchmark" diff --git a/prover/tests/bench_gpu.rs b/prover/tests/bench_gpu.rs new file mode 100644 index 000000000..fa225c54b --- /dev/null +++ b/prover/tests/bench_gpu.rs @@ -0,0 +1,74 @@ +//! End-to-end timing probe: prove `fib_iterative_1M` (≈1M instructions) once +//! and print wall-clock time. Intended to be run twice — once with the `cuda` +//! feature, once without — so the caller can compare. Ignored by default. +//! +//! Usage: +//! cargo test -p lambda-vm-prover --release --test bench_gpu -- --ignored --nocapture +//! cargo test -p lambda-vm-prover --release --features cuda --test bench_gpu -- --ignored --nocapture + +use std::time::Instant; + +use lambda_vm_prover::test_utils::asm_elf_bytes; + +fn bench_prove(name: &str, trials: u32) { + let elf = asm_elf_bytes(name); + // Warm up — first prove pays lazy one-time costs (PTX load on the GPU side, + // buffer pool warm-up on the CPU side). + let _ = lambda_vm_prover::prove(&elf).expect("warm-up prove"); + + #[cfg(feature = "cuda")] + stark::gpu_lde::reset_gpu_lde_calls(); + + let t0 = Instant::now(); + for _ in 0..trials { + let _ = lambda_vm_prover::prove(&elf).expect("prove"); + } + let elapsed = t0.elapsed().as_secs_f64() / trials as f64; + + let gpu = if cfg!(feature = "cuda") { "gpu" } else { "cpu" }; + println!("prove({name}) [{gpu}]: {elapsed:.3}s avg over {trials} trials"); + + #[cfg(feature = "cuda")] + { + let calls = stark::gpu_lde::gpu_lde_calls(); + let eh = stark::gpu_lde::gpu_extend_halves_calls(); + let r4 = stark::gpu_lde::gpu_r4_lde_calls(); + let parts = stark::gpu_lde::gpu_parts_lde_calls(); + let leaf = stark::gpu_lde::gpu_leaf_hash_calls(); + let bary = stark::gpu_lde::gpu_bary_calls(); + let mtree = stark::gpu_lde::gpu_merkle_tree_calls(); + let deep = stark::gpu_lde::gpu_deep_calls(); + println!(" GPU LDE calls across {trials} proves: {calls}"); + println!(" GPU deep-composition calls: {deep}"); + println!(" GPU extend_two_halves calls: {eh}"); + println!(" GPU R4 deep-poly LDE calls: {r4}"); + println!(" GPU R2 parts LDE calls: {parts}"); + println!(" GPU leaf-hash calls: {leaf}"); + println!(" GPU barycentric OOD calls: {bary}"); + println!(" GPU Merkle inner-tree calls: {mtree}"); + } +} + +#[test] +#[ignore = "bench; run with --ignored --nocapture"] +fn bench_prove_fib_1m() { + bench_prove("fib_iterative_1M", 5); +} + +#[test] +#[ignore = "bench; run with --ignored --nocapture"] +fn bench_prove_fib_1m_long() { + bench_prove("fib_iterative_1M", 15); +} + +#[test] +#[ignore = "bench; run with --ignored --nocapture"] +fn bench_prove_fib_2m() { + bench_prove("fib_iterative_2M", 5); +} + +#[test] +#[ignore = "bench; run with --ignored --nocapture"] +fn bench_prove_fib_4m() { + bench_prove("fib_iterative_4M", 3); +} diff --git a/prover/tests/bench_single.rs b/prover/tests/bench_single.rs new file mode 100644 index 000000000..947f0fddf --- /dev/null +++ b/prover/tests/bench_single.rs @@ -0,0 +1,12 @@ +//! Single-prove bench for profiling with nsys / ncu. +use lambda_vm_prover::test_utils::asm_elf_bytes; + +#[test] +#[ignore = "bench; run with --ignored --nocapture"] +fn prove_fib_1m_once() { + let elf = asm_elf_bytes("fib_iterative_1M"); + // Warm-up pays one-time costs (PTX load, pool warm-up). + let _ = lambda_vm_prover::prove(&elf).expect("warm-up"); + // The profiled run: + let _ = lambda_vm_prover::prove(&elf).expect("prove"); +}