diff --git a/.github/workflows/pr_main.yaml b/.github/workflows/pr_main.yaml index 518cb7aa6..68fae4fb0 100644 --- a/.github/workflows/pr_main.yaml +++ b/.github/workflows/pr_main.yaml @@ -105,16 +105,18 @@ jobs: test: name: Test if: always() - needs: [test-executor, test-prover] + needs: [test-executor, test-prover, test-disk-spill] runs-on: ubuntu-latest steps: - name: Check results run: | executor="${{ needs.test-executor.result }}" prover="${{ needs.test-prover.result }}" + disk_spill="${{ needs.test-disk-spill.result }}" echo "test-executor: $executor" echo "test-prover: $prover" + echo "test-disk-spill: $disk_spill" # Allow "success" or "skipped" (skipped on merge queue pushes) if [[ "$executor" != "success" && "$executor" != "skipped" ]]; then @@ -123,6 +125,68 @@ jobs: if [[ "$prover" != "success" && "$prover" != "skipped" ]]; then exit 1 fi + if [[ "$disk_spill" != "success" && "$disk_spill" != "skipped" ]]; then + exit 1 + fi + + test-disk-spill: + name: Disk-spill tests + runs-on: ubuntu-latest + if: github.event_name != 'push' || github.actor != 'github-merge-queue[bot]' + steps: + - name: Checkout sources + uses: actions/checkout@v4 + + - name: Setup Rust Environment + uses: ./.github/actions/setup-rust + + - name: Cache cargo build artifacts + uses: Swatinem/rust-cache@v2 + with: + shared-key: "lambda-vm-disk-spill" + cache-all-crates: "true" + + - name: Cache compiled ASM ELF artifacts + id: cache-asm-elfs + uses: actions/cache@v4 + with: + path: executor/program_artifacts/asm + key: asm-elf-artifacts-${{ hashFiles('executor/programs/asm/**') }} + + - name: Install clang and lld + if: steps.cache-asm-elfs.outputs.cache-hit != 'true' + run: sudo apt-get update && sudo apt-get install -y clang lld + + - name: Compile ASM programs to ELF + if: steps.cache-asm-elfs.outputs.cache-hit != 'true' + run: | + make compile-programs-asm + + - name: Cache compiled Rust ELF artifacts and build cache + id: cache-rust-elfs + uses: actions/cache@v4 + with: + path: | + executor/program_artifacts/rust + executor/shared_target + key: rust-elf-artifacts-${{ hashFiles('executor/programs/rust/**', 'executor/programs/riscv64im-lambda-vm-elf.json', 'syscalls/**', 'Makefile') }} + restore-keys: | + rust-elf-artifacts- + + - name: Compile Rust programs to ELF + if: steps.cache-rust-elfs.outputs.cache-hit != 'true' + run: | + make compile-programs-rust + + - name: Run stark disk-spill tests + run: | + cargo test --release -p stark --features disk-spill disk_spill + + - name: Run prover disk-spill tests + env: + FORCE_DISK_SPILL: "1" + run: | + cargo test --release -p lambda-vm-prover --features disk-spill -- disk_spill count_table_lengths build-prover-tests: name: Build prover tests diff --git a/Cargo.lock b/Cargo.lock index 0f01bf090..70b4071e8 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -567,6 +567,7 @@ version = "0.1.0" dependencies = [ "bincode", "clap 4.5.53", + "env_logger", "executor", "lambda-vm-prover", "stark", @@ -771,14 +772,18 @@ checksum = "460fbee9c2c2f33933d720630a6a0bac33ba7053db5344fac858d4b8952d77d5" name = "crypto" version = "0.1.0" dependencies = [ + "bincode", "digest", + "libc", "math", + "memmap2", "rand 0.8.5", "rand_chacha 0.3.1", "rayon", "serde", "sha2", "sha3", + "tempfile", ] [[package]] @@ -1623,7 +1628,7 @@ dependencies = [ "js-sys", "log", "wasm-bindgen", - "windows-core", + "windows-core 0.62.2", ] [[package]] @@ -1947,14 +1952,19 @@ dependencies = [ name = "lambda-vm-prover" version = "0.1.0" dependencies = [ + "bincode", "criterion 0.5.1", "crypto", "env_logger", "executor", + "log", "math", "rayon", "serde", "stark", + "sysinfo", + "tikv-jemalloc-ctl", + "tikv-jemallocator", "tiny-keccak", ] @@ -2145,6 +2155,15 @@ version = "2.7.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f52b00d39961fc5b2736ea853c9cc86238e165017a493d1d5c8eac6bdc4cc273" +[[package]] +name = "memmap2" +version = "0.9.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "714098028fe011992e1c3962653c96b2d578c4b4bce9036e15ff220319b1e0e3" +dependencies = [ + "libc", +] + [[package]] name = "munge" version = "0.4.7" @@ -2165,6 +2184,15 @@ dependencies = [ "syn 2.0.111", ] +[[package]] +name = "ntapi" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c3b335231dfd352ffb0f8017f3b6027a4917f7df785ea2143d8af2adc66980ae" +dependencies = [ + "winapi", +] + [[package]] name = "nu-ansi-term" version = "0.50.3" @@ -3200,18 +3228,22 @@ checksum = "6ce2be8dc25455e1f91df71bfa12ad37d7af1092ae736f3a6cd0e37bc7810596" name = "stark" version = "0.1.0" dependencies = [ + "bincode", "criterion 0.4.0", "crypto", "env_logger", "itertools 0.11.0", + "libc", "log", "math", "math-cuda", + "memmap2", "rayon", "serde", "serde-wasm-bindgen", "serde_cbor", "sha3", + "tempfile", "test-log", "thiserror 1.0.69", "wasm-bindgen", @@ -3290,6 +3322,19 @@ dependencies = [ "syn 2.0.111", ] +[[package]] +name = "sysinfo" +version = "0.31.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "355dbe4f8799b304b05e1b0f05fc59b2a18d36645cf169607da45bde2f69a1be" +dependencies = [ + "core-foundation-sys", + "libc", + "memchr", + "ntapi", + "windows", +] + [[package]] name = "tap" version = "1.0.1" @@ -3826,19 +3871,52 @@ version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" +[[package]] +name = "windows" +version = "0.57.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "12342cb4d8e3b046f3d80effd474a7a02447231330ef77d71daa6fbc40681143" +dependencies = [ + "windows-core 0.57.0", + "windows-targets", +] + +[[package]] +name = "windows-core" +version = "0.57.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d2ed2439a290666cd67ecce2b0ffaad89c2a56b976b736e6ece670297897832d" +dependencies = [ + "windows-implement 0.57.0", + "windows-interface 0.57.0", + "windows-result 0.1.2", + "windows-targets", +] + [[package]] name = "windows-core" version = "0.62.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b8e83a14d34d0623b51dce9581199302a221863196a1dde71a7663a4c2be9deb" dependencies = [ - "windows-implement", - "windows-interface", + "windows-implement 0.60.2", + "windows-interface 0.59.3", "windows-link", - "windows-result", + "windows-result 0.4.1", "windows-strings", ] +[[package]] +name = "windows-implement" +version = "0.57.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9107ddc059d5b6fbfbffdfa7a7fe3e22a226def0b2608f72e9d552763d3e1ad7" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.111", +] + [[package]] name = "windows-implement" version = "0.60.2" @@ -3850,6 +3928,17 @@ dependencies = [ "syn 2.0.111", ] +[[package]] +name = "windows-interface" +version = "0.57.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "29bee4b38ea3cde66011baa44dba677c432a78593e202392d1e9070cf2a7fca7" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.111", +] + [[package]] name = "windows-interface" version = "0.59.3" @@ -3867,6 +3956,15 @@ version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f0805222e57f7521d6a62e36fa9163bc891acd422f971defe97d64e70d0a4fe5" +[[package]] +name = "windows-result" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5e383302e8ec8515204254685643de10811af0ed97ea37210dc26fb0032647f8" +dependencies = [ + "windows-targets", +] + [[package]] name = "windows-result" version = "0.4.1" @@ -3894,6 +3992,70 @@ dependencies = [ "windows-link", ] +[[package]] +name = "windows-targets" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9b724f72796e036ab90c1021d4780d4d3d648aca59e491e6b98e725b84e99973" +dependencies = [ + "windows_aarch64_gnullvm", + "windows_aarch64_msvc", + "windows_i686_gnu", + "windows_i686_gnullvm", + "windows_i686_msvc", + "windows_x86_64_gnu", + "windows_x86_64_gnullvm", + "windows_x86_64_msvc", +] + +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32a4622180e7a0ec044bb555404c800bc9fd9ec262ec147edd5989ccd0c02cd3" + +[[package]] +name = "windows_aarch64_msvc" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09ec2a7bb152e2252b53fa7803150007879548bc709c039df7627cabbd05d469" + +[[package]] +name = "windows_i686_gnu" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e9b5ad5ab802e97eb8e295ac6720e509ee4c243f69d781394014ebfe8bbfa0b" + +[[package]] +name = "windows_i686_gnullvm" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0eee52d38c090b3caa76c563b86c3a4bd71ef1a819287c19d586d7334ae8ed66" + +[[package]] +name = "windows_i686_msvc" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "240948bc05c5e7c6dabba28bf89d89ffce3e303022809e73deaefe4f6ec56c66" + +[[package]] +name = "windows_x86_64_gnu" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "147a5c80aabfbf0c7d901cb5895d1de30ef2907eb21fbbab29ca94c5b08b1a78" + +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "24d5b23dc417412679681396f2b49f3de8c1473deb516bd34410872eff51ed0d" + +[[package]] +name = "windows_x86_64_msvc" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" + [[package]] name = "winnow" version = "0.7.14" diff --git a/Makefile b/Makefile index aadd3d961..fcde68e9c 100644 --- a/Makefile +++ b/Makefile @@ -1,7 +1,7 @@ .PHONY: deps deps-linux deps-macos prepare-test-data compile-programs-asm compile-programs-rust compile-bench \ compile-programs clean-asm clean-rust clean-bench clean-shared clean test test-asm test-no-compile \ test-asm-no-compile test-rust test-rust-no-compile test-executor flamegraph-prover \ -test-fast test-prover test-prover-all test-math-cuda bench-math-cuda build check clippy fmt lint +test-fast test-prover test-prover-all test-disk-spill test-math-cuda bench-math-cuda build check clippy fmt lint UNAME := $(shell uname) @@ -185,6 +185,11 @@ test-prover-all: test-prover-debug: cargo test -p lambda-vm-prover --features debug-checks -- --nocapture +# Disk-spill tests (stark + prover). FORCE_DISK_SPILL is required by the prover tests. +test-disk-spill: + cargo test --release -p stark --features disk-spill disk_spill + FORCE_DISK_SPILL=1 cargo test --release -p lambda-vm-prover --features disk-spill -- disk_spill count_table_lengths + # math-cuda parity tests (requires NVIDIA GPU + nvcc) test-math-cuda: cargo test -p math-cuda --release @@ -208,6 +213,7 @@ check: clippy: cargo clippy --workspace --all-targets -- -D warnings -A clippy::op_ref cargo clippy --workspace --all-targets --no-default-features --features lambda-vm-prover/debug-checks -- -D warnings -A clippy::op_ref + cargo clippy --workspace --all-targets --features lambda-vm-prover/disk-spill -- -D warnings -A clippy::op_ref fmt: cargo fmt --all @@ -217,6 +223,7 @@ lint: cargo fmt --check --all cargo clippy --workspace --all-targets -- -D warnings -A clippy::op_ref cargo clippy --workspace --all-targets --no-default-features --features lambda-vm-prover/debug-checks -- -D warnings -A clippy::op_ref + cargo clippy --workspace --all-targets --features lambda-vm-prover/disk-spill -- -D warnings -A clippy::op_ref flamegraph-prover: cd crypto/stark && samply record cargo bench --bench profile_prover --features parallel diff --git a/bin/cli/Cargo.toml b/bin/cli/Cargo.toml index fdc8eab8c..e6f16582f 100644 --- a/bin/cli/Cargo.toml +++ b/bin/cli/Cargo.toml @@ -11,7 +11,9 @@ clap = { version = "4.3.10", features = ["derive"] } bincode = "1" tikv-jemallocator = "0.6" tikv-jemalloc-ctl = { version = "0.6", features = ["stats"], optional = true } +env_logger = "0.11" [features] jemalloc-stats = ["dep:tikv-jemalloc-ctl"] +disk-spill = ["prover/disk-spill"] instruments = ["prover/instruments", "stark/instruments"] diff --git a/bin/cli/src/main.rs b/bin/cli/src/main.rs index f166e751d..bdcea9518 100644 --- a/bin/cli/src/main.rs +++ b/bin/cli/src/main.rs @@ -174,6 +174,7 @@ enum Commands { } fn main() -> ExitCode { + env_logger::init(); let cli = Cli::parse(); match cli.command { diff --git a/crypto/crypto/Cargo.toml b/crypto/crypto/Cargo.toml index ff91bae63..89e314c25 100644 --- a/crypto/crypto/Cargo.toml +++ b/crypto/crypto/Cargo.toml @@ -18,12 +18,16 @@ serde = { version = "1.0", default-features = false, features = [ rayon = { version = "1.8.0", optional = true } rand = { version = "0.8.5", default-features = false } rand_chacha = { version = "0.3.1", default-features = false } +memmap2 = { version = "0.9", optional = true } +tempfile = { version = "3", optional = true } +libc = { version = "0.2", optional = true } [dev-dependencies] math = { path = "../math", features = ["test-utils"] } rand = "0.8.5" rand_chacha = "0.3.1" sha2 = { version = "0.10", default-features = false } +bincode = "1" [features] default = ["asm", "std"] @@ -31,4 +35,5 @@ asm = ["sha3/asm"] std = ["math/std", "sha3/std", "serde?/std"] serde = ["dep:serde"] parallel = ["dep:rayon"] +disk-spill = ["std", "dep:memmap2", "dep:tempfile", "dep:libc"] alloc = [] \ No newline at end of file diff --git a/crypto/crypto/src/lib.rs b/crypto/crypto/src/lib.rs index 20462a407..d7a273d62 100644 --- a/crypto/crypto/src/lib.rs +++ b/crypto/crypto/src/lib.rs @@ -1,11 +1,17 @@ #![allow(clippy::op_ref)] #![cfg_attr(not(feature = "std"), no_std)] + +#[cfg(all(target_arch = "wasm32", feature = "disk-spill"))] +compile_error!("the `disk-spill` feature requires memmap2, which does not compile on wasm32"); + #[macro_use] extern crate alloc; pub mod fiat_shamir; pub mod hash; pub mod merkle_tree; +#[cfg(feature = "disk-spill")] +pub mod mmap_util; #[cfg(test)] pub mod tests; diff --git a/crypto/crypto/src/merkle_tree/merkle.rs b/crypto/crypto/src/merkle_tree/merkle.rs index b702a846e..4ea0e5411 100644 --- a/crypto/crypto/src/merkle_tree/merkle.rs +++ b/crypto/crypto/src/merkle_tree/merkle.rs @@ -4,6 +4,8 @@ use crate::merkle_tree::proof::BatchProof; use super::{proof::Proof, traits::IsMerkleTreeBackend, utils::*}; use alloc::{collections::BTreeSet, vec::Vec}; +#[cfg(feature = "disk-spill")] +use math::spill_safe::SpillSafe; #[derive(Debug)] pub enum Error { @@ -22,6 +24,16 @@ impl Display for Error { #[cfg(feature = "std")] impl std::error::Error for Error {} +/// File-backed mmap storage for Merkle tree nodes. +/// +/// After `spill_nodes_to_disk()`, the in-memory node vector is freed and +/// node access goes through this mmap instead. +#[cfg(feature = "disk-spill")] +pub(crate) struct MmapNodeBacking { + mmap: memmap2::Mmap, + node_count: usize, +} + /// The struct for the Merkle tree, consisting of the root and the nodes. /// A typical tree would look like this /// root @@ -31,11 +43,68 @@ impl std::error::Error for Error {} /// leaf 1 leaf 2 leaf 3 leaf 4 /// The bottom leafs correspond to the hashes of the elements, while each upper /// layer contains the hash of the concatenation of the daughter nodes. -#[derive(Clone)] -#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +#[cfg_attr(not(feature = "disk-spill"), derive(Clone))] +#[cfg_attr( + all(feature = "serde", not(feature = "disk-spill")), + derive(serde::Serialize, serde::Deserialize) +)] +#[cfg_attr( + all(feature = "serde", feature = "disk-spill"), + derive(serde::Deserialize) +)] pub struct MerkleTree { pub root: B::Node, nodes: Vec, + #[cfg(feature = "disk-spill")] + #[cfg_attr(feature = "serde", serde(skip))] + mmap_backing: Option, +} + +// `mmap_backing` is `#[serde(skip)]` and `spill_nodes_to_disk` empties `nodes`, +// so the default derive would emit `{root, nodes: []}` and lose the tree. +// +// Output matches the non-disk-spill derive byte-for-byte, so a proof from either +// storage mode deserializes with the same `Deserialize` impl. +#[cfg(all(feature = "serde", feature = "disk-spill"))] +impl serde::Serialize for MerkleTree +where + B::Node: serde::Serialize, +{ + fn serialize(&self, serializer: S) -> Result { + use serde::ser::SerializeStruct; + let mut s = serializer.serialize_struct("MerkleTree", 2)?; + s.serialize_field("root", &self.root)?; + if self.mmap_backing.is_some() { + s.serialize_field("nodes", &MmapNodesSeq(self))?; + } else { + s.serialize_field("nodes", &self.nodes)?; + } + s.end() + } +} + +#[cfg(all(feature = "serde", feature = "disk-spill"))] +struct MmapNodesSeq<'a, B: IsMerkleTreeBackend>(&'a MerkleTree); + +#[cfg(all(feature = "serde", feature = "disk-spill"))] +impl serde::Serialize for MmapNodesSeq<'_, B> +where + B::Node: serde::Serialize, +{ + fn serialize(&self, serializer: S) -> Result { + use serde::ser::SerializeSeq; + let backing = self + .0 + .mmap_backing + .as_ref() + .expect("MmapNodesSeq is only constructed when mmap_backing is Some"); + let n = backing.node_count; + let mut seq = serializer.serialize_seq(Some(n))?; + for i in 0..n { + seq.serialize_element(self.0.node_get(i).expect("index in bounds"))?; + } + seq.end() + } } const ROOT: usize = 0; @@ -78,9 +147,44 @@ where Some(MerkleTree { root: nodes[ROOT].clone(), nodes, + #[cfg(feature = "disk-spill")] + mmap_backing: None, }) } + /// Total number of nodes in the tree (inner + leaves). + fn node_count(&self) -> usize { + #[cfg(feature = "disk-spill")] + if let Some(ref backing) = self.mmap_backing { + return backing.node_count; + } + self.nodes.len() + } + + /// Access a node by index, returning a reference. + /// + /// Returns `None` if `idx` is out of bounds. + fn node_get(&self, idx: usize) -> Option<&B::Node> { + #[cfg(feature = "disk-spill")] + if let Some(ref backing) = self.mmap_backing { + if idx < backing.node_count { + // SAFETY: spill_nodes_to_disk is the only function that populates + // mmap_backing, and its where-clause requires B::Node: SpillSafe. + // Reaching this branch means that bound was checked at construction, + // so B::Node carries no padding and every bit pattern is valid. + // + // Alignment: the mmap base is page-aligned (>= 4096), spill_slice_to_mmap + // asserts align_of::() <= 4096, and Rust guarantees + // size_of:: is a multiple of align_of::, so every + // offset idx * node_size lands on an aligned address. + let ptr = unsafe { backing.mmap.as_ptr().add(idx * size_of::()) }; + return Some(unsafe { &*(ptr as *const B::Node) }); + } + return None; + } + self.nodes.get(idx) + } + /// Read-only access to the full node buffer in standard layout: /// `nodes[0..leaves_len - 1]` are inner nodes (root at index 0) and /// `nodes[leaves_len - 1..]` are the leaves. @@ -92,7 +196,7 @@ where /// For example, give me an inclusion proof for the 3rd element in the /// Merkle tree pub fn get_proof_by_pos(&self, pos: usize) -> Option> { - let pos = pos + self.nodes.len() / 2; + let pos = pos + self.node_count() / 2; let Ok(merkle_path) = self.build_merkle_path(pos) else { return None; }; @@ -108,12 +212,12 @@ where /// Returns the Merkle path for the element/s for the leaf at position pos fn build_merkle_path(&self, pos: usize) -> Result, Error> { // Pre-allocate based on tree depth (log2 of tree size) - let tree_depth = (self.nodes.len() + 1).ilog2() as usize; + let tree_depth = (self.node_count() + 1).ilog2() as usize; let mut merkle_path = Vec::with_capacity(tree_depth); let mut pos = pos; while pos != ROOT { - let Some(node) = self.nodes.get(sibling_index(pos)) else { + let Some(node) = self.node_get(sibling_index(pos)) else { // out of bounds, exit returning the current merkle_path return Err(Error::OutOfBounds); }; @@ -148,7 +252,7 @@ where return Err(Error::EmptyPositionList); } - let num_leaves = (self.nodes.len() + 1).div_ceil(2); + let num_leaves = (self.node_count() + 1).div_ceil(2); // Validate all positions are within bounds for &pos in pos_list { @@ -161,7 +265,7 @@ where // of the leaves. let leaf_positions = pos_list .iter() - .map(|pos| pos + self.nodes.len() / 2) + .map(|pos| pos + self.node_count() / 2) .collect::>(); // We get the positions of the nodes for the batch proof. let batch_auth_path_positions = self.get_batch_auth_path_positions(&leaf_positions); @@ -169,7 +273,11 @@ where // We get the nodes for the batch proof. let batch_auth_path_nodes = batch_auth_path_positions .iter() - .map(|pos| self.nodes[*pos].clone()) + .map(|pos| { + self.node_get(*pos) + .expect("batch auth path position in bounds") + .clone() + }) .collect(); Ok(BatchProof { @@ -195,7 +303,7 @@ where let mut obtainable: BTreeSet = leaf_positions.iter().cloned().collect(); // Number of levels in tree - let num_levels = (self.nodes.len() + 1).ilog2(); + let num_levels = (self.node_count() + 1).ilog2(); // Iter lefevel-by-level from leaves to root. for _ in 0..num_levels - 1 { @@ -224,4 +332,59 @@ where // This makes the proof ordered from bottom (nodes closer to leaves) to top (nodes loser to root). auth_path_set.into_iter().rev().collect() } + + /// Spill the node vector to a temp-file-backed mmap and free the heap + /// allocation. Node access methods read from the mmap after this call. + #[cfg(feature = "disk-spill")] + pub fn spill_nodes_to_disk(&mut self) -> std::io::Result<()> + where + B::Node: SpillSafe, + { + if self.nodes.is_empty() || self.mmap_backing.is_some() { + return Ok(()); + } + + let node_count = self.nodes.len(); + let mmap = crate::mmap_util::spill_slice_to_mmap(&self.nodes)?; + self.nodes = Vec::new(); + self.mmap_backing = Some(MmapNodeBacking { mmap, node_count }); + + Ok(()) + } +} + +#[cfg(all(test, feature = "serde", feature = "disk-spill"))] +mod disk_spill_serde_tests { + use super::*; + use crate::merkle_tree::backends::field_element::FieldElementBackend; + use math::field::{element::FieldElement, goldilocks::GoldilocksField}; + use sha3::Keccak256; + + type F = GoldilocksField; + type FE = FieldElement; + type Backend = FieldElementBackend; + + /// Serializing a spilled MerkleTree must produce identical bytes to + /// serializing the same tree before spilling, and round-trip back to an + /// equal tree. + #[test] + fn test_serialize_spilled_merkle_tree_matches_unspilled() { + let values: Vec = (1..17).map(FE::from).collect(); + let unspilled = MerkleTree::::build(&values).expect("build merkle tree"); + let unspilled_bytes = bincode::serialize(&unspilled).expect("serialize unspilled"); + + let mut spilled = MerkleTree::::build(&values).expect("build merkle tree"); + spilled.spill_nodes_to_disk().expect("spill_nodes_to_disk"); + let spilled_bytes = bincode::serialize(&spilled).expect("serialize spilled"); + + assert_eq!( + spilled_bytes, unspilled_bytes, + "spilled and unspilled trees must serialize to identical bytes" + ); + + let restored: MerkleTree = + bincode::deserialize(&spilled_bytes).expect("deserialize spilled bytes"); + assert!(restored.mmap_backing.is_none()); + assert_eq!(restored.root, unspilled.root); + } } diff --git a/crypto/crypto/src/mmap_util.rs b/crypto/crypto/src/mmap_util.rs new file mode 100644 index 000000000..c5600c5d9 --- /dev/null +++ b/crypto/crypto/src/mmap_util.rs @@ -0,0 +1,71 @@ +use core::slice; +use math::spill_safe::SpillSafe; +use memmap2::{Mmap, MmapOptions}; +use std::fs::File; +use std::io::{Error, ErrorKind, Result}; + +/// Mmap a fresh temp file, copy `slice` into the mapping, downgrade to +/// read-only, and return it. +/// +/// Alignment: the mmap base is page-aligned (>= 4096), this function +/// asserts `align_of::() <= 4096`, and Rust guarantees `size_of::()` +/// is a multiple of `align_of::()`, so every element offset is aligned. +pub fn spill_slice_to_mmap(slice: &[T]) -> Result { + const { + assert!( + align_of::() <= 4096, + "T alignment must fit within mmap page alignment" + ) + } + + let elem_size = size_of::(); + let total_bytes = (slice.len() as u64) + .checked_mul(elem_size as u64) + .ok_or_else(|| { + Error::new( + ErrorKind::InvalidInput, + "spill_slice_to_mmap: byte count overflows u64", + ) + })?; + + let file = tempfile::tempfile()?; + reserve_file_blocks(&file, total_bytes)?; + + // SAFETY: tempfile() creates an anonymous file with no filesystem path, + // so no other process can open or modify it. + let mut mmap_mut = unsafe { MmapOptions::new().map_mut(&file)? }; + // SAFETY: SpillSafe's safety contract requires no padding on T, so + // `slice`'s bytes are initialized and reading them as &[u8] is sound. + let bytes: &[u8] = + unsafe { slice::from_raw_parts(slice.as_ptr() as *const u8, size_of_val(slice)) }; + mmap_mut.copy_from_slice(bytes); + mmap_mut.make_read_only() +} + +/// Reserve disk blocks up front so this call fails on a full disk. +/// Without reservation, the kernel sends SIGBUS during the later mmap write. +/// +/// Linux only, using `posix_fallocate`. On other platforms we only call +/// `set_len` and skip reservation, so the kernel can still send SIGBUS if +/// the disk fills mid-write. +/// +/// `/tmp` is often tmpfs (RAM-backed) on systemd-default distros; set +/// `TMPDIR` to a disk-backed path so spill files actually live on disk. +fn reserve_file_blocks(file: &File, total_bytes: u64) -> Result<()> { + file.set_len(total_bytes)?; + #[cfg(target_os = "linux")] + { + use std::os::unix::io::AsRawFd; + let len = i64::try_from(total_bytes).map_err(|_| { + Error::new( + ErrorKind::InvalidInput, + "spill file too large for posix_fallocate", + ) + })?; + let ret = unsafe { libc::posix_fallocate(file.as_raw_fd(), 0, len) }; + if ret != 0 { + return Err(Error::from_raw_os_error(ret)); + } + } + Ok(()) +} diff --git a/crypto/math/src/field/element.rs b/crypto/math/src/field/element.rs index a94886763..0eb0aef96 100644 --- a/crypto/math/src/field/element.rs +++ b/crypto/math/src/field/element.rs @@ -39,7 +39,13 @@ use serde::ser::{Serialize, SerializeStruct, Serializer}; use super::traits::{IsPrimeField, IsSubFieldOf, LegendreSymbol}; /// A field element with operations algorithms defined in `F` +/// +/// `#[repr(transparent)]` makes `FieldElement` byte-identical to +/// `F::BaseType`, which [`SpillSafe`](crate::spill_safe::SpillSafe) +/// requires. Changing the `repr` or adding fields breaks this and +/// is UB in any function that requires `T: SpillSafe`. #[allow(clippy::derived_hash_with_manual_eq)] +#[repr(transparent)] #[derive(Debug, Clone, Hash, Copy)] pub struct FieldElement { value: F::BaseType, diff --git a/crypto/math/src/lib.rs b/crypto/math/src/lib.rs index 2f2f1fccb..9d5e6dd97 100644 --- a/crypto/math/src/lib.rs +++ b/crypto/math/src/lib.rs @@ -6,6 +6,7 @@ extern crate alloc; pub mod errors; pub mod field; pub mod helpers; +pub mod spill_safe; pub mod traits; pub mod unsigned_integer; diff --git a/crypto/math/src/spill_safe.rs b/crypto/math/src/spill_safe.rs new file mode 100644 index 000000000..7bcbdf103 --- /dev/null +++ b/crypto/math/src/spill_safe.rs @@ -0,0 +1,34 @@ +//! Marker trait for types whose in-memory bytes can be reinterpreted as the +//! same type without UB: no padding, every bit pattern valid, no indirection. +//! +//! Stricter than `Copy`, which permits types with restricted bit patterns +//! (e.g. `bool`, `NonZeroU32`). +//! +//! `unsafe impl` puts the layout invariants on the implementer. The +//! compiler does not check. + +use crate::field::{element::FieldElement, traits::IsField}; + +/// # Safety +/// Implementer asserts `Self`'s memory representation contains no padding, +/// every bit pattern is a valid value of `Self`, and `Self` carries no +/// indirection (heap pointers, references, etc.). Adding this `unsafe impl` +/// for a type that violates these invariants is UB at any byte cast. +pub unsafe trait SpillSafe: Copy + 'static {} + +unsafe impl SpillSafe for u8 {} +unsafe impl SpillSafe for u16 {} +unsafe impl SpillSafe for u32 {} +unsafe impl SpillSafe for u64 {} +unsafe impl SpillSafe for u128 {} +unsafe impl SpillSafe for i8 {} +unsafe impl SpillSafe for i16 {} +unsafe impl SpillSafe for i32 {} +unsafe impl SpillSafe for i64 {} +unsafe impl SpillSafe for i128 {} + +unsafe impl SpillSafe for [T; N] {} + +// `FieldElement` is `#[repr(transparent)]` over `F::BaseType`, so its +// layout matches the base type's exactly. SpillSafe propagates through. +unsafe impl SpillSafe for FieldElement where F::BaseType: SpillSafe {} diff --git a/crypto/stark/Cargo.toml b/crypto/stark/Cargo.toml index 4d1f2cbca..b342fbdc0 100644 --- a/crypto/stark/Cargo.toml +++ b/crypto/stark/Cargo.toml @@ -22,6 +22,9 @@ itertools = "0.11.0" # Parallelization crates rayon = { version = "1.8.0", optional = true } +memmap2 = { version = "0.9", optional = true } +tempfile = { version = "3", optional = true } +libc = { version = "0.2", optional = true } # GPU backend for trace LDE — only linked when `cuda` is enabled. math-cuda = { path = "../math-cuda", optional = true } @@ -35,6 +38,7 @@ serde_cbor = { version = "0.11.1" } criterion = { version = "0.4", default-features = false } env_logger = "*" test-log = { version = "0.2.11", features = ["log"] } +bincode = "1" [features] test-utils = [] @@ -44,6 +48,7 @@ debug-checks = [] # Enables v parallel = ["dep:rayon", "crypto/parallel"] cuda = ["dep:math-cuda"] wasm = ["dep:wasm-bindgen", "dep:serde-wasm-bindgen", "dep:web-sys"] +disk-spill = ["dep:memmap2", "dep:tempfile", "dep:libc", "crypto/disk-spill"] [package.metadata.wasm-pack.profile.dev] diff --git a/crypto/stark/src/fri/fri_commitment.rs b/crypto/stark/src/fri/fri_commitment.rs index 7eb530452..b0b3188b2 100644 --- a/crypto/stark/src/fri/fri_commitment.rs +++ b/crypto/stark/src/fri/fri_commitment.rs @@ -4,7 +4,7 @@ use math::{ traits::AsBytes, }; -#[derive(Clone)] +#[cfg_attr(not(feature = "disk-spill"), derive(Clone))] pub struct FriLayer where F: IsField, diff --git a/crypto/stark/src/lib.rs b/crypto/stark/src/lib.rs index 09ca16ed4..acc8420f4 100644 --- a/crypto/stark/src/lib.rs +++ b/crypto/stark/src/lib.rs @@ -1,3 +1,8 @@ +// `StorageMode::Disk` uses `memmap2`, which does not build on wasm32. +// Fail at the crate root rather than as a transitive memmap2 error. +#[cfg(all(target_arch = "wasm32", feature = "disk-spill"))] +compile_error!("the `disk-spill` feature requires memmap2, which does not compile on wasm32"); + #[cfg(feature = "debug-checks")] pub mod bus_debug; pub mod constraints; @@ -14,11 +19,15 @@ pub mod instruments; pub mod lookup; pub mod proof; pub mod prover; +#[cfg(feature = "disk-spill")] +pub mod storage_mode; pub mod table; pub mod trace; pub mod traits; pub mod verifier; +#[cfg(test)] +pub mod test_utils; #[cfg(test)] pub mod tests; diff --git a/crypto/stark/src/prover.rs b/crypto/stark/src/prover.rs index 8a4577360..68f50ea53 100644 --- a/crypto/stark/src/prover.rs +++ b/crypto/stark/src/prover.rs @@ -10,6 +10,7 @@ use math::fft::errors::FFTError; use log::info; use math::field::traits::{IsField, IsSubFieldOf}; +use math::spill_safe::SpillSafe; use math::traits::{AsBytes, ByteConversion}; use math::{ field::{element::FieldElement, traits::IsFFTField}, @@ -28,6 +29,8 @@ use crate::domain::new_domain; use crate::fri; use crate::lookup::LOGUP_NUM_CHALLENGES; use crate::proof::stark::{DeepPolynomialOpenings, PolynomialOpenings}; +#[cfg(feature = "disk-spill")] +use crate::storage_mode::StorageMode; use crate::table::Table; use crate::trace::LDETraceTable; @@ -100,6 +103,10 @@ where pub enum ProvingError { WrongParameter(String), EmptyCommitment, + /// I/O failure while spilling prover state (traces, LDE, Merkle trees) to disk: + /// out of disk space, fd exhaustion, or mmap failure. + #[cfg(feature = "disk-spill")] + DiskSpill(String), } /// A container for the intermediate results of the commitments to a trace table, main or auxiliary in case of RAP, @@ -296,7 +303,7 @@ impl LdeTwiddles { /// Number of tables to process concurrently in `multi_prove`. /// Default: num_cores / 3 (benchmarked optimal on both M3 Pro and EPYC 9454P). /// Override with `TABLE_PARALLELISM` env var. -fn table_parallelism() -> usize { +pub fn table_parallelism() -> usize { #[cfg(feature = "parallel")] { std::env::var("TABLE_PARALLELISM") @@ -605,6 +612,7 @@ pub trait IsStarkProver< trace: &TraceTable, domain: &Domain, twiddles: &LdeTwiddles, + #[cfg(feature = "disk-spill")] storage_mode: StorageMode, ) -> Result< ( BatchedMerkleTree, @@ -622,6 +630,10 @@ pub trait IsStarkProver< { let lde_size = domain.interpolation_domain_size * domain.blowup_factor; let mut columns = trace.extract_columns_main(lde_size); + #[cfg(feature = "disk-spill")] + if storage_mode == StorageMode::Disk { + trace.main_table.advise_drop_cache(); + } #[cfg(feature = "instruments")] let t_sub = Instant::now(); Self::expand_columns_to_lde::(&mut columns, domain, twiddles); @@ -630,11 +642,18 @@ pub trait IsStarkProver< #[cfg(feature = "instruments")] let t_sub = Instant::now(); - let (tree, root) = + #[allow(unused_mut)] + let (mut tree, root) = Self::commit_columns_bit_reversed(&columns).ok_or(ProvingError::EmptyCommitment)?; #[cfg(feature = "instruments")] crate::instruments::accum_r1_main(main_lde_dur, t_sub.elapsed()); + #[cfg(feature = "disk-spill")] + if storage_mode == StorageMode::Disk { + tree.spill_nodes_to_disk() + .map_err(|e| ProvingError::DiskSpill(format!("main Merkle tree: {e}")))?; + } + Ok((tree, root, None, None, 0, columns)) } @@ -646,6 +665,7 @@ pub trait IsStarkProver< precomputed_commitment: Commitment, num_precomputed_cols: usize, twiddles: &LdeTwiddles, + #[cfg(feature = "disk-spill")] storage_mode: StorageMode, ) -> Result< ( BatchedMerkleTree, @@ -663,6 +683,10 @@ pub trait IsStarkProver< { let lde_size = domain.interpolation_domain_size * domain.blowup_factor; let mut columns = trace.extract_columns_main(lde_size); + #[cfg(feature = "disk-spill")] + if storage_mode == StorageMode::Disk { + trace.main_table.advise_drop_cache(); + } #[cfg(feature = "instruments")] let t_sub = Instant::now(); Self::expand_columns_to_lde::(&mut columns, domain, twiddles); @@ -671,11 +695,13 @@ pub trait IsStarkProver< #[cfg(feature = "instruments")] let t_sub = Instant::now(); - let (precomputed_tree, precomputed_root) = + #[allow(unused_mut)] + let (mut precomputed_tree, precomputed_root) = Self::commit_columns_bit_reversed(&columns[..num_precomputed_cols]) .ok_or(ProvingError::EmptyCommitment)?; - let (mult_tree, mult_root) = + #[allow(unused_mut)] + let (mut mult_tree, mult_root) = Self::commit_columns_bit_reversed(&columns[num_precomputed_cols..]) .ok_or(ProvingError::EmptyCommitment)?; #[cfg(feature = "instruments")] @@ -686,6 +712,16 @@ pub trait IsStarkProver< "Prover's precomputed commitment doesn't match hardcoded AIR commitment" ); + #[cfg(feature = "disk-spill")] + if storage_mode == StorageMode::Disk { + precomputed_tree + .spill_nodes_to_disk() + .map_err(|e| ProvingError::DiskSpill(format!("precomputed Merkle tree: {e}")))?; + mult_tree + .spill_nodes_to_disk() + .map_err(|e| ProvingError::DiskSpill(format!("mult Merkle tree: {e}")))?; + } + Ok(( mult_tree, mult_root, @@ -1570,11 +1606,16 @@ pub trait IsStarkProver< fn multi_prove( mut air_trace_pairs: Vec>, transcript: &mut (impl IsStarkTranscript + Clone + Send), + #[cfg(feature = "disk-spill")] storage_mode: StorageMode, ) -> Result, ProvingError> where FieldElement: AsBytes, FieldElement: AsBytes, PI: Send + Sync + Clone, + Field: Copy + 'static, + FieldExtension: Copy + 'static, + ::BaseType: SpillSafe, + ::BaseType: SpillSafe, { info!("Started proof generation..."); @@ -1638,6 +1679,21 @@ pub trait IsStarkProver< let k = table_parallelism().min(num_airs).max(1); + // Spill main traces to mmap before Round 1 LDE. + #[cfg(feature = "disk-spill")] + if storage_mode == StorageMode::Disk { + #[cfg(feature = "parallel")] + let spill_iter = air_trace_pairs.par_iter_mut(); + #[cfg(not(feature = "parallel"))] + let mut spill_iter = air_trace_pairs.iter_mut(); + spill_iter.try_for_each(|(_, trace, _)| { + trace + .main_table + .spill_to_disk() + .map_err(|e| ProvingError::DiskSpill(format!("early main: {e}"))) + })?; + } + #[cfg(feature = "instruments")] let prepass_elapsed = phase_start.elapsed(); #[cfg(feature = "instruments")] @@ -1679,9 +1735,17 @@ pub trait IsStarkProver< air.precomputed_commitment(), air.num_precomputed_columns(), twiddles, + #[cfg(feature = "disk-spill")] + storage_mode, ) } else { - Self::commit_main_trace(*trace, domain, twiddles) + Self::commit_main_trace( + *trace, + domain, + twiddles, + #[cfg(feature = "disk-spill")] + storage_mode, + ) } }) .collect(); @@ -1754,6 +1818,23 @@ pub trait IsStarkProver< }) .collect(); + // Spill all aux trace tables to mmap before any Round 1 aux LDE work. + #[cfg(feature = "disk-spill")] + if storage_mode == StorageMode::Disk { + #[cfg(feature = "parallel")] + let spill_iter = air_trace_pairs.par_iter_mut(); + #[cfg(not(feature = "parallel"))] + let mut spill_iter = air_trace_pairs.iter_mut(); + spill_iter.try_for_each(|(air, trace, _)| { + if air.has_aux_trace() { + trace + .spill_aux_to_disk() + .map_err(|e| ProvingError::DiskSpill(format!("aux trace: {e}")))?; + } + Ok(()) + })?; + } + #[cfg(feature = "instruments")] let aux_build_elapsed = phase_start.elapsed(); #[cfg(feature = "instruments")] @@ -1803,6 +1884,10 @@ pub trait IsStarkProver< if air.has_aux_trace() { let lde_size = domain.interpolation_domain_size * domain.blowup_factor; let mut columns = trace.extract_columns_aux(lde_size); + #[cfg(feature = "disk-spill")] + if storage_mode == StorageMode::Disk { + trace.aux_table.advise_drop_cache(); + } #[cfg(feature = "instruments")] let t_sub = Instant::now(); Self::expand_columns_to_lde::( @@ -1814,11 +1899,19 @@ pub trait IsStarkProver< let aux_lde_dur = t_sub.elapsed(); #[cfg(feature = "instruments")] let t_sub = Instant::now(); - let (tree, root) = Self::commit_columns_bit_reversed(&columns) + #[allow(unused_mut)] + let (mut tree, root) = Self::commit_columns_bit_reversed(&columns) .ok_or(ProvingError::EmptyCommitment)?; #[cfg(feature = "instruments")] crate::instruments::accum_r1_aux(aux_lde_dur, t_sub.elapsed()); + #[cfg(feature = "disk-spill")] + if storage_mode == StorageMode::Disk { + tree.spill_nodes_to_disk().map_err(|e| { + ProvingError::DiskSpill(format!("aux Merkle tree: {e}")) + })?; + } + Ok((Some(Arc::new(tree)), Some(root), columns)) } else { Ok((None, None, Vec::new())) @@ -2007,10 +2100,19 @@ pub trait IsStarkProver< FieldElement: AsBytes, FieldElement: AsBytes, PI: Send + Sync + Clone, + Field: Copy + 'static, + FieldExtension: Copy + 'static, + ::BaseType: SpillSafe, + ::BaseType: SpillSafe, { let air_trace_pairs = vec![(air, trace, pub_inputs)]; - Self::multi_prove(air_trace_pairs, transcript) - .map(|mut multi_proof| multi_proof.proofs.remove(0)) + Self::multi_prove( + air_trace_pairs, + transcript, + #[cfg(feature = "disk-spill")] + StorageMode::Ram, + ) + .map(|mut multi_proof| multi_proof.proofs.remove(0)) } // TODO: propagate errors instead of unwrap() in open_deep_composition_poly and FRI operations diff --git a/crypto/stark/src/storage_mode.rs b/crypto/stark/src/storage_mode.rs new file mode 100644 index 000000000..0ba4b32b5 --- /dev/null +++ b/crypto/stark/src/storage_mode.rs @@ -0,0 +1,8 @@ +/// Storage backend for intermediate prover state: `Ram` (heap) or `Disk` (mmap). +/// Disk trades wall time for peak RAM. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] +pub enum StorageMode { + #[default] + Ram, + Disk, +} diff --git a/crypto/stark/src/table.rs b/crypto/stark/src/table.rs index 352b9e6f7..f91be2f5e 100644 --- a/crypto/stark/src/table.rs +++ b/crypto/stark/src/table.rs @@ -1,24 +1,149 @@ use crate::frame::Frame; +#[cfg(feature = "disk-spill")] +use crypto::mmap_util::spill_slice_to_mmap; use math::field::{ element::FieldElement, traits::{IsField, IsSubFieldOf}, }; +#[cfg(feature = "disk-spill")] +use math::spill_safe::SpillSafe; #[cfg(feature = "parallel")] use rayon::prelude::*; +/// Mmap-backed storage for a spilled Table. +/// +/// Access goes through pointer arithmetic on the mmap, matching the +/// original `data[row * width + col]` layout. +#[cfg(feature = "disk-spill")] +struct TableMmapBacking { + mmap: memmap2::Mmap, + /// Number of columns per row. + width: usize, + /// Number of rows. + height: usize, + /// Size in bytes of a single element. + elem_size: usize, +} + +#[cfg(feature = "disk-spill")] +impl std::fmt::Debug for TableMmapBacking { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("TableMmapBacking") + .field("width", &self.width) + .field("height", &self.height) + .field("elem_size", &self.elem_size) + .finish() + } +} + /// A two-dimensional Table holding field elements, arranged in a row-major order. /// This is the basic underlying data structure used for any two-dimensional component in the /// the STARK protocol implementation, such as the `TraceTable` and the `EvaluationFrame`. /// Since this struct is a representation of a two-dimensional table, all rows should have the same /// length. -#[derive(Clone, Default, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)] +#[derive(Default, Debug, serde::Deserialize)] +#[cfg_attr( + not(feature = "disk-spill"), + derive(serde::Serialize, Clone, PartialEq, Eq) +)] #[serde(bound = "")] pub struct Table { pub data: Vec>, pub width: usize, pub height: usize, + #[cfg(feature = "disk-spill")] + #[serde(skip)] + mmap_backing: Option, +} + +#[cfg(feature = "disk-spill")] +impl serde::Serialize for Table +where + FieldElement: serde::Serialize, +{ + fn serialize(&self, serializer: S) -> Result { + use serde::ser::SerializeStruct; + let mut s = serializer.serialize_struct("Table", 3)?; + if self.mmap_backing.is_some() { + s.serialize_field("data", &MmapDataSeq(self))?; + } else { + s.serialize_field("data", &self.data)?; + } + s.serialize_field("width", &self.width)?; + s.serialize_field("height", &self.height)?; + s.end() + } } +#[cfg(feature = "disk-spill")] +struct MmapDataSeq<'a, F: IsField>(&'a Table); + +#[cfg(feature = "disk-spill")] +impl serde::Serialize for MmapDataSeq<'_, F> +where + FieldElement: serde::Serialize, +{ + fn serialize(&self, serializer: S) -> Result { + use serde::ser::SerializeSeq; + let table = self.0; + let mut seq = serializer.serialize_seq(Some(table.width * table.height))?; + for r in 0..table.height { + for elem in table.get_row(r) { + seq.serialize_element(elem)?; + } + } + seq.end() + } +} + +/// Cloning a spilled table copies its mmap bytes into a fresh heap `Vec` +/// and returns an unspilled clone. +#[cfg(feature = "disk-spill")] +impl Clone for Table { + fn clone(&self) -> Self { + if self.mmap_backing.is_some() { + let mut data = Vec::with_capacity(self.width * self.height); + for row in 0..self.height { + for col in 0..self.width { + data.push(self.get(row, col).clone()); + } + } + return Self { + data, + width: self.width, + height: self.height, + mmap_backing: None, + }; + } + Self { + data: self.data.clone(), + width: self.width, + height: self.height, + mmap_backing: None, + } + } +} + +#[cfg(feature = "disk-spill")] +impl PartialEq for Table { + fn eq(&self, other: &Self) -> bool { + if self.width != other.width || self.height != other.height { + return false; + } + for row in 0..self.height { + for col in 0..self.width { + if self.get(row, col) != other.get(row, col) { + return false; + } + } + } + true + } +} + +#[cfg(feature = "disk-spill")] +impl Eq for Table {} + impl Table { /// Crates a new Table instance from a one-dimensional array in row major order /// and the intended width of the table. @@ -29,6 +154,8 @@ impl Table { data: Vec::new(), width, height: 0, + #[cfg(feature = "disk-spill")] + mmap_backing: None, }; } @@ -40,6 +167,8 @@ impl Table { data, width, height, + #[cfg(feature = "disk-spill")] + mmap_backing: None, } } @@ -67,6 +196,27 @@ impl Table { /// Given a row index, returns a reference to that row as a slice of field elements. pub fn get_row(&self, row_idx: usize) -> &[FieldElement] { + #[cfg(feature = "disk-spill")] + if let Some(ref backing) = self.mmap_backing { + // Ensures the unsafe block's read stays within the mmap. + assert!( + row_idx < backing.height, + "Table::get_row out of bounds: row={row_idx}, height={}", + backing.height + ); + let offset = row_idx * backing.width * backing.elem_size; + // SAFETY: spill_to_disk writes the table in row-major layout, so + // width elements at this offset are contiguous. FieldElement + // is #[repr(transparent)] over F::BaseType, and spill_to_disk + // requires F::BaseType: SpillSafe (no padding, all bit patterns + // valid). + return unsafe { + std::slice::from_raw_parts( + backing.mmap.as_ptr().add(offset) as *const FieldElement, + backing.width, + ) + }; + } let row_offset = row_idx * self.width; &self.data[row_offset..row_offset + self.width] } @@ -77,7 +227,7 @@ impl Table { (0..self.width) .map(|col_idx| { (0..self.height) - .map(|row_idx| self.data[row_idx * self.width + col_idx].clone()) + .map(|row_idx| self.get(row_idx, col_idx).clone()) .collect() }) .collect() @@ -97,7 +247,7 @@ impl Table { iter.map(|col_idx| { let mut buf = Vec::with_capacity(capacity); for row_idx in 0..self.height { - buf.push(self.data[row_idx * self.width + col_idx].clone()); + buf.push(self.get(row_idx, col_idx).clone()); } buf }) @@ -106,15 +256,89 @@ impl Table { /// Given row and column indexes, returns the stored field element in that position of the table. pub fn get(&self, row: usize, col: usize) -> &FieldElement { + #[cfg(feature = "disk-spill")] + if let Some(ref backing) = self.mmap_backing { + // Ensures the unsafe block's read stays within the mmap. + assert!( + row < backing.height && col < backing.width, + "Table::get out of bounds: row={row}, col={col}, height={}, width={}", + backing.height, + backing.width + ); + // Row-major layout: offset = (row * width + col) * elem_size + let offset = (row * backing.width + col) * backing.elem_size; + // SAFETY: FieldElement is #[repr(transparent)] over F::BaseType. + // The mmap is page-aligned and elements are contiguously packed. + // The data was written from identical types on the same machine, + // and spill_to_disk requires F::BaseType: SpillSafe (no padding, + // all bit patterns valid). + return unsafe { &*(backing.mmap.as_ptr().add(offset) as *const FieldElement) }; + } let idx = row * self.width + col; &self.data[idx] } pub fn set(&mut self, row: usize, col: usize, value: FieldElement) { + #[cfg(feature = "disk-spill")] + assert!( + self.mmap_backing.is_none(), + "Table::set on a spilled table — backing mmap is read-only" + ); let idx = row * self.width + col; self.data[idx] = value; } + /// Spill the table's row-major data to a temp file and mmap it back. + /// Frees the heap `data` Vec while preserving access through + /// [`Self::get`], [`Self::get_row`], and [`Self::columns`]. + /// + /// No-op if the table is empty or already spilled. + #[cfg(feature = "disk-spill")] + pub fn spill_to_disk(&mut self) -> std::io::Result<()> + where + F: Copy + 'static, + F::BaseType: SpillSafe, + { + if self.data.is_empty() || self.mmap_backing.is_some() { + return Ok(()); + } + + let mmap = spill_slice_to_mmap(&self.data)?; + self.mmap_backing = Some(TableMmapBacking { + mmap, + width: self.width, + height: self.height, + elem_size: size_of::>(), + }); + self.data = Vec::new(); + + Ok(()) + } + + /// Hint the kernel to drop mmap pages from the page cache. + /// Call after reading spilled data into pool buffers so the same + /// data doesn't occupy RAM in both places. + /// + /// Reliable on Linux for clean file-backed mappings; on other Unix + /// (macOS/BSD) the hint may be a no-op. No-op on non-Unix targets. + #[cfg(all(feature = "disk-spill", unix))] + pub fn advise_drop_cache(&self) { + if let Some(ref backing) = self.mmap_backing { + // SAFETY: pointer and length are from a valid mmap. + // MADV_DONTNEED is advisory and cannot cause UB. + unsafe { + libc::madvise( + backing.mmap.as_ptr() as *mut libc::c_void, + backing.mmap.len(), + libc::MADV_DONTNEED, + ); + } + } + } + + #[cfg(all(feature = "disk-spill", not(unix)))] + pub fn advise_drop_cache(&self) {} + /// Given a step size, converts the given table into a `Frame`. /// Clones row data into owned Vecs (only used by verifier on small OOD tables). pub fn into_frame(&self, main_trace_columns: usize, step_size: usize) -> Frame { @@ -172,3 +396,122 @@ where &self.aux_data[row][col] } } + +#[cfg(all(test, feature = "disk-spill"))] +mod disk_spill_tests { + use super::*; + use math::field::goldilocks::GoldilocksField; + + type F = GoldilocksField; + + #[test] + fn test_table_spill_roundtrip() { + let width = 4; + let height = 8; + let data: Vec> = (0..width * height) + .map(|i| FieldElement::::from(i as u64)) + .collect(); + + let mut table = Table::new(data.clone(), width); + assert!(table.mmap_backing.is_none()); + + // Snapshot values before spill + let pre_spill: Vec>> = (0..height) + .map(|r| (0..width).map(|c| *table.get(r, c)).collect()) + .collect(); + + table.spill_to_disk().expect("spill_to_disk failed"); + assert!(table.mmap_backing.is_some()); + assert!( + table.data.is_empty(), + "heap data should be freed after spill" + ); + + // Verify get() returns the same values + for (r, pre_row) in pre_spill.iter().enumerate() { + for (c, pre_val) in pre_row.iter().enumerate() { + assert_eq!(table.get(r, c), pre_val, "mismatch at ({r}, {c})"); + } + } + + // Verify get_row() returns the same values + for (r, pre_row) in pre_spill.iter().enumerate() { + let row = table.get_row(r); + assert_eq!(row.len(), width); + for (c, pre_val) in pre_row.iter().enumerate() { + assert_eq!(&row[c], pre_val, "get_row mismatch at ({r}, {c})"); + } + } + } + + #[test] + fn test_table_spill_empty_is_noop() { + let mut table = Table::::new(Vec::new(), 0); + table + .spill_to_disk() + .expect("spill_to_disk on empty table failed"); + assert!(table.mmap_backing.is_none()); + } + + #[test] + fn test_table_spill_idempotent() { + let data: Vec> = + (0..16).map(|i| FieldElement::::from(i as u64)).collect(); + let mut table = Table::new(data, 4); + + table.spill_to_disk().expect("first spill failed"); + assert!(table.mmap_backing.is_some()); + + table.spill_to_disk().expect("second spill should be no-op"); + assert!(table.mmap_backing.is_some()); + + // Still readable + assert_eq!(table.get(0, 0), &FieldElement::::from(0u64)); + assert_eq!(table.get(3, 3), &FieldElement::::from(15u64)); + } + + #[test] + fn test_clone_spilled_table_materializes_to_heap() { + let width = 4; + let height = 8; + let data: Vec> = (0..width * height) + .map(|i| FieldElement::::from(i as u64)) + .collect(); + + let mut table = Table::new(data, width); + table.spill_to_disk().expect("spill_to_disk failed"); + assert!(table.mmap_backing.is_some()); + + let cloned = table.clone(); + assert!(cloned.mmap_backing.is_none(), "clone should not be spilled"); + assert_eq!(cloned.width, width); + assert_eq!(cloned.height, height); + assert_eq!(cloned, table, "clone must equal source element-wise"); + } + + #[test] + fn test_serialize_spilled_table_matches_unspilled() { + let width = 4; + let height = 8; + let data: Vec> = (0..width * height) + .map(|i| FieldElement::::from(i as u64)) + .collect(); + + let unspilled = Table::new(data.clone(), width); + let unspilled_bytes = bincode::serialize(&unspilled).expect("serialize unspilled"); + + let mut spilled = Table::new(data, width); + spilled.spill_to_disk().expect("spill_to_disk failed"); + let spilled_bytes = bincode::serialize(&spilled).expect("serialize spilled"); + + assert_eq!( + spilled_bytes, unspilled_bytes, + "spilled and unspilled tables must serialize to identical bytes" + ); + + let restored: Table = + bincode::deserialize(&spilled_bytes).expect("deserialize spilled bytes"); + assert!(restored.mmap_backing.is_none()); + assert_eq!(restored, unspilled); + } +} diff --git a/crypto/stark/src/test_utils.rs b/crypto/stark/src/test_utils.rs new file mode 100644 index 000000000..f5cd19f80 --- /dev/null +++ b/crypto/stark/src/test_utils.rs @@ -0,0 +1,38 @@ +//! Shared test helpers for the stark crate. + +use crate::proof::stark::MultiProof; +use crate::prover::{IsStarkProver, Prover, ProvingError}; +use crate::trace::TraceTable; +use crate::traits::AIR; +use crypto::fiat_shamir::is_transcript::IsStarkTranscript; +use math::field::element::FieldElement; +use math::field::traits::{IsFFTField, IsField, IsSubFieldOf}; +use math::spill_safe::SpillSafe; +use math::traits::{AsBytes, ByteConversion}; + +type AirTracePair<'a, Field, FieldExtension, PI> = ( + &'a dyn AIR, + &'a mut TraceTable, + &'a PI, +); + +pub fn multi_prove_ram( + air_trace_pairs: Vec>, + transcript: &mut (impl IsStarkTranscript + Clone + Send), +) -> Result, ProvingError> +where + Field: IsSubFieldOf + IsFFTField + Send + Sync + Copy + 'static, + FieldExtension: IsField + Send + Sync + Copy + 'static, + PI: Send + Sync + Clone, + FieldElement: AsBytes + ByteConversion, + FieldElement: AsBytes + ByteConversion, + ::BaseType: SpillSafe, + ::BaseType: SpillSafe, +{ + Prover::::multi_prove( + air_trace_pairs, + transcript, + #[cfg(feature = "disk-spill")] + crate::storage_mode::StorageMode::Ram, + ) +} diff --git a/crypto/stark/src/tests/air_tests.rs b/crypto/stark/src/tests/air_tests.rs index 11d356ccf..8e20f303e 100644 --- a/crypto/stark/src/tests/air_tests.rs +++ b/crypto/stark/src/tests/air_tests.rs @@ -31,6 +31,7 @@ type Felt = FieldElement; use crate::examples::read_only_memory_logup::{ LogReadOnlyPublicInputs, LogReadOnlyRAP, read_only_logup_trace, }; +use crate::test_utils::multi_prove_ram; #[test_log::test] fn test_prove_fib() { @@ -400,7 +401,7 @@ fn test_multi_prove_fib_3_tables() { (&air_3, &mut trace_3, &pub_inputs_3), ]; let multi_proof = - Prover::multi_prove(air_trace_pairs, &mut DefaultTranscript::::new(&[])).unwrap(); + multi_prove_ram(air_trace_pairs, &mut DefaultTranscript::::new(&[])).unwrap(); let airs: Vec< &dyn AIR< @@ -500,7 +501,7 @@ fn test_multi_prove_2_tables_small_field() { (&air_2, &mut trace_2, &pub_inputs_2), ]; - let multi_proof = Prover::multi_prove( + let multi_proof = multi_prove_ram( air_trace_pairs, &mut DefaultTranscript::::new(&[]), ) @@ -538,7 +539,7 @@ fn test_multi_prove_different_airs() { )> = vec![(&air_1, &mut trace_1, &()), (&air_2, &mut trace_2, &())]; let multi_proof = - Prover::multi_prove(air_trace_pairs, &mut DefaultTranscript::::new(&[])).unwrap(); + multi_prove_ram(air_trace_pairs, &mut DefaultTranscript::::new(&[])).unwrap(); let airs: Vec< &dyn AIR, diff --git a/crypto/stark/src/tests/bus_tests/completeness_tests.rs b/crypto/stark/src/tests/bus_tests/completeness_tests.rs index 7ca124fe1..83f8ac391 100644 --- a/crypto/stark/src/tests/bus_tests/completeness_tests.rs +++ b/crypto/stark/src/tests/bus_tests/completeness_tests.rs @@ -16,7 +16,7 @@ use crate::lookup::{ NullBoundaryConstraintBuilder, Packing, }; use crate::proof::options::ProofOptions; -use crate::prover::{IsStarkProver, Prover}; +use crate::test_utils::multi_prove_ram; use crate::trace::TraceTable; use crate::traits::AIR; use crate::verifier::{IsStarkVerifier, Verifier}; @@ -122,7 +122,7 @@ fn test_multi_table_proof() { ]; let multi_proof = - Prover::multi_prove(air_trace_pairs, &mut DefaultTranscript::::new(&[])).unwrap(); + multi_prove_ram(air_trace_pairs, &mut DefaultTranscript::::new(&[])).unwrap(); let airs: Vec<&dyn AIR> = vec![&cpu_air, &add_air, &mul_air]; @@ -185,7 +185,7 @@ fn test_all_padding() { ]; let multi_proof = - Prover::multi_prove(air_trace_pairs, &mut DefaultTranscript::::new(&[])).unwrap(); + multi_prove_ram(air_trace_pairs, &mut DefaultTranscript::::new(&[])).unwrap(); let airs: Vec<&dyn AIR> = vec![&cpu_air, &add_air, &mul_air]; @@ -248,7 +248,7 @@ fn test_single_operation() { ]; let multi_proof = - Prover::multi_prove(air_trace_pairs, &mut DefaultTranscript::::new(&[])).unwrap(); + multi_prove_ram(air_trace_pairs, &mut DefaultTranscript::::new(&[])).unwrap(); let airs: Vec<&dyn AIR> = vec![&cpu_air, &add_air, &mul_air]; @@ -311,7 +311,7 @@ fn test_duplicate_operations() { ]; let multi_proof = - Prover::multi_prove(air_trace_pairs, &mut DefaultTranscript::::new(&[])).unwrap(); + multi_prove_ram(air_trace_pairs, &mut DefaultTranscript::::new(&[])).unwrap(); let airs: Vec<&dyn AIR> = vec![&cpu_air, &add_air, &mul_air]; @@ -374,7 +374,7 @@ fn test_serialization_roundtrip() { ]; let multi_proof = - Prover::multi_prove(air_trace_pairs, &mut DefaultTranscript::::new(&[])).unwrap(); + multi_prove_ram(air_trace_pairs, &mut DefaultTranscript::::new(&[])).unwrap(); // Serialize and deserialize let serialized = serde_cbor::to_vec(&multi_proof).expect("serialization failed"); @@ -519,7 +519,7 @@ fn test_bus_value_features() { ]; let multi_proof = - Prover::multi_prove(air_trace_pairs, &mut DefaultTranscript::::new(&[])).unwrap(); + multi_prove_ram(air_trace_pairs, &mut DefaultTranscript::::new(&[])).unwrap(); let airs: Vec<&dyn AIR> = vec![&sender_air, &receiver_air]; diff --git a/crypto/stark/src/tests/bus_tests/multiplicity_tests.rs b/crypto/stark/src/tests/bus_tests/multiplicity_tests.rs index d4ef1aee9..7e4d632dd 100644 --- a/crypto/stark/src/tests/bus_tests/multiplicity_tests.rs +++ b/crypto/stark/src/tests/bus_tests/multiplicity_tests.rs @@ -15,7 +15,7 @@ use crate::lookup::{ NullBoundaryConstraintBuilder, Packing, }; use crate::proof::options::ProofOptions; -use crate::prover::{IsStarkProver, Prover}; +use crate::test_utils::multi_prove_ram; use crate::trace::TraceTable; use crate::traits::AIR; use crate::verifier::{IsStarkVerifier, Verifier}; @@ -113,7 +113,7 @@ fn test_multiplicity_one() { ]; let multi_proof = - Prover::multi_prove(air_trace_pairs, &mut DefaultTranscript::::new(&[])).unwrap(); + multi_prove_ram(air_trace_pairs, &mut DefaultTranscript::::new(&[])).unwrap(); let airs: Vec<&dyn AIR> = vec![&sender, &receiver]; @@ -223,7 +223,7 @@ fn test_multiplicity_sum() { ]; let multi_proof = - Prover::multi_prove(air_trace_pairs, &mut DefaultTranscript::::new(&[])).unwrap(); + multi_prove_ram(air_trace_pairs, &mut DefaultTranscript::::new(&[])).unwrap(); let airs: Vec<&dyn AIR> = vec![&sender, &receiver]; @@ -331,7 +331,7 @@ fn test_multiplicity_negated() { ]; let multi_proof = - Prover::multi_prove(air_trace_pairs, &mut DefaultTranscript::::new(&[])).unwrap(); + multi_prove_ram(air_trace_pairs, &mut DefaultTranscript::::new(&[])).unwrap(); let airs: Vec<&dyn AIR> = vec![&sender, &receiver]; diff --git a/crypto/stark/src/tests/bus_tests/soundness_tests.rs b/crypto/stark/src/tests/bus_tests/soundness_tests.rs index e1994ef6a..fc718bf7c 100644 --- a/crypto/stark/src/tests/bus_tests/soundness_tests.rs +++ b/crypto/stark/src/tests/bus_tests/soundness_tests.rs @@ -14,6 +14,7 @@ use crate::examples::multi_table_lookup::{ }; use crate::proof::options::ProofOptions; use crate::prover::{IsStarkProver, Prover}; +use crate::test_utils::multi_prove_ram; use crate::trace::TraceTable; use crate::traits::AIR; use crate::verifier::{IsStarkVerifier, Verifier}; @@ -79,7 +80,7 @@ fn test_wrong_result_value() { ]; let multi_proof = - Prover::multi_prove(air_trace_pairs, &mut DefaultTranscript::::new(&[])).unwrap(); + multi_prove_ram(air_trace_pairs, &mut DefaultTranscript::::new(&[])).unwrap(); let airs: Vec<&dyn AIR> = vec![&cpu_air, &add_air, &mul_air]; @@ -142,7 +143,7 @@ fn test_off_by_one() { ]; let multi_proof = - Prover::multi_prove(air_trace_pairs, &mut DefaultTranscript::::new(&[])).unwrap(); + multi_prove_ram(air_trace_pairs, &mut DefaultTranscript::::new(&[])).unwrap(); let airs: Vec<&dyn AIR> = vec![&cpu_air, &add_air, &mul_air]; @@ -205,7 +206,7 @@ fn test_swapped_operands() { ]; let multi_proof = - Prover::multi_prove(air_trace_pairs, &mut DefaultTranscript::::new(&[])).unwrap(); + multi_prove_ram(air_trace_pairs, &mut DefaultTranscript::::new(&[])).unwrap(); let airs: Vec<&dyn AIR> = vec![&cpu_air, &add_air, &mul_air]; @@ -268,7 +269,7 @@ fn test_single_column_wrong() { ]; let multi_proof = - Prover::multi_prove(air_trace_pairs, &mut DefaultTranscript::::new(&[])).unwrap(); + multi_prove_ram(air_trace_pairs, &mut DefaultTranscript::::new(&[])).unwrap(); let airs: Vec<&dyn AIR> = vec![&cpu_air, &add_air, &mul_air]; @@ -335,7 +336,7 @@ fn test_over_report_multiplicity() { ]; let multi_proof = - Prover::multi_prove(air_trace_pairs, &mut DefaultTranscript::::new(&[])).unwrap(); + multi_prove_ram(air_trace_pairs, &mut DefaultTranscript::::new(&[])).unwrap(); let airs: Vec<&dyn AIR> = vec![&cpu_air, &add_air, &mul_air]; @@ -398,7 +399,7 @@ fn test_under_report_multiplicity() { ]; let multi_proof = - Prover::multi_prove(air_trace_pairs, &mut DefaultTranscript::::new(&[])).unwrap(); + multi_prove_ram(air_trace_pairs, &mut DefaultTranscript::::new(&[])).unwrap(); let airs: Vec<&dyn AIR> = vec![&cpu_air, &add_air, &mul_air]; @@ -461,7 +462,7 @@ fn test_zero_multiplicity_skip() { ]; let multi_proof = - Prover::multi_prove(air_trace_pairs, &mut DefaultTranscript::::new(&[])).unwrap(); + multi_prove_ram(air_trace_pairs, &mut DefaultTranscript::::new(&[])).unwrap(); let airs: Vec<&dyn AIR> = vec![&cpu_air, &add_air, &mul_air]; @@ -528,7 +529,7 @@ fn test_phantom_receive() { ]; let multi_proof = - Prover::multi_prove(air_trace_pairs, &mut DefaultTranscript::::new(&[])).unwrap(); + multi_prove_ram(air_trace_pairs, &mut DefaultTranscript::::new(&[])).unwrap(); let airs: Vec<&dyn AIR> = vec![&cpu_air, &add_air, &mul_air]; @@ -591,7 +592,7 @@ fn test_missing_receiver() { ]; let multi_proof = - Prover::multi_prove(air_trace_pairs, &mut DefaultTranscript::::new(&[])).unwrap(); + multi_prove_ram(air_trace_pairs, &mut DefaultTranscript::::new(&[])).unwrap(); let airs: Vec<&dyn AIR> = vec![&cpu_air, &add_air, &mul_air]; @@ -662,7 +663,7 @@ fn test_tampered_table_contribution() { ]; let mut multi_proof = - Prover::multi_prove(air_trace_pairs, &mut DefaultTranscript::::new(&[])).unwrap(); + multi_prove_ram(air_trace_pairs, &mut DefaultTranscript::::new(&[])).unwrap(); // Corrupt table_contribution in the ADD table's bus public inputs. // This changes the per-row offset L/N used in the circular constraint, @@ -742,7 +743,7 @@ fn test_tampered_acc_ood_evaluation() { ]; let mut multi_proof = - Prover::multi_prove(air_trace_pairs, &mut DefaultTranscript::::new(&[])).unwrap(); + multi_prove_ram(air_trace_pairs, &mut DefaultTranscript::::new(&[])).unwrap(); // Corrupt the acc column OOD evaluation in the ADD table proof. // With batching + absorption, ADD has 4 main columns and 1 aux column @@ -827,7 +828,7 @@ fn test_missing_bus_public_inputs_rejected() { ]; let mut multi_proof = - Prover::multi_prove(air_trace_pairs, &mut DefaultTranscript::::new(&[])).unwrap(); + multi_prove_ram(air_trace_pairs, &mut DefaultTranscript::::new(&[])).unwrap(); // Remove bus_public_inputs from the ADD table proof entirely. multi_proof.proofs[1].bus_public_inputs = None; @@ -948,7 +949,7 @@ fn test_zeroed_table_contribution_rejected() { ]; let mut multi_proof = - Prover::multi_prove(air_trace_pairs, &mut DefaultTranscript::::new(&[])).unwrap(); + multi_prove_ram(air_trace_pairs, &mut DefaultTranscript::::new(&[])).unwrap(); // Zero out table_contribution for the ADD table. let add_proof = &mut multi_proof.proofs[1]; @@ -1026,7 +1027,7 @@ fn test_one_of_many_wrong() { ]; let multi_proof = - Prover::multi_prove(air_trace_pairs, &mut DefaultTranscript::::new(&[])).unwrap(); + multi_prove_ram(air_trace_pairs, &mut DefaultTranscript::::new(&[])).unwrap(); let airs: Vec<&dyn AIR> = vec![&cpu_air, &add_air, &mul_air]; @@ -1134,7 +1135,7 @@ fn test_full_scenario_wrong_add() { ]; let multi_proof = - Prover::multi_prove(air_trace_pairs, &mut DefaultTranscript::::new(&[])).unwrap(); + multi_prove_ram(air_trace_pairs, &mut DefaultTranscript::::new(&[])).unwrap(); let airs: Vec<&dyn AIR> = vec![&cpu_air, &add_air, &mul_air]; @@ -1208,7 +1209,7 @@ fn test_wrong_table_consumes_value_rejected() { ]; let multi_proof = - Prover::multi_prove(air_trace_pairs, &mut DefaultTranscript::::new(&[])).unwrap(); + multi_prove_ram(air_trace_pairs, &mut DefaultTranscript::::new(&[])).unwrap(); let airs: Vec<&dyn AIR> = vec![&cpu_air, &add_air, &mul_air]; @@ -1324,7 +1325,7 @@ fn test_packing_mismatch_direct_vs_word2l() { ]; let multi_proof = - Prover::multi_prove(air_trace_pairs, &mut DefaultTranscript::::new(&[])).unwrap(); + multi_prove_ram(air_trace_pairs, &mut DefaultTranscript::::new(&[])).unwrap(); let airs: Vec<&dyn AIR> = vec![&sender, &receiver]; @@ -1429,7 +1430,7 @@ fn test_packing_mismatch_element_count() { ]; let multi_proof = - Prover::multi_prove(air_trace_pairs, &mut DefaultTranscript::::new(&[])).unwrap(); + multi_prove_ram(air_trace_pairs, &mut DefaultTranscript::::new(&[])).unwrap(); let airs: Vec<&dyn AIR> = vec![&sender, &receiver]; @@ -1531,7 +1532,7 @@ fn test_packing_mismatch_shift_constant() { ]; let multi_proof = - Prover::multi_prove(air_trace_pairs, &mut DefaultTranscript::::new(&[])).unwrap(); + multi_prove_ram(air_trace_pairs, &mut DefaultTranscript::::new(&[])).unwrap(); let airs: Vec<&dyn AIR> = vec![&sender, &receiver]; @@ -1634,7 +1635,7 @@ fn test_compound_mismatch_dwordhhw_vs_dwordwhh() { ]; let multi_proof = - Prover::multi_prove(air_trace_pairs, &mut DefaultTranscript::::new(&[])).unwrap(); + multi_prove_ram(air_trace_pairs, &mut DefaultTranscript::::new(&[])).unwrap(); let airs: Vec<&dyn AIR> = vec![&sender, &receiver]; @@ -1727,7 +1728,7 @@ fn test_compound_equals_primitive_expansion() { ]; let multi_proof = - Prover::multi_prove(air_trace_pairs, &mut DefaultTranscript::::new(&[])).unwrap(); + multi_prove_ram(air_trace_pairs, &mut DefaultTranscript::::new(&[])).unwrap(); let airs: Vec<&dyn AIR> = vec![&sender, &receiver]; @@ -1843,7 +1844,7 @@ fn test_full_scenario_wrong_mul() { ]; let multi_proof = - Prover::multi_prove(air_trace_pairs, &mut DefaultTranscript::::new(&[])).unwrap(); + multi_prove_ram(air_trace_pairs, &mut DefaultTranscript::::new(&[])).unwrap(); let airs: Vec<&dyn AIR> = vec![&cpu_air, &add_air, &mul_air]; diff --git a/crypto/stark/src/tests/prove_verify_roundtrip_tests.rs b/crypto/stark/src/tests/prove_verify_roundtrip_tests.rs index 9e3c60091..4059ed481 100644 --- a/crypto/stark/src/tests/prove_verify_roundtrip_tests.rs +++ b/crypto/stark/src/tests/prove_verify_roundtrip_tests.rs @@ -16,11 +16,9 @@ use crate::lookup::{ }; use crate::proof::options::ProofOptions; use crate::proof::stark::MultiProof; +use crate::test_utils::multi_prove_ram; use crate::traits::AIR; -use crate::{ - prover::{IsStarkProver, Prover}, - verifier::{IsStarkVerifier, Verifier}, -}; +use crate::verifier::{IsStarkVerifier, Verifier}; type F = GoldilocksField; type E = Degree3GoldilocksExtensionField; @@ -137,7 +135,7 @@ fn test_verify_serialized_multi_table_proofs() { (&mul_air, &mut mul_trace, &()), ]; - Prover::multi_prove(air_trace_pairs, &mut DefaultTranscript::::new(&[])).unwrap() + multi_prove_ram(air_trace_pairs, &mut DefaultTranscript::::new(&[])).unwrap() }; // ========================================================================= diff --git a/crypto/stark/src/tests/prover_tests.rs b/crypto/stark/src/tests/prover_tests.rs index b1e403b36..1355b363d 100644 --- a/crypto/stark/src/tests/prover_tests.rs +++ b/crypto/stark/src/tests/prover_tests.rs @@ -8,6 +8,7 @@ use crate::{ }, proof::options::ProofOptions, prover::{IsStarkProver, Prover, domain_cache_stats, evaluate_polynomial_on_lde_domain}, + test_utils::multi_prove_ram, trace::{LDETraceTable, get_trace_evaluations, get_trace_evaluations_from_lde}, traits::AIR, verifier::{IsStarkVerifier, Verifier}, @@ -287,7 +288,7 @@ fn test_multi_prove_mixed_coset_offsets() { (&air_2, &mut trace_2, &pub_inputs), ]; - let multi_proof = Prover::multi_prove( + let multi_proof = multi_prove_ram( air_trace_pairs, &mut DefaultTranscript::::new(&[]), ) @@ -353,7 +354,7 @@ fn test_multi_prove_dedups_shared_domain_params() { (&air_3, &mut trace_3, &pub_inputs), ]; - let multi_proof = Prover::multi_prove( + let multi_proof = multi_prove_ram( air_trace_pairs, &mut DefaultTranscript::::new(&[]), ) diff --git a/crypto/stark/src/trace.rs b/crypto/stark/src/trace.rs index ef6ee7833..6f5896eb7 100644 --- a/crypto/stark/src/trace.rs +++ b/crypto/stark/src/trace.rs @@ -4,6 +4,8 @@ use itertools::Itertools; use math::fft::errors::FFTError; use math::field::traits::{IsField, IsSubFieldOf}; use math::polynomial::barycentric_inv_denoms; +#[cfg(feature = "disk-spill")] +use math::spill_safe::SpillSafe; use math::{ field::{element::FieldElement, traits::IsFFTField}, polynomial::Polynomial, @@ -147,6 +149,26 @@ where self.num_aux_columns = num_aux_columns; } + /// Write main trace data to a temp file and free the in-memory vector. + /// Accessors read from the mmap after this call. + #[cfg(feature = "disk-spill")] + pub fn spill_main_to_disk(&mut self) -> std::io::Result<()> + where + F: Copy + 'static, + F::BaseType: SpillSafe, + { + self.main_table.spill_to_disk() + } + + #[cfg(feature = "disk-spill")] + pub fn spill_aux_to_disk(&mut self) -> std::io::Result<()> + where + E: Copy + 'static, + E::BaseType: SpillSafe, + { + self.aux_table.spill_to_disk() + } + pub fn compute_trace_polys_main(&self) -> Vec>> where S: IsFFTField + IsSubFieldOf, @@ -255,9 +277,8 @@ where /// Gather a full main-trace row into an owned Vec. /// Used by `open_trace_polys` (called ~30 times per table, allocation is negligible). pub fn gather_main_row(&self, row_idx: usize) -> Vec> { - self.main_columns - .iter() - .map(|col| col[row_idx].clone()) + (0..self.num_main_cols()) + .map(|col| self.get_main(row_idx, col).clone()) .collect() } @@ -269,17 +290,15 @@ where col_start: usize, col_end: usize, ) -> Vec> { - self.main_columns[col_start..col_end] - .iter() - .map(|col| col[row_idx].clone()) + (col_start..col_end) + .map(|col| self.get_main(row_idx, col).clone()) .collect() } /// Gather a full aux-trace row into an owned Vec. pub fn gather_aux_row(&self, row_idx: usize) -> Vec> { - self.aux_columns - .iter() - .map(|col| col[row_idx].clone()) + (0..self.num_aux_cols()) + .map(|col| self.get_aux(row_idx, col).clone()) .collect() } diff --git a/executor/programs/asm/fib_iterative_16M.s b/executor/programs/asm/fib_iterative_16M.s new file mode 100644 index 000000000..1ede85aaf --- /dev/null +++ b/executor/programs/asm/fib_iterative_16M.s @@ -0,0 +1,24 @@ + .attribute 5, "rv64i2p1_m2p0" + .globl main +main: + # Iterative Fibonacci - pure register arithmetic + # ~16M steps + # + # Loop body: 5 instructions per iteration + # 3200000 iterations × 5 = 16000000 + setup/teardown + + li t0, 0 # a = fib(0) = 0 + li t1, 1 # b = fib(1) = 1 + li a0, 3200000 # iteration count + +.loop: + add t2, t0, t1 # t2 = a + b + mv t0, t1 # a = b + mv t1, t2 # b = t2 + addi a0, a0, -1 # n-- + bnez a0, .loop # loop if n != 0 + + mv a0, t1 # result = b + li a0, 0 + li a7, 93 + ecall # halt with result in a0 diff --git a/executor/programs/asm/fib_iterative_32M.s b/executor/programs/asm/fib_iterative_32M.s new file mode 100644 index 000000000..df6644193 --- /dev/null +++ b/executor/programs/asm/fib_iterative_32M.s @@ -0,0 +1,24 @@ + .attribute 5, "rv64i2p1_m2p0" + .globl main +main: + # Iterative Fibonacci - pure register arithmetic + # ~32M steps + # + # Loop body: 5 instructions per iteration + # 6400000 iterations × 5 = 32000000 + setup/teardown + + li t0, 0 # a = fib(0) = 0 + li t1, 1 # b = fib(1) = 1 + li a0, 6400000 # iteration count + +.loop: + add t2, t0, t1 # t2 = a + b + mv t0, t1 # a = b + mv t1, t2 # b = t2 + addi a0, a0, -1 # n-- + bnez a0, .loop # loop if n != 0 + + mv a0, t1 # result = b + li a0, 0 + li a7, 93 + ecall # halt with result in a0 diff --git a/prover/Cargo.toml b/prover/Cargo.toml index 60ed39c0c..82ca8bfe0 100644 --- a/prover/Cargo.toml +++ b/prover/Cargo.toml @@ -8,6 +8,7 @@ default = ["parallel"] parallel = ["stark/parallel", "math/parallel", "crypto/parallel", "dep:rayon"] debug-checks = ["stark/debug-checks"] instruments = ["stark/instruments"] +disk-spill = ["stark/disk-spill"] [dependencies] stark = { path = "../crypto/stark" } @@ -16,10 +17,15 @@ math = { path = "../crypto/math" } executor = { path = "../executor" } serde = { version = "1.0", features = ["derive"] } rayon = { version = "1.8.0", optional = true } +sysinfo = { version = "0.31", default-features = false, features = ["system"] } +log = "0.4" [dev-dependencies] env_logger = "*" criterion = { version = "0.5", default-features = false } +bincode = "1" +tikv-jemallocator = "0.6" +tikv-jemalloc-ctl = { version = "0.6", features = ["stats"] } tiny-keccak = { version = "2.0", features = ["keccak"] } [[bench]] diff --git a/prover/src/auto_storage.rs b/prover/src/auto_storage.rs new file mode 100644 index 000000000..a28bcd498 --- /dev/null +++ b/prover/src/auto_storage.rs @@ -0,0 +1,402 @@ +//! Automatic `StorageMode` selection from an analytical peak-RAM estimate. +//! +//! `FORCE_DISK_SPILL` env var forces `StorageMode::Disk` regardless of the +//! estimate. + +use crate::tables::bitwise::{ + NUM_ROWS as BITWISE_ROWS, bus_interactions as bitwise_buses, cols::NUM_COLUMNS as BITWISE_COLS, +}; +use crate::tables::branch::{bus_interactions as branch_buses, cols::NUM_COLUMNS as BRANCH_COLS}; +use crate::tables::commit::{bus_interactions as commit_buses, cols::NUM_COLUMNS as COMMIT_COLS}; +use crate::tables::cpu::{bus_interactions as cpu_buses, cols::NUM_COLUMNS as CPU_COLS}; +use crate::tables::decode::{bus_interactions as decode_buses, cols::NUM_COLUMNS as DECODE_COLS}; +use crate::tables::dvrm::{bus_interactions as dvrm_buses, cols::NUM_COLUMNS as DVRM_COLS}; +use crate::tables::halt::{bus_interactions as halt_buses, cols::NUM_COLUMNS as HALT_COLS}; +use crate::tables::load::{bus_interactions as load_buses, cols::NUM_COLUMNS as LOAD_COLS}; +use crate::tables::lt::{bus_interactions as lt_buses, cols::NUM_COLUMNS as LT_COLS}; +use crate::tables::memw::{bus_interactions as memw_buses, cols::NUM_COLUMNS as MEMW_COLS}; +use crate::tables::memw_aligned::{ + bus_interactions as memw_a_buses, cols::NUM_COLUMNS as MEMW_A_COLS, +}; +use crate::tables::memw_register::{ + bus_interactions as memw_r_buses, cols::NUM_COLUMNS as MEMW_R_COLS, +}; +use crate::tables::mul::{bus_interactions as mul_buses, cols::NUM_COLUMNS as MUL_COLS}; +use crate::tables::page::{ + DEFAULT_PAGE_SIZE as PAGE_SIZE, bus_interactions as page_buses, cols::NUM_COLUMNS as PAGE_COLS, +}; +use crate::tables::register::{ + NUM_REGISTER_ADDRESSES, bus_interactions as register_buses, cols::NUM_COLUMNS as REGISTER_COLS, +}; +use crate::tables::shift::{bus_interactions as shift_buses, cols::NUM_COLUMNS as SHIFT_COLS}; +use crate::tables::trace_builder::TableLengths; +use stark::prover::table_parallelism; +use stark::storage_mode::StorageMode; +use sysinfo::System; + +const GOLDILOCKS_BYTES: u64 = 8; +const CUBIC_EXT_BYTES: u64 = 24; +const KECCAK_NODE_BYTES: u64 = 32; +const LOG_STRUCT_BYTES: u64 = 40; +const MEMORY_CELL_BYTES: u64 = 32; +const INSTRUCTION_MAP_BYTES_PER_ROW: u64 = 32; + +/// 9/10 budget headroom for OS, other processes, and allocator slack. +pub const SAFETY_FRACTION_NUM: u64 = 9; +pub const SAFETY_FRACTION_DEN: u64 = 10; + +/// `(rows, main_cols, aux_cols, num_main_merkle_trees)` for a single table. +type TableSpec = (u64, u64, u64, u64); + +/// Bytes alive for the duration of phase D (LDE columns + main/aux Merkle). +fn persistent_per_table(spec: TableSpec, blowup: u64) -> u64 { + let (rows, main_cols, aux_cols, main_trees) = spec; + let main_lde = rows + .saturating_mul(main_cols) + .saturating_mul(GOLDILOCKS_BYTES) + .saturating_mul(1 + blowup); + let aux_lde = rows + .saturating_mul(aux_cols) + .saturating_mul(CUBIC_EXT_BYTES) + .saturating_mul(1 + blowup); + let main_merkle = main_trees + .saturating_mul(2) + .saturating_mul(rows) + .saturating_mul(blowup) + .saturating_mul(KECCAK_NODE_BYTES); + let aux_merkle = if aux_cols > 0 { + 2u64.saturating_mul(rows) + .saturating_mul(blowup) + .saturating_mul(KECCAK_NODE_BYTES) + } else { + 0 + }; + main_lde + .saturating_add(aux_lde) + .saturating_add(main_merkle) + .saturating_add(aux_merkle) +} + +/// Bytes (constraint evals, composition, FRI) alive during rounds 2-4 for one chunk. +fn transient_per_table(spec: TableSpec, blowup: u64) -> u64 { + let (rows, _, _, _) = spec; + let lde_size = rows.saturating_mul(blowup); + let constraint_evals = lde_size.saturating_mul(CUBIC_EXT_BYTES); + let composition_lde = lde_size.saturating_mul(2).saturating_mul(CUBIC_EXT_BYTES); + let composition_merkle = lde_size.saturating_mul(KECCAK_NODE_BYTES); + let fri_evals = lde_size.saturating_mul(CUBIC_EXT_BYTES); + let fri_merkle = lde_size.saturating_mul(KECCAK_NODE_BYTES); + constraint_evals + .saturating_add(composition_lde) + .saturating_add(composition_merkle) + .saturating_add(fri_evals) + .saturating_add(fri_merkle) +} + +/// Bytes for one Domain/LdeTwiddles cache entry. +fn domain_cache_bytes(rows: u64, blowup: u64) -> u64 { + rows.saturating_mul(3 + 2 * blowup) + .saturating_mul(GOLDILOCKS_BYTES) +} + +fn aux_cols(bus_count: usize) -> u64 { + bus_count.div_ceil(2) as u64 +} + +/// Per-table specs in the same order as `air_trace_pairs` in `prove`. +fn table_specs(lengths: &TableLengths) -> Vec { + let bitwise_rows = BITWISE_ROWS as u64; + let register_rows = NUM_REGISTER_ADDRESSES.next_power_of_two() as u64; + let halt_rows = 1u64; + let page_rows = PAGE_SIZE as u64; + + let mut specs = vec![ + ( + lengths.cpu_padded_rows, + CPU_COLS as u64, + aux_cols(cpu_buses().len()), + 1, + ), + ( + lengths.memw_padded_rows, + MEMW_COLS as u64, + aux_cols(memw_buses().len()), + 1, + ), + ( + lengths.memw_aligned_padded_rows, + MEMW_A_COLS as u64, + aux_cols(memw_a_buses().len()), + 1, + ), + ( + lengths.memw_register_padded_rows, + MEMW_R_COLS as u64, + aux_cols(memw_r_buses().len()), + 1, + ), + ( + lengths.load_padded_rows, + LOAD_COLS as u64, + aux_cols(load_buses().len()), + 1, + ), + ( + lengths.lt_padded_rows, + LT_COLS as u64, + aux_cols(lt_buses().len()), + 1, + ), + ( + lengths.shift_padded_rows, + SHIFT_COLS as u64, + aux_cols(shift_buses().len()), + 1, + ), + ( + lengths.mul_padded_rows, + MUL_COLS as u64, + aux_cols(mul_buses().len()), + 1, + ), + ( + lengths.dvrm_padded_rows, + DVRM_COLS as u64, + aux_cols(dvrm_buses().len()), + 1, + ), + ( + lengths.branch_padded_rows, + BRANCH_COLS as u64, + aux_cols(branch_buses().len()), + 1, + ), + ( + lengths.commit_padded_rows, + COMMIT_COLS as u64, + aux_cols(commit_buses().len()), + 1, + ), + // BITWISE / DECODE / PAGE / REGISTER take the preprocessed-trace commit + // path: it extracts ALL columns into the LDE and builds two Merkle trees + // (precomputed_tree + mult_tree), so main_cols = full NUM_COLUMNS and + // main_trees = 2. + ( + bitwise_rows, + BITWISE_COLS as u64, + aux_cols(bitwise_buses().len()), + 2, + ), + ( + lengths.decode_rows, + DECODE_COLS as u64, + aux_cols(decode_buses().len()), + 2, + ), + (halt_rows, HALT_COLS as u64, aux_cols(halt_buses().len()), 1), + ( + register_rows, + REGISTER_COLS as u64, + aux_cols(register_buses().len()), + 2, + ), + ]; + // Each unique 256 KB page → its own PAGE table at PAGE_SIZE rows. + for _ in 0..lengths.unique_page_count { + specs.push(( + page_rows, + PAGE_COLS as u64, + aux_cols(page_buses(0).len()), + 2, + )); + } + specs +} + +/// Estimates heap from `lengths` and `blowup_factor`. Picks `Disk` if the +/// estimate is greater than available RAM, else `Ram`. `FORCE_DISK_SPILL` env +/// var forces `Disk`. +pub fn decide(lengths: &TableLengths, blowup_factor: u8) -> StorageMode { + if std::env::var("FORCE_DISK_SPILL").is_ok() { + log::info!("storage_mode: Disk (forced via FORCE_DISK_SPILL)"); + return StorageMode::Disk; + } + let estimated = peak_bytes(lengths, blowup_factor, table_parallelism()); + let mode = select_storage_mode(estimated, available_ram_bytes()); + log::info!("estimated_peak_bytes: {estimated}, storage_mode: {mode:?}"); + mode +} + +/// Peak RAM estimate in bytes for a proof whose trace shape matches `lengths`. +pub fn peak_bytes(lengths: &TableLengths, blowup_factor: u8, table_parallelism: usize) -> u64 { + let blowup = blowup_factor as u64; + let k = table_parallelism.max(1); + let specs = table_specs(lengths); + + // Persistent: every table's LDE + main/aux Merkle is alive across phase D. + let persistent_total: u64 = specs + .iter() + .map(|s| persistent_per_table(*s, blowup)) + .fold(0u64, u64::saturating_add); + + // Transient: only k tables run round 2-4 in parallel. Conservative bound is + // the top-k tables by transient bytes (worst possible chunk assignment). + let mut transient_per: Vec = specs + .iter() + .map(|s| transient_per_table(*s, blowup)) + .collect(); + transient_per.sort_unstable_by(|a, b| b.cmp(a)); + let transient_total: u64 = transient_per + .iter() + .take(k) + .copied() + .fold(0u64, u64::saturating_add); + + // Domain + LdeTwiddles cache: one entry per unique padded-row count + // (blowup_factor and coset_offset are constant across tables in this + // codebase, so the unique key collapses to `rows`). + let mut unique_rows: Vec = specs.iter().map(|s| s.0).collect(); + unique_rows.sort_unstable(); + unique_rows.dedup(); + let domain_total: u64 = unique_rows + .iter() + .map(|&r| domain_cache_bytes(r, blowup)) + .fold(0u64, u64::saturating_add); + + // State alive across the prove call (memory cells, log Vec, instruction + // map). Independent of trace shape. + let state_total = lengths + .unique_byte_count + .saturating_mul(MEMORY_CELL_BYTES) + .saturating_add(lengths.cycle_count.saturating_mul(LOG_STRUCT_BYTES)) + .saturating_add( + lengths + .decode_rows + .saturating_mul(INSTRUCTION_MAP_BYTES_PER_ROW), + ); + + persistent_total + .saturating_add(transient_total) + .saturating_add(domain_total) + .saturating_add(state_total) +} + +/// `Disk` if `estimated` exceeds `available` minus a safety margin, else +/// `Ram`. Defaults to `Disk` when `available` is `None`. +fn select_storage_mode(estimated: u64, available: Option) -> StorageMode { + let Some(available) = available else { + log::warn!("Auto disk-spill: sysinfo could not read system memory, defaulting to Disk."); + return StorageMode::Disk; + }; + let threshold = available.saturating_mul(SAFETY_FRACTION_NUM) / SAFETY_FRACTION_DEN; + if estimated > threshold { + StorageMode::Disk + } else { + StorageMode::Ram + } +} + +/// OS-available RAM, or `None` if sysinfo can't read it. +fn available_ram_bytes() -> Option { + let mut sys = System::new(); + sys.refresh_memory(); + // total_memory == 0 means sysinfo can't read; otherwise available is real. + if sys.total_memory() == 0 { + None + } else { + Some(sys.available_memory()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + const GB: u64 = 1_000_000_000; + /// Larger than the table count, so every table lands in the top-k and the + /// per-table delta in `peak_bytes_per_table_increment_is_exact` is purely + /// additive. + const ALL_TABLES: usize = 1_000; + + fn empty_lengths() -> TableLengths { + TableLengths::default() + } + + /// Adding rows to a single chunked table must increase `peak_bytes` by + /// exactly the per-row contribution from the formula in the module doc. + /// Verifies the per-table breakdown is exact rather than averaged. + #[test] + fn peak_bytes_per_table_increment_is_exact() { + let blowup = 2u8; + let b = blowup as u64; + + let baseline = peak_bytes(&empty_lengths(), blowup, ALL_TABLES); + + let mut lengths = empty_lengths(); + lengths.cpu_padded_rows = 4; + let bumped = peak_bytes(&lengths, blowup, ALL_TABLES); + + let cpu_main = CPU_COLS as u64; + let cpu_aux = cpu_buses().len().div_ceil(2) as u64; + let per_row_persistent = cpu_main * GOLDILOCKS_BYTES * (1 + b) + + cpu_aux * CUBIC_EXT_BYTES * (1 + b) + + 2 * b * KECCAK_NODE_BYTES // main Merkle (1 tree) + + 2 * b * KECCAK_NODE_BYTES; // aux Merkle + let per_row_transient = b * CUBIC_EXT_BYTES // constraint_evaluations + + 2 * b * CUBIC_EXT_BYTES // composition LDE (2 parts, d=2) + + b * KECCAK_NODE_BYTES // composition Merkle (PairKeccak) + + b * CUBIC_EXT_BYTES // FRI evals (geometric ≈ 1) + + b * KECCAK_NODE_BYTES; // FRI Merkle (geometric ≈ 1) + let per_row_domain = (3 + 2 * b) * GOLDILOCKS_BYTES; + + // CPU adds 4 rows of persistent + transient (top-k by ALL_TABLES) + + // its 4-row Domain entry (a fresh unique key not previously present). + assert_eq!( + bumped - baseline, + 4 * (per_row_persistent + per_row_transient + per_row_domain) + ); + } + + /// Higher blowup_factor should produce a strictly larger estimate. + #[test] + fn peak_bytes_scales_with_blowup() { + let lengths = empty_lengths(); + let two = peak_bytes(&lengths, 2, ALL_TABLES); + let four = peak_bytes(&lengths, 4, ALL_TABLES); + let eight = peak_bytes(&lengths, 8, ALL_TABLES); + assert!(two < four); + assert!(four < eight); + } + + /// Lower table_parallelism caps the transient sum to fewer tables, so the + /// estimate must be monotone in `k`. + #[test] + fn peak_bytes_monotone_in_table_parallelism() { + let lengths = empty_lengths(); + let k1 = peak_bytes(&lengths, 2, 1); + let k4 = peak_bytes(&lengths, 2, 4); + let k_all = peak_bytes(&lengths, 2, ALL_TABLES); + assert!(k1 < k4); + assert!(k4 <= k_all); + } + + #[test] + fn select_ram_when_estimate_below_threshold() { + // 10 GB estimated, 32 GB available → threshold 28.8 GB → Ram. + let mode = select_storage_mode(10 * GB, Some(32 * GB)); + assert_eq!(mode, StorageMode::Ram); + } + + #[test] + fn select_disk_when_estimate_exceeds_threshold() { + // 30 GB estimated, 32 GB available → threshold 28.8 GB → Disk. + let mode = select_storage_mode(30 * GB, Some(32 * GB)); + assert_eq!(mode, StorageMode::Disk); + } + + #[test] + fn unknown_available_defaults_to_disk() { + let mode = select_storage_mode(peak_bytes(&empty_lengths(), 2, ALL_TABLES), None); + assert_eq!(mode, StorageMode::Disk); + } +} diff --git a/prover/src/lib.rs b/prover/src/lib.rs index 254c37834..dbe13d20b 100644 --- a/prover/src/lib.rs +++ b/prover/src/lib.rs @@ -10,6 +10,8 @@ //! assert!(lambda_vm_prover::verify(&vm_proof, &elf_bytes).unwrap()); //! ``` +#[cfg(feature = "disk-spill")] +pub mod auto_storage; pub mod constraints; #[cfg(feature = "debug-checks")] mod debug_report; @@ -28,6 +30,8 @@ use executor::elf::Elf; use executor::vm::execution::Executor; use math::field::element::FieldElement; use stark::prover::{IsStarkProver, Prover}; +#[cfg(feature = "disk-spill")] +use stark::storage_mode::StorageMode; use stark::traits::AIR; use stark::verifier::{IsStarkVerifier, Verifier}; @@ -37,6 +41,8 @@ use crate::tables::decode; use crate::tables::page; use crate::tables::register; use crate::tables::trace_builder::Traces; +#[cfg(feature = "disk-spill")] +use crate::tables::trace_builder::count_table_lengths; use crate::tables::types::BusId; use crate::test_utils::{ E, F, VmAir, create_bitwise_air, create_branch_air, create_commit_air, create_cpu_air, @@ -544,6 +550,8 @@ pub fn count_elements(elf_bytes: &[u8], private_inputs: &[u8]) -> Result<(u64, u &result.logs, &MaxRowsConfig::default(), private_inputs, + #[cfg(feature = "disk-spill")] + StorageMode::Ram, )?; Ok(( traces.total_field_elements(), @@ -593,9 +601,25 @@ pub fn prove_with_options_and_inputs( #[cfg(feature = "instruments")] let phase_start = std::time::Instant::now(); - // Generate all traces from ELF and execution logs. - // Page tables are derived from the prover's MemoryState (all accessed pages). - let mut traces = Traces::from_elf_and_logs(&program, &result.logs, max_rows, private_inputs)?; + #[cfg(feature = "disk-spill")] + let storage_mode = { + let lengths = count_table_lengths(&program, &result.logs, max_rows, private_inputs)?; + auto_storage::decide(&lengths, proof_options.blowup_factor) + }; + + let mut traces = Traces::from_elf_and_logs( + &program, + &result.logs, + max_rows, + private_inputs, + #[cfg(feature = "disk-spill")] + storage_mode, + )?; + debug_assert_eq!( + traces.public_output_bytes, result.return_values.memory_values, + "public output diverged between executor view and trace reconstruction" + ); + drop(result); #[cfg(feature = "instruments")] let trace_build_elapsed = phase_start.elapsed(); @@ -626,6 +650,8 @@ pub fn prove_with_options_and_inputs( let proof = Prover::multi_prove( airs.air_trace_pairs(&mut traces), &mut DefaultTranscript::::new(&[]), + #[cfg(feature = "disk-spill")] + storage_mode, ) .map_err(|e| Error::Prover(format!("{e:?}")))?; @@ -651,11 +677,6 @@ pub fn prove_with_options_and_inputs( .filter(|c| c.is_private_input) .count(); - debug_assert_eq!( - traces.public_output_bytes, result.return_values.memory_values, - "public output diverged between executor view and trace reconstruction" - ); - Ok(VmProof { proof, runtime_page_ranges, diff --git a/prover/src/tables/trace_builder.rs b/prover/src/tables/trace_builder.rs index bd5a3f2d2..f83763280 100644 --- a/prover/src/tables/trace_builder.rs +++ b/prover/src/tables/trace_builder.rs @@ -26,11 +26,15 @@ //! ``` use std::collections::HashMap; +#[cfg(feature = "disk-spill")] +use std::collections::HashSet; use executor::elf::Elf; use executor::vm::instruction::decoding::Instruction; use executor::vm::logs::Log; use executor::vm::memory::U64HashMap; +#[cfg(feature = "disk-spill")] +use stark::storage_mode::StorageMode; use stark::trace::TraceTable; use super::bitwise::{self, BitwiseOperation, BitwiseOperationType}; @@ -100,6 +104,18 @@ impl MemoryState { Self { cells } } + /// Number of distinct pages that contain at least one cell. + #[cfg(feature = "disk-spill")] + fn unique_page_count(&self, page_size: u64) -> u64 { + debug_assert!( + page_size.is_power_of_two(), + "page_size must be a power of two for the bitmask to work" + ); + let mask = !(page_size - 1); + let pages: HashSet = self.cells.keys().map(|&a| a & mask).collect(); + pages.len() as u64 + } + /// Pre-populate the private input memory region at `PRIVATE_INPUT_START_INDEX`. fn add_private_input(&mut self, private_input: &[u8]) { if private_input.is_empty() { @@ -2058,17 +2074,33 @@ struct CollectedOps { keccak_ops: Vec, } -/// Chunk raw ops and generate one trace table per chunk. +/// Chunk raw ops and generate one trace table per chunk. When `storage_mode` +/// is `Disk`, each chunk's main table is spilled to mmap before the next chunk +/// is built so peak heap usage stays bounded. fn chunk_and_generate( ops: &[T], max_rows: usize, generate: impl Fn(&[T]) -> TraceTable, -) -> Vec> { - if ops.is_empty() { - vec![generate(&[])] + #[cfg(feature = "disk-spill")] storage_mode: StorageMode, +) -> Result>, Error> { + let op_chunks: Vec<&[T]> = if ops.is_empty() { + vec![&[][..]] } else { - ops.chunks(max_rows).map(generate).collect() - } + ops.chunks(max_rows).collect() + }; + let mut tables = Vec::with_capacity(op_chunks.len()); + for chunk in op_chunks { + #[allow(unused_mut)] + let mut t = generate(chunk); + #[cfg(feature = "disk-spill")] + if storage_mode == StorageMode::Disk { + t.main_table + .spill_to_disk() + .map_err(|e| Error::Prover(format!("disk-spill trace: {e}")))?; + } + tables.push(t); + } + Ok(tables) } /// Phase 2: Collect and route all operations from CPU ops. @@ -2192,6 +2224,7 @@ fn build_traces( decode_pc_to_row: HashMap, register_state: RegisterState, max_rows: &super::MaxRowsConfig, + #[cfg(feature = "disk-spill")] storage_mode: StorageMode, private_input: &[u8], ) -> Result { let CollectedOps { @@ -2262,24 +2295,76 @@ fn build_traces( .ok_or(Error::MissingHaltEcall)?; let halt_timestamp = halt_op.timestamp; - let cpus = chunk_and_generate(&cpu_ops, max_rows.cpu, cpu::generate_cpu_trace); - let memws = chunk_and_generate(&memw_ops, max_rows.memw, memw::generate_memw_trace); + let cpus = chunk_and_generate( + &cpu_ops, + max_rows.cpu, + cpu::generate_cpu_trace, + #[cfg(feature = "disk-spill")] + storage_mode, + )?; + let memws = chunk_and_generate( + &memw_ops, + max_rows.memw, + memw::generate_memw_trace, + #[cfg(feature = "disk-spill")] + storage_mode, + )?; let memw_aligneds = chunk_and_generate( &memw_aligned_ops, max_rows.memw_aligned, memw_aligned::generate_memw_aligned_trace, - ); + #[cfg(feature = "disk-spill")] + storage_mode, + )?; let memw_registers = chunk_and_generate( &memw_register_ops, max_rows.memw_register, memw_register::generate_memw_register_trace, - ); - let loads = chunk_and_generate(&load_ops, max_rows.load, load::generate_load_trace); - let lts = chunk_and_generate(<_ops, max_rows.lt, lt::generate_lt_trace); - let shifts = chunk_and_generate(&shift_ops, max_rows.shift, shift::generate_shift_trace); - let muls = chunk_and_generate(&mul_ops, max_rows.mul, mul::generate_mul_trace); - let dvrms = chunk_and_generate(&dvrm_ops, max_rows.dvrm, dvrm::generate_dvrm_trace); - let branches = chunk_and_generate(&branch_ops, max_rows.branch, branch::generate_branch_trace); + #[cfg(feature = "disk-spill")] + storage_mode, + )?; + let loads = chunk_and_generate( + &load_ops, + max_rows.load, + load::generate_load_trace, + #[cfg(feature = "disk-spill")] + storage_mode, + )?; + let lts = chunk_and_generate( + <_ops, + max_rows.lt, + lt::generate_lt_trace, + #[cfg(feature = "disk-spill")] + storage_mode, + )?; + let shifts = chunk_and_generate( + &shift_ops, + max_rows.shift, + shift::generate_shift_trace, + #[cfg(feature = "disk-spill")] + storage_mode, + )?; + let muls = chunk_and_generate( + &mul_ops, + max_rows.mul, + mul::generate_mul_trace, + #[cfg(feature = "disk-spill")] + storage_mode, + )?; + let dvrms = chunk_and_generate( + &dvrm_ops, + max_rows.dvrm, + dvrm::generate_dvrm_trace, + #[cfg(feature = "disk-spill")] + storage_mode, + )?; + let branches = chunk_and_generate( + &branch_ops, + max_rows.branch, + branch::generate_branch_trace, + #[cfg(feature = "disk-spill")] + storage_mode, + )?; let mut bitwise = bitwise::generate_bitwise_trace(); bitwise::update_multiplicities(&mut bitwise, &bitwise_ops); @@ -2298,7 +2383,8 @@ fn build_traces( // Generate remaining traces in parallel (page, register, halt, commit). // chunk_and_generate already handled cpu, lt, memw, load, mul, dvrm, branch above. - let commit_trace = commit::generate_commit_trace(&commit_ops); + #[allow(unused_mut)] + let mut commit_trace = commit::generate_commit_trace(&commit_ops); // Generate keccak traces (core table + per-round table + preprocessed RC) let keccak_rnd_ops: Vec = keccak_ops @@ -2314,7 +2400,8 @@ fn build_traces( let mut keccak_rc_trace = keccak_rc::generate_keccak_rc_trace(); keccak_rc::update_multiplicities(&mut keccak_rc_trace, keccak_ops.len()); - let (pages, page_configs, register_trace, halt_trace); + #[allow(unused_mut)] + let (mut pages, page_configs, mut register_trace, mut halt_trace); #[cfg(feature = "parallel")] { let ((pages_val, register_val), halt_val) = rayon::join( @@ -2352,6 +2439,37 @@ fn build_traces( halt_trace = halt::generate_halt_trace(halt_timestamp); } + // Fixed-size and per-page tables aren't built through `chunk_and_generate`, + // so spill them here before returning. + #[cfg(feature = "disk-spill")] + if storage_mode == StorageMode::Disk { + bitwise + .main_table + .spill_to_disk() + .map_err(|e| Error::Prover(format!("disk-spill bitwise: {e}")))?; + decode + .main_table + .spill_to_disk() + .map_err(|e| Error::Prover(format!("disk-spill decode: {e}")))?; + commit_trace + .main_table + .spill_to_disk() + .map_err(|e| Error::Prover(format!("disk-spill commit: {e}")))?; + register_trace + .main_table + .spill_to_disk() + .map_err(|e| Error::Prover(format!("disk-spill register: {e}")))?; + halt_trace + .main_table + .spill_to_disk() + .map_err(|e| Error::Prover(format!("disk-spill halt: {e}")))?; + for page in &mut pages { + page.main_table + .spill_to_disk() + .map_err(|e| Error::Prover(format!("disk-spill page: {e}")))?; + } + } + Ok(Traces { cpus, bitwise, @@ -2377,6 +2495,231 @@ fn build_traces( }) } +/// Padded row count after chunking. +#[cfg(feature = "disk-spill")] +fn padded_chunked_rows(ops_count: usize, max_rows: usize) -> u64 { + // `max_rows <= 0` would loop forever. Called internally with const values > 0. + assert!(max_rows > 0, "max_rows must be positive"); + if ops_count == 0 { + return 4; // empty-chunk tables still allocate one 4-row padded chunk + } + let mut total: u64 = 0; + let mut remaining = ops_count; + while remaining > 0 { + let chunk_size = remaining.min(max_rows); + total += chunk_size.next_power_of_two().max(4) as u64; + remaining -= chunk_size; + } + total +} + +/// Per-table padded row counts plus auxiliary metrics for peak-heap estimation. +#[cfg(feature = "disk-spill")] +#[derive(Debug, Default, Clone)] +pub struct TableLengths { + pub cpu_padded_rows: u64, + pub memw_padded_rows: u64, + pub memw_aligned_padded_rows: u64, + pub memw_register_padded_rows: u64, + pub load_padded_rows: u64, + pub lt_padded_rows: u64, + pub shift_padded_rows: u64, + pub mul_padded_rows: u64, + pub dvrm_padded_rows: u64, + pub branch_padded_rows: u64, + pub commit_padded_rows: u64, + pub decode_rows: u64, + pub unique_page_count: u64, + pub cycle_count: u64, + pub unique_byte_count: u64, +} + +/// Per-table row counts from `logs`, without building op vectors. +/// Exact for tables that don't dedup; upper bound for LT, MUL, DVRM, BRANCH. +/// Must stay in sync with `Traces::from_elf_and_logs`. +#[cfg(feature = "disk-spill")] +pub fn count_table_lengths( + elf: &Elf, + logs: &[Log], + max_rows: &super::MaxRowsConfig, + private_input: &[u8], +) -> Result { + // Phase 0: ELF → instructions + DECODE row count. + let instructions = decode::instructions_from_elf(elf) + .map_err(|e| Error::Execution(format!("Failed to parse instructions: {e}")))?; + // Mirrors the padding inside `generate_decode_trace`. + let decode_rows = (instructions.len() as u64 + 1).next_power_of_two().max(2); + + // Memory + register state for partition predicates that need timestamps. + let mut memory_state = MemoryState::from_elf(elf); + memory_state.add_private_input(private_input); + let mut register_state = RegisterState::new(elf.entry_point); + + // Raw counts (pre-chunking + pre-padding). + let mut cpu_count = 0usize; + // Wide-MEMW counts bucketed by width, used by the LT-from-MEMW derivation. + let mut memw_by_width: [usize; 4] = [0; 4]; + let mut memw_aligned_count = 0usize; + let mut memw_register_count = 0usize; + let mut load_count = 0usize; + let mut lt_count = 0usize; + let mut shift_count = 0usize; + let mut mul_count = 0usize; + let mut dvrm_count = 0usize; + let mut branch_count = 0usize; + let mut commit_count = 0usize; + let mut current_commit_index = 0u32; + + let partition_memw = |op: &MemwOperation, + by_width: &mut [usize; 4], + aligned: &mut usize, + register: &mut usize| { + if is_register_op(op) { + *register += 1; + } else if is_aligned_op(op) { + *aligned += 1; + } else { + let idx = match op.width { + 1 => 0, + 2 => 1, + 4 => 2, + 8 => 3, + _ => return, + }; + by_width[idx] += 1; + } + }; + + for (i, log) in logs.iter().enumerate() { + let timestamp = (i as u64) * 4 + 4; + let instruction = instructions + .get(&log.current_pc) + .copied() + .ok_or(Error::MissingInstruction(log.current_pc))?; + let cpu_op = CpuOperation::from_log_and_instruction(log, timestamp, instruction); + cpu_count += 1; + + // Memory ops from load/store + if cpu_op.decode.op_load { + let (memw_op, _load_op, _bitwise) = + collect_load_op_from_cpu(&cpu_op, &mut memory_state); + partition_memw( + &memw_op, + &mut memw_by_width, + &mut memw_aligned_count, + &mut memw_register_count, + ); + load_count += 1; + } else if cpu_op.decode.op_store { + let memw_op = collect_store_op_from_cpu(&cpu_op, &mut memory_state); + partition_memw( + &memw_op, + &mut memw_by_width, + &mut memw_aligned_count, + &mut memw_register_count, + ); + } + + // Register accesses. + let reg_memw_ops = collect_register_ops_from_cpu(&cpu_op, &mut register_state); + for memw_op in ®_memw_ops { + partition_memw( + memw_op, + &mut memw_by_width, + &mut memw_aligned_count, + &mut memw_register_count, + ); + } + + // ECALL Commit + if cpu_op.ecall_commit { + // Match `expand_commit_operations_for_ecall`'s `0..=count` loop + // without building the op vector. + commit_count += (cpu_op.commit_count as usize) + .checked_add(1) + .ok_or_else(|| Error::Execution("commit_count overflows usize".into()))?; + let reg_commit_ops = + collect_commit_memw_ops(&cpu_op, &mut register_state, &mut memory_state); + for memw_op in ®_commit_ops { + partition_memw( + memw_op, + &mut memw_by_width, + &mut memw_aligned_count, + &mut memw_register_count, + ); + } + let count = u32::try_from(cpu_op.commit_count) + .map_err(|_| Error::Execution("commit_count exceeds u32 range".into()))?; + current_commit_index = current_commit_index + .checked_add(count) + .ok_or_else(|| Error::Execution("commit index exceeds u32 range".into()))?; + } + + // CPU-side per-instruction-kind counters + if cpu_op.decode.op_slt || cpu_op.decode.op_blt { + lt_count += 1; + } + if cpu_op.decode.op_shift { + shift_count += 1; + } + if cpu_op.decode.op_mul { + mul_count += 1; + } + if cpu_op.decode.op_divrem { + dvrm_count += 1; + } + if cpu_op.branch_cond { + branch_count += 1; + } + } + + // HALT finalization. Halt ops fall through to wide MEMW. + let halt_memw_ops = collect_halt_ops(&mut register_state); + for memw_op in &halt_memw_ops { + partition_memw( + memw_op, + &mut memw_by_width, + &mut memw_aligned_count, + &mut memw_register_count, + ); + } + + // LT ops derived from wide-MEMW and memw_aligned ops. + let memw_count = memw_by_width.iter().sum::(); + let lt_from_memw = + memw_by_width[0] + 2 * memw_by_width[1] + 4 * memw_by_width[2] + 8 * memw_by_width[3]; + lt_count += lt_from_memw + memw_aligned_count; + + // DVRM derives mul and lt ops. + mul_count += 2 * dvrm_count; + lt_count += dvrm_count; + + let unique_page_count = memory_state.unique_page_count(page::DEFAULT_PAGE_SIZE as u64); + let unique_byte_count = memory_state.cells.len() as u64; + let cycle_count = logs.len() as u64; + + Ok(TableLengths { + cpu_padded_rows: padded_chunked_rows(cpu_count, max_rows.cpu), + memw_padded_rows: padded_chunked_rows(memw_count, max_rows.memw), + memw_aligned_padded_rows: padded_chunked_rows(memw_aligned_count, max_rows.memw_aligned), + memw_register_padded_rows: padded_chunked_rows(memw_register_count, max_rows.memw_register), + load_padded_rows: padded_chunked_rows(load_count, max_rows.load), + lt_padded_rows: padded_chunked_rows(lt_count, max_rows.lt), + shift_padded_rows: padded_chunked_rows(shift_count, max_rows.shift), + mul_padded_rows: padded_chunked_rows(mul_count, max_rows.mul), + dvrm_padded_rows: padded_chunked_rows(dvrm_count, max_rows.dvrm), + branch_padded_rows: padded_chunked_rows(branch_count, max_rows.branch), + commit_padded_rows: commit_count + .checked_next_power_of_two() + .unwrap_or(usize::MAX) + .max(4) as u64, + decode_rows, + unique_page_count, + cycle_count, + unique_byte_count, + }) +} + impl Traces { /// Returns the total number of main-trace field elements across all tables. /// @@ -2722,6 +3065,7 @@ impl Traces { logs: &[Log], max_rows: &super::MaxRowsConfig, private_input: &[u8], + #[cfg(feature = "disk-spill")] storage_mode: StorageMode, ) -> Result { // Phase 0: ELF → DECODE + instructions // IMPORTANT: Use generate_decode_trace (same as compute_precomputed_commitment) @@ -2762,6 +3106,8 @@ impl Traces { decode_pc_to_row, register_state, max_rows, + #[cfg(feature = "disk-spill")] + storage_mode, private_input, ) } @@ -2812,6 +3158,8 @@ impl Traces { decode_pc_to_row, register_state, max_rows, + #[cfg(feature = "disk-spill")] + StorageMode::Ram, &[], ) } @@ -2875,7 +3223,14 @@ impl Traces { max_rows: &super::MaxRowsConfig, private_input: &[u8], ) -> Result { - let mut traces = Self::from_elf_and_logs(elf, logs, max_rows, private_input)?; + let mut traces = Self::from_elf_and_logs( + elf, + logs, + max_rows, + private_input, + #[cfg(feature = "disk-spill")] + StorageMode::Ram, + )?; traces.bitwise = bitwise::trim_zero_rows(traces.bitwise); Ok(traces) } diff --git a/prover/src/test_utils.rs b/prover/src/test_utils.rs index 1dcb768b2..83386a417 100644 --- a/prover/src/test_utils.rs +++ b/prover/src/test_utils.rs @@ -12,6 +12,7 @@ use std::path::PathBuf; +use crypto::fiat_shamir::is_transcript::IsStarkTranscript; use executor::elf::Elf; use executor::vm::execution::Executor; use executor::vm::instruction::decoding::Instruction; @@ -21,7 +22,12 @@ use math::field::element::FieldElement; use stark::constraints::transition::{TransitionConstraint, TransitionConstraintEvaluator}; use stark::lookup::{AirWithBuses, AuxiliaryTraceBuildData, NullBoundaryConstraintBuilder}; use stark::proof::options::ProofOptions; +use stark::proof::stark::MultiProof; +use stark::prover::{IsStarkProver, Prover, ProvingError}; +#[cfg(feature = "disk-spill")] +use stark::storage_mode::StorageMode; use stark::trace::TraceTable; +use stark::traits::AIR; use crate::constraints::cpu::create_all_cpu_constraints; use crate::tables::bitwise::{ @@ -81,6 +87,27 @@ pub type FE = FieldElement; pub type VmAir = AirWithBuses; +type GoldilocksPair<'a, PI> = ( + &'a dyn AIR, + &'a mut TraceTable, + &'a PI, +); + +pub fn multi_prove_ram( + air_trace_pairs: Vec>, + transcript: &mut (impl IsStarkTranscript + Clone + Send), +) -> Result, ProvingError> +where + PI: Send + Sync + Clone, +{ + Prover::::multi_prove( + air_trace_pairs, + transcript, + #[cfg(feature = "disk-spill")] + StorageMode::Ram, + ) +} + // ============================================================================= // ELF Execution Helpers // ============================================================================= diff --git a/prover/src/tests/bitwise_bus_tests.rs b/prover/src/tests/bitwise_bus_tests.rs index 317b22362..2a5fd31dd 100644 --- a/prover/src/tests/bitwise_bus_tests.rs +++ b/prover/src/tests/bitwise_bus_tests.rs @@ -15,12 +15,12 @@ use stark::lookup::{ NullBoundaryConstraintBuilder, Packing, }; use stark::proof::options::ProofOptions; -use stark::prover::{IsStarkProver, Prover}; use stark::trace::TraceTable; use stark::traits::AIR; use stark::verifier::{IsStarkVerifier, Verifier}; use crate::tables::types::{BusId, FE, GoldilocksExtension, GoldilocksField}; +use crate::test_utils::multi_prove_ram; type F = GoldilocksField; type E = GoldilocksExtension; @@ -197,7 +197,7 @@ fn prove_and_verify(sender_lookups: &[(u8, u8, u8)]) -> bool { ]; let multi_proof = - Prover::multi_prove(air_trace_pairs, &mut DefaultTranscript::::new(&[])).unwrap(); + multi_prove_ram(air_trace_pairs, &mut DefaultTranscript::::new(&[])).unwrap(); let airs: Vec<&dyn AIR> = vec![&sender_air, &receiver_air]; @@ -307,7 +307,7 @@ fn prove_and_verify_custom( ]; let multi_proof = - Prover::multi_prove(air_trace_pairs, &mut DefaultTranscript::::new(&[])).unwrap(); + multi_prove_ram(air_trace_pairs, &mut DefaultTranscript::::new(&[])).unwrap(); let airs: Vec<&dyn AIR> = vec![&sender_air, &receiver_air]; diff --git a/prover/src/tests/bitwise_tests.rs b/prover/src/tests/bitwise_tests.rs index 2848edef4..8337f8bf7 100644 --- a/prover/src/tests/bitwise_tests.rs +++ b/prover/src/tests/bitwise_tests.rs @@ -5,6 +5,7 @@ use crate::tables::bitwise::{ generate_bitwise_trace, is_preprocessed, preprocessed_commitment, row_index, }; use crate::tables::types::FE; +use crate::test_utils::multi_prove_ram; use math::field::element::FieldElement; use stark::proof::options::ProofOptions; @@ -590,7 +591,7 @@ mod soundness_tests { ]; let multi_proof = - Prover::multi_prove(air_trace_pairs, &mut DefaultTranscript::::new(&[])).unwrap(); + multi_prove_ram(air_trace_pairs, &mut DefaultTranscript::::new(&[])).unwrap(); let airs: Vec<&dyn AIR> = vec![&sender_air, &receiver_air]; @@ -638,7 +639,7 @@ mod soundness_tests { ]; let multi_proof = - Prover::multi_prove(air_trace_pairs, &mut DefaultTranscript::::new(&[])).unwrap(); + multi_prove_ram(air_trace_pairs, &mut DefaultTranscript::::new(&[])).unwrap(); let airs: Vec<&dyn AIR> = vec![&sender_air, &receiver_air]; @@ -708,7 +709,7 @@ mod soundness_tests { ]; let multi_proof = - Prover::multi_prove(air_trace_pairs, &mut DefaultTranscript::::new(&[])).unwrap(); + multi_prove_ram(air_trace_pairs, &mut DefaultTranscript::::new(&[])).unwrap(); // Verifier uses DIFFERENT AIR with honest commitment let verifier_airs: Vec<&dyn AIR> = diff --git a/prover/src/tests/branch_bus_tests.rs b/prover/src/tests/branch_bus_tests.rs index 1b3ae5071..c19a580ad 100644 --- a/prover/src/tests/branch_bus_tests.rs +++ b/prover/src/tests/branch_bus_tests.rs @@ -17,13 +17,13 @@ use stark::lookup::{ NullBoundaryConstraintBuilder, Packing, }; use stark::proof::options::ProofOptions; -use stark::prover::{IsStarkProver, Prover}; use stark::trace::TraceTable; use stark::traits::AIR; use stark::verifier::{IsStarkVerifier, Verifier}; use crate::tables::branch::{BranchOperation, cols, generate_branch_trace}; use crate::tables::types::{BusId, FE, GoldilocksExtension, GoldilocksField}; +use crate::test_utils::multi_prove_ram; type F = GoldilocksField; type E = GoldilocksExtension; @@ -340,7 +340,7 @@ fn prove_and_verify(ops: &[BranchOperation]) -> bool { ]; let multi_proof = - Prover::multi_prove(air_trace_pairs, &mut DefaultTranscript::::new(&[])).unwrap(); + multi_prove_ram(air_trace_pairs, &mut DefaultTranscript::::new(&[])).unwrap(); let airs: Vec<&dyn AIR> = vec![&sender_air, &receiver_air]; @@ -430,7 +430,7 @@ fn prove_and_verify_custom(ops: &[BranchOperation], receiver_rows: &[CustomBranc ]; let multi_proof = - Prover::multi_prove(air_trace_pairs, &mut DefaultTranscript::::new(&[])).unwrap(); + multi_prove_ram(air_trace_pairs, &mut DefaultTranscript::::new(&[])).unwrap(); let airs: Vec<&dyn AIR> = vec![&sender_air, &receiver_air]; diff --git a/prover/src/tests/count_table_lengths_drift_tests.rs b/prover/src/tests/count_table_lengths_drift_tests.rs new file mode 100644 index 000000000..6855fcb5b --- /dev/null +++ b/prover/src/tests/count_table_lengths_drift_tests.rs @@ -0,0 +1,93 @@ +//! Asserts `count_table_lengths` matches `Traces::from_elf_and_logs` row counts. + +use crate::tables::MaxRowsConfig; +use crate::tables::trace_builder::{Traces, count_table_lengths}; +use crate::test_utils::run_asm_elf; + +#[test] +fn count_table_lengths_matches_traces() { + let (elf, logs, _) = run_asm_elf("fib_iterative_372k"); + let max_rows = MaxRowsConfig::default(); + + let predicted = + count_table_lengths(&elf, &logs, &max_rows, &[]).expect("count_table_lengths succeeds"); + let traces = Traces::from_elf_and_logs_minimal(&elf, &logs, &max_rows, &[]) + .expect("trace build succeeds"); + + let sum_heights = |tables: &[stark::trace::TraceTable<_, _>]| -> u64 { + tables.iter().map(|t| t.main_table.height as u64).sum() + }; + + // Exact-match tables: predicted row count equals built trace. + assert_eq!(predicted.cpu_padded_rows, sum_heights(&traces.cpus), "cpu"); + assert_eq!( + predicted.memw_padded_rows, + sum_heights(&traces.memws), + "memw" + ); + assert_eq!( + predicted.memw_aligned_padded_rows, + sum_heights(&traces.memw_aligneds), + "memw_aligned" + ); + assert_eq!( + predicted.memw_register_padded_rows, + sum_heights(&traces.memw_registers), + "memw_register" + ); + assert_eq!( + predicted.load_padded_rows, + sum_heights(&traces.loads), + "load" + ); + assert_eq!( + predicted.shift_padded_rows, + sum_heights(&traces.shifts), + "shift" + ); + assert_eq!( + predicted.commit_padded_rows, traces.commit.main_table.height as u64, + "commit" + ); + assert_eq!( + predicted.decode_rows, traces.decode.main_table.height as u64, + "decode" + ); + + // Upper-bound tables: predicted is `>=` actual (LT/MUL/DVRM/BRANCH dedup ops). + assert!( + predicted.lt_padded_rows >= sum_heights(&traces.lts), + "lt: predicted={} actual={}", + predicted.lt_padded_rows, + sum_heights(&traces.lts) + ); + assert!( + predicted.mul_padded_rows >= sum_heights(&traces.muls), + "mul: predicted={} actual={}", + predicted.mul_padded_rows, + sum_heights(&traces.muls) + ); + assert!( + predicted.dvrm_padded_rows >= sum_heights(&traces.dvrms), + "dvrm: predicted={} actual={}", + predicted.dvrm_padded_rows, + sum_heights(&traces.dvrms) + ); + assert!( + predicted.branch_padded_rows >= sum_heights(&traces.branches), + "branch: predicted={} actual={}", + predicted.branch_padded_rows, + sum_heights(&traces.branches) + ); + + // Auxiliary scalars. + assert_eq!(predicted.cycle_count, logs.len() as u64, "cycle_count"); + assert_eq!( + predicted.unique_page_count, + traces.pages.len() as u64, + "unique_page_count" + ); + + // Mirrors hardcoded `halt_rows = 1` in `auto_storage::table_specs`. + assert_eq!(traces.halt.main_table.height, 1, "halt_rows"); +} diff --git a/prover/src/tests/decode_tests.rs b/prover/src/tests/decode_tests.rs index 852c2ccd6..f1f60e5ba 100644 --- a/prover/src/tests/decode_tests.rs +++ b/prover/src/tests/decode_tests.rs @@ -9,7 +9,9 @@ use crate::tables::decode::{ DecodeEntry, bus_interactions, cols, generate_decode_trace, instructions_from_elf, update_multiplicities, }; +use crate::tables::trace_builder::Traces; use crate::tables::types::{FE, packed_decode as bits}; +use crate::test_utils::multi_prove_ram; use crate::test_utils::run_asm_elf; // ========================================================================= @@ -867,7 +869,6 @@ fn test_instructions_from_elf_includes_all_executable() { fn test_decode_soundness_different_elf_rejected() { use crypto::fiat_shamir::default_transcript::DefaultTranscript; use stark::proof::options::ProofOptions; - use stark::prover::{IsStarkProver, Prover}; use stark::traits::AIR; use stark::verifier::{IsStarkVerifier, Verifier}; @@ -948,7 +949,7 @@ fn test_decode_soundness_different_elf_rejected() { (&prover_decode_air, &mut traces.decode, &()), ]; - let proof = Prover::multi_prove(air_trace_pairs, &mut DefaultTranscript::::new(&[])) + let proof = multi_prove_ram(air_trace_pairs, &mut DefaultTranscript::::new(&[])) .expect("Prover failed to generate proof"); // ========================================================================= @@ -999,11 +1000,9 @@ fn test_decode_soundness_different_elf_rejected() { fn test_decode_soundness_same_elf_accepted() { use crypto::fiat_shamir::default_transcript::DefaultTranscript; use stark::proof::options::ProofOptions; - use stark::prover::{IsStarkProver, Prover}; use stark::verifier::{IsStarkVerifier, Verifier}; use crate::VmAirs; - use crate::tables::trace_builder::Traces; use crate::tables::types::GoldilocksExtension; type E = GoldilocksExtension; @@ -1031,8 +1030,15 @@ fn test_decode_soundness_same_elf_accepted() { .expect("Failed to create executor"); let result = executor.run().expect("Failed to run program"); - let mut traces = - Traces::from_elf_and_logs(&prover_elf, &result.logs, &Default::default(), &[]).unwrap(); + let mut traces = Traces::from_elf_and_logs( + &prover_elf, + &result.logs, + &Default::default(), + &[], + #[cfg(feature = "disk-spill")] + stark::storage_mode::StorageMode::Ram, + ) + .unwrap(); let table_counts = traces.table_counts(); let prover_airs = VmAirs::new( &prover_elf, @@ -1042,7 +1048,7 @@ fn test_decode_soundness_same_elf_accepted() { &table_counts, ); - let proof = Prover::multi_prove( + let proof = multi_prove_ram( prover_airs.air_trace_pairs(&mut traces), &mut DefaultTranscript::::new(&[]), ) diff --git a/prover/src/tests/disk_spill_tests.rs b/prover/src/tests/disk_spill_tests.rs new file mode 100644 index 000000000..e019fa456 --- /dev/null +++ b/prover/src/tests/disk_spill_tests.rs @@ -0,0 +1,58 @@ +//! End-to-end tests forcing `StorageMode::Disk` via the `FORCE_DISK_SPILL` env var. +//! +//! Run with `FORCE_DISK_SPILL=1` set in the environment, e.g. +//! `FORCE_DISK_SPILL=1 cargo test --features disk-spill disk_spill`. Tests +//! fail fast if the var is unset to avoid silent loss of coverage. + +use crate::VmProof; +use crate::tables::MaxRowsConfig; +use crate::test_utils::asm_elf_bytes; +use stark::proof::options::GoldilocksCubicProofOptions; + +fn require_force_disk_spill() { + assert_eq!( + std::env::var("FORCE_DISK_SPILL").as_deref(), + Ok("1"), + "set FORCE_DISK_SPILL=1 before running disk-spill tests", + ); +} + +#[test] +fn test_disk_spill_prove_verify_and_roundtrip_small() { + require_force_disk_spill(); + let elf_bytes = asm_elf_bytes("sub"); + let opts = GoldilocksCubicProofOptions::with_blowup(2).expect("blowup=2 is always valid"); + let proof = crate::prove_with_options(&elf_bytes, &opts, &MaxRowsConfig::default()) + .expect("prove failed"); + assert!( + crate::verify_with_options(&proof, &elf_bytes, &opts).expect("verify failed"), + "verification returned false" + ); + + let bytes = bincode::serialize(&proof).expect("serialize failed"); + let proof2: VmProof = bincode::deserialize(&bytes).expect("deserialize failed"); + assert!( + crate::verify_with_options(&proof2, &elf_bytes, &opts).expect("verify failed"), + "verification failed after serialization roundtrip" + ); +} + +#[test] +fn test_disk_spill_prove_verify_and_roundtrip_chunked() { + require_force_disk_spill(); + let elf_bytes = asm_elf_bytes("all_instructions_64"); + let opts = GoldilocksCubicProofOptions::with_blowup(2).expect("blowup=2 is always valid"); + let proof = crate::prove_with_options(&elf_bytes, &opts, &MaxRowsConfig::small()) + .expect("prove failed"); + assert!( + crate::verify_with_options(&proof, &elf_bytes, &opts).expect("verify failed"), + "verification returned false" + ); + + let bytes = bincode::serialize(&proof).expect("serialize failed"); + let proof2: VmProof = bincode::deserialize(&bytes).expect("deserialize failed"); + assert!( + crate::verify_with_options(&proof2, &elf_bytes, &opts).expect("verify failed"), + "verification failed after serialization roundtrip (chunked)" + ); +} diff --git a/prover/src/tests/lt_bus_tests.rs b/prover/src/tests/lt_bus_tests.rs index d794995b7..dcc555780 100644 --- a/prover/src/tests/lt_bus_tests.rs +++ b/prover/src/tests/lt_bus_tests.rs @@ -17,13 +17,13 @@ use stark::lookup::{ NullBoundaryConstraintBuilder, Packing, }; use stark::proof::options::ProofOptions; -use stark::prover::{IsStarkProver, Prover}; use stark::trace::TraceTable; use stark::traits::AIR; use stark::verifier::{IsStarkVerifier, Verifier}; use crate::tables::lt::{LtOperation, cols, generate_lt_trace}; use crate::tables::types::{BusId, FE, GoldilocksExtension, GoldilocksField}; +use crate::test_utils::multi_prove_ram; type F = GoldilocksField; type E = GoldilocksExtension; @@ -293,7 +293,7 @@ fn prove_and_verify(ops: &[LtOperation]) -> bool { ]; let multi_proof = - Prover::multi_prove(air_trace_pairs, &mut DefaultTranscript::::new(&[])).unwrap(); + multi_prove_ram(air_trace_pairs, &mut DefaultTranscript::::new(&[])).unwrap(); let airs: Vec<&dyn AIR> = vec![&sender_air, &receiver_air]; @@ -377,7 +377,7 @@ fn prove_and_verify_custom(ops: &[LtOperation], receiver_rows: &[CustomLtRow]) - ]; let multi_proof = - Prover::multi_prove(air_trace_pairs, &mut DefaultTranscript::::new(&[])).unwrap(); + multi_prove_ram(air_trace_pairs, &mut DefaultTranscript::::new(&[])).unwrap(); let airs: Vec<&dyn AIR> = vec![&sender_air, &receiver_air]; diff --git a/prover/src/tests/mod.rs b/prover/src/tests/mod.rs index 957845c95..dc5f3fe22 100644 --- a/prover/src/tests/mod.rs +++ b/prover/src/tests/mod.rs @@ -10,10 +10,14 @@ pub mod branch_constraints_tests; pub mod commit_tests; #[cfg(test)] pub mod constraints_tests; +#[cfg(all(test, feature = "disk-spill"))] +pub mod count_table_lengths_drift_tests; #[cfg(test)] pub mod cpu_tests; #[cfg(test)] pub mod decode_tests; +#[cfg(all(test, feature = "disk-spill"))] +pub mod disk_spill_tests; #[cfg(test)] pub mod dvrm_tests; #[cfg(test)] diff --git a/prover/src/tests/prove_elfs_tests.rs b/prover/src/tests/prove_elfs_tests.rs index 736fcd78e..53149a943 100644 --- a/prover/src/tests/prove_elfs_tests.rs +++ b/prover/src/tests/prove_elfs_tests.rs @@ -18,20 +18,17 @@ use math::field::element::FieldElement; use stark::constraints::transition::TransitionConstraintEvaluator; use stark::lookup::{AirWithBuses, AuxiliaryTraceBuildData}; use stark::proof::options::ProofOptions; -use stark::prover::{IsStarkProver, Prover}; use stark::traits::AIR; use stark::verifier::{IsStarkVerifier, Verifier}; -use crate::VmProof; -use crate::tables::MaxRowsConfig; use crate::tables::trace_builder::Traces; use crate::tables::types::{GoldilocksExtension, GoldilocksField}; use executor::elf::Elf; -use executor::vm::execution::Executor; // Import shared utilities use crate::VmAirs; +use crate::test_utils::multi_prove_ram; use crate::test_utils::run_asm_elf; type F = GoldilocksField; @@ -63,11 +60,11 @@ fn prove_and_verify_vm_minimal(elf: &Elf, traces: &mut Traces) -> bool { // Build air_trace_pairs for all tables let air_trace_pairs = airs.air_trace_pairs(traces); - let multi_proof = - match Prover::multi_prove(air_trace_pairs, &mut DefaultTranscript::::new(&[])) { - Ok(proof) => proof, - Err(_) => return false, - }; + let multi_proof = match multi_prove_ram(air_trace_pairs, &mut DefaultTranscript::::new(&[])) + { + Ok(proof) => proof, + Err(_) => return false, + }; // Compute the verifier-side expected COMMIT bus balance from public output bytes let expected_bus_balance = crate::compute_expected_commit_bus_balance( @@ -86,79 +83,6 @@ fn prove_and_verify_vm_minimal(elf: &Elf, traces: &mut Traces) -> bool { ) } -/// Like [`crate::prove_with_options_and_inputs`] but with trimmed bitwise (TEST ONLY). -/// -/// ~100x faster than the production path. Same unsoundness caveats as -/// [`Traces::from_elf_and_logs_minimal`]. The full preprocessed bitwise -/// path is covered by `test_prove_elfs_all_instructions_64_full`. -fn prove_vm_minimal(elf_bytes: &[u8], private_inputs: &[u8], max_rows: &MaxRowsConfig) -> VmProof { - let proof_options = ProofOptions::default_test_options(); - let elf = Elf::load(elf_bytes).expect("ELF load"); - let executor = Executor::new(&elf, private_inputs.to_vec()).expect("executor"); - let result = executor.run().expect("execution"); - let mut traces = - Traces::from_elf_and_logs_minimal(&elf, &result.logs, max_rows, private_inputs).unwrap(); - let table_counts = traces.table_counts(); - let airs = VmAirs::new( - &elf, - &proof_options, - true, - &traces.page_configs, - &table_counts, - ); - let runtime_page_ranges = traces.runtime_page_ranges(); - let proof = Prover::multi_prove( - airs.air_trace_pairs(&mut traces), - &mut DefaultTranscript::::new(&[]), - ) - .expect("prove"); - let num_private_input_pages = traces - .page_configs - .iter() - .filter(|c| c.is_private_input) - .count(); - VmProof { - proof, - runtime_page_ranges, - table_counts, - public_output: traces.public_output_bytes.clone(), - num_private_input_pages, - } -} - -/// Like [`crate::verify_with_options`] but matches the minimal bitwise AIR. -/// -/// Must be used to verify proofs from [`prove_vm_minimal`]. -fn verify_vm_minimal(vm_proof: &VmProof, elf_bytes: &[u8]) -> bool { - let proof_options = ProofOptions::default_test_options(); - let elf = Elf::load(elf_bytes).expect("ELF load"); - let page_configs = Traces::page_configs_from_elf_and_runtime( - &elf, - &vm_proof.runtime_page_ranges, - vm_proof.num_private_input_pages, - ); - let airs = VmAirs::new( - &elf, - &proof_options, - true, - &page_configs, - &vm_proof.table_counts, - ); - let air_refs = airs.air_refs(); - let expected_bus_balance = crate::compute_expected_commit_bus_balance( - &air_refs, - &vm_proof.proof, - &vm_proof.public_output, - ) - .expect("fingerprint collision in test"); - Verifier::multi_verify( - &air_refs, - &vm_proof.proof, - &mut DefaultTranscript::::new(&[]), - &expected_bus_balance, - ) -} - // ============================================================================= // Integration tests // ============================================================================= @@ -201,7 +125,7 @@ fn test_cpu_only_no_bus() { _, )> = vec![(&cpu_air, &mut cpu_trace, &())]; - let multi_proof = Prover::multi_prove(air_trace_pairs, &mut DefaultTranscript::::new(&[])) + let multi_proof = multi_prove_ram(air_trace_pairs, &mut DefaultTranscript::::new(&[])) .expect("Prover failed"); let airs: Vec<&dyn AIR> = vec![&cpu_air]; @@ -233,7 +157,7 @@ fn test_cpu_only_no_bus() { fn test_prove_elfs_sub_fast() { let _ = env_logger::builder().is_test(true).try_init(); let (elf, logs, _instructions) = run_asm_elf("sub"); - // Use from_elf_and_logs_minimal to get PAGE and REGISTER tables for Memory bus + // Use from_elf_and_logs to get PAGE and REGISTER tables for Memory bus let mut traces = Traces::from_elf_and_logs_minimal(&elf, &logs, &Default::default(), &[]).unwrap(); @@ -805,7 +729,8 @@ fn test_prove_elfs_keccak() { let (elf, logs, _instructions) = run_asm_elf("test_keccak"); // Must use from_elf_and_logs (not from_logs_minimal) because keccak accesses // RAM (stack memory), which requires PAGE tables for Memory bus balance. - let mut traces = Traces::from_elf_and_logs(&elf, &logs, &Default::default(), &[]).unwrap(); + let mut traces = + Traces::from_elf_and_logs_minimal(&elf, &logs, &Default::default(), &[]).unwrap(); assert!( prove_and_verify_vm_minimal(&elf, &mut traces), @@ -841,7 +766,7 @@ fn test_prove_elfs_keccak_multi_call() { ); let mut traces = - Traces::from_elf_and_logs(&elf, &result.logs, &Default::default(), &[]).unwrap(); + Traces::from_elf_and_logs_minimal(&elf, &result.logs, &Default::default(), &[]).unwrap(); assert_eq!( traces.public_output_bytes, result.return_values.memory_values @@ -874,7 +799,7 @@ fn test_prove_elfs_keccak_unaligned_state_addr() { executor::vm::execution::Executor::new(&elf, vec![]).expect("Failed to create executor"); let result = executor.run().expect("Failed to run program"); let mut traces = - Traces::from_elf_and_logs(&elf, &result.logs, &Default::default(), &[]).unwrap(); + Traces::from_elf_and_logs_minimal(&elf, &result.logs, &Default::default(), &[]).unwrap(); // Tamper the first real keccak row: replace addr(1) (a byte cell) with a // value outside [0, 256). The new IS_BYTE bus sender will emit this @@ -945,7 +870,7 @@ fn test_prove_elfs_test_commit_4_wrong_pages_rejected() { &traces.page_configs, &table_counts, ); - let proof = Prover::multi_prove( + let proof = multi_prove_ram( prover_airs.air_trace_pairs(&mut traces), &mut DefaultTranscript::::new(&[]), ) @@ -1304,7 +1229,15 @@ fn test_debug_memory_tokens_sb_sh() { use std::collections::HashMap; let (elf, logs, _instructions) = run_asm_elf("test_sb_sh_8"); - let traces = Traces::from_elf_and_logs(&elf, &logs, &Default::default(), &[]).unwrap(); + let traces = Traces::from_elf_and_logs( + &elf, + &logs, + &Default::default(), + &[], + #[cfg(feature = "disk-spill")] + stark::storage_mode::StorageMode::Ram, + ) + .unwrap(); let memw = &traces.memws[0]; // Small test: single MEMW chunk println!("DEBUG: test_sb_sh_8 Memory bus tokens (FULL)"); @@ -1674,7 +1607,7 @@ fn test_deep_stack_runtime_pages_roundtrip() { &traces.page_configs, &table_counts, ); - let proof = Prover::multi_prove( + let proof = multi_prove_ram( prover_airs.air_trace_pairs(&mut traces), &mut DefaultTranscript::::new(&[]), ) @@ -1729,7 +1662,7 @@ fn test_deep_stack_missing_pages_rejected() { &traces.page_configs, &table_counts, ); - let proof = Prover::multi_prove( + let proof = multi_prove_ram( prover_airs.air_trace_pairs(&mut traces), &mut DefaultTranscript::::new(&[]), ) @@ -1819,7 +1752,7 @@ fn test_heap_alloc_runtime_pages_roundtrip() { &traces.page_configs, &table_counts, ); - let proof = Prover::multi_prove( + let proof = multi_prove_ram( prover_airs.air_trace_pairs(&mut traces), &mut DefaultTranscript::::new(&[]), ) @@ -1991,7 +1924,7 @@ fn test_crafted_zero_count_proof_must_not_verify() { (&airs.decode, &mut decode_trace, &()), ]; - let proof = Prover::multi_prove(pairs, &mut DefaultTranscript::::new(&[])) + let proof = multi_prove_ram(pairs, &mut DefaultTranscript::::new(&[])) .expect("Proof generation should succeed"); assert_eq!(proof.proofs.len(), 2); @@ -2010,9 +1943,11 @@ fn test_crafted_zero_count_proof_must_not_verify() { #[test] fn test_small_max_rows_splits_tables() { let elf_bytes = crate::test_utils::asm_elf_bytes("all_instructions_64"); + let proof_options = ProofOptions::default_test_options(); let max_rows = crate::tables::MaxRowsConfig::small(); - let vm_proof = prove_vm_minimal(&elf_bytes, &[], &max_rows); + let vm_proof = crate::prove_with_options(&elf_bytes, &proof_options, &max_rows) + .expect("Prover should succeed with small max_rows"); // With 2^5 max rows and 64+ instructions, tables should have multiple chunks. assert!( @@ -2021,10 +1956,9 @@ fn test_small_max_rows_splits_tables() { vm_proof.table_counts.cpu ); - assert!( - verify_vm_minimal(&vm_proof, &elf_bytes), - "Proof with small max_rows should verify" - ); + let verified = crate::verify_with_options(&vm_proof, &elf_bytes, &proof_options) + .expect("Verifier should not error"); + assert!(verified, "Proof with small max_rows should verify"); } // ============================================================================= @@ -2089,11 +2023,8 @@ fn test_verify_rejects_inflated_table_counts() { #[test] fn test_prove_wsuffix_64bit() { let elf_bytes = crate::test_utils::asm_elf_bytes("test_wsuffix_64bit"); - let vm_proof = prove_vm_minimal(&elf_bytes, &[], &Default::default()); - assert!( - verify_vm_minimal(&vm_proof, &elf_bytes), - "W-suffix 64-bit register test should verify" - ); + let result = crate::prove_and_verify(&elf_bytes).expect("prove_and_verify failed"); + assert!(result, "W-suffix 64-bit register test should verify"); } /// Proves a minimal Rust std program that uses `init_allocator()` and @@ -2110,9 +2041,9 @@ fn test_prove_allocator_minimal_reproducer() { let elf_bytes = std::fs::read(workspace_root.join("executor/program_artifacts/rust/allocator.elf")) .expect("allocator.elf not found — run `make compile-programs-rust`"); - let proof = prove_vm_minimal(&elf_bytes, &[], &Default::default()); + let proof = crate::prove(&elf_bytes).expect("prove should succeed"); assert!( - verify_vm_minimal(&proof, &elf_bytes), + crate::verify(&proof, &elf_bytes).expect("verify should not error"), "allocator.elf should verify" ); assert_eq!(proof.public_output, b"Hello World"); @@ -2129,9 +2060,9 @@ fn test_pure_commit_rust() { let elf_bytes = std::fs::read(workspace_root.join("executor/program_artifacts/rust/pure_commit.elf")) .expect("pure_commit.elf not found — run `make compile-programs-rust`"); - let proof = prove_vm_minimal(&elf_bytes, &[], &Default::default()); + let proof = crate::prove(&elf_bytes).expect("prove should succeed"); assert!( - verify_vm_minimal(&proof, &elf_bytes), + crate::verify(&proof, &elf_bytes).expect("verify should not error"), "pure_commit.elf should verify" ); assert_eq!(proof.public_output, vec![0xAA, 0xBB, 0xCC, 0xDD]); @@ -2154,8 +2085,12 @@ fn test_prove_with_input_empty() { fn test_prove_private_input_xpage() { let elf_bytes = crate::test_utils::asm_elf_bytes("test_private_input_xpage"); let input: Vec = (0u8..16).collect(); - let proof = prove_vm_minimal(&elf_bytes, &input, &Default::default()); - assert!(verify_vm_minimal(&proof, &elf_bytes), "proof should verify"); + let proof = + crate::prove_with_inputs(&elf_bytes, &input).expect("prove_with_inputs should succeed"); + assert!( + crate::verify(&proof, &elf_bytes).expect("verify should not error"), + "proof should verify" + ); assert_eq!(proof.public_output, input[4..12].to_vec()); } @@ -2167,8 +2102,11 @@ fn test_prove_private_input_different_values() { 0xAA, 0xBB, 0xCC, 0xDD, 0xEE, 0xFF, 0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88, 0x99, 0x00, ]; - let proof = prove_vm_minimal(&elf_bytes, &input, &Default::default()); - assert!(verify_vm_minimal(&proof, &elf_bytes), "proof should verify"); + let proof = crate::prove_with_inputs(&elf_bytes, &input).expect("prove"); + assert!( + crate::verify(&proof, &elf_bytes).expect("verify"), + "proof should verify" + ); assert_eq!(proof.public_output, input[4..12].to_vec()); } @@ -2208,9 +2146,9 @@ fn test_prove_commit_sum() { std::fs::read(workspace_root.join("executor/program_artifacts/rust/commit_sum.elf")) .expect("commit_sum.elf not found — run `make compile-programs-rust`"); let input = &[3u8, 5u8]; - let proof = prove_vm_minimal(&elf_bytes, input, &Default::default()); + let proof = crate::prove_with_inputs(&elf_bytes, input).expect("prove should succeed"); assert!( - verify_vm_minimal(&proof, &elf_bytes), + crate::verify(&proof, &elf_bytes).expect("verify should not error"), "commit_sum should verify" ); assert_eq!(proof.public_output, vec![8u8]); @@ -2326,7 +2264,7 @@ fn test_verify_rejects_private_input_with_tampered_public_output() { let vm_proof = crate::prove_with_inputs(&elf_bytes, &input).expect("prove should succeed"); assert!( - crate::verify(&vm_proof, &elf_bytes).expect("verify should not error"), + crate::verify(&vm_proof, &elf_bytes).expect("verify"), "Baseline must verify" ); @@ -2375,11 +2313,8 @@ fn test_proof_does_not_contain_private_input_field() { #[test] fn test_addiw_neg_immediate() { let elf_bytes = crate::test_utils::asm_elf_bytes("test_addiw_neg"); - let proof = prove_vm_minimal(&elf_bytes, &[], &Default::default()); - assert!( - verify_vm_minimal(&proof, &elf_bytes), - "addiw with negative immediate should verify" - ); + let result = crate::prove_and_verify(&elf_bytes).expect("prove_and_verify failed"); + assert!(result, "addiw with negative immediate should verify"); } /// Regression test: both main and aux field element counts must be nonzero for any real ELF. diff --git a/prover/tests/calibration.rs b/prover/tests/calibration.rs new file mode 100644 index 000000000..ff11bcf4b --- /dev/null +++ b/prover/tests/calibration.rs @@ -0,0 +1,77 @@ +//! Asserts predicted `peak_bytes` does not underestimate jemalloc-measured +//! heap during a proof. Lives in its own integration-test binary so that +//! `#[global_allocator]` and `tikv_jemalloc_ctl::stats::allocated` reads are +//! isolated from the rest of the prover test suite. + +#![cfg(feature = "disk-spill")] + +use lambda_vm_prover::auto_storage::{SAFETY_FRACTION_DEN, SAFETY_FRACTION_NUM, peak_bytes}; +use lambda_vm_prover::prove_with_options_and_inputs; +use lambda_vm_prover::tables::MaxRowsConfig; +use lambda_vm_prover::tables::trace_builder::count_table_lengths; +use lambda_vm_prover::test_utils::{asm_elf_bytes, run_asm_elf}; +use stark::proof::options::GoldilocksCubicProofOptions; +use stark::prover::table_parallelism; +use std::sync::Arc; +use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; +use std::thread; +use std::time::Duration; +use tikv_jemalloc_ctl::{epoch, stats}; + +#[global_allocator] +static ALLOC: tikv_jemallocator::Jemalloc = tikv_jemallocator::Jemalloc; + +fn allocated_bytes() -> usize { + epoch::advance().ok(); + stats::allocated::read().unwrap_or(0) +} + +#[test] +fn peak_bytes_does_not_underestimate_measured_heap() { + let (elf, logs, _) = run_asm_elf("fib_iterative_372k"); + let elf_bytes = asm_elf_bytes("fib_iterative_372k"); + + let max_rows = MaxRowsConfig::default(); + let lengths = + count_table_lengths(&elf, &logs, &max_rows, &[]).expect("count_table_lengths succeeds"); + + let opts = GoldilocksCubicProofOptions::with_blowup(2).expect("blowup=2 is valid"); + let predicted = peak_bytes(&lengths, opts.blowup_factor, table_parallelism()) as usize; + + drop(logs); + + let baseline = allocated_bytes(); + let peak = Arc::new(AtomicUsize::new(baseline)); + let stop = Arc::new(AtomicBool::new(false)); + + let sampler = { + let peak = Arc::clone(&peak); + let stop = Arc::clone(&stop); + thread::spawn(move || { + while !stop.load(Ordering::Relaxed) { + peak.fetch_max(allocated_bytes(), Ordering::Relaxed); + thread::sleep(Duration::from_millis(10)); + } + }) + }; + + let _proof = + prove_with_options_and_inputs(&elf_bytes, &[], &opts, &max_rows).expect("proof succeeds"); + + stop.store(true, Ordering::Relaxed); + sampler.join().expect("sampler joins"); + + let measured = peak.load(Ordering::Relaxed).saturating_sub(baseline); + + eprintln!( + "peak_bytes calibration: predicted={predicted} bytes, measured_heap={measured} bytes, ratio={:.2}", + predicted as f64 / measured as f64 + ); + + let safety_num = SAFETY_FRACTION_NUM as usize; + let safety_den = SAFETY_FRACTION_DEN as usize; + assert!( + predicted.saturating_mul(safety_den) >= measured.saturating_mul(safety_num), + "peak_bytes underestimates measured heap: predicted={predicted}, measured={measured}" + ); +} diff --git a/scripts/calibrate_threshold.sh b/scripts/calibrate_threshold.sh new file mode 100755 index 000000000..795eeb777 --- /dev/null +++ b/scripts/calibrate_threshold.sh @@ -0,0 +1,53 @@ +#!/bin/bash +# Calibrate the auto-disk-spill threshold: actual RSS / estimated_peak_bytes. +# +# Usage: calibrate_threshold.sh elf1.elf [elf2.elf ...] +# +# Builds CLI with jemalloc-stats, runs each ELF under `/usr/bin/time -v`, +# and prints predicted vs measured peak. Use the rss/pred ratio to adjust +# the safety margin in `auto_storage.rs`. + +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +ROOT_DIR="$(cd "$SCRIPT_DIR/.." && pwd)" +OUT="/tmp/calibrate_threshold" + +mkdir -p "$OUT" +rm -f "$OUT"/*.txt + +echo "Building CLI with jemalloc-stats and disk-spill..." +cargo build --release -p cli --features jemalloc-stats,disk-spill --manifest-path "$ROOT_DIR/Cargo.toml" 2>&1 | tail -1 + +CLI="$ROOT_DIR/target/release/cli" + +printf "\n%-55s %10s %10s %10s %10s %10s\n" \ + "ELF" "pred(MB)" "heap(MB)" "rss(MB)" "rss/pred" "heap/pred" +printf '%.0s-' {1..110} +printf '\n' + +for elf in "$@"; do + name=$(basename "$elf") + RUST_LOG=info /usr/bin/time -v "$CLI" prove "$elf" -o "$OUT/proof.bin" \ + > "$OUT/out.txt" 2> "$OUT/err.txt" || { + echo "FAIL: $name" + tail -5 "$OUT/err.txt" + continue + } + + pred=$(grep -o 'estimated_peak_bytes: [0-9]*' "$OUT/err.txt" | awk '{print $2}') + heap_mb=$(grep -o 'Peak heap: [0-9]*' "$OUT/out.txt" | awk '{print $3}') + rss_kb=$(grep "Maximum resident set size" "$OUT/err.txt" | awk '{print $NF}') + + awk -v name="$name" -v p="$pred" -v h="$heap_mb" -v r="$rss_kb" 'BEGIN { + pred_mb = p / 1024 / 1024 + rss_mb = r / 1024 + printf "%-55s %10.0f %10.0f %10.0f %10.2f %10.2f\n", + name, pred_mb, h, rss_mb, rss_mb/pred_mb, h/pred_mb + }' + + rm -f "$OUT/proof.bin" +done + +echo "" +echo "Use the rss/pred ratio to adjust the safety margin in auto_storage.rs."