Skip to content

Cuda Exp 11#575

Draft
MauroToscano wants to merge 38 commits into
mainfrom
cuda/exp-11-merge
Draft

Cuda Exp 11#575
MauroToscano wants to merge 38 commits into
mainfrom
cuda/exp-11-merge

Conversation

@MauroToscano
Copy link
Copy Markdown
Contributor

@MauroToscano MauroToscano commented May 2, 2026

This the latest cuda experimental version, it currently gave this speedup results:

image

It's missing constraint evaluation and trace expansion, and also the bench was not using GPU logup

Benchmark machine:

24 Cores
128 GiB of ram

GPU

  • NVIDIA GeForce RTX 5090, 32 GiB VRAM

Detailed breakdown:

image

Mauro Toscano and others added 30 commits May 2, 2026 15:37
New `math-cuda` crate carries a CUDA backend for the per-table coset LDE:
 - Goldilocks field arithmetic on device (bit-identical to CPU).
 - Radix-2 DIT NTT with shared-memory fusion of the first 8 levels.
 - Batched variant: one kernel launch handles all M columns of a table.
 - Single shared pinned host staging buffer, grows to max LDE seen.
 - Outputs written directly into caller-provided slices.

Wired in at stark::prover::expand_columns_to_lde behind a `cuda` feature
flag; Goldilocks-base tables above the LDE-size threshold route to the
GPU batched path, others fall through to the existing rayon CPU path.

Bench (RTX 5090, 46-core CPU, blowup=4, warm):
 - 64 cols, n=2^16 (LDE 2^18): CPU 95ms, GPU 18ms (~5x)
 - 20 cols, n=2^20 (LDE 2^22): CPU 512ms, GPU 266ms (~2x)
 - 1M fib end-to-end: CPU ~16.5s, CUDA ~15s (warm)

Feature is opt-in (`stark/cuda`, `lambda-vm-prover/cuda`). CPU-only builds
are byte-identical to before and pay zero overhead.
The batched-LDE host pack was a single-threaded memcpy from caller Vecs
into the pinned staging buffer. At prover scale (20 cols x 1M rows, 640
MB) that ran in 27 ms on one core while the GPU sat idle. Parallelising
with rayon par_iter saturates DRAM across cores and drops pack to ~8 ms.

Microbench impact (RTX 5090, prover-scale 20 cols, log_n=20, blowup=4):
 - Before: host pack 27 ms
 - After:  host pack  8 ms

Also switched bench_quick to median-of-10 trials for stable measurements
(prior single-trial numbers were 10-50% noisy).
An NTT over Goldilocks cubic-extension columns is algebraically
equivalent to three independent base-field NTTs over the component
slabs, because the DIT butterfly multiplies by a base twiddle and
`base * ext3` acts componentwise. Exploit this to route the aux-trace
LDE (previously the biggest remaining FFT chunk on the CPU path) to
the existing base-field batched kernels with no new CUDA:

 - `math_cuda::lde::coset_lde_batch_ext3_into` de-interleaves each
   ext3 column into three base slabs in the pinned staging buffer,
   runs the batched NTT over 3M logical slabs, then re-interleaves
   three slabs back per output column.
 - Stark\'s `gpu_lde::try_expand_columns_batched` now dispatches to
   this path when `E == Degree3GoldilocksExtensionField`. Base-field
   tables still go through the 1-col kernel as before.
 - Parity-tested in `tests/lde_batch_ext3.rs` vs CPU coset_lde_full_expand.

End-to-end on fib_iterative_1M (median of 5 trials):
 - CPU (rayon, 46 cores): 17.02s
 - CUDA before this change:  16.97s  (~tied)
 - CUDA after this change:   16.15s  (5.1% faster than CPU)

Instruments breakdown (aggregate over rayon threads):
 - Main LDE:   3.3s CPU -> 2.1s GPU
 - Aux LDE:    2.9s CPU -> 2.4s GPU (newly GPU-accelerated)

Also added `NOTES.md` with a running log of what\'s been tried and the
remaining path to a larger (10x-class) speedup.
…ales up)

Adds `try_extend_two_halves_gpu` which batches the two rayon::join()ed
`extend_half_to_lde` calls inside `decompose_and_extend_d2` into a single
GPU ext3 LDE call. Weights are `g^(-k) / N` so the `(g²)^(-k)` input-
coset undo from `interpolate_offset_fft` and the `g^k` output-coset shift
from `evaluate_polynomial_on_lde_domain` combine to a single multiply.

In the current VM config the big tables hit the `number_of_parts > 2`
branch in `round_2_compute_composition_polynomial` (interpolate_offset_fft
+ break_in_parts + evaluate_polynomial_on_lde_domain) and only tiny
tables (h0.len == 16) reach `decompose_and_extend_d2`; those land below
the GPU LDE threshold, so this path currently fires 0 times per proof.
The infrastructure is correct and parity-tested, and will pick up work
automatically when AIRs land with `degree_bound(N)/N == 2` at prover
scale.

End-to-end on fib_iterative_1M (median of 5 trials):
 - CPU (rayon, 46 cores): 17.010s
 - CUDA:                  15.665s  (7.9% faster, stable across runs)
Round 4 extends the DEEP composition polynomial from N trace-coset
evaluations to `domain_size` LDE-coset evaluations via
`interpolate_fft + evaluate_fft(poly, 1, Some(domain_size))` — that's
the no-coset ext3 LDE pattern with uniform `1/N` weights, which our
existing `coset_lde_batch_ext3_into` already implements.

Added `try_r4_deep_poly_lde_gpu` that routes the ext3 LDE through the
batched GPU path; prover falls back to CPU when the feature is off or
size is below threshold. Caller keeps its trailing `bit_reverse_permute`
so output order is unchanged.

End-to-end on fib_iterative_1M (median of 5 trials):
 - CPU (rayon, 46 cores): 16.907s
 - CUDA after this change:  14.971s  (11.5% faster end-to-end)

R4 deep-poly LDE fires ~8-9 times per proof at prover scale (one per
big table).
The `number_of_parts > 2` branch of round_2_compute_composition_polynomial
does `interpolate_offset_fft(2N evals) -> break_in_parts -> K calls to
evaluate_polynomial_on_lde_domain`. At 1M-fib scale each of those K
evaluations is a 2^20 -> 2^22 ext3 FFT on the g-coset — the single
biggest FFT chunk in the proof after the main-trace LDE.

Added `math_cuda::lde::evaluate_poly_coset_batch_ext3_into` which skips
the iFFT stage (input is coefficients, not evaluations) and applies just
the `offset^k` coset scaling + padded forward NTT. Parity-tested
against `Polynomial::evaluate_offset_fft`.

Stark prover now batches all K parts into a single GPU call via
`try_evaluate_parts_on_lde_gpu`; CPU interpolate_offset_fft still runs
once per table (smaller, and reusing it unchanged avoids scaffolding).

End-to-end on fib_iterative_1M (median of 5 trials):
 - CPU (rayon, 46 cores):             17.641s
 - CUDA after this change:            13.460s  (23.7% faster end-to-end)

GPU R2 parts LDE fires 42 times (~8 big tables per proof * 5 trials).
End-to-end on RTX 5090 vs 46-core rayon CPU:
 - fib_iterative_1M:  17.64s CPU -> 13.46s CUDA  (1.31x, 24% faster)
 - fib_iterative_4M:  41.40s CPU -> 35.14s CUDA  (1.18x, 15% faster)

All 28 math-cuda parity tests + 121 stark cuda tests pass.
Add a Keccak-f1600 kernel and two batched leaf-hash kernels
(`keccak256_leaves_base_batched`, `keccak256_leaves_ext3_batched`) that
read canonical u64 values directly from the device LDE buffer, byte-swap
into Keccak lanes, absorb, and squeeze 32-byte digests. Matches the CPU
reference path (`canonical_u64().to_be_bytes()` → Keccak-256 with 0x01
padding) bit-for-bit — parity-tested in `tests/keccak_leaves.rs` across
base + ext3 and a sweep of `log_n` / column counts.

Introduce `coset_lde_batch_base_into_with_leaf_hash` which runs the
full NTT pipeline + Merkle leaf hash in one on-device sequence, then
D2Hs LDE columns into the existing pinned staging AND hashed leaves
into a new dedicated pinned staging — same stream so the two transfers
queue back to back at pinned PCIe rate.

Stark prover's `commit_main_trace` calls a new
`try_expand_and_leaf_hash_batched` helper that routes the whole
expand+commit chain through the combined GPU path; Merkle tree is built
on CPU from the GPU-computed hashed leaves via
`BatchedMerkleTree::build_from_hashed_leaves`. Falls through to CPU
when `cuda` is off or size is below threshold.

Block size dropped to 128 threads for the Keccak kernels — the 25-lane
state + auxiliary arrays push per-thread register usage past the sm_120
block register budget at 256 threads.

End-to-end on fib_iterative_1M (median of 5 trials):
 - CPU (rayon, 46 cores):             17.658s
 - CUDA before this change:           13.460s  (23.7% faster)
 - CUDA after this change:            12.959s  (26.6% faster)

Aggregate instrument numbers for the main-trace commit phase:
 - Main Merkle before (CPU Keccak):   ~5.79 s aggregate
 - Main Merkle after  (GPU Keccak):   ~1.26 s aggregate  (tree build only)
Extend the combined LDE+leaf-hash pipeline to the aux-trace commit.
`coset_lde_batch_ext3_into_with_leaf_hash` runs the ext3 LDE over the
three de-interleaved base slabs and invokes the
`keccak256_leaves_ext3_batched` kernel directly on the same device
buffer, re-interleaves the LDE output for Round 2-4 reuse, and returns
hashed leaves via the pinned hash staging.

Stark prover wires it into `multi_prove`'s aux-commit chunk so each
RAP-table's aux-trace LDE + Merkle commit run as one GPU pipeline.

End-to-end on fib_iterative_1M (median of 5 trials):
 - CPU (rayon, 46 cores):             18.269s
 - CUDA (main-only Keccak, prev):     12.959s  (26.6% faster)
 - CUDA (main + aux Keccak, now):     13.127s  (28.1% faster)

Counter shows ~15 leaf-hash calls per proof (main + aux across 8 big
tables).
Current speedup on fib_iterative_1M vs 46-core rayon CPU: 1.39x
(28.1% faster, 18.27s -> 13.13s).

Documents what's still on CPU (R3 OOD, R2 evaluate, deep composition,
R2/R4 Merkle commits) and what it would take to reach ~2x.
Adds GPU infrastructure for barycentric point-evaluation of ext3 columns
at a single evaluation point — the primitive behind R3 OOD and R4 DEEP
composition. Parity-tested against the CPU reference but kept unwired
in the prover: benchmarking showed R3 OOD is already rayon-parallelised
in negligible wall time on a 46-core host while the GPU is busy with
LDE/Merkle on other streams, so routing R3 to the GPU regresses the
end-to-end proof (fib_1M 13.09s → 14.20s, fib_4M 33.67s → 36.03s).
The kernels remain as a building block for single-table or very-large-
trace workloads where the GPU has idle windows during R3.

New:
- kernels/ext3.cuh: full ext3 arithmetic (add/sub/neg/mul_base/mul).
  Uses a dot3 helper that fuses 3 u128 products into a single reduce128
  to cut ext3 multiplication cost.
- kernels/barycentric.cu: batched kernels over M columns, one CUDA block
  per column, shared-memory tree reduction, 256 threads per block. Two
  variants: base-field and ext3 columns (de-interleaved 3-slab layout).
  Returns the unscaled sum; the caller applies the ext3 scalar on host.
- src/barycentric.rs: Rust launchers for both kernels.
- tests/ext3.rs, tests/barycentric.rs: parity against CPU Degree3 ops.

Plumbing:
- build.rs compiles the new PTX.
- device.rs registers the four new kernel handles.
- gpu_lde.rs adds wrappers + counter (dead-coded until re-wired).
- bin/cli exposes a `cuda` feature so the CLI picks up the GPU build.
- bench_gpu prints the bary-call counter alongside the other GPU counters.
Adds `keccak_merkle_level` CUDA kernel that does one level of pair-hash
in the standard Merkle node layout (matches the CPU
`build_from_hashed_leaves` node order). A Rust wrapper
`math_cuda::merkle::build_merkle_tree_on_device` drives it
layer-by-layer to build the full tree.

Exposes `MerkleTree::from_precomputed_nodes` in crypto/crypto so callers
can hand the GPU-built node buffer straight to the prover.

Also adds a stark-crate helper `try_build_merkle_tree_gpu` that
bridges a host `Vec<[u8; 32]>` leaves into the GPU kernel and back.

Not wired into the prover: a bench-against-baseline showed the 50-80 ms
of CPU tree-build time per table is already small enough that the
H2D-of-leaves + D2H-of-tree round-trip erases the gain (leaves come back
from `try_expand_and_leaf_hash_batched` in a pageable Vec, so the re-H2D
is ~slow). A real win needs an end-to-end fused LDE → leaf-hash → tree
pipeline where the leaf buffer never leaves the device — left as future
work. Kernel + parity tests land as infrastructure for that fusion.

Tests: `cargo test -p math-cuda --test merkle_tree` covers
log₂(N) ∈ {1..6, 10, 12, 14, 18} against a pure-CPU Keccak reference.
…eline

Adds `coset_lde_batch_{base,ext3}_into_with_merkle_tree` variants of the
with_leaf_hash entry points. Same LDE/Keccak path, but the leaf hashes
stay on device and feed straight into `keccak_merkle_level` so the full
`2*lde_size - 1` node buffer is built on the same stream and only the
final tree (not the intermediate leaves) crosses PCIe.

Wired into `commit_main_trace` and the aux trace commit via new
`try_expand_leaf_and_tree_batched{,_ext3}` helpers. The prover now gets
a finished `MerkleTree` back from one GPU call instead of
  H2D cols → LDE → leaf-hash → D2H(leaves, pageable) → CPU rayon tree build.

Benchmark on fib_iterative_{1M, 4M}, 3 runs × 5 trials:

  fib_1M: 13.552 s → 12.906 s  (−4.8%, was 28.1% faster than CPU,
                                 now 29.4%)
  fib_4M: 33.669 s → 32.931 s  (−2.2%)

Correctness: 121 stark cuda tests + all math-cuda parity tests pass.

Savings come from (a) skipping the 128 MB pinned→pageable memcpy that
the leaves round-trip needed, and (b) skipping the pageable H2D that a
separate GPU tree build would pay on re-upload. The remaining tree
kernel runtime is <10 ms per call (microsecond per level × log₂(N)
levels) — well inside what PCIe was previously spending on the
unnecessary leaf D2H.
Adds a row-pair Keccak leaf kernel `keccak_comp_poly_leaves_ext3`:
each thread hashes two bit-reversed rows × `num_parts` ext3 values in
the same byte order as `commit_composition_polynomial`. Reuses the
existing `keccak_merkle_level` for the inner tree.

Two device-side entry points:
 - `evaluate_poly_coset_batch_ext3_into_with_merkle_tree`: fused
   coefficient → LDE → tree (future wire site for number_of_parts > 2;
   currently unwired while we benchmark the H2D overhead of the
   separate path below).
 - `build_comp_poly_tree_from_evals_ext3`: takes already-computed LDE
   parts (from any of the three R2 branches — `== 1`, `== 2`,
   `> 2`) and runs just leaves + tree. Used by the prover.

Parity test (`tests/comp_poly_tree.rs`) checks the whole LDE + tree
pipeline against a CPU pipeline on log_n ∈ {2..5, 10, 12, 14} and
blowup ∈ {2, 4, 8} and parts ∈ {1, 2, 4}. All green.

Bench on `cargo test bench_gpu`, 3 runs each:
  fib_1M : 12.906 s → 12.951 s  (within noise; small trees hit the
                                 threshold guard and fall back to CPU)
  fib_4M : 32.931 s → 32.094 s  (−2.5 %; bigger trees benefit)

The neutral fib_1M number says the H2D of composition-poly LDE parts
is eating what should be a win — the real fix is to keep the LDE on
device after `try_evaluate_parts_on_lde_gpu` produces it (a future
change; the fused device path is already written against that day).
`keccak_fri_leaves_ext3` hashes 2 consecutive ext3 evals (48 BE bytes) per
leaf — matches `FriLayerMerkleTreeBackend::hash_data`. Reuses
`keccak_merkle_level` for the inner tree. Parity-tested on log_num_leaves
in {1..6, 10, 12, 14, 18}.

Wired into `commit_phase_from_evaluations` then reverted after A/B: the
per-layer H2D of the folded-evals slab (each layer is a fresh pageable
Vec from `fold_evaluations_in_place`) eats the tree-build savings, so
net is noise-to-slightly-negative on fib_1M and fib_4M. The real win
needs the FRI state to stay on device across layers (fold + leaves +
tree all GPU-side) — deferred to the "LDE GPU-resident across rounds"
item. The kernel stays here as a building block for that fusion.
Item 4 (architectural unlock) + item 3 together. The R1 fused pipeline
now optionally keeps the LDE device buffer alive and exposes it on
`LDETraceTable` via new `GpuLdeBase` / `GpuLdeExt3` handles. Downstream
GPU rounds can read the main/aux LDE directly from device without
paying a re-H2D of ~500 MB per call.

Concretely this ships item 3 (R4 deep_composition_poly_evaluations on
GPU). New kernel `deep_composition_ext3_row` in `kernels/deep.cu`: one
thread per trace-size row, sums ~(num_parts + num_total_cols ×
num_eval_points) ext3 FMA contributions, reading main LDE (base) and
aux LDE (ext3 de-interleaved) from the device handles plus
composition-parts LDE + scalar arrays H2D'd fresh each call.

Parity test `tests/deep.rs` covers small/medium/no-aux shapes against
a direct CPU port of the prover's row-wise loop.

Plumbing:
  - `coset_lde_batch_{base,ext3}_into_with_merkle_tree_keep` —
    variants that return `Arc<CudaSlice<u64>>` instead of dropping it.
  - `try_expand_leaf_and_tree_batched{,_ext3}_keep` — thin stark-side
    wrappers that propagate the handle.
  - `MainTraceCommitResult` struct replaces the 6-tuple return of
    `commit_main_trace` / `commit_preprocessed_trace`; adds a 7th
    `gpu_main: Option<GpuLdeBase>` field.
  - `Lde` struct gets matching `gpu_main` / `gpu_aux` fields.
  - `LDETraceTable` gains cfg-gated `gpu_main` / `gpu_aux` fields and
    set/get accessors.

Benchmark (cargo test bench_gpu, median of 3 runs × 5 trials):
  fib_1M: 13.02 s → 12.27 s  (−5.8 %, 1.49× vs CPU 18.27 s)
  fib_4M: 32.09 s → 29.75 s  (−7.3 %)

Correctness: 121 stark cuda tests + 43 math-cuda parity tests pass.
Adds strided variants of the barycentric kernels —
  barycentric_base_batched_strided,
  barycentric_ext3_batched_strided
— that take an extra `row_stride` and read every `row_stride`-th row
from each column. Lets R3 OOD operate directly on the LDE device
buffer from R1 (stride = blowup_factor for the trace-size coset) with
no H2D of column data at all.

Wired into `get_trace_evaluations_from_lde`: when `LDETraceTable.gpu_main`
/ `gpu_aux` are populated by the R1 fused pipeline, main + aux OOD
runs GPU-side per eval point; otherwise falls back to the rayon CPU
path. Host side still does the ~200 ms CPU prelude (inv_denoms batch
inverse + coset-points setup).

Parity test `tests/barycentric_strided.rs` checks the strided kernels
against the non-strided ones fed pre-strided buffers (log_trace ∈
{4, 8, 10, 12}, blowup ∈ {2, 4}, base and ext3).

Benchmark (median of 3×5 trials):
  fib_1M: 12.66 s → 12.38 s  (−2.2 %, 1.48× vs CPU 18.27 s)
  fib_4M: 29.75 s → 28.83 s  (−3.1 %)

Correctness: 121 stark cuda tests + all math-cuda parity tests pass.
get_trace_evaluations_from_lde used to unconditionally extract
trace-size Vec<FieldElement> slabs from LDETraceTable before looping
over eval points. With R3 OOD now running against device handles via
the strided barycentric kernels, those slabs are pure waste when the
GPU path fires — ~num_main_cols × n × 8 B per table of pageable Vec
alloc + populate.

Gate each extraction on `gpu_{main,aux}_available`: skip when the
R1 fused pipeline set the corresponding device handle on LDETraceTable.

Benchmark (fib_1M, median of 3×5 trials): 12.24 s → 11.93 s (−2.5 %).
New speedup 1.53× vs CPU 18.27 s (was 1.49×).

Correctness: 121 stark cuda tests + all math-cuda parity tests pass.
Adds `fri_fold_ext3` (one thread per output: `out[j] = (lo+hi) +
inv_tw[j]*zeta*(lo-hi)`) and `fri_update_twiddles` (`new[j] =
old[2j]²`). Not wired into `commit_phase_from_evaluations` yet — the
current CPU fold is ~0.1-0.2 s wall so the win is smaller than the
LDE-resident + barycentric optimisations that just landed. These
kernels are infrastructure for a future fully-on-device FRI commit
(fold + leaves + tree + root D2H per layer, keeping evals GPU-resident
across log(N) iterations, zisk pattern).

Also updates NOTES with the new 1.51× baseline.
Replaces per-element u64 copy loops (~1M u64 writes serially) with
slice-cast + copy_nonoverlapping. inv_t outer loop now runs in
parallel via rayon.

Bench: fib_1M median 12.13s → 11.88s (−2.0 %, 1.54× vs CPU).
`FriCommitState` in math-cuda owns two ping-pong ext3 eval buffers and
the base-field inv_twiddles buffer; each `fold_and_commit_layer(zeta)`
call launches fri_fold_ext3 → keccak_fri_leaves_ext3 →
keccak_merkle_level × log(n), plus fri_update_twiddles for the next
layer — all on the same stream, no cross-layer host round-trips.

Wired into `commit_phase_from_evaluations` via `try_fri_commit_gpu`:
the host loop still samples each layer's zeta from the transcript and
appends the root, but the folded evals, twiddles, and per-layer trees
never leave the device between iterations. Per-layer D2H is only the
32 B root + the layer's evals + its tree nodes (needed by
query_phase). Falls back to CPU when `cuda` off, type mismatch, or
domain below threshold.

The CPU `compute_coset_twiddles_inv` is still done on host (bit-reverse
permute + batch_inverse on n/2 base-field entries) — cheap vs. the
pattern of kernel launches we just avoided. Moving that to GPU too is
a follow-up.

Benchmark (median of 3×5):
  fib_1M : 12.04 s → 11.77 s  (−2.3 %, 1.55× vs CPU 18.27 s)
  fib_4M : 29.05 s → 28.34 s  (−2.4 %)

Correctness: 121 stark cuda tests pass end-to-end (prove/verify
round-trip is the ultimate parity gate).
Added `evaluate_poly_coset_batch_ext3_into_keep` that retains the LDE
device buffer as a GpuLdeExt3 handle. R2
`round_2_compute_composition_polynomial` now threads the handle into
`Round2::gpu_composition_parts` (cfg-gated). R4 deep_composition picks
it up via `deep_composition_ext3_with_dev_parts` which skips the
`num_parts * 3 * lde_size` u64 H2D of the composition-parts LDE.

Measured (mean of 3×15 trials on fib_1M): 11.64 s → 11.61 s. Neutral
within noise because the `number_of_parts > 2` branch that fires the
GPU parts LDE only triggers on a subset of AIRs; most fib_1M tables
have `number_of_parts == 2` and use `decompose_and_extend_d2` (no
handle populated). The plumbing still ships as architecturally clean
infrastructure for AIRs / programs that do hit the > 2 branch.
Each tier-3 item (stream overlap via cudaEvents, warp-level bary
reduction, GPU Montgomery batch inverse) was scoped and rejected:
either the per-call payoff is below the ~0.4 s run-to-run variance
on fib_1M, or the scope is larger than tier-3 intent.

Best candidate (GPU batch inverse) needs a parallel Blelloch scan
over ext3 to beat CPU's 7-way rayon parallelism across tables; the
single-thread variant I prototyped net-regresses. Deferred to tier-1.

Perf sits at tier-2's 1.57× on fib_1M. Branch pinned as the
traceable record of the investigation.
Ran nsys profile over 2 fib_1M proves (1 warmup + 1 measured). Out of
12 s wall-clock, ~2.6 s is CUDA activity (kernels + memcpy); 635 ms
of that is actual kernel compute. The rest (~9.4 s) is CPU work —
trace build, aux trace build, constraint eval, query openings.

Biggest kernel-time consumers per proof:
  ntt_dit_level_batched      243 ms / 1176 invocations (9.5 % CUDA)
  barycentric_ext3_strided    74 ms /   28 invocations (2.9 %)
  keccak_merkle_level         66 ms / 3312 invocations (2.6 %)
  bit_reverse_permute         56 ms /   98 invocations (2.2 %)
  keccak256_leaves_ext3       53 ms /   14 invocations (2.1 %)
  — all others < 50 ms each

Memcpy dominates CUDA activity:
  D2H 1275 ms, 16.3 GB (490 invocations)
  H2D  639 ms, 10.3 GB (1674 invocations)

Implications for the open optimisation list:
- Tile-based NTT layout (previously the tier-3 candidate) rejected —
  even a 2× speedup on all NTT kernels saves <100 ms wall because
  NTT compute is ~320 ms per proof and mostly overlapped.
- GPU Montgomery batch inverse still viable (~50-100 ms wall) but
  marginal.
- Constraint eval interpreter (item 5a, ~0.5-0.8 s wall) remains the
  biggest remaining GPU-side lever.
- Aux-trace-build + trace-build ports (~4.8 s wall combined) are the
  only path to 2× on fib_1M, and they require per-AIR / per-executor
  program logic porting. Multi-day scope.

Profile artefacts in /tmp/profile/fib_1m_nsys.{nsys-rep,sqlite} +
the per-kernel CSV analysis reproduced in PROFILE.md.

Also adds `prover/tests/bench_single.rs` — a single-prove bench used
as the nsys target (bench_gpu's 5-trial loop isn't ideal for
profiling).
Lambda Dev and others added 7 commits May 2, 2026 15:37
Adds three new GPU kernels in `kernels/inverse.cu`:
 - `compute_denoms_ext3`: pointwise kernel that writes
   `denoms[k*n+i] = x[i*stride] - z[k]` for all (k,i), base−ext3.
 - `chunk_{prefix,suffix}_scan_ext3` + totals-scan + apply-offsets +
   combine — the 6-kernel pipeline for parallel Montgomery batch
   inverse over ext3 elements. Chunk-based scan with K=256 threads,
   each thread owns C=ceil(N/K) elements serially; inter-chunk scan
   runs on one thread; final combine computes
   `inv[i] = prefix_incl[i-1] * suffix_incl[i+1] * inv_total`.

Wraps these in `math_cuda::inverse` with:
 - `batch_inverse_ext3(a)` — host input, full pipeline.
 - `batch_inverse_ext3_dev(a_dev, n)` — device input, reuses buffer.
 - `compute_and_invert_denoms_ext3(...)` — fused B.1 + B.2 (no
   intermediate D2H/H2D of the denoms buffer).

Wired into `compute_deep_composition_poly_evaluations` (R4 DEEP
prelude): replaces the sequential `for i in 0..n { denoms.push(...) }`
loop + `inplace_batch_inverse` with a single
`compute_and_invert_denoms_ext3` call. Falls back to CPU when
`cuda` is off, types aren't Goldilocks+Ext3, or domain_size < 1024.

R3 OOD prelude was prototyped using the same batch inverse but
regressed wall time by ~200 ms in a 2×15-trial A/B — 21 concurrent
batch-inverse kernel launches from 7 rayon-parallel tables × 3
eval-points contend on the stream pool, negating the CPU savings.
Kept on CPU for now; note left in trace.rs.

Parity tests in `tests/batch_inverse.rs` cover n ∈ {2, 3, 5, 16, 63,
255, 256, 257, 1024, 4096, 8192, 2^18, 2^20} against
`FieldElement::inplace_batch_inverse`. Also gated by the 121 stark
prove+verify round-trip tests.

Benchmark (fib_1M, mean of 5×15 trials):
  before:  11.64 s  (1.57× vs CPU 18.27 s)
  after:   11.25 s  (1.62×, −3.4 %)
…ed_term_column

Adds a complete device-side LogUp term-column pipeline: fingerprint
compute (supporting every Packing variant + OP_LINEAR), parallel
Montgomery batch inverse, and term assembly (every Multiplicity
variant). A bytecode format serialises BusInteraction into a
data-driven kernel input so we don't need per-air kernel variants.

Dispatch happens at the table level via `try_compute_table_term_columns`,
which uploads main_cols once per aux-build and walks all pairs
against the shared device buffer — the per-pair upload version
regressed fib_1M by ~5s from redundant ~240 MB H2Ds.

Current perf on fib_iterative_1M (15-trial mean):
  CPU (default):                            11.17s
  GPU table-batched (threshold=0 env):      11.81s
  GPU per-pair (earlier iteration):         16.06s

Still ~640 ms behind CPU because ~12 tables run build_auxiliary_trace
in parallel and each contends for the GPU. Gated off by default
(LAMBDA_VM_GPU_LOGUP_THRESHOLD=usize::MAX), so no regression to the
default CPU build or the shipping GPU path. Opt-in for experiments.

Parity: all 121 stark prove+verify tests pass with the GPU path
forced on. Verifier and constraints untouched.

See crypto/math-cuda/NOTES_LOGUP.md for detail + follow-up paths
(cross-table batching, fused multi-pair kernel, device-resident trace).
Captures perf data and a pil2-proofman (Zisk) kernel-by-kernel
comparison on top of exp-7-logup-gpu. No code changes.

Scale numbers (fib_iterative, LAMBDA_VM_GPU_LOGUP_THRESHOLD=1048576,
15-trial config — 5 for 1M/2M, 3 for 4M):
  fib_1M : 12.52 s  (1.00×)
  fib_2M : 20.33 s  (1.62×)
  fib_4M : 32.30 s  (2.58×)

Per-row cost drops 35% at 4M vs 1M — the GPU-favored regime; future
work should bench at fib_4M too.

Zisk comparison finds three concrete gaps vs pil2-proofman:
  1. Unified expression/constraint evaluator on device
  2. FRI query phase on device (genMerkleProof + getTreeTracePols)
  3. Device-resident trace with in-place writes (insertTracePol style)

NOTES_SCALE.md lays out the wall-time breakdown; ZISK_COMPARISON.md
maps their kernels against ours. `profiles/scale_bench.log` is the
raw instruments output from this run.
Restructure try_compute_table_term_columns into three phases:
  1. CPU prep (parallel across tables): flatten main_cols, serialize
     bytecode, compute alpha_powers
  2. GPU dispatch (under a global Mutex): upload main_cols, launch
     all pair/single kernels, D2H term columns as Vec<u64>
  3. CPU post (parallel across tables): reassemble FieldElement<E>
     vectors

Rationale: the prover's Pass 1 runs build_auxiliary_trace in a rayon
par_iter over ~12 tables. Without the mutex, each table fires its
own GPU stream, so H2D of each table's ~240 MB main_cols competes
on PCIe and kernel launches fight for SM time. The mutex serializes
only the GPU-bound portion; rayon still overlaps CPU work.

Bench (fib_iterative_1M, 15-trial mean, LAMBDA_VM_GPU_LOGUP_THRESHOLD
=1048576):
  exp-7 baseline (parallel GPU contention):  ~11.00 s
  exp-9 (serialized GPU dispatch):             10.96 s
  Per-trial Aux-trace-build wall:
    exp-7 ~2.66 s
    exp-9 ~1.90 s  (≈700 ms saved)

The aux-build improvement is real per-trial; total speedup is
smaller because Round 1 phases overlap with aux-build in rayon.
All 121 stark prove+verify tests pass with LAMBDA_VM_GPU_LOGUP_
THRESHOLD=0. No default-path regression; mutex only engages when
the GPU LogUp threshold is met.
Detailed plan for eliminating the redundant main-trace H2D between
Phase A (GPU LDE) and Pass 1 (GPU LogUp). Shipping as a doc-only
checkpoint — matches the exp-4 / exp-6 precedent — because the plumbing
touches 5 files and ~600-900 LOC, more than fits in one session.

Summary of the plan:
  1. Preserve the uploaded main_cols inside the fused LDE kernel via a
     device→device copy before iNTT (basically free at ~1 TB/s VRAM).
  2. Bubble a new Arc<DeviceMainCols> handle up through MainTraceCommit.
  3. Cache in logup_gpu.rs keyed by trace pointer, consult from aux-build.

Expected win on fib_1M: 200-300 ms wall (10.96s → ~10.6s), scaling
with trace size (600-800 ms on fib_4M).

See DESIGN_EXP11.md for the full breakdown, including the cache-key
stability caveat and the VRAM budget analysis for fib_4M.
Reconciles the GPU experiment branch with main's CPU optimizations
(#522 direct 2N DEEP, #566 R4 bit-reverse skip, #573 DomainConstants,

4 files needed manual conflict resolution: bin/cli/Cargo.toml,
crypto/stark/src/{lookup,trace,prover}.rs.

Resolution strategy:
- lookup.rs: take main's chunk-parallel CPU body, wrap in compute_cpu()
  closure so branch's table-level CUDA dispatcher falls back to it.
  Per-pair CUDA shortcut at the head of compute_logup_batched_term_column
  preserved; table_name parameter dropped to match main.
- trace.rs: take main's DomainConstants signature and inlined direct-from-LDE
  formula. Branch's gpu_main/gpu_aux device handles preserved on
  LDETraceTable; barycentric GPU fast paths now thread DomainConstants
  fields (dc.points, dc.size_inv, dc.offset_pow_n, dc.offset_pow_n_inv)
  to the unchanged gpu_lde:: function signatures.
- prover.rs:
  R3 OOD: take main (DomainConstants-based, calls trace::get_trace_evaluations_from_lde
    with &dc).
  R4: take main (no iFFT+FFT extension; deep_evals already at 2N points).
    Drops branch's try_r4_deep_poly_lde_gpu call site - kernel left in
    math-cuda for future use, dead-code warning expected.
  compute_deep_composition_poly_evaluations: take main's column-compressed
    Plonky3 path. Drops branch's try_deep_composition_gpu and the GPU
    denom port; kernels remain in math-cuda. Re-port to the new lde_size
    loop is a follow-up if benchmarks warrant.
  R1/R2 fused-pipeline GPU work, FRI commit device-resident, multi_prove
    GPU handle plumbing all preserved unchanged.
- bin/cli/Cargo.toml: union of both feature additions.

Branch kernels obsoleted by this merge (call sites removed, kernels kept
for possible reuse): try_r4_deep_poly_lde_gpu, try_deep_composition_gpu,
try_evaluate_parts_on_lde_gpu, gpu_composition_parts handle.

CPU optimizations on main that overlap GPU work in the branch (re-bench
needed): #561 parallel batch inverse vs branch GPU Montgomery batch
inverse; #548 chunk-parallel LogUp vs branch GPU LogUp port; #522 direct
2N DEEP + stride reads delivers branch's "skip CPU trace-slab extraction"
intent on CPU.

Original branch tip preserved at tag exp-11-archive (b5e11e3f).

cargo check --workspace and cargo check -p {math-cuda, stark --features
cuda, lambda-vm-prover --features cuda} all pass.
Switches the cudarc feature from cuda-13010 to cuda-12080. Reasons:

- cuda-13010 makes cudarc resolve newer CUDA 13.x symbols (e.g.
  cuDevSmResourceSplit) at init time. On systems whose libcuda.so was
  shipped with a slightly older CUDA 13 minor — including this box's
  Driver 580.105.08 — the dlsym fails and the barycentric tests panic.
- math-cuda only references long-stable APIs (cuMemFreeHost, etc.); we
  use no CUDA-13-only symbols. Targeting cuda-12080 with dynamic-loading
  enabled (already on) makes the binary load only CUDA 12.8.0+ symbols,
  which are present on every CUDA 12.8+ and 13.x driver.

Trade-off: any future cudarc API surface that requires CUDA 13 is
gated out. We don't currently use any.
@github-actions
Copy link
Copy Markdown

github-actions Bot commented May 2, 2026

Codex Code Review

Findings

  1. High (Security, memory safety) - crypto/stark/src/gpu_lde.rs:131
    The CUDA fast paths rely on debug_assert!(col.capacity() >= lde_size) before unsafe { col.set_len(lde_size) }. In release builds this check disappears, so any caller passing normal Vecs with insufficient capacity can trigger out-of-bounds writes/UB through safe public prover APIs. This pattern repeats in the same file. Use a real runtime check and either reserve/resize safely or fall back to CPU.

  2. High (Build bug) - Cargo.toml:8, crypto/math-cuda/build.rs:32
    crypto/math-cuda is added as a normal workspace member, but its build script unconditionally invokes nvcc. That makes ordinary workspace builds/checks fail on non-CUDA machines even though stark gates the dependency behind the cuda feature. Keep it out of default workspace builds, gate the build script, or make the crate opt-in for CUDA-only CI/jobs.

  3. Medium (Build bug) - crypto/stark/src/prover.rs:193, crypto/stark/src/prover.rs:740
    With both cuda and debug-checks enabled, Lde has gpu_main/gpu_aux fields, but reconstruct_round1 constructs Lde { main, aux } without them. That feature combination will not compile. Add gpu_main: None, gpu_aux: None under #[cfg(feature = "cuda")] there.

  Removes 74k lines of checkpoint patch dumps (artifacts/checkpoint-*/)
  and bench logs (profiles/*.log) that were committed during the GPU
  experiment journey. They duplicate the actual code changes as .patch
  files and bloat the PR with no review value. Adds .gitignore rules so
  they don't leak in again.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant