diff --git a/.github/scripts/publish_bench_vs.sh b/.github/scripts/publish_bench_vs.sh index de79ce64b..d1585cd9f 100644 --- a/.github/scripts/publish_bench_vs.sh +++ b/.github/scripts/publish_bench_vs.sh @@ -13,7 +13,7 @@ METRICS_FILE="bench_vs_artifacts/metrics.txt" if [ ! -f "$METRICS_FILE" ]; then curl -X POST "$WEBHOOK_URL" \ -H 'Content-Type: application/json; charset=utf-8' \ - --data '{"blocks":[{"type":"header","text":{"type":"plain_text","text":"Lambda VM vs SP1 v6 - Nightly Benchmark"}},{"type":"section","text":{"type":"mrkdwn","text":":x: Benchmark failed - no metrics found. Check the workflow logs."}}]}' + --data '{"blocks":[{"type":"header","text":{"type":"plain_text","text":"Lambda VM Nightly Benchmark"}},{"type":"section","text":{"type":"mrkdwn","text":":x: Benchmark failed - no metrics found. Check the workflow logs."}}]}' exit 0 fi @@ -79,6 +79,84 @@ if [ -n "$LAMBDA_PROJECTED_H" ] || [ -n "$SP1_PROJECTED_H" ]; then PROJ_SECTION=',{"type":"divider"},{"type":"header","text":{"type":"plain_text","text":"Linear Projection"}},{"type":"section","text":{"type":"mrkdwn","text":"'"$PROJ_MRKDWN"'"}}' fi +# --- Plonky3 section (optional) -------------------------------------------- +# Built when `bench_vs_artifacts/p3/headline/metrics.txt` exists. + +p3_parse() { + local file=$1 + local key=$2 + { grep "^${key}=" "$file" 2>/dev/null || true; } | cut -d= -f2- +} + +p3_fmt_seconds() { + LC_NUMERIC=C awk -v s="$1" 'BEGIN { + if (s == "") { print "n/a"; exit } + if (s + 0 < 1) printf "%.0fms", s * 1000 + else printf "%.3fs", s + }' +} + +p3_fmt_mb() { + LC_NUMERIC=C awk -v b="$1" 'BEGIN { + if (b == "") { print "n/a"; exit } + printf "%.1f MB", b / (1024 * 1024) + }' +} + +p3_fmt_gb() { + LC_NUMERIC=C awk -v kb="$1" 'BEGIN { + if (kb == "") { print "n/a"; exit } + printf "%.2f GB", kb / (1024 * 1024) + }' +} + +p3_fmt_ratio_pair() { + LC_NUMERIC=C awk -v a="$1" -v b="$2" 'BEGIN { + if (a == "" || b == "" || b + 0 == 0) { print "n/a"; exit } + printf "%.2fx", a / b + }' +} + +P3_SECTION="" +P3_FILE="bench_vs_artifacts/p3/metrics.txt" +if [ -f "$P3_FILE" ]; then + H_LOG_ROWS=$(p3_parse "$P3_FILE" "log_rows_series") + H_COLS=$(p3_parse "$P3_FILE" "columns") + H_BLOWUP=$(p3_parse "$P3_FILE" "blowup") + H_QUERIES=$(p3_parse "$P3_FILE" "fri_queries") + H_ROWS=$(p3_parse "$P3_FILE" "rows_series") + H_LAMBDA_PROVE=$(p3_parse "$P3_FILE" "lambda_prove_medians") + H_P3_PROVE=$(p3_parse "$P3_FILE" "p3_prove_medians") + H_LAMBDA_VERIFY=$(p3_parse "$P3_FILE" "lambda_verify_medians") + H_P3_VERIFY=$(p3_parse "$P3_FILE" "p3_verify_medians") + H_LAMBDA_PROOF=$(p3_parse "$P3_FILE" "lambda_proof_size_medians") + H_P3_PROOF=$(p3_parse "$P3_FILE" "p3_proof_size_medians") + H_LAMBDA_RSS=$(p3_parse "$P3_FILE" "lambda_peak_rss_medians") + H_P3_RSS=$(p3_parse "$P3_FILE" "p3_peak_rss_medians") + H_RATIO=$(p3_parse "$P3_FILE" "ratios_lambda_over_p3") + + H_ROWS_FMT=$(LC_NUMERIC=C awk -v r="$H_ROWS" 'BEGIN { + if (r == "") { print "n/a"; exit } + if (r + 0 >= 1000000) printf "%.1fM", r / 1000000 + else if (r + 0 >= 1000) printf "%.0fK", r / 1000 + else printf "%d", r + }') + + PROOF_RATIO=$(p3_fmt_ratio_pair "$H_LAMBDA_PROOF" "$H_P3_PROOF") + RSS_RATIO=$(p3_fmt_ratio_pair "$H_LAMBDA_RSS" "$H_P3_RSS") + PROVE_RATIO_FMT=$(LC_NUMERIC=C awk -v r="$H_RATIO" 'BEGIN { + if (r == "" || r == "n/a") { print "n/a"; exit } + printf "%.2fx", r + }') + + P3_MRKDWN="*log_rows=${H_LOG_ROWS} (${H_ROWS_FMT} rows · ${H_COLS} cols · blowup=${H_BLOWUP} · ${H_QUERIES} queries)*" + P3_MRKDWN="${P3_MRKDWN}\\n*Lambda:* $(p3_fmt_seconds "$H_LAMBDA_PROVE") prove · $(p3_fmt_seconds "$H_LAMBDA_VERIFY") verify · $(p3_fmt_mb "$H_LAMBDA_PROOF") proof · $(p3_fmt_gb "$H_LAMBDA_RSS") RSS" + P3_MRKDWN="${P3_MRKDWN}\\n*Plonky3:* $(p3_fmt_seconds "$H_P3_PROVE") prove · $(p3_fmt_seconds "$H_P3_VERIFY") verify · $(p3_fmt_mb "$H_P3_PROOF") proof · $(p3_fmt_gb "$H_P3_RSS") RSS" + P3_MRKDWN="${P3_MRKDWN}\\n*Ratio L/P3:* ${PROVE_RATIO_FMT} prove · ${PROOF_RATIO} proof · ${RSS_RATIO} RSS" + + P3_SECTION=',{"type":"divider"},{"type":"header","text":{"type":"plain_text","text":"Lambda VM vs Plonky3"}},{"type":"section","text":{"type":"mrkdwn","text":"'"$P3_MRKDWN"'"}}' +fi + ETHREX_METRICS_FILE="bench_vs_artifacts/ethrex_metrics.txt" ETHREX_SECTION="" if [ -f "$ETHREX_METRICS_FILE" ]; then @@ -95,4 +173,4 @@ fi curl -X POST "$WEBHOOK_URL" \ -H 'Content-Type: application/json; charset=utf-8' \ - --data '{"blocks":[{"type":"header","text":{"type":"plain_text","text":"Lambda VM vs SP1 v6 - Nightly Benchmark"}},{"type":"context","elements":[{"type":"mrkdwn","text":"*Program:* Fibonacci · *Device:* CPU"}]},{"type":"divider"},{"type":"section","text":{"type":"mrkdwn","text":"'"$RESULTS_MRKDWN"'"}}'"$PROJ_SECTION$ETHREX_SECTION"']}' + --data '{"blocks":[{"type":"header","text":{"type":"plain_text","text":"Lambda VM Nightly Benchmark"}},{"type":"context","elements":[{"type":"mrkdwn","text":"*Program:* Fibonacci · *Device:* CPU"}]},{"type":"divider"},{"type":"header","text":{"type":"plain_text","text":"Lambda VM vs SP1 v6"}},{"type":"section","text":{"type":"mrkdwn","text":"'"$RESULTS_MRKDWN"'"}}'"$PROJ_SECTION$ETHREX_SECTION$P3_SECTION"']}' diff --git a/.github/workflows/bench-vs-nightly.yml b/.github/workflows/bench-vs-nightly.yml index c1fdd7c86..4d21a0a31 100644 --- a/.github/workflows/bench-vs-nightly.yml +++ b/.github/workflows/bench-vs-nightly.yml @@ -62,6 +62,20 @@ jobs: --report-dir bench_vs_artifacts \ --no-color + - name: Refresh Plonky3 to latest main + run: | + cargo update --manifest-path bench_vs_plonky3/Cargo.toml \ + -p p3-air -p p3-field -p p3-goldilocks -p p3-matrix \ + -p p3-commit -p p3-challenger -p p3-symmetric \ + -p p3-merkle-tree -p p3-keccak -p p3-fri \ + -p p3-uni-stark -p p3-dft + + - name: Run Plonky3 nightly benchmark + run: | + bash ./bench_vs_plonky3/run.sh \ + --log-rows 21 --num-sequences 16 --runs 10 --scalar \ + --report-dir bench_vs_artifacts/p3 --no-color + - name: Upload nightly benchmark artifact uses: actions/upload-artifact@v4 with: diff --git a/Cargo.lock b/Cargo.lock index 70b4071e8..e3bb9f4cb 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -293,6 +293,33 @@ version = "1.8.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2af50177e190e07a26ab74f8b1efbfe2ef87da2116221318cb1c2e82baf7de06" +[[package]] +name = "bench-vs-plonky3" +version = "0.1.0" +dependencies = [ + "criterion 0.4.0", + "crypto", + "libc", + "math", + "p3-air", + "p3-challenger", + "p3-commit", + "p3-dft 0.5.1", + "p3-field 0.5.1", + "p3-fri", + "p3-goldilocks", + "p3-keccak", + "p3-matrix 0.5.1", + "p3-merkle-tree", + "p3-symmetric 0.5.1", + "p3-uni-stark", + "serde", + "serde_cbor", + "stark", + "tracing", + "tracing-subscriber", +] + [[package]] name = "bincode" version = "1.3.3" @@ -1945,7 +1972,7 @@ dependencies = [ "serde_arrays", "sha2", "sp1_bls12_381", - "spin", + "spin 0.9.8", ] [[package]] @@ -2049,6 +2076,15 @@ version = "0.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6373607a59f0be73a39b6fe456b8192fcc3585f602af20751600e974dd455e77" +[[package]] +name = "lock_api" +version = "0.4.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "224399e74b87b5f3557511d98dff8b14089b3dadafcab6bb93eab67d3aace965" +dependencies = [ + "scopeguard", +] + [[package]] name = "log" version = "0.4.29" @@ -2283,6 +2319,16 @@ dependencies = [ "sha2", ] +[[package]] +name = "p3-air" +version = "0.5.1" +source = "git+https://github.com/Plonky3/Plonky3.git#de83ef4367b66d5908623c6503946ffcfdc3b6ae" +dependencies = [ + "p3-field 0.5.1", + "p3-matrix 0.5.1", + "tracing", +] + [[package]] name = "p3-baby-bear" version = "0.2.3-succinct" @@ -2290,24 +2336,66 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7521838ecab2ddf4f7bc4ceebad06ec02414729598485c1ada516c39900820e8" dependencies = [ "num-bigint 0.4.6", - "p3-field", - "p3-mds", - "p3-poseidon2", - "p3-symmetric", + "p3-field 0.2.3-succinct", + "p3-mds 0.2.3-succinct", + "p3-poseidon2 0.2.3-succinct", + "p3-symmetric 0.2.3-succinct", "rand 0.8.5", "serde", ] +[[package]] +name = "p3-challenger" +version = "0.5.1" +source = "git+https://github.com/Plonky3/Plonky3.git#de83ef4367b66d5908623c6503946ffcfdc3b6ae" +dependencies = [ + "p3-field 0.5.1", + "p3-maybe-rayon 0.5.1", + "p3-monty-31", + "p3-symmetric 0.5.1", + "p3-util 0.5.1", + "tracing", +] + +[[package]] +name = "p3-commit" +version = "0.5.1" +source = "git+https://github.com/Plonky3/Plonky3.git#de83ef4367b66d5908623c6503946ffcfdc3b6ae" +dependencies = [ + "itertools 0.14.0", + "p3-challenger", + "p3-dft 0.5.1", + "p3-field 0.5.1", + "p3-matrix 0.5.1", + "p3-multilinear-util", + "p3-util 0.5.1", + "serde", +] + [[package]] name = "p3-dft" version = "0.2.3-succinct" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "46414daedd796f1eefcdc1811c0484e4bced5729486b6eaba9521c572c76761a" dependencies = [ - "p3-field", - "p3-matrix", - "p3-maybe-rayon", - "p3-util", + "p3-field 0.2.3-succinct", + "p3-matrix 0.2.3-succinct", + "p3-maybe-rayon 0.2.3-succinct", + "p3-util 0.2.3-succinct", + "tracing", +] + +[[package]] +name = "p3-dft" +version = "0.5.1" +source = "git+https://github.com/Plonky3/Plonky3.git#de83ef4367b66d5908623c6503946ffcfdc3b6ae" +dependencies = [ + "itertools 0.14.0", + "p3-field 0.5.1", + "p3-matrix 0.5.1", + "p3-maybe-rayon 0.5.1", + "p3-util 0.5.1", + "spin 0.11.0", "tracing", ] @@ -2320,11 +2408,75 @@ dependencies = [ "itertools 0.12.1", "num-bigint 0.4.6", "num-traits", - "p3-util", + "p3-util 0.2.3-succinct", "rand 0.8.5", "serde", ] +[[package]] +name = "p3-field" +version = "0.5.1" +source = "git+https://github.com/Plonky3/Plonky3.git#de83ef4367b66d5908623c6503946ffcfdc3b6ae" +dependencies = [ + "itertools 0.14.0", + "num-bigint 0.4.6", + "p3-maybe-rayon 0.5.1", + "p3-util 0.5.1", + "paste", + "rand 0.10.1", + "serde", + "tracing", +] + +[[package]] +name = "p3-fri" +version = "0.5.1" +source = "git+https://github.com/Plonky3/Plonky3.git#de83ef4367b66d5908623c6503946ffcfdc3b6ae" +dependencies = [ + "itertools 0.14.0", + "p3-challenger", + "p3-commit", + "p3-dft 0.5.1", + "p3-field 0.5.1", + "p3-matrix 0.5.1", + "p3-maybe-rayon 0.5.1", + "p3-util 0.5.1", + "rand 0.10.1", + "serde", + "spin 0.11.0", + "thiserror 2.0.17", + "tracing", +] + +[[package]] +name = "p3-goldilocks" +version = "0.5.1" +source = "git+https://github.com/Plonky3/Plonky3.git#de83ef4367b66d5908623c6503946ffcfdc3b6ae" +dependencies = [ + "num-bigint 0.4.6", + "p3-challenger", + "p3-dft 0.5.1", + "p3-field 0.5.1", + "p3-mds 0.5.1", + "p3-poseidon1", + "p3-poseidon2 0.5.1", + "p3-symmetric 0.5.1", + "p3-util 0.5.1", + "paste", + "rand 0.10.1", + "serde", +] + +[[package]] +name = "p3-keccak" +version = "0.5.1" +source = "git+https://github.com/Plonky3/Plonky3.git#de83ef4367b66d5908623c6503946ffcfdc3b6ae" +dependencies = [ + "p3-symmetric 0.5.1", + "p3-util 0.5.1", + "tiny-keccak", +] + [[package]] name = "p3-matrix" version = "0.2.3-succinct" @@ -2332,20 +2484,42 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3e4de3f373589477cb735ea58e125898ed20935e03664b4614c7fac258b3c42f" dependencies = [ "itertools 0.12.1", - "p3-field", - "p3-maybe-rayon", - "p3-util", + "p3-field 0.2.3-succinct", + "p3-maybe-rayon 0.2.3-succinct", + "p3-util 0.2.3-succinct", "rand 0.8.5", "serde", "tracing", ] +[[package]] +name = "p3-matrix" +version = "0.5.1" +source = "git+https://github.com/Plonky3/Plonky3.git#de83ef4367b66d5908623c6503946ffcfdc3b6ae" +dependencies = [ + "itertools 0.14.0", + "p3-field 0.5.1", + "p3-maybe-rayon 0.5.1", + "p3-util 0.5.1", + "rand 0.10.1", + "serde", + "tracing", +] + [[package]] name = "p3-maybe-rayon" version = "0.2.3-succinct" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c3968ad1160310296eb04f91a5f4edfa38fe1d6b2b8cd6b5c64e6f9b7370979e" +[[package]] +name = "p3-maybe-rayon" +version = "0.5.1" +source = "git+https://github.com/Plonky3/Plonky3.git#de83ef4367b66d5908623c6503946ffcfdc3b6ae" +dependencies = [ + "rayon", +] + [[package]] name = "p3-mds" version = "0.2.3-succinct" @@ -2353,14 +2527,93 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2356b1ed0add6d5dfbf7a338ce534a6fde827374394a52cec16a0840af6e97c9" dependencies = [ "itertools 0.12.1", - "p3-dft", - "p3-field", - "p3-matrix", - "p3-symmetric", - "p3-util", + "p3-dft 0.2.3-succinct", + "p3-field 0.2.3-succinct", + "p3-matrix 0.2.3-succinct", + "p3-symmetric 0.2.3-succinct", + "p3-util 0.2.3-succinct", "rand 0.8.5", ] +[[package]] +name = "p3-mds" +version = "0.5.1" +source = "git+https://github.com/Plonky3/Plonky3.git#de83ef4367b66d5908623c6503946ffcfdc3b6ae" +dependencies = [ + "p3-dft 0.5.1", + "p3-field 0.5.1", + "p3-symmetric 0.5.1", + "p3-util 0.5.1", + "rand 0.10.1", +] + +[[package]] +name = "p3-merkle-tree" +version = "0.5.1" +source = "git+https://github.com/Plonky3/Plonky3.git#de83ef4367b66d5908623c6503946ffcfdc3b6ae" +dependencies = [ + "itertools 0.14.0", + "p3-commit", + "p3-field 0.5.1", + "p3-matrix 0.5.1", + "p3-maybe-rayon 0.5.1", + "p3-symmetric 0.5.1", + "p3-util 0.5.1", + "rand 0.10.1", + "serde", + "spin 0.11.0", + "thiserror 2.0.17", + "tracing", +] + +[[package]] +name = "p3-monty-31" +version = "0.5.1" +source = "git+https://github.com/Plonky3/Plonky3.git#de83ef4367b66d5908623c6503946ffcfdc3b6ae" +dependencies = [ + "itertools 0.14.0", + "num-bigint 0.4.6", + "p3-dft 0.5.1", + "p3-field 0.5.1", + "p3-matrix 0.5.1", + "p3-maybe-rayon 0.5.1", + "p3-mds 0.5.1", + "p3-poseidon1", + "p3-poseidon2 0.5.1", + "p3-symmetric 0.5.1", + "p3-util 0.5.1", + "paste", + "rand 0.10.1", + "serde", + "spin 0.11.0", + "tracing", +] + +[[package]] +name = "p3-multilinear-util" +version = "0.5.1" +source = "git+https://github.com/Plonky3/Plonky3.git#de83ef4367b66d5908623c6503946ffcfdc3b6ae" +dependencies = [ + "itertools 0.14.0", + "p3-field 0.5.1", + "p3-matrix 0.5.1", + "p3-maybe-rayon 0.5.1", + "p3-util 0.5.1", + "rand 0.10.1", + "serde", + "tracing", +] + +[[package]] +name = "p3-poseidon1" +version = "0.5.1" +source = "git+https://github.com/Plonky3/Plonky3.git#de83ef4367b66d5908623c6503946ffcfdc3b6ae" +dependencies = [ + "p3-field 0.5.1", + "p3-symmetric 0.5.1", + "rand 0.10.1", +] + [[package]] name = "p3-poseidon2" version = "0.2.3-succinct" @@ -2368,13 +2621,25 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7da1eec7e1b6900581bedd95e76e1ef4975608dd55be9872c9d257a8a9651c3a" dependencies = [ "gcd", - "p3-field", - "p3-mds", - "p3-symmetric", + "p3-field 0.2.3-succinct", + "p3-mds 0.2.3-succinct", + "p3-symmetric 0.2.3-succinct", "rand 0.8.5", "serde", ] +[[package]] +name = "p3-poseidon2" +version = "0.5.1" +source = "git+https://github.com/Plonky3/Plonky3.git#de83ef4367b66d5908623c6503946ffcfdc3b6ae" +dependencies = [ + "p3-field 0.5.1", + "p3-mds 0.5.1", + "p3-symmetric 0.5.1", + "p3-util 0.5.1", + "rand 0.10.1", +] + [[package]] name = "p3-symmetric" version = "0.2.3-succinct" @@ -2382,10 +2647,41 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "edb439bea1d822623b41ff4b51e3309e80d13cadf8b86d16ffd5e6efb9fdc360" dependencies = [ "itertools 0.12.1", - "p3-field", + "p3-field 0.2.3-succinct", + "serde", +] + +[[package]] +name = "p3-symmetric" +version = "0.5.1" +source = "git+https://github.com/Plonky3/Plonky3.git#de83ef4367b66d5908623c6503946ffcfdc3b6ae" +dependencies = [ + "itertools 0.14.0", + "p3-field 0.5.1", + "p3-util 0.5.1", "serde", ] +[[package]] +name = "p3-uni-stark" +version = "0.5.1" +source = "git+https://github.com/Plonky3/Plonky3.git#de83ef4367b66d5908623c6503946ffcfdc3b6ae" +dependencies = [ + "itertools 0.14.0", + "libm", + "p3-air", + "p3-challenger", + "p3-commit", + "p3-field 0.5.1", + "p3-fri", + "p3-matrix 0.5.1", + "p3-maybe-rayon 0.5.1", + "p3-util 0.5.1", + "serde", + "thiserror 2.0.17", + "tracing", +] + [[package]] name = "p3-util" version = "0.2.3-succinct" @@ -2395,6 +2691,15 @@ dependencies = [ "serde", ] +[[package]] +name = "p3-util" +version = "0.5.1" +source = "git+https://github.com/Plonky3/Plonky3.git#de83ef4367b66d5908623c6503946ffcfdc3b6ae" +dependencies = [ + "serde", + "transpose", +] + [[package]] name = "pairing" version = "0.23.0" @@ -2687,6 +2992,15 @@ dependencies = [ "rand_core 0.9.3", ] +[[package]] +name = "rand" +version = "0.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d2e8e8bcc7961af1fdac401278c6a831614941f6164ee3bf4ce61b7edb162207" +dependencies = [ + "rand_core 0.10.1", +] + [[package]] name = "rand_chacha" version = "0.3.1" @@ -2725,6 +3039,12 @@ dependencies = [ "getrandom 0.3.4", ] +[[package]] +name = "rand_core" +version = "0.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "63b8176103e19a2643978565ca18b50549f6101881c443590420e4dc998a3c69" + [[package]] name = "rand_xorshift" version = "0.4.0" @@ -2968,6 +3288,12 @@ dependencies = [ "serde_json", ] +[[package]] +name = "scopeguard" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" + [[package]] name = "sec1" version = "0.7.3" @@ -3180,9 +3506,9 @@ dependencies = [ "lazy_static", "num-bigint 0.4.6", "p3-baby-bear", - "p3-field", - "p3-poseidon2", - "p3-symmetric", + "p3-field 0.2.3-succinct", + "p3-poseidon2 0.2.3-succinct", + "p3-symmetric 0.2.3-succinct", "serde", "sha2", ] @@ -3208,6 +3534,15 @@ version = "0.9.8" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6980e8d7511241f8acf4aebddbb1ff938df5eebe98691418c4468d0b72a96a67" +[[package]] +name = "spin" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "783f3f6f6b01e295a669edfc402133a5f2553d1f0e81284b3ba4594e80bdd4a2" +dependencies = [ + "lock_api", +] + [[package]] name = "spki" version = "0.7.3" @@ -3256,6 +3591,12 @@ version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f" +[[package]] +name = "strength_reduce" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fe895eb47f22e2ddd4dabc02bce419d2e643c8e3b585c78158b349195bc24d82" + [[package]] name = "strsim" version = "0.11.1" @@ -3652,6 +3993,16 @@ dependencies = [ "tracing-log", ] +[[package]] +name = "transpose" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1ad61aed86bc3faea4300c7aee358b4c6d0c8d6ccc36524c96e4c92ccf26e77e" +dependencies = [ + "num-integer", + "strength_reduce", +] + [[package]] name = "typenum" version = "1.19.0" diff --git a/Cargo.toml b/Cargo.toml index e43dc7f0d..067b28dc1 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -7,6 +7,7 @@ members = [ "crypto/math", "crypto/math-cuda", "bin/cli", + "bench_vs_plonky3", ] resolver = "2" diff --git a/bench_vs_plonky3/Cargo.toml b/bench_vs_plonky3/Cargo.toml new file mode 100644 index 000000000..8fef10667 --- /dev/null +++ b/bench_vs_plonky3/Cargo.toml @@ -0,0 +1,54 @@ +[package] +name = "bench-vs-plonky3" +version = "0.1.0" +edition = "2024" + +[dependencies] +# Lambda STARK +stark = { path = "../crypto/stark", features = ["test-utils"] } +crypto = { path = "../crypto/crypto", features = ["std", "serde"] } +math = { path = "../crypto/math", features = [ + "std", + "lambdaworks-serde-binary", +] } + +p3-air = { git = "https://github.com/Plonky3/Plonky3.git" } +p3-field = { git = "https://github.com/Plonky3/Plonky3.git" } +p3-goldilocks = { git = "https://github.com/Plonky3/Plonky3.git" } +p3-matrix = { git = "https://github.com/Plonky3/Plonky3.git" } +p3-commit = { git = "https://github.com/Plonky3/Plonky3.git" } +p3-challenger = { git = "https://github.com/Plonky3/Plonky3.git" } +p3-symmetric = { git = "https://github.com/Plonky3/Plonky3.git" } +p3-merkle-tree = { git = "https://github.com/Plonky3/Plonky3.git" } +p3-keccak = { git = "https://github.com/Plonky3/Plonky3.git" } +p3-fri = { git = "https://github.com/Plonky3/Plonky3.git" } +p3-uni-stark = { git = "https://github.com/Plonky3/Plonky3.git", features = ["parallel"] } +p3-dft = { git = "https://github.com/Plonky3/Plonky3.git", features = ["parallel"] } + +# Tracing for P3 span-based profiling +tracing = "0.1" +tracing-subscriber = { version = "0.3", features = ["fmt", "env-filter"] } +libc = "0.2" +serde = { version = "1.0", features = ["derive"] } +serde_cbor = "0.11" + +[dev-dependencies] +criterion = { version = "0.4", default-features = false } + +[features] +# Both provers run multi-threaded by default: Plonky3's `Radix2DitParallel` DFT +# uses rayon unconditionally, so Lambda must also enable `parallel` for a fair +# apples-to-apples comparison. Disable with `--no-default-features` to compare +# single-threaded. The cubic extension is `x^3 - 2` (binomial) on Lambda and +# `x^3 - x - 1` (trinomial) on upstream Plonky3 — same degree, same soundness. +default = ["parallel"] +parallel = ["stark/parallel"] +instruments = ["stark/instruments"] + +[[bin]] +name = "prove_bench" +path = "src/bin/prove_bench.rs" + +[[bench]] +name = "stark_comparison" +harness = false diff --git a/bench_vs_plonky3/README.md b/bench_vs_plonky3/README.md new file mode 100644 index 000000000..60d44792a --- /dev/null +++ b/bench_vs_plonky3/README.md @@ -0,0 +1,233 @@ +# Lambda STARK vs Plonky3 Benchmark + +Compares **single-shot end-to-end proving time** for an identical multi-sequence +Fibonacci AIR. Complements `bench_vs/` (which compares Lambda VM vs SP1 on a +full guest program) by isolating the STARK prover — no VM execution, no trace +builder, just one AIR and two provers. + +## What is measured + +Both provers prove the same AIR: + +- **Columns** = `2 × num_sequences` (default 16 sequences → 32 columns). +- **Rows** = `2 ^ log_rows` (default `19` → 524 288 rows). +- **Blowup** = 2 (matches Lambda production `GoldilocksCubicProofOptions::with_blowup(2)`). +- **FRI queries** = 219, grinding = 0. + +The timing window on both sides is **`Instant::now()` around `prove`, no +verification, no proof serialization**: + +| Phase | Lambda STARK | Plonky3 | +|--------------------------------------|:------------:|:-------:| +| Build AIR + trace | ❌ (outside) | ❌ (outside) | +| Build public inputs | ❌ (outside) | ❌ (outside) | +| Prove (Round 1 → Round 4) | ✅ | ✅ (`p3_uni_stark::prove`) | +| Proof serialize / disk write | ❌ | ❌ | +| Verify | ❌ | ❌ | + +Lambda's trace, public inputs, and AIR are constructed via +`lambda_fibonacci_pair::{compute_trace, create_public_inputs, FibonacciPairMultiColAIR}`. +Plonky3's counterpart uses `plonky3_fibonacci::{P3FibonacciAir, generate_fibonacci_trace, public_values}` +with `plonky3_config::matched_params_config`. Both AIRs are **cell-by-cell +equivalent** — this is asserted by the `lambda_pair_trace_matches_plonky3_trace` +test. + +## Usage + +```bash +# Default: log-rows=19, num-sequences=16, runs=10, cubic extension, no scalar +./bench_vs_plonky3/run.sh + +# Size sweep +./bench_vs_plonky3/run.sh --log-rows 17 18 19 20 + +# Single prover +./bench_vs_plonky3/run.sh --lambda-only +./bench_vs_plonky3/run.sh --p3-only + +# Scalar mode on both sides (x86_64 only — disables AVX2/AVX-512) +./bench_vs_plonky3/run.sh --scalar + +# Write machine-readable artifacts +./bench_vs_plonky3/run.sh --report-dir /tmp/p3_report --no-color +``` + +### Flags + +| Flag | Default | Effect | +|---|---|---| +| `--log-rows K [K ...]` | `19` | One or more power-of-2 row counts. | +| `--num-sequences N` | `16` | Number of Fibonacci sequences (columns = `2 × N`). | +| `--runs N` | `10` | Runs per `(size, prover)`; median + CV are reported. | +| `--lambda-only` / `--p3-only` | both | Restrict to a single prover. | +| `--report-dir DIR` | — | Write TSV + metrics + raw stdouts + raw audits. | +| `--scalar` | off | Pin `RUSTFLAGS="-C target-feature=-avx2,-avx512f"` so Goldilocks field arithmetic runs scalar on both sides. x86_64 only; on other archs the flag is ignored with a warning. The MMCS is already scalar regardless of this flag (see [P3 config: scalar MMCS](#p3-config-scalar-mmcs)). | +| `--no-color` | off | Disable ANSI colors. | +| `-h` / `--help` | — | Print usage. | + +## Output + +Stdout (without `--report-dir`): + +``` +=== STARK prove benchmark: Lambda vs Plonky3 === + log-rows: 19 + num-sequences: 16 (columns = 32) + runs/size: 10 (median + CV reported) + p3 extension: upstream CubicTrinomialExtensionField (x^3 - x - 1) + p3 mmcs: scalar Keccak256 (val_packing_width=1, hash_lanes=1) + proof params: blowup=2, queries=219, grinding=0 + scalar mode: on (arch=x86_64, RUSTFLAGS="-C target-feature=-avx2,-avx512f") + +[build] prove_bench +--- log-rows=19 (rows = 524288) --- + [lambda] prove median 0.574s (CV 3.07%), verify 0.024s, proof 4116000 B, rss 805000 KB + [p3] prove median 0.324s (CV 2.85%), verify 0.019s, proof 1987000 B, rss 627000 KB + +=== Summary === + log-rows rows Lambda (s) L CV% P3 (s) P3 CV% L/P3 + -------- ---- ---------- ----- ------ ------ ---- + 19 524288 0.574s 3.07% 0.324s 2.85% 1.770x (P3 faster) + +Timing window: prove only for the ratio. Verify, proof size, RSS and throughput are reported separately. +``` + +With `--report-dir DIR` the script writes: + +- `results.tsv` — tab-separated, one row per `log_rows` size with 14 columns: + `log_rows, rows, lambda_prove_median_s, lambda_prove_cv_pct, + lambda_verify_median_s, lambda_proof_size_bytes_median, + lambda_peak_rss_kb_median, p3_prove_median_s, p3_prove_cv_pct, + p3_verify_median_s, p3_proof_size_bytes_median, p3_peak_rss_kb_median, + ratio_lambda_over_p3, runs`. +- `raw_metrics.tsv` — one row per `(prover, log_rows, run)` with all + `METRICS` fields parsed out. +- `raw_audits.tsv` — one row per `(prover, log_rows, run)` with the AUDIT + line emitted by `prove_bench` before each prove call. Lets you confirm in + retrospect that `val_packing_width=1`, `hash_lanes=1`, + `base_transition_constraints=2×num_sequences`, etc. Don't trust a number + without skimming this file. +- `metrics.txt` — key=value pairs with the config used (arch, scalar flag, + extension, mmcs choice, blowup, queries, runs, rustflags) and the + per-series values slash-joined (so post-processing scripts can split easily). +- `raw/` — per-invocation stdouts (`{prover}_log{K}_run{i}.stdout`). + +No markdown file is generated — the TSV is the single source of truth for +downstream tooling. + +## Nightly + +The Lambda-vs-Plonky3 bench is part of the shared +`.github/workflows/bench-vs-nightly.yml` workflow, which runs daily at +06:00 UTC (03:00 Buenos Aires) on the self-hosted `bench` runner. The P3 +step executes after the Lambda-vs-SP1 and ethrex empty-block steps: + +```bash +bash ./bench_vs_plonky3/run.sh \ + --log-rows 21 \ + --num-sequences 16 \ + --runs 10 \ + --scalar \ + --report-dir bench_vs_artifacts/p3 \ + --no-color +``` + +A `cargo update -p p3-*` runs before this step so the bench tracks the +latest upstream Plonky3 `main`. The full `bench_vs_artifacts/` directory +(SP1 + ethrex + P3 outputs) is uploaded as one artifact named +`bench-vs-nightly--` with 90-day retention. A "Lambda +VM vs Plonky3" section is appended to the same Slack post that publishes +the SP1 and ethrex results. + +## Breakdown (per-phase timing) for manual analysis + +The nightly only reports wall-clock totals. When you need to see *where* the +time goes (constraint eval vs FFT vs FRI vs Merkle vs queries on the Lambda +side, and the per-span breakdown on the Plonky3 side), run the +`instruments_breakdown` test: + +```bash +# x86_64 (server), Goldilocks scalar: +RUSTFLAGS="-C target-feature=-avx2,-avx512f" \ +cargo test -p bench-vs-plonky3 --features instruments --release -- \ + instruments_breakdown --nocapture +``` + +- `--features instruments` activates `stark/instruments` — without it, the + per-phase timers are no-ops and the Lambda breakdown prints zeros. +- `--release` is mandatory (debug numbers are meaningless). +- `--nocapture` is required to see the output (`cargo test` swallows stdout + otherwise). +- The test hardcodes `num_sequences = 16`, `rows = 1 << 19` (524 288), same + shape as the nightly, so the breakdown maps onto the nightly numbers. +- Output is split in two sections: + - **Lambda**: explicit per-phase totals (Pre-pass / R1 Main commits / R1 Aux + build+commit / Rounds 2-4) plus sub-ops (Main LDE, Main Merkle, constraint + eval, decompose+extend, composition Merkle, OOD, deep comp, deep extend, + FRI commit, queries+open). + - **Plonky3**: every `tracing` span emitted at DEBUG during + `p3_uni_stark::prove`, sorted by wall-clock descending, filtered ≥ 0.1 ms. + Spans nest (e.g. `prove ⊃ compute_quotient_values`), so Σspans > total is + expected and not a bug. `(unaccounted)` can be negative from nesting. + +The nightly does **not** activate this path — it would add ~1 % overhead and +pollute the historical wall-clock numbers. + +## P3 config: scalar MMCS + +`plonky3_config.rs` sets up the P3 stark config with a deliberately +**non-production** MMCS: + +```rust +type ByteHash = Keccak256Hash; // tiny_keccak scalar +type FieldHash = SerializingHasher; +type MyCompress = CompressionFunctionFromHasher; +pub type ValMmcs = MerkleTreeMmcs; +``` + +The Plonky3 default for Goldilocks MMCS uses `PaddingFreeSponge` with leaves `[Val; VECTOR_LEN]` and digests `[u64; VECTOR_LEN]`, +where `VECTOR_LEN` is set at compile-time per arch: NEON=2, AVX-512=8, +AVX2=4, SSE2=2, fallback=1. That gives Plonky3 a free `N×` Keccak speedup +on every Merkle node — which Lambda's `sha3::Keccak256` cannot exploit +because the Lambda MMCS hashes a single input at a time. + +The scalar config here makes both sides hash one input per Keccak call. +Both still use the **same Keccak-f[1600] permutation** (capacity 512, rate +1088, 256-bit output, Keccak-original 0x01 padding); the only thing +removed is data-parallel lanes on the P3 side. Consequence: the ratio +published by this bench is **apples-to-apples scalar**, not "Plonky3 as +shipped in production." If you want the production-realistic P3 number, +swap the MMCS back to the vector-lane variant from upstream's examples. + +On aarch64 with `feature="asm"` enabled in `crypto/crypto`, Lambda's +`sha3::Keccak256` uses ARMv8 SHA3 intrinsics, which speeds up *one* Keccak +call (no data parallelism). `tiny_keccak`'s `Keccak256Hash` on P3 is pure +Rust and gets no such acceleration. On x86_64 server, neither side has +that path, so the comparison is cleanest there. + +## Notes on fairness + +- **Extension field**: Plonky3 runs upstream `CubicTrinomialExtensionField` + over Goldilocks (`x^3 - x - 1`); Lambda runs `Degree3GoldilocksExtensionField` + (`x^3 - 2`). Both are degree-3 irreducible extensions of `GF(p)` with the + same field size and the same soundness. Cell-by-cell trace equivalence is + asserted by `lambda_pair_trace_matches_plonky3_trace`. +- **Parallelism**: both provers are multi-threaded by default. Lambda pulls + rayon via `stark/parallel`; Plonky3 pulls rayon via `p3-uni-stark` / + `p3-dft` (hardcoded `features = ["parallel"]`, always on). +- **SIMD**: the MMCS Keccak is scalar on both sides (see above). For + Goldilocks field arithmetic, without `--scalar` each side uses whatever + target-features the compiler decides from the host CPU. `--scalar` + (x86_64 only) disables AVX2 / AVX-512. +- **AIR base-field path**: the Lambda AIR overrides + `num_base_transition_constraints` and implements `evaluate_prover` so its + Fibonacci transition constraints are evaluated in the base field (F×E, + ≈3 muls/term) instead of the default extension path (E×E, ≈9 muls/term). + This matches what the production Lambda STARK does for all + domain-constraint AIRs. +- **Queries / grinding**: same `blowup=2`, `queries=219`, `grinding=0` on both + sides. Security models differ (Lambda: Johnson-bound, ~108 bits proven; + P3: conjectured, 219 queries × 1 bit = 219 bits, capped at 192 by the + cubic extension field) — the compute work is equivalent, the claimed + soundness is not. diff --git a/bench_vs_plonky3/benches/stark_comparison.rs b/bench_vs_plonky3/benches/stark_comparison.rs new file mode 100644 index 000000000..577664892 --- /dev/null +++ b/bench_vs_plonky3/benches/stark_comparison.rs @@ -0,0 +1,178 @@ +use criterion::{BenchmarkId, Criterion, Throughput, criterion_group, criterion_main}; +use crypto::fiat_shamir::default_transcript::DefaultTranscript; +use math::field::element::FieldElement; +use math::field::extensions_goldilocks::Degree3GoldilocksExtensionField; +use math::field::goldilocks::GoldilocksField; +use p3_uni_stark::{prove as p3_prove, verify as p3_verify}; +use stark::proof::options::ProofOptions; +use stark::prover::{IsStarkProver, Prover}; +use stark::verifier::{IsStarkVerifier, Verifier}; + +use bench_vs_plonky3::lambda_fibonacci_pair; +use bench_vs_plonky3::plonky3_config; +use bench_vs_plonky3::plonky3_fibonacci; + +type F = GoldilocksField; +type E = Degree3GoldilocksExtensionField; +type FE = FieldElement; + +/// Number of independent Fibonacci sequences. +const NUM_SEQUENCES: usize = 16; + +/// Rows (same for both Lambda and Plonky3 — identical AIR shape). +/// +/// 2^18 rows × 2 Fibonacci steps packed per row = 2^19 effective Fibonacci +/// steps per sequence, matching Lambda's original `FibonacciMultiColumnAIR` +/// at 2^19 rows × 1 step/row. +const ROWS: usize = 1 << 18; +const TRACE_LABEL: &str = "fib_pair_16seq_2^18"; + +/// Production proof options: blowup=2, 219 queries (from +/// `GoldilocksCubicProofOptions::with_blowup(2)`), grinding=0 (excluded +/// from benchmark — identical PoW work on both sides, not informative). +fn benchmark_proof_options() -> ProofOptions { + ProofOptions { + blowup_factor: 2, + fri_number_of_queries: 219, + coset_offset: 3, + grinding_factor: 0, + } +} + +fn lambda_initial_values() -> Vec<(FE, FE)> { + (0..NUM_SEQUENCES) + .map(|i| (FE::from((i + 1) as u64), FE::from((i + 2) as u64))) + .collect() +} + +fn bench_lambda_prove(c: &mut Criterion) { + let mut group = c.benchmark_group("lambda_stark/prove"); + group.throughput(Throughput::Elements((ROWS * 2 * NUM_SEQUENCES) as u64)); + let proof_options = benchmark_proof_options(); + + group.bench_with_input( + BenchmarkId::new("fibonacci", TRACE_LABEL), + &ROWS, + |b, &rows| { + b.iter_with_setup( + || { + let initial_values = lambda_initial_values(); + let trace = lambda_fibonacci_pair::compute_trace::(&initial_values, rows); + let pub_inputs = lambda_fibonacci_pair::create_public_inputs(initial_values); + let air = + lambda_fibonacci_pair::FibonacciPairMultiColAIR::::with_num_sequences( + &proof_options, + NUM_SEQUENCES, + ); + (trace, pub_inputs, air) + }, + |(mut trace, pub_inputs, air)| { + Prover::::prove( + &air, + &mut trace, + &pub_inputs, + &mut DefaultTranscript::::new(&[]), + ) + .unwrap() + }, + ); + }, + ); + group.finish(); +} + +fn bench_plonky3_prove(c: &mut Criterion) { + let mut group = c.benchmark_group("plonky3_stark/prove"); + group.throughput(Throughput::Elements((ROWS * 2 * NUM_SEQUENCES) as u64)); + + group.bench_with_input( + BenchmarkId::new("fibonacci", TRACE_LABEL), + &ROWS, + |b, &rows| { + b.iter_with_setup( + || { + let config = plonky3_config::matched_params_config(); + let air = plonky3_fibonacci::P3FibonacciAir { + num_sequences: NUM_SEQUENCES, + }; + let trace = plonky3_fibonacci::generate_fibonacci_trace(NUM_SEQUENCES, rows); + let pis = plonky3_fibonacci::public_values(NUM_SEQUENCES); + (config, air, trace, pis) + }, + |(config, air, trace, pis)| p3_prove(&config, &air, trace, &pis), + ); + }, + ); + group.finish(); +} + +fn bench_lambda_verify(c: &mut Criterion) { + let mut group = c.benchmark_group("lambda_stark/verify"); + group.throughput(Throughput::Elements((ROWS * 2 * NUM_SEQUENCES) as u64)); + let proof_options = benchmark_proof_options(); + + let initial_values = lambda_initial_values(); + let mut trace = lambda_fibonacci_pair::compute_trace::(&initial_values, ROWS); + let pub_inputs = lambda_fibonacci_pair::create_public_inputs(initial_values); + let air = lambda_fibonacci_pair::FibonacciPairMultiColAIR::::with_num_sequences( + &proof_options, + NUM_SEQUENCES, + ); + let proof = Prover::::prove( + &air, + &mut trace, + &pub_inputs, + &mut DefaultTranscript::::new(&[]), + ) + .unwrap(); + + group.bench_with_input(BenchmarkId::new("fibonacci", TRACE_LABEL), &ROWS, |b, _| { + b.iter(|| { + assert!(Verifier::::verify( + &proof, + &air, + &mut DefaultTranscript::::new(&[]), + )) + }); + }); + group.finish(); +} + +fn bench_plonky3_verify(c: &mut Criterion) { + let mut group = c.benchmark_group("plonky3_stark/verify"); + group.throughput(Throughput::Elements((ROWS * 2 * NUM_SEQUENCES) as u64)); + + let air = plonky3_fibonacci::P3FibonacciAir { + num_sequences: NUM_SEQUENCES, + }; + let trace = plonky3_fibonacci::generate_fibonacci_trace(NUM_SEQUENCES, ROWS); + let pis = plonky3_fibonacci::public_values(NUM_SEQUENCES); + let config = plonky3_config::matched_params_config(); + let proof = p3_prove(&config, &air, trace, &pis); + + group.bench_with_input(BenchmarkId::new("fibonacci", TRACE_LABEL), &ROWS, |b, _| { + b.iter(|| { + let config = plonky3_config::matched_params_config(); + p3_verify(&config, &air, &proof, &pis).unwrap(); + }); + }); + group.finish(); +} + +criterion_group! { + name = prove_comparison; + config = Criterion::default() + .sample_size(10) + .measurement_time(std::time::Duration::from_secs(120)); + targets = bench_lambda_prove, bench_plonky3_prove +} + +criterion_group! { + name = verify_comparison; + config = Criterion::default() + .sample_size(10) + .measurement_time(std::time::Duration::from_secs(30)); + targets = bench_lambda_verify, bench_plonky3_verify +} + +criterion_main!(prove_comparison, verify_comparison); diff --git a/bench_vs_plonky3/run.sh b/bench_vs_plonky3/run.sh new file mode 100755 index 000000000..a4ce67fc2 --- /dev/null +++ b/bench_vs_plonky3/run.sh @@ -0,0 +1,619 @@ +#!/bin/bash +# Benchmark: Lambda STARK vs Plonky3 — single-shot prove time on the shared +# Fibonacci AIR (columns = 2 * num_sequences, blowup = 2, fri_queries = 219). +# +# Usage: +# ./bench_vs_plonky3/run.sh [--log-rows K ...] [--num-sequences N] [--runs N] +# [--lambda-only | --p3-only] [--report-dir DIR] +# [--scalar] [--breakdown] [--no-color] +# +# Defaults: --log-rows 19, --num-sequences 16, --runs 10. +# With multiple --log-rows values, prints one stats row per size. +# +# --scalar: on x86_64 drops AVX2 / AVX-512 so Goldilocks runs scalar. The MMCS +# itself is already scalar (single-input tiny_keccak via Keccak256Hash) regardless +# of this flag — its SIMD lanes were removed in the config. Triggers a rebuild +# when toggling; subsequent runs with the same RUSTFLAGS are cached. + +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +ROOT_DIR="$(cd "$SCRIPT_DIR/.." && pwd)" +TMP_DIR="$(mktemp -d -t bench_p3.XXXXXX)" +trap 'rm -rf "$TMP_DIR"' EXIT +REPORT_DIR="" +NO_COLOR=false +SCALAR=false +BREAKDOWN=false + +RED='\033[0;31m' +GREEN='\033[0;32m' +YELLOW='\033[1;33m' +BOLD='\033[1m' +NC='\033[0m' + +LOG_ROWS=() +NUM_SEQUENCES=16 +RUNS=10 +BLOWUP=2 +FRI_QUERIES=219 +GRINDING=0 +RUN_LAMBDA=true +RUN_P3=true + +# --- Parse args ------------------------------------------------------------- +while [[ $# -gt 0 ]]; do + case $1 in + --log-rows) + shift + while [[ $# -gt 0 && ! "$1" =~ ^-- ]]; do + LOG_ROWS+=("$1") + shift + done + ;; + --num-sequences) + if [[ $# -lt 2 ]]; then echo "--num-sequences requires an argument"; exit 1; fi + NUM_SEQUENCES=$2 + shift 2 + ;; + --runs) + if [[ $# -lt 2 ]]; then echo "--runs requires an argument"; exit 1; fi + RUNS=$2 + shift 2 + ;; + --lambda-only) + RUN_P3=false + shift + ;; + --p3-only) + RUN_LAMBDA=false + shift + ;; + --report-dir) + if [[ $# -lt 2 ]]; then echo "--report-dir requires an argument"; exit 1; fi + REPORT_DIR=$2 + shift 2 + ;; + --scalar) + SCALAR=true + shift + ;; + --breakdown) + BREAKDOWN=true + shift + ;; + --no-color) + NO_COLOR=true + shift + ;; + -h|--help) + sed -n '2,11p' "$0" | sed 's/^# //' + exit 0 + ;; + *) + echo "Unknown option: $1" + exit 1 + ;; + esac +done + +if [ ${#LOG_ROWS[@]} -eq 0 ]; then + LOG_ROWS=(19) +fi + +if ! $RUN_LAMBDA && ! $RUN_P3; then + echo "At least one prover must be enabled" + exit 1 +fi + +if [ "$RUNS" -lt 1 ]; then + echo "--runs must be >= 1" + exit 1 +fi + +if $NO_COLOR; then + RED='' + GREEN='' + YELLOW='' + BOLD='' + NC='' +fi + +if [ -n "$REPORT_DIR" ]; then + mkdir -p "$REPORT_DIR/raw" +fi + +# --- Scalar (no SIMD) toggle ------------------------------------------------ +# When --scalar is on, disable AVX2/AVX-512 so Goldilocks field arithmetic runs +# scalar for an apples-to-apples comparison against Lambda STARK. The MMCS Keccak +# is already scalar regardless of this flag (see plonky3_config.rs). +# Cargo caches per-RUSTFLAGS, so toggling scalar vs vector triggers a rebuild +# on first use but is cached afterwards. +SCALAR_RUSTFLAGS="" +SCALAR_ACTIVE=false +if $SCALAR; then + case "$(uname -m)" in + x86_64|amd64) + SCALAR_RUSTFLAGS="-C target-feature=-avx2,-avx512f" + SCALAR_ACTIVE=true + ;; + *) + echo "warning: --scalar: only supported on x86_64; host is $(uname -m), not pinning RUSTFLAGS" >&2 + ;; + esac + if [ -n "$SCALAR_RUSTFLAGS" ]; then + if [ -n "${RUSTFLAGS:-}" ]; then + export RUSTFLAGS="${RUSTFLAGS} ${SCALAR_RUSTFLAGS}" + else + export RUSTFLAGS="$SCALAR_RUSTFLAGS" + fi + fi +fi + +# --- Build ------------------------------------------------------------------ +echo -e "${BOLD}=== STARK prove benchmark: Lambda vs Plonky3 ===${NC}" +echo -e " log-rows: ${YELLOW}${LOG_ROWS[*]}${NC}" +echo -e " num-sequences: ${YELLOW}${NUM_SEQUENCES}${NC} (columns = $((2 * NUM_SEQUENCES)))" +echo -e " runs/size: ${YELLOW}${RUNS}${NC} (median + CV reported)" +echo -e " p3 extension: ${YELLOW}upstream CubicTrinomialExtensionField (x^3 - x - 1)${NC}" +echo -e " p3 mmcs: ${YELLOW}scalar Keccak256 (val_packing_width=1, hash_lanes=1)${NC}" +echo -e " proof params: ${YELLOW}blowup=${BLOWUP}, queries=${FRI_QUERIES}, grinding=${GRINDING}${NC}" +if $BREAKDOWN; then + echo -e " breakdown: ${YELLOW}on${NC} (Lambda instruments + P3 tracing spans)" +else + echo -e " breakdown: ${YELLOW}off${NC}" +fi +if $SCALAR_ACTIVE; then + echo -e " scalar mode: ${YELLOW}on${NC} (arch=$(uname -m), RUSTFLAGS=\"${RUSTFLAGS:-}\")" +elif $SCALAR; then + echo -e " scalar mode: ${YELLOW}requested (unsupported on $(uname -m))${NC} (SIMD enabled, compiler default)" +else + echo -e " scalar mode: ${YELLOW}off${NC} (SIMD enabled, compiler default)" +fi +echo "" + +echo -e "${GREEN}[build]${NC} prove_bench" +BUILD_ARGS=(build --release -p bench-vs-plonky3 --bin prove_bench --manifest-path "$ROOT_DIR/Cargo.toml") +if $BREAKDOWN; then + BUILD_ARGS+=(--features instruments) +fi +cargo "${BUILD_ARGS[@]}" 2>&1 | tail -5 + +# Resolve the actual target directory via cargo metadata so we find the binary +# whether cargo used ./target/ (default) or a custom CARGO_TARGET_DIR. +TARGET_DIR=$(cargo metadata --manifest-path "$ROOT_DIR/Cargo.toml" \ + --format-version 1 --no-deps 2>/dev/null \ + | python3 -c 'import json, sys; print(json.load(sys.stdin)["target_directory"])' \ + 2>/dev/null || echo "$ROOT_DIR/target") +BIN="$TARGET_DIR/release/prove_bench" +if [ ! -x "$BIN" ]; then + echo -e "${RED}[build] prove_bench not produced at $BIN${NC}" + exit 1 +fi + +# --- Helpers ---------------------------------------------------------------- +extract_proving_time() { + sed -nE '/Proving time: [0-9.]+s/ { + s/.*Proving time: ([0-9.]+)s.*/\1/ + p + q + }' +} + +extract_metrics_line() { + sed -n '/^METRICS / { + p + q + }' +} + +extract_audit_line() { + sed -n '/^AUDIT / { + p + q + }' +} + +metric_value() { + local line=$1 + local key=$2 + printf '%s\n' "$line" | tr '\t' '\n' | LC_ALL=C awk -F= -v key="$key" '$1 == key { print $2; exit }' +} + +median_of() { + # prints median of the given numeric arguments. + # Uses shell `sort -g` for portability (macOS awk lacks gawk's asort). + printf '%s\n' "$@" | LC_ALL=C sort -g | LC_NUMERIC=C awk ' + { a[NR] = $0 + 0 } + END { + if (NR == 0) { print "n/a"; exit } + if (NR % 2 == 1) { + printf "%.6f\n", a[(NR + 1) / 2] + } else { + printf "%.6f\n", (a[NR / 2] + a[NR / 2 + 1]) / 2 + } + }' +} + +ratio_fmt() { + LC_NUMERIC=C awk -v num="$1" -v den="$2" 'BEGIN { + if (den + 0 == 0) { print "n/a"; exit } + printf "%.3f\n", num / den + }' +} + +median_file() { + LC_ALL=C sort -g "$1" | LC_NUMERIC=C awk ' + { a[NR] = $0 + 0 } + END { + if (NR == 0) { print "n/a"; exit } + if (NR % 2 == 1) printf "%.6f\n", a[(NR + 1) / 2] + else printf "%.6f\n", (a[NR / 2] + a[NR / 2 + 1]) / 2 + }' +} + +cv_pct_file() { + LC_NUMERIC=C awk ' + { s += $1; ss += $1 * $1; n++ } + END { + if (n == 0) { print "n/a"; exit } + m = s / n + v = (ss / n) - (m * m) + if (v < 0) v = 0 + sd = sqrt(v) + if (m == 0) print "n/a" + else printf "%.2f\n", sd * 100 / m + }' "$1" +} + +fmt0() { + LC_NUMERIC=C awk -v v="$1" 'BEGIN { if (v == "n/a") print v; else printf "%.0f\n", v }' +} + +metric_file_for() { + local metrics_file=$1 + local key=$2 + local out_file=$3 + : > "$out_file" + while IFS= read -r line; do + local value + value=$(metric_value "$line" "$key") + if [ -n "$value" ] && [ "$value" != "n/a" ]; then + printf '%s\n' "$value" >> "$out_file" + fi + done < "$metrics_file" +} + +median_metric() { + local prover=$1 + local log_rows=$2 + local key=$3 + local file="$TMP_DIR/${prover}_${log_rows}_${key}.values" + metric_file_for "$TMP_DIR/${prover}_${log_rows}.metrics" "$key" "$file" + if [ ! -s "$file" ]; then + printf "n/a\n" + else + median_file "$file" + fi +} + +# --- Run benchmark ---------------------------------------------------------- + +RESULT_LOG_ROWS=() +RESULT_ROWS=() +RESULT_LAMBDA=() +RESULT_P3=() +RESULT_RATIO=() +RESULT_LAMBDA_CV=() +RESULT_P3_CV=() +RESULT_LAMBDA_VERIFY=() +RESULT_P3_VERIFY=() +RESULT_LAMBDA_PROOF_SIZE=() +RESULT_P3_PROOF_SIZE=() +RESULT_LAMBDA_RSS=() +RESULT_P3_RSS=() + +run_prover() { + local prover=$1 # lambda | p3 + local log_rows=$2 + local times=() + local metrics_file="$TMP_DIR/${prover}_${log_rows}.metrics" + local audit_file="$TMP_DIR/${prover}_${log_rows}.audits" + local breakdown_file="$TMP_DIR/${prover}_${log_rows}.breakdown" + : > "$metrics_file" + : > "$audit_file" + : > "$breakdown_file" + for run_i in $(seq 1 "$RUNS"); do + local out_file="$TMP_DIR/${prover}_${log_rows}_${run_i}.stdout" + local run_args=(--prover "$prover" --log-rows "$log_rows" --num-sequences "$NUM_SEQUENCES" --blowup "$BLOWUP" --queries "$FRI_QUERIES" --grinding "$GRINDING") + if $BREAKDOWN; then + run_args+=(--breakdown) + fi + if ! "$BIN" "${run_args[@]}" > "$out_file" 2>&1; then + echo -e " ${RED}[${prover}] FAILED on log-rows=${log_rows} run ${run_i}${NC}" >&2 + cat "$out_file" >&2 + exit 1 + fi + local audit_line + audit_line=$(extract_audit_line < "$out_file") + if [ -n "$audit_line" ]; then + printf 'run=%s\t%s\n' "$run_i" "$audit_line" >> "$audit_file" + fi + local metrics_line + metrics_line=$(extract_metrics_line < "$out_file") + if [ -z "$metrics_line" ]; then + echo -e " ${RED}[${prover}] could not parse metrics (log-rows=${log_rows}, run ${run_i})${NC}" >&2 + cat "$out_file" >&2 + exit 1 + fi + printf '%s\n' "$metrics_line" >> "$metrics_file" + if $BREAKDOWN; then + sed -n "s/^BREAKDOWN /BREAKDOWN run=${run_i} /p" "$out_file" >> "$breakdown_file" + fi + + local t + t=$(metric_value "$metrics_line" prove_s) + if [ -z "$t" ]; then + t=$(extract_proving_time < "$out_file") + fi + times+=("$t") + if [ -n "$REPORT_DIR" ]; then + cp "$out_file" "$REPORT_DIR/raw/${prover}_log${log_rows}_run${run_i}.stdout" + fi + done + printf '%s\n' "${times[@]}" > "$TMP_DIR/${prover}_${log_rows}.times" + median_of "${times[@]}" +} + +for lr in "${LOG_ROWS[@]}"; do + rows=$((1 << lr)) + echo -e "${BOLD}--- log-rows=${lr} (rows = ${rows}) ---${NC}" + + lambda_median="n/a" + p3_median="n/a" + lambda_cv="n/a" + p3_cv="n/a" + lambda_verify="n/a" + p3_verify="n/a" + lambda_proof_size="n/a" + p3_proof_size="n/a" + lambda_rss="n/a" + p3_rss="n/a" + + if $RUN_LAMBDA; then + echo -ne " ${GREEN}[lambda]${NC} " + lambda_median=$(run_prover lambda "$lr") + lambda_cv=$(cv_pct_file "$TMP_DIR/lambda_${lr}.times") + lambda_verify=$(median_metric lambda "$lr" verify_s) + lambda_proof_size=$(median_metric lambda "$lr" proof_size_bytes) + lambda_rss=$(median_metric lambda "$lr" peak_rss_kb) + echo -e "prove median ${BOLD}${lambda_median}s${NC} (CV ${lambda_cv}%), verify ${lambda_verify}s, proof $(fmt0 "$lambda_proof_size") B, rss $(fmt0 "$lambda_rss") KB" + fi + + if $RUN_P3; then + echo -ne " ${GREEN}[p3]${NC} " + p3_median=$(run_prover p3 "$lr") + p3_cv=$(cv_pct_file "$TMP_DIR/p3_${lr}.times") + p3_verify=$(median_metric p3 "$lr" verify_s) + p3_proof_size=$(median_metric p3 "$lr" proof_size_bytes) + p3_rss=$(median_metric p3 "$lr" peak_rss_kb) + echo -e "prove median ${BOLD}${p3_median}s${NC} (CV ${p3_cv}%), verify ${p3_verify}s, proof $(fmt0 "$p3_proof_size") B, rss $(fmt0 "$p3_rss") KB" + fi + + local_ratio="n/a" + if $RUN_LAMBDA && $RUN_P3; then + local_ratio=$(ratio_fmt "$lambda_median" "$p3_median") + fi + + RESULT_LOG_ROWS+=("$lr") + RESULT_ROWS+=("$rows") + RESULT_LAMBDA+=("$lambda_median") + RESULT_P3+=("$p3_median") + RESULT_RATIO+=("$local_ratio") + RESULT_LAMBDA_CV+=("$lambda_cv") + RESULT_P3_CV+=("$p3_cv") + RESULT_LAMBDA_VERIFY+=("$lambda_verify") + RESULT_P3_VERIFY+=("$p3_verify") + RESULT_LAMBDA_PROOF_SIZE+=("$lambda_proof_size") + RESULT_P3_PROOF_SIZE+=("$p3_proof_size") + RESULT_LAMBDA_RSS+=("$lambda_rss") + RESULT_P3_RSS+=("$p3_rss") +done + +# --- Summary table ---------------------------------------------------------- + +echo "" +echo -e "${BOLD}=== Summary ===${NC}" +if $RUN_LAMBDA && $RUN_P3; then + printf " %-9s %-12s %14s %9s %14s %9s %10s\n" "log-rows" "rows" "Lambda (s)" "L CV%" "P3 (s)" "P3 CV%" "L/P3" + printf " %-9s %-12s %14s %9s %14s %9s %10s\n" "--------" "----" "----------" "-----" "------" "------" "----" +else + printf " %-9s %-12s %14s %9s\n" "log-rows" "rows" "Time (s)" "CV%" + printf " %-9s %-12s %14s %9s\n" "--------" "----" "--------" "---" +fi + +for i in "${!RESULT_LOG_ROWS[@]}"; do + lr="${RESULT_LOG_ROWS[$i]}" + rows="${RESULT_ROWS[$i]}" + lt="${RESULT_LAMBDA[$i]}" + pt="${RESULT_P3[$i]}" + rt="${RESULT_RATIO[$i]}" + lcv="${RESULT_LAMBDA_CV[$i]}" + pcv="${RESULT_P3_CV[$i]}" + if $RUN_LAMBDA && $RUN_P3; then + color=$GREEN + verdict="Lambda faster" + if awk -v l="$lt" -v p="$pt" 'BEGIN{ exit !(l+0 > p+0) }'; then + color=$RED + verdict="P3 faster" + fi + printf " %-9s %-12s %13ss %8s%% %13ss %8s%% ${color}%9sx${NC} (${color}%s${NC})\n" \ + "$lr" "$rows" "$lt" "$lcv" "$pt" "$pcv" "$rt" "$verdict" + elif $RUN_LAMBDA; then + printf " %-9s %-12s %13ss %8s%%\n" "$lr" "$rows" "$lt" "$lcv" + else + printf " %-9s %-12s %13ss %8s%%\n" "$lr" "$rows" "$pt" "$pcv" + fi +done + +echo "" +if $RUN_LAMBDA && $RUN_P3; then + echo -e "Timing window: prove only for the ratio. Verify, proof size, RSS and throughput are reported separately." +fi + +# --- Machine-readable report ------------------------------------------------ + +if [ -n "$REPORT_DIR" ]; then + # Slash-joined helpers for metrics.txt (mirrors the format used by + # bench_vs/run.sh). + join_slash() { + local joined="" + for value in "$@"; do + joined="${joined:+$joined/}$value" + done + printf "%s\n" "$joined" + } + + { + printf "log_rows\trows\tlambda_prove_median_s\tlambda_prove_cv_pct\tlambda_verify_median_s\tlambda_proof_size_bytes_median\tlambda_peak_rss_kb_median\tp3_prove_median_s\tp3_prove_cv_pct\tp3_verify_median_s\tp3_proof_size_bytes_median\tp3_peak_rss_kb_median\tratio_lambda_over_p3\truns\n" + for i in "${!RESULT_LOG_ROWS[@]}"; do + printf "%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\n" \ + "${RESULT_LOG_ROWS[$i]}" \ + "${RESULT_ROWS[$i]}" \ + "${RESULT_LAMBDA[$i]}" \ + "${RESULT_LAMBDA_CV[$i]}" \ + "${RESULT_LAMBDA_VERIFY[$i]}" \ + "${RESULT_LAMBDA_PROOF_SIZE[$i]}" \ + "${RESULT_LAMBDA_RSS[$i]}" \ + "${RESULT_P3[$i]}" \ + "${RESULT_P3_CV[$i]}" \ + "${RESULT_P3_VERIFY[$i]}" \ + "${RESULT_P3_PROOF_SIZE[$i]}" \ + "${RESULT_P3_RSS[$i]}" \ + "${RESULT_RATIO[$i]}" \ + "$RUNS" + done + } > "$REPORT_DIR/results.tsv" + + { + printf "workload\tprover\tlog_rows\trows\tnum_sequences\tmain_cols\taux_cols\ttables\tlogup\tblowup\tfri_queries\tgrinding\tprove_s\tverify_s\tproof_size_bytes\tpeak_rss_kb\trows_per_sec\tcells_per_sec\n" + for lr in "${RESULT_LOG_ROWS[@]}"; do + for prover in lambda p3; do + metrics_file="$TMP_DIR/${prover}_${lr}.metrics" + if [ ! -f "$metrics_file" ]; then + continue + fi + while IFS= read -r line; do + printf "%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\n" \ + "$(metric_value "$line" workload)" \ + "$(metric_value "$line" prover)" \ + "$(metric_value "$line" log_rows)" \ + "$(metric_value "$line" rows)" \ + "$(metric_value "$line" num_sequences)" \ + "$(metric_value "$line" main_cols)" \ + "$(metric_value "$line" aux_cols)" \ + "$(metric_value "$line" tables)" \ + "$(metric_value "$line" logup)" \ + "$(metric_value "$line" blowup)" \ + "$(metric_value "$line" fri_queries)" \ + "$(metric_value "$line" grinding)" \ + "$(metric_value "$line" prove_s)" \ + "$(metric_value "$line" verify_s)" \ + "$(metric_value "$line" proof_size_bytes)" \ + "$(metric_value "$line" peak_rss_kb)" \ + "$(metric_value "$line" rows_per_sec)" \ + "$(metric_value "$line" cells_per_sec)" + done < "$metrics_file" + done + done + } > "$REPORT_DIR/raw_metrics.tsv" + + # Raw AUDIT lines per run, one row per prover×log_rows×run. Lets the reader + # confirm in retrospect that val_packing_width=1, hash_lanes=1, etc. + { + printf "run\taudit_line\n" + for lr in "${RESULT_LOG_ROWS[@]}"; do + for prover in lambda p3; do + audit_file="$TMP_DIR/${prover}_${lr}.audits" + if [ -f "$audit_file" ]; then + cat "$audit_file" + fi + done + done + } > "$REPORT_DIR/raw_audits.tsv" + + if $BREAKDOWN; then + { + printf "run\tworkload\tprover\tlog_rows\trows\tphase\tms\ttable\ttable_rows\tspan\n" + for lr in "${RESULT_LOG_ROWS[@]}"; do + for prover in lambda p3; do + breakdown_file="$TMP_DIR/${prover}_${lr}.breakdown" + if [ ! -f "$breakdown_file" ]; then + continue + fi + while IFS= read -r line; do + printf "%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\n" \ + "$(metric_value "$line" run)" \ + "$(metric_value "$line" workload)" \ + "$(metric_value "$line" prover)" \ + "$(metric_value "$line" log_rows)" \ + "$(metric_value "$line" rows)" \ + "$(metric_value "$line" phase)" \ + "$(metric_value "$line" ms)" \ + "$(metric_value "$line" table)" \ + "$(metric_value "$line" table_rows)" \ + "$(metric_value "$line" span)" + done < "$breakdown_file" + done + done + } > "$REPORT_DIR/breakdown.tsv" + fi + + # Capture commit + timestamp so the artifact is self-describing. + git_sha="$(git -C "$ROOT_DIR" rev-parse HEAD 2>/dev/null || echo unknown)" + git_dirty="clean" + if ! git -C "$ROOT_DIR" diff --quiet HEAD -- 2>/dev/null; then + git_dirty="dirty" + fi + timestamp_utc="$(date -u +%Y-%m-%dT%H:%M:%SZ)" + + { + echo "timestamp_utc=$timestamp_utc" + echo "git_sha=$git_sha" + echo "git_tree=$git_dirty" + echo "arch=$(uname -m)" + echo "num_sequences=$NUM_SEQUENCES" + echo "columns=$((2 * NUM_SEQUENCES))" + echo "blowup=$BLOWUP" + echo "fri_queries=$FRI_QUERIES" + echo "grinding=$GRINDING" + echo "runs_per_size=$RUNS" + if $BREAKDOWN; then + echo "breakdown=on" + else + echo "breakdown=off" + fi + echo "p3_extension=upstream_cubic_trinomial" + echo "p3_mmcs=scalar_keccak256" + if $SCALAR_ACTIVE; then + echo "scalar=on" + echo "rustflags=$SCALAR_RUSTFLAGS" + elif $SCALAR; then + echo "scalar=requested_unsupported" + else + echo "scalar=off" + fi + echo "timing_window=prove_only_ratio_verify_size_rss_reported_separately" + echo "log_rows_series=$(join_slash "${RESULT_LOG_ROWS[@]}")" + echo "rows_series=$(join_slash "${RESULT_ROWS[@]}")" + echo "lambda_prove_medians=$(join_slash "${RESULT_LAMBDA[@]}")" + echo "p3_prove_medians=$(join_slash "${RESULT_P3[@]}")" + echo "lambda_verify_medians=$(join_slash "${RESULT_LAMBDA_VERIFY[@]}")" + echo "p3_verify_medians=$(join_slash "${RESULT_P3_VERIFY[@]}")" + echo "lambda_proof_size_medians=$(join_slash "${RESULT_LAMBDA_PROOF_SIZE[@]}")" + echo "p3_proof_size_medians=$(join_slash "${RESULT_P3_PROOF_SIZE[@]}")" + echo "lambda_peak_rss_medians=$(join_slash "${RESULT_LAMBDA_RSS[@]}")" + echo "p3_peak_rss_medians=$(join_slash "${RESULT_P3_RSS[@]}")" + echo "ratios_lambda_over_p3=$(join_slash "${RESULT_RATIO[@]}")" + } > "$REPORT_DIR/metrics.txt" +fi diff --git a/bench_vs_plonky3/src/bin/prove_bench.rs b/bench_vs_plonky3/src/bin/prove_bench.rs new file mode 100644 index 000000000..c132f57a5 --- /dev/null +++ b/bench_vs_plonky3/src/bin/prove_bench.rs @@ -0,0 +1,569 @@ +//! Minimal wall-clock benchmark harness for Lambda STARK vs Plonky3. +//! +//! Builds the same Fibonacci AIR as `instruments_breakdown` (but without any +//! instrumentation) and prints human-readable timings plus one tab-separated +//! `METRICS` line, suitable for parsing by `bench_vs_plonky3/run.sh`. +//! +//! Usage: +//! prove_bench --prover {lambda|p3} [--log-rows K] [--num-sequences N] +//! [--blowup B] [--queries Q] [--grinding G] [--breakdown] +//! +//! Defaults match production (`GoldilocksCubicProofOptions::with_blowup(2)`): +//! log-rows=19, num-sequences=16, blowup=2, queries=219, grinding=0. + +use std::process::ExitCode; +use std::time::Instant; + +use bench_vs_plonky3::span_timing::{P3TimingLayer, SpanResults as P3SpanResults}; +use bench_vs_plonky3::{lambda_fibonacci_pair, plonky3_config, plonky3_fibonacci}; +use crypto::fiat_shamir::default_transcript::DefaultTranscript; +use math::field::element::FieldElement; +use math::field::extensions_goldilocks::Degree3GoldilocksExtensionField; +use math::field::goldilocks::GoldilocksField; +use stark::proof::options::ProofOptions; +use stark::prover::{IsStarkProver, Prover}; +use stark::verifier::{IsStarkVerifier, Verifier}; +use tracing_subscriber::layer::SubscriberExt; + +type F = GoldilocksField; +type E = Degree3GoldilocksExtensionField; +type FE = FieldElement; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum ProverKind { + Lambda, + P3, +} + +struct Args { + prover: ProverKind, + log_rows: u32, + num_sequences: usize, + blowup: u8, + queries: usize, + grinding: u8, + breakdown: bool, +} + +struct BenchMetrics { + prove_s: f64, + verify_s: f64, + proof_size_bytes: usize, + peak_rss_kb: Option, +} + +impl Default for Args { + fn default() -> Self { + Self { + prover: ProverKind::Lambda, + log_rows: 19, + num_sequences: 16, + blowup: 2, + queries: 219, + grinding: 0, + breakdown: false, + } + } +} + +fn print_usage() { + eprintln!( + "usage: prove_bench --prover {{lambda|p3}} \ + [--log-rows K] [--num-sequences N] \ + [--blowup B] [--queries Q] [--grinding G] [--breakdown]" + ); +} + +fn parse_args() -> Result { + let mut args = Args::default(); + let mut prover_set = false; + let mut iter = std::env::args().skip(1); + while let Some(a) = iter.next() { + match a.as_str() { + "--prover" => { + let v = iter.next().ok_or("--prover needs a value")?; + args.prover = match v.as_str() { + "lambda" => ProverKind::Lambda, + "p3" => ProverKind::P3, + other => return Err(format!("unknown prover: {other}")), + }; + prover_set = true; + } + "--log-rows" => { + let v = iter.next().ok_or("--log-rows needs a value")?; + args.log_rows = v.parse().map_err(|_| "--log-rows: invalid u32")?; + } + "--num-sequences" => { + let v = iter.next().ok_or("--num-sequences needs a value")?; + args.num_sequences = v.parse().map_err(|_| "--num-sequences: invalid usize")?; + } + "--blowup" => { + let v = iter.next().ok_or("--blowup needs a value")?; + args.blowup = v.parse().map_err(|_| "--blowup: invalid u8")?; + } + "--queries" => { + let v = iter.next().ok_or("--queries needs a value")?; + args.queries = v.parse().map_err(|_| "--queries: invalid usize")?; + } + "--grinding" => { + let v = iter.next().ok_or("--grinding needs a value")?; + args.grinding = v.parse().map_err(|_| "--grinding: invalid u8")?; + } + "--breakdown" => { + args.breakdown = true; + } + "-h" | "--help" => { + print_usage(); + std::process::exit(0); + } + other => return Err(format!("unknown arg: {other}")), + } + } + if !prover_set { + return Err("--prover is required".into()); + } + if args.log_rows < 2 || args.log_rows > 30 { + return Err("--log-rows must be in [2, 30]".into()); + } + if args.num_sequences == 0 { + return Err("--num-sequences must be > 0".into()); + } + if !args.blowup.is_power_of_two() { + return Err("--blowup must be a power of two".into()); + } + if args.queries == 0 { + return Err("--queries must be > 0".into()); + } + Ok(args) +} + +fn proof_options(args: &Args) -> ProofOptions { + ProofOptions { + blowup_factor: args.blowup, + fri_number_of_queries: args.queries, + coset_offset: 3, + grinding_factor: args.grinding, + } +} + +fn ms(seconds: f64) -> f64 { + seconds * 1000.0 +} + +fn print_breakdown( + prover: &str, + log_rows: u32, + rows: usize, + phase: &str, + elapsed_ms: f64, + extra: &str, +) { + println!( + "BREAKDOWN\tworkload=fib_pair\tprover={prover}\tlog_rows={log_rows}\trows={rows}\tphase={phase}\tms={elapsed_ms:.3}{extra}" + ); +} + +#[cfg(feature = "instruments")] +fn emit_lambda_breakdown(args: &Args, rows: usize, total_ms: f64) { + print_breakdown("lambda", args.log_rows, rows, "prove_total", total_ms, ""); + + if let Some(timing) = stark::instruments::take() { + print_breakdown( + "lambda", + args.log_rows, + rows, + "prepass", + ms(timing.prepass.as_secs_f64()), + "", + ); + print_breakdown( + "lambda", + args.log_rows, + rows, + "main_commits", + ms(timing.main_commits.as_secs_f64()), + "", + ); + print_breakdown( + "lambda", + args.log_rows, + rows, + "aux_build", + ms(timing.aux_build.as_secs_f64()), + "", + ); + print_breakdown( + "lambda", + args.log_rows, + rows, + "aux_commit", + ms(timing.aux_commit.as_secs_f64()), + "", + ); + print_breakdown( + "lambda", + args.log_rows, + rows, + "rounds_2_4", + ms(timing.rounds_2_4.as_secs_f64()), + "", + ); + + let r1 = timing.round1_sub; + print_breakdown( + "lambda", + args.log_rows, + rows, + "r1_main_lde", + ms(r1.main_lde.as_secs_f64()), + "", + ); + print_breakdown( + "lambda", + args.log_rows, + rows, + "r1_main_merkle", + ms(r1.main_merkle.as_secs_f64()), + "", + ); + print_breakdown( + "lambda", + args.log_rows, + rows, + "r1_aux_lde", + ms(r1.aux_lde.as_secs_f64()), + "", + ); + print_breakdown( + "lambda", + args.log_rows, + rows, + "r1_aux_merkle", + ms(r1.aux_merkle.as_secs_f64()), + "", + ); + + for (name, table_rows, dur, sub) in timing.table_timings { + let extra = format!("\ttable={name}\ttable_rows={table_rows}"); + print_breakdown( + "lambda", + args.log_rows, + rows, + "table_total", + ms(dur.as_secs_f64()), + &extra, + ); + print_breakdown( + "lambda", + args.log_rows, + rows, + "r2_constraints", + ms(sub.constraints.as_secs_f64()), + &extra, + ); + print_breakdown( + "lambda", + args.log_rows, + rows, + "r2_comp_decompose", + ms(sub.comp_decompose.as_secs_f64()), + &extra, + ); + print_breakdown( + "lambda", + args.log_rows, + rows, + "r2_comp_commit", + ms(sub.comp_commit.as_secs_f64()), + &extra, + ); + print_breakdown( + "lambda", + args.log_rows, + rows, + "r3_ood", + ms(sub.ood.as_secs_f64()), + &extra, + ); + print_breakdown( + "lambda", + args.log_rows, + rows, + "r4_deep_comp", + ms(sub.deep_comp.as_secs_f64()), + &extra, + ); + print_breakdown( + "lambda", + args.log_rows, + rows, + "r4_deep_extend", + ms(sub.deep_extend.as_secs_f64()), + &extra, + ); + print_breakdown( + "lambda", + args.log_rows, + rows, + "r4_fri_commit", + ms(sub.fri_commit.as_secs_f64()), + &extra, + ); + print_breakdown( + "lambda", + args.log_rows, + rows, + "r4_queries", + ms(sub.queries.as_secs_f64()), + &extra, + ); + } + } +} + +#[cfg(not(feature = "instruments"))] +fn emit_lambda_breakdown(args: &Args, rows: usize, total_ms: f64) { + print_breakdown("lambda", args.log_rows, rows, "prove_total", total_ms, ""); + eprintln!("warning: Lambda phase breakdown requires building with --features instruments"); +} + +fn p3_span_subscriber() -> (impl tracing::Subscriber + Send + Sync, P3SpanResults) { + let (layer, results) = P3TimingLayer::new(); + let filter = tracing_subscriber::filter::LevelFilter::DEBUG; + ( + tracing_subscriber::registry().with(filter).with(layer), + results, + ) +} + +fn peak_rss_kb() -> Option { + let mut usage = std::mem::MaybeUninit::::uninit(); + // SAFETY: getrusage initializes `usage` when it returns 0. + let rc = unsafe { libc::getrusage(libc::RUSAGE_SELF, usage.as_mut_ptr()) }; + if rc != 0 { + return None; + } + + let maxrss = unsafe { usage.assume_init().ru_maxrss }; + if maxrss < 0 { + return None; + } + let maxrss = maxrss as u64; + #[cfg(target_os = "macos")] + { + Some(maxrss.div_ceil(1024)) + } + #[cfg(not(target_os = "macos"))] + { + Some(maxrss) + } +} + +fn run_lambda(args: &Args) -> BenchMetrics { + let rows = 1usize << args.log_rows; + let options = proof_options(args); + + let initial_values: Vec<(FE, FE)> = (0..args.num_sequences) + .map(|i| (FE::from((i + 1) as u64), FE::from((i + 2) as u64))) + .collect(); + + let mut trace = lambda_fibonacci_pair::compute_trace::(&initial_values, rows); + let pub_inputs = lambda_fibonacci_pair::create_public_inputs(initial_values); + let air = lambda_fibonacci_pair::FibonacciPairMultiColAIR::::with_num_sequences( + &options, + args.num_sequences, + ); + + let start = Instant::now(); + let _proof = Prover::::prove( + &air, + &mut trace, + &pub_inputs, + &mut DefaultTranscript::::new(&[]), + ) + .expect("lambda prove failed"); + let prove_s = start.elapsed().as_secs_f64(); + if args.breakdown { + emit_lambda_breakdown(args, rows, ms(prove_s)); + } + + let proof_size_bytes = serde_cbor::to_vec(&_proof) + .expect("lambda proof serialization failed") + .len(); + + let start = Instant::now(); + let verified = + Verifier::::verify(&_proof, &air, &mut DefaultTranscript::::new(&[])); + let verify_s = start.elapsed().as_secs_f64(); + assert!(verified, "lambda verify failed"); + + BenchMetrics { + prove_s, + verify_s, + proof_size_bytes, + peak_rss_kb: peak_rss_kb(), + } +} + +fn run_p3(args: &Args) -> BenchMetrics { + let rows = 1usize << args.log_rows; + let config = plonky3_config::params_config(args.blowup, args.queries, args.grinding); + let air = plonky3_fibonacci::P3FibonacciAir { + num_sequences: args.num_sequences, + }; + let trace = plonky3_fibonacci::generate_fibonacci_trace(args.num_sequences, rows); + let pis = plonky3_fibonacci::public_values(args.num_sequences); + + let (prove_s, _proof, span_results) = if args.breakdown { + let (subscriber, results) = p3_span_subscriber(); + let start = Instant::now(); + let proof = { + let _guard = tracing::subscriber::set_default(subscriber); + p3_uni_stark::prove(&config, &air, trace, &pis) + }; + (start.elapsed().as_secs_f64(), proof, Some(results)) + } else { + let start = Instant::now(); + let proof = p3_uni_stark::prove(&config, &air, trace, &pis); + (start.elapsed().as_secs_f64(), proof, None) + }; + + if args.breakdown { + print_breakdown("p3", args.log_rows, rows, "prove_total", ms(prove_s), ""); + if let Some(results) = span_results { + let mut span_data = results.lock().unwrap().clone(); + span_data.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap()); + for (name, elapsed_ms) in span_data { + if elapsed_ms >= 0.1 { + let extra = format!("\tspan={name}"); + print_breakdown("p3", args.log_rows, rows, "span", elapsed_ms, &extra); + } + } + } + } + + let proof_size_bytes = serde_cbor::to_vec(&_proof) + .expect("p3 proof serialization failed") + .len(); + + let start = Instant::now(); + p3_uni_stark::verify(&config, &air, &_proof, &pis).expect("p3 verify failed"); + let verify_s = start.elapsed().as_secs_f64(); + + BenchMetrics { + prove_s, + verify_s, + proof_size_bytes, + peak_rss_kb: peak_rss_kb(), + } +} + +fn print_audit(args: &Args) { + let prover_name = match args.prover { + ProverKind::Lambda => "lambda", + ProverKind::P3 => "p3", + }; + let rows = 1usize << args.log_rows; + let main_cols = 2 * args.num_sequences; + let trace_cells = rows * main_cols; + let public_values = 2 * args.num_sequences; + let transition_constraints = 2 * args.num_sequences; + + // Common prefix. + let common = format!( + "AUDIT\tprover={prover_name}\tworkload=fib_pair\tlog_rows={}\trows={rows}\t\ + main_cols={main_cols}\taux_cols=0\ttrace_cells={trace_cells}\t\ + public_values={public_values}", + args.log_rows, + ); + + // Per-prover audit fields. + let prover_specific = match args.prover { + ProverKind::Lambda => format!( + "transition_constraints={transition_constraints}\t\ + base_transition_constraints={transition_constraints}\t\ + boundary_constraints={transition_constraints}\t\ + composition_chunks=1" + ), + ProverKind::P3 => { + // P3 counts 2*num_sequences first-row constraints (boundary equivalent, + // encoded inside the AIR via `when_first_row`) + 2*num_sequences + // transition constraints, total 4*num_sequences. + let air_constraints = 4 * args.num_sequences; + let first_row_constraints = 2 * args.num_sequences; + format!( + "air_constraints={air_constraints}\t\ + first_row_constraints={first_row_constraints}\t\ + transition_constraints={transition_constraints}\t\ + boundary_constraints=0\tquotient_chunks=1\t\ + val_packing_width={}\thash_lanes={}", + plonky3_config::VAL_PACKING_WIDTH, + plonky3_config::HASH_LANES, + ) + } + }; + + let tail = format!( + "blowup={}\tqueries={}\tgrinding={}\t\ + trace_generation_timed=false\tverify_in_ratio=false", + args.blowup, args.queries, args.grinding, + ); + + println!("{common}\t{prover_specific}\t{tail}"); +} + +fn main() -> ExitCode { + let args = match parse_args() { + Ok(a) => a, + Err(e) => { + eprintln!("error: {e}"); + print_usage(); + return ExitCode::from(2); + } + }; + + print_audit(&args); + + let metrics = match args.prover { + ProverKind::Lambda => run_lambda(&args), + ProverKind::P3 => run_p3(&args), + }; + + let prover_name = match args.prover { + ProverKind::Lambda => "lambda", + ProverKind::P3 => "p3", + }; + let rows = 1usize << args.log_rows; + let main_cols = 2 * args.num_sequences; + let aux_cols = 0usize; + let cells = rows * main_cols; + let rows_per_sec = rows as f64 / metrics.prove_s; + let cells_per_sec = cells as f64 / metrics.prove_s; + let peak_rss_kb = metrics + .peak_rss_kb + .map(|v| v.to_string()) + .unwrap_or_else(|| "n/a".to_string()); + + println!("Proving time: {:.6}s", metrics.prove_s); + println!("Verification time: {:.6}s", metrics.verify_s); + println!("Proof size: {} bytes", metrics.proof_size_bytes); + println!("Peak RSS: {peak_rss_kb} KB"); + println!( + "METRICS\tworkload=fib_pair\tprover={prover_name}\tlog_rows={}\trows={rows}\t\ + num_sequences={}\tmain_cols={main_cols}\taux_cols={aux_cols}\ttables=1\t\ + logup=false\tblowup={}\tfri_queries={}\tgrinding={}\tprove_s={:.6}\t\ + verify_s={:.6}\tproof_size_bytes={}\tpeak_rss_kb={peak_rss_kb}\t\ + rows_per_sec={:.3}\tcells_per_sec={:.3}", + args.log_rows, + args.num_sequences, + args.blowup, + args.queries, + args.grinding, + metrics.prove_s, + metrics.verify_s, + metrics.proof_size_bytes, + rows_per_sec, + cells_per_sec, + ); + ExitCode::SUCCESS +} diff --git a/bench_vs_plonky3/src/lambda_fibonacci_pair.rs b/bench_vs_plonky3/src/lambda_fibonacci_pair.rs new file mode 100644 index 000000000..bae1235dc --- /dev/null +++ b/bench_vs_plonky3/src/lambda_fibonacci_pair.rs @@ -0,0 +1,376 @@ +//! Lambda AIR matching Plonky3's `P3FibonacciAir` exactly in shape. +//! +//! Each sequence uses 2 columns (`left`, `right`) with a 2-row transition +//! window, packing 2 Fibonacci steps per row: +//! +//! `local.left = x_{2i}` +//! `local.right = x_{2i+1}` +//! `next.left = x_{2i+2} = local.left + local.right` +//! `next.right = x_{2i+3} = local.right + next.left` +//! +//! For `num_sequences` sequences: +//! - columns = `2 * num_sequences` +//! - transition constraints = `2 * num_sequences` +//! - boundary constraints = `2 * num_sequences` (pin `(a, b)` at row 0) +//! +//! This matches `P3FibonacciAir` cell-by-cell; only the prover internals +//! (multi_prove vs uni-stark, degree-3 vs degree-2 extension) differ. + +use std::marker::PhantomData; + +use math::field::{ + element::FieldElement, + traits::{IsFFTField, IsField, IsSubFieldOf}, +}; +use stark::{ + constraints::{ + boundary::{BoundaryConstraint, BoundaryConstraints}, + transition::TransitionConstraintEvaluator, + }, + context::AirContext, + proof::options::ProofOptions, + trace::TraceTable, + traits::{AIR, TransitionEvaluationContext}, +}; + +/// `next.left = local.left + local.right` (advances 2 Fibonacci steps) +#[derive(Clone)] +pub struct FibPairShiftConstraint +where + F: IsSubFieldOf + IsFFTField + Send + Sync, + E: IsField + Send + Sync, +{ + seq_idx: usize, + constraint_idx: usize, + phantom_f: PhantomData, + phantom_e: PhantomData, +} + +impl FibPairShiftConstraint +where + F: IsSubFieldOf + IsFFTField + Send + Sync, + E: IsField + Send + Sync, +{ + pub fn new(seq_idx: usize, constraint_idx: usize) -> Self { + Self { + seq_idx, + constraint_idx, + phantom_f: PhantomData, + phantom_e: PhantomData, + } + } +} + +impl TransitionConstraintEvaluator for FibPairShiftConstraint +where + F: IsSubFieldOf + IsFFTField + Send + Sync, + E: IsField + Send + Sync, +{ + fn degree(&self) -> usize { + 1 + } + + fn constraint_idx(&self) -> usize { + self.constraint_idx + } + + fn end_exemptions(&self) -> usize { + 1 + } + + fn evaluate_verifier( + &self, + eval_ctx: &TransitionEvaluationContext, + out: &mut [FieldElement], + ) { + match eval_ctx { + TransitionEvaluationContext::Prover { frame, .. } => { + let s0 = frame.get_evaluation_step(0); + let s1 = frame.get_evaluation_step(1); + let local_left = s0.get_main_evaluation_element(0, 2 * self.seq_idx); + let local_right = s0.get_main_evaluation_element(0, 2 * self.seq_idx + 1); + let next_left = s1.get_main_evaluation_element(0, 2 * self.seq_idx); + let res = next_left - local_left - local_right; + out[self.constraint_idx] = res.to_extension(); + } + TransitionEvaluationContext::Verifier { frame, .. } => { + let s0 = frame.get_evaluation_step(0); + let s1 = frame.get_evaluation_step(1); + let local_left = s0.get_main_evaluation_element(0, 2 * self.seq_idx); + let local_right = s0.get_main_evaluation_element(0, 2 * self.seq_idx + 1); + let next_left = s1.get_main_evaluation_element(0, 2 * self.seq_idx); + let res = next_left - local_left - local_right; + out[self.constraint_idx] = res; + } + } + } + + fn evaluate_prover( + &self, + eval_ctx: &TransitionEvaluationContext, + base_evals: &mut [FieldElement], + _ext_evals: &mut [FieldElement], + ) { + let TransitionEvaluationContext::Prover { frame, .. } = eval_ctx else { + unreachable!("evaluate_prover called with non-Prover context"); + }; + let s0 = frame.get_evaluation_step(0); + let s1 = frame.get_evaluation_step(1); + let local_left = s0.get_main_evaluation_element(0, 2 * self.seq_idx); + let local_right = s0.get_main_evaluation_element(0, 2 * self.seq_idx + 1); + let next_left = s1.get_main_evaluation_element(0, 2 * self.seq_idx); + base_evals[self.constraint_idx] = next_left - local_left - local_right; + } +} + +/// `next.right = local.right + next.left` +#[derive(Clone)] +pub struct FibPairSumConstraint +where + F: IsSubFieldOf + IsFFTField + Send + Sync, + E: IsField + Send + Sync, +{ + seq_idx: usize, + constraint_idx: usize, + phantom_f: PhantomData, + phantom_e: PhantomData, +} + +impl FibPairSumConstraint +where + F: IsSubFieldOf + IsFFTField + Send + Sync, + E: IsField + Send + Sync, +{ + pub fn new(seq_idx: usize, constraint_idx: usize) -> Self { + Self { + seq_idx, + constraint_idx, + phantom_f: PhantomData, + phantom_e: PhantomData, + } + } +} + +impl TransitionConstraintEvaluator for FibPairSumConstraint +where + F: IsSubFieldOf + IsFFTField + Send + Sync, + E: IsField + Send + Sync, +{ + fn degree(&self) -> usize { + 1 + } + + fn constraint_idx(&self) -> usize { + self.constraint_idx + } + + fn end_exemptions(&self) -> usize { + 1 + } + + fn evaluate_verifier( + &self, + eval_ctx: &TransitionEvaluationContext, + out: &mut [FieldElement], + ) { + match eval_ctx { + TransitionEvaluationContext::Prover { frame, .. } => { + let s0 = frame.get_evaluation_step(0); + let s1 = frame.get_evaluation_step(1); + let local_right = s0.get_main_evaluation_element(0, 2 * self.seq_idx + 1); + let next_left = s1.get_main_evaluation_element(0, 2 * self.seq_idx); + let next_right = s1.get_main_evaluation_element(0, 2 * self.seq_idx + 1); + let res = next_right - local_right - next_left; + out[self.constraint_idx] = res.to_extension(); + } + TransitionEvaluationContext::Verifier { frame, .. } => { + let s0 = frame.get_evaluation_step(0); + let s1 = frame.get_evaluation_step(1); + let local_right = s0.get_main_evaluation_element(0, 2 * self.seq_idx + 1); + let next_left = s1.get_main_evaluation_element(0, 2 * self.seq_idx); + let next_right = s1.get_main_evaluation_element(0, 2 * self.seq_idx + 1); + let res = next_right - local_right - next_left; + out[self.constraint_idx] = res; + } + } + } + + fn evaluate_prover( + &self, + eval_ctx: &TransitionEvaluationContext, + base_evals: &mut [FieldElement], + _ext_evals: &mut [FieldElement], + ) { + let TransitionEvaluationContext::Prover { frame, .. } = eval_ctx else { + unreachable!("evaluate_prover called with non-Prover context"); + }; + let s0 = frame.get_evaluation_step(0); + let s1 = frame.get_evaluation_step(1); + let local_right = s0.get_main_evaluation_element(0, 2 * self.seq_idx + 1); + let next_left = s1.get_main_evaluation_element(0, 2 * self.seq_idx); + let next_right = s1.get_main_evaluation_element(0, 2 * self.seq_idx + 1); + base_evals[self.constraint_idx] = next_right - local_right - next_left; + } +} + +/// Public inputs: initial `(a, b) = (left, right)` pair for each sequence. +#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)] +#[serde(bound = "")] +pub struct FibonacciPairPublicInputs { + pub initial_values: Vec<(FieldElement, FieldElement)>, +} + +/// Multi-sequence Fibonacci AIR with 2-row window, matching Plonky3's `P3FibonacciAir`. +pub struct FibonacciPairMultiColAIR +where + F: IsSubFieldOf + IsFFTField + Send + Sync, + E: IsField + Send + Sync, +{ + context: AirContext, + constraints: Vec>>, + num_sequences: usize, +} + +impl AIR for FibonacciPairMultiColAIR +where + F: IsSubFieldOf + IsFFTField + Send + Sync + 'static, + E: IsField + Send + Sync + 'static, +{ + type Field = F; + type FieldExtension = E; + type PublicInputs = FibonacciPairPublicInputs; + + fn step_size(&self) -> usize { + 1 + } + + fn name(&self) -> &str { + "fib_pair" + } + + fn new(proof_options: &ProofOptions) -> Self { + Self::with_num_sequences(proof_options, 2) + } + + fn composition_poly_degree_bound(&self, trace_length: usize) -> usize { + trace_length + } + + fn transition_constraints(&self) -> &Vec>> { + &self.constraints + } + + fn num_base_transition_constraints(&self) -> usize { + 2 * self.num_sequences + } + + fn boundary_constraints( + &self, + pub_inputs: &Self::PublicInputs, + _rap_challenges: &[FieldElement], + _bus_public_inputs: Option<&stark::lookup::BusPublicInputs>, + _trace_length: usize, + ) -> BoundaryConstraints { + assert_eq!( + pub_inputs.initial_values.len(), + self.num_sequences, + "AIR built for {} sequences, public inputs carry {}", + self.num_sequences, + pub_inputs.initial_values.len(), + ); + let mut constraints = Vec::with_capacity(2 * pub_inputs.initial_values.len()); + for (seq_idx, (a, b)) in pub_inputs.initial_values.iter().enumerate() { + constraints.push(BoundaryConstraint::new_main( + 2 * seq_idx, + 0, + a.clone().to_extension(), + )); + constraints.push(BoundaryConstraint::new_main( + 2 * seq_idx + 1, + 0, + b.clone().to_extension(), + )); + } + BoundaryConstraints::from_constraints(constraints) + } + + fn context(&self) -> &AirContext { + &self.context + } + + fn trace_layout(&self) -> (usize, usize) { + (2 * self.num_sequences, 0) + } +} + +impl FibonacciPairMultiColAIR +where + F: IsSubFieldOf + IsFFTField + Send + Sync + 'static, + E: IsField + Send + Sync + 'static, +{ + pub fn with_num_sequences(proof_options: &ProofOptions, num_sequences: usize) -> Self { + let mut constraints: Vec>> = + Vec::with_capacity(2 * num_sequences); + for seq in 0..num_sequences { + constraints.push(Box::new(FibPairShiftConstraint::new(seq, 2 * seq))); + constraints.push(Box::new(FibPairSumConstraint::new(seq, 2 * seq + 1))); + } + + let context = AirContext { + proof_options: proof_options.clone(), + trace_columns: 2 * num_sequences, + transition_offsets: vec![0, 1], + num_transition_constraints: 2 * num_sequences, + }; + + Self { + context, + constraints, + num_sequences, + } + } +} + +/// Computes the packed Fibonacci trace. +/// +/// Each row holds `(x_{2i}, x_{2i+1})` for each sequence. Identical values to +/// `plonky3_fibonacci::generate_fibonacci_trace` at the same coordinates. +pub fn compute_trace( + initial_values: &[(FieldElement, FieldElement)], + trace_length: usize, +) -> TraceTable +where + F: IsSubFieldOf + IsFFTField + Send + Sync, + E: IsField + Send + Sync, +{ + let num_sequences = initial_values.len(); + let mut columns: Vec>> = Vec::with_capacity(2 * num_sequences); + + for (a, b) in initial_values { + let mut left_col = Vec::with_capacity(trace_length); + let mut right_col = Vec::with_capacity(trace_length); + + let mut left = a.clone(); + let mut right = b.clone(); + + for _ in 0..trace_length { + left_col.push(left.clone()); + right_col.push(right.clone()); + let new_left = left.clone() + right.clone(); + let new_right = right.clone() + new_left.clone(); + left = new_left; + right = new_right; + } + + columns.push(left_col); + columns.push(right_col); + } + + TraceTable::from_columns_main(columns, 1) +} + +pub fn create_public_inputs( + initial_values: Vec<(FieldElement, FieldElement)>, +) -> FibonacciPairPublicInputs { + FibonacciPairPublicInputs { initial_values } +} diff --git a/bench_vs_plonky3/src/lib.rs b/bench_vs_plonky3/src/lib.rs new file mode 100644 index 000000000..dd5cbf675 --- /dev/null +++ b/bench_vs_plonky3/src/lib.rs @@ -0,0 +1,292 @@ +pub mod lambda_fibonacci_pair; +pub mod plonky3_config; +pub mod plonky3_fibonacci; +pub mod span_timing; + +#[cfg(test)] +mod tests { + use super::*; + + use crypto::fiat_shamir::default_transcript::DefaultTranscript; + use math::field::element::FieldElement; + use math::field::extensions_goldilocks::Degree3GoldilocksExtensionField; + use math::field::goldilocks::GoldilocksField; + use p3_field::PrimeField64; + use p3_uni_stark::{prove, verify}; + use stark::proof::options::ProofOptions; + use stark::prover::{IsStarkProver, Prover}; + use stark::verifier::{IsStarkVerifier, Verifier}; + + type F = GoldilocksField; + type E = Degree3GoldilocksExtensionField; + type FE = FieldElement; + + fn benchmark_proof_options() -> ProofOptions { + ProofOptions { + blowup_factor: 2, + fri_number_of_queries: 219, + coset_offset: 3, + grinding_factor: 0, + } + } + + #[test] + fn lambda_fibonacci_pair_prove_verify() { + let num_sequences = 2; + let trace_length = 128; // 2^7 + let proof_options = benchmark_proof_options(); + + let initial_values: Vec<(FE, FE)> = (0..num_sequences) + .map(|i| (FE::from((i + 1) as u64), FE::from((i + 2) as u64))) + .collect(); + + let mut trace = lambda_fibonacci_pair::compute_trace::(&initial_values, trace_length); + let pub_inputs = lambda_fibonacci_pair::create_public_inputs(initial_values); + let air = lambda_fibonacci_pair::FibonacciPairMultiColAIR::::with_num_sequences( + &proof_options, + num_sequences, + ); + + let proof = Prover::::prove( + &air, + &mut trace, + &pub_inputs, + &mut DefaultTranscript::::new(&[]), + ) + .unwrap(); + + assert!(Verifier::::verify( + &proof, + &air, + &mut DefaultTranscript::::new(&[]), + )); + } + + #[test] + fn plonky3_fibonacci_prove_verify() { + let num_sequences = 2; + let rows = 128; // 2^7 + + let config = plonky3_config::matched_params_config(); + let air = plonky3_fibonacci::P3FibonacciAir { num_sequences }; + let trace = plonky3_fibonacci::generate_fibonacci_trace(num_sequences, rows); + let pis = plonky3_fibonacci::public_values(num_sequences); + + let proof = prove(&config, &air, trace, &pis); + verify(&config, &air, &proof, &pis).expect("Plonky3 verification failed"); + } + + /// Lambda prove with instruments breakdown + P3 span-based breakdown. + /// Run: cargo test -p bench-vs-plonky3 --features instruments --release -- instruments_breakdown --ignored --nocapture + #[test] + #[ignore = "heavy: run with --release -- instruments_breakdown --ignored --nocapture"] + fn instruments_breakdown() { + let num_sequences = 16; + let rows = 1 << 19; + let proof_options = benchmark_proof_options(); + + let initial_values: Vec<(FE, FE)> = (0..num_sequences) + .map(|i| (FE::from((i + 1) as u64), FE::from((i + 2) as u64))) + .collect(); + + let mut trace = lambda_fibonacci_pair::compute_trace::(&initial_values, rows); + let pub_inputs = lambda_fibonacci_pair::create_public_inputs(initial_values); + let air = lambda_fibonacci_pair::FibonacciPairMultiColAIR::::with_num_sequences( + &proof_options, + num_sequences, + ); + + let start = std::time::Instant::now(); + let _proof = Prover::::prove( + &air, + &mut trace, + &pub_inputs, + &mut DefaultTranscript::::new(&[]), + ) + .unwrap(); + let total = start.elapsed(); + + println!("\n============================================================"); + println!( + "Lambda STARK Instruments (blowup={}, queries={})", + proof_options.blowup_factor, proof_options.fri_number_of_queries + ); + println!("Trace: {} rows x {} cols", rows, 2 * num_sequences); + println!("Total prove: {:.3}s", total.as_secs_f64()); + + #[cfg(feature = "instruments")] + if let Some(timing) = stark::instruments::take() { + println!("\n--- High-level phases ---"); + println!( + " Pre-pass: {:>8.1}ms", + timing.prepass.as_secs_f64() * 1000.0 + ); + println!( + " R1 Main commits: {:>8.1}ms", + timing.main_commits.as_secs_f64() * 1000.0 + ); + println!( + " R1 Aux build: {:>8.1}ms", + timing.aux_build.as_secs_f64() * 1000.0 + ); + println!( + " R1 Aux commit: {:>8.1}ms", + timing.aux_commit.as_secs_f64() * 1000.0 + ); + println!( + " Rounds 2-4: {:>8.1}ms", + timing.rounds_2_4.as_secs_f64() * 1000.0 + ); + + let r1 = &timing.round1_sub; + println!("\n--- Round 1 sub-ops ---"); + println!( + " Main LDE (FFT): {:>8.1}ms", + r1.main_lde.as_secs_f64() * 1000.0 + ); + println!( + " Main Merkle: {:>8.1}ms", + r1.main_merkle.as_secs_f64() * 1000.0 + ); + + for (name, tbl_rows, dur, sub) in &timing.table_timings { + println!( + "\n--- Rounds 2-4: {} ({} rows, {:.1}ms) ---", + name, + tbl_rows, + dur.as_secs_f64() * 1000.0 + ); + println!( + " R2 constraint eval:{:>8.1}ms ({:.0}%)", + sub.constraints.as_secs_f64() * 1000.0, + sub.constraints.as_secs_f64() / total.as_secs_f64() * 100.0 + ); + println!( + " R2 decompose+ext: {:>8.1}ms ({:.0}%)", + sub.comp_decompose.as_secs_f64() * 1000.0, + sub.comp_decompose.as_secs_f64() / total.as_secs_f64() * 100.0 + ); + println!( + " R2 comp Merkle: {:>8.1}ms ({:.0}%)", + sub.comp_commit.as_secs_f64() * 1000.0, + sub.comp_commit.as_secs_f64() / total.as_secs_f64() * 100.0 + ); + println!( + " R3 OOD eval: {:>8.1}ms ({:.0}%)", + sub.ood.as_secs_f64() * 1000.0, + sub.ood.as_secs_f64() / total.as_secs_f64() * 100.0 + ); + println!( + " R4 deep comp: {:>8.1}ms ({:.0}%)", + sub.deep_comp.as_secs_f64() * 1000.0, + sub.deep_comp.as_secs_f64() / total.as_secs_f64() * 100.0 + ); + println!( + " R4 deep extend: {:>8.1}ms ({:.0}%)", + sub.deep_extend.as_secs_f64() * 1000.0, + sub.deep_extend.as_secs_f64() / total.as_secs_f64() * 100.0 + ); + println!( + " R4 FRI commit: {:>8.1}ms ({:.0}%)", + sub.fri_commit.as_secs_f64() * 1000.0, + sub.fri_commit.as_secs_f64() / total.as_secs_f64() * 100.0 + ); + println!( + " R4 queries+open: {:>8.1}ms ({:.0}%)", + sub.queries.as_secs_f64() * 1000.0, + sub.queries.as_secs_f64() / total.as_secs_f64() * 100.0 + ); + } + } + + #[cfg(not(feature = "instruments"))] + println!("(rebuild with --features instruments for breakdown)"); + + // --- Plonky3 breakdown via tracing spans --- + // Captures ALL spans (info + debug) so we see quotient_values, FRI commit, etc. + println!("\n============================================================"); + println!("Plonky3 STARK Span Breakdown"); + + use tracing_subscriber::layer::SubscriberExt; + + let (layer, results) = crate::span_timing::P3TimingLayer::new(); + let filter = tracing_subscriber::filter::LevelFilter::DEBUG; + let subscriber = tracing_subscriber::registry().with(filter).with(layer); + + let config = plonky3_config::matched_params_config(); + let p3_air = plonky3_fibonacci::P3FibonacciAir { num_sequences }; + let p3_trace = plonky3_fibonacci::generate_fibonacci_trace(num_sequences, rows); + let p3_pis = plonky3_fibonacci::public_values(num_sequences); + + let p3_prove_dur; + { + let _guard = tracing::subscriber::set_default(subscriber); + let p3_start = std::time::Instant::now(); + let _p3_proof = p3_uni_stark::prove(&config, &p3_air, p3_trace, &p3_pis); + p3_prove_dur = p3_start.elapsed(); + } + + let total_ms = p3_prove_dur.as_secs_f64() * 1000.0; + println!(" Prove total: {:.1}ms\n", total_ms); + + // Sort spans by duration descending and print + let mut span_data = results.lock().unwrap().clone(); + span_data.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap()); + for (name, ms) in &span_data { + if *ms >= 0.1 { + println!( + " {:.<40} {:>8.1}ms ({:.0}%)", + name, + ms, + ms / total_ms * 100.0 + ); + } + } + let accounted: f64 = span_data.iter().map(|(_, ms)| ms).sum(); + let unaccounted = total_ms - accounted; + if unaccounted > 1.0 { + println!( + " {:.<40} {:>8.1}ms ({:.0}%)", + "(unaccounted)", + unaccounted, + unaccounted / total_ms * 100.0 + ); + } + println!("============================================================\n"); + } + + /// Verifies that the new Lambda pair AIR trace and the Plonky3 trace are + /// cell-by-cell identical at the same (row, col) coordinates. + #[test] + fn lambda_pair_trace_matches_plonky3_trace() { + let num_sequences = 3; + let rows = 16; + + let initial_values: Vec<(FE, FE)> = (0..num_sequences) + .map(|i| (FE::from((i + 1) as u64), FE::from((i + 2) as u64))) + .collect(); + + let lambda_trace = lambda_fibonacci_pair::compute_trace::(&initial_values, rows); + let p3_trace = plonky3_fibonacci::generate_fibonacci_trace(num_sequences, rows); + + assert_eq!(p3_trace.width, 2 * num_sequences); + for row in 0..rows { + for seq in 0..num_sequences { + let p3_left = p3_trace.values[row * p3_trace.width + 2 * seq].as_canonical_u64(); + let p3_right = + p3_trace.values[row * p3_trace.width + 2 * seq + 1].as_canonical_u64(); + + assert_eq!( + FE::from(p3_left), + lambda_trace.get_main(row, 2 * seq).clone(), + "left mismatch at row {row}, seq {seq}" + ); + assert_eq!( + FE::from(p3_right), + lambda_trace.get_main(row, 2 * seq + 1).clone(), + "right mismatch at row {row}, seq {seq}" + ); + } + } + } +} diff --git a/bench_vs_plonky3/src/plonky3_config.rs b/bench_vs_plonky3/src/plonky3_config.rs new file mode 100644 index 000000000..d0ead2657 --- /dev/null +++ b/bench_vs_plonky3/src/plonky3_config.rs @@ -0,0 +1,82 @@ +use p3_challenger::{HashChallenger, SerializingChallenger64}; +use p3_commit::ExtensionMmcs; +use p3_dft::Radix2DitParallel; +use p3_field::extension::CubicTrinomialExtensionField; +use p3_fri::{FriParameters, TwoAdicFriPcs}; +use p3_goldilocks::Goldilocks; +use p3_keccak::Keccak256Hash; +use p3_merkle_tree::MerkleTreeMmcs; +use p3_symmetric::{CompressionFunctionFromHasher, SerializingHasher}; +use p3_uni_stark::StarkConfig; + +pub type Val = Goldilocks; +pub type Challenge = CubicTrinomialExtensionField; + +// Scalar byte-oriented MMCS, deliberately not the Plonky3 production config. +// Leaves are individual field elements, digests are 32 raw bytes, and the +// underlying Keccak path is single-input tiny_keccak. This removes the +// `[Val; VECTOR_LEN]` / `[u64; VECTOR_LEN]` Keccak lanes that the +// vector-friendly upstream config uses (NEON=2, SSE2=2, AVX2=4, AVX-512=8), +// so the Merkle compression cost is one Keccak-f per call on both sides. +type ByteHash = Keccak256Hash; +type FieldHash = SerializingHasher; +type MyCompress = CompressionFunctionFromHasher; +pub type ValMmcs = MerkleTreeMmcs; +type ChallengeMmcs = ExtensionMmcs; +type Dft = Radix2DitParallel; +pub type Pcs = TwoAdicFriPcs; +pub type Challenger = SerializingChallenger64>; + +pub type P3Config = StarkConfig; + +/// Packing width of the MMCS leaves (`P` parameter of `MerkleTreeMmcs`). +/// `Val` directly = 1; `[Val; N]` would be `N`. Exposed for the AUDIT line. +pub const VAL_PACKING_WIDTH: usize = 1; + +/// Lanes of the underlying Keccak permutation as seen by the MMCS. +/// `Keccak256Hash` is single-input scalar; lane-vectorized `KeccakF` paths +/// would set this to 2/4/8 depending on arch. +pub const HASH_LANES: usize = 1; + +fn build_mmcs() -> (ValMmcs, ChallengeMmcs, ByteHash) { + let byte_hash = ByteHash {}; + let field_hash = FieldHash::new(byte_hash); + let compress = MyCompress::new(byte_hash); + let val_mmcs = ValMmcs::new(field_hash, compress, 3); + let challenge_mmcs = ChallengeMmcs::new(val_mmcs.clone()); + (val_mmcs, challenge_mmcs, byte_hash) +} + +/// Creates a Plonky3 STARK config with parameters matched to Lambda's proof +/// options. `blowup` must be a power of two because Plonky3 stores it as +/// `log_blowup`. +pub fn params_config(blowup: u8, queries: usize, grinding: u8) -> P3Config { + assert!( + blowup.is_power_of_two(), + "blowup must be a power of two for Plonky3" + ); + + let (val_mmcs, challenge_mmcs, byte_hash) = build_mmcs(); + let dft = Dft::default(); + let challenger = Challenger::from_hasher(vec![], byte_hash); + + let fri_params = FriParameters { + log_blowup: blowup.trailing_zeros() as usize, + log_final_poly_len: 0, + max_log_arity: 1, + num_queries: queries, + commit_proof_of_work_bits: grinding as usize, + query_proof_of_work_bits: 0, + mmcs: challenge_mmcs, + }; + + let pcs = Pcs::new(dft, val_mmcs, fri_params); + P3Config::new(pcs, challenger) +} + +/// Creates a Plonky3 STARK config with parameters matched to Lambda's +/// production config `GoldilocksCubicProofOptions::with_blowup(2)`: +/// blowup=2, 219 FRI queries, grinding=0. +pub fn matched_params_config() -> P3Config { + params_config(2, 219, 0) +} diff --git a/bench_vs_plonky3/src/plonky3_fibonacci.rs b/bench_vs_plonky3/src/plonky3_fibonacci.rs new file mode 100644 index 000000000..b1f0816eb --- /dev/null +++ b/bench_vs_plonky3/src/plonky3_fibonacci.rs @@ -0,0 +1,149 @@ +use p3_air::{Air, AirBuilder, BaseAir, WindowAccess}; +use p3_field::PrimeCharacteristicRing; +use p3_goldilocks::Goldilocks; +use p3_matrix::dense::RowMajorMatrix; + +/// Multi-sequence Fibonacci AIR for Plonky3. +/// +/// Each sequence uses 2 columns (left, right) in a 2-row window, where each +/// Plonky3 row stores two consecutive Lambda rows: +/// local.left = x_{2i} +/// local.right = x_{2i+1} +/// next.left = x_{2i+2} = local.left + local.right +/// next.right = x_{2i+3} = local.right + next.left +/// +/// This packs two consecutive Lambda trace rows into one Plonky3 row. It is the +/// closest encoding of Lambda's `row + 2` Fibonacci transition available in +/// Plonky3's current/next-row AIR window while keeping the same committed cell +/// count. +/// +/// Boundary constraints at the first row pin each sequence's initial (a, b) +/// values against public inputs, matching Lambda's `FibonacciMultiColumnAIR`. +/// +/// Public values layout: `[a_0, b_0, a_1, b_1, ..., a_{N-1}, b_{N-1}]` +/// where `N = num_sequences`. +/// +/// For `num_sequences` sequences, the AIR has `2 * num_sequences` columns +/// and `2 * num_sequences` public values. +pub struct P3FibonacciAir { + pub num_sequences: usize, +} + +impl BaseAir for P3FibonacciAir { + fn width(&self) -> usize { + 2 * self.num_sequences + } + + fn num_public_values(&self) -> usize { + 2 * self.num_sequences + } +} + +/// One sequence's (local_left, local_right, next_left, next_right, a, b) +/// snapshot extracted from an `AirBuilder`. Factored out to keep the +/// `Air::eval` signature readable (clippy::type_complexity). +type FibPairRow = ( + ::Var, + ::Var, + ::Var, + ::Var, + ::PublicVar, + ::PublicVar, +); + +impl Air for P3FibonacciAir { + fn eval(&self, builder: &mut AB) { + let main = builder.main(); + let local = main.current_slice(); + let next = main.next_slice(); + + // Collect (left, right, next_left, next_right, a, b) per sequence so that + // `pis`'s borrow on `builder` can end before we mutate `builder`. + let rows: Vec> = { + let pis = builder.public_values(); + (0..self.num_sequences) + .map(|seq| { + ( + local[2 * seq], + local[2 * seq + 1], + next[2 * seq], + next[2 * seq + 1], + pis[2 * seq], + pis[2 * seq + 1], + ) + }) + .collect() + }; + drop(main); + + for (left, right, next_left, next_right, a, b) in rows { + // Boundary: first row pins (left, right) = (a, b) + let mut when_first_row = builder.when_first_row(); + when_first_row.assert_eq(left, a); + when_first_row.assert_eq(right, b); + + let mut when_transition = builder.when_transition(); + // Advance two Lambda rows per Plonky3 row. + when_transition.assert_eq(next_left, left + right); + when_transition.assert_eq(next_right, right + next_left); + } + } +} + +/// Generates a Fibonacci trace for Plonky3. +/// +/// For `num_sequences` sequences and `num_rows` rows (must be power of 2), +/// produces a `RowMajorMatrix` with `2 * num_sequences` columns. +/// Use `rows_for_lambda_trace(lambda_trace_length)` when comparing against +/// Lambda's one-column-per-sequence trace. +/// +/// Each sequence `s` starts with initial values matching Lambda's +/// `create_initial_values()`: `left = s + 1`, `right = s + 2`. +pub fn generate_fibonacci_trace( + num_sequences: usize, + num_rows: usize, +) -> RowMajorMatrix { + assert!(num_rows.is_power_of_two(), "num_rows must be a power of 2"); + let width = 2 * num_sequences; + let mut values = vec![Goldilocks::ZERO; width * num_rows]; + + for seq in 0..num_sequences { + let mut left = Goldilocks::from_u64((seq + 1) as u64); + let mut right = Goldilocks::from_u64((seq + 2) as u64); + + for row in 0..num_rows { + values[row * width + 2 * seq] = left; + values[row * width + 2 * seq + 1] = right; + let next_left = left + right; + let next_right = right + next_left; + left = next_left; + right = next_right; + } + } + + RowMajorMatrix::new(values, width) +} + +/// Returns the number of packed Plonky3 rows for a Lambda trace length. +pub fn rows_for_lambda_trace(lambda_trace_length: usize) -> usize { + assert!( + lambda_trace_length >= 2, + "lambda_trace_length must contain at least two rows" + ); + assert!( + lambda_trace_length.is_power_of_two(), + "lambda_trace_length must be a power of 2" + ); + lambda_trace_length / 2 +} + +/// Builds public values matching `generate_fibonacci_trace`'s initial values: +/// `[a_0, b_0, a_1, b_1, ...] = [1, 2, 2, 3, 3, 4, ...]` +pub fn public_values(num_sequences: usize) -> Vec { + let mut pis = Vec::with_capacity(2 * num_sequences); + for seq in 0..num_sequences { + pis.push(Goldilocks::from_u64((seq + 1) as u64)); + pis.push(Goldilocks::from_u64((seq + 2) as u64)); + } + pis +} diff --git a/bench_vs_plonky3/src/span_timing.rs b/bench_vs_plonky3/src/span_timing.rs new file mode 100644 index 000000000..4d37423fb --- /dev/null +++ b/bench_vs_plonky3/src/span_timing.rs @@ -0,0 +1,83 @@ +//! Tracing layer that accumulates per-span wall-clock durations from +//! Plonky3's `tracing` instrumentation. Used by `prove_bench --breakdown` +//! and by the `instruments_breakdown` test. + +use std::collections::HashMap; +use std::sync::{Arc, Mutex}; +use std::time::{Duration, Instant}; + +use tracing_subscriber::Layer; + +pub type SpanResults = Arc>>; + +struct SpanState { + name: String, + active_since: Option, + accumulated: Duration, +} + +pub struct P3TimingLayer { + spans: Mutex>, + results: SpanResults, +} + +impl P3TimingLayer { + pub fn new() -> (Self, SpanResults) { + let results: SpanResults = Arc::new(Mutex::new(Vec::new())); + let layer = Self { + spans: Mutex::new(HashMap::new()), + results: Arc::clone(&results), + }; + (layer, results) + } +} + +impl Layer for P3TimingLayer +where + S: tracing::Subscriber + for<'lookup> tracing_subscriber::registry::LookupSpan<'lookup>, +{ + fn on_new_span( + &self, + attrs: &tracing::span::Attributes<'_>, + id: &tracing::span::Id, + _ctx: tracing_subscriber::layer::Context<'_, S>, + ) { + self.spans.lock().unwrap().insert( + id.into_u64(), + SpanState { + name: attrs.metadata().name().to_string(), + active_since: None, + accumulated: Duration::ZERO, + }, + ); + } + + fn on_enter(&self, id: &tracing::span::Id, _ctx: tracing_subscriber::layer::Context<'_, S>) { + if let Some(entry) = self.spans.lock().unwrap().get_mut(&id.into_u64()) + && entry.active_since.is_none() + { + entry.active_since = Some(Instant::now()); + } + } + + fn on_exit(&self, id: &tracing::span::Id, _ctx: tracing_subscriber::layer::Context<'_, S>) { + if let Some(entry) = self.spans.lock().unwrap().get_mut(&id.into_u64()) + && let Some(start) = entry.active_since.take() + { + entry.accumulated += start.elapsed(); + } + } + + fn on_close(&self, id: tracing::span::Id, _ctx: tracing_subscriber::layer::Context<'_, S>) { + if let Some(entry) = self.spans.lock().unwrap().remove(&id.into_u64()) { + let mut total = entry.accumulated; + if let Some(start) = entry.active_since { + total += start.elapsed(); + } + self.results + .lock() + .unwrap() + .push((entry.name, total.as_secs_f64() * 1000.0)); + } + } +}