Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1,286 changes: 949 additions & 337 deletions Cargo.lock

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions bin/cli/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,5 +13,7 @@ tikv-jemallocator = "0.6"
tikv-jemalloc-ctl = { version = "0.6", features = ["stats"], optional = true }

[features]
default = ["disk-spill"]
jemalloc-stats = ["dep:tikv-jemalloc-ctl"]
disk-spill = ["prover/disk-spill"]
instruments = ["prover/instruments"]
43 changes: 39 additions & 4 deletions bin/cli/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,10 @@ enum Commands {
/// Print timing breakdown
#[arg(long)]
time: bool,

/// Maximum rows per table chunk (power of 2). Smaller = less memory, more chunks.
#[arg(long)]
max_rows: Option<usize>,
},

/// Verify a proof bundle
Expand Down Expand Up @@ -155,7 +159,8 @@ fn main() -> ExitCode {
output,
blowup,
time,
} => cmd_prove(elf, output, blowup, time),
max_rows,
} => cmd_prove(elf, output, blowup, time, max_rows),
Commands::Verify {
proof,
elf,
Expand Down Expand Up @@ -249,7 +254,13 @@ fn cmd_execute(elf_path: PathBuf, flamegraph_path: Option<PathBuf>) -> ExitCode
ExitCode::SUCCESS
}

fn cmd_prove(elf_path: PathBuf, output_path: PathBuf, blowup: Option<u8>, time: bool) -> ExitCode {
fn cmd_prove(
elf_path: PathBuf,
output_path: PathBuf,
blowup: Option<u8>,
time: bool,
max_rows: Option<usize>,
) -> ExitCode {
eprintln!("Reading ELF file...");
let elf_data = match std::fs::read(&elf_path) {
Ok(data) => data,
Expand All @@ -262,6 +273,28 @@ fn cmd_prove(elf_path: PathBuf, output_path: PathBuf, blowup: Option<u8>, time:
#[cfg(feature = "jemalloc-stats")]
let tracker = heap_tracker::HeapTracker::start();

if cfg!(feature = "disk-spill") {
eprintln!("Disk-spill: enabled");
}

let max_rows_config = match max_rows {
Some(mr) => {
eprintln!("Max rows per chunk: {mr}");
prover::MaxRowsConfig {
cpu: mr,
memw: mr,
memw_aligned: mr,
dvrm: mr,
mul: mr,
lt: mr,
shift: mr,
load: mr,
branch: mr,
}
}
None => prover::MaxRowsConfig::default(),
};

let start = Instant::now();
let proof = match blowup {
Some(b) => {
Expand All @@ -276,11 +309,13 @@ fn cmd_prove(elf_path: PathBuf, output_path: PathBuf, blowup: Option<u8>, time:
"Generating proof (blowup={b}, queries={})...",
opts.fri_number_of_queries
);
prover::prove_with_options(&elf_data, &opts, &Default::default())
prover::prove_with_options(&elf_data, &opts, &max_rows_config)
}
None => {
let opts =
GoldilocksCubicProofOptions::with_blowup(2).expect("blowup=2 is always valid");
eprintln!("Generating proof...");
prover::prove(&elf_data)
prover::prove_with_options(&elf_data, &opts, &max_rows_config)
}
};
let prove_elapsed = start.elapsed();
Expand Down
3 changes: 3 additions & 0 deletions crypto/crypto/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ serde = { version = "1.0", default-features = false, features = [
rayon = { version = "1.8.0", optional = true }
rand = { version = "0.8.5", default-features = false }
rand_chacha = { version = "0.3.1", default-features = false }
memmap2 = { version = "0.9", optional = true }
tempfile = { version = "3", optional = true }

[dev-dependencies]
math = { path = "../math", features = ["test-utils"] }
Expand All @@ -31,4 +33,5 @@ asm = ["sha3/asm"]
std = ["math/std", "sha3/std", "serde?/std"]
serde = ["dep:serde"]
parallel = ["dep:rayon"]
disk-spill = ["std", "dep:memmap2", "dep:tempfile"]
alloc = []
137 changes: 129 additions & 8 deletions crypto/crypto/src/merkle_tree/merkle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,19 @@ impl Display for Error {
#[cfg(feature = "std")]
impl std::error::Error for Error {}

/// File-backed mmap storage for Merkle tree nodes.
///
/// After `spill_nodes_to_disk()`, the heap `Vec<B::Node>` is freed and all
/// node access goes through this mmap. The OS manages page eviction under
/// memory pressure — file-backed pages are evictable without swap.
#[cfg(feature = "disk-spill")]
pub(crate) struct MmapNodeBacking {
mmap: memmap2::Mmap,
_file: std::fs::File,
node_count: usize,
node_size: usize,
}

/// The struct for the Merkle tree, consisting of the root and the nodes.
/// A typical tree would look like this
/// root
Expand All @@ -31,11 +44,29 @@ impl std::error::Error for Error {}
/// leaf 1 leaf 2 leaf 3 leaf 4
/// The bottom leafs correspond to the hashes of the elements, while each upper
/// layer contains the hash of the concatenation of the daughter nodes.
#[derive(Clone)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct MerkleTree<B: IsMerkleTreeBackend> {
pub root: B::Node,
nodes: Vec<B::Node>,
#[cfg(feature = "disk-spill")]
#[cfg_attr(feature = "serde", serde(skip))]
mmap_backing: Option<MmapNodeBacking>,
}

impl<B: IsMerkleTreeBackend> Clone for MerkleTree<B> {
fn clone(&self) -> Self {
#[cfg(feature = "disk-spill")]
assert!(
self.mmap_backing.is_none(),
"cannot clone a spilled MerkleTree — nodes have been freed; use Arc instead"
);
Self {
root: self.root.clone(),
nodes: self.nodes.clone(),
#[cfg(feature = "disk-spill")]
mmap_backing: None,
}
}
}

const ROOT: usize = 0;
Expand Down Expand Up @@ -78,14 +109,46 @@ where
Some(MerkleTree {
root: nodes[ROOT].clone(),
nodes,
#[cfg(feature = "disk-spill")]
mmap_backing: None,
})
}

/// Total number of nodes in the tree (inner + leaves).
#[inline]
fn node_count(&self) -> usize {
#[cfg(feature = "disk-spill")]
if let Some(ref backing) = self.mmap_backing {
return backing.node_count;
}
self.nodes.len()
}

/// Access a node by index, returning a reference.
///
/// Returns `None` if `idx` is out of bounds.
#[inline]
fn node_get(&self, idx: usize) -> Option<&B::Node> {
#[cfg(feature = "disk-spill")]
if let Some(ref backing) = self.mmap_backing {
if idx < backing.node_count {
// SAFETY: B::Node is Copy (required by spill_nodes_to_disk's where clause).
// The mmap contains node_count × node_size contiguous bytes written from
// identical Node values on the same machine. The mmap base is page-aligned
// and node_size divides into page size for all concrete Node types ([u8; 32/64]).
let ptr = unsafe { backing.mmap.as_ptr().add(idx * backing.node_size) };
return Some(unsafe { &*(ptr as *const B::Node) });
}
return None;
}
self.nodes.get(idx)
}

/// Returns a Merkle proof for the element/s at position pos
/// For example, give me an inclusion proof for the 3rd element in the
/// Merkle tree
pub fn get_proof_by_pos(&self, pos: usize) -> Option<Proof<B::Node>> {
let pos = pos + self.nodes.len() / 2;
let pos = pos + self.node_count() / 2;
let Ok(merkle_path) = self.build_merkle_path(pos) else {
return None;
};
Expand All @@ -101,12 +164,12 @@ where
/// Returns the Merkle path for the element/s for the leaf at position pos
fn build_merkle_path(&self, pos: usize) -> Result<Vec<B::Node>, Error> {
// Pre-allocate based on tree depth (log2 of tree size)
let tree_depth = (self.nodes.len() + 1).ilog2() as usize;
let tree_depth = (self.node_count() + 1).ilog2() as usize;
let mut merkle_path = Vec::with_capacity(tree_depth);
let mut pos = pos;

while pos != ROOT {
let Some(node) = self.nodes.get(sibling_index(pos)) else {
let Some(node) = self.node_get(sibling_index(pos)) else {
// out of bounds, exit returning the current merkle_path
return Err(Error::OutOfBounds);
};
Expand Down Expand Up @@ -141,7 +204,7 @@ where
return Err(Error::EmptyPositionList);
}

let num_leaves = (self.nodes.len() + 1).div_ceil(2);
let num_leaves = (self.node_count() + 1).div_ceil(2);

// Validate all positions are within bounds
for &pos in pos_list {
Expand All @@ -154,15 +217,19 @@ where
// of the leaves.
let leaf_positions = pos_list
.iter()
.map(|pos| pos + self.nodes.len() / 2)
.map(|pos| pos + self.node_count() / 2)
.collect::<Vec<usize>>();
// We get the positions of the nodes for the batch proof.
let batch_auth_path_positions = self.get_batch_auth_path_positions(&leaf_positions);

// We get the nodes for the batch proof.
let batch_auth_path_nodes = batch_auth_path_positions
.iter()
.map(|pos| self.nodes[*pos].clone())
.map(|pos| {
self.node_get(*pos)
.expect("batch auth path position in bounds")
.clone()
})
.collect();

Ok(BatchProof {
Expand All @@ -188,7 +255,7 @@ where
let mut obtainable: BTreeSet<usize> = leaf_positions.iter().cloned().collect();

// Number of levels in tree
let num_levels = (self.nodes.len() + 1).ilog2();
let num_levels = (self.node_count() + 1).ilog2();

// Iter lefevel-by-level from leaves to root.
for _ in 0..num_levels - 1 {
Expand Down Expand Up @@ -217,4 +284,58 @@ where
// This makes the proof ordered from bottom (nodes closer to leaves) to top (nodes loser to root).
auth_path_set.into_iter().rev().collect()
}

/// Write tree nodes to a temp file, mmap it read-only, and free the heap Vec.
///
/// After this call, all node access methods read from the mmap transparently.
/// The OS can evict mmap pages under memory pressure since they're file-backed.
///
/// Requires `B::Node: Copy` to ensure nodes have a trivial byte representation
/// suitable for raw serialization and mmap casting.
///
/// Note: the concrete `Node` type is `[u8; 32]` (Keccak hash), which has no
/// padding bytes. The raw byte round-trip is therefore well-defined.
#[cfg(feature = "disk-spill")]
pub fn spill_nodes_to_disk(&mut self) -> std::io::Result<()>
where
B::Node: Copy,
{
use std::io::Write;

if self.nodes.is_empty() {
return Ok(());
}

let node_size = core::mem::size_of::<B::Node>();
let node_count = self.nodes.len();
let total_bytes = node_count * node_size;

let file = tempfile::tempfile()?;
file.set_len(total_bytes as u64)?;
{
let mut writer = std::io::BufWriter::new(&file);
// SAFETY: B::Node is Copy, so its in-memory representation is a
// valid byte sequence. The Vec is contiguous.
let bytes = unsafe {
core::slice::from_raw_parts(self.nodes.as_ptr() as *const u8, total_bytes)
};
writer.write_all(bytes)?;
writer.flush()?;
}

// SAFETY: We own the file exclusively; it won't be modified externally.
let mmap = unsafe { memmap2::MmapOptions::new().map(&file)? };

// Free the heap allocation
self.nodes = Vec::new();

self.mmap_backing = Some(MmapNodeBacking {
mmap,
_file: file,
node_count,
node_size,
});

Ok(())
}
}
1 change: 1 addition & 0 deletions crypto/math/src/field/element.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ use super::traits::{IsPrimeField, IsSubFieldOf, LegendreSymbol};

/// A field element with operations algorithms defined in `F`
#[allow(clippy::derived_hash_with_manual_eq)]
#[repr(transparent)]
#[derive(Debug, Clone, Hash, Copy)]
pub struct FieldElement<F: IsField> {
value: F::BaseType,
Expand Down
5 changes: 5 additions & 0 deletions crypto/stark/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@ itertools = "0.11.0"
# Parallelization crates
rayon = { version = "1.8.0", optional = true }

# Disk-spill: mmap LDE data to reduce heap memory during proving
memmap2 = { version = "0.9", optional = true }
tempfile = { version = "3", optional = true }

# wasm
wasm-bindgen = { version = "0.2", optional = true }
serde-wasm-bindgen = { version = "0.5", optional = true }
Expand All @@ -40,6 +44,7 @@ instruments = [] # This enab
debug-checks = [] # Enables validate_trace + bus balance report in prover
parallel = ["dep:rayon", "crypto/parallel"]
wasm = ["dep:wasm-bindgen", "dep:serde-wasm-bindgen", "dep:web-sys"]
disk-spill = ["dep:memmap2", "dep:tempfile", "crypto/disk-spill"]


[package.metadata.wasm-pack.profile.dev]
Expand Down
Loading
Loading