diff --git a/.github/workflows/end-to-end.yml b/.github/workflows/end-to-end.yml index 774edeec4..71eed6bdd 100644 --- a/.github/workflows/end-to-end.yml +++ b/.github/workflows/end-to-end.yml @@ -102,38 +102,37 @@ jobs: ./benchmark-inputs/t_attest-proof.np - # Disabled gnark, check https://github.com/worldfnd/provekit/issues/302 - # - name: Run Gnark verifier - # working-directory: recursive-verifier - # run: | - # go build -o gnark-verifier cmd/cli/main.go - - # # Set up cleanup trap - # cleanup() { - # if [ ! -z "$MONITOR_PID" ]; then - # kill $MONITOR_PID 2>/dev/null || true - # fi - # } - # trap cleanup EXIT - - # # Start monitoring in background - # ( - # while true; do - # echo "=== $(date) ===" - # echo "Memory:" - # free -h - # echo "Disk:" - # df -h - # echo "Processes:" - # ps aux --sort=-%mem | head -5 - # echo "==================" - # sleep 10 # Check every 10 seconds - # done - # ) & - # MONITOR_PID=$! - - # # Run the main process - # ./gnark-verifier --config "../noir-examples/noir-passport/merkle_age_check/params_for_recursive_verifier" --r1cs "../noir-examples/noir-passport/merkle_age_check/r1cs.json" - - # # Stop monitoring - # kill $MONITOR_PID \ No newline at end of file + - name: Run Gnark verifier + working-directory: recursive-verifier + run: | + go build -o gnark-verifier cmd/cli/main.go + + # Set up cleanup trap + cleanup() { + if [ ! -z "$MONITOR_PID" ]; then + kill $MONITOR_PID 2>/dev/null || true + fi + } + trap cleanup EXIT + + # Start monitoring in background + ( + while true; do + echo "=== $(date) ===" + echo "Memory:" + free -h + echo "Disk:" + df -h + echo "Processes:" + ps aux --sort=-%mem | head -5 + echo "==================" + sleep 10 # Check every 10 seconds + done + ) & + MONITOR_PID=$! + + # Run the main process + ./gnark-verifier --config "../noir-examples/noir-passport/merkle_age_check/params_for_recursive_verifier" --r1cs "../noir-examples/noir-passport/merkle_age_check/r1cs.json" + + # Stop monitoring + kill $MONITOR_PID \ No newline at end of file diff --git a/Cargo.lock b/Cargo.lock index 3a9bc6ef7..1270271e5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -49,7 +49,7 @@ dependencies = [ "acvm_blackbox_solver", "brillig_vm", "fxhash", - "indexmap 2.13.0", + "indexmap 2.13.1", "serde", "thiserror 1.0.69", "tracing", @@ -117,9 +117,9 @@ checksum = "683d7910e743518b0e34f1186f92494becacb047c7b6bf616c96772180fef923" [[package]] name = "alloy-rlp" -version = "0.3.13" +version = "0.3.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e93e50f64a77ad9c5470bf2ad0ca02f228da70c792a8f06634801e202579f35e" +checksum = "dc90b1e703d3c03f4ff7f48e82dd0bc1c8211ab7d079cd836a06fcfeb06651cb" dependencies = [ "arrayvec", "bytes", @@ -136,9 +136,9 @@ dependencies = [ [[package]] name = "anstream" -version = "0.6.21" +version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "43d5b281e737544384e969a5ccad3f1cdd24b48086a0fc1b2a5262a26b8f4f4a" +checksum = "824a212faf96e9acacdbd09febd34438f8f711fb84e09a8916013cd7815ca28d" dependencies = [ "anstyle", "anstyle-parse", @@ -151,15 +151,15 @@ dependencies = [ [[package]] name = "anstyle" -version = "1.0.13" +version = "1.0.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5192cca8006f1fd4f7237516f40fa183bb07f8fbdfedaa0036de5ea9b0b45e78" +checksum = "940b3a0ca603d1eade50a4846a2afffd5ef57a9feac2c0e2ec2e14f9ead76000" [[package]] name = "anstyle-parse" -version = "0.2.7" +version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4e7644824f0aa2c7b9384579234ef10eb7efb6a0deb83f9630a49594dd9c15c2" +checksum = "52ce7f38b242319f7cabaa6813055467063ecdc9d355bbb4ce0c68908cd8130e" dependencies = [ "utf8parse", ] @@ -198,9 +198,9 @@ checksum = "c3d036a3c4ab069c7b410a2ce876bd74808d2d0888a82667669f8e783a898bf1" [[package]] name = "argh" -version = "0.1.15" +version = "0.1.19" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d32c2462e89541e6687e684d97310015d64a0627b61106fc472156a38f61cd1e" +checksum = "211818e820cda9ca6f167a64a5c808837366a6dfd807157c64c1304c486cd033" dependencies = [ "argh_derive", "argh_shared", @@ -208,9 +208,9 @@ dependencies = [ [[package]] name = "argh_derive" -version = "0.1.15" +version = "0.1.19" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ccc2a031b364bd099fed016feb1ccfca2c3549d63c16f330cfc40b27b7692231" +checksum = "c442a9d18cef5dde467405d27d461d080d68972d6d0dfd0408265b6749ec427d" dependencies = [ "argh_shared", "proc-macro2", @@ -220,9 +220,9 @@ dependencies = [ [[package]] name = "argh_shared" -version = "0.1.15" +version = "0.1.19" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5b9abea17ef74821d1d3490aee9e0749d731445d965b7512308b2aa00c90079e" +checksum = "e5ade012bac4db278517a0132c8c10c6427025868dca16c801087c28d5a411f1" dependencies = [ "serde", ] @@ -857,10 +857,11 @@ dependencies = [ [[package]] name = "borsh" -version = "1.6.0" +version = "1.6.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d1da5ab77c1437701eeff7c88d968729e7766172279eab0676857b3d63af7a6f" +checksum = "cfd1e3f8955a5d7de9fab72fc8373fade9fb8a703968cb200ae3dc6cf08e185a" dependencies = [ + "bytes", "cfg_aliases", ] @@ -938,9 +939,9 @@ checksum = "1e748733b7cbc798e1434b6ac524f0c1ff2ab456fe201501e6497c8417a4fc33" [[package]] name = "cc" -version = "1.2.56" +version = "1.2.58" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "aebf35691d1bfb0ac386a69bac2fde4dd276fb618cf8bf4f5318fe285e821bb2" +checksum = "e1e928d4b69e3077709075a938a05ffbedfa53a84c8f766efbf8220bb1ff60e1" dependencies = [ "find-msvc-tools", "jobserver", @@ -1009,9 +1010,9 @@ dependencies = [ [[package]] name = "clap" -version = "4.5.60" +version = "4.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2797f34da339ce31042b27d23607e051786132987f595b02ba4f6a6dffb7030a" +checksum = "b193af5b67834b676abd72466a96c1024e6a6ad978a1f484bd90b85c94041351" dependencies = [ "clap_builder", "clap_derive", @@ -1019,9 +1020,9 @@ dependencies = [ [[package]] name = "clap_builder" -version = "4.5.60" +version = "4.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "24a241312cea5059b13574bb9b3861cabf758b879c15190b37b6d6fd63ab6876" +checksum = "714a53001bf66416adb0e2ef5ac857140e7dc3a0c48fb28b2f10762fc4b5069f" dependencies = [ "anstream", "anstyle", @@ -1032,18 +1033,18 @@ dependencies = [ [[package]] name = "clap_complete" -version = "4.5.66" +version = "4.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c757a3b7e39161a4e56f9365141ada2a6c915a8622c408ab6bb4b5d047371031" +checksum = "19c9f1dde76b736e3681f28cec9d5a61299cbaae0fce80a68e43724ad56031eb" dependencies = [ "clap", ] [[package]] name = "clap_derive" -version = "4.5.55" +version = "4.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a92793da1a46a5f2a02a6f4c46c6496b28c43638adea8306fcb0caa1634f24e5" +checksum = "1110bd8a634a1ab8cb04345d8d878267d57c3cf1b38d91b71af6686408bbca6a" dependencies = [ "heck 0.5.0", "proc-macro2", @@ -1053,9 +1054,9 @@ dependencies = [ [[package]] name = "clap_lex" -version = "1.0.0" +version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3a822ea5bc7590f9d40f1ba12c0dc3c2760f3482c6984db1573ad11031420831" +checksum = "c8d4a3bb8b1e0c1050499d1815f5ab16d04f0959b233085fb31653fbfc9d98f9" [[package]] name = "clipboard-win" @@ -1138,9 +1139,9 @@ dependencies = [ [[package]] name = "colorchoice" -version = "1.0.4" +version = "1.0.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b05b61dc5112cbb17e4b6cd61790d9845d13888356391624cbe7e41efeac1e75" +checksum = "1d07550c9036bf2ae0c684c4297d503f838287c83c53686d05370d0e139ae570" [[package]] name = "combine" @@ -1349,9 +1350,9 @@ dependencies = [ [[package]] name = "crypto-common" -version = "0.1.7" +version = "0.1.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "78c8292055d1c1df0cce5d180393dc8cce0abec0a7102adb6c7b1eef6016d60a" +checksum = "1bfb12502f3fc46cca1bb51ac28df9d618d813cdc3d2f25b9fe775a34af26bb3" dependencies = [ "generic-array", "typenum", @@ -1391,9 +1392,9 @@ dependencies = [ [[package]] name = "darling" -version = "0.21.3" +version = "0.23.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9cdf337090841a411e2a7f3deb9187445851f91b309c0c0a29e05f74a00a48c0" +checksum = "25ae13da2f202d56bd7f91c25fba009e7717a1e4a1cc98a76d844b65ae912e9d" dependencies = [ "darling_core", "darling_macro", @@ -1401,11 +1402,10 @@ dependencies = [ [[package]] name = "darling_core" -version = "0.21.3" +version = "0.23.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1247195ecd7e3c85f83c8d2a366e4210d588e802133e1e355180a9870b517ea4" +checksum = "9865a50f7c335f53564bb694ef660825eb8610e0a53d3e11bf1b0d3df31e03b0" dependencies = [ - "fnv", "ident_case", "proc-macro2", "quote", @@ -1415,9 +1415,9 @@ dependencies = [ [[package]] name = "darling_macro" -version = "0.21.3" +version = "0.23.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d38308df82d1080de0afee5d069fa14b0326a88c14f15c5ccda35b4a6c414c81" +checksum = "ac3984ec7bd6cfa798e62b4a642426a5be0e68f9401cfc2a01e3fa9ea2fcdb8d" dependencies = [ "darling_core", "quote", @@ -1478,9 +1478,9 @@ dependencies = [ [[package]] name = "derive-where" -version = "1.6.0" +version = "1.6.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ef941ded77d15ca19b40374869ac6000af1c9f2a4c0f3d4c70926287e6364a8f" +checksum = "d08b3a0bcc0d079199cd476b2cae8435016ec11d1c0986c6901c5ac223041534" dependencies = [ "proc-macro2", "quote", @@ -1719,9 +1719,9 @@ dependencies = [ [[package]] name = "env_filter" -version = "1.0.0" +version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7a1c3cc8e57274ec99de65301228b537f1e4eedc1b8e0f9411c6caac8ae7308f" +checksum = "32e90c2accc4b07a8456ea0debdc2e7587bdd890680d71173a15d4ae604f6eef" dependencies = [ "log", "regex", @@ -1729,9 +1729,9 @@ dependencies = [ [[package]] name = "env_logger" -version = "0.11.9" +version = "0.11.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b2daee4ea451f429a58296525ddf28b45a3b64f1acf6587e2067437bb11e218d" +checksum = "0621c04f2196ac3f488dd583365b9c09be011a4ab8b9f37248ffcc8f6198b56a" dependencies = [ "env_filter", "log", @@ -1848,7 +1848,7 @@ checksum = "f98844151eee8917efc50bd9e8318cb963ae8b297431495d3f758616ea5c57db" dependencies = [ "cfg-if", "libc", - "libredox 0.1.14", + "libredox 0.1.15", ] [[package]] @@ -2074,9 +2074,9 @@ dependencies = [ [[package]] name = "generic-array" -version = "0.14.7" +version = "0.14.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "85649ca51fd72272d7821adaf274ad91c288277713d9c18820d8499a7ff69e9a" +checksum = "4bb6743198531e02858aeaea5398fcc883e71851fcbcb5a2f773e2fb6cb1edf2" dependencies = [ "typenum", "version_check", @@ -2153,7 +2153,7 @@ dependencies = [ "futures-core", "futures-sink", "http", - "indexmap 2.13.0", + "indexmap 2.13.1", "slab", "tokio", "tokio-util", @@ -2311,9 +2311,9 @@ checksum = "df3b46402a9d5adb4c86a0cf463f42e19994e3ee891101b1841f30a545cb49a9" [[package]] name = "hyper" -version = "1.8.1" +version = "1.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2ab2d4f250c3d7b1c9fcdff1cece94ea4e2dfbec68614f7b87cb205f24ca9d11" +checksum = "6299f016b246a94207e63da54dbe807655bf9e00044f73ded42c3ac5305fbcca" dependencies = [ "atomic-waker", "bytes", @@ -2326,7 +2326,6 @@ dependencies = [ "httpdate", "itoa", "pin-project-lite", - "pin-utils", "smallvec", "tokio", "want", @@ -2416,12 +2415,13 @@ dependencies = [ [[package]] name = "icu_collections" -version = "2.1.1" +version = "2.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4c6b649701667bbe825c3b7e6388cb521c23d88644678e83c0c4d0a621a34b43" +checksum = "2984d1cd16c883d7935b9e07e44071dca8d917fd52ecc02c04d5fa0b5a3f191c" dependencies = [ "displaydoc", "potential_utf", + "utf8_iter", "yoke", "zerofrom", "zerovec", @@ -2429,9 +2429,9 @@ dependencies = [ [[package]] name = "icu_locale_core" -version = "2.1.1" +version = "2.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "edba7861004dd3714265b4db54a3c390e880ab658fec5f7db895fae2046b5bb6" +checksum = "92219b62b3e2b4d88ac5119f8904c10f8f61bf7e95b640d25ba3075e6cac2c29" dependencies = [ "displaydoc", "litemap", @@ -2442,9 +2442,9 @@ dependencies = [ [[package]] name = "icu_normalizer" -version = "2.1.1" +version = "2.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5f6c8828b67bf8908d82127b2054ea1b4427ff0230ee9141c54251934ab1b599" +checksum = "c56e5ee99d6e3d33bd91c5d85458b6005a22140021cc324cea84dd0e72cff3b4" dependencies = [ "icu_collections", "icu_normalizer_data", @@ -2456,15 +2456,15 @@ dependencies = [ [[package]] name = "icu_normalizer_data" -version = "2.1.1" +version = "2.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7aedcccd01fc5fe81e6b489c15b247b8b0690feb23304303a9e560f37efc560a" +checksum = "da3be0ae77ea334f4da67c12f149704f19f81d1adf7c51cf482943e84a2bad38" [[package]] name = "icu_properties" -version = "2.1.2" +version = "2.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "020bfc02fe870ec3a66d93e677ccca0562506e5872c650f893269e08615d74ec" +checksum = "bee3b67d0ea5c2cca5003417989af8996f8604e34fb9ddf96208a033901e70de" dependencies = [ "icu_collections", "icu_locale_core", @@ -2476,15 +2476,15 @@ dependencies = [ [[package]] name = "icu_properties_data" -version = "2.1.2" +version = "2.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "616c294cf8d725c6afcd8f55abc17c56464ef6211f9ed59cccffe534129c77af" +checksum = "8e2bbb201e0c04f7b4b3e14382af113e17ba4f63e2c9d2ee626b720cbce54a14" [[package]] name = "icu_provider" -version = "2.1.1" +version = "2.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "85962cf0ce02e1e0a629cc34e7ca3e373ce20dda4c4d7294bbd0bf1fdb59e614" +checksum = "139c4cf31c8b5f33d7e199446eff9c1e02decfc2f0eec2c8d71f65befa45b421" dependencies = [ "displaydoc", "icu_locale_core", @@ -2591,9 +2591,9 @@ dependencies = [ [[package]] name = "indexmap" -version = "2.13.0" +version = "2.13.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7714e70437a7dc3ac8eb7e6f8df75fd8eb422675fc7678aff7364301092b1017" +checksum = "45a8a2b9cb3e0b0c1803dbb0758ffac5de2f425b23c28f518faabd9d805342ff" dependencies = [ "equivalent", "hashbrown 0.16.1", @@ -2638,9 +2638,9 @@ checksum = "d98f6fed1fde3f8c21bc40a1abb88dd75e67924f9cffc3ef95607bad8017f8e2" [[package]] name = "iri-string" -version = "0.7.10" +version = "0.7.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c91338f0783edbd6195decb37bae672fd3b165faffb89bf7b9e6942f8b1a731a" +checksum = "25e659a4bb38e810ebc252e53b5814ff908a8c58c2a9ce2fae1bbec24cbf4e20" dependencies = [ "memchr", "serde", @@ -2697,9 +2697,9 @@ dependencies = [ [[package]] name = "itoa" -version = "1.0.17" +version = "1.0.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "92ecc6618181def0457392ccd0ee51198e065e016d1d527a7ac1b6dc7c1f09d2" +checksum = "8f42a60cbdf9a97f5d2305f08a87dc4e09308d1276d28c869c684d7777685682" [[package]] name = "jni" @@ -2710,7 +2710,7 @@ dependencies = [ "cesu8", "cfg-if", "combine", - "jni-sys", + "jni-sys 0.3.1", "log", "thiserror 1.0.69", "walkdir", @@ -2719,9 +2719,31 @@ dependencies = [ [[package]] name = "jni-sys" -version = "0.3.0" +version = "0.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8eaf4bc02d17cbdd7ff4c7438cafcdf7fb9a4613313ad11b4f8fefe7d3fa0130" +checksum = "41a652e1f9b6e0275df1f15b32661cf0d4b78d4d87ddec5e0c3c20f097433258" +dependencies = [ + "jni-sys 0.4.1", +] + +[[package]] +name = "jni-sys" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c6377a88cb3910bee9b0fa88d4f42e1d2da8e79915598f65fb0c7ee14c878af2" +dependencies = [ + "jni-sys-macros", +] + +[[package]] +name = "jni-sys-macros" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "38c0b942f458fe50cdac086d2f946512305e5631e720728f2a61aabcd47a6264" +dependencies = [ + "quote", + "syn 2.0.117", +] [[package]] name = "jobserver" @@ -2918,9 +2940,9 @@ checksum = "82903360c009b816f5ab72a9b68158c27c301ee2c3f20655b55c5e589e7d3bb7" [[package]] name = "libc" -version = "0.2.182" +version = "0.2.184" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6800badb6cb2082ffd7b6a67e6125bb39f18782f793520caee8cb8846be06112" +checksum = "48f5d2a454e16a5ea0f4ced81bd44e4cfc7bd3a507b61887c99fd3538b28e4af" [[package]] name = "libm" @@ -2941,9 +2963,9 @@ dependencies = [ [[package]] name = "libredox" -version = "0.1.14" +version = "0.1.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1744e39d1d6a9948f4f388969627434e31128196de472883b39f148769bfe30a" +checksum = "7ddbf48fd451246b1f8c2610bd3b4ac0cc6e149d89832867093ab69a17194f08" dependencies = [ "bitflags 2.11.0", "libc", @@ -2965,9 +2987,9 @@ checksum = "32a66949e030da00e8c7d4434b251670a91556f4144941d37452769c25d58a53" [[package]] name = "litemap" -version = "0.8.1" +version = "0.8.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6373607a59f0be73a39b6fe456b8192fcc3585f602af20751600e974dd455e77" +checksum = "92daf443525c4cce67b150400bc2316076100ce0b3686209eb8cf3c31612e6f0" [[package]] name = "lock_api" @@ -3141,9 +3163,9 @@ dependencies = [ [[package]] name = "mio" -version = "1.1.1" +version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a69bcab0ad47271a0234d9422b131806bf3968021e5dc9328caf2d4cd58557fc" +checksum = "50b7e5b27aa02a74bac8c3f23f448f8d87ff11f92d3aac1a6ed369ee08cc56c1" dependencies = [ "libc", "wasi", @@ -3715,9 +3737,9 @@ dependencies = [ [[package]] name = "num-conv" -version = "0.2.0" +version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cf97ec579c3c42f953ef76dbf8d55ac91fb219dde70e49aa4a6b7d74e9919050" +checksum = "c6673768db2d862beb9b39a78fdcb1a69439615d5794a1be50caa9bc92c81967" [[package]] name = "num-integer" @@ -3751,9 +3773,9 @@ dependencies = [ [[package]] name = "num_enum" -version = "0.7.5" +version = "0.7.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b1207a7e20ad57b847bbddc6776b968420d38292bbfe2089accff5e19e82454c" +checksum = "5d0bca838442ec211fa11de3a8b0e0e8f3a4522575b5c4c06ed722e005036f26" dependencies = [ "num_enum_derive", "rustversion", @@ -3761,9 +3783,9 @@ dependencies = [ [[package]] name = "num_enum_derive" -version = "0.7.5" +version = "0.7.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ff32365de1b6743cb203b710788263c44a03de03802daf96092f2da4fe6ba4d7" +checksum = "680998035259dcfcafe653688bf2aa6d3e2dc05e98be6ab46afb089dc84f1df8" dependencies = [ "proc-macro-crate", "proc-macro2", @@ -3797,9 +3819,9 @@ dependencies = [ [[package]] name = "once_cell" -version = "1.21.3" +version = "1.21.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "42f5e15c9953c5e4ccceeb2e7382a716482c34515315f7b03532b8b4e8393d2d" +checksum = "9f7c3e4beb33f85d45ae3e3a1792185706c8e16d043238c593331cc7cd313b50" [[package]] name = "once_cell_polyfill" @@ -3820,9 +3842,9 @@ dependencies = [ [[package]] name = "openssl" -version = "0.10.75" +version = "0.10.76" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "08838db121398ad17ab8531ce9de97b244589089e290a384c900cb9ff7434328" +checksum = "951c002c75e16ea2c65b8c7e4d3d51d5530d8dfa7d060b4776828c88cfb18ecf" dependencies = [ "bitflags 2.11.0", "cfg-if", @@ -3852,9 +3874,9 @@ checksum = "7c87def4c32ab89d880effc9e097653c8da5d6ef28e6b539d313baaacfbafcbe" [[package]] name = "openssl-sys" -version = "0.9.111" +version = "0.9.112" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "82cab2d520aa75e3c58898289429321eb788c3106963d0dc886ec7a5f4adc321" +checksum = "57d55af3b3e226502be1526dfdba67ab0e9c96fc293004e79576b2b9edb0dbdb" dependencies = [ "cc", "libc", @@ -3864,9 +3886,9 @@ dependencies = [ [[package]] name = "ordered-float" -version = "5.1.0" +version = "5.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7f4779c6901a562440c3786d08192c6fbda7c1c2060edd10006b05ee35d10f2d" +checksum = "b7d950ca161dc355eaf28f82b11345ed76c6e1f6eb1f4f4479e0323b9e2fbd0e" dependencies = [ "num-traits", "rand 0.8.5", @@ -4162,7 +4184,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3672b37090dbd86368a4145bc067582552b29c27377cad4e0a306c97f9bd7772" dependencies = [ "fixedbitset", - "indexmap 2.13.0", + "indexmap 2.13.1", ] [[package]] @@ -4173,7 +4195,7 @@ checksum = "8701b58ea97060d5e5b155d383a69952a60943f0e6dfe30b04c287beb0b27455" dependencies = [ "fixedbitset", "hashbrown 0.15.5", - "indexmap 2.13.0", + "indexmap 2.13.1", "serde", ] @@ -4203,12 +4225,6 @@ version = "0.2.17" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a89322df9ebe1c1578d689c92318e070967d1042b512afbe49518723f4e6d5cd" -[[package]] -name = "pin-utils" -version = "0.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" - [[package]] name = "pkcs1" version = "0.7.5" @@ -4266,9 +4282,9 @@ dependencies = [ [[package]] name = "potential_utf" -version = "0.1.4" +version = "0.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b73949432f5e2a09657003c25bca5e19a0e9c84f8058ca374f49e0ebe605af77" +checksum = "0103b1cef7ec0cf76490e969665504990193874ea05c85ff9bab8b911d0a0564" dependencies = [ "zerovec", ] @@ -4349,7 +4365,7 @@ version = "3.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e67ba7e9b2b56446f1d419b1d807906278ffa1a658a8a5d8a39dcb1f5a78614f" dependencies = [ - "toml_edit 0.25.4+spec-1.1.0", + "toml_edit 0.25.10+spec-1.1.0", ] [[package]] @@ -4363,9 +4379,9 @@ dependencies = [ [[package]] name = "proptest" -version = "1.10.0" +version = "1.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "37566cb3fdacef14c0737f9546df7cfeadbfbc9fef10991038bf5015d0c80532" +checksum = "4b45fcc2344c680f5025fe57779faef368840d0bd1f42f216291f0dc4ace4744" dependencies = [ "bit-set", "bit-vec", @@ -4396,7 +4412,7 @@ version = "0.13.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "be769465445e8c1474e9c5dac2018218498557af32d9ed057325ec9a41ae81bf" dependencies = [ - "heck 0.5.0", + "heck 0.4.1", "itertools 0.14.0", "log", "multimap", @@ -4614,6 +4630,7 @@ dependencies = [ "ark-ff 0.5.0", "ark-std 0.5.0", "bn254_blackbox_solver", + "hex", "mavros-artifacts", "mavros-vm", "nargo", @@ -4977,7 +4994,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ba009ff324d1fc1b900bd1fdb31564febe58a8ccc8a6fdbb93b543d33b13ca43" dependencies = [ "getrandom 0.2.17", - "libredox 0.1.14", + "libredox 0.1.15", "thiserror 1.0.69", ] @@ -5226,9 +5243,9 @@ checksum = "b50b8869d9fc858ce7266cce0194bd74df58b9d0e3f6df3a9fc8eb470d95c09d" [[package]] name = "rustc-hash" -version = "2.1.1" +version = "2.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "357703d41365b4b27c590e3ed91eabb1b663f07c4c084095e60cbed4362dff0d" +checksum = "94300abf3f1ae2e2b8ffb7b58043de3d399c73fa6f4b73826402a5c457614dbe" [[package]] name = "rustc-hex" @@ -5354,9 +5371,9 @@ checksum = "f87165f0995f63a9fbeea62b64d10b4d9d8e78ec6d7d51fb2125fda7bb36788f" [[package]] name = "rustls-webpki" -version = "0.103.9" +version = "0.103.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d7df23109aa6c1567d1c575b9952556388da57401e4ace1d15f79eedad0d8f53" +checksum = "df33b2b81ac578cabaf06b89b0631153a3f416b0a886e8a7a1707fb51abbd1ef" dependencies = [ "ring", "rustls-pki-types", @@ -5488,9 +5505,9 @@ dependencies = [ [[package]] name = "schannel" -version = "0.1.28" +version = "0.1.29" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "891d81b926048e76efe18581bf793546b4c0eaf8448d72be8de2bbee5fd166e1" +checksum = "91c1b7e4904c873ef0710c1f407dde2e6287de2bebc1bbbf7d430bb7cbffd939" dependencies = [ "windows-sys 0.61.2", ] @@ -5706,15 +5723,15 @@ dependencies = [ [[package]] name = "serde_with" -version = "3.17.0" +version = "3.18.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "381b283ce7bc6b476d903296fb59d0d36633652b633b27f64db4fb46dcbfc3b9" +checksum = "dd5414fad8e6907dbdd5bc441a50ae8d6e26151a03b1de04d89a5576de61d01f" dependencies = [ "base64", "chrono", "hex", "indexmap 1.9.3", - "indexmap 2.13.0", + "indexmap 2.13.1", "schemars 0.9.0", "schemars 1.2.1", "serde_core", @@ -5725,9 +5742,9 @@ dependencies = [ [[package]] name = "serde_with_macros" -version = "3.17.0" +version = "3.18.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a6d4e30573c8cb306ed6ab1dca8423eec9a463ea0e155f45399455e0368b27e0" +checksum = "d3db8978e608f1fe7357e211969fd9abdcae80bac1ba7a3369bb7eb6b404eb65" dependencies = [ "darling", "proc-macro2", @@ -5821,9 +5838,9 @@ dependencies = [ [[package]] name = "simd-adler32" -version = "0.3.8" +version = "0.3.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e320a6c5ad31d271ad523dcf3ad13e2767ad8b1cb8f047f75a8aeaf8da139da2" +checksum = "703d5c7ef118737c72f1af64ad2f6f8c5e1921f818cdcb97b8fe6fc69bf66214" [[package]] name = "similar" @@ -5936,12 +5953,12 @@ dependencies = [ [[package]] name = "socket2" -version = "0.6.2" +version = "0.6.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "86f4aa3ad99f2088c990dfa82d367e19cb29268ed67c574d10d0a4bfe71f07e0" +checksum = "3a766e1110788c36f4fa1c2b71b387a7815aa65f88ce0229841826633d93723e" dependencies = [ "libc", - "windows-sys 0.60.2", + "windows-sys 0.61.2", ] [[package]] @@ -6140,9 +6157,9 @@ checksum = "55937e1799185b12863d447f42597ed69d9928686b8d88a1df17376a097d8369" [[package]] name = "tempfile" -version = "3.26.0" +version = "3.27.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "82a72c767771b47409d2345987fda8628641887d5466101319899796367354a0" +checksum = "32497e9a4c7b38532efcdebeef879707aa9f794296a4f0244f6f69e9bc8574bd" dependencies = [ "fastrand", "getrandom 0.4.2", @@ -6173,12 +6190,12 @@ dependencies = [ [[package]] name = "terminal_size" -version = "0.4.3" +version = "0.4.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "60b8cb979cb11c32ce1603f8137b22262a9d131aaa5c37b5678025f22b8becd0" +checksum = "230a1b821ccbd75b185820a1f1ff7b14d21da1e442e22c0863ea5f08771a8874" dependencies = [ "rustix 1.1.4", - "windows-sys 0.60.2", + "windows-sys 0.61.2", ] [[package]] @@ -6339,9 +6356,9 @@ dependencies = [ [[package]] name = "tinystr" -version = "0.8.2" +version = "0.8.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "42d3e9c45c09de15d06dd8acf5f4e0e399e85927b7f00711024eb7ae10fa4869" +checksum = "c8323304221c2a851516f22236c5722a72eaa19749016521d6dff0824447d96d" dependencies = [ "displaydoc", "zerovec", @@ -6349,13 +6366,13 @@ dependencies = [ [[package]] name = "tokio" -version = "1.50.0" +version = "1.51.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "27ad5e34374e03cfffefc301becb44e9dc3c17584f414349ebe29ed26661822d" +checksum = "2bd1c4c0fc4a7ab90fc15ef6daaa3ec3b893f004f915f2392557ed23237820cd" dependencies = [ "bytes", "libc", - "mio 1.1.1", + "mio 1.2.0", "parking_lot", "pin-project-lite", "signal-hook-registry", @@ -6366,9 +6383,9 @@ dependencies = [ [[package]] name = "tokio-macros" -version = "2.6.1" +version = "2.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5c55a2eff8b69ce66c84f85e1da1c233edc36ceb85a2058d11b0d6a3c7e7569c" +checksum = "385a6cb71ab9ab790c5fe8d67f1645e6c450a7ce006a33de03daa956cf70a496" dependencies = [ "proc-macro2", "quote", @@ -6444,9 +6461,9 @@ dependencies = [ [[package]] name = "toml_datetime" -version = "1.0.0+spec-1.1.0" +version = "1.1.1+spec-1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "32c2555c699578a4f59f0cc68e5116c8d7cabbd45e1409b989d4be085b53f13e" +checksum = "3165f65f62e28e0115a00b2ebdd37eb6f3b641855f9d636d3cd4103767159ad7" dependencies = [ "serde_core", ] @@ -6457,7 +6474,7 @@ version = "0.19.15" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1b5bb770da30e5cbfde35a2d7b9b8a2c4b8ef89548a7a6aeab5c9a576e3e7421" dependencies = [ - "indexmap 2.13.0", + "indexmap 2.13.1", "serde", "serde_spanned", "toml_datetime 0.6.11", @@ -6470,33 +6487,33 @@ version = "0.22.27" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "41fe8c660ae4257887cf66394862d21dbca4a6ddd26f04a3560410406a2f819a" dependencies = [ - "indexmap 2.13.0", + "indexmap 2.13.1", "serde", "serde_spanned", "toml_datetime 0.6.11", "toml_write", - "winnow 0.7.14", + "winnow 0.7.15", ] [[package]] name = "toml_edit" -version = "0.25.4+spec-1.1.0" +version = "0.25.10+spec-1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7193cbd0ce53dc966037f54351dbbcf0d5a642c7f0038c382ef9e677ce8c13f2" +checksum = "a82418ca169e235e6c399a84e395ab6debeb3bc90edc959bf0f48647c6a32d1b" dependencies = [ - "indexmap 2.13.0", - "toml_datetime 1.0.0+spec-1.1.0", + "indexmap 2.13.1", + "toml_datetime 1.1.1+spec-1.1.0", "toml_parser", - "winnow 0.7.14", + "winnow 1.0.1", ] [[package]] name = "toml_parser" -version = "1.0.9+spec-1.1.0" +version = "1.1.2+spec-1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "702d4415e08923e7e1ef96cd5727c0dfed80b4d2fa25db9647fe5eb6f7c5a4c4" +checksum = "a2abe9b86193656635d2411dc43050282ca48aa31c2451210f4202550afb7526" dependencies = [ - "winnow 0.7.14", + "winnow 1.0.1", ] [[package]] @@ -6631,9 +6648,9 @@ dependencies = [ [[package]] name = "tracing-subscriber" -version = "0.3.22" +version = "0.3.23" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2f30143827ddab0d256fd843b7a66d164e9f271cfa0dde49142c5ca0ca291f1e" +checksum = "cb7f578e5945fb242538965c2d0b04418d38ec25c79d160cd279bf0731c8d319" dependencies = [ "matchers", "nu-ansi-term", @@ -6721,9 +6738,9 @@ checksum = "562d481066bde0658276a35467c4af00bdc6ee726305698a55b86e61d7ad82bb" [[package]] name = "typewit" -version = "1.14.2" +version = "1.15.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f8c1ae7cc0fdb8b842d65d127cb981574b0d2b249b74d1c7a2986863dc134f71" +checksum = "06fee3a8df48c50c55ad646a4e03b00a370da6fe1850ebf467a8d0165dfcafae" dependencies = [ "typewit_proc_macros", ] @@ -6784,9 +6801,9 @@ checksum = "3b09c83c3c29d37506a3e260c08c03743a6bb66a9cd432c6934ab501a190571f" [[package]] name = "unicode-segmentation" -version = "1.12.0" +version = "1.13.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f6ccf251212114b54433ec949fd6a7841275f9ada20dddd2f29e9ceea4501493" +checksum = "9629274872b2bfaf8d66f5f15725007f635594914870f65218920345aa11aa8c" [[package]] name = "unicode-width" @@ -6839,9 +6856,9 @@ checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821" [[package]] name = "uuid" -version = "1.22.0" +version = "1.23.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a68d3c8f01c0cfa54a75291d83601161799e4a89a39e0929f4b0354d88757a37" +checksum = "5ac8b6f42ead25368cf5b098aeb3dc8a1a2c05a3eee8a9a1a68c640edbfc79d9" dependencies = [ "getrandom 0.4.2", ] @@ -7070,7 +7087,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bb0e353e6a2fbdc176932bbaab493762eb1255a7900fe0fea1a2f96c296cc909" dependencies = [ "anyhow", - "indexmap 2.13.0", + "indexmap 2.13.1", "wasm-encoder", "wasmparser", ] @@ -7094,7 +7111,7 @@ checksum = "47b807c72e1bac69382b3a6fb3dbe8ea4c0ed87ff5629b8685ae6b9a611028fe" dependencies = [ "bitflags 2.11.0", "hashbrown 0.15.5", - "indexmap 2.13.0", + "indexmap 2.13.1", "semver 1.0.27", ] @@ -7564,9 +7581,18 @@ dependencies = [ [[package]] name = "winnow" -version = "0.7.14" +version = "0.7.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5a5364e9d77fcdeeaa6062ced926ee3381faa2ee02d3eb83a5c27a8825540829" +checksum = "df79d97927682d2fd8adb29682d1140b343be4ac0f08fd68b7765d9c059d3945" +dependencies = [ + "memchr", +] + +[[package]] +name = "winnow" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09dac053f1cd375980747450bfc7250c264eaae0583872e845c0c7cd578872b5" dependencies = [ "memchr", ] @@ -7599,7 +7625,7 @@ checksum = "b7c566e0f4b284dd6561c786d9cb0142da491f46a9fbed79ea69cdad5db17f21" dependencies = [ "anyhow", "heck 0.5.0", - "indexmap 2.13.0", + "indexmap 2.13.1", "prettyplease", "syn 2.0.117", "wasm-metadata", @@ -7630,7 +7656,7 @@ checksum = "9d66ea20e9553b30172b5e831994e35fbde2d165325bec84fc43dbf6f4eb9cb2" dependencies = [ "anyhow", "bitflags 2.11.0", - "indexmap 2.13.0", + "indexmap 2.13.1", "log", "serde", "serde_derive", @@ -7649,7 +7675,7 @@ checksum = "ecc8ac4bc1dc3381b7f59c34f00b67e18f910c2c0f50015669dde7def656a736" dependencies = [ "anyhow", "id-arena", - "indexmap 2.13.0", + "indexmap 2.13.1", "log", "semver 1.0.27", "serde", @@ -7661,9 +7687,9 @@ dependencies = [ [[package]] name = "writeable" -version = "0.6.2" +version = "0.6.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9edde0db4769d2dc68579893f2306b26c6ecfbe0ef499b013d731b7b9247e0b9" +checksum = "1ffae5123b2d3fc086436f8834ae3ab053a283cfac8fe0a0b8eaae044768a4c4" [[package]] name = "wyz" @@ -7702,9 +7728,9 @@ dependencies = [ [[package]] name = "yoke" -version = "0.8.1" +version = "0.8.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "72d6e5c6afb84d73944e5cedb052c4680d5657337201555f9f2a16b7406d4954" +checksum = "abe8c5fda708d9ca3df187cae8bfb9ceda00dd96231bed36e445a1a48e66f9ca" dependencies = [ "stable_deref_trait", "yoke-derive", @@ -7713,9 +7739,9 @@ dependencies = [ [[package]] name = "yoke-derive" -version = "0.8.1" +version = "0.8.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b659052874eb698efe5b9e8cf382204678a0086ebf46982b79d6ca3182927e5d" +checksum = "de844c262c8848816172cef550288e7dc6c7b7814b4ee56b3e1553f275f1858e" dependencies = [ "proc-macro2", "quote", @@ -7725,18 +7751,18 @@ dependencies = [ [[package]] name = "zerocopy" -version = "0.8.40" +version = "0.8.48" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a789c6e490b576db9f7e6b6d661bcc9799f7c0ac8352f56ea20193b2681532e5" +checksum = "eed437bf9d6692032087e337407a86f04cd8d6a16a37199ed57949d415bd68e9" dependencies = [ "zerocopy-derive", ] [[package]] name = "zerocopy-derive" -version = "0.8.40" +version = "0.8.48" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f65c489a7071a749c849713807783f70672b28094011623e200cb86dcb835953" +checksum = "70e3cd084b1788766f53af483dd21f93881ff30d7320490ec3ef7526d203bad4" dependencies = [ "proc-macro2", "quote", @@ -7745,18 +7771,18 @@ dependencies = [ [[package]] name = "zerofrom" -version = "0.1.6" +version = "0.1.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "50cc42e0333e05660c3587f3bf9d0478688e15d870fab3346451ce7f8c9fbea5" +checksum = "69faa1f2a1ea75661980b013019ed6687ed0e83d069bc1114e2cc74c6c04c4df" dependencies = [ "zerofrom-derive", ] [[package]] name = "zerofrom-derive" -version = "0.1.6" +version = "0.1.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d71e5d6e06ab090c67b5e44993ec16b72dcbaabc526db883a360057678b48502" +checksum = "11532158c46691caf0f2593ea8358fed6bbf68a0315e80aae9bd41fbade684a1" dependencies = [ "proc-macro2", "quote", @@ -7786,9 +7812,9 @@ dependencies = [ [[package]] name = "zerotrie" -version = "0.2.3" +version = "0.2.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2a59c17a5562d507e4b54960e8569ebee33bee890c70aa3fe7b97e85a9fd7851" +checksum = "0f9152d31db0792fa83f70fb2f83148effb5c1f5b8c7686c3459e361d9bc20bf" dependencies = [ "displaydoc", "yoke", @@ -7797,9 +7823,9 @@ dependencies = [ [[package]] name = "zerovec" -version = "0.11.5" +version = "0.11.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6c28719294829477f525be0186d13efa9a3c602f7ec202ca9e353d310fb9a002" +checksum = "90f911cbc359ab6af17377d242225f4d75119aec87ea711a880987b18cd7b239" dependencies = [ "yoke", "zerofrom", @@ -7808,9 +7834,9 @@ dependencies = [ [[package]] name = "zerovec-derive" -version = "0.11.2" +version = "0.11.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eadce39539ca5cb3985590102671f2567e659fca9666581ad3411d59207951f3" +checksum = "625dc425cab0dca6dc3c3319506e6593dcb08a9f387ea3b284dbd52a92c40555" dependencies = [ "proc-macro2", "quote", diff --git a/Cargo.toml b/Cargo.toml index a6fcf7e4e..1ad285751 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -108,6 +108,7 @@ provekit-wasm = { path = "tooling/provekit-wasm" } # 3rd party anyhow = "1.0.93" +ciborium = "0.2.2" argh = "0.1.12" axum = "0.8.4" base64 = "0.22.1" @@ -190,4 +191,4 @@ spongefish = { git = "https://github.com/arkworks-rs/spongefish", features = [ "ark-ff", "sha2", ], rev = "fcc277f8a857fdeeadd7cca92ab08de63b1ff1a1" } spongefish-pow = { git = "https://github.com/arkworks-rs/spongefish", rev = "fcc277f8a857fdeeadd7cca92ab08de63b1ff1a1" } -whir = { git ="https://github.com/WizardOfMenlo/whir/", rev="0aeaa7f337c743d9ddfcb9d909628d6491e3355c", features = ["tracing", "rs_in_order"] } +whir = { git = "https://github.com/WizardOfMenlo/whir/", rev = "0aeaa7f337c743d9ddfcb9d909628d6491e3355c", features = ["tracing", "rs_in_order"] } diff --git a/provekit/common/src/skyscraper/mod.rs b/provekit/common/src/skyscraper/mod.rs index eca2c2ae9..80d8b78ba 100644 --- a/provekit/common/src/skyscraper/mod.rs +++ b/provekit/common/src/skyscraper/mod.rs @@ -5,5 +5,5 @@ mod whir; pub use self::{ pow::SkyscraperPoW, sponge::SkyscraperSponge, - whir::{SkyscraperHashEngine, SKYSCRAPER}, + whir::{SkyscraperHashEngine, SKYSCRAPER, SKYSCRAPER_ENGINE_ID}, }; diff --git a/provekit/common/src/skyscraper/whir.rs b/provekit/common/src/skyscraper/whir.rs index 0095cc877..5c8c22232 100644 --- a/provekit/common/src/skyscraper/whir.rs +++ b/provekit/common/src/skyscraper/whir.rs @@ -11,13 +11,16 @@ use { }, }; -/// Pre-computed `EngineId` for the Skyscraper hash engine. +/// Raw 32-byte engine ID for the Skyscraper hash engine. /// -/// Derived as `SHA3-256("whir::hash" || "skyscraper")`. -pub const SKYSCRAPER: EngineId = EngineId::new([ +/// Derived as `SHA3-256("whir::hash" || "skyscraper")`. Use as protocol_id etc. +pub const SKYSCRAPER_ENGINE_ID: [u8; 32] = [ 0xa5, 0x0d, 0x5e, 0xe2, 0xa3, 0xfc, 0x52, 0xe9, 0x6f, 0x11, 0x10, 0x3c, 0xbb, 0x8a, 0x65, 0xa3, 0x77, 0xb5, 0x82, 0xb0, 0xb2, 0xdd, 0x42, 0x1c, 0x66, 0x19, 0x13, 0xe6, 0xa5, 0x63, 0xf8, 0xa1, -]); +]; + +/// Pre-computed `EngineId` for the Skyscraper hash engine. +pub const SKYSCRAPER: EngineId = EngineId::new(SKYSCRAPER_ENGINE_ID); // ============================================================================ // WHIR 2.0 HashEngine Implementation diff --git a/provekit/prover/Cargo.toml b/provekit/prover/Cargo.toml index 82f848326..d289f7c2e 100644 --- a/provekit/prover/Cargo.toml +++ b/provekit/prover/Cargo.toml @@ -25,6 +25,7 @@ noirc_abi.workspace = true ark-ff.workspace = true ark-std.workspace = true whir.workspace = true +hex.workspace = true # 3rd party anyhow.workspace = true diff --git a/provekit/prover/src/lib.rs b/provekit/prover/src/lib.rs index ec73218ef..c3bcef1a8 100644 --- a/provekit/prover/src/lib.rs +++ b/provekit/prover/src/lib.rs @@ -7,6 +7,7 @@ use { }, acir::native_types::{Witness, WitnessMap}, anyhow::{Context, Result}, + hex, provekit_common::{ utils::noir_to_native, FieldElement, NoirElement, NoirProof, NoirProver, Prover, PublicInputs, TranscriptSponge, @@ -138,6 +139,7 @@ impl Prove for NoirProver { .whir_for_witness .create_domain_separator() .instance(&instance); + let mut merlin = ProverState::new(&ds, TranscriptSponge::from_config(self.hash_config)); let mut witness: Vec> = vec![None; num_witnesses]; diff --git a/provekit/verifier/src/whir_r1cs.rs b/provekit/verifier/src/whir_r1cs.rs index 9646ebdb8..76380c0ed 100644 --- a/provekit/verifier/src/whir_r1cs.rs +++ b/provekit/verifier/src/whir_r1cs.rs @@ -19,6 +19,7 @@ use { }, }; +#[derive(Debug)] pub struct DataFromSumcheckVerifier { r: Vec, alpha: Vec, @@ -79,7 +80,6 @@ impl WhirR1CSVerifier for WhirR1CSScheme { } else { (None, None) }; - let (transposed, sumcheck_result) = rayon::join( || transpose_r1cs_matrices(r1cs), || run_sumcheck_verifier(&mut arthur, self.m_0), diff --git a/recursive-verifier/app/circuit/circuit.go b/recursive-verifier/app/circuit/circuit.go index 006924ff0..162a823ea 100644 --- a/recursive-verifier/app/circuit/circuit.go +++ b/recursive-verifier/app/circuit/circuit.go @@ -3,13 +3,14 @@ package circuit import ( "fmt" "log" + "math/big" "os" "path/filepath" "time" "reilabs/whir-verifier-circuit/app/common" - "reilabs/whir-verifier-circuit/app/typeConverters" "reilabs/whir-verifier-circuit/app/utilities" + "reilabs/whir-verifier-circuit/app/whir" "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark/backend" @@ -18,386 +19,370 @@ import ( "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/frontend/cs/r1cs" "github.com/consensys/gnark/std/math/uints" + gnark_nimue "github.com/reilabs/gnark-nimue" ) -type Circuit struct { - // Inputs - WitnessLinearStatementEvaluations []frontend.Variable - HidingSpartanLinearStatementEvaluations []frontend.Variable - LogNumConstraints int - LogNumVariables int - LogANumTerms int - HidingSpartanFirstRound Merkle - HidingSpartanMerkle Merkle - WHIRParamsWitness WHIRParams - WHIRParamsHidingSpartan WHIRParams - NumChallenges int - W1Size int - - // Witness commitments (length 1 for single mode, N for batch mode) - WitnessFirstRounds []Merkle - WitnessClaimedEvaluations [][]frontend.Variable // [commitment_idx][eval_idx] - WitnessBlindingEvaluations [][]frontend.Variable - - // For public_f_sum and public_g_sum - PubWitnessEvaluations []frontend.Variable - - // Batch mode only: batched polynomial for rounds 1+ - WitnessMerkle Merkle +type NimueInit = gnark_nimue.NimueInit +type Circuit struct { + InitializationData NimueInit + LogNumConstraints int + + SessionID [32]uints.U8 `gnark:",public"` + Transcript []uints.U8 `gnark:",public"` + BlindingCommitmentWhirConfig WHIRParams + BlindedCommitmentWhirConfig WHIRParams + NumChallenges int + ChallengeOffsets []int + W1Size int + PublicInputs PublicInputs + + // Merkle proof data for WHIR commitment verification (commitment 1 / single). + BlindedMerkleData whir.WhirMerkleData + BlindingMerkleData whir.WhirMerkleData + // Merkle proof data for commitment 2 (dual mode only). + BlindedMerkleData2 whir.WhirMerkleData + BlindingMerkleData2 whir.WhirMerkleData + + // R1CS matrices as sparse cell lists. Used to compute the weight MLE + // evaluations for the FinalClaim binding check. MatrixA []MatrixCell MatrixB []MatrixCell MatrixC []MatrixCell +} - IO []byte - Transcript []uints.U8 `gnark:",public"` - PublicInputs PublicInputs +type Commitment struct { + RootHash frontend.Variable + InitialOODQueries []frontend.Variable + InitialOODAnswers [][]frontend.Variable } func (circuit *Circuit) Define(api frontend.API) error { - sc, arthur, uapi, err := initializeComponents(api, circuit) + sc, nimue, uapi, err := initializeComponents(api, circuit) if err != nil { return err } - // Parse first commitment (C1) - needed to consume transcript - rootHash1, batchingRandomness1, initialOODQueries1, initialOODAnswers1, err := parseBatchedCommitment(arthur, circuit.WHIRParamsWitness) + blindedCommitments, blindingCommitment, err := zkWHIRCommitmentParsing(api, nimue, circuit.BlindedCommitmentWhirConfig, circuit.BlindingCommitmentWhirConfig, 1) + // api.Println("blindedCommitments", blindedCommitments) + // api.Println("blindingCommitment", blindingCommitment) if err != nil { return err } + numPolynomials := 1 + isDualMode := circuit.NumChallenges > 0 - // Variables for second commitment (only used in dual mode) - var rootHash2, batchingRandomness2 frontend.Variable - var initialOODQueries2 []frontend.Variable - var initialOODAnswers2 [][]frontend.Variable - - if circuit.NumChallenges > 0 { - // Squeeze logup challenges + var blindedCommitments2 []Commitment + var blindingCommitment2 Commitment + if isDualMode { logupChallenges := make([]frontend.Variable, circuit.NumChallenges) - if err = arthur.FillChallengeScalars(logupChallenges); err != nil { + if err = nimue.FillChallengeScalars(logupChallenges); err != nil { return err } - // Parse second commitment (C2) - rootHash2, batchingRandomness2, initialOODQueries2, initialOODAnswers2, err = parseBatchedCommitment(arthur, circuit.WHIRParamsWitness) + blindedCommitments2, blindingCommitment2, err = zkWHIRCommitmentParsing(api, nimue, circuit.BlindedCommitmentWhirConfig, circuit.BlindingCommitmentWhirConfig, 1) + // api.Println("blindedCommitments2", blindedCommitments2) + // api.Println("blindingCommitment2", blindingCommitment2) if err != nil { return err } } - // Squeeze tRand for Spartan - tRand := make([]frontend.Variable, circuit.LogNumConstraints) - err = arthur.FillChallengeScalars(tRand) + tRand, alpha, fAtAlpha, blindingEval, err := runZKSumcheck(api, sc, uapi, circuit, nimue, frontend.Variable(0), circuit.LogNumConstraints, 4) if err != nil { return err } - // Run ZK sumcheck - spartanSumcheckRand, spartanSumcheckLastValue, err := runZKSumcheck(api, sc, uapi, circuit, arthur, frontend.Variable(0), circuit.LogNumConstraints, 4, circuit.WHIRParamsHidingSpartan) + // Public inputs hash check + x challenge + err = publicInputsHashCheck(api, sc, nimue, circuit.PublicInputs) if err != nil { return err } - // Read public inputs hash from transcript - publicInputsHashBuf := make([]frontend.Variable, 1) - if err := arthur.FillNextScalars(publicInputsHashBuf); err != nil { - return fmt.Errorf("failed to read public inputs hash: %w", err) + publicWeightsChallenge := make([]frontend.Variable, 1) + if err := nimue.FillChallengeScalars(publicWeightsChallenge); err != nil { + return fmt.Errorf("failed to read public weights challenge: %w", err) } - expectedHash, err := hashPublicInputs(sc, circuit.PublicInputs) - if err != nil { - return fmt.Errorf("failed to compute public inputs hash: %w", err) + // Read evaluations from transcript (prover_message in Rust) + evals1 := make([]frontend.Variable, 3) + if err := nimue.FillNextScalars(evals1); err != nil { + return fmt.Errorf("failed to read evals_1 from transcript: %w", err) } + evals1Az := evals1[0] + evals1Bz := evals1[1] + evals1Cz := evals1[2] - api.AssertIsEqual(publicInputsHashBuf[0], expectedHash) - - // Squeeze rand for public weights - publicWeightsChallenge := make([]frontend.Variable, 1) - if err := arthur.FillChallengeScalars(publicWeightsChallenge); err != nil { - return fmt.Errorf("failed to read public weights challenge: %w", err) + // In dual mode, evals_2 follows evals_1 in the transcript (before public_eval). + var evals2Az, evals2Bz, evals2Cz frontend.Variable + if isDualMode { + evals2 := make([]frontend.Variable, 3) + if err := nimue.FillNextScalars(evals2); err != nil { + return fmt.Errorf("failed to read evals_2 from transcript: %w", err) + } + evals2Az = evals2[0] + evals2Bz = evals2[1] + evals2Cz = evals2[2] } - // WHIR verification - var whirFoldingRandomness []frontend.Variable - var az, bz, cz frontend.Variable + hasPublicInputs := !circuit.PublicInputs.IsEmpty() - if circuit.NumChallenges > 0 { - // Only statement_1 (first commitment) gets extended with public weights, statement_2 remains unchanged - extendedLinearStatementEvalsBatch := make([][][]frontend.Variable, 2) + var publicEval frontend.Variable + if hasPublicInputs { + publicEvalSlice := make([]frontend.Variable, 1) + if err := nimue.FillNextScalars(publicEvalSlice); err != nil { + return fmt.Errorf("failed to read public_eval from transcript: %w", err) + } + publicEval = publicEvalSlice[0] + + // Verify public input binding (Rust: verify_public_input_binding). + // expected = 1 + x*pi[0] + x²*pi[1] + ... + // where x = publicWeightsChallenge and position 0 is the constant 1. + expectedPublicEval := frontend.Variable(1) + xPow := publicWeightsChallenge[0] + for _, pi := range circuit.PublicInputs.Values { + expectedPublicEval = api.Add(expectedPublicEval, api.Mul(xPow, pi)) + xPow = api.Mul(xPow, publicWeightsChallenge[0]) + } + api.AssertIsEqual(publicEval, expectedPublicEval) + } - if !circuit.PublicInputs.IsEmpty() { - extendedLinearStatementEvalsBatch[0] = extendLinearStatement( - circuit, - [][]frontend.Variable{circuit.WitnessClaimedEvaluations[0], circuit.WitnessBlindingEvaluations[0]}, - circuit.PubWitnessEvaluations, - ) + // Challenge binding: in dual mode, read challenge_eval from transcript. + var challengeEval frontend.Variable + if isDualMode { + ceSlice := make([]frontend.Variable, 1) + if err := nimue.FillNextScalars(ceSlice); err != nil { + return fmt.Errorf("failed to read challenge_eval from transcript: %w", err) + } + challengeEval = ceSlice[0] + } - extendedLinearStatementEvalsBatch[1] = [][]frontend.Variable{ - circuit.WitnessClaimedEvaluations[1], - circuit.WitnessBlindingEvaluations[1], - } - } else { - // Use original arrays as before, no public inputs - extendedLinearStatementEvalsBatch[0] = [][]frontend.Variable{ - circuit.WitnessClaimedEvaluations[0], - circuit.WitnessBlindingEvaluations[0], - } - extendedLinearStatementEvalsBatch[1] = [][]frontend.Variable{ - circuit.WitnessClaimedEvaluations[1], - circuit.WitnessBlindingEvaluations[1], - } + var whirEvaluations []frontend.Variable + if hasPublicInputs { + whirEvaluations = []frontend.Variable{publicEval, evals1Az, evals1Bz, evals1Cz, blindingEval} + } else { + whirEvaluations = []frontend.Variable{evals1Az, evals1Bz, evals1Cz, blindingEval} + } + + weightsLen := 4 + if hasPublicInputs { + weightsLen = 5 + } + + blindedCommitmentNimue := ParsedCommitmentNimue{ + Root: blindedCommitments[0].RootHash, + OodPoints: blindedCommitments[0].InitialOODQueries, + OodAnswers: flattenOODAnswers(blindedCommitments[0].InitialOODAnswers), + } + blindingCommitmentNimue := ParsedCommitmentNimue{ + Root: blindingCommitment.RootHash, + OodPoints: blindingCommitment.InitialOODQueries, + OodAnswers: flattenOODAnswers(blindingCommitment.InitialOODAnswers), + } + + mode := SingleCommitment + if isDualMode { + mode = DualCommitment1 + } + + err = ZKWhirVerify( + api, sc, nimue, + blindedCommitmentNimue, + blindingCommitmentNimue, + circuit.BlindedCommitmentWhirConfig, + circuit.BlindingCommitmentWhirConfig, + whirEvaluations, + weightsLen, + numPolynomials, + &circuit.BlindedMerkleData, + &circuit.BlindingMerkleData, + R1CSWeightParams{ + Circuit: circuit, + Alpha: alpha, + PublicWeightsChallenge: publicWeightsChallenge[0], + HasPublicInputs: hasPublicInputs, + Mode: mode, + }, + ) + if err != nil { + return fmt.Errorf("ZK-WHIR verification failed for commitment 1: %w", err) + } + + var azAtAlpha, bzAtAlpha, czAtAlpha frontend.Variable + if isDualMode { + // Commitment 2 has 3 base weights (A,B,C) + 1 challenge weight if challenge_eval exists. + whirEvaluations2 := []frontend.Variable{evals2Az, evals2Bz, evals2Cz} + weightsLen2 := 3 + if circuit.NumChallenges > 0 { + whirEvaluations2 = append(whirEvaluations2, challengeEval) + weightsLen2 = 4 } - whirFoldingRandomness, err = RunZKWhirBatch( - api, arthur, uapi, sc, - circuit.WitnessFirstRounds, // firstRounds []Merkle - []frontend.Variable{batchingRandomness1, batchingRandomness2}, // batchingRandomnesses - [][]frontend.Variable{initialOODQueries1, initialOODQueries2}, // initialOODQueries - [][][]frontend.Variable{initialOODAnswers1, initialOODAnswers2}, // initialOODAnswers - []frontend.Variable{rootHash1, rootHash2}, // rootHashes - circuit.WitnessMerkle, // batchedMerkle - extendedLinearStatementEvalsBatch, // linearStatementEvals (extended for first commitment) - circuit.WHIRParamsWitness, // whirParams - circuit.WitnessLinearStatementEvaluations, // linearStatementValuesAtPoints - circuit.PublicInputs, // publicInputs - ) - if err != nil { - return err + blindedCommitmentNimue2 := ParsedCommitmentNimue{ + Root: blindedCommitments2[0].RootHash, + OodPoints: blindedCommitments2[0].InitialOODQueries, + OodAnswers: flattenOODAnswers(blindedCommitments2[0].InitialOODAnswers), + } + blindingCommitmentNimue2 := ParsedCommitmentNimue{ + Root: blindingCommitment2.RootHash, + OodPoints: blindingCommitment2.InitialOODQueries, + OodAnswers: flattenOODAnswers(blindingCommitment2.InitialOODAnswers), } - // Sum evaluations from both commitments - az = api.Add(circuit.WitnessClaimedEvaluations[0][0], circuit.WitnessClaimedEvaluations[1][0]) - bz = api.Add(circuit.WitnessClaimedEvaluations[0][1], circuit.WitnessClaimedEvaluations[1][1]) - cz = api.Add(circuit.WitnessClaimedEvaluations[0][2], circuit.WitnessClaimedEvaluations[1][2]) - } else { - extendedLinearStatementEvals := extendLinearStatement(circuit, [][]frontend.Variable{circuit.WitnessClaimedEvaluations[0], circuit.WitnessBlindingEvaluations[0]}, circuit.PubWitnessEvaluations) - - // Single commitment mode - whirFoldingRandomness, err = RunZKWhir( - api, arthur, uapi, sc, - circuit.WitnessMerkle, circuit.WitnessFirstRounds[0], - circuit.WHIRParamsWitness, - extendedLinearStatementEvals, - circuit.WitnessLinearStatementEvaluations, - batchingRandomness1, - initialOODQueries1, - initialOODAnswers1, - rootHash1, + err = ZKWhirVerify( + api, sc, nimue, + blindedCommitmentNimue2, + blindingCommitmentNimue2, + circuit.BlindedCommitmentWhirConfig, + circuit.BlindingCommitmentWhirConfig, + whirEvaluations2, + weightsLen2, + numPolynomials, + &circuit.BlindedMerkleData2, + &circuit.BlindingMerkleData2, + R1CSWeightParams{ + Circuit: circuit, + Alpha: alpha, + PublicWeightsChallenge: publicWeightsChallenge[0], + HasPublicInputs: false, + ChallengeOffsets: circuit.ChallengeOffsets, + Mode: DualCommitment2, + }, ) if err != nil { - return err + return fmt.Errorf("ZK-WHIR verification failed for commitment 2: %w", err) } - az = circuit.WitnessClaimedEvaluations[0][0] - bz = circuit.WitnessClaimedEvaluations[0][1] - cz = circuit.WitnessClaimedEvaluations[0][2] + // az_at_alpha = evals_1 + evals_2 (Rust verifier sums the two) + azAtAlpha = api.Add(evals1Az, evals2Az) + bzAtAlpha = api.Add(evals1Bz, evals2Bz) + czAtAlpha = api.Add(evals1Cz, evals2Cz) + } else { + azAtAlpha = evals1Az + bzAtAlpha = evals1Bz + czAtAlpha = evals1Cz } - // Spartan sumcheck relation check (common to both modes) - x := api.Mul(api.Sub(api.Mul(az, bz), cz), calculateEQ(api, spartanSumcheckRand, tRand)) - api.AssertIsEqual(spartanSumcheckLastValue, x) + eqRA := calculateEqCircuit(api, tRand, alpha) + rhs := api.Mul(api.Sub(api.Mul(azAtAlpha, bzAtAlpha), czAtAlpha), eqRA) + api.AssertIsEqual(fAtAlpha, rhs) - offset := 0 - if !circuit.PublicInputs.IsEmpty() { - // can be generalized later on if we have more different kinds of statements - offset = 1 + if remaining := nimue.RemainingTranscriptLen(); remaining != 0 { + return fmt.Errorf("transcript not fully consumed: %d bytes remaining", remaining) } - if circuit.NumChallenges > 0 { - // Batch mode - check 6 deferred values - matrixExtensionEvals := evaluateR1CSMatrixExtensionBatch(api, circuit, spartanSumcheckRand, whirFoldingRandomness, circuit.W1Size) - for i := 0; i < 6; i++ { - api.AssertIsEqual(matrixExtensionEvals[i], circuit.WitnessLinearStatementEvaluations[offset+i]) - } - } else { - - // Single mode - existing logic - matrixExtensionEvals := evaluateR1CSMatrixExtension(api, circuit, spartanSumcheckRand, whirFoldingRandomness) - for i := 0; i < 3; i++ { - api.AssertIsEqual(matrixExtensionEvals[i], circuit.WitnessLinearStatementEvaluations[offset+i]) - } - } - - // Geometric weights for public inputs - if !circuit.PublicInputs.IsEmpty() { - publicWeightEval := computePublicWeightEvaluation( - api, circuit.PublicInputs, whirFoldingRandomness, publicWeightsChallenge[0], - ) + return nil +} - api.AssertIsEqual(publicWeightEval, circuit.WitnessLinearStatementEvaluations[0]) +// flattenOODAnswers converts [][]frontend.Variable (each inner slice is a +// single-element answer) into a flat []frontend.Variable. +func flattenOODAnswers(answers [][]frontend.Variable) []frontend.Variable { + var flat []frontend.Variable + for _, ans := range answers { + flat = append(flat, ans...) } + return flat +} - return nil +// configToNimueInit returns (circuit placeholder, assignment) for NimueInit. +// Circuit placeholder has all fields zeroed. Assignment is filled from cfg: +// - ProtocolID[0]: little-endian field element from cfg.ProtocolID bytes 0..31 +// - ProtocolID[1]: little-endian field element from cfg.ProtocolID bytes 32..63 +// - SessionID: little-endian field element from cfg.SessionID bytes 0..31 +// +// InstanceID is computed in-circuit from PublicInputs (see initializeComponents). +func configToNimueInit(cfg Config) (circuit, assign NimueInit) { + var pid [64]byte + copy(pid[:], cfg.ProtocolID) + var sid [32]byte + copy(sid[:], cfg.SessionID) + + // Compute instance = public_inputs.hash_bytes() for the assignment. + // The circuit recomputes this in-circuit (see initializeComponents), but + // gnark requires all witness fields to have concrete values. + piValues := make([]*big.Int, len(cfg.PublicInputs.Values)) + for i, v := range cfg.PublicInputs.Values { + piValues[i] = v.(*big.Int) + } + instance := nativePublicInputsHashBytes(piValues) + + assign = NimueInit{ + ProtocolID: [2]frontend.Variable{ + leBytesToNativeBigInt(pid[:32]), + leBytesToNativeBigInt(pid[32:]), + }, + SessionID: leBytesToNativeBigInt(sid[:]), + InstanceID: leBytesToNativeBigInt(instance[:]), + } + return circuit, assign } -func computePublicWeightEvaluation( - api frontend.API, - publicInputs PublicInputs, - foldingRandomness []frontend.Variable, - x frontend.Variable, -) frontend.Variable { - return geometricTill(api, x, len(publicInputs.Values), foldingRandomness) +// DualCommitmentData holds the additional data needed for dual-commitment mode. +type DualCommitmentData struct { + Evals2BigInt []*big.Int + BlindedMerkleData whir.WhirMerkleData + BlindingMerkleData whir.WhirMerkleData } +// verifyCircuit builds the gnark circuit and runs Groth16 proving + verification. func verifyCircuit( - deferred []Fp256, cfg Config, - hints Hints, pk *groth16.ProvingKey, vk *groth16.VerifyingKey, - claimedEvaluations ClaimedEvaluations, - claimedEvaluations2 ClaimedEvaluations, - publicWeightsClaimedEvaluation [2]Fp256, internedR1CS R1CS, interner Interner, buildOps common.BuildOps, publicInputs PublicInputs, + blindedMerkleData whir.WhirMerkleData, + blindingMerkleData whir.WhirMerkleData, + dualData *DualCommitmentData, // nil for single-commitment mode ) error { - transcriptT := make([]uints.U8, cfg.TranscriptLen) - contTranscript := make([]uints.U8, cfg.TranscriptLen) + transcriptT := make([]uints.U8, len(cfg.NargString)) + contTranscript := make([]uints.U8, len(cfg.NargString)) - for i := range cfg.Transcript { - transcriptT[i] = uints.NewU8(cfg.Transcript[i]) + for i := range cfg.NargString { + transcriptT[i] = uints.NewU8(cfg.NargString[i]) } - // Determine witness linear statement evals size based on mode - var witnessLinearStatementEvalsSize int - if cfg.NumChallenges > 0 { - if !cfg.PublicInputs.IsEmpty() { - // 3 per commitment in batch mode + 1 public_input (geometric statement as a subset of linear statement) - witnessLinearStatementEvalsSize = 7 - } else { - witnessLinearStatementEvalsSize = 6 - } - } else { - if !cfg.PublicInputs.IsEmpty() { - witnessLinearStatementEvalsSize = 4 - } else { - witnessLinearStatementEvalsSize = 3 - } - } - - witnessLinearStatementEvaluations := make([]frontend.Variable, witnessLinearStatementEvalsSize) - hidingSpartanLinearStatementEvaluations := make([]frontend.Variable, 1) - contWitnessLinearStatementEvaluations := make([]frontend.Variable, witnessLinearStatementEvalsSize) - contHidingSpartanLinearStatementEvaluations := make([]frontend.Variable, 1) - - if len(deferred) < 1+witnessLinearStatementEvalsSize { - return fmt.Errorf("deferred array too short: expected at least %d elements, got %d", 1+witnessLinearStatementEvalsSize, len(deferred)) - } - hidingSpartanLinearStatementEvaluations[0] = typeConverters.LimbsToBigIntMod(deferred[0].Limbs) - for i := 0; i < witnessLinearStatementEvalsSize; i++ { - witnessLinearStatementEvaluations[i] = typeConverters.LimbsToBigIntMod(deferred[1+i].Limbs) - } + nimueInitCircuit, nimueInitAssign := configToNimueInit(cfg) - colIndicesA := internedR1CS.A.DecodeColIndices() - if colIndicesA == nil { - return fmt.Errorf("failed to decode column indices for matrix A: inconsistent data") - } - matrixA := make([]MatrixCell, len(internedR1CS.A.Values)) - for i := range len(internedR1CS.A.RowIndices) { - end := len(internedR1CS.A.Values) - 1 - if i < len(internedR1CS.A.RowIndices)-1 { - end = int(internedR1CS.A.RowIndices[i+1] - 1) - } - for j := int(internedR1CS.A.RowIndices[i]); j <= end; j++ { - matrixA[j] = MatrixCell{ - row: i, - column: int(colIndicesA[j]), - value: typeConverters.LimbsToBigIntMod(interner.Values[internedR1CS.A.Values[j]].Limbs), - } - } - } - - colIndicesB := internedR1CS.B.DecodeColIndices() - if colIndicesB == nil { - return fmt.Errorf("failed to decode column indices for matrix B: inconsistent data") - } - matrixB := make([]MatrixCell, len(internedR1CS.B.Values)) - for i := range len(internedR1CS.B.RowIndices) { - end := len(internedR1CS.B.Values) - 1 - if i < len(internedR1CS.B.RowIndices)-1 { - end = int(internedR1CS.B.RowIndices[i+1] - 1) - } - for j := int(internedR1CS.B.RowIndices[i]); j <= end; j++ { - matrixB[j] = MatrixCell{ - row: i, - column: int(colIndicesB[j]), - value: typeConverters.LimbsToBigIntMod(interner.Values[internedR1CS.B.Values[j]].Limbs), - } - } - } - - colIndicesC := internedR1CS.C.DecodeColIndices() - if colIndicesC == nil { - return fmt.Errorf("failed to decode column indices for matrix C: inconsistent data") - } - matrixC := make([]MatrixCell, len(internedR1CS.C.Values)) - for i := range len(internedR1CS.C.RowIndices) { - end := len(internedR1CS.C.Values) - 1 - if i < len(internedR1CS.C.RowIndices)-1 { - end = int(internedR1CS.C.RowIndices[i+1] - 1) - } - for j := int(internedR1CS.C.RowIndices[i]); j <= end; j++ { - matrixC[j] = MatrixCell{ - row: i, - column: int(colIndicesC[j]), - value: typeConverters.LimbsToBigIntMod(interner.Values[internedR1CS.C.Values[j]].Limbs), - } - } + matrixA, matrixB, matrixC, err := buildR1CSMatrixCells(internedR1CS, interner) + if err != nil { + return err } - // Parse claimed evaluations for first commitment - fSums, gSums := parseClaimedEvaluations(claimedEvaluations, true) - - // Parse claimed evaluations for second commitment (if dual mode) - var fSums2, gSums2 []frontend.Variable - if cfg.NumChallenges > 0 { - fSums2, gSums2 = parseClaimedEvaluations(claimedEvaluations2, true) + publicInputsContainer := PublicInputs{ + Values: make([]frontend.Variable, len(publicInputs.Values)), } - // Parse public weights claimed evaluation - fSumPublicWeights, gSumPublicWeights := parsePublicWeightsClaimedEvaluation(publicWeightsClaimedEvaluation, true) - pubWitnessEvaluations := []frontend.Variable{fSumPublicWeights, gSumPublicWeights} - - // Build witness slices conditionally - var witnessClaimedEvals, witnessBlindingEvals [][]frontend.Variable - if cfg.NumChallenges > 0 { - witnessClaimedEvals = [][]frontend.Variable{fSums, fSums2} - witnessBlindingEvals = [][]frontend.Variable{gSums, gSums2} - } else { - witnessClaimedEvals = [][]frontend.Variable{fSums} - witnessBlindingEvals = [][]frontend.Variable{gSums} - } + // Circuit template: placeholder (zero-valued) fields for compilation. + blindedMerkleTemplate := allocateZeroWhirMerkleData(blindedMerkleData) + blindingMerkleTemplate := allocateZeroWhirMerkleData(blindingMerkleData) - // Empty container while circuit creation - publicInputsContainer := PublicInputs{ - Values: make([]frontend.Variable, len(publicInputs.Values)), + // Dual-commitment templates + var blindedMerkleTemplate2, blindingMerkleTemplate2 whir.WhirMerkleData + if dualData != nil { + blindedMerkleTemplate2 = allocateZeroWhirMerkleData(dualData.BlindedMerkleData) + blindingMerkleTemplate2 = allocateZeroWhirMerkleData(dualData.BlindingMerkleData) } circuit := Circuit{ - IO: []byte(cfg.IOPattern), - Transcript: contTranscript, - LogNumConstraints: cfg.LogNumConstraints, - LogNumVariables: cfg.LogNumVariables, - LogANumTerms: cfg.LogANumTerms, - WitnessClaimedEvaluations: witnessClaimedEvals, - WitnessBlindingEvaluations: witnessBlindingEvals, - PubWitnessEvaluations: pubWitnessEvaluations, - WitnessLinearStatementEvaluations: contWitnessLinearStatementEvaluations, - HidingSpartanLinearStatementEvaluations: contHidingSpartanLinearStatementEvaluations, - HidingSpartanFirstRound: newMerkle(hints.spartanHidingHint.firstRoundMerklePaths.path, true), - HidingSpartanMerkle: newMerkle(hints.spartanHidingHint.roundHints, true), - WitnessFirstRounds: witnessFirstRounds(hints, true), - WitnessMerkle: newMerkle(hints.WitnessRoundHints.roundHints, true), - NumChallenges: cfg.NumChallenges, - W1Size: cfg.W1Size, - WHIRParamsWitness: NewWhirParams(cfg.WHIRConfigWitness), - WHIRParamsHidingSpartan: NewWhirParams(cfg.WHIRConfigHidingSpartan), - MatrixA: matrixA, - MatrixB: matrixB, - MatrixC: matrixC, - PublicInputs: publicInputsContainer, + InitializationData: nimueInitCircuit, + Transcript: contTranscript, + LogNumConstraints: cfg.LogNumConstraints, + NumChallenges: cfg.NumChallenges, + ChallengeOffsets: cfg.ChallengeOffsets, + W1Size: cfg.W1Size, + BlindingCommitmentWhirConfig: NewWhirParams(cfg.BlindingCommitmentWhirConfig), + BlindedCommitmentWhirConfig: NewWhirParams(cfg.BlindedCommitmentWhirConfig), + PublicInputs: publicInputsContainer, + BlindedMerkleData: blindedMerkleTemplate, + BlindingMerkleData: blindingMerkleTemplate, + BlindedMerkleData2: blindedMerkleTemplate2, + BlindingMerkleData2: blindingMerkleTemplate2, + MatrixA: matrixA, + MatrixB: matrixB, + MatrixC: matrixC, } ccs, err := frontend.Compile(ecc.BN254.ScalarField(), r1cs.NewBuilder, &circuit) @@ -427,15 +412,12 @@ func verifyCircuit( vk = &unsafeVk if buildOps.ShouldSaveKeys() { - // Create the save keys directory if it doesn't exist if err := os.MkdirAll(buildOps.SaveKeys, 0o755); err != nil { log.Printf("Failed to create save keys directory %s: %v", buildOps.SaveKeys, err) } - // Generate timestamp for filenames timestamp := time.Now().Format("02Jan_15-04-05") - // Save proving key to file pkFilename := filepath.Join(buildOps.SaveKeys, fmt.Sprintf("pk_%s.bin", timestamp)) pkFile, err := os.Create(pkFilename) if err != nil { @@ -446,15 +428,13 @@ func verifyCircuit( log.Printf("Failed to close PK file: %v", err) } }() - _, err = (*pk).WriteTo(pkFile) // Dereference with (*pk) - if err != nil { + if _, err = (*pk).WriteTo(pkFile); err != nil { log.Printf("Failed to write PK to file: %v", err) } else { log.Printf("Proving key saved to %s", pkFilename) } } - // Save verifying key to file vkFilename := filepath.Join(buildOps.SaveKeys, fmt.Sprintf("vk_%s.bin", timestamp)) vkFile, err := os.Create(vkFilename) if err != nil { @@ -465,8 +445,7 @@ func verifyCircuit( log.Printf("Failed to close VK file: %v", err) } }() - _, err = (*vk).WriteTo(vkFile) // Dereference with (*vk) - if err != nil { + if _, err = (*vk).WriteTo(vkFile); err != nil { log.Printf("Failed to write VK to file: %v", err) } else { log.Printf("Verifying key saved to %s", vkFilename) @@ -475,46 +454,37 @@ func verifyCircuit( } } - // Parse actual values for assignment - fSums, gSums = parseClaimedEvaluations(claimedEvaluations, false) - if cfg.NumChallenges > 0 { - fSums2, gSums2 = parseClaimedEvaluations(claimedEvaluations2, false) - witnessClaimedEvals = [][]frontend.Variable{fSums, fSums2} - witnessBlindingEvals = [][]frontend.Variable{gSums, gSums2} - } else { - witnessClaimedEvals = [][]frontend.Variable{fSums} - witnessBlindingEvals = [][]frontend.Variable{gSums} + // Build dual-commitment assignment data + var blindedMerkleAssign2, blindingMerkleAssign2 whir.WhirMerkleData + if dualData != nil { + blindedMerkleAssign2 = dualData.BlindedMerkleData + blindingMerkleAssign2 = dualData.BlindingMerkleData } - fSumPublicWeights, gSumPublicWeights = parsePublicWeightsClaimedEvaluation(publicWeightsClaimedEvaluation, false) - pubWitnessEvaluations = []frontend.Variable{fSumPublicWeights, gSumPublicWeights} - assignment := Circuit{ - IO: []byte(cfg.IOPattern), - Transcript: transcriptT, - LogNumConstraints: cfg.LogNumConstraints, - LogNumVariables: cfg.LogNumVariables, - LogANumTerms: cfg.LogANumTerms, - WitnessClaimedEvaluations: witnessClaimedEvals, - WitnessBlindingEvaluations: witnessBlindingEvals, - WitnessLinearStatementEvaluations: witnessLinearStatementEvaluations, - PubWitnessEvaluations: pubWitnessEvaluations, - HidingSpartanLinearStatementEvaluations: hidingSpartanLinearStatementEvaluations, - HidingSpartanFirstRound: newMerkle(hints.spartanHidingHint.firstRoundMerklePaths.path, false), - HidingSpartanMerkle: newMerkle(hints.spartanHidingHint.roundHints, false), - WitnessFirstRounds: witnessFirstRounds(hints, false), - WitnessMerkle: newMerkle(hints.WitnessRoundHints.roundHints, false), - NumChallenges: cfg.NumChallenges, - W1Size: cfg.W1Size, - WHIRParamsWitness: NewWhirParams(cfg.WHIRConfigWitness), - WHIRParamsHidingSpartan: NewWhirParams(cfg.WHIRConfigHidingSpartan), - MatrixA: matrixA, - MatrixB: matrixB, - MatrixC: matrixC, - PublicInputs: publicInputs, - } - - witness, _ := frontend.NewWitness(&assignment, ecc.BN254.ScalarField()) + InitializationData: nimueInitAssign, + Transcript: transcriptT, + LogNumConstraints: cfg.LogNumConstraints, + NumChallenges: cfg.NumChallenges, + ChallengeOffsets: cfg.ChallengeOffsets, + W1Size: cfg.W1Size, + BlindingCommitmentWhirConfig: NewWhirParams(cfg.BlindingCommitmentWhirConfig), + BlindedCommitmentWhirConfig: NewWhirParams(cfg.BlindedCommitmentWhirConfig), + PublicInputs: publicInputs, + BlindedMerkleData: blindedMerkleData, + BlindingMerkleData: blindingMerkleData, + BlindedMerkleData2: blindedMerkleAssign2, + BlindingMerkleData2: blindingMerkleAssign2, + MatrixA: matrixA, + MatrixB: matrixB, + MatrixC: matrixC, + } + + witness, err := frontend.NewWitness(&assignment, ecc.BN254.ScalarField()) + if err != nil { + log.Printf("Failed to create witness: %v", err) + return err + } publicWitness, err := witness.Public() if err != nil { log.Printf("Failed witness, Public(): %v", err) @@ -526,7 +496,11 @@ func verifyCircuit( backend.WithIcicleAcceleration(), } - proof, _ := groth16.Prove(ccs, *pk, witness, opts...) + proof, err := groth16.Prove(ccs, *pk, witness, opts...) + if err != nil { + log.Printf("Failed to prove: %v", err) + return err + } err = groth16.Verify(proof, *vk, publicWitness) if err != nil { log.Printf("Failed to verify proof: %v", err) @@ -535,63 +509,30 @@ func verifyCircuit( return nil } -func parseClaimedEvaluations(claimedEvaluations ClaimedEvaluations, isContainer bool) ([]frontend.Variable, []frontend.Variable) { - fSums := make([]frontend.Variable, len(claimedEvaluations.FSums)) - gSums := make([]frontend.Variable, len(claimedEvaluations.GSums)) - - if !isContainer { - for i := range claimedEvaluations.FSums { - fSums[i] = typeConverters.LimbsToBigIntMod(claimedEvaluations.FSums[i].Limbs) - gSums[i] = typeConverters.LimbsToBigIntMod(claimedEvaluations.GSums[i].Limbs) +// allocateZeroWhirMerkleData creates a zero-valued copy of a WhirMerkleData +// with the same shape. Used as the circuit template for gnark compilation; +// the actual values go in the assignment only. +func allocateZeroWhirMerkleData(src whir.WhirMerkleData) whir.WhirMerkleData { + dst := whir.WhirMerkleData{ + Rounds: make([]whir.RoundMerkleEntry, len(src.Rounds)), + } + for r, rd := range src.Rounds { + nq := len(rd.Leaves) + entry := whir.RoundMerkleEntry{ + Leaves: make([][]frontend.Variable, nq), + SiblingHashes: make([]frontend.Variable, nq), + AuthPaths: make([][]frontend.Variable, nq), + LeafIndexes: make([]frontend.Variable, nq), } + for q := range nq { + if len(rd.Leaves[q]) > 0 { + entry.Leaves[q] = make([]frontend.Variable, len(rd.Leaves[q])) + } + if len(rd.AuthPaths[q]) > 0 { + entry.AuthPaths[q] = make([]frontend.Variable, len(rd.AuthPaths[q])) + } + } + dst.Rounds[r] = entry } - - return fSums, gSums -} - -func witnessFirstRounds(hints Hints, isContainer bool) []Merkle { - result := make([]Merkle, len(hints.WitnessFirstRoundHints)) - for i, hint := range hints.WitnessFirstRoundHints { - result[i] = newMerkle(hint.path, isContainer) - } - return result -} - -func parsePublicWeightsClaimedEvaluation(publicWeightsClaimedEvaluation [2]Fp256, isContainer bool) (frontend.Variable, frontend.Variable) { - var fSumPublicWeights, gSumPublicWeights frontend.Variable - - if !isContainer { - fSumPublicWeights = typeConverters.LimbsToBigIntMod(publicWeightsClaimedEvaluation[0].Limbs) - gSumPublicWeights = typeConverters.LimbsToBigIntMod(publicWeightsClaimedEvaluation[1].Limbs) - } - - return fSumPublicWeights, gSumPublicWeights -} - -func extendLinearStatement( - circuit *Circuit, - linearStatementEvaluations [][]frontend.Variable, - pubWitnessEvaluations []frontend.Variable, -) [][]frontend.Variable { - var extendedLinearStatementEvals [][]frontend.Variable - - if !circuit.PublicInputs.IsEmpty() { - // Extend the statement equivalent array by prepending the public constraint (public constraint is added in starting at prover side) - extendedLinearStatementEvals = make([][]frontend.Variable, 2) - - // f_sums: [public_f_sum, f_sums[0], f_sums[1]... ] - extendedLinearStatementEvals[0] = make([]frontend.Variable, len(linearStatementEvaluations[0])+1) - extendedLinearStatementEvals[0][0] = pubWitnessEvaluations[0] - copy(extendedLinearStatementEvals[0][1:], linearStatementEvaluations[0]) - - // g_sums: [public_g_sum, g_sums[0], g_sums[1]... ] - extendedLinearStatementEvals[1] = make([]frontend.Variable, len(linearStatementEvaluations[1])+1) - extendedLinearStatementEvals[1][0] = pubWitnessEvaluations[1] - copy(extendedLinearStatementEvals[1][1:], linearStatementEvaluations[1]) - } else { - // No public inputs, use original arrays - extendedLinearStatementEvals = linearStatementEvaluations - } - - return extendedLinearStatementEvals + return dst } diff --git a/recursive-verifier/app/circuit/circuit_test.go b/recursive-verifier/app/circuit/circuit_test.go index 2c62803a7..2d02d34ae 100644 --- a/recursive-verifier/app/circuit/circuit_test.go +++ b/recursive-verifier/app/circuit/circuit_test.go @@ -1,16 +1,14 @@ package circuit import ( - "bytes" - "encoding/binary" - "encoding/hex" "encoding/json" + "fmt" + "math/big" "os" - "strings" "testing" - "reilabs/whir-verifier-circuit/app/typeConverters" "reilabs/whir-verifier-circuit/app/utilities" + "reilabs/whir-verifier-circuit/app/whir" "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark/backend" @@ -19,15 +17,12 @@ import ( "github.com/consensys/gnark/frontend/cs/r1cs" "github.com/consensys/gnark/std/math/uints" "github.com/consensys/gnark/test" - gnarkNimue "github.com/reilabs/gnark-nimue" - arkSerialize "github.com/reilabs/go-ark-serialize" ) // TestCircuitConstraints checks that the circuit constraints are satisfied // without generating/verifying a full Groth16 proof. // This is much faster for testing purposes. func TestCircuitConstraints(t *testing.T) { - // Skip if test fixtures don't exist configPath := os.Getenv("TEST_CONFIG_PATH") r1csPath := os.Getenv("TEST_R1CS_PATH") @@ -35,35 +30,13 @@ func TestCircuitConstraints(t *testing.T) { t.Skip("Skipping test: TEST_CONFIG_PATH and TEST_R1CS_PATH env vars not set") } - // Load config - configFile, err := os.ReadFile(configPath) - if err != nil { - t.Fatalf("Failed to read config file: %v", err) - } - - var config Config - if err := json.Unmarshal(configFile, &config); err != nil { - t.Fatalf("Failed to unmarshal config JSON: %v", err) - } - - // Load R1CS - r1csFile, err := os.ReadFile(r1csPath) - if err != nil { - t.Fatalf("Failed to read r1cs file: %v", err) - } - - var r1csData R1CS - if err := json.Unmarshal(r1csFile, &r1csData); err != nil { - t.Fatalf("Failed to unmarshal r1cs JSON: %v", err) - } + config, r1csData := loadTestData(t, configPath, r1csPath) - // Build circuit and assignment circuit, assignment, err := buildCircuitAndAssignment(config, r1csData) if err != nil { t.Fatalf("Failed to build circuit and assignment: %v", err) } - // Use gnark's test framework to check constraint satisfaction assert := test.NewAssert(t) assert.CheckCircuit( circuit, @@ -84,32 +57,13 @@ func TestCircuitConstraintsSolverOnly(t *testing.T) { t.Skip("Skipping test: TEST_CONFIG_PATH and TEST_R1CS_PATH env vars not set") } - configFile, err := os.ReadFile(configPath) - if err != nil { - t.Fatalf("Failed to read config file: %v", err) - } - - var config Config - if err := json.Unmarshal(configFile, &config); err != nil { - t.Fatalf("Failed to unmarshal config JSON: %v", err) - } - - r1csFile, err := os.ReadFile(r1csPath) - if err != nil { - t.Fatalf("Failed to read r1cs file: %v", err) - } - - var r1csData R1CS - if err := json.Unmarshal(r1csFile, &r1csData); err != nil { - t.Fatalf("Failed to unmarshal r1cs JSON: %v", err) - } + config, r1csData := loadTestData(t, configPath, r1csPath) circuit, assignment, err := buildCircuitAndAssignment(config, r1csData) if err != nil { t.Fatalf("Failed to build circuit and assignment: %v", err) } - // Compile circuit ccs, err := frontend.Compile(ecc.BN254.ScalarField(), r1cs.NewBuilder, circuit) if err != nil { t.Fatalf("Failed to compile circuit: %v", err) @@ -117,13 +71,11 @@ func TestCircuitConstraintsSolverOnly(t *testing.T) { t.Logf("Circuit compiled: %d constraints", ccs.GetNbConstraints()) - // Create witness witness, err := frontend.NewWitness(assignment, ecc.BN254.ScalarField()) if err != nil { t.Fatalf("Failed to create witness: %v", err) } - // Solve the constraint system _, err = ccs.Solve(witness, solver.WithHints(utilities.IndexOf)) if err != nil { t.Fatalf("Constraint system not satisfied: %v", err) @@ -132,330 +84,211 @@ func TestCircuitConstraintsSolverOnly(t *testing.T) { t.Log("All constraints satisfied!") } -// buildCircuitAndAssignment constructs both the circuit definition and the witness assignment -// from the config and r1cs data. This mirrors what verifyCircuit does but separates -// circuit (placeholder) from assignment (actual values). +func loadTestData(t *testing.T, configPath, r1csPath string) (Config, R1CS) { + t.Helper() + + configFile, err := os.ReadFile(configPath) + if err != nil { + t.Fatalf("Failed to read config file: %v", err) + } + var config Config + if err := json.Unmarshal(configFile, &config); err != nil { + t.Fatalf("Failed to unmarshal config JSON: %v", err) + } + + r1csFile, err := os.ReadFile(r1csPath) + if err != nil { + t.Fatalf("Failed to read r1cs file: %v", err) + } + var r1csData R1CS + if err := json.Unmarshal(r1csFile, &r1csData); err != nil { + t.Fatalf("Failed to unmarshal r1cs JSON: %v", err) + } + + return config, r1csData +} + +// buildCircuitAndAssignment replays the Fiat-Shamir transcript natively +// (like PrepareAndVerifyCircuit) and then builds the gnark Circuit template +// and assignment (like verifyCircuit), returning both for test use. func buildCircuitAndAssignment(config Config, r1csData R1CS) (*Circuit, *Circuit, error) { - // Parse transcript and extract hints - io := gnarkNimue.IOPattern{} - if err := io.Parse([]byte(config.IOPattern)); err != nil { - return nil, nil, err - } - - var pointer uint64 - var truncated []byte - - var merklePaths []FullMultiPath[KeccakDigest] - var stirAnswers [][][]Fp256 - var deferred []Fp256 - var claimedEvaluations ClaimedEvaluations - var claimedEvaluations2 ClaimedEvaluations - var publicWeightsEvaluations [2]Fp256 - - for _, op := range io.Ops { - switch op.Kind { - case gnarkNimue.Hint: - if pointer+4 > uint64(len(config.Transcript)) { - return nil, nil, nil - } - hintLen := binary.LittleEndian.Uint32(config.Transcript[pointer : pointer+4]) - start := pointer + 4 - end := start + uint64(hintLen) - - switch string(op.Label) { - default: - // Handle batch-mode hints: stir_answers_witness_X and merkle_proof_witness_X - label := string(op.Label) - if strings.HasPrefix(label, "merkle_proof_witness_") { - var path FullMultiPath[KeccakDigest] - _, err := arkSerialize.CanonicalDeserializeWithMode( - bytes.NewReader(config.Transcript[start:end]), - &path, - false, false, - ) - if err != nil { - return nil, nil, err - } - merklePaths = append(merklePaths, path) - } else if strings.HasPrefix(label, "stir_answers_witness_") { - var stirAnswersTemporary [][]Fp256 - _, err := arkSerialize.CanonicalDeserializeWithMode( - bytes.NewReader(config.Transcript[start:end]), - &stirAnswersTemporary, - false, false, - ) - if err != nil { - return nil, nil, err - } - stirAnswers = append(stirAnswers, stirAnswersTemporary) - } - - case "merkle_proof": - var path FullMultiPath[KeccakDigest] - _, err := arkSerialize.CanonicalDeserializeWithMode( - bytes.NewReader(config.Transcript[start:end]), - &path, false, false, - ) - if err != nil { - return nil, nil, err - } - merklePaths = append(merklePaths, path) - - case "stir_answers": - var stirAnswersTemporary [][]Fp256 - _, err := arkSerialize.CanonicalDeserializeWithMode( - bytes.NewReader(config.Transcript[start:end]), - &stirAnswersTemporary, false, false, - ) - if err != nil { - return nil, nil, err - } - stirAnswers = append(stirAnswers, stirAnswersTemporary) - - case "deferred_weight_evaluations": - var deferredTemporary []Fp256 - _, err := arkSerialize.CanonicalDeserializeWithMode( - bytes.NewReader(config.Transcript[start:end]), - &deferredTemporary, false, false, - ) - if err != nil { - return nil, nil, err - } - deferred = append(deferred, deferredTemporary...) - - case "claimed_evaluations": - _, err := arkSerialize.CanonicalDeserializeWithMode( - bytes.NewReader(config.Transcript[start:end]), - &claimedEvaluations, false, false, - ) - if err != nil { - return nil, nil, err - } - case "claimed_evaluations_1": - _, err := arkSerialize.CanonicalDeserializeWithMode( - bytes.NewReader(config.Transcript[start:end]), - &claimedEvaluations, false, false, - ) - if err != nil { - return nil, nil, err - } - - case "claimed_evaluations_2": - _, err := arkSerialize.CanonicalDeserializeWithMode( - bytes.NewReader(config.Transcript[start:end]), - &claimedEvaluations2, false, false, - ) - if err != nil { - return nil, nil, err - } - - case "public_weights_evaluations": - _, err := arkSerialize.CanonicalDeserializeWithMode( - bytes.NewReader(config.Transcript[start:end]), - &publicWeightsEvaluations, false, false, - ) - if err != nil { - return nil, nil, err - } - } - pointer = end - - case gnarkNimue.Absorb: - start := pointer - if string(op.Label) == "pow-nonce" { - pointer += op.Size - } else { - pointer += op.Size * 32 - } - truncated = append(truncated, config.Transcript[start:pointer]...) - } + if len(config.ProtocolID) < 64 { + return nil, nil, fmt.Errorf("protocol_id must be 64 bytes, got %d", len(config.ProtocolID)) + } + var pid [64]byte + copy(pid[:], config.ProtocolID[:64]) + + // Compute instance = public_inputs.hash_bytes() to bind public inputs to the transcript. + piValues := make([]*big.Int, len(config.PublicInputs.Values)) + for i, v := range config.PublicInputs.Values { + piValues[i] = v.(*big.Int) } + instance := nativePublicInputsHashBytes(piValues) - config.Transcript = truncated + nimue := NewNativeNimue(pid, config.SessionID, instance, config.NargString, config.Hints) + blindedCommitmentWhirConfig := NewWhirParams(config.BlindedCommitmentWhirConfig) + blindingCommitmentWhirConfig := NewWhirParams(config.BlindingCommitmentWhirConfig) - // Parse interner - internerBytes, err := hex.DecodeString(r1csData.Interner.Values) + // 1. Parse commitment 1 + _, blindedOODPoints, blindedOODMatrix, err := nativeParseBatchedCommitment(nimue, blindedCommitmentWhirConfig) if err != nil { - return nil, nil, err + return nil, nil, fmt.Errorf("parse blinded commitment: %w", err) } + blindedCommitment := NativeCommitmentFromParsed(blindedOODPoints, blindedOODMatrix) - var interner Interner - _, err = arkSerialize.CanonicalDeserializeWithMode( - bytes.NewReader(internerBytes), &interner, false, false, - ) + _, blindingOODPoints, blindingOODMatrix, err := nativeParseBatchedCommitment(nimue, blindingCommitmentWhirConfig) if err != nil { - return nil, nil, err + return nil, nil, fmt.Errorf("parse blinding commitment: %w", err) } + blindingCommitment := NativeCommitmentFromParsed(blindingOODPoints, blindingOODMatrix) - // Build hints - hidingSpartanData := consumeWhirData(config.WHIRConfigHidingSpartan, &merklePaths, &stirAnswers) + // 2. If dual mode: squeeze logup challenges, parse commitment 2 + if config.NumChallenges > 0 { + if _, err = nimue.FillChallengeScalars(config.NumChallenges); err != nil { + return nil, nil, fmt.Errorf("logup challenges: %w", err) + } + if _, _, _, err = nativeParseBatchedCommitment(nimue, blindedCommitmentWhirConfig); err != nil { + return nil, nil, fmt.Errorf("parse commitment 2 blinded: %w", err) + } + if _, _, _, err = nativeParseBatchedCommitment(nimue, blindingCommitmentWhirConfig); err != nil { + return nil, nil, fmt.Errorf("parse commitment 2 blinding: %w", err) + } + } - var witnessFirstRoundHints []FirstRoundHint - var witnessRoundHints ZKHint + // 3. Sumcheck + if _, err = nativeRunSumcheckVerifier(nimue, config.LogNumConstraints); err != nil { + return nil, nil, fmt.Errorf("sumcheck verifier: %w", err) + } + // 4. Public inputs hash + x challenge + if _, err = nimue.FillNextScalars(1); err != nil { + return nil, nil, fmt.Errorf("public inputs hash: %w", err) + } + if _, err = nimue.FillChallengeScalars(1); err != nil { + return nil, nil, fmt.Errorf("x challenge: %w", err) + } + + // 5. Read evaluation hints + var evals1 []Fp256 + if err = nimue.ProverHintArk(&evals1); err != nil { + return nil, nil, fmt.Errorf("evals_1: %w", err) + } + evals1BigInt := fp256SliceToBigInt(evals1) + + var evals2BigInt []*big.Int if config.NumChallenges > 0 { - numCommitments := 2 - witnessFirstRoundHints = make([]FirstRoundHint, numCommitments) - for i := 0; i < numCommitments; i++ { - witnessFirstRoundHints[i] = consumeFirstRoundOnly(&merklePaths, &stirAnswers) + var evals2 []Fp256 + if err = nimue.ProverHintArk(&evals2); err != nil { + return nil, nil, fmt.Errorf("evals_2: %w", err) } - witnessRoundHints = consumeWhirDataRoundsOnly(config.WHIRConfigWitness, &merklePaths, &stirAnswers) - } else { - witnessData := consumeWhirData(config.WHIRConfigWitness, &merklePaths, &stirAnswers) - witnessFirstRoundHints = []FirstRoundHint{witnessData.firstRoundMerklePaths} - witnessRoundHints = witnessData + evals2BigInt = fp256SliceToBigInt(evals2) } - hints := Hints{ - spartanHidingHint: hidingSpartanData, - WitnessFirstRoundHints: witnessFirstRoundHints, - WitnessRoundHints: witnessRoundHints, + hasPublicInputs := !config.PublicInputs.IsEmpty() + if hasPublicInputs { + if _, err = nimue.FillNextScalars(1); err != nil { + return nil, nil, fmt.Errorf("public_eval: %w", err) + } } - // Build matrices - matrixA := buildMatrix(r1csData.A, interner) - matrixB := buildMatrix(r1csData.B, interner) - matrixC := buildMatrix(r1csData.C, interner) - - // Parse evaluations - need separate values for circuit (placeholder) and assignment (actual) - hidingSpartanLinearStatementEvaluations := make([]frontend.Variable, 1) - hidingSpartanLinearStatementEvaluationsAssign := make([]frontend.Variable, 1) - hidingSpartanLinearStatementEvaluationsAssign[0] = typeConverters.LimbsToBigIntMod(deferred[0].Limbs) + // 6. zkWHIR verify (commitment 1) + zkWhirParams := newZKWhirVerifyParams(1, hasPublicInputs) + zkWhirData1, err := nativeZKWhirVerify(nimue, config, blindedCommitmentWhirConfig, blindingCommitmentWhirConfig, zkWhirParams, blindedCommitment, blindingCommitment, evals1BigInt) + if err != nil { + return nil, nil, fmt.Errorf("zkWHIR verify commitment 1: %w", err) + } - var witnessLinearStatementEvalsSize int + // 7. If dual mode: zkWHIR verify (commitment 2) + var dualData *DualCommitmentData if config.NumChallenges > 0 { - if !config.PublicInputs.IsEmpty() { - witnessLinearStatementEvalsSize = 7 - } else { - witnessLinearStatementEvalsSize = 6 + zkWhirParams2 := ZKWhirVerifyParams{NumPolynomials: 1, WeightsLen: 3} + zkWhirData2, err := nativeZKWhirVerify(nimue, config, blindedCommitmentWhirConfig, blindingCommitmentWhirConfig, zkWhirParams2, blindedCommitment, blindingCommitment, evals2BigInt) + if err != nil { + return nil, nil, fmt.Errorf("zkWHIR verify commitment 2: %w", err) } - } else { - if !config.PublicInputs.IsEmpty() { - witnessLinearStatementEvalsSize = 4 - } else { - witnessLinearStatementEvalsSize = 3 + dualData = &DualCommitmentData{ + Evals2BigInt: evals2BigInt, + BlindedMerkleData: *zkWhirData2.BlindedMerkleData, + BlindingMerkleData: *zkWhirData2.BlindingMerkleData, } } - // For circuit definition: placeholder values (nil) - witnessLinearStatementEvaluations := make([]frontend.Variable, witnessLinearStatementEvalsSize) + // 8. Parse interner and build matrices + interner, err := ParseInterner(r1csData) + if err != nil { + return nil, nil, fmt.Errorf("parse interner: %w", err) + } - // For assignment: actual values - witnessLinearStatementEvaluationsAssign := make([]frontend.Variable, witnessLinearStatementEvalsSize) - for i := 0; i < witnessLinearStatementEvalsSize; i++ { - witnessLinearStatementEvaluationsAssign[i] = typeConverters.LimbsToBigIntMod(deferred[1+i].Limbs) + matrixA, matrixB, matrixC, err := buildR1CSMatrixCells(r1csData, interner) + if err != nil { + return nil, nil, fmt.Errorf("build matrices: %w", err) } - // Build transcript - transcriptT := make([]uints.U8, config.TranscriptLen) - contTranscript := make([]uints.U8, config.TranscriptLen) - for i := range config.Transcript { - transcriptT[i] = uints.NewU8(config.Transcript[i]) + // 9. Build circuit template and assignment (mirrors verifyCircuit) + nimueInitCircuit, nimueInitAssign := configToNimueInit(config) + + transcriptT := make([]uints.U8, len(config.NargString)) + contTranscript := make([]uints.U8, len(config.NargString)) + for i := range config.NargString { + transcriptT[i] = uints.NewU8(config.NargString[i]) } - // Parse claimed evaluations - fSums, gSums := parseClaimedEvaluations(claimedEvaluations, true) - var fSums2, gSums2 []frontend.Variable - if config.NumChallenges > 0 { - fSums2, gSums2 = parseClaimedEvaluations(claimedEvaluations2, true) + publicInputsContainer := PublicInputs{ + Values: make([]frontend.Variable, len(config.PublicInputs.Values)), } - // Build the slices conditionally - var witnessClaimedEvals, witnessBlindingEvals [][]frontend.Variable - if config.NumChallenges > 0 { - witnessClaimedEvals = [][]frontend.Variable{fSums, fSums2} - witnessBlindingEvals = [][]frontend.Variable{gSums, gSums2} - } else { - witnessClaimedEvals = [][]frontend.Variable{fSums} - witnessBlindingEvals = [][]frontend.Variable{gSums} + blindedMerkleTemplate := allocateZeroWhirMerkleData(*zkWhirData1.BlindedMerkleData) + blindingMerkleTemplate := allocateZeroWhirMerkleData(*zkWhirData1.BlindingMerkleData) + + var blindedMerkleTemplate2, blindingMerkleTemplate2 whir.WhirMerkleData + if dualData != nil { + blindedMerkleTemplate2 = allocateZeroWhirMerkleData(dualData.BlindedMerkleData) + blindingMerkleTemplate2 = allocateZeroWhirMerkleData(dualData.BlindingMerkleData) } - // Build circuit definition (with placeholder values) circuit := &Circuit{ - IO: []byte(config.IOPattern), - Transcript: contTranscript, - LogNumConstraints: config.LogNumConstraints, - LogNumVariables: config.LogNumVariables, - LogANumTerms: config.LogANumTerms, - WitnessClaimedEvaluations: witnessClaimedEvals, - WitnessBlindingEvaluations: witnessBlindingEvals, - WitnessLinearStatementEvaluations: witnessLinearStatementEvaluations, - HidingSpartanLinearStatementEvaluations: hidingSpartanLinearStatementEvaluations, - HidingSpartanFirstRound: newMerkle(hints.spartanHidingHint.firstRoundMerklePaths.path, true), - HidingSpartanMerkle: newMerkle(hints.spartanHidingHint.roundHints, true), - WitnessFirstRounds: witnessFirstRounds(hints, true), - WitnessMerkle: newMerkle(hints.WitnessRoundHints.roundHints, true), - NumChallenges: config.NumChallenges, - W1Size: config.W1Size, - WHIRParamsWitness: NewWhirParams(config.WHIRConfigWitness), - WHIRParamsHidingSpartan: NewWhirParams(config.WHIRConfigHidingSpartan), - MatrixA: matrixA, - MatrixB: matrixB, - MatrixC: matrixC, - PublicInputs: PublicInputs{Values: make([]frontend.Variable, len(config.PublicInputs.Values))}, - PubWitnessEvaluations: make([]frontend.Variable, 2), - } - - // Build assignment (with actual values) - fSumsAssign, gSumsAssign := parseClaimedEvaluations(claimedEvaluations, false) - var fSums2Assign, gSums2Assign []frontend.Variable - var witnessClaimedEvalsAssign, witnessBlindingEvalsAssign [][]frontend.Variable - if config.NumChallenges > 0 { - fSums2Assign, gSums2Assign = parseClaimedEvaluations(claimedEvaluations2, false) - witnessClaimedEvalsAssign = [][]frontend.Variable{fSumsAssign, fSums2Assign} - witnessBlindingEvalsAssign = [][]frontend.Variable{gSumsAssign, gSums2Assign} - } else { - witnessClaimedEvalsAssign = [][]frontend.Variable{fSumsAssign} - witnessBlindingEvalsAssign = [][]frontend.Variable{gSumsAssign} + InitializationData: nimueInitCircuit, + Transcript: contTranscript, + LogNumConstraints: config.LogNumConstraints, + NumChallenges: config.NumChallenges, + ChallengeOffsets: config.ChallengeOffsets, + W1Size: config.W1Size, + BlindingCommitmentWhirConfig: NewWhirParams(config.BlindingCommitmentWhirConfig), + BlindedCommitmentWhirConfig: NewWhirParams(config.BlindedCommitmentWhirConfig), + PublicInputs: publicInputsContainer, + BlindedMerkleData: blindedMerkleTemplate, + BlindingMerkleData: blindingMerkleTemplate, + BlindedMerkleData2: blindedMerkleTemplate2, + BlindingMerkleData2: blindingMerkleTemplate2, + MatrixA: matrixA, + MatrixB: matrixB, + MatrixC: matrixC, + } + + var blindedMerkleAssign2, blindingMerkleAssign2 whir.WhirMerkleData + if dualData != nil { + blindedMerkleAssign2 = dualData.BlindedMerkleData + blindingMerkleAssign2 = dualData.BlindingMerkleData } assignment := &Circuit{ - IO: []byte(config.IOPattern), - Transcript: transcriptT, - LogNumConstraints: config.LogNumConstraints, - WitnessClaimedEvaluations: witnessClaimedEvalsAssign, - WitnessBlindingEvaluations: witnessBlindingEvalsAssign, - WitnessLinearStatementEvaluations: witnessLinearStatementEvaluationsAssign, - HidingSpartanLinearStatementEvaluations: hidingSpartanLinearStatementEvaluationsAssign, - HidingSpartanFirstRound: newMerkle(hints.spartanHidingHint.firstRoundMerklePaths.path, false), - HidingSpartanMerkle: newMerkle(hints.spartanHidingHint.roundHints, false), - WitnessFirstRounds: witnessFirstRounds(hints, false), - WitnessMerkle: newMerkle(hints.WitnessRoundHints.roundHints, false), - NumChallenges: config.NumChallenges, - W1Size: config.W1Size, - WHIRParamsWitness: NewWhirParams(config.WHIRConfigWitness), - WHIRParamsHidingSpartan: NewWhirParams(config.WHIRConfigHidingSpartan), - MatrixA: matrixA, - MatrixB: matrixB, - MatrixC: matrixC, - PublicInputs: config.PublicInputs, - PubWitnessEvaluations: []frontend.Variable{ - typeConverters.LimbsToBigIntMod(publicWeightsEvaluations[0].Limbs), - typeConverters.LimbsToBigIntMod(publicWeightsEvaluations[1].Limbs), - }, + InitializationData: nimueInitAssign, + Transcript: transcriptT, + LogNumConstraints: config.LogNumConstraints, + NumChallenges: config.NumChallenges, + ChallengeOffsets: config.ChallengeOffsets, + W1Size: config.W1Size, + BlindingCommitmentWhirConfig: NewWhirParams(config.BlindingCommitmentWhirConfig), + BlindedCommitmentWhirConfig: NewWhirParams(config.BlindedCommitmentWhirConfig), + PublicInputs: config.PublicInputs, + BlindedMerkleData: *zkWhirData1.BlindedMerkleData, + BlindingMerkleData: *zkWhirData1.BlindingMerkleData, + BlindedMerkleData2: blindedMerkleAssign2, + BlindingMerkleData2: blindingMerkleAssign2, + MatrixA: matrixA, + MatrixB: matrixB, + MatrixC: matrixC, } return circuit, assignment, nil } - -func buildMatrix(sparse SparseMatrix, interner Interner) []MatrixCell { - colIndices := sparse.DecodeColIndices() - if colIndices == nil { - panic("failed to decode column indices: inconsistent matrix data") - } - matrix := make([]MatrixCell, len(sparse.Values)) - for i := range len(sparse.RowIndices) { - end := len(sparse.Values) - 1 - if i < len(sparse.RowIndices)-1 { - end = int(sparse.RowIndices[i+1] - 1) - } - for j := int(sparse.RowIndices[i]); j <= end; j++ { - matrix[j] = MatrixCell{ - row: i, - column: int(colIndices[j]), - value: typeConverters.LimbsToBigIntMod(interner.Values[sparse.Values[j]].Limbs), - } - } - } - return matrix -} diff --git a/recursive-verifier/app/circuit/common.go b/recursive-verifier/app/circuit/common.go index 707286738..5ec0a4657 100644 --- a/recursive-verifier/app/circuit/common.go +++ b/recursive-verifier/app/circuit/common.go @@ -2,226 +2,643 @@ package circuit import ( "bytes" - "encoding/binary" "encoding/hex" "fmt" + "io" "log" - "strings" + "math/big" + "math/bits" + "net/http" + "os" + "sort" + "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark/backend/groth16" - gnarkNimue "github.com/reilabs/gnark-nimue" - arkSerialize "github.com/reilabs/go-ark-serialize" "reilabs/whir-verifier-circuit/app/common" + "reilabs/whir-verifier-circuit/app/whir" ) +func FrDecimalToHexLE(decimal string) string { + n := new(big.Int) + _, ok := n.SetString(decimal, 10) + if !ok { + return "" + } + + be := n.Bytes() // big-endian + + // pad to 32 bytes + buf := make([]byte, 32) + copy(buf[32-len(be):], be) + + // convert to little-endian + for i, j := 0, len(buf)-1; i < j; i, j = i+1, j-1 { + buf[i], buf[j] = buf[j], buf[i] + } + + return hex.EncodeToString(buf) +} + +// --------------------------------------------------------------------------- +// PrepareAndVerifyCircuit: replays the spongefish Fiat-Shamir transcript +// natively to determine hint offsets, reads hints, and builds the data +// structures needed by the gnark circuit. Currently skips the actual +// circuit call (the goal is to make parameter passing functional first). +// --------------------------------------------------------------------------- + func PrepareAndVerifyCircuit(config Config, r1cs R1CS, pk *groth16.ProvingKey, vk *groth16.VerifyingKey, buildOps common.BuildOps) error { - io := gnarkNimue.IOPattern{} - err := io.Parse([]byte(config.IOPattern)) + if len(config.ProtocolID) < 64 { + return fmt.Errorf("protocol_id must be 64 bytes, got %d", len(config.ProtocolID)) + } + var pid [64]byte + copy(pid[:], config.ProtocolID[:64]) + + // Compute instance = public_inputs.hash_bytes() to bind public inputs to the transcript. + piValues := make([]*big.Int, len(config.PublicInputs.Values)) + for i, v := range config.PublicInputs.Values { + piValues[i] = v.(*big.Int) + } + instance := nativePublicInputsHashBytes(piValues) + + nimue := NewNativeNimue(pid, config.SessionID, instance, config.NargString, config.Hints) + blindedCommitmentWhirConfig := NewWhirParams(config.BlindedCommitmentWhirConfig) + blindingCommitmentWhirConfig := NewWhirParams(config.BlindingCommitmentWhirConfig) + + _, blindedCommitmentOODPoint, blindedCommitmentOODMatrix, err := nativeParseBatchedCommitment(nimue, blindedCommitmentWhirConfig) + if err != nil { - return fmt.Errorf("failed to parse IO pattern: %w", err) + return fmt.Errorf("parse blinded commitment: %w", err) } + blindedCommitment := NativeCommitmentFromParsed(blindedCommitmentOODPoint, blindedCommitmentOODMatrix) - var pointer uint64 - var truncated []byte + _, blindingCommitmentOODPoint, blindingCommitmentOODMatrix, err := nativeParseBatchedCommitment(nimue, blindingCommitmentWhirConfig) - var merklePaths []FullMultiPath[KeccakDigest] - var stirAnswers [][][]Fp256 - var deferred []Fp256 - var claimedEvaluations ClaimedEvaluations - var claimedEvaluations2 ClaimedEvaluations - var publicWeightsEvaluations [2]Fp256 + if err != nil { + return fmt.Errorf("parse blinding commitment: %w", err) + } + blindingCommitment := NativeCommitmentFromParsed(blindingCommitmentOODPoint, blindingCommitmentOODMatrix) - for _, op := range io.Ops { - switch op.Kind { - case gnarkNimue.Hint: - if pointer+4 > uint64(len(config.Transcript)) { - return fmt.Errorf("insufficient bytes for hint length") - } - hintLen := binary.LittleEndian.Uint32(config.Transcript[pointer : pointer+4]) - start := pointer + 4 - end := start + uint64(hintLen) + if config.NumChallenges > 0 { + _, err := nimue.FillChallengeScalars(config.NumChallenges) + if err != nil { + return fmt.Errorf("logup challenges: %w", err) + } + _, _, _, err = nativeParseBatchedCommitment(nimue, blindedCommitmentWhirConfig) + if err != nil { + return fmt.Errorf("parse commitment 2 blinded: %w", err) + } + _, _, _, err = nativeParseBatchedCommitment(nimue, blindingCommitmentWhirConfig) + if err != nil { + return fmt.Errorf("parse commitment 2: %w", err) + } + } - if end > uint64(len(config.Transcript)) { - return fmt.Errorf("insufficient bytes for merkle proof") - } + // Spartan sumcheck: squeeze tRand, then run ZK sumcheck + sumcheckData, err := nativeRunSumcheckVerifier(nimue, config.LogNumConstraints) + if err != nil { + return fmt.Errorf("sumcheck verifier: %w", err) + } - switch string(op.Label) { - default: - // Handle batch-mode hints: stir_answers_witness_X and merkle_proof_witness_X - label := string(op.Label) - if strings.HasPrefix(label, "merkle_proof_witness_") { - var path FullMultiPath[KeccakDigest] - _, err = arkSerialize.CanonicalDeserializeWithMode( - bytes.NewReader(config.Transcript[start:end]), - &path, - false, false, - ) - merklePaths = append(merklePaths, path) - } else if strings.HasPrefix(label, "stir_answers_witness_") { - var stirAnswersTemporary [][]Fp256 - _, err = arkSerialize.CanonicalDeserializeWithMode( - bytes.NewReader(config.Transcript[start:end]), - &stirAnswersTemporary, - false, false, - ) - stirAnswers = append(stirAnswers, stirAnswersTemporary) - } - - case "merkle_proof": - var path FullMultiPath[KeccakDigest] - _, err = arkSerialize.CanonicalDeserializeWithMode( - bytes.NewReader(config.Transcript[start:end]), - &path, - false, false, - ) - merklePaths = append(merklePaths, path) - - case "stir_answers": - var stirAnswersTemporary [][]Fp256 - _, err = arkSerialize.CanonicalDeserializeWithMode( - bytes.NewReader(config.Transcript[start:end]), - &stirAnswersTemporary, - false, false, - ) - stirAnswers = append(stirAnswers, stirAnswersTemporary) - - case "deferred_weight_evaluations": - var deferredTemporary []Fp256 - _, err = arkSerialize.CanonicalDeserializeWithMode( - bytes.NewReader(config.Transcript[start:end]), - &deferredTemporary, - false, false, - ) - if err != nil { - return fmt.Errorf("failed to deserialize deferred hint: %w", err) - } - deferred = append(deferred, deferredTemporary...) - - // Single mode hint - case "claimed_evaluations": - _, err = arkSerialize.CanonicalDeserializeWithMode( - bytes.NewReader(config.Transcript[start:end]), - &claimedEvaluations, - false, false, - ) - if err != nil { - return fmt.Errorf("failed to deserialize claimed_evaluations: %w", err) - } - - // Dual mode hints - case "claimed_evaluations_1": - _, err = arkSerialize.CanonicalDeserializeWithMode( - bytes.NewReader(config.Transcript[start:end]), - &claimedEvaluations, - false, false, - ) - if err != nil { - return fmt.Errorf("failed to deserialize claimed_evaluations_1: %w", err) - } - - case "claimed_evaluations_2": - _, err = arkSerialize.CanonicalDeserializeWithMode( - bytes.NewReader(config.Transcript[start:end]), - &claimedEvaluations2, - false, false, - ) - if err != nil { - return fmt.Errorf("failed to deserialize claimed_evaluations_2: %w", err) - } - - case "public_weights_evaluations": - _, err = arkSerialize.CanonicalDeserializeWithMode( - bytes.NewReader(config.Transcript[start:end]), - &publicWeightsEvaluations, - false, false, - ) - if err != nil { - return fmt.Errorf("failed to deserialize public_weights_evaluations: %w", err) - } - } + _, err = nimue.FillNextScalars(1) + if err != nil { + return fmt.Errorf("public inputs hash: %w", err) + } - if err != nil { - return fmt.Errorf("failed to deserialize merkle proof: %w", err) - } + _, err = nimue.FillChallengeScalars(1) + if err != nil { + return fmt.Errorf("x challenge: %w", err) + } - pointer = end + // Read evaluations from transcript (prover_message in Rust) + evals1Scalars, err := nimue.FillNextScalars(3) + if err != nil { + return fmt.Errorf("evals_1: %w", err) + } + evals1BigInt := evals1Scalars - case gnarkNimue.Absorb: - start := pointer - if string(op.Label) == "pow-nonce" { - pointer += op.Size - } else { - pointer += op.Size * 32 - } + var evals2BigInt []*big.Int + if config.NumChallenges > 0 { + evals2Scalars, err := nimue.FillNextScalars(3) + if err != nil { + return fmt.Errorf("evals_2: %w", err) + } + evals2BigInt = evals2Scalars + } - if pointer > uint64(len(config.Transcript)) { - return fmt.Errorf("absorb exceeds transcript length") - } + hasPublicInputs := !config.PublicInputs.IsEmpty() + var publicEval *big.Int + if hasPublicInputs { + publicEvalSlice, err := nimue.FillNextScalars(1) + if err != nil { + return fmt.Errorf("public_eval: %w", err) + } + publicEval = publicEvalSlice[0] + } + + // Challenge binding: in dual mode, read challenge_eval from transcript + // (prover_message in Rust). This binds w2's challenge values to the transcript. + // challenge_eval is appended to evaluations_2 and an extra weight is added. + var challengeEval *big.Int + if config.NumChallenges > 0 { + ceSlice, err := nimue.FillNextScalars(1) + if err != nil { + return fmt.Errorf("challenge_eval: %w", err) + } + challengeEval = ceSlice[0] + } - truncated = append(truncated, config.Transcript[start:pointer]...) + // Build the full evaluations vector matching Rust whir_r1cs.rs: + // [public_eval?, Az, Bz, Cz, blinding_eval] + // The blinding_eval comes from the Spartan sumcheck verifier output. + blindingEval := sumcheckData.BlindingEval + var fullEvals1 []*big.Int + if hasPublicInputs { + fullEvals1 = append([]*big.Int{publicEval}, evals1BigInt...) + } else { + fullEvals1 = append([]*big.Int{}, evals1BigInt...) + } + fullEvals1 = append(fullEvals1, blindingEval) + + // --------------------------------------------------------------- + // 6. zkWHIR verify (first commitment) + // weightsLen: 3 (A,B,C) + optional public + 1 blinding + // numPolynomials: 1 (single commitment) + // --------------------------------------------------------------- + zkWhirParams := newZKWhirVerifyParams(1, hasPublicInputs) + zkWhirData1, err := nativeZKWhirVerify(nimue, config, blindedCommitmentWhirConfig, blindingCommitmentWhirConfig, zkWhirParams, blindedCommitment, blindingCommitment, fullEvals1) + if err != nil { + return fmt.Errorf("zkWHIR verify commitment 1: %w", err) + } + + // --------------------------------------------------------------- + // 7. If dual mode: zkWHIR verify (second commitment) + // weights_2 has no public weight and no blinding weight → 3 weights + // --------------------------------------------------------------- + var dualData *DualCommitmentData + if config.NumChallenges > 0 { + // Commitment 2 has 3 base weights (A,B,C) + 1 challenge weight if challenge_eval exists. + evals2WithChallenge := evals2BigInt + weightsLen2 := 3 + if challengeEval != nil { + evals2WithChallenge = append(evals2WithChallenge, challengeEval) + weightsLen2 = 4 } + zkWhirParams2 := ZKWhirVerifyParams{NumPolynomials: 1, WeightsLen: weightsLen2} + zkWhirData2, err := nativeZKWhirVerify(nimue, config, blindedCommitmentWhirConfig, blindingCommitmentWhirConfig, zkWhirParams2, blindedCommitment, blindingCommitment, evals2WithChallenge) + if err != nil { + return fmt.Errorf("zkWHIR verify commitment 2: %w", err) + } + dualData = &DualCommitmentData{ + Evals2BigInt: evals2BigInt, + BlindedMerkleData: *zkWhirData2.BlindedMerkleData, + BlindingMerkleData: *zkWhirData2.BlindingMerkleData, + } + } + // --------------------------------------------------------------- + // 8. Remaining transcript consumed. Log status. + // --------------------------------------------------------------- + remainingHints := nimue.hints.Len() + remainingTranscript := len(nimue.nargString) + fmt.Printf("Native transcript replay complete. Remaining: %d hint bytes, %d transcript bytes\n", remainingHints, remainingTranscript) + + interner, err := ParseInterner(r1cs) + if err != nil { + return fmt.Errorf("parse interner: %w", err) + } + + if err := verifyCircuit(config, pk, vk, r1cs, interner, buildOps, config.PublicInputs, *zkWhirData1.BlindedMerkleData, *zkWhirData1.BlindingMerkleData, dualData); err != nil { + return fmt.Errorf("verify circuit: %w", err) } - config.Transcript = truncated + return nil +} + +// --------------------------------------------------------------------------- +// Native sumcheck verifier (mirrors Rust run_sumcheck_verifier) +// --------------------------------------------------------------------------- - internerBytes, err := hex.DecodeString(r1cs.Interner.Values) +// NativeSumcheckData holds the output of the native sumcheck verifier replay. +type NativeSumcheckData struct { + R []*big.Int // verifier randomness (length m0) + Alpha []*big.Int // folding challenges (length m0) + BlindingEval *big.Int // blinding polynomial evaluation + FAtAlpha *big.Int // f evaluated at alpha +} + +// nativeEvalCubicPoly evaluates poly[0] + x*(poly[1] + x*(poly[2] + x*poly[3])) mod p. +func nativeEvalCubicPoly(poly [4]*big.Int, point *big.Int) *big.Int { + // Horner's method: ((poly[3]*x + poly[2])*x + poly[1])*x + poly[0] + result := new(big.Int).Set(poly[3]) + result.Mul(result, point) + result.Add(result, poly[2]) + result.Mul(result, point) + result.Add(result, poly[1]) + result.Mul(result, point) + result.Add(result, poly[0]) + result.Mod(result, bn254Modulus) + return result +} + +// nativeRunSumcheckVerifier replays the Spartan sumcheck transcript and +// verifies the sumcheck equality assertions natively. +func nativeRunSumcheckVerifier(nimue *NativeNimue, m0 int) (*NativeSumcheckData, error) { + // r = verifier_message_vec(m0) + r, err := nimue.FillChallengeScalars(m0) if err != nil { - return fmt.Errorf("failed to decode interner values: %w", err) + return nil, fmt.Errorf("r: %w", err) } - var interner Interner - _, err = arkSerialize.CanonicalDeserializeWithMode( - bytes.NewReader(internerBytes), &interner, false, false, - ) + // sum_g = prover_message() + sumGSlice, err := nimue.FillNextScalars(1) + if err != nil { + return nil, fmt.Errorf("sum_g: %w", err) + } + sumG := sumGSlice[0] + + // rho = verifier_message() + rhoSlice, err := nimue.FillChallengeScalars(1) if err != nil { - return fmt.Errorf("failed to deserialize interner: %w", err) + return nil, fmt.Errorf("rho: %w", err) } + rho := rhoSlice[0] - hidingSpartanData := consumeWhirData(config.WHIRConfigHidingSpartan, &merklePaths, &stirAnswers) + // saved_val = rho * sum_g + savedVal := new(big.Int).Mul(rho, sumG) + savedVal.Mod(savedVal, bn254Modulus) - // Build witness hints based on mode - var witnessFirstRoundHints []FirstRoundHint - var witnessRoundHints ZKHint + alpha := make([]*big.Int, m0) - if config.NumChallenges > 0 { - // Batch mode: N commitments - // Rust emits: N first-round hints, then NRounds hints for batched polynomial - var numCommitments int - if config.NumChallenges > 0 { - numCommitments = 2 - } else { - numCommitments = 1 + for i := range m0 { + // Read 4 cubic polynomial coefficients + coeffSlice, err := nimue.FillNextScalars(4) + if err != nil { + return nil, fmt.Errorf("hhat coeff round %d: %w", i, err) + } + var hhat [4]*big.Int + hhat[0] = coeffSlice[0] + hhat[1] = coeffSlice[1] + hhat[2] = coeffSlice[2] + hhat[3] = coeffSlice[3] + + // alpha_i = verifier_message() + alphaSlice, err := nimue.FillChallengeScalars(1) + if err != nil { + return nil, fmt.Errorf("alpha round %d: %w", i, err) + } + alpha[i] = alphaSlice[0] + + // Sumcheck equality assertion: saved_val == hhat(0) + hhat(1) + hhatAtZero := nativeEvalCubicPoly(hhat, big.NewInt(0)) + hhatAtOne := nativeEvalCubicPoly(hhat, big.NewInt(1)) + sum := new(big.Int).Add(hhatAtZero, hhatAtOne) + sum.Mod(sum, bn254Modulus) + if savedVal.Cmp(sum) != 0 { + return nil, fmt.Errorf("sumcheck equality assertion failed at round %d: %s != %s", i, savedVal.String(), sum.String()) } - // Consume first-round hints for each original commitment - witnessFirstRoundHints = make([]FirstRoundHint, numCommitments) - for i := 0; i < numCommitments; i++ { - witnessFirstRoundHints[i] = consumeFirstRoundOnly(&merklePaths, &stirAnswers) + // saved_val = hhat(alpha_i) + savedVal = nativeEvalCubicPoly(hhat, alpha[i]) + } + + // blinding_eval = prover_message() + blindingSlice, err := nimue.FillNextScalars(1) + if err != nil { + return nil, fmt.Errorf("blinding_eval: %w", err) + } + blindingEval := blindingSlice[0] + + // f_at_alpha = saved_val - rho * blinding_eval + rhoBE := new(big.Int).Mul(rho, blindingEval) + rhoBE.Mod(rhoBE, bn254Modulus) + fAtAlpha := new(big.Int).Sub(savedVal, rhoBE) + fAtAlpha.Mod(fAtAlpha, bn254Modulus) + // Ensure non-negative result (Go's Mod can return negative for negative inputs) + if fAtAlpha.Sign() < 0 { + fAtAlpha.Add(fAtAlpha, bn254Modulus) + } + + return &NativeSumcheckData{ + R: r, + Alpha: alpha, + BlindingEval: blindingEval, + FAtAlpha: fAtAlpha, + }, nil +} + +// --------------------------------------------------------------------------- +// Native zkWHIR verification transcript replay +// --------------------------------------------------------------------------- + +// ZKWhirVerifyParams bundles the config values needed to replay the zkWHIR +// verify transcript. All counts refer to the Rust Config fields. +type ZKWhirVerifyParams struct { + NumPolynomials int // commitment.f_hat.len() (typically 1) + WeightsLen int // number of weight linear forms (includes blinding weight) +} + +// newZKWhirVerifyParams derives the transcript replay parameters from the +// Config and blinded/blinding WHIRParams. The caller only needs to supply +// numPolynomials and weightsLen which depend on the call site. +func newZKWhirVerifyParams(numPolynomials int, hasPublicInputs bool) ZKWhirVerifyParams { + // weightsLen: 3 (A,B,C) + 1 (blinding) = 4 without public inputs + // 3 (A,B,C) + 1 (public) + 1 (blinding) = 5 with public inputs + // The blinding weight is the last one; it is an internal zkWHIR weight used + // to compute numWFoldedEvals from the transcript, but its evaluation is NOT + // in the external evaluations slice passed to NativeWhirVerify. + weightsLen := 4 + if hasPublicInputs { + weightsLen = 5 + } + return ZKWhirVerifyParams{ + NumPolynomials: numPolynomials, + WeightsLen: weightsLen, + } +} + +// NativeZKWhirData holds the transcript values parsed by nativeZKWhirVerify. +type NativeZKWhirData struct { + BlindingChallenge *big.Int + WFoldedBlindingEvals []*big.Int + MaskingChallenge *big.Int + InitialQueryIndices []int + Tau1 *big.Int + Tau2 *big.Int + // Per-gamma, per-polynomial evaluations: [gamma_idx][poly_idx] → (m_eval, g_hat_evals...) + PerGammaEvals [][][]*big.Int + CombinedClaims []*big.Int + BatchedHClaims []*big.Int + // Merkle proof data for each WHIR verification (blinded and blinding). + BlindedMerkleData *whir.WhirMerkleData + BlindingMerkleData *whir.WhirMerkleData + // FinalClaim from the blinded WHIR verification. + BlindedFinalClaim NativeFinalClaim +} + +// nativeIRSCommitVerify replays the initial_committer.verify() transcript +// operations: squeeze in-domain challenge indices, read submatrix hint, read +// Merkle proof hints. +func nativeIRSCommitVerify( + nimue *NativeNimue, + numQueries int, + domainSize int, + foldingFactorPower int, +) ([]int, error) { + // in_domain_challenges: squeeze challenge bytes → query indices + indices, err := nativeGetStirChallenges(nimue, domainSize/foldingFactorPower, numQueries, false) + if err != nil { + return nil, fmt.Errorf("initial in-domain challenges: %w", err) + } + + var submatrix []Fp256 + if err = nimue.ProverHintArk(&submatrix); err != nil { + return nil, fmt.Errorf("initial submatrix: %w", err) + } + + // matrix_commit.verify: read Merkle proof from hints + foldedDomainSize := domainSize / foldingFactorPower + treeHeight := bits.Len(uint(foldedDomainSize)) - 1 + dedupedIndices := make([]int, len(indices)) + copy(dedupedIndices, indices) + sort.Ints(dedupedIndices) + dedupedIndices = dedup(dedupedIndices) + + _, err = consumeMerkleHints(nimue, dedupedIndices, treeHeight) + if err != nil { + return nil, fmt.Errorf("initial merkle: %w", err) + } + + return indices, nil +} + +// nativeZKWhirVerify replays the zkWHIR Config::verify() transcript. +// It parses all transcript messages in the same order as the Rust verifier, +// calling nativeIRSCommitVerify for the initial commitment and NativeWhirVerify +// for the blinded/blinding commitment WHIR verifications. +// +// blindedCommitment and blindingCommitment are the parsed commitments from +// nativeParseBatchedCommitment, converted via NativeCommitmentFromParsed. +// evaluations are the claimed linear form evaluations from the Spartan layer. +func nativeZKWhirVerify( + nimue *NativeNimue, + config Config, + blindedWhirParams WHIRParams, + blindingWhirParams WHIRParams, + params ZKWhirVerifyParams, + blindedCommitment *NativeCommitment, + blindingCommitment *NativeCommitment, + evaluations []*big.Int, +) (*NativeZKWhirData, error) { + data := &NativeZKWhirData{} + + numWitnessVariables := blindedWhirParams.MVParamsNumberOfVariables + interleavingDepth := 1 << blindedWhirParams.FoldingFactorArray[0] + + bc, err := nimue.FillChallengeScalars(1) + if err != nil { + return nil, fmt.Errorf("blinding_challenge: %w", err) + } + data.BlindingChallenge = bc[0] + + // w_folded_blinding_evals = prover_messages_vec(num_w_folded_evals) + // num_w_folded_evals = weights.len() * num_polynomials * (μ + 1) + numWFoldedEvals := params.WeightsLen * params.NumPolynomials * (numWitnessVariables + 1) + wfbe, err := nimue.FillNextScalars(numWFoldedEvals) + if err != nil { + return nil, fmt.Errorf("w_folded_blinding_evals: %w", err) + } + data.WFoldedBlindingEvals = wfbe + + mc, err := nimue.FillChallengeScalars(1) + if err != nil { + return nil, fmt.Errorf("masking_challenge: %w", err) + } + data.MaskingChallenge = mc[0] + + indices, err := nativeIRSCommitVerify( + nimue, + blindedWhirParams.InitialInDomainSamples, + blindedWhirParams.DomainSize, + interleavingDepth, + ) + if err != nil { + return nil, fmt.Errorf("initial_committer: %w", err) + } + data.InitialQueryIndices = indices + + // h_gammas = all_gammas(initial_in_domain.points) + // Each query point expands to interleavingDepth gamma points. + hGammasCount := len(indices) * interleavingDepth + + tau1Slice, err := nimue.FillChallengeScalars(1) + if err != nil { + return nil, fmt.Errorf("tau1: %w", err) + } + data.Tau1 = tau1Slice[0] + + tau2Slice, err := nimue.FillChallengeScalars(1) + if err != nil { + return nil, fmt.Errorf("tau2: %w", err) + } + data.Tau2 = tau2Slice[0] + + // Per-gamma evaluation loop + // For each gamma in h_gammas: + // For each polynomial: + // m_eval = prover_message() + // g_hat_evals = prover_message() × num_witness_variables + evalsPerPoly := 1 + numWitnessVariables // m_eval + g_hat_evals + data.PerGammaEvals = make([][][]*big.Int, hGammasCount) + for g := range hGammasCount { + data.PerGammaEvals[g] = make([][]*big.Int, params.NumPolynomials) + for p := range params.NumPolynomials { + vals, err := nimue.FillNextScalars(evalsPerPoly) + if err != nil { + return nil, fmt.Errorf("gamma %d poly %d evals: %w", g, p, err) + } + data.PerGammaEvals[g][p] = vals } + } - // Consume rounds 1+ for the batched polynomial - witnessRoundHints = consumeWhirDataRoundsOnly(config.WHIRConfigWitness, &merklePaths, &stirAnswers) - } else { - // Single mode - witnessData := consumeWhirData(config.WHIRConfigWitness, &merklePaths, &stirAnswers) - witnessFirstRoundHints = []FirstRoundHint{witnessData.firstRoundMerklePaths} - witnessRoundHints = witnessData + // combined_claims = prover_messages_vec(num_polynomials) + // batched_h_claims = prover_messages_vec(num_polynomials) + data.CombinedClaims, err = nimue.FillNextScalars(params.NumPolynomials) + if err != nil { + return nil, fmt.Errorf("combined_claims: %w", err) } - hints := Hints{ - spartanHidingHint: hidingSpartanData, - WitnessFirstRoundHints: witnessFirstRoundHints, - WitnessRoundHints: witnessRoundHints, + data.BatchedHClaims, err = nimue.FillNextScalars(params.NumPolynomials) + if err != nil { + return nil, fmt.Errorf("batched_h_claims: %w", err) } - err = verifyCircuit(deferred, config, hints, pk, vk, claimedEvaluations, claimedEvaluations2, publicWeightsEvaluations, r1cs, interner, buildOps, config.PublicInputs) + // blinded_commitment.verify() — full WHIR verification + // Verifies the witness polynomial commitment using NativeWhirVerify. + // numLinearForms excludes the blinding weight (last in WeightsLen) because + // the blinding evaluation is not part of the external evaluations slice. + // + // Build modified_evaluations = evaluations + m_evals, matching + // Rust whir_zk/verifier.rs: modified_evaluations[i] = evaluations[i] + m_evals[i] + // where m_evals[i] is the first element of each (μ+1)-sized block in wFoldedBlindingEvals. + blockSize := numWitnessVariables + 1 + modifiedEvaluations := make([]*big.Int, len(evaluations)) + for i, eval := range evaluations { + mEval := data.WFoldedBlindingEvals[i*blockSize] + modifiedEvaluations[i] = frAdd(eval, mEval) + } + blindedResult, err := NativeWhirVerify( + nimue, + blindedWhirParams, + config.BlindedCommitmentWhirConfig, + []*NativeCommitment{blindedCommitment}, + modifiedEvaluations, + ) if err != nil { - return fmt.Errorf("verification failed: %w", err) + return nil, fmt.Errorf("blinded_commitment verify: %w", err) } - return nil + + // blinding_commitment.verify() — full WHIR verification + // Verifies the blinding polynomial commitment using NativeWhirVerify. + // The evaluations are all_expected_blinding_claims, which is the + // concatenation of: + // - expected_batched_blinding_subproof_claims: accumulated m_claims + // and g_hat_claims from the per-gamma evaluation loop, interleaved + // as [m_0, g_hat_0..., m_1, g_hat_1..., ...] (num_polynomials * + // (1 + num_witness_variables) elements) + // - w_folded_blinding_evals: parsed from transcript at step 2 + + // Accumulate m_claims[p] = Σ_g tau2^g * PerGammaEvals[g][p][0] + // and g_hat_claims[p][j] = Σ_g tau2^g * PerGammaEvals[g][p][j+1] + mClaims := make([]*big.Int, params.NumPolynomials) + gHatClaims := make([][]*big.Int, params.NumPolynomials) + for p := range params.NumPolynomials { + mClaims[p] = new(big.Int) + gHatClaims[p] = make([]*big.Int, numWitnessVariables) + for j := range numWitnessVariables { + gHatClaims[p][j] = new(big.Int) + } + } + tau2Power := big.NewInt(1) + for g := range hGammasCount { + for p := range params.NumPolynomials { + evals := data.PerGammaEvals[g][p] + mClaims[p] = frAdd(mClaims[p], frMul(tau2Power, evals[0])) + for j := range numWitnessVariables { + gHatClaims[p][j] = frAdd(gHatClaims[p][j], frMul(tau2Power, evals[j+1])) + } + } + tau2Power = frMul(tau2Power, data.Tau2) + } + + // Build subproof_claims: [m_0, g_hat_0..., m_1, g_hat_1..., ...] + subproofClaims := make([]*big.Int, 0, params.NumPolynomials*(1+numWitnessVariables)) + for p := range params.NumPolynomials { + subproofClaims = append(subproofClaims, mClaims[p]) + subproofClaims = append(subproofClaims, gHatClaims[p]...) + } + + // all_expected_blinding_claims = subproof_claims ++ w_folded_blinding_evals + blindingEvaluations := append(subproofClaims, data.WFoldedBlindingEvals...) + + blindingResult, err := NativeWhirVerify( + nimue, + blindingWhirParams, + config.BlindingCommitmentWhirConfig, + []*NativeCommitment{blindingCommitment}, + blindingEvaluations, + ) + if err != nil { + return nil, fmt.Errorf("blinding_commitment verify: %w", err) + } + + data.BlindedMerkleData = blindedResult.MerkleData + data.BlindingMerkleData = blindingResult.MerkleData + data.BlindedFinalClaim = blindedResult.FinalClaim + + return data, nil +} + +// --------------------------------------------------------------------------- +// Native protocol replay helpers +// --------------------------------------------------------------------------- +func nativeParseBatchedCommitment(nimue *NativeNimue, whirParams WHIRParams) ( + rootHash *big.Int, + oodPoints []*big.Int, + oodAnswers [][]*big.Int, + err error, +) { + roots, e := nimue.FillNextScalars(1) + if e != nil { + err = e + return + } + rootHash = roots[0] + + oodSamples := whirParams.RoundParametersOODSamples[0] + oodPts, e := nimue.FillChallengeScalars(oodSamples) + if e != nil { + err = e + return + } + oodPoints = oodPts + + oodAnswers = make([][]*big.Int, whirParams.BatchSize*oodSamples) + for i := range whirParams.BatchSize * oodSamples { + ans, e := nimue.FillNextScalars(1) + if e != nil { + err = e + return + } + oodAnswers[i] = ans + } + + return } +// --------------------------------------------------------------------------- +// Key management utilities (unchanged) +// --------------------------------------------------------------------------- + func GetPkAndVkFromPath(pkPath string, vkPath string) (*groth16.ProvingKey, *groth16.VerifyingKey, error) { var pk *groth16.ProvingKey var vk *groth16.VerifyingKey @@ -266,3 +683,93 @@ func GetR1csFromUrl(r1csUrl string) ([]byte, error) { log.Printf("Successfully downloaded") return r1csFile, nil } + +func keysFromFiles(pkPath string, vkPath string) (groth16.ProvingKey, groth16.VerifyingKey, error) { + pkFile, err := os.Open(pkPath) + if err != nil { + return nil, nil, fmt.Errorf("failed to open proving key file: %w", err) + } + defer func() { + if err := pkFile.Close(); err != nil { + log.Printf("failed to close proving key file: %v", err) + } + }() + + pk := groth16.NewProvingKey(ecc.BN254) + _, err = pk.ReadFrom(pkFile) + if err != nil { + return nil, nil, fmt.Errorf("failed to restore proving key: %w", err) + } + + vkFile, err := os.Open(vkPath) + if err != nil { + return nil, nil, fmt.Errorf("failed to open verifying key file: %w", err) + } + defer func() { + if err := vkFile.Close(); err != nil { + log.Printf("failed to close verifying key file: %v", err) + } + }() + + vk := groth16.NewVerifyingKey(ecc.BN254) + _, err = vk.ReadFrom(vkFile) + if err != nil { + return nil, nil, fmt.Errorf("failed to restore verifying key: %w", err) + } + + return pk, vk, nil +} + +func keysFromUrl(pkUrl string, vkUrl string) (groth16.ProvingKey, groth16.VerifyingKey, error) { + vkBytes, err := downloadFromUrl(vkUrl) + if err != nil { + return nil, nil, fmt.Errorf("failed to download verifying key: %w", err) + } + log.Printf("Downloaded VK") + + vk := groth16.NewVerifyingKey(ecc.BN254) + _, err = vk.UnsafeReadFrom(bytes.NewReader(vkBytes)) + if err != nil { + return nil, nil, fmt.Errorf("failed to deserialize verifying key: %w", err) + } + log.Printf("Loaded VK") + + pkBytes, err := downloadFromUrl(pkUrl) + if err != nil { + return nil, nil, fmt.Errorf("failed to download proving key: %v", err) + } + log.Printf("Downloaded PK") + + pk := groth16.NewProvingKey(ecc.BN254) + _, err = pk.UnsafeReadFrom(bytes.NewReader(pkBytes)) + if err != nil { + return nil, nil, fmt.Errorf("failed to deserialize proving key: %w", err) + } + log.Printf("Loaded PK") + + return pk, vk, nil +} + +func downloadFromUrl(url string) ([]byte, error) { + resp, err := http.Get(url) //nolint:gosec + if err != nil { + return nil, fmt.Errorf("failed to download from %s: %w", url, err) + } + defer func() { + if closeErr := resp.Body.Close(); closeErr != nil { + log.Printf("Warning: failed to close response body: %v", closeErr) + } + }() + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("HTTP error %d when downloading from %s", resp.StatusCode, url) + } + + buffer := &bytes.Buffer{} + _, err = io.Copy(buffer, resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to copy to buffer: %w", err) + } + + return buffer.Bytes(), nil +} diff --git a/recursive-verifier/app/circuit/matrix_evaluation.go b/recursive-verifier/app/circuit/matrix_evaluation.go index edf0aaad3..979fca622 100644 --- a/recursive-verifier/app/circuit/matrix_evaluation.go +++ b/recursive-verifier/app/circuit/matrix_evaluation.go @@ -1,10 +1,15 @@ package circuit import ( + "bytes" + "encoding/hex" "fmt" "math/big" + "reilabs/whir-verifier-circuit/app/typeConverters" + "github.com/consensys/gnark/frontend" + arkSerialize "github.com/reilabs/go-ark-serialize" ) type SparseMatrix struct { @@ -91,6 +96,62 @@ type MatrixCell struct { value *big.Int } +// ParseInterner decodes the hex-encoded interner from R1CS JSON into an Interner. +func ParseInterner(r1cs R1CS) (Interner, error) { + internerBytes, err := hex.DecodeString(r1cs.Interner.Values) + if err != nil { + return Interner{}, fmt.Errorf("decode interner hex: %w", err) + } + var interner Interner + _, err = arkSerialize.CanonicalDeserializeWithMode( + bytes.NewReader(internerBytes), &interner, false, false, + ) + if err != nil { + return Interner{}, fmt.Errorf("deserialize interner: %w", err) + } + return interner, nil +} + +// buildSparseMatrixCells converts a SparseMatrix + Interner into a flat []MatrixCell. +func buildSparseMatrixCells(sm SparseMatrix, interner Interner) ([]MatrixCell, error) { + colIndices := sm.DecodeColIndices() + if colIndices == nil { + return nil, fmt.Errorf("failed to decode column indices: inconsistent data") + } + cells := make([]MatrixCell, len(sm.Values)) + for i := range len(sm.RowIndices) { + end := len(sm.Values) - 1 + if i < len(sm.RowIndices)-1 { + end = int(sm.RowIndices[i+1] - 1) + } + for j := int(sm.RowIndices[i]); j <= end; j++ { + cells[j] = MatrixCell{ + row: i, + column: int(colIndices[j]), + value: typeConverters.LimbsToBigIntMod(interner.Values[sm.Values[j]].Limbs), + } + } + } + return cells, nil +} + +// buildR1CSMatrixCells builds MatrixCell slices for all three R1CS matrices. +func buildR1CSMatrixCells(r1cs R1CS, interner Interner) ([]MatrixCell, []MatrixCell, []MatrixCell, error) { + a, err := buildSparseMatrixCells(r1cs.A, interner) + if err != nil { + return nil, nil, nil, fmt.Errorf("matrix A: %w", err) + } + b, err := buildSparseMatrixCells(r1cs.B, interner) + if err != nil { + return nil, nil, nil, fmt.Errorf("matrix B: %w", err) + } + c, err := buildSparseMatrixCells(r1cs.C, interner) + if err != nil { + return nil, nil, nil, fmt.Errorf("matrix C: %w", err) + } + return a, b, c, nil +} + func evaluateR1CSMatrixExtension(api frontend.API, circuit *Circuit, rowRand []frontend.Variable, colRand []frontend.Variable) []frontend.Variable { ansA := frontend.Variable(0) ansB := frontend.Variable(0) @@ -112,53 +173,189 @@ func evaluateR1CSMatrixExtension(api frontend.API, circuit *Circuit, rowRand []f return []frontend.Variable{ansA, ansB, ansC} } -func evaluateR1CSMatrixExtensionBatch( +// evaluateFoldedR1CSMatrixExtension computes the MLE of folded R1CS weight +// covectors at the given evaluation point. Folding wraps column indices modulo +// maskSize = 2^(numBlindingVars+1), producing a covector of that size. The MLE +// of the folded covector at foldedPoint (with numBlindingVars+1 variables) is: +// +// w_folded_mle(p) = Σ_{(row,col)} M[row,col] * eq(alpha, row) * L_{col mod mask}(p) +// +// Returns [A_folded_mle, B_folded_mle, C_folded_mle]. +func evaluateFoldedR1CSMatrixExtension( api frontend.API, circuit *Circuit, rowRand []frontend.Variable, - colRand []frontend.Variable, - w1Size int, + foldedPoint []frontend.Variable, + maskSize int, ) []frontend.Variable { - // Returns [Az1, Bz1, Cz1, Az2, Bz2, Cz2] rowEval := calculateEQOverBooleanHypercube(api, rowRand) - colEval := calculateEQOverBooleanHypercube(api, colRand) + foldedColEval := calculateEQOverBooleanHypercube(api, foldedPoint) - ans := make([]frontend.Variable, 6) - for i := range ans { - ans[i] = frontend.Variable(0) + ansA := frontend.Variable(0) + ansB := frontend.Variable(0) + ansC := frontend.Variable(0) + + for i := range circuit.MatrixA { + cell := circuit.MatrixA[i] + ansA = api.Add(ansA, api.Mul(cell.value, api.Mul(rowEval[cell.row], foldedColEval[cell.column%maskSize]))) } + for i := range circuit.MatrixB { + cell := circuit.MatrixB[i] + ansB = api.Add(ansB, api.Mul(cell.value, api.Mul(rowEval[cell.row], foldedColEval[cell.column%maskSize]))) + } + for i := range circuit.MatrixC { + cell := circuit.MatrixC[i] + ansC = api.Add(ansC, api.Mul(cell.value, api.Mul(rowEval[cell.row], foldedColEval[cell.column%maskSize]))) + } + + return []frontend.Variable{ansA, ansB, ansC} +} + +// evaluateR1CSMatrixExtensionSplit evaluates the R1CS matrix MLEs at (rowRand, colRand) +// but splits contributions by column: columns < w1Size contribute to the first set +// [A1, B1, C1], columns >= w1Size contribute to the second set [A2, B2, C2] with +// column indices shifted by w1Size. +// Returns [A1, B1, C1, A2, B2, C2]. +func evaluateR1CSMatrixExtensionSplit( + api frontend.API, + circuit *Circuit, + rowRand []frontend.Variable, + colRand1 []frontend.Variable, + colRand2 []frontend.Variable, + w1Size int, +) ([]frontend.Variable, []frontend.Variable) { + rowEval := calculateEQOverBooleanHypercube(api, rowRand) + + eval1 := colRand1 != nil + eval2 := colRand2 != nil + + var colEval1, colEval2 []frontend.Variable + if eval1 { + colEval1 = calculateEQOverBooleanHypercube(api, colRand1) + } + if eval2 { + colEval2 = calculateEQOverBooleanHypercube(api, colRand2) + } + + ans1 := []frontend.Variable{frontend.Variable(0), frontend.Variable(0), frontend.Variable(0)} + ans2 := []frontend.Variable{frontend.Variable(0), frontend.Variable(0), frontend.Variable(0)} for i := range circuit.MatrixA { - col := circuit.MatrixA[i].column - row := circuit.MatrixA[i].row - val := circuit.MatrixA[i].value + cell := circuit.MatrixA[i] + contrib := api.Mul(cell.value, rowEval[cell.row]) + if cell.column < w1Size { + if eval1 { + ans1[0] = api.Add(ans1[0], api.Mul(contrib, colEval1[cell.column])) + } + } else { + if eval2 { + ans2[0] = api.Add(ans2[0], api.Mul(contrib, colEval2[cell.column-w1Size])) + } + } + } + + for i := range circuit.MatrixB { + cell := circuit.MatrixB[i] + contrib := api.Mul(cell.value, rowEval[cell.row]) + if cell.column < w1Size { + if eval1 { + ans1[1] = api.Add(ans1[1], api.Mul(contrib, colEval1[cell.column])) + } + } else { + if eval2 { + ans2[1] = api.Add(ans2[1], api.Mul(contrib, colEval2[cell.column-w1Size])) + } + } + } - if col < w1Size { - ans[0] = api.Add(ans[0], api.Mul(val, api.Mul(rowEval[row], colEval[col]))) + for i := range circuit.MatrixC { + cell := circuit.MatrixC[i] + contrib := api.Mul(cell.value, rowEval[cell.row]) + if cell.column < w1Size { + if eval1 { + ans1[2] = api.Add(ans1[2], api.Mul(contrib, colEval1[cell.column])) + } } else { - ans[3] = api.Add(ans[3], api.Mul(val, api.Mul(rowEval[row], colEval[col-w1Size]))) + if eval2 { + ans2[2] = api.Add(ans2[2], api.Mul(contrib, colEval2[cell.column-w1Size])) + } + } + } + + return ans1, ans2 +} + +// evaluateFoldedR1CSMatrixExtensionSplit computes the folded R1CS weight MLE +// split by column, analogous to evaluateR1CSMatrixExtensionSplit but with +// column indices taken modulo maskSize. +func evaluateFoldedR1CSMatrixExtensionSplit( + api frontend.API, + circuit *Circuit, + rowRand []frontend.Variable, + foldedPoint1 []frontend.Variable, + foldedPoint2 []frontend.Variable, + maskSize int, + w1Size int, +) ([]frontend.Variable, []frontend.Variable) { + rowEval := calculateEQOverBooleanHypercube(api, rowRand) + + eval1 := foldedPoint1 != nil + eval2 := foldedPoint2 != nil + + var foldedColEval1, foldedColEval2 []frontend.Variable + if eval1 { + foldedColEval1 = calculateEQOverBooleanHypercube(api, foldedPoint1) + } + if eval2 { + foldedColEval2 = calculateEQOverBooleanHypercube(api, foldedPoint2) + } + + ans1 := []frontend.Variable{frontend.Variable(0), frontend.Variable(0), frontend.Variable(0)} + ans2 := []frontend.Variable{frontend.Variable(0), frontend.Variable(0), frontend.Variable(0)} + + for i := range circuit.MatrixA { + cell := circuit.MatrixA[i] + contrib := api.Mul(cell.value, rowEval[cell.row]) + if cell.column < w1Size { + if eval1 { + ans1[0] = api.Add(ans1[0], api.Mul(contrib, foldedColEval1[cell.column%maskSize])) + } + } else { + if eval2 { + ans2[0] = api.Add(ans2[0], api.Mul(contrib, foldedColEval2[(cell.column-w1Size)%maskSize])) + } } } for i := range circuit.MatrixB { - col := circuit.MatrixB[i].column - if col < w1Size { - ans[1] = api.Add(ans[1], api.Mul(circuit.MatrixB[i].value, api.Mul(rowEval[circuit.MatrixB[i].row], colEval[col]))) + cell := circuit.MatrixB[i] + contrib := api.Mul(cell.value, rowEval[cell.row]) + if cell.column < w1Size { + if eval1 { + ans1[1] = api.Add(ans1[1], api.Mul(contrib, foldedColEval1[cell.column%maskSize])) + } } else { - ans[4] = api.Add(ans[4], api.Mul(circuit.MatrixB[i].value, api.Mul(rowEval[circuit.MatrixB[i].row], colEval[col-w1Size]))) + if eval2 { + ans2[1] = api.Add(ans2[1], api.Mul(contrib, foldedColEval2[(cell.column-w1Size)%maskSize])) + } } } for i := range circuit.MatrixC { - col := circuit.MatrixC[i].column - if col < w1Size { - ans[2] = api.Add(ans[2], api.Mul(circuit.MatrixC[i].value, api.Mul(rowEval[circuit.MatrixC[i].row], colEval[col]))) + cell := circuit.MatrixC[i] + contrib := api.Mul(cell.value, rowEval[cell.row]) + if cell.column < w1Size { + if eval1 { + ans1[2] = api.Add(ans1[2], api.Mul(contrib, foldedColEval1[cell.column%maskSize])) + } } else { - ans[5] = api.Add(ans[5], api.Mul(circuit.MatrixC[i].value, api.Mul(rowEval[circuit.MatrixC[i].row], colEval[col-w1Size]))) + if eval2 { + ans2[2] = api.Add(ans2[2], api.Mul(contrib, foldedColEval2[(cell.column-w1Size)%maskSize])) + } } } - return ans + return ans1, ans2 } func calculateEQOverBooleanHypercube(api frontend.API, r []frontend.Variable) []frontend.Variable { @@ -180,9 +377,42 @@ func calculateEQOverBooleanHypercube(api frontend.API, r []frontend.Variable) [] return ans } +// blindingCovectorMLE computes the MLE evaluation of the blinding OffsetCovector +// at the given evaluation point. This mirrors Rust's OffsetCovector::mle_evaluate +// for the blinding polynomial weights expand_powers::<4>(alpha) at offset w1Size. +// +// weights[4*j + k] = alpha[j]^k for k in 0..4, placed at domain index w1Size + 4*j + k. +func blindingCovectorMLE(api frontend.API, alpha []frontend.Variable, w1Size int, evaluationPoint []frontend.Variable) frontend.Variable { + n := len(evaluationPoint) + result := frontend.Variable(0) + + // expand_powers::<4>(alpha) produces [1, α₀, α₀², α₀³, 1, α₁, α₁², α₁³, ...] + for j := range alpha { + alphaPow := frontend.Variable(1) // alpha[j]^0 + for k := 0; k < 4; k++ { + idx := w1Size + 4*j + k + + // Compute Lagrange basis: Π_bit point[bit] or (1-point[bit]) + basis := frontend.Variable(1) + for b := 0; b < n; b++ { + if (idx>>(n-1-b))&1 == 1 { + basis = api.Mul(basis, evaluationPoint[b]) + } else { + basis = api.Mul(basis, api.Sub(1, evaluationPoint[b])) + } + } + + result = api.Add(result, api.Mul(alphaPow, basis)) + alphaPow = api.Mul(alphaPow, alpha[j]) + } + } + return result +} + // geometricTill evaluates the multilinear extension of the geometric vector // [1, x, x^2, ..., x^{n-1}, 0, ..., 0] at point foldingRandomness. -// This is O(k) constraints where k = len(foldingRandomness) +// This is O(k) constraints where k = len(foldingRandomness). +// Used for the public weight MLE evaluation (make_public_weight in Rust). func geometricTill(api frontend.API, x frontend.Variable, n int, foldingRandomness []frontend.Variable) frontend.Variable { k := len(foldingRandomness) if n <= 0 || n > (1<(alpha) at offset w1Size: +// +// weights[4*j + k] = alpha[j]^k at domain index w1Size + 4*j + k +// +// Folded: result = Σ_{j,k} alpha[j]^k * L_{(w1Size + 4*j + k) mod maskSize}(evalPoint) +// challengeWeightMLE computes the MLE evaluation of the challenge weight covector. +// The challenge weight is a sparse covector with entries (offset_i, x^i). +// MLE at r = Σ_i x^i * eq(offset_i, r) +func challengeWeightMLE( + api frontend.API, + x frontend.Variable, + challengeOffsets []int, + evaluationPoint []frontend.Variable, +) frontend.Variable { + lagrange := calculateEQOverBooleanHypercube(api, evaluationPoint) + result := frontend.Variable(0) + xPow := frontend.Variable(1) + for _, offset := range challengeOffsets { + result = api.Add(result, api.Mul(xPow, lagrange[offset])) + xPow = api.Mul(xPow, x) + } + return result +} + +// foldedChallengeWeightMLE computes the MLE of the challenge weight covector +// folded modulo maskSize, evaluated at evalPoint. +func foldedChallengeWeightMLE( + api frontend.API, + x frontend.Variable, + challengeOffsets []int, + evalPoint []frontend.Variable, + maskSize int, +) frontend.Variable { + lagrange := calculateEQOverBooleanHypercube(api, evalPoint) + result := frontend.Variable(0) + xPow := frontend.Variable(1) + for _, offset := range challengeOffsets { + idx := offset % maskSize + result = api.Add(result, api.Mul(xPow, lagrange[idx])) + xPow = api.Mul(xPow, x) + } + return result +} + +func foldedBlindingCovectorMLE( + api frontend.API, + alpha []frontend.Variable, + w1Size int, + evalPoint []frontend.Variable, + maskSize int, +) frontend.Variable { + lagrange := calculateEQOverBooleanHypercube(api, evalPoint) + + result := frontend.Variable(0) + for j := range alpha { + alphaPow := frontend.Variable(1) + for k := 0; k < 4; k++ { + idx := (w1Size + 4*j + k) % maskSize + result = api.Add(result, api.Mul(alphaPow, lagrange[idx])) + alphaPow = api.Mul(alphaPow, alpha[j]) + } + } + return result +} diff --git a/recursive-verifier/app/circuit/mt.go b/recursive-verifier/app/circuit/mt.go deleted file mode 100644 index cc30a9efd..000000000 --- a/recursive-verifier/app/circuit/mt.go +++ /dev/null @@ -1,91 +0,0 @@ -package circuit - -import ( - "reilabs/whir-verifier-circuit/app/typeConverters" - - "github.com/consensys/gnark/frontend" - "github.com/consensys/gnark/std/math/uints" -) - -func newMerkle( - hint Hint, - isContainer bool, -) Merkle { - totalAuthPath := make([][][]frontend.Variable, len(hint.merklePaths)) - totalLeaves := make([][][]frontend.Variable, len(hint.merklePaths)) - totalLeafSiblingHashes := make([][]frontend.Variable, len(hint.merklePaths)) - totalLeafIndexes := make([][]uints.U64, len(hint.merklePaths)) - - for i, merkle_path := range hint.merklePaths { - numOfLeavesProved := len(merkle_path.Proofs) - treeHeight := len(merkle_path.Proofs[0].AuthPath) - - totalAuthPath[i] = make([][]frontend.Variable, numOfLeavesProved) - totalLeaves[i] = make([][]frontend.Variable, numOfLeavesProved) - totalLeafSiblingHashes[i] = make([]frontend.Variable, numOfLeavesProved) - - for j := range numOfLeavesProved { - totalAuthPath[i][j] = make([]frontend.Variable, treeHeight) - totalLeaves[i][j] = make([]frontend.Variable, len(hint.stirAnswers[i][j])) - } - - totalLeafIndexes[i] = make([]uints.U64, numOfLeavesProved) - - if !isContainer { - for j := range numOfLeavesProved { - proof := merkle_path.Proofs[j] - - for z := range treeHeight { - totalAuthPath[i][j][z] = typeConverters. - LittleEndianUint8ToBigInt(proof.AuthPath[treeHeight-1-z].KeccakDigest[:]) - } - - totalLeafSiblingHashes[i][j] = typeConverters. - LittleEndianUint8ToBigInt(proof.LeafSiblingHash.KeccakDigest[:]) - totalLeafIndexes[i][j] = uints.NewU64(proof.LeafIndex) - - for k := range hint.stirAnswers[i][j] { - input := hint.stirAnswers[i][j][k] - totalLeaves[i][j][k] = typeConverters.LimbsToBigIntMod(input.Limbs) - } - } - } - } - - return Merkle{ - Leaves: totalLeaves, - LeafIndexes: totalLeafIndexes, - LeafSiblingHashes: totalLeafSiblingHashes, - AuthPaths: totalAuthPath, - } -} - -func oodAnswers( - api frontend.API, - answers [][]frontend.Variable, - randomness frontend.Variable, -) (result []frontend.Variable) { - if len(answers) == 0 { - return nil - } - - multiplier := frontend.Variable(1) - - first := answers[0] - result = make([]frontend.Variable, len(first)) - for j := range first { - result[j] = api.Mul(first[j], multiplier) - } - - for i := 1; i < len(answers); i++ { - multiplier = api.Mul(multiplier, randomness) - - round := answers[i] - for j := range round { - term := api.Mul(round[j], multiplier) - result[j] = api.Add(result[j], term) - } - } - - return result -} diff --git a/recursive-verifier/app/circuit/mtUtilities.go b/recursive-verifier/app/circuit/mtUtilities.go deleted file mode 100644 index 5f44e6507..000000000 --- a/recursive-verifier/app/circuit/mtUtilities.go +++ /dev/null @@ -1,140 +0,0 @@ -package circuit - -import ( - "reilabs/whir-verifier-circuit/app/utilities" - - "github.com/consensys/gnark/frontend" - "github.com/consensys/gnark/std/math/uints" - gnarkNimue "github.com/reilabs/gnark-nimue" - skyscraper "github.com/reilabs/gnark-skyscraper" -) - -func initialSumcheck( - api frontend.API, - arthur gnarkNimue.Arthur, - batchingRandomness frontend.Variable, - initialOODQueries []frontend.Variable, - initialOODAnswers []frontend.Variable, - whirParams WHIRParams, - linearStatementEvaluations [][]frontend.Variable, -) (InitialSumcheckData, frontend.Variable, []frontend.Variable, error) { - - initialCombinationRandomness, err := GenerateCombinationRandomness(api, arthur, len(initialOODAnswers)+len(linearStatementEvaluations[0])) - if err != nil { - return InitialSumcheckData{}, nil, nil, err - } - - combinedLinearStatementEvaluations := make([]frontend.Variable, len(linearStatementEvaluations[0])) //[0, 1, 2] - for evaluationIndex := range len(linearStatementEvaluations[0]) { - sum := frontend.Variable(0) - multiplier := frontend.Variable(1) - for j := range len(linearStatementEvaluations) { - sum = api.Add(sum, api.Mul(linearStatementEvaluations[j][evaluationIndex], multiplier)) - multiplier = api.Mul(multiplier, batchingRandomness) - } - combinedLinearStatementEvaluations[evaluationIndex] = sum - } - OODAnswersAndStatementEvaluations := append(initialOODAnswers, combinedLinearStatementEvaluations...) - lastEval := utilities.DotProduct(api, initialCombinationRandomness, OODAnswersAndStatementEvaluations) - - initialSumcheckFoldingRandomness, lastEval, err := runWhirSumcheckRounds(api, lastEval, arthur, whirParams.FoldingFactorArray[0], 3) - if err != nil { - return InitialSumcheckData{}, nil, nil, err - } - - return InitialSumcheckData{ - InitialOODQueries: initialOODQueries, - InitialCombinationRandomness: initialCombinationRandomness, - }, lastEval, initialSumcheckFoldingRandomness, nil -} - -func parseBatchedCommitment(arthur gnarkNimue.Arthur, whir_params WHIRParams) (frontend.Variable, frontend.Variable, []frontend.Variable, [][]frontend.Variable, error) { - rootHash := make([]frontend.Variable, 1) - if err := arthur.FillNextScalars(rootHash); err != nil { - return nil, nil, nil, [][]frontend.Variable{}, err - } - oodPoints := make([]frontend.Variable, 1) - oodAnswers := make([][]frontend.Variable, whir_params.BatchSize) - - if err := arthur.FillChallengeScalars(oodPoints); err != nil { - return nil, nil, nil, nil, err - } - for i := range whir_params.BatchSize { - oodAnswer := make([]frontend.Variable, 1) - - if err := arthur.FillNextScalars(oodAnswer); err != nil { - return nil, nil, nil, nil, err - } - oodAnswers[i] = oodAnswer - } - - batchingRandomness := make([]frontend.Variable, 1) - if err := arthur.FillChallengeScalars(batchingRandomness); err != nil { - return nil, 0, nil, nil, err - } - return rootHash[0], batchingRandomness[0], oodPoints, oodAnswers, nil -} - -func generateFinalCoefficientsAndRandomnessPoints(api frontend.API, arthur gnarkNimue.Arthur, whir_params WHIRParams, circuit Merkle, uapi *uints.BinaryField[uints.U64], sc *skyscraper.Skyscraper, domainSize int, expDomainGenerator frontend.Variable) ([]frontend.Variable, []frontend.Variable, error) { - finalCoefficients := make([]frontend.Variable, 1< with byte-level position +// tracking. State is 64 bytes (2 field elements), rate R = 32 bytes. +// --------------------------------------------------------------------------- + +const spongeRate = 32 + +type NativeSponge struct { + state [64]byte + absorbPos int // 0..spongeRate (byte-level) + squeezePos int // 0..spongeRate (byte-level) +} + +func newNativeSponge() *NativeSponge { + return &NativeSponge{ + squeezePos: spongeRate, + } +} + +func (s *NativeSponge) permute() { + left := leBytesToNativeBigInt(s.state[:32]) + right := leBytesToNativeBigInt(s.state[32:]) + st := [2]*big.Int{left, right} + nativePermuteV2(&st) + lBytes := nativeBigIntToLeBytes(st[0]) + rBytes := nativeBigIntToLeBytes(st[1]) + copy(s.state[:32], lBytes[:]) + copy(s.state[32:], rBytes[:]) +} + +// Absorb writes input bytes into the rate portion of the state, permuting +// when the rate is full. Matches DuplexSponge::absorb exactly. +func (s *NativeSponge) Absorb(input []byte) { + s.squeezePos = spongeRate + + for len(input) > 0 { + if s.absorbPos == spongeRate { + s.permute() + s.absorbPos = 0 + } + chunkLen := min(len(input), spongeRate-s.absorbPos) + copy(s.state[s.absorbPos:s.absorbPos+chunkLen], input[:chunkLen]) + s.absorbPos += chunkLen + input = input[chunkLen:] + } +} + +// Squeeze reads output bytes from the rate portion of the state, permuting +// when the rate is exhausted. Matches DuplexSponge::squeeze exactly. +func (s *NativeSponge) Squeeze(output []byte) { + if len(output) == 0 { + return + } + s.absorbPos = 0 + + if s.squeezePos == spongeRate { + s.squeezePos = 0 + s.permute() + } + chunkLen := min(len(output), spongeRate-s.squeezePos) + copy(output[:chunkLen], s.state[s.squeezePos:s.squeezePos+chunkLen]) + s.squeezePos += chunkLen + s.Squeeze(output[chunkLen:]) +} + +// nativeCompress computes the Skyscraper compression function: +// permute(l, r), then add back the initial l (Davies-Meyer feed-forward). +// This matches skyscraper::reference::compress on the Rust side. +func nativeCompress(l, r *big.Int) *big.Int { + t := new(big.Int).Set(l) + state := [2]*big.Int{new(big.Int).Set(l), new(big.Int).Set(r)} + nativePermuteV2(&state) + result := new(big.Int).Add(state[0], t) + result.Mod(result, bn254Modulus) + return result +} + +// nativePublicInputsHashBytes computes the public inputs hash as 32 LE bytes, +// matching PublicInputs::hash_bytes() on the Rust side. +func nativePublicInputsHashBytes(publicInputs []*big.Int) [32]byte { + var hash *big.Int + switch len(publicInputs) { + case 0: + hash = big.NewInt(0) + case 1: + hash = nativeCompress(publicInputs[0], big.NewInt(0)) + default: + hash = new(big.Int).Set(publicInputs[0]) + for i := 1; i < len(publicInputs); i++ { + hash = nativeCompress(hash, publicInputs[i]) + } + } + return nativeBigIntToLeBytes(hash) +} + +// InitFromProtocolID initializes the sponge by absorbing the 64-byte protocol_id, +// the 32-byte session_id, and the 32-byte instance as raw bytes. This matches +// spongefish's DomainSeparator initialization which absorbs raw bytes via public_message. +func (s *NativeSponge) InitFromProtocolID(protocolID [64]byte, sessionID []byte, instance [32]byte) { + s.state = [64]byte{} + s.absorbPos = 0 + s.squeezePos = spongeRate + + // Absorb protocol ID as raw bytes (64 bytes) + s.Absorb(protocolID[:]) + // Absorb session ID as raw bytes (32 bytes, zero-padded if needed) + var sessionBuf [32]byte + if len(sessionID) >= 32 { + copy(sessionBuf[:], sessionID[:32]) + } + s.Absorb(sessionBuf[:]) + // Absorb instance as raw bytes (32 bytes) + s.Absorb(instance[:]) +} + +// AbsorbFr absorbs a field element as 32 LE bytes. +func (s *NativeSponge) AbsorbFr(val *big.Int) { + leBytes := nativeBigIntToLeBytes(val) + s.Absorb(leBytes[:]) +} + +// SqueezeFr squeezes 32 bytes and interprets them as a LE field element. +func (s *NativeSponge) SqueezeFr() *big.Int { + var buf [32]byte + s.Squeeze(buf[:]) + return leBytesToNativeBigInt(buf[:]) +} + +// leBytesToBigIntUnreduced interprets b as a little-endian integer without reducing mod p. +func leBytesToBigIntUnreduced(b []byte) *big.Int { + val := new(big.Int) + for i := len(b) - 1; i >= 0; i-- { + val.Lsh(val, 8) + val.Or(val, big.NewInt(int64(b[i]))) + } + return val +} + +func leBytesToNativeBigInt(b []byte) *big.Int { + val := leBytesToBigIntUnreduced(b) + val.Mod(val, bn254Modulus) + return val +} + +func nativeBigIntToLeBytes(v *big.Int) [32]byte { + var buf [32]byte + vv := new(big.Int).Set(v) + vv.Mod(vv, bn254Modulus) + for i := 0; i < 32; i++ { + var m big.Int + vv.DivMod(vv, big.NewInt(256), &m) + buf[i] = byte(m.Int64()) + } + return buf +} + +// --------------------------------------------------------------------------- +// NativeNimue: native transcript reader mirroring the in-circuit Nimue. +// Reads scalars from nargString (prover messages), squeezes challenges +// from the sponge, and reads hints from a separate buffer. +// --------------------------------------------------------------------------- + +type NativeNimue struct { + sponge *NativeSponge + nargString []byte + hints *bytes.Reader +} + +func NewNativeNimue(protocolID [64]byte, sessionID []byte, instance [32]byte, nargString []byte, hints []byte) *NativeNimue { + sponge := newNativeSponge() + sponge.InitFromProtocolID(protocolID, sessionID, instance) + return &NativeNimue{ + sponge: sponge, + nargString: nargString, + hints: bytes.NewReader(hints), + } +} + +// FillNextScalars reads n field elements (32 bytes each, LE) from the +// transcript and absorbs them into the sponge. +func (a *NativeNimue) FillNextScalars(n int) ([]*big.Int, error) { + out := make([]*big.Int, n) + for i := range n { + if len(a.nargString) < 32 { + return nil, fmt.Errorf("FillNextScalars: need 32 bytes, have %d", len(a.nargString)) + } + out[i] = leBytesToNativeBigInt(a.nargString[:32]) + out[i].Mod(out[i], bn254Modulus) + a.nargString = a.nargString[32:] + } + for _, v := range out { + a.sponge.AbsorbFr(v) + } + return out, nil +} + +// FillChallengeScalars squeezes n field elements from the sponge. +// Each challenge requires 64 bytes to match spongefish's DecodingFieldBuffer +// which uses (MODULUS_BIT_SIZE.div_ceil(8) + 32) = 64 bytes per field element +// for statistical uniformity, then reduces mod p once over the full 64-byte LE integer. +func (a *NativeNimue) FillChallengeScalars(n int) ([]*big.Int, error) { + out := make([]*big.Int, n) + for i := range n { + // Squeeze 64 raw bytes and interpret as a single LE integer, then reduce mod p. + var buf [64]byte + a.sponge.Squeeze(buf[:]) + out[i] = leBytesToBigIntUnreduced(buf[:]) + out[i].Mod(out[i], bn254Modulus) + } + return out, nil +} + +// FillNextBytes reads n bytes from the transcript and absorbs them as raw +// bytes into the sponge rate block. Partial writes leave the remaining rate +// bytes unchanged, matching Rust spongefish's EncodingByteBuffer behavior. +func (a *NativeNimue) FillNextBytes(n int) ([]byte, error) { + if len(a.nargString) < n { + return nil, fmt.Errorf("FillNextBytes: need %d bytes, have %d", n, len(a.nargString)) + } + raw := make([]byte, n) + copy(raw, a.nargString[:n]) + a.nargString = a.nargString[n:] + a.sponge.Absorb(raw) + return raw, nil +} + +// FillChallengeBytes squeezes n bytes directly from the sponge. +// Uses byte-level squeeze tracking, matching Rust DuplexSponge exactly. +func (a *NativeNimue) FillChallengeBytes(n int) ([]byte, error) { + out := make([]byte, n) + a.sponge.Squeeze(out) + return out, nil +} + +// ProverHint reads exactly n raw bytes from the hints buffer (NargDeserialize). +func (a *NativeNimue) ProverHint(n int) ([]byte, error) { + buf := make([]byte, n) + _, err := io.ReadFull(a.hints, buf) + if err != nil { + return nil, fmt.Errorf("ProverHint: %w", err) + } + return buf, nil +} + +// ProverHintArk reads an Arkworks compressed-serialized value from the hints buffer. +func (a *NativeNimue) ProverHintArk(target interface{}) error { + _, err := arkSerialize.CanonicalDeserializeWithMode(a.hints, target, false, false) + if err != nil { + return fmt.Errorf("ProverHintArk: %w", err) + } + return nil +} + +// --------------------------------------------------------------------------- +// Native challenge index derivation (mirrors Rust challenge_indices) +// --------------------------------------------------------------------------- + +func nativeGetStirChallenges( + nimue *NativeNimue, + numLeaves int, + count int, + deduplicate bool, +) ([]int, error) { + if count == 0 { + return []int{}, nil + } + if numLeaves == 1 { + if deduplicate { + return []int{0}, nil + } + return make([]int, count), nil + } + + sizeBytes := (bits.Len(uint(numLeaves)) - 1 + 7) / 8 + + entropy, err := nimue.FillChallengeBytes(count * sizeBytes) + if err != nil { + return nil, err + } + + indices := make([]int, count) + for i := range count { + chunk := entropy[i*sizeBytes : (i+1)*sizeBytes] + value := 0 + for _, b := range chunk { + value = (value << 8) | int(b) + } + indices[i] = value % numLeaves + } + + if deduplicate { + sort.Ints(indices) + indices = dedup(indices) + } + + return indices, nil +} + +// --------------------------------------------------------------------------- +// Merkle tree hint consumption +// Reads sibling hashes from the hints buffer following the same traversal +// order as whir's merkle_tree::verify. +// --------------------------------------------------------------------------- + +// countMerkleHints determines the number of 32-byte sibling hashes in the +// hints buffer for a Merkle multi-opening at the given leaf indices. +// It also returns the FullMultiPath reconstructed from the hints. +func consumeMerkleHints(nimue *NativeNimue, indices []int, treeHeight int) (FullMultiPath[Digest], error) { + if len(indices) == 0 { + return FullMultiPath[Digest]{}, nil + } + + sorted := make([]int, len(indices)) + copy(sorted, indices) + sort.Ints(sorted) + sorted = dedup(sorted) + + proofs := make(map[int]*Path[Digest]) + for _, idx := range sorted { + proofs[idx] = &Path[Digest]{ + LeafIndex: uint64(idx), + AuthPath: make([]Digest, 0, treeHeight), + } + } + + currentIndices := sorted + for level := 0; level < treeHeight; level++ { + var nextIndices []int + i := 0 + for i < len(currentIndices) { + a := currentIndices[i] + if i+1 < len(currentIndices) && currentIndices[i+1] == a^1 { + // Sibling pair in the query set — no hint needed + nextIndices = append(nextIndices, a>>1) + i += 2 + } else { + // Need sibling hash from hints + siblingHash, err := nimue.ProverHint(32) + if err != nil { + return FullMultiPath[Digest]{}, fmt.Errorf("merkle level %d, index %d: %w", level, a, err) + } + var digest Digest + copy(digest.Digest[:], siblingHash) + + sibling := a ^ 1 + if level == 0 { + for _, idx := range sorted { + switch idx { + case a: + proofs[idx].LeafSiblingHash = digest + case sibling: + proofs[idx].LeafSiblingHash = digest + } + } + } + + // Store sibling hash in auth path for all original indices + // that trace through this node + for _, origIdx := range sorted { + ancestorIdx := origIdx >> uint(level) + if ancestorIdx == a { + if level > 0 { + proofs[origIdx].AuthPath = append(proofs[origIdx].AuthPath, digest) + } + } + } + + nextIndices = append(nextIndices, a>>1) + i++ + } + } + sort.Ints(nextIndices) + nextIndices = dedup(nextIndices) + currentIndices = nextIndices + } + + // Build the FullMultiPath from collected proofs (in original index order) + paths := make([]Path[Digest], 0, len(sorted)) + for _, idx := range sorted { + paths = append(paths, *proofs[idx]) + } + return FullMultiPath[Digest]{Proofs: paths}, nil +} + +func dedup(sorted []int) []int { + if len(sorted) <= 1 { + return sorted + } + result := sorted[:1] + for _, v := range sorted[1:] { + if v != result[len(result)-1] { + result = append(result, v) + } + } + return result +} diff --git a/recursive-verifier/app/circuit/native_whir_verify.go b/recursive-verifier/app/circuit/native_whir_verify.go new file mode 100644 index 000000000..a84ca1159 --- /dev/null +++ b/recursive-verifier/app/circuit/native_whir_verify.go @@ -0,0 +1,1000 @@ +package circuit + +import ( + "fmt" + "math" + "math/big" + "math/bits" + "sort" + + "github.com/consensys/gnark/frontend" + + "reilabs/whir-verifier-circuit/app/typeConverters" + "reilabs/whir-verifier-circuit/app/whir" +) + +// --------------------------------------------------------------------------- +// Native types mirroring Rust whir types for verification +// --------------------------------------------------------------------------- + +// NativeEvaluations mirrors Rust Evaluations: OOD evaluation points and +// the matrix of evaluations (row-major, one row per OOD point). +type NativeEvaluations struct { + Points []*big.Int // OOD evaluation points + Matrix []*big.Int // flattened row-major: [point0_col0, point0_col1, ..., point1_col0, ...] +} + +func (e *NativeEvaluations) NumPoints() int { + return len(e.Points) +} + +func (e *NativeEvaluations) NumColumns() int { + np := e.NumPoints() + if np == 0 { + return 0 + } + return len(e.Matrix) / np +} + +func (e *NativeEvaluations) Rows() [][]*big.Int { + cols := e.NumColumns() + rows := make([][]*big.Int, e.NumPoints()) + for i := range rows { + rows[i] = e.Matrix[i*cols : (i+1)*cols] + } + return rows +} + +// NativeCommitment mirrors Rust irs_commit::Commitment. +type NativeCommitment struct { + OutOfDomain NativeEvaluations +} + +func (c *NativeCommitment) NumVectors() int { + return c.OutOfDomain.NumColumns() +} + +// NativeFinalClaim mirrors Rust FinalClaim. +type NativeFinalClaim struct { + EvaluationPoint []*big.Int + RLCCoefficients []*big.Int + LinearFormRLC *big.Int +} + +// NativeRoundConstraint holds the RLC coefficients and the OOD/in-domain +// evaluator points for one round's constraints. +type NativeRoundConstraint struct { + RLCCoeffs []*big.Int // random linear combination coefficients + EvaluatorInfos []evaluatorInfo // one per constraint (OOD point + size) +} + +// evaluatorInfo stores the data needed to compute mle_evaluate for a +// UnivariateEvaluation linear form (OOD or in-domain evaluator). +type evaluatorInfo struct { + point *big.Int // evaluation point (OOD point or domain point) + size int // polynomial size (= 2^num_variables for that round) +} + +// --------------------------------------------------------------------------- +// Fp256 conversion helpers +// --------------------------------------------------------------------------- + +// fp256ToBigInt converts an Fp256 (4 x uint64 limbs, little-endian) to *big.Int. +func fp256ToBigInt(f Fp256) *big.Int { + r := new(big.Int) + for i := 3; i >= 0; i-- { + r.Lsh(r, 64) + r.Or(r, new(big.Int).SetUint64(f.Limbs[i])) + } + r.Mod(r, bn254Modulus) + return r +} + +// fp256SliceToBigInt converts a slice of Fp256 to []*big.Int. +func fp256SliceToBigInt(fs []Fp256) []*big.Int { + result := make([]*big.Int, len(fs)) + for i, f := range fs { + result[i] = fp256ToBigInt(f) + } + return result +} + +// --------------------------------------------------------------------------- +// Field arithmetic helpers (all mod BN254) +// --------------------------------------------------------------------------- + +func frAdd(a, b *big.Int) *big.Int { + r := new(big.Int).Add(a, b) + r.Mod(r, bn254Modulus) + return r +} + +func frSub(a, b *big.Int) *big.Int { + r := new(big.Int).Sub(a, b) + r.Mod(r, bn254Modulus) + if r.Sign() < 0 { + r.Add(r, bn254Modulus) + } + return r +} + +func frMul(a, b *big.Int) *big.Int { + r := new(big.Int).Mul(a, b) + r.Mod(r, bn254Modulus) + return r +} + +func frInv(a *big.Int) *big.Int { + r := new(big.Int).ModInverse(a, bn254Modulus) + return r +} + +func frDiv(a, b *big.Int) *big.Int { + return frMul(a, frInv(b)) +} + +// --------------------------------------------------------------------------- +// Algebraic helpers mirroring Rust whir algebra +// --------------------------------------------------------------------------- + +// nativeGeometricSequence returns [1, x, x^2, ..., x^(count-1)] mod p. +func nativeGeometricSequence(x *big.Int, count int) []*big.Int { + result := make([]*big.Int, count) + result[0] = big.NewInt(1) + for i := 1; i < count; i++ { + result[i] = frMul(result[i-1], x) + } + return result +} + +// nativeGeometricChallenge mirrors Rust geometric_challenge: squeeze a single +// challenge and expand to a geometric sequence of the given length. +func nativeGeometricChallenge(nimue *NativeNimue, count int) ([]*big.Int, error) { + switch count { + case 0: + return []*big.Int{}, nil + case 1: + return []*big.Int{big.NewInt(1)}, nil + default: + x, err := nimue.FillChallengeScalars(1) + if err != nil { + return nil, err + } + return nativeGeometricSequence(x[0], count), nil + } +} + +// nativeDotBigInt computes the dot product of two big.Int slices mod p. +func nativeDotBigInt(a, b []*big.Int) *big.Int { + result := big.NewInt(0) + for i := range a { + result = frAdd(result, frMul(a[i], b[i])) + } + return result +} + +// nativeTensorProduct computes the tensor (Kronecker) product of two vectors. +func nativeTensorProduct(a, b []*big.Int) []*big.Int { + result := make([]*big.Int, len(a)*len(b)) + for i, x := range a { + for j, y := range b { + result[i*len(b)+j] = frMul(x, y) + } + } + return result +} + +// nativeEqWeights computes eq polynomial weights for a multilinear point. +// Returns a vector of size 2^n where result[i] = eq(point, binary(i)). +// Matches the circuit's calculateEQOverBooleanHypercube: iterates in reverse +// so that point[0] controls bit 0 (LSB) of the index. +func nativeEqWeights(point []*big.Int) []*big.Int { + result := []*big.Int{big.NewInt(1)} + for i := len(point) - 1; i >= 0; i-- { + x := point[i] + oneMinusX := frSub(big.NewInt(1), x) + length := len(result) + left := make([]*big.Int, length) + right := make([]*big.Int, length) + for j := 0; j < length; j++ { + left[j] = frMul(result[j], oneMinusX) + right[j] = frMul(result[j], x) + } + result = append(left, right...) + } + return result +} + +// nativeMultilinearEval evaluates the multilinear extension of `values` at `point`. +// This is: Σ_i values[i] * eq(i, point) +func nativeMultilinearEval(point []*big.Int, values []*big.Int) *big.Int { + eqW := nativeEqWeights(point) + return nativeDotBigInt(eqW, values) +} + +// nativeUnivariateEvalMLE computes UnivariateEvaluation{point, size}.mle_evaluate(mlPoint). +// +// This is the MLE of the linear form that evaluates a polynomial at `point`, +// given its values on a domain of the given `size`. Specifically: +// +// mle(x_1,...,x_n) = Π_{i=0}^{n-1} ((1 - x_i) + x_i * point^(2^i)) +// +// where n = log2(size) and the x_i are taken from mlPoint. +func nativeUnivariateEvalMLE(point *big.Int, size int, mlPoint []*big.Int) *big.Int { + n := len(mlPoint) + _ = size // size is implicit: 2^n + result := big.NewInt(1) + // power tracks point^(2^i) + power := new(big.Int).Set(point) + // Iterate in reverse to match Rust point.iter().rev() and + // circuit UnivarMleEvaluate which iterates i = n-1..0. + for i := n - 1; i >= 0; i-- { + r := mlPoint[i] + oneMinusR := frSub(big.NewInt(1), r) + factor := frAdd(oneMinusR, frMul(r, power)) + result = frMul(result, factor) + power = frMul(power, power) + } + return result +} + +// --------------------------------------------------------------------------- +// Native WHIR sumcheck verification (quadratic polynomial, 2 coefficients) +// --------------------------------------------------------------------------- + +// nativeWhirSumcheckVerify runs the WHIR-style sumcheck verification. +// Each round reads 2 monomial coefficients (c0, c2) of the quadratic polynomial +// p(x) = c0 + c1*x + c2*x^2. The linear coefficient c1 is derived from the +// sumcheck constraint p(0) + p(1) = sum, giving c1 = sum - 2*c0 - c2. +// After squeezing a folding randomness challenge r, the sum is updated to p(r). +// Returns the folding randomness points and the final sum. +func nativeWhirSumcheckVerify( + nimue *NativeNimue, + sum *big.Int, + numRounds int, + powThreshold uint64, +) ([]*big.Int, *big.Int, error) { + foldingRandomness := make([]*big.Int, numRounds) + currentSum := new(big.Int).Set(sum) + + for i := 0; i < numRounds; i++ { + // Read c0 and c2 (monomial coefficients: constant and quadratic) + evals, err := nimue.FillNextScalars(2) + if err != nil { + return nil, nil, fmt.Errorf("sumcheck round %d coeffs: %w", i, err) + } + c0 := evals[0] + c2 := evals[1] + // Derive c1 from p(0) + p(1) = sum: + // p(0) = c0, p(1) = c0 + c1 + c2, so 2*c0 + c1 + c2 = sum + c1 := frSub(frSub(currentSum, frAdd(c0, c0)), c2) + + // PoW check (matching Rust sumcheck round_pow.verify) + if err := nativePoWVerify(nimue, powThreshold); err != nil { + return nil, nil, fmt.Errorf("sumcheck round %d pow: %w", i, err) + } + + // Squeeze folding randomness + rSlice, err := nimue.FillChallengeScalars(1) + if err != nil { + return nil, nil, fmt.Errorf("sumcheck round %d challenge: %w", i, err) + } + r := rSlice[0] + foldingRandomness[i] = r + + // Update sum: p(r) = (c2*r + c1)*r + c0 + currentSum = frAdd(frMul(frAdd(frMul(c2, r), c1), r), c0) + } + + return foldingRandomness, currentSum, nil +} + +// --------------------------------------------------------------------------- +// Native IRS commit verify (returns in-domain evaluation points) +// --------------------------------------------------------------------------- + +// nativeIRSCommitVerifyWithPoints replays the IRS commit verification and +// returns the in-domain query indices plus the parsed Merkle round data. +func nativeIRSCommitVerifyWithPoints( + nimue *NativeNimue, + numQueries int, + domainSize int, + foldingFactorPower int, +) ([]int, *whir.RoundMerkleEntry, error) { + _ = int(nimue.hints.Size()) - nimue.hints.Len() + // Squeeze challenge indices + indices, err := nativeGetStirChallenges(nimue, domainSize/foldingFactorPower, numQueries, false) + if err != nil { + return nil, nil, fmt.Errorf("stir challenges: %w", err) + } + + // Read submatrix hint + var submatrix []Fp256 + if err = nimue.ProverHintArk(&submatrix); err != nil { + return nil, nil, fmt.Errorf("submatrix: %w", err) + } + + // Read Merkle proof hints + foldedDomainSize := domainSize / foldingFactorPower + treeHeight := bits.Len(uint(foldedDomainSize)) - 1 + dedupedIndices := make([]int, len(indices)) + copy(dedupedIndices, indices) + sort.Ints(dedupedIndices) + dedupedIndices = dedup(dedupedIndices) + + // The leaf fold size (num_cols) may differ from foldingFactorPower when + // batch_size > 1 (initial commitment). Derive it from the submatrix length. + leafFoldSize := foldingFactorPower + if len(indices) > 0 { + leafFoldSize = len(submatrix) / len(indices) + } + entry, err := consumeHintsAndBuildMerkleEntry(nimue, indices, dedupedIndices, submatrix, leafFoldSize, treeHeight) + if err != nil { + return nil, nil, fmt.Errorf("merkle: %w", err) + } + return indices, entry, nil +} + +// IndexPair identifies a node in the Merkle tree by its depth and index. +type IndexPair struct { + Depth uint64 + Index uint64 +} + +// extractFullAuthPath extracts a complete authentication path for a given leaf +// index from the reconstructed tree map. +func ExtractFullAuthPath( + tree map[IndexPair]Digest, + leafIdx uint64, + depth int, +) (siblingHash Digest, authPath []Digest) { + siblingHash = tree[IndexPair{Depth: uint64(depth), Index: leafIdx ^ 1}] + authPath = make([]Digest, depth-1) + currentIdx := leafIdx / 2 + for level := depth - 1; level >= 1; level-- { + authPath[depth-1-level] = tree[IndexPair{Depth: uint64(level), Index: currentIdx ^ 1}] + currentIdx /= 2 + } + return siblingHash, authPath +} + +// consumeHintsAndBuildMerkleEntry reads Merkle proof hints from nimue and +// builds the RoundMerkleEntry in a single bottom-up pass. This combines what +// was previously two separate functions (consumeMerkleHints + buildRoundMerkleEntry) +// that each independently replayed the same level-by-level tree traversal. +// +// indices: original query indices (may contain duplicates) +// dedupSorted: sorted, deduplicated indices (pre-computed by caller) +// submatrix: leaf data laid out in order of indices, foldSize elements per index +func consumeHintsAndBuildMerkleEntry( + name *NativeNimue, + indices []int, + dedupSorted []int, + submatrix []Fp256, + foldSize int, + treeHeight int, +) (*whir.RoundMerkleEntry, error) { + if len(indices) == 0 { + return &whir.RoundMerkleEntry{}, nil + } + + // Build index → submatrix row lookup. The submatrix is laid out in the + // order of the ORIGINAL indices (unsorted, with duplicates). Each index + // contributes foldSize elements. + leafFp256ByIdx := make(map[int][]Fp256) + leafVarByIdx := make(map[int][]frontend.Variable) + for i, idx := range indices { + if _, exists := leafFp256ByIdx[idx]; exists { + continue // already mapped (duplicate index) + } + fp256Row := make([]Fp256, foldSize) + varRow := make([]frontend.Variable, foldSize) + for j := range foldSize { + elemIdx := i*foldSize + j + if elemIdx < len(submatrix) { + fp256Row[j] = submatrix[elemIdx] + varRow[j] = typeConverters.LimbsToBigIntMod(submatrix[elemIdx].Limbs) + } + } + leafFp256ByIdx[idx] = fp256Row + leafVarByIdx[idx] = varRow + } + + // Build the Merkle tree bottom-up while consuming hints from nimue. + // For each level, sibling pairs (both in the query set) don't need a hint; + // single indices get a sibling hash from the hint stream. + tree := make(map[IndexPair]Digest) + + // Compute and store all leaf hashes. + for _, idx := range dedupSorted { + leafHash := HashLeafData(leafFp256ByIdx[idx]) + tree[IndexPair{Depth: uint64(treeHeight), Index: uint64(idx)}] = leafHash + } + + // Bottom-up level-by-level: consume hints and compute parent hashes. + currentIndices := make([]int, len(dedupSorted)) + copy(currentIndices, dedupSorted) + + for level := 0; level < treeHeight; level++ { + currentDepth := uint64(treeHeight - level) + parentDepth := currentDepth - 1 + var nextIndices []int + i := 0 + for i < len(currentIndices) { + a := currentIndices[i] + if i+1 < len(currentIndices) && currentIndices[i+1] == a^1 { + // Sibling pair: both hashes already in tree. + aHash := tree[IndexPair{Depth: currentDepth, Index: uint64(a)}] + bHash := tree[IndexPair{Depth: currentDepth, Index: uint64(a ^ 1)}] + var left, right Digest + if a%2 == 0 { + left, right = aHash, bHash + } else { + left, right = bHash, aHash + } + tree[IndexPair{Depth: parentDepth, Index: uint64(a >> 1)}] = HashTwoDigests(left, right) + nextIndices = append(nextIndices, a>>1) + i += 2 + } else { + // Single index: read sibling hash from hints. + siblingHash, err := name.ProverHint(32) + if err != nil { + return nil, fmt.Errorf("merkle level %d, index %d: %w", level, a, err) + } + var digest Digest + copy(digest.Digest[:], siblingHash) + + sibKey := IndexPair{Depth: currentDepth, Index: uint64(a ^ 1)} + if _, exists := tree[sibKey]; !exists { + tree[sibKey] = digest + } + + // Compute parent hash. + nodeHash := tree[IndexPair{Depth: currentDepth, Index: uint64(a)}] + sibNodeHash := tree[sibKey] + var left, right Digest + if a%2 == 0 { + left, right = nodeHash, sibNodeHash + } else { + left, right = sibNodeHash, nodeHash + } + tree[IndexPair{Depth: parentDepth, Index: uint64(a >> 1)}] = HashTwoDigests(left, right) + nextIndices = append(nextIndices, a>>1) + i++ + } + } + sort.Ints(nextIndices) + nextIndices = dedup(nextIndices) + currentIndices = nextIndices + } + + // Build the per-query RoundMerkleEntry with complete auth paths from the tree. + nq := len(indices) + entry := &whir.RoundMerkleEntry{ + Leaves: make([][]frontend.Variable, nq), + SiblingHashes: make([]frontend.Variable, nq), + AuthPaths: make([][]frontend.Variable, nq), + LeafIndexes: make([]frontend.Variable, nq), + } + + for q, idx := range indices { + entry.LeafIndexes[q] = big.NewInt(int64(idx)) + entry.Leaves[q] = leafVarByIdx[idx] + + siblingHash, authPath := ExtractFullAuthPath(tree, uint64(idx), treeHeight) + entry.SiblingHashes[q] = DigestToFieldElement(siblingHash) + entry.AuthPaths[q] = make([]frontend.Variable, len(authPath)) + for lvl, digest := range authPath { + entry.AuthPaths[q][lvl] = DigestToFieldElement(digest) + } + } + + return entry, nil +} + +// --------------------------------------------------------------------------- +// NativeCommitmentFromParsed builds a NativeCommitment from the output of +// nativeParseBatchedCommitment. +// --------------------------------------------------------------------------- + +func NativeCommitmentFromParsed(oodPoints []*big.Int, oodAnswers [][]*big.Int) *NativeCommitment { + // Flatten oodAnswers [][]*big.Int (each is a 1-element slice from FillNextScalars) + // into a row-major matrix: rows = OOD points, columns = batch vectors. + var matrix []*big.Int + for _, ans := range oodAnswers { + matrix = append(matrix, ans...) + } + return &NativeCommitment{ + OutOfDomain: NativeEvaluations{ + Points: oodPoints, + Matrix: matrix, + }, + } +} + +// --------------------------------------------------------------------------- +// nativeReceiveCommitment reads a round commitment from the transcript: +// root hash + OOD points + OOD answers. +// --------------------------------------------------------------------------- + +func nativeReceiveCommitment( + name *NativeNimue, + oodSamples int, +) (*NativeCommitment, error) { + // Root hash (prover_message) — absorbed as raw bytes, NOT as a field element, + // because the Merkle root hash can be >= BN254 modulus. + _, err := name.FillNextBytes(32) + if err != nil { + return nil, fmt.Errorf("root hash: %w", err) + } + // OOD points (verifier challenges) + oodPoints := make([]*big.Int, 0) + oodAnswers := make([]*big.Int, 0) + if oodSamples > 0 { + pts, err := name.FillChallengeScalars(oodSamples) + if err != nil { + return nil, fmt.Errorf("ood points: %w", err) + } + oodPoints = pts + + ans, err := name.FillNextScalars(oodSamples) + if err != nil { + return nil, fmt.Errorf("ood answers: %w", err) + } + oodAnswers = ans + } + + return &NativeCommitment{ + OutOfDomain: NativeEvaluations{ + Points: oodPoints, + Matrix: oodAnswers, // single-vector: matrix is 1 column + }, + }, nil +} + +// --------------------------------------------------------------------------- +// Native PoW verification (transcript replay only — actual check is in circuit) +// --------------------------------------------------------------------------- + +func nativePoWVerify(nimue *NativeNimue, threshold uint64) error { + if threshold < math.MaxUint64 { + challengeBytes, err := nimue.FillChallengeBytes(32) + if err != nil { + return fmt.Errorf("pow challenge: %w", err) + } + nonceBytes, err := nimue.FillNextBytes(8) + if err != nil { + return fmt.Errorf("pow nonce: %w", err) + } + + // Convert challenge bytes (LE) to [4]uint64 limbs + var challengeLimbs [4]uint64 + for i := 0; i < 4; i++ { + for j := 0; j < 8; j++ { + challengeLimbs[i] |= uint64(challengeBytes[i*8+j]) << (8 * j) + } + } + + // Convert nonce bytes (LE) to [4]uint64 — nonce is u64 in low limb, rest zero + var nonceLimbs [4]uint64 + for j := 0; j < 8; j++ { + nonceLimbs[0] |= uint64(nonceBytes[j]) << (8 * j) + } + + // Skyscraper compress + hash := SkyscraperCompress(challengeLimbs, nonceLimbs) + + // Rust PoW check: first 8 bytes of hash as u64 LE <= threshold + value := hash[0] + if value > threshold { + return fmt.Errorf("PoW check failed: hash not below threshold (value=%d, threshold=%d)", value, threshold) + } + } + return nil +} + +// --------------------------------------------------------------------------- +// NativeWhirVerify: full WHIR batched verification +// Mirrors Rust whir::Config::verify() exactly. +// --------------------------------------------------------------------------- + +// NativeWhirVerifyResult bundles the verify output with the ZKHint data +// needed for circuit construction. +type NativeWhirVerifyResult struct { + FinalClaim NativeFinalClaim + Hint ZKHint + MerkleData *whir.WhirMerkleData +} + +// NativeWhirVerify replays and verifies the full WHIR protocol transcript. +// This mirrors the Rust `Config::verify()` method for batched commitments. +// +// Parameters: +// - nimue: transcript reader +// - whirParams: WHIR protocol parameters +// - whirConfig: WHIR configuration (for ZKHint construction) +// - commitments: N parsed commitments (from parseBatchedCommitment) +// - evaluations: constraint evaluation values (flattened) +// - numLinearForms: number of external linear form constraints +func NativeWhirVerify( + nimue *NativeNimue, + whirParams WHIRParams, + whirConfig WHIRConfig, + commitments []*NativeCommitment, + evaluations []*big.Int, +) (*NativeWhirVerifyResult, error) { + var allMerklePaths []FullMultiPath[Digest] + var allStirAnswers [][][]Fp256 + + numVectors := 0 + for _, c := range commitments { + numVectors += c.NumVectors() + } + if len(evaluations) > 0 && numVectors > 0 && len(evaluations)%numVectors != 0 { + return nil, fmt.Errorf("evaluations length %d not multiple of num_vectors %d", len(evaluations), numVectors) + } + if numVectors == 0 { + return &NativeWhirVerifyResult{ + FinalClaim: NativeFinalClaim{ + LinearFormRLC: big.NewInt(0), + }, + }, nil + } + numLinearForms := len(evaluations) / numVectors + // Complete OOD evaluation matrix with cross-terms + var oodsEvalInfos []evaluatorInfo // evaluator info per OOD constraint + var oodsMatrix []*big.Int // flattened: [ood0_vec0, ood0_vec1, ..., ood1_vec0, ...] + + vectorOffset := 0 + for _, commitment := range commitments { + ood := &commitment.OutOfDomain + for rowIdx, row := range ood.Rows() { + for j := 0; j < numVectors; j++ { + if j >= vectorOffset && j < len(row)+vectorOffset { + oodsMatrix = append(oodsMatrix, row[j-vectorOffset]) + } else { + // Cross-term: read from transcript + vals, err := nimue.FillNextScalars(1) + if err != nil { + return nil, fmt.Errorf("ood cross-term: %w", err) + } + oodsMatrix = append(oodsMatrix, vals[0]) + } + } + _ = rowIdx + // Each OOD row creates one evaluator + } + // Add evaluator infos for this commitment's OOD points + initialSize := whirParams.DomainSize / (1 << whirConfig.Rate) + for _, pt := range ood.Points { + oodsEvalInfos = append(oodsEvalInfos, evaluatorInfo{point: pt, size: initialSize}) + } + vectorOffset += commitment.NumVectors() + } + + vectorRLCCoeffs, err := nativeGeometricChallenge(nimue, numVectors) + if err != nil { + return nil, fmt.Errorf("vector_rlc: %w", err) + } + + totalConstraints := len(oodsEvalInfos) + numLinearForms + constraintRLCCoeffs, err := nativeGeometricChallenge(nimue, totalConstraints) + if err != nil { + return nil, fmt.Errorf("constraint_rlc: %w", err) + } + + initialFormRLCCoeffs := constraintRLCCoeffs[:numLinearForms] + oodsRLCCoeffs := constraintRLCCoeffs[numLinearForms:] + + theSum := big.NewInt(0) + + // Contribution from external linear forms + for i, coeff := range initialFormRLCCoeffs { + row := evaluations[i*numVectors : (i+1)*numVectors] + theSum = frAdd(theSum, frMul(coeff, nativeDotBigInt(vectorRLCCoeffs, row))) + } + + // Contribution from OOD constraints + for i, coeff := range oodsRLCCoeffs { + row := oodsMatrix[i*numVectors : (i+1)*numVectors] + theSum = frAdd(theSum, frMul(coeff, nativeDotBigInt(vectorRLCCoeffs, row))) + } + + // Track round constraints for final MLE subtraction + roundConstraints := []NativeRoundConstraint{ + {RLCCoeffs: oodsRLCCoeffs, EvaluatorInfos: oodsEvalInfos}, + } + + var allFoldingRandomness [][]*big.Int + + // Initial sumcheck + if len(constraintRLCCoeffs) == 0 { + // No constraints: skip sumcheck, just squeeze folding randomness + if theSum.Cmp(big.NewInt(0)) != 0 { + return nil, fmt.Errorf("the_sum should be zero but got %s", theSum.String()) + } + ff0 := whirParams.FoldingFactorArray[0] + foldRandomness, err := nimue.FillChallengeScalars(ff0) + if err != nil { + return nil, fmt.Errorf("initial skip folding: %w", err) + } + // initial_skip_pow + if err := nativePoWVerify(nimue, whirConfig.InitialSkipPowThreshold); err != nil { + return nil, fmt.Errorf("initial skip pow: %w", err) + } + allFoldingRandomness = append(allFoldingRandomness, foldRandomness) + } else { + ff0 := whirParams.FoldingFactorArray[0] + foldRandomness, newSum, err := nativeWhirSumcheckVerify(nimue, theSum, ff0, whirConfig.InitialSumcheckPowThreshold) + if err != nil { + return nil, fmt.Errorf("initial sumcheck: %w", err) + } + theSum = newSum + allFoldingRandomness = append(allFoldingRandomness, foldRandomness) + } + + // Main WHIR rounds + domainSize := whirParams.DomainSize + nRounds := whirParams.ParamNRounds + var merkleRounds []whir.RoundMerkleEntry + + // Track which commitment type was previous (for opening) + type prevType int + const ( + prevInitial prevType = iota + prevRound + ) + prev := prevInitial + // polyRLC for initial is vectorRLCCoeffs; for round is [1] + currentPolyRLC := vectorRLCCoeffs + + // expDomainGen is the generator of the folded domain (codeword rows). + // It starts as startingDomainGen^(1 << foldingFactor[0]) = generator of + // codeword_length = domainSize / interleavingDepth. + startingDomainGen, _ := new(big.Int).SetString(whirConfig.DomainGenerator, 10) + expDomainGen := new(big.Int).Exp(startingDomainGen, big.NewInt(int64(1< 1; v >>= 1 { + numBitsR++ + } + for qi, idx := range inDomainIndices { + reversedIdx := whir.BitReverseInt(idx, numBitsR) + domainPoint := new(big.Int).Exp(expDomainGen, big.NewInt(int64(reversedIdx)), bn254Modulus) + constraintEvalInfos = append(constraintEvalInfos, evaluatorInfo{ + point: domainPoint, + size: roundSize, + }) + + // Compute in-domain value from leaf data: dot(tensorWeights, leaf) + if qi < len(roundMerkle.Leaves) { + leafBigInts := make([]*big.Int, len(roundMerkle.Leaves[qi])) + for j, v := range roundMerkle.Leaves[qi] { + leafBigInts[j] = v.(*big.Int) + } + constraintValues = append(constraintValues, nativeDotBigInt(tensorWeights, leafBigInts)) + } else { + constraintValues = append(constraintValues, big.NewInt(0)) + } + } + + constraintRLC, err := nativeGeometricChallenge(nimue, len(constraintValues)) + if err != nil { + return nil, fmt.Errorf("round %d combination randomness: %w", r, err) + } + inDomainContrib := nativeDotBigInt(constraintRLC, constraintValues) + theSum = frAdd(theSum, inDomainContrib) + + roundConstraints = append(roundConstraints, NativeRoundConstraint{ + RLCCoeffs: constraintRLC, + EvaluatorInfos: constraintEvalInfos, + }) + + // Sumcheck for this round + ff := whirParams.FoldingFactorArray[r] + if r+1 < len(whirParams.FoldingFactorArray) { + ff = whirParams.FoldingFactorArray[r+1] + } + foldRandomness, newSum, err := nativeWhirSumcheckVerify(nimue, theSum, ff, whirConfig.SumcheckPowThresholds[r]) + if err != nil { + return nil, fmt.Errorf("round %d sumcheck: %w", r, err) + } + theSum = newSum + allFoldingRandomness = append(allFoldingRandomness, foldRandomness) + + prev = prevRound + currentPolyRLC = []*big.Int{big.NewInt(1)} + domainSize /= 2 + + // Update the folded domain generator for the next round. + // Mirrors the circuit code: numSquarings = 1 + ff[r+1] - ff[r] + nextFF := whirParams.FoldingFactorArray[r] + if r+1 < len(whirParams.FoldingFactorArray) { + nextFF = whirParams.FoldingFactorArray[r+1] + } + numSquarings := 1 + nextFF - whirParams.FoldingFactorArray[r] + for k := 0; k < numSquarings; k++ { + expDomainGen = frMul(expDomainGen, expDomainGen) + } + } + + // Final round: receive full vector + finalSize := 1 << whirParams.FinalSumcheckRounds + finalVector, err := nimue.FillNextScalars(finalSize) + + if err != nil { + return nil, fmt.Errorf("final vector: %w", err) + } + + // Final PoW + if err := nativePoWVerify(nimue, whirConfig.FinalPowThreshold); err != nil { + return nil, fmt.Errorf("final pow: %w", err) + } + + // Open previous commitment (final IRS verify) + finalFoldingFactorPower := 1 << whirParams.FoldingFactorArray[nRounds] + finalIndices, err := nativeGetStirChallenges( + nimue, + domainSize/finalFoldingFactorPower, + whirParams.FinalQueries, + false, + ) + if err != nil { + return nil, fmt.Errorf("final stir challenges: %w", err) + } + + var finalSubmatrix []Fp256 + if err = nimue.ProverHintArk(&finalSubmatrix); err != nil { + return nil, fmt.Errorf("final submatrix: %w", err) + } + + allStirAnswers = append(allStirAnswers, [][]Fp256{}) + + foldedDomainSize := domainSize / finalFoldingFactorPower + treeHeight := bits.Len(uint(foldedDomainSize)) - 1 + dedupedFinal := make([]int, len(finalIndices)) + copy(dedupedFinal, finalIndices) + sort.Ints(dedupedFinal) + dedupedFinal = dedup(dedupedFinal) + + finalMerkleEntry, err := consumeHintsAndBuildMerkleEntry(nimue, finalIndices, dedupedFinal, finalSubmatrix, finalFoldingFactorPower, treeHeight) + if err != nil { + return nil, fmt.Errorf("final merkle: %w", err) + } + allMerklePaths = append(allMerklePaths, FullMultiPath[Digest]{}) + merkleRounds = append(merkleRounds, *finalMerkleEntry) + + // Final sumcheck + finalSumcheckRandomness, newSum, err := nativeWhirSumcheckVerify(nimue, theSum, whirParams.FinalSumcheckRounds, whirConfig.FinalFoldingPowThreshold) + if err != nil { + return nil, fmt.Errorf("final sumcheck: %w", err) + } + theSum = newSum + allFoldingRandomness = append(allFoldingRandomness, finalSumcheckRandomness) + + // Final folding PoW + if err := nativePoWVerify(nimue, whirConfig.FinalFoldingPowThreshold); err != nil { + return nil, fmt.Errorf("final folding pow: %w", err) + } + + // Compute evaluation point (all folding randomness concatenated) + var evaluationPoint []*big.Int + for _, fr := range allFoldingRandomness { + evaluationPoint = append(evaluationPoint, fr...) + } + + // Compute linear_form_rlc from the sumcheck invariant + // poly_eval = MLE(final_sumcheck_randomness).evaluate(Identity, final_vector) + polyEval := nativeMultilinearEval(finalSumcheckRandomness, finalVector) + + // linear_form_rlc = the_sum / poly_eval + linearFormRLC := frDiv(theSum, polyEval) + + // Subtract all internal linear forms. + // roundConstraints[0] = initial OODs (uses initial_num_variables = MVParams) + // roundConstraints[r+1] = main round r (uses round_configs[r].initial_num_variables) + // Mirrors Rust: round.checked_sub(1).map_or(initial_num_vars, |p| round_configs[p].initial_num_vars) + for round, rc := range roundConstraints { + numVariables := whirParams.MVParamsNumberOfVariables + for k := 0; k < round; k++ { + numVariables -= whirParams.FoldingFactorArray[k] + } + start := len(evaluationPoint) - numVariables + if start < 0 { + start = 0 + } + subPoint := evaluationPoint[start:] + for i, coeff := range rc.RLCCoeffs { + if i < len(rc.EvaluatorInfos) { + info := rc.EvaluatorInfos[i] + mleVal := nativeUnivariateEvalMLE(info.point, info.size, subPoint) + linearFormRLC = frSub(linearFormRLC, frMul(coeff, mleVal)) + } + } + } + + // Build ZKHint from parsed Merkle data + zkHint := consumeWhirData(whirConfig, &allMerklePaths, &allStirAnswers) + + return &NativeWhirVerifyResult{ + FinalClaim: NativeFinalClaim{ + EvaluationPoint: evaluationPoint, + RLCCoefficients: initialFormRLCCoeffs, + LinearFormRLC: linearFormRLC, + }, + Hint: zkHint, + MerkleData: &whir.WhirMerkleData{ + Rounds: merkleRounds, + }, + }, nil +} diff --git a/recursive-verifier/app/circuit/skyscraper2.go b/recursive-verifier/app/circuit/skyscraper2.go new file mode 100644 index 000000000..29494cf9c --- /dev/null +++ b/recursive-verifier/app/circuit/skyscraper2.go @@ -0,0 +1,214 @@ +package circuit + +import ( + "math/big" +) + +// SIGMA_INV constant for Skyscraper +var sigmaInv *big.Int + +// Round constants for Skyscraper +var skyscraperRC [18]*big.Int + +func init() { + sigmaInv, _ = new(big.Int).SetString("9915499612839321149637521777990102151350674507940716049588462388200839649614", 10) + + rcHex := [][4]uint64{ + {0x0000000000000000, 0x0000000000000000, 0x0000000000000000, 0x0000000000000000}, + {0x903c4324270bd744, 0x873125f708a7d269, 0x081dd27906c83855, 0x276b1823ea6d7667}, + {0x7ac8edbb4b378d71, 0xe29d79f3d99e2cb7, 0x751417914c1a5a18, 0x0cf02bd758a484a6}, + {0xfa7adc6769e5bc36, 0x1c3f8e297cca387d, 0x0eb7730d63481db0, 0x25b0e03f18ede544}, + {0x57847e652f03cfb7, 0x33440b9668873404, 0x955a32e849af80bc, 0x002882fcbe14ae70}, + {0x979231396257d4d7, 0x29989c3e1b37d3c1, 0x12ef02b47f1277ba, 0x039ad8571e2b7a9c}, + {0xb5b48465abbb7887, 0xa72a6bc5e6ba2d2b, 0x4cd48043712f7b29, 0x1142d5410fc1fc1a}, + {0x7ab2c156059075d3, 0x17cb3594047999b2, 0x44f2c93598f289f7, 0x1d78439f69bc0bec}, + {0x05d7a965138b8edb, 0x36ef35a3d55c48b1, 0x8ddfb8a1ac6f1628, 0x258588a508f4ff82}, + {0x1596fb9afccb49e9, 0x9a7367d69a09a95b, 0x9bc43f6984e4c157, 0x13087879d2f514fe}, + {0x295ccd233b4109fa, 0xe1d72f89ed868012, 0x2e9e1eea4bc88a8e, 0x17dadee898c45232}, + {0x9a8590b4aa1f486f, 0xb75834b430e9130e, 0xb8e90b1034d5de31, 0x295c6d1546e7f4a6}, + {0x850adcb74c6eb892, 0x07699ef305b92fc3, 0x4ef96a2ba1720f2d, 0x1288ca0e1d3ed446}, + {0x01960f9349d1b5ee, 0x8ccad30769371c69, 0xe5c81e8991c98662, 0x17563b4d1ae023f3}, + {0x6ba01e9476b32917, 0xa1cb0a3add977bc9, 0x86815a945815f030, 0x2869043be91a1eea}, + {0x81776c885511d976, 0x7475d34f47f414e7, 0x5d090056095d96cf, 0x14941f0aff59e79a}, + {0xbc40b4fd8fc8c034, 0xbb7142c3cce4fd48, 0x318356758a39005a, 0x1ce337a190f4379f}, + {0x0000000000000000, 0x0000000000000000, 0x0000000000000000, 0x0000000000000000}, + } + + for i := range rcHex { + skyscraperRC[i] = skLimbsToBigInt(rcHex[i]) + } +} + +func skLimbsToBigInt(limbs [4]uint64) *big.Int { + result := new(big.Int) + result.SetUint64(limbs[0]) + temp := new(big.Int) + temp.SetUint64(limbs[1]) + temp.Lsh(temp, 64) + result.Add(result, temp) + temp.SetUint64(limbs[2]) + temp.Lsh(temp, 128) + result.Add(result, temp) + temp.SetUint64(limbs[3]) + temp.Lsh(temp, 192) + result.Add(result, temp) + return result.Mod(result, bn254Modulus) +} + +func skBigIntToLimbs(val *big.Int) [4]uint64 { + v := new(big.Int).Set(val) + v.Mod(v, bn254Modulus) + limbs := [4]uint64{} + mask := new(big.Int).SetUint64(0xFFFFFFFFFFFFFFFF) + for i := 0; i < 4; i++ { + t := new(big.Int).And(v, mask) + limbs[i] = t.Uint64() + v.Rsh(v, 64) + } + return limbs +} + +func skFieldAdd(a, b *big.Int) *big.Int { + return new(big.Int).Mod(new(big.Int).Add(a, b), bn254Modulus) +} + +func skFieldMul(a, b *big.Int) *big.Int { + return new(big.Int).Mod(new(big.Int).Mul(a, b), bn254Modulus) +} + +func skFieldSquare(a *big.Int) *big.Int { + return skFieldMul(a, a) +} + +func sbox(v byte) byte { + notV := ^v + rotLeft1 := (notV << 1) | (notV >> 7) + rotLeft2 := (v << 2) | (v >> 6) + rotLeft3 := (v << 3) | (v >> 5) + xor := v ^ (rotLeft1 & rotLeft2 & rotLeft3) + return (xor << 1) | (xor >> 7) +} + +func skBar(x *big.Int) *big.Int { + bytes := make([]byte, 32) + xBytes := x.Bytes() + for i := 0; i < len(xBytes) && i < 32; i++ { + bytes[i] = xBytes[len(xBytes)-1-i] + } + rotated := make([]byte, 32) + copy(rotated[:16], bytes[16:]) + copy(rotated[16:], bytes[:16]) + for i := range rotated { + rotated[i] = sbox(rotated[i]) + } + result := new(big.Int) + for i := 31; i >= 0; i-- { + result.Lsh(result, 8) + result.Or(result, new(big.Int).SetUint64(uint64(rotated[i]))) + } + return result.Mod(result, bn254Modulus) +} + +func skSS(round int, l, r *big.Int) (*big.Int, *big.Int) { + lSq := skFieldSquare(l) + term := skFieldAdd(skFieldMul(lSq, sigmaInv), skyscraperRC[round]) + r = skFieldAdd(r, term) + l, r = r, l + + lSq = skFieldSquare(l) + term = skFieldAdd(skFieldMul(lSq, sigmaInv), skyscraperRC[round+1]) + r = skFieldAdd(r, term) + l, r = r, l + return l, r +} + +func skBB(round int, l, r *big.Int) (*big.Int, *big.Int) { + r = skFieldAdd(r, skFieldAdd(skBar(l), skyscraperRC[round])) + l, r = r, l + r = skFieldAdd(r, skFieldAdd(skBar(l), skyscraperRC[round+1])) + l, r = r, l + return l, r +} + +func skPermute(l, r *big.Int) (*big.Int, *big.Int) { + l, r = skSS(0, l, r) + l, r = skSS(2, l, r) + l, r = skSS(4, l, r) + l, r = skBB(6, l, r) + l, r = skSS(8, l, r) + l, r = skBB(10, l, r) + l, r = skSS(12, l, r) + l, r = skSS(14, l, r) + l, r = skSS(16, l, r) + return l, r +} + +// SkyscraperCompress compresses two field elements using the Skyscraper permutation. +func SkyscraperCompress(left, right [4]uint64) [4]uint64 { + l := skLimbsToBigInt(left) + r := skLimbsToBigInt(right) + t := new(big.Int).Set(l) + l, _ = skPermute(l, r) + return skBigIntToLimbs(skFieldAdd(l, t)) +} + +// HashLeafData hashes multiple Fp256 elements into a single KeccakDigest +// using iterative Skyscraper compression. +func HashLeafData(leafData []Fp256) Digest { + if len(leafData) == 0 { + return Digest{} + } + currentLimbs := leafData[0].Limbs + for i := 1; i < len(leafData); i++ { + currentLimbs = SkyscraperCompress(currentLimbs, leafData[i].Limbs) + } + var digest Digest + for i := 0; i < 4; i++ { + limb := currentLimbs[i] + for j := 0; j < 8; j++ { + digest.Digest[i*8+j] = byte(limb & 0xFF) + limb >>= 8 + } + } + return digest +} + +// HashTwoDigests hashes two KeccakDigest values together using Skyscraper compression. +func HashTwoDigests(left, right Digest) Digest { + leftLimbs := [4]uint64{} + rightLimbs := [4]uint64{} + for i := 0; i < 4; i++ { + var limb uint64 + for j := 0; j < 8; j++ { + limb |= uint64(left.Digest[i*8+j]) << (8 * j) + } + leftLimbs[i] = limb + } + for i := 0; i < 4; i++ { + var limb uint64 + for j := 0; j < 8; j++ { + limb |= uint64(right.Digest[i*8+j]) << (8 * j) + } + rightLimbs[i] = limb + } + resultLimbs := SkyscraperCompress(leftLimbs, rightLimbs) + var digest Digest + for i := 0; i < 4; i++ { + limb := resultLimbs[i] + for j := 0; j < 8; j++ { + digest.Digest[i*8+j] = byte(limb & 0xFF) + limb >>= 8 + } + } + return digest +} + +// DigestToFieldElement converts a KeccakDigest to a *big.Int field element. +func DigestToFieldElement(d Digest) *big.Int { + result := new(big.Int) + for i := 31; i >= 0; i-- { + result.Lsh(result, 8) + result.Add(result, big.NewInt(int64(d.Digest[i]))) + } + return result.Mod(result, bn254Modulus) +} diff --git a/recursive-verifier/app/circuit/types.go b/recursive-verifier/app/circuit/types.go index bc748b432..17e39fefa 100644 --- a/recursive-verifier/app/circuit/types.go +++ b/recursive-verifier/app/circuit/types.go @@ -2,14 +2,15 @@ package circuit import ( "reilabs/whir-verifier-circuit/app/utilities" + "reilabs/whir-verifier-circuit/app/whir" "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/std/math/uints" ) // Common types -type KeccakDigest struct { - KeccakDigest [32]uint8 +type Digest struct { + Digest [32]uint8 } type Fp256 struct { @@ -28,36 +29,29 @@ type FullMultiPath[Digest any] struct { // WHIR specific types type WHIRConfig struct { - NRounds int `json:"n_rounds"` - Rate int `json:"rate"` - NVars int `json:"n_vars"` - FoldingFactor []int `json:"folding_factor"` - OODSamples []int `json:"ood_samples"` - NumQueries []int `json:"num_queries"` - PowBits []int `json:"pow_bits"` - FinalQueries int `json:"final_queries"` - FinalPowBits int `json:"final_pow_bits"` - FinalFoldingPowBits int `json:"final_folding_pow_bits"` - DomainGenerator string `json:"domain_generator"` - BatchSize int `json:"batch_size"` -} - -type WHIRParams struct { - ParamNRounds int - FoldingFactorArray []int - RoundParametersOODSamples []int - RoundParametersNumOfQueries []int - PowBits []int - FinalQueries int - FinalPowBits int - FinalFoldingPowBits int - StartingDomainBackingDomainGenerator frontend.Variable - DomainSize int - CommitmentOODSamples int - FinalSumcheckRounds int - MVParamsNumberOfVariables int - BatchSize int -} + NRounds int `json:"n_rounds"` + Rate int `json:"rate"` + NVars int `json:"n_vars"` + FoldingFactor []int `json:"folding_factor"` + OODSamples []int `json:"ood_samples"` + NumQueries []int `json:"num_queries"` + PowBits []int `json:"pow_bits"` + PowThresholds []uint64 `json:"pow_thresholds"` + SumcheckPowThresholds []uint64 `json:"sumcheck_pow_thresholds"` + InitialSumcheckPowThreshold uint64 `json:"initial_sumcheck_pow_threshold"` + InitialSkipPowThreshold uint64 `json:"initial_skip_pow_threshold"` + FinalQueries int `json:"final_queries"` + FinalPowBits int `json:"final_pow_bits"` + FinalPowThreshold uint64 `json:"final_pow_threshold"` + FinalFoldingPowBits int `json:"final_folding_pow_bits"` + FinalFoldingPowThreshold uint64 `json:"final_folding_pow_threshold"` + DomainGenerator string `json:"domain_generator"` + BatchSize int `json:"batch_size"` + InitialInDomainSamples int `json:"initial_in_domain_samples"` // initial_committer.in_domain_samples (num queries for zkWHIR in-domain verification) +} + +// WHIRParams is an alias for whir.WHIRParams to avoid duplicate struct definitions. +type WHIRParams = whir.WHIRParams type MainRoundData struct { OODPoints [][]frontend.Variable @@ -71,13 +65,6 @@ type InitialSumcheckData struct { } // Merkle specific types -type MerklePaths struct { - Leaves [][][]frontend.Variable - LeafIndexes [][]uints.U64 - LeafSiblingHashes [][][]uints.U8 - AuthPaths [][][][]uints.U8 -} - type Merkle struct { Leaves [][][]frontend.Variable LeafIndexes [][]uints.U64 @@ -85,41 +72,29 @@ type Merkle struct { AuthPaths [][][]frontend.Variable } -// Other types -type ProofObject struct { - StatementValuesAtRandomPoint []Fp256 `json:"statement_values_at_random_point"` -} - +// Config matches the Rust GnarkConfig struct. +// narg_string + hints are the spongefish proof buffers. +// protocol_id is SHA3-512(CBOR(WhirR1CSScheme)); session_id is optional (default zero) for domain separation. type Config struct { - WHIRConfigWitness WHIRConfig `json:"whir_config_witness"` - WHIRConfigHidingSpartan WHIRConfig `json:"whir_config_hiding_spartan"` + BlindedCommitmentWhirConfig WHIRConfig `json:"blinded_commitment_whir_config"` + BlindingCommitmentWhirConfig WHIRConfig `json:"blinding_commitment_whir_config"` LogNumConstraints int `json:"log_num_constraints"` LogNumVariables int `json:"log_num_variables"` LogANumTerms int `json:"log_a_num_terms"` - IOPattern string `json:"io_pattern"` - Transcript []byte `json:"transcript"` - TranscriptLen int `json:"transcript_len"` - WitnessStatementEvaluations []string `json:"witness_statement_evaluations"` - BlindingStatementEvaluations []string `json:"blinding_statement_evaluations"` + NargString []byte `json:"narg_string"` + NargStringLen int `json:"narg_string_len"` + Hints []byte `json:"hints"` + HintsLen int `json:"hints_len"` + ProtocolID []byte `json:"protocol_id"` + SessionID []byte `json:"session_id"` NumChallenges int `json:"num_challenges"` + ChallengeOffsets []int `json:"challenge_offsets"` W1Size int `json:"w1_size"` PublicInputs PublicInputs `json:"public_inputs"` } -// Update Hints to support batch mode -type Hints struct { - spartanHidingHint ZKHint - - // Witness hints (length 1 for single mode, N for batch mode) - WitnessFirstRoundHints []FirstRoundHint - - // Single mode: rounds 1+ for the one commitment - // Batch mode: rounds 1+ for batched polynomial - WitnessRoundHints ZKHint -} - type Hint struct { - merklePaths []FullMultiPath[KeccakDigest] + merklePaths []FullMultiPath[Digest] stirAnswers [][][]Fp256 } @@ -133,16 +108,6 @@ type ZKHint struct { roundHints Hint } -type ClaimedEvaluations struct { - FSums []Fp256 - GSums []Fp256 -} - -type DualClaimedEvaluations struct { - First ClaimedEvaluations - Second ClaimedEvaluations -} - type PublicInputs struct { Values []frontend.Variable } diff --git a/recursive-verifier/app/circuit/utilities.go b/recursive-verifier/app/circuit/utilities.go index c4bcbf8f8..9a3806e5c 100644 --- a/recursive-verifier/app/circuit/utilities.go +++ b/recursive-verifier/app/circuit/utilities.go @@ -1,17 +1,7 @@ package circuit import ( - "bytes" - "fmt" - "io" - "log" - "net/http" - "os" - - "github.com/consensys/gnark-crypto/ecc" - "github.com/consensys/gnark/backend/groth16" - - "reilabs/whir-verifier-circuit/app/utilities" + "reilabs/whir-verifier-circuit/app/whir" "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/std/math/uints" @@ -19,17 +9,14 @@ import ( skyscraper "github.com/reilabs/gnark-skyscraper" ) -func calculateEQ(api frontend.API, alphas []frontend.Variable, r []frontend.Variable) frontend.Variable { - ans := frontend.Variable(1) - for i, alpha := range alphas { - ans = api.Mul(ans, api.Add(api.Mul(alpha, r[i]), api.Mul(api.Sub(frontend.Variable(1), alpha), api.Sub(frontend.Variable(1), r[i])))) - } - return ans -} - -func initializeComponents(api frontend.API, circuit *Circuit) (*skyscraper.Skyscraper, gnarkNimue.Arthur, *uints.BinaryField[uints.U64], error) { +func initializeComponents(api frontend.API, circuit *Circuit) (*skyscraper.Skyscraper, gnarkNimue.Nimue, *uints.BinaryField[uints.U64], error) { sc := skyscraper.NewSkyscraper(api, 2) - arthur, err := gnarkNimue.NewSkyscraperArthur(api, sc, circuit.IO, circuit.Transcript[:], true) + + // Compute InstanceID in-circuit from public inputs, matching Rust's PublicInputs::hash_bytes(). + initData := circuit.InitializationData + initData.InstanceID = publicInputsHash(sc, circuit.PublicInputs) + + nimue, err := gnarkNimue.NewSkyscraperNimue(api, sc, initData, circuit.Transcript[:]) if err != nil { return nil, nil, nil, err } @@ -37,105 +24,12 @@ func initializeComponents(api frontend.API, circuit *Circuit) (*skyscraper.Skysc if err != nil { return nil, nil, nil, err } - return sc, arthur, uapi, nil -} - -func keysFromFiles(pkPath string, vkPath string) (groth16.ProvingKey, groth16.VerifyingKey, error) { - pkFile, err := os.Open(pkPath) - if err != nil { - return nil, nil, fmt.Errorf("failed to open proving key file: %w", err) - } - defer func(pkFile *os.File) { - err := pkFile.Close() - if err != nil { - log.Printf("failed to close proving key file: %v", err) - } - }(pkFile) - - pk := groth16.NewProvingKey(ecc.BN254) - _, err = pk.ReadFrom(pkFile) - if err != nil { - return nil, nil, fmt.Errorf("failed to restore proving key: %w", err) - } - - vkFile, err := os.Open(vkPath) - if err != nil { - return nil, nil, fmt.Errorf("failed to open verifying key file: %w", err) - } - defer func(vkFile *os.File) { - err := vkFile.Close() - if err != nil { - log.Printf("failed to close verifying key file: %v", err) - } - }(vkFile) - - vk := groth16.NewVerifyingKey(ecc.BN254) - _, err = vk.ReadFrom(vkFile) - if err != nil { - return nil, nil, fmt.Errorf("failed to restore verifying key: %w", err) - } - - return pk, vk, nil -} - -func keysFromUrl(pkUrl string, vkUrl string) (groth16.ProvingKey, groth16.VerifyingKey, error) { - vkBytes, err := downloadFromUrl(vkUrl) - if err != nil { - return nil, nil, fmt.Errorf("failed to download verifying key: %w", err) - } - log.Printf("Downloaded VK") - - vk := groth16.NewVerifyingKey(ecc.BN254) - _, err = vk.UnsafeReadFrom(bytes.NewReader(vkBytes)) - if err != nil { - return nil, nil, fmt.Errorf("failed to deserialize verifying key: %w", err) - } - log.Printf("Loaded VK") - - pkBytes, err := downloadFromUrl(pkUrl) - if err != nil { - return nil, nil, fmt.Errorf("failed to download proving key: %v", err) - } - log.Printf("Downloaded PK") - - pk := groth16.NewProvingKey(ecc.BN254) - _, err = pk.UnsafeReadFrom(bytes.NewReader(pkBytes)) - if err != nil { - return nil, nil, fmt.Errorf("failed to deserialize proving key: %w", err) - } - log.Printf("Loaded PK") - - return pk, vk, nil -} - -func downloadFromUrl(url string) ([]byte, error) { - resp, err := http.Get(url) - if err != nil { - return nil, fmt.Errorf("failed to download from %s: %w", url, err) - } - defer func() { - if closeErr := resp.Body.Close(); closeErr != nil { - log.Printf("Warning: failed to close response body: %v", closeErr) - } - }() - - if resp.StatusCode != http.StatusOK { - return nil, fmt.Errorf("HTTP error %d when downloading from %s", resp.StatusCode, url) - } - - buffer := &bytes.Buffer{} - - _, err = io.Copy(buffer, resp.Body) - if err != nil { - return nil, fmt.Errorf("failed to copy to buffer: %w", err) - } - - return buffer.Bytes(), nil + return sc, nimue, uapi, nil } func runSumcheck( api frontend.API, - arthur gnarkNimue.Arthur, + nimue gnarkNimue.Nimue, lastEval frontend.Variable, foldingFactor int, polynomialDegree int, @@ -145,70 +39,103 @@ func runSumcheck( foldingRandomnessTemp := make([]frontend.Variable, 1) for i := range foldingFactor { - if err := arthur.FillNextScalars(sumcheckPolynomial); err != nil { + if err := nimue.FillNextScalars(sumcheckPolynomial); err != nil { return nil, nil, err } - if err := arthur.FillChallengeScalars(foldingRandomnessTemp); err != nil { + if err := nimue.FillChallengeScalars(foldingRandomnessTemp); err != nil { return nil, nil, err } foldingRandomness[i] = foldingRandomnessTemp[0] sumcheckVal := api.Add( - utilities.UnivarPoly(api, sumcheckPolynomial, []frontend.Variable{0})[0], - utilities.UnivarPoly(api, sumcheckPolynomial, []frontend.Variable{1})[0], + whir.UnivarPoly(api, sumcheckPolynomial, []frontend.Variable{0})[0], + whir.UnivarPoly(api, sumcheckPolynomial, []frontend.Variable{1})[0], ) api.AssertIsEqual(sumcheckVal, lastEval) - lastEval = utilities.UnivarPoly(api, sumcheckPolynomial, []frontend.Variable{foldingRandomness[i]})[0] + lastEval = whir.UnivarPoly(api, sumcheckPolynomial, []frontend.Variable{foldingRandomness[i]})[0] } return foldingRandomness, lastEval, nil } +// runZKSumcheck replays the ZK Spartan sumcheck transcript. +// Returns (tRand, alpha, fAtAlpha, blindingEval, error) where: +// - tRand: verifier randomness r (length m0) +// - alpha: folding challenges from the sumcheck rounds (length m0) +// - fAtAlpha: the unblinded final evaluation f(alpha) +// - blindingEval: evaluation of the blinding polynomial at alpha func runZKSumcheck( api frontend.API, sc *skyscraper.Skyscraper, uapi *uints.BinaryField[uints.U64], circuit *Circuit, - arthur gnarkNimue.Arthur, + nimue gnarkNimue.Nimue, lastEval frontend.Variable, foldingFactor int, polynomialDegree int, - whirParams WHIRParams, -) ([]frontend.Variable, frontend.Variable, error) { - rootHash, batchingRandomness, initialOODQueries, initialOODAnswers, err := parseBatchedCommitment(arthur, whirParams) +) ([]frontend.Variable, []frontend.Variable, frontend.Variable, frontend.Variable, error) { + tRand := make([]frontend.Variable, circuit.LogNumConstraints) + err := nimue.FillChallengeScalars(tRand) if err != nil { - return nil, nil, err + return nil, nil, nil, nil, err } - sumOfG, rhoRandomness, err := getZKSumcheckInitialValue(arthur) + sumOfG, rhoRandomness, err := getZKSumcheckInitialValue(nimue) if err != nil { - return nil, nil, err + return nil, nil, nil, nil, err } lastEval = api.Add(lastEval, api.Mul(sumOfG, rhoRandomness)) - foldingRandomness, lastEval, err := runSumcheck(api, arthur, lastEval, foldingFactor, polynomialDegree) + foldingRandomness, lastEval, err := runSumcheck(api, nimue, lastEval, foldingFactor, polynomialDegree) if err != nil { - return nil, nil, err + return nil, nil, nil, nil, err } - lastEval, polynomialSums := unblindLastEval(api, arthur, lastEval, rhoRandomness) + lastEval, polynomialSums := unblindLastEval(api, nimue, lastEval, rhoRandomness) - _, err = RunZKWhir(api, arthur, uapi, sc, circuit.HidingSpartanMerkle, circuit.HidingSpartanFirstRound, whirParams, [][]frontend.Variable{{polynomialSums[0]}, {polynomialSums[1]}}, circuit.HidingSpartanLinearStatementEvaluations, batchingRandomness, initialOODQueries, initialOODAnswers, rootHash) - if err != nil { - return nil, nil, err + return tRand, foldingRandomness, lastEval, polynomialSums[0], nil +} + +func publicInputsHash(sc *skyscraper.Skyscraper, publicInputs PublicInputs) frontend.Variable { + var expectedHash frontend.Variable + switch len(publicInputs.Values) { + case 0: + expectedHash = frontend.Variable(0) + case 1: + expectedHash = sc.CompressV2(publicInputs.Values[0], frontend.Variable(0)) + default: + expectedHash = publicInputs.Values[0] + for i := 1; i < len(publicInputs.Values); i++ { + expectedHash = sc.CompressV2(expectedHash, publicInputs.Values[i]) + } } + return expectedHash +} - return foldingRandomness, lastEval, nil +func publicInputsHashCheck( + api frontend.API, + sc *skyscraper.Skyscraper, + nimue gnarkNimue.Nimue, + publicInputs PublicInputs, +) error { + publicInputsHashBuf := make([]frontend.Variable, 1) + if err := nimue.FillNextScalars(publicInputsHashBuf); err != nil { + return err + } + + expectedHash := publicInputsHash(sc, publicInputs) + api.AssertIsEqual(publicInputsHashBuf[0], expectedHash) + return nil } func getZKSumcheckInitialValue( - arthur gnarkNimue.Arthur, + nimue gnarkNimue.Nimue, ) (frontend.Variable, frontend.Variable, error) { sumOfG := make([]frontend.Variable, 1) rhoRandomness := make([]frontend.Variable, 1) - if err := arthur.FillNextScalars(sumOfG); err != nil { + if err := nimue.FillNextScalars(sumOfG); err != nil { return nil, nil, err } - if err := arthur.FillChallengeScalars(rhoRandomness); err != nil { + if err := nimue.FillChallengeScalars(rhoRandomness); err != nil { return nil, nil, err } return sumOfG[0], rhoRandomness[0], nil @@ -216,12 +143,12 @@ func getZKSumcheckInitialValue( func unblindLastEval( api frontend.API, - arthur gnarkNimue.Arthur, + nimue gnarkNimue.Nimue, lastEval frontend.Variable, rhoRandomness frontend.Variable, ) (frontend.Variable, []frontend.Variable) { - polynomialSums := make([]frontend.Variable, 2) - if err := arthur.FillNextScalars(polynomialSums); err != nil { + polynomialSums := make([]frontend.Variable, 1) + if err := nimue.FillNextScalars(polynomialSums); err != nil { return 0, nil } @@ -239,7 +166,7 @@ func consumeFront[T any](slice *[]T) T { return head } -func consumeWhirData(whirConfig WHIRConfig, merkle_paths *[]FullMultiPath[KeccakDigest], stir_answers *[][][]Fp256) ZKHint { +func consumeWhirData(whirConfig WHIRConfig, merkle_paths *[]FullMultiPath[Digest], stir_answers *[][][]Fp256) ZKHint { var zkHint ZKHint if len(*merkle_paths) > 0 && len(*stir_answers) > 0 { @@ -248,7 +175,7 @@ func consumeWhirData(whirConfig WHIRConfig, merkle_paths *[]FullMultiPath[Keccak zkHint.firstRoundMerklePaths = FirstRoundHint{ path: Hint{ - merklePaths: []FullMultiPath[KeccakDigest]{firstRoundMerklePath}, + merklePaths: []FullMultiPath[Digest]{firstRoundMerklePath}, stirAnswers: [][][]Fp256{firstRoundStirAnswers}, }, expectedStirAnswers: firstRoundStirAnswers, @@ -257,7 +184,7 @@ func consumeWhirData(whirConfig WHIRConfig, merkle_paths *[]FullMultiPath[Keccak expectedRounds := whirConfig.NRounds - var remainingMerklePaths []FullMultiPath[KeccakDigest] + var remainingMerklePaths []FullMultiPath[Digest] var remainingStirAnswers [][][]Fp256 for i := 0; i < expectedRounds && len(*merkle_paths) > 0 && len(*stir_answers) > 0; i++ { @@ -272,47 +199,3 @@ func consumeWhirData(whirConfig WHIRConfig, merkle_paths *[]FullMultiPath[Keccak return zkHint } - -// consumeFirstRoundOnly consumes only the first round hint (no subsequent rounds) -// Used for batch mode where each original commitment has its own first round -func consumeFirstRoundOnly(merklePaths *[]FullMultiPath[KeccakDigest], stirAnswers *[][][]Fp256) FirstRoundHint { - var hint FirstRoundHint - - if len(*merklePaths) > 0 && len(*stirAnswers) > 0 { - firstRoundMerklePath := consumeFront(merklePaths) - firstRoundStirAnswers := consumeFront(stirAnswers) - - hint = FirstRoundHint{ - path: Hint{ - merklePaths: []FullMultiPath[KeccakDigest]{firstRoundMerklePath}, - stirAnswers: [][][]Fp256{firstRoundStirAnswers}, - }, - expectedStirAnswers: firstRoundStirAnswers, - } - } - - return hint -} - -// consumeWhirDataRoundsOnly consumes only the round hints (not first round) -// Used for batched polynomial in batch mode -func consumeWhirDataRoundsOnly(whirConfig WHIRConfig, merklePaths *[]FullMultiPath[KeccakDigest], stirAnswers *[][][]Fp256) ZKHint { - var zkHint ZKHint - - expectedRounds := whirConfig.NRounds - - var remainingMerklePaths []FullMultiPath[KeccakDigest] - var remainingStirAnswers [][][]Fp256 - - for i := 0; i < expectedRounds && len(*merklePaths) > 0 && len(*stirAnswers) > 0; i++ { - remainingMerklePaths = append(remainingMerklePaths, consumeFront(merklePaths)) - remainingStirAnswers = append(remainingStirAnswers, consumeFront(stirAnswers)) - } - - zkHint.roundHints = Hint{ - merklePaths: remainingMerklePaths, - stirAnswers: remainingStirAnswers, - } - - return zkHint -} diff --git a/recursive-verifier/app/circuit/whir.go b/recursive-verifier/app/circuit/whir.go index e6d4a1d70..afb892b8a 100644 --- a/recursive-verifier/app/circuit/whir.go +++ b/recursive-verifier/app/circuit/whir.go @@ -1,19 +1,31 @@ package circuit import ( - "fmt" "math/big" - - "reilabs/whir-verifier-circuit/app/utilities" + "math/bits" "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/std/math/uints" gnarkNimue "github.com/reilabs/gnark-nimue" - skyscraper "github.com/reilabs/gnark-skyscraper" ) // NewWhirParams creates a new WHIRParams instance from the given configuration. // It processes the folding factors and calculates domain sizes based on the provided config. +// bn254TwoAdicRootOfUnity is the primitive 2^28-th root of unity for BN254 Fr. +// This matches arkworks bn254::Fr::TWO_ADIC_ROOT_OF_UNITY. +var bn254TwoAdicRootOfUnity, _ = new(big.Int).SetString( + "19103219067921713944291392827692070036145651957329286315305642004821462161904", 10) + +const bn254TwoAdicity = 28 + +// computeNTTGenerator computes the primitive root of unity of order `domainSize` +// for the BN254 scalar field. domainSize must be a power of two <= 2^28. +func computeNTTGenerator(domainSize int) *big.Int { + // generator(N) = two_adic_root ^ (2^28 / N) + exp := new(big.Int).SetUint64(uint64(1 << bn254TwoAdicity / domainSize)) + return new(big.Int).Exp(bn254TwoAdicRootOfUnity, exp, bn254Modulus) +} + func NewWhirParams(cfg WHIRConfig) WHIRParams { startingDomainGen, _ := new(big.Int).SetString(cfg.DomainGenerator, 10) mvParamsNumberOfVariables := cfg.NVars @@ -28,6 +40,13 @@ func NewWhirParams(cfg WHIRConfig) WHIRParams { finalSumcheckRounds = mvParamsNumberOfVariables % 4 } domainSize := (2 << mvParamsNumberOfVariables) * (1 << cfg.Rate) / 2 + interleavingDepth := 1 << foldingFactor[0] + + // Compute omega_full (generator of full domain of size DomainSize) and + // zeta (interleaving coset generator = omega_full^codeword_length). + omegaFull := computeNTTGenerator(domainSize) + codewordLength := domainSize / interleavingDepth + zeta := new(big.Int).Exp(omegaFull, big.NewInt(int64(codewordLength)), bn254Modulus) return WHIRParams{ ParamNRounds: cfg.NRounds, @@ -44,735 +63,39 @@ func NewWhirParams(cfg WHIRConfig) WHIRParams { FinalSumcheckRounds: finalSumcheckRounds, MVParamsNumberOfVariables: mvParamsNumberOfVariables, BatchSize: cfg.BatchSize, + InitialInDomainSamples: cfg.InitialInDomainSamples, + OmegaFull: *omegaFull, + Zeta: *zeta, } } -// RunZKWhir executes the zero-knowledge WHIR protocol for proof verification. -// It processes multiple rounds of sumcheck protocols and merkle tree verifications -// to verify the given circuit proof against the provided parameters. -func RunZKWhir( +func getStirChallenges( api frontend.API, - arthur gnarkNimue.Arthur, - uapi *uints.BinaryField[uints.U64], - sc *skyscraper.Skyscraper, - circuit Merkle, - firstRound Merkle, - whirParams WHIRParams, - linearStatementEvaluations [][]frontend.Variable, - linearStatementValuesAtPoints []frontend.Variable, - batchingRandomness frontend.Variable, - initialOODQueries []frontend.Variable, - initialOODAnswers [][]frontend.Variable, - rootHashes frontend.Variable, -) (totalFoldingRandomness []frontend.Variable, err error) { - initialOODs := oodAnswers(api, initialOODAnswers, batchingRandomness) - // batchSizeLen := whirParams.BatchSize - - initialSumcheckData, lastEval, initialSumcheckFoldingRandomness, err := initialSumcheck(api, arthur, batchingRandomness, initialOODQueries, initialOODs, whirParams, linearStatementEvaluations) - if err != nil { - return - } - - copyOfFirstLeaves := make([][][]frontend.Variable, len(firstRound.Leaves)) - for i := range len(firstRound.Leaves) { - copyOfFirstLeaves[i] = make([][]frontend.Variable, len(firstRound.Leaves[i])) - for j := range len(firstRound.Leaves[i]) { - copyOfFirstLeaves[i][j] = make([]frontend.Variable, len(firstRound.Leaves[i][j])) - for k := range len(firstRound.Leaves[i][j]) { - copyOfFirstLeaves[i][j][k] = firstRound.Leaves[i][j][k] - } - } - } - - roundAnswers := make([][][]frontend.Variable, len(circuit.Leaves)+1) - - foldSize := 1 << whirParams.FoldingFactorArray[0] - collapsed := rlcBatchedLeaves(api, firstRound.Leaves[0], foldSize, whirParams.BatchSize, batchingRandomness) - roundAnswers[0] = collapsed - - for i := range len(circuit.Leaves) { - roundAnswers[i+1] = circuit.Leaves[i] - } - - computedFold := computeFold(collapsed, initialSumcheckFoldingRandomness, api) - - mainRoundData := generateEmptyMainRoundData(whirParams) - expDomainGenerator := utilities.Exponent(api, uapi, whirParams.StartingDomainBackingDomainGenerator, uints.NewU64(uint64(1< 0 { - _, _, err = utilities.PoW(api, sc, arthur, whirParams.FinalFoldingPowBits) - if err != nil { - return - } - } - - totalFoldingRandomness = utilities.Reverse(totalFoldingRandomness) - - evaluationOfWPoly := computeWPoly( - api, - whirParams, - initialSumcheckData, - mainRoundData, - totalFoldingRandomness, - linearStatementValuesAtPoints, - ) - - api.AssertIsEqual( - lastEval, - api.Mul(evaluationOfWPoly, utilities.MultivarPoly(finalCoefficients, finalSumcheckRandomness, api)), - ) - - return totalFoldingRandomness, nil -} - -// RunZKWhirBatch executes batch WHIR verification for N commitments. -// This is used when num_challenges > 0 (batch commitment mode). -func RunZKWhirBatch( - api frontend.API, - arthur gnarkNimue.Arthur, - uapi *uints.BinaryField[uints.U64], - sc *skyscraper.Skyscraper, - // N commitments data for round 0 - firstRounds []Merkle, - batchingRandomnesses []frontend.Variable, - initialOODQueries [][]frontend.Variable, - initialOODAnswers [][][]frontend.Variable, - rootHashes []frontend.Variable, - // Batched polynomial merkle for rounds 1+ - batchedMerkle Merkle, - // Statement evaluations per commitment - linearStatementEvals [][][]frontend.Variable, // [commitment_idx][f/g][eval_idx] - // Common parameters - whirParams WHIRParams, - linearStatementValuesAtPoints []frontend.Variable, - publicInputs PublicInputs, -) (totalFoldingRandomness []frontend.Variable, err error) { - numPolynomials := len(firstRounds) - if numPolynomials == 0 { - return nil, fmt.Errorf("RunZKWhirBatch: need at least one commitment") - } - - // Step 1: Reduce OOD answers for each commitment - initialOODs := make([][]frontend.Variable, numPolynomials) - for i := 0; i < numPolynomials; i++ { - initialOODs[i] = oodAnswers(api, initialOODAnswers[i], batchingRandomnesses[i]) - } - - // Step 2: Count total constraints (OOD + statement per commitment) - numOOD := 0 - for i := 0; i < numPolynomials; i++ { - numOOD += len(initialOODQueries[i]) - } - - numStatementConstraints := 0 - - // w1 has 4 (pub, Az, Bz, Cz) constraints, w2 and remaining have 3 (Az, Bz, Cz) constraints - if !publicInputs.IsEmpty() { - numStatementConstraints = 4 + 3*(numPolynomials-1) - } else { - numStatementConstraints = numPolynomials * 3 - } - numConstraints := numOOD + numStatementConstraints - - // Step 3: Read N×M evaluation matrix from transcript - evalMatrix := make([][]frontend.Variable, numPolynomials) - for i := 0; i < numPolynomials; i++ { - evalMatrix[i] = make([]frontend.Variable, numConstraints) - if err = arthur.FillNextScalars(evalMatrix[i]); err != nil { - return nil, err - } - } - - // Step 4: Squeeze batching randomness γ - gamma := make([]frontend.Variable, 1) - if err = arthur.FillChallengeScalars(gamma); err != nil { - return nil, err - } - batchGamma := gamma[0] - - // Precompute powers of gamma: [1, γ, γ^2, ..., γ^(numPolynomials-1)] - gammaPowers := make([]frontend.Variable, numPolynomials) - gammaPowers[0] = frontend.Variable(1) - for i := 1; i < numPolynomials; i++ { - gammaPowers[i] = api.Mul(gammaPowers[i-1], batchGamma) - } - - // Step 5: RLC-combine constraint evaluations: combined[j] = Σᵢ γⁱ·eval[i][j] - combinedEvals := make([]frontend.Variable, numConstraints) - for j := 0; j < numConstraints; j++ { - combined := frontend.Variable(0) - for i := 0; i < numPolynomials; i++ { - combined = api.Add(combined, api.Mul(gammaPowers[i], evalMatrix[i][j])) - } - combinedEvals[j] = combined - } - - // Step 6: Initial combination randomness and claimed sum - initialCombRandomness, err := GenerateCombinationRandomness(api, arthur, numConstraints) - if err != nil { - return nil, err - } - lastEval := utilities.DotProduct(api, initialCombRandomness, combinedEvals) - - // Step 7: Initial sumcheck - initialFoldingRandomness, lastEval, err := runWhirSumcheckRounds( - api, lastEval, arthur, whirParams.FoldingFactorArray[0], 3) - if err != nil { - return nil, err - } - totalFoldingRandomness = initialFoldingRandomness - - // ======================================== - // ROUND 0: Batch-specific verification - // Verify STIR queries against ALL N original trees - // ======================================== - - // Read commitment to batched folded polynomial (used in rounds 1+) - batchedRootHash := make([]frontend.Variable, 1) - if err = arthur.FillNextScalars(batchedRootHash); err != nil { - return nil, err - } - - // Read OOD for batched polynomial - round0OODPoints, round0OODAnswers, err := fillInOODPointsAndAnswers( - whirParams.RoundParametersOODSamples[0], arthur) - if err != nil { - return nil, err - } - - // PoW for round 0 - if err = RunPoW(api, sc, arthur, whirParams.PowBits[0]); err != nil { - return nil, err - } - - // Get STIR challenge indices - domainSize := whirParams.DomainSize - foldSize := 1 << whirParams.FoldingFactorArray[0] - stirChallengeIndices, err := getStirChallenges( - api, arthur, whirParams.RoundParametersNumOfQueries[0], domainSize, foldSize) - if err != nil { - return nil, err - } - - expDomainGenerator := utilities.Exponent(api, uapi, - whirParams.StartingDomainBackingDomainGenerator, - uints.NewU64(uint64(foldSize))) - - // Verify Merkle proofs in ALL N original trees - collapsedAnswers := make([][][]frontend.Variable, numPolynomials) - for i := 0; i < numPolynomials; i++ { - // Check indices match - err = utilities.IsEqual(api, uapi, stirChallengeIndices, firstRounds[i].LeafIndexes[0]) - if err != nil { - return nil, err - } - // Verify Merkle proofs - err = verifyMerkleTreeProofs(api, uapi, sc, - firstRounds[i].LeafIndexes[0], - firstRounds[i].Leaves[0], - firstRounds[i].LeafSiblingHashes[0], - firstRounds[i].AuthPaths[0], - rootHashes[i]) - if err != nil { - return nil, err - } - // Collapse batched leaves using commitment's batching randomness - collapsedAnswers[i] = rlcBatchedLeaves(api, - firstRounds[i].Leaves[0], foldSize, whirParams.BatchSize, batchingRandomnesses[i]) - } - - // RLC-combine answers across N polynomials: combined[q][f] = Σᵢ γⁱ·collapsed[i][q][f] - numQueries := len(collapsedAnswers[0]) - combinedAnswers := make([][]frontend.Variable, numQueries) - for q := 0; q < numQueries; q++ { - combinedAnswers[q] = make([]frontend.Variable, foldSize) - for f := 0; f < foldSize; f++ { - combined := frontend.Variable(0) - for i := 0; i < numPolynomials; i++ { - combined = api.Add(combined, api.Mul(gammaPowers[i], collapsedAnswers[i][q][f])) - } - combinedAnswers[q][f] = combined - } - } - - // Compute fold evaluations from combined answers - computedFold := computeFold(combinedAnswers, initialFoldingRandomness, api) - - // Convert STIR indices to domain points - stirChallengePoints := make([]frontend.Variable, numQueries) - for idx := range firstRounds[0].LeafIndexes[0] { - stirChallengePoints[idx] = utilities.Exponent(api, uapi, expDomainGenerator, firstRounds[0].LeafIndexes[0][idx]) - } - - // Combination randomness for round 0 constraints (OOD + STIR) - round0CombRandomness, err := GenerateCombinationRandomness(api, arthur, - len(round0OODPoints)+len(computedFold)) - if err != nil { - return nil, err - } - - // Update claimed sum with OOD and STIR constraints - lastEval = api.Add(lastEval, calculateShiftValue(round0OODAnswers, round0CombRandomness, computedFold, api)) - - // Sumcheck for round 0 - round0FoldingRandomness, lastEval, err := runWhirSumcheckRounds( - api, lastEval, arthur, whirParams.FoldingFactorArray[1], 3) - if err != nil { - return nil, err - } - totalFoldingRandomness = append(totalFoldingRandomness, round0FoldingRandomness...) - - // Update domain - domainSize /= 2 - expDomainGenerator = api.Mul(expDomainGenerator, expDomainGenerator) - - // Prepare for rounds 1+ - mainRoundData := generateEmptyMainRoundData(whirParams) - mainRoundData.OODPoints[0] = round0OODPoints - mainRoundData.StirChallengesPoints[0] = stirChallengePoints - mainRoundData.CombinationRandomness[0] = round0CombRandomness - - rootHashList := make([]frontend.Variable, whirParams.ParamNRounds) - rootHashList[0] = batchedRootHash[0] - - // Update computedFold for next round - if len(batchedMerkle.Leaves) > 0 { - computedFold = computeFold(batchedMerkle.Leaves[0], round0FoldingRandomness, api) - } - - // ======================================== - // ROUNDS 1+: Standard WHIR on batched polynomial - // ======================================== - for r := 1; r < whirParams.ParamNRounds; r++ { - rootHash := make([]frontend.Variable, 1) - if err = arthur.FillNextScalars(rootHash); err != nil { - return nil, err - } - rootHashList[r] = rootHash[0] - - var roundOODAnswers []frontend.Variable - mainRoundData.OODPoints[r], roundOODAnswers, err = fillInOODPointsAndAnswers( - whirParams.RoundParametersOODSamples[r], arthur) - if err != nil { - return nil, err - } - - if err = RunPoW(api, sc, arthur, whirParams.PowBits[r]); err != nil { - return nil, err - } - - // Get STIR challenges and verify against batched merkle - mainRoundData.StirChallengesPoints[r], err = getStirChallenges( - api, arthur, whirParams.RoundParametersNumOfQueries[r], domainSize, 1< 0 { - _, _, err = utilities.PoW(api, sc, arthur, whirParams.FinalFoldingPowBits) - if err != nil { - return nil, err - } - } - - // Reverse randomness for W-poly evaluation - totalFoldingRandomness = utilities.Reverse(totalFoldingRandomness) - - // Build combined initial sumcheck data - allOODQueries := make([]frontend.Variable, 0) - for i := 0; i < numPolynomials; i++ { - allOODQueries = append(allOODQueries, initialOODQueries[i]...) - } - initialSumcheckData := InitialSumcheckData{ - InitialOODQueries: allOODQueries, - InitialCombinationRandomness: initialCombRandomness, - } - - // Compute W-poly evaluation - evaluationOfWPoly := computeWPoly( - api, - whirParams, - initialSumcheckData, - mainRoundData, - totalFoldingRandomness, - linearStatementValuesAtPoints, - ) - - // Final check - api.AssertIsEqual( - lastEval, - api.Mul(evaluationOfWPoly, utilities.MultivarPoly(finalCoefficients, finalSumcheckRandomness, api)), - ) - - return totalFoldingRandomness, nil -} - -//nolint:unused -func runWhir( - api frontend.API, - arthur gnarkNimue.Arthur, - uapi *uints.BinaryField[uints.U64], - sc *skyscraper.Skyscraper, - circuit Merkle, - whirParams WHIRParams, - linearStatementEvaluations []frontend.Variable, - linearStatementValuesAtPoints []frontend.Variable, -) (totalFoldingRandomness []frontend.Variable, err error) { - if err = fillInAndVerifyRootHash(0, api, uapi, sc, circuit, arthur); err != nil { - return - } - - initialOODQueries, initialOODAnswers, tempErr := fillInOODPointsAndAnswers(whirParams.CommitmentOODSamples, arthur) - if tempErr != nil { - err = tempErr - return - } - - initialCombinationRandomness, tempErr := GenerateCombinationRandomness(api, arthur, whirParams.CommitmentOODSamples+len(linearStatementEvaluations)) - if tempErr != nil { - err = tempErr - return - } - - OODAnswersAndStatementEvaluations := append(initialOODAnswers, linearStatementEvaluations...) - lastEval := utilities.DotProduct(api, initialCombinationRandomness, OODAnswersAndStatementEvaluations) - - initialSumcheckFoldingRandomness, lastEval, tempErr := runWhirSumcheckRounds(api, lastEval, arthur, whirParams.FoldingFactorArray[0], 3) - if tempErr != nil { - err = tempErr - return - } - - initialData := InitialSumcheckData{ - InitialOODQueries: initialOODQueries, - InitialCombinationRandomness: initialCombinationRandomness, - } - - computedFold := computeFold(circuit.Leaves[0], initialSumcheckFoldingRandomness, api) - - mainRoundData := generateEmptyMainRoundData(whirParams) - - expDomainGenerator := utilities.Exponent(api, uapi, whirParams.StartingDomainBackingDomainGenerator, uints.NewU64(uint64(1< 0 { - _, _, err := utilities.PoW(api, sc, arthur, difficulty) - if err != nil { - return err - } - } - return nil -} - -// GenerateStirChallengePoints generates the stir challenge points for the given parameters. -// It calculates the folding factor power and generates the stir challenges for the given leaf indexes. -func GenerateStirChallengePoints( - api frontend.API, - arthur gnarkNimue.Arthur, - NQueries int, - leafIndexes []uints.U64, + nimue gnarkNimue.Nimue, + numQueries int, domainSize int, - uapi *uints.BinaryField[uints.U64], - expDomainGenerator frontend.Variable, - foldingFactor int, + foldingFactorPower int, ) ([]frontend.Variable, error) { - foldingFactorPower := 1 << foldingFactor - finalIndexes, err := getStirChallenges(api, arthur, NQueries, domainSize, foldingFactorPower) - if err != nil { - return nil, err - } + foldedDomainSize := domainSize / foldingFactorPower + domainSizeBytes := (bits.Len(uint(foldedDomainSize*2-1)) - 1 + 7) / 8 - err = utilities.IsEqual(api, uapi, finalIndexes, leafIndexes) - if err != nil { + stirQueries := make([]uints.U8, domainSizeBytes*numQueries) + if err := nimue.FillChallengeBytes(stirQueries); err != nil { return nil, err } - finalRandomnessPoints := make([]frontend.Variable, len(leafIndexes)) - - for index := range leafIndexes { - finalRandomnessPoints[index] = utilities.Exponent(api, uapi, expDomainGenerator, leafIndexes[index]) - } + bitLength := bits.Len(uint(foldedDomainSize)) - 1 - return finalRandomnessPoints, nil -} + indexes := make([]frontend.Variable, numQueries) + for i := range numQueries { + var value frontend.Variable = 0 + for j := range domainSizeBytes { + value = api.Add(stirQueries[j+i*domainSizeBytes].Val, api.Mul(value, 256)) + } -// GenerateCombinationRandomness generates the combination randomness for the given parameters. -// It generates a random scalar and expands it to the required length. -func GenerateCombinationRandomness(api frontend.API, arthur gnarkNimue.Arthur, randomnessLength int) ([]frontend.Variable, error) { - combRandomnessGen := make([]frontend.Variable, 1) - if err := arthur.FillChallengeScalars(combRandomnessGen); err != nil { - return nil, err + bitsOfValue := api.ToBinary(value) + indexes[i] = api.FromBinary(bitsOfValue[:bitLength]...) } - combinationRandomness := utilities.ExpandRandomness(api, combRandomnessGen[0], randomnessLength) - return combinationRandomness, nil + return indexes, nil } diff --git a/recursive-verifier/app/circuit/whir_utilities.go b/recursive-verifier/app/circuit/whir_utilities.go deleted file mode 100644 index 2fd30f91c..000000000 --- a/recursive-verifier/app/circuit/whir_utilities.go +++ /dev/null @@ -1,192 +0,0 @@ -package circuit - -import ( - "math/bits" - - "reilabs/whir-verifier-circuit/app/utilities" - - "github.com/consensys/gnark/frontend" - "github.com/consensys/gnark/std/math/uints" - gnarkNimue "github.com/reilabs/gnark-nimue" - skyscraper "github.com/reilabs/gnark-skyscraper" -) - -func verifyMerkleTreeProofs(api frontend.API, uapi *uints.BinaryField[uints.U64], sc *skyscraper.Skyscraper, leafIndexes []uints.U64, leaves [][]frontend.Variable, leafSiblingHashes []frontend.Variable, authPaths [][]frontend.Variable, rootHash frontend.Variable) error { - numOfLeavesProved := len(leaves) - for i := range numOfLeavesProved { - treeHeight := len(authPaths[i]) + 1 - leafIndexBits := api.ToBinary(uapi.ToValue(leafIndexes[i]), treeHeight) - leafSiblingHash := leafSiblingHashes[i] - - claimedLeafHash := sc.CompressV2(leaves[i][0], leaves[i][1]) - for x := range len(leaves[i]) - 2 { - claimedLeafHash = sc.CompressV2(claimedLeafHash, leaves[i][x+2]) - } - - dir := leafIndexBits[0] - - xLeftChild := api.Select(dir, leafSiblingHash, claimedLeafHash) - xRightChild := api.Select(dir, claimedLeafHash, leafSiblingHash) - - currentHash := sc.CompressV2(xLeftChild, xRightChild) - - for level := 1; level < treeHeight; level++ { - indexBit := leafIndexBits[level] - - siblingHash := authPaths[i][level-1] - - dir := api.And(indexBit, 1) - left := api.Select(dir, siblingHash, currentHash) - right := api.Select(dir, currentHash, siblingHash) - - currentHash = sc.CompressV2(left, right) - } - api.AssertIsEqual(currentHash, rootHash) - } - return nil -} - -func getStirChallenges( - api frontend.API, - arthur gnarkNimue.Arthur, - numQueries int, - domainSize int, - foldingFactorPower int, -) ([]frontend.Variable, error) { - foldedDomainSize := domainSize / foldingFactorPower - domainSizeBytes := (bits.Len(uint(foldedDomainSize*2-1)) - 1 + 7) / 8 - - stirQueries := make([]uints.U8, domainSizeBytes*numQueries) - if err := arthur.FillChallengeBytes(stirQueries); err != nil { - return nil, err - } - - bitLength := bits.Len(uint(foldedDomainSize)) - 1 - - indexes := make([]frontend.Variable, numQueries) - for i := range numQueries { - var value frontend.Variable = 0 - for j := range domainSizeBytes { - value = api.Add(stirQueries[j+i*domainSizeBytes].Val, api.Mul(value, 256)) - } - - bitsOfValue := api.ToBinary(value) - indexes[i] = api.FromBinary(bitsOfValue[:bitLength]...) - } - - return indexes, nil -} - -func generateEmptyMainRoundData(circuit WHIRParams) MainRoundData { - return MainRoundData{ - OODPoints: make([][]frontend.Variable, len(circuit.RoundParametersOODSamples)), - StirChallengesPoints: make([][]frontend.Variable, len(circuit.RoundParametersOODSamples)), - CombinationRandomness: make([][]frontend.Variable, len(circuit.RoundParametersOODSamples)), - } -} - -func fillInOODPointsAndAnswers(numberOfOODPoints int, arthur gnarkNimue.Arthur) ([]frontend.Variable, []frontend.Variable, error) { - oodPoints := make([]frontend.Variable, numberOfOODPoints) - oodAnswers := make([]frontend.Variable, numberOfOODPoints) - - if err := arthur.FillChallengeScalars(oodPoints); err != nil { - return nil, nil, err - } - - if err := arthur.FillNextScalars(oodAnswers); err != nil { - return nil, nil, err - } - - return oodPoints, oodAnswers, nil -} - -func runWhirSumcheckRounds( - api frontend.API, - lastEval frontend.Variable, - arthur gnarkNimue.Arthur, - foldingFactor int, - polynomialDegree int, -) ([]frontend.Variable, frontend.Variable, error) { - sumcheckPolynomial := make([]frontend.Variable, polynomialDegree) - foldingRandomness := make([]frontend.Variable, foldingFactor) - foldingRandomnessTemp := make([]frontend.Variable, 1) - - for i := range foldingFactor { - if err := arthur.FillNextScalars(sumcheckPolynomial); err != nil { - return nil, nil, err - } - if err := arthur.FillChallengeScalars(foldingRandomnessTemp); err != nil { - return nil, nil, err - } - foldingRandomness[i] = foldingRandomnessTemp[0] - utilities.CheckSumOverBool(api, lastEval, sumcheckPolynomial) - lastEval = utilities.EvaluateQuadraticPolynomialFromEvaluationList(api, sumcheckPolynomial, foldingRandomness[i]) - } - return foldingRandomness, lastEval, nil -} - -func computeWPoly( - api frontend.API, - circuit WHIRParams, - initialData InitialSumcheckData, - mainRoundData MainRoundData, - totalFoldingRandomness []frontend.Variable, - linearStatementValuesAtPoints []frontend.Variable, -) frontend.Variable { - numberVars := circuit.MVParamsNumberOfVariables - - value := frontend.Variable(0) - for j := range initialData.InitialOODQueries { - value = api.Add(value, api.Mul(initialData.InitialCombinationRandomness[j], utilities.EqPolyOutside(api, utilities.ExpandFromUnivariate(api, initialData.InitialOODQueries[j], numberVars), totalFoldingRandomness))) - } - - // Values are directly used as all linearStatements are deferred and hints were given. Checking of hints will be done later on. - for j, linearStatementValueAtPoint := range linearStatementValuesAtPoints { - value = api.Add(value, api.Mul(initialData.InitialCombinationRandomness[len(initialData.InitialOODQueries)+j], linearStatementValueAtPoint)) - } - for r := range mainRoundData.OODPoints { - numberVars -= circuit.FoldingFactorArray[r] - newTmpArr := append(mainRoundData.OODPoints[r], mainRoundData.StirChallengesPoints[r]...) - - sumOfClaims := frontend.Variable(0) - for i := range newTmpArr { - point := utilities.ExpandFromUnivariate(api, newTmpArr[i], numberVars) - sumOfClaims = api.Add(sumOfClaims, api.Mul(utilities.EqPolyOutside(api, point, totalFoldingRandomness[0:numberVars]), mainRoundData.CombinationRandomness[r][i])) - } - value = api.Add(value, sumOfClaims) - } - - return value -} - -//nolint:unused -func fillInAndVerifyRootHash( - roundNum int, - api frontend.API, - uapi *uints.BinaryField[uints.U64], - sc *skyscraper.Skyscraper, - circuit Merkle, - arthur gnarkNimue.Arthur, -) error { - rootHash := make([]frontend.Variable, 1) - if err := arthur.FillNextScalars(rootHash); err != nil { - return err - } - err := verifyMerkleTreeProofs(api, uapi, sc, circuit.LeafIndexes[roundNum], circuit.Leaves[roundNum], circuit.LeafSiblingHashes[roundNum], circuit.AuthPaths[roundNum], rootHash[0]) - if err != nil { - return err - } - return nil -} - -func computeFold(leaves [][]frontend.Variable, foldingRandomness []frontend.Variable, api frontend.API) []frontend.Variable { - computedFold := make([]frontend.Variable, len(leaves)) - for j := range leaves { - computedFold[j] = utilities.MultivarPoly(leaves[j], foldingRandomness, api) - } - return computedFold -} - -func calculateShiftValue(oodAnswers []frontend.Variable, combinationRandomness []frontend.Variable, computedFold []frontend.Variable, api frontend.API) frontend.Variable { - return utilities.DotProduct(api, append(oodAnswers, computedFold...), combinationRandomness) -} diff --git a/recursive-verifier/app/circuit/zk_whir_verify.go b/recursive-verifier/app/circuit/zk_whir_verify.go new file mode 100644 index 000000000..2b4e4139c --- /dev/null +++ b/recursive-verifier/app/circuit/zk_whir_verify.go @@ -0,0 +1,436 @@ +package circuit + +import ( + "fmt" + + "reilabs/whir-verifier-circuit/app/whir" + + "github.com/consensys/gnark/frontend" + gnarkNimue "github.com/reilabs/gnark-nimue" + skyscraper "github.com/reilabs/gnark-skyscraper" +) + +// --------------------------------------------------------------------------- +// ParsedCommitmentNimue is the circuit-level parsed commitment. +// --------------------------------------------------------------------------- + +// ParsedCommitmentNimue mirrors the gnark whir ParsedCommitment but lives +// in the circuit package to avoid importing the whir package. +type ParsedCommitmentNimue struct { + Root frontend.Variable + OodPoints []frontend.Variable + OodAnswers []frontend.Variable // flat: outDomainSamples * batchSize +} + +// WhirStatement represents an external linear form constraint for WHIR. +type WhirStatement struct { + Evaluation frontend.Variable // claimed evaluation value +} + +// --------------------------------------------------------------------------- +// ZKWhirVerifyNimue is the circuit-level ZK-WHIR verification wrapper. +// Mirrors nativeZKWhirVerify but uses gnark constraints. +// --------------------------------------------------------------------------- + +// CommitmentMode controls which columns of the R1CS matrices contribute to +// weight MLE evaluations in the FinalClaim checks. +type CommitmentMode int + +const ( + // SingleCommitment uses all columns (single-commitment path). + SingleCommitment CommitmentMode = iota + // DualCommitment1 uses only columns < W1Size and includes public + blinding weights. + DualCommitment1 + // DualCommitment2 uses only columns >= W1Size (shifted by W1Size). No public or blinding weights. + DualCommitment2 +) + +// R1CSWeightParams bundles the circuit data needed to compute weight MLE +// evaluations for both the blinded and blinding FinalClaim checks. +type R1CSWeightParams struct { + Circuit *Circuit + Alpha []frontend.Variable + PublicWeightsChallenge frontend.Variable + HasPublicInputs bool + ChallengeOffsets []int + Mode CommitmentMode +} + +func ZKWhirVerify( + api frontend.API, + sc *skyscraper.Skyscraper, + nimue gnarkNimue.Nimue, + blindedCommitment ParsedCommitmentNimue, + blindingCommitment ParsedCommitmentNimue, + blindedParams WHIRParams, + blindingParams WHIRParams, + evaluations []frontend.Variable, // claimed linear form evaluations [pub?, az, bz, cz, blinding_eval] + weightsLen int, // 4 (no public) or 5 (with public) + numPolynomials int, + blindedMerkleData *whir.WhirMerkleData, + blindingMerkleData *whir.WhirMerkleData, + r1csWeights R1CSWeightParams, +) error { + numWitnessVariables := blindedParams.MVParamsNumberOfVariables + interleavingDepth := 1 << blindedParams.FoldingFactorArray[0] + + blindingChallenge := make([]frontend.Variable, 1) + if err := nimue.FillChallengeScalars(blindingChallenge); err != nil { + return fmt.Errorf("blinding_challenge: %w", err) + } + + numWFoldedEvals := weightsLen * numPolynomials * (numWitnessVariables + 1) + wFoldedBlindingEvals := make([]frontend.Variable, numWFoldedEvals) + if err := nimue.FillNextScalars(wFoldedBlindingEvals); err != nil { + return fmt.Errorf("w_folded_blinding_evals: %w", err) + } + + maskingChallenge := make([]frontend.Variable, 1) + if err := nimue.FillChallengeScalars(maskingChallenge); err != nil { + return fmt.Errorf("masking_challenge: %w", err) + } + + numInitialQueries := blindedParams.InitialInDomainSamples + initialStirIndexes, err := getStirChallenges(api, nimue, numInitialQueries, blindedParams.DomainSize, interleavingDepth) + if err != nil { + return fmt.Errorf("initial_committer stir: %w", err) + } + + // h_gammas count + hGammasCount := numInitialQueries * interleavingDepth + + tau1 := make([]frontend.Variable, 1) + if err := nimue.FillChallengeScalars(tau1); err != nil { + return fmt.Errorf("tau1: %w", err) + } + tau2 := make([]frontend.Variable, 1) + if err := nimue.FillChallengeScalars(tau2); err != nil { + return fmt.Errorf("tau2: %w", err) + } + + evalsPerPoly := 1 + numWitnessVariables + perGammaEvals := make([][][]frontend.Variable, hGammasCount) + for g := range hGammasCount { + perGammaEvals[g] = make([][]frontend.Variable, numPolynomials) + for p := range numPolynomials { + vals := make([]frontend.Variable, evalsPerPoly) + if err := nimue.FillNextScalars(vals); err != nil { + return fmt.Errorf("gamma %d poly %d: %w", g, p, err) + } + perGammaEvals[g][p] = vals + } + } + + combinedClaims := make([]frontend.Variable, numPolynomials) + if err := nimue.FillNextScalars(combinedClaims); err != nil { + return fmt.Errorf("combined_claims: %w", err) + } + batchedHClaims := make([]frontend.Variable, numPolynomials) + if err := nimue.FillNextScalars(batchedHClaims); err != nil { + return fmt.Errorf("batched_h_claims: %w", err) + } + + // Verify batched_h_claims (Rust: verify!(batched_h_claims == expected_batched_h_claims)) + // Compute gamma values: for each query index i, for k = 0..interleavingDepth-1: + // gamma_{i,k} = omega_full^(index_i) * zeta^k + numBitsIdx := 0 + for v := blindedParams.DomainSize / interleavingDepth; v > 1; v >>= 1 { + numBitsIdx++ + } + // Precompute zeta powers: [1, zeta, zeta^2, ..., zeta^(interleavingDepth-1)] + zetaPowers := make([]frontend.Variable, interleavingDepth) + zetaPowers[0] = frontend.Variable(1) + for k := 1; k < interleavingDepth; k++ { + zetaPowers[k] = api.Mul(zetaPowers[k-1], blindedParams.Zeta) + } + + gammas := make([]frontend.Variable, hGammasCount) + for qi := range numInitialQueries { + // The provekit NTT (RSFr) uses bit-reversed evaluation order. + // evaluation_points returns generator^(bit_reverse(index, numBitsIdx)), + // so all_gammas computes coset_offset = omega_full^(bit_reverse(index)). + cosetOffset := whir.BitReversedExponentVar(api, blindedParams.OmegaFull, initialStirIndexes[qi], numBitsIdx) + for k := range interleavingDepth { + gammas[qi*interleavingDepth+k] = api.Mul(cosetOffset, zetaPowers[k]) + } + } + + // Compute expected_batched_h_claims from per-gamma evaluations. + // Mirrors Rust whir_zk/verifier.rs lines 95-113. + expectedBatchedHClaims := make([]frontend.Variable, numPolynomials) + for p := range numPolynomials { + expectedBatchedHClaims[p] = frontend.Variable(0) + } + tau2PowerH := frontend.Variable(1) + for g := range hGammasCount { + gamma := gammas[g] + for p := range numPolynomials { + evals := perGammaEvals[g][p] + mEval := evals[0] + hValue := mEval + blindingPower := blindingChallenge[0] + gammaPower := gamma + for j := range numWitnessVariables { + gHatEval := evals[j+1] + hValue = api.Add(hValue, api.Mul(api.Mul(blindingPower, gammaPower), gHatEval)) + blindingPower = api.Mul(blindingPower, blindingChallenge[0]) + gammaPower = api.Mul(gammaPower, gammaPower) + } + expectedBatchedHClaims[p] = api.Add(expectedBatchedHClaims[p], api.Mul(tau2PowerH, hValue)) + } + tau2PowerH = api.Mul(tau2PowerH, tau2[0]) + } + for p := range numPolynomials { + api.AssertIsEqual(batchedHClaims[p], expectedBatchedHClaims[p]) + } + + // Blinded commitment WHIR verify + // Mirrors Rust whir_zk/verifier.rs lines 118-125: + // modified_evaluations[i] = evaluations[i] + m_evals[i] + // where m_evals[i] is the first element of each (μ+1)-sized block + // in wFoldedBlindingEvals. + blockSize := numWitnessVariables + 1 + modifiedEvaluations := make([]frontend.Variable, len(evaluations)) + for i, eval := range evaluations { + mEval := wFoldedBlindingEvals[i*blockSize] // first element of each block + modifiedEvaluations[i] = api.Add(eval, mEval) + } + blindedWhirCommitment := toWhirCommitment(blindedCommitment) + blindedWhirStatements := toWhirStatements(modifiedEvaluations, blindedParams.BatchSize) + blindedWhirParams := blindedParams + + blindedResult, err := whir.VerifyWhir(api, sc, nimue, blindedWhirCommitment, blindedWhirStatements, blindedWhirParams, blindedMerkleData) + if err != nil { + return fmt.Errorf("blinded WHIR verify: %w", err) + } + + // blinding commitment WHIR verify + // Accumulate m_claims and g_hat_claims using tau2 powers + mClaims := make([]frontend.Variable, numPolynomials) + gHatClaims := make([][]frontend.Variable, numPolynomials) + for p := range numPolynomials { + mClaims[p] = frontend.Variable(0) + gHatClaims[p] = make([]frontend.Variable, numWitnessVariables) + for j := range numWitnessVariables { + gHatClaims[p][j] = frontend.Variable(0) + } + } + + tau2Power := frontend.Variable(1) + for g := range hGammasCount { + for p := range numPolynomials { + evals := perGammaEvals[g][p] + mClaims[p] = api.Add(mClaims[p], api.Mul(tau2Power, evals[0])) + for j := range numWitnessVariables { + gHatClaims[p][j] = api.Add(gHatClaims[p][j], api.Mul(tau2Power, evals[j+1])) + } + } + tau2Power = api.Mul(tau2Power, tau2[0]) + } + + // verify combined_claims (Rust: verify!(combined_claims == expected_combined_claims)) + // combined_claims[p] = m_claims[p] + 2 * tau1 * univariate_evaluate(g_hat_claims[p], tau1) + for p := range numPolynomials { + // Horner evaluation of g_hat_claims[p] at tau1 + gHatEval := frontend.Variable(0) + for j := len(gHatClaims[p]) - 1; j >= 0; j-- { + gHatEval = api.Add(gHatClaims[p][j], api.Mul(gHatEval, tau1[0])) + } + // expected = m_claims[p] + 2 * tau1 * gHatEval + expectedCombined := api.Add(mClaims[p], api.Mul(frontend.Variable(2), api.Mul(tau1[0], gHatEval))) + api.AssertIsEqual(combinedClaims[p], expectedCombined) + } + + // Build subproof_claims: [m_0, g_hat_0..., m_1, g_hat_1..., ...] + var subproofClaims []frontend.Variable + for p := range numPolynomials { + subproofClaims = append(subproofClaims, mClaims[p]) + subproofClaims = append(subproofClaims, gHatClaims[p]...) + } + + // all_expected_blinding_claims = subproof_claims ++ w_folded_blinding_evals + blindingEvaluations := append(subproofClaims, wFoldedBlindingEvals...) + + blindingWhirCommitment := toWhirCommitment(blindingCommitment) + blindingWhirStatements := toWhirStatements(blindingEvaluations, blindingParams.BatchSize) + blindingWhirParams := blindingParams + + blindingResult, err := whir.VerifyWhir(api, sc, nimue, blindingWhirCommitment, blindingWhirStatements, blindingWhirParams, blindingMerkleData) + if err != nil { + return fmt.Errorf("blinding WHIR verify: %w", err) + } + + // verify WHIR-committed polynomial matches R1CS weight linear forms. + w := r1csWeights + fc := blindedResult.FinalClaim + foldingRandomness := fc.EvaluationPoint + + var weightMLEs []frontend.Variable + switch w.Mode { + case SingleCommitment: + // All columns, public + blinding weights. + matrixExtensionEvals := evaluateR1CSMatrixExtension(api, w.Circuit, w.Alpha, foldingRandomness) + if w.HasPublicInputs { + // n = num_public_inputs + 1 (the +1 accounts for the constant-1 witness at position 0) + weightMLEs = append(weightMLEs, geometricTill(api, w.PublicWeightsChallenge, len(w.Circuit.PublicInputs.Values)+1, foldingRandomness)) + } + weightMLEs = append(weightMLEs, matrixExtensionEvals[0], matrixExtensionEvals[1], matrixExtensionEvals[2]) + weightMLEs = append(weightMLEs, blindingCovectorMLE(api, w.Alpha, w.Circuit.W1Size, foldingRandomness)) + case DualCommitment1: + // Columns < W1Size only, with public + blinding weights. + matrixEvals1, _ := evaluateR1CSMatrixExtensionSplit(api, w.Circuit, w.Alpha, foldingRandomness, nil, w.Circuit.W1Size) + if w.HasPublicInputs { + // n = num_public_inputs + 1 (the +1 accounts for the constant-1 witness at position 0) + weightMLEs = append(weightMLEs, geometricTill(api, w.PublicWeightsChallenge, len(w.Circuit.PublicInputs.Values)+1, foldingRandomness)) + } + weightMLEs = append(weightMLEs, matrixEvals1[0], matrixEvals1[1], matrixEvals1[2]) + weightMLEs = append(weightMLEs, blindingCovectorMLE(api, w.Alpha, w.Circuit.W1Size, foldingRandomness)) + case DualCommitment2: + // Columns >= W1Size only, no public, no blinding. + _, matrixEvals2 := evaluateR1CSMatrixExtensionSplit(api, w.Circuit, w.Alpha, nil, foldingRandomness, w.Circuit.W1Size) + weightMLEs = append(weightMLEs, matrixEvals2[0], matrixEvals2[1], matrixEvals2[2]) + // Challenge weight MLE if challenge_offsets are present. + if len(w.ChallengeOffsets) > 0 { + weightMLEs = append(weightMLEs, challengeWeightMLE(api, w.PublicWeightsChallenge, w.ChallengeOffsets, foldingRandomness)) + } + } + fc.VerifyClaim(api, weightMLEs) + + // verify blinding polynomial matches beq_weights and folded R1CS weights. + w = r1csWeights + fc = blindingResult.FinalClaim + evalPoint := fc.EvaluationPoint + numBlindingVars := blindingParams.MVParamsNumberOfVariables - 1 + maskSize := 1 << (numBlindingVars + 1) + + beqMLE := batchedBeqMLE(api, gammas, maskingChallenge[0], tau2[0], numBlindingVars, evalPoint) + weightMLEs = []frontend.Variable{beqMLE} + + switch w.Mode { + case SingleCommitment: + foldedMatrixEvals := evaluateFoldedR1CSMatrixExtension(api, w.Circuit, w.Alpha, evalPoint, maskSize) + if w.HasPublicInputs { + weightMLEs = append(weightMLEs, foldedGeometricTill(api, w.PublicWeightsChallenge, len(w.Circuit.PublicInputs.Values)+1, evalPoint, maskSize)) + } + weightMLEs = append(weightMLEs, foldedMatrixEvals[0], foldedMatrixEvals[1], foldedMatrixEvals[2]) + weightMLEs = append(weightMLEs, foldedBlindingCovectorMLE(api, w.Alpha, w.Circuit.W1Size, evalPoint, maskSize)) + case DualCommitment1: + foldedEvals1, _ := evaluateFoldedR1CSMatrixExtensionSplit(api, w.Circuit, w.Alpha, evalPoint, nil, maskSize, w.Circuit.W1Size) + if w.HasPublicInputs { + weightMLEs = append(weightMLEs, foldedGeometricTill(api, w.PublicWeightsChallenge, len(w.Circuit.PublicInputs.Values)+1, evalPoint, maskSize)) + } + weightMLEs = append(weightMLEs, foldedEvals1[0], foldedEvals1[1], foldedEvals1[2]) + weightMLEs = append(weightMLEs, foldedBlindingCovectorMLE(api, w.Alpha, w.Circuit.W1Size, evalPoint, maskSize)) + case DualCommitment2: + _, foldedEvals2 := evaluateFoldedR1CSMatrixExtensionSplit(api, w.Circuit, w.Alpha, nil, evalPoint, maskSize, w.Circuit.W1Size) + weightMLEs = append(weightMLEs, foldedEvals2[0], foldedEvals2[1], foldedEvals2[2]) + // Folded challenge weight MLE if challenge_offsets are present. + if len(w.ChallengeOffsets) > 0 { + weightMLEs = append(weightMLEs, foldedChallengeWeightMLE(api, w.PublicWeightsChallenge, w.ChallengeOffsets, evalPoint, maskSize)) + } + } + fc.VerifyClaim(api, weightMLEs) + + return nil +} + +// --------------------------------------------------------------------------- +// Type conversion helpers for calling whir.VerifyWhir from the circuit package. +// --------------------------------------------------------------------------- + +func toWhirCommitment(c ParsedCommitmentNimue) whir.ParsedCommitment { + return whir.ParsedCommitment{ + Root: c.Root, + OodPoints: c.OodPoints, + OodAnswers: c.OodAnswers, + } +} + +// toWhirStatements converts a flat slice of evaluation values into +// whir.Statement objects, grouping every numVectors evaluations into a +// single statement with that many constraints. When numVectors=1 each +// evaluation becomes its own statement (the common case for the blinded WHIR). +func toWhirStatements(evaluations []frontend.Variable, numVectors int) []whir.Statement { + numStatements := len(evaluations) / numVectors + statements := make([]whir.Statement, numStatements) + for i := range numStatements { + constraints := make([]whir.MLConstraint, numVectors) + for j := range numVectors { + constraints[j] = whir.MLConstraint{Evaluation: evaluations[i*numVectors+j]} + } + statements[i] = whir.Statement{Constraints: constraints, NVars: 0} + } + return statements +} + +// --------------------------------------------------------------------------- +// calculateEqCircuit computes eq(a, b) = Π_i (a_i*b_i + (1-a_i)*(1-b_i)) +// --------------------------------------------------------------------------- + +func calculateEqCircuit(api frontend.API, a, b []frontend.Variable) frontend.Variable { + result := frontend.Variable(1) + for i := range a { + ab := api.Mul(a[i], b[i]) + oneMinusA := api.Sub(frontend.Variable(1), a[i]) + oneMinusB := api.Sub(frontend.Variable(1), b[i]) + prod := api.Mul(oneMinusA, oneMinusB) + term := api.Add(ab, prod) + result = api.Mul(result, term) + } + return result +} + +// batchedBeqMLE computes the MLE of the batched beq_weights at evalPoint. +// +// beq_weights = Σ_g tau2^g * beq((pow(gamma_g), -maskingChallenge), ·) +// +// The MLE of a covector c at point p equals c(p) = Σ_j c[j] * L_j(p). +// Since beq is itself multilinear, its MLE at p is just beq(target, p). +// So: beq_mle(p) = Σ_g tau2^g * beq((pow(gamma_g), -maskingChallenge), p) +// +// where beq(a, b) = Π_i (a_i*b_i + (1-a_i)*(1-b_i)) over ell+1 variables: +// - variables 0..ell-1: squaring ladder (gamma, gamma^2, gamma^4, ...) +// - variable ell: -maskingChallenge +func batchedBeqMLE( + api frontend.API, + gammas []frontend.Variable, + maskingChallenge frontend.Variable, + tau2 frontend.Variable, + numBlindingVars int, + evalPoint []frontend.Variable, // ell+1 variables +) frontend.Variable { + negRho := api.Neg(maskingChallenge) + + result := frontend.Variable(0) + tau2Power := frontend.Variable(1) + + for _, gamma := range gammas { + // Compute beq((pow(gamma), -rho), evalPoint) + beqVal := frontend.Variable(1) + + // Variables 0..ell-1: squaring ladder of gamma + gammaPower := gamma + for i := range numBlindingVars { + ab := api.Mul(gammaPower, evalPoint[i]) + oneMinusA := api.Sub(frontend.Variable(1), gammaPower) + oneMinusB := api.Sub(frontend.Variable(1), evalPoint[i]) + term := api.Add(ab, api.Mul(oneMinusA, oneMinusB)) + beqVal = api.Mul(beqVal, term) + gammaPower = api.Mul(gammaPower, gammaPower) + } + + // Variable ell: -maskingChallenge + { + ab := api.Mul(negRho, evalPoint[numBlindingVars]) + oneMinusA := api.Sub(frontend.Variable(1), negRho) + oneMinusB := api.Sub(frontend.Variable(1), evalPoint[numBlindingVars]) + term := api.Add(ab, api.Mul(oneMinusA, oneMinusB)) + beqVal = api.Mul(beqVal, term) + } + + result = api.Add(result, api.Mul(tau2Power, beqVal)) + tau2Power = api.Mul(tau2Power, tau2) + } + + return result +} diff --git a/recursive-verifier/app/keccakSponge/keccakSponge.go b/recursive-verifier/app/keccakSponge/keccakSponge.go deleted file mode 100644 index 84b1f910b..000000000 --- a/recursive-verifier/app/keccakSponge/keccakSponge.go +++ /dev/null @@ -1,82 +0,0 @@ -package keccakSponge - -import ( - "github.com/consensys/gnark/frontend" - "github.com/consensys/gnark/std/math/uints" - "github.com/consensys/gnark/std/permutation/keccakf" -) - -type Digest struct { - api frontend.API - uapi *uints.BinaryField[uints.U64] - state [25]uints.U64 - absorb_pos int - squeeze_pos int -} - -func NewKeccak(api frontend.API) (*Digest, error) { - uapi, err := uints.New[uints.U64](api) - if err != nil { - return nil, err - } - return &Digest{ - api: api, - uapi: uapi, - state: newState(), - absorb_pos: 0, - squeeze_pos: 136, - }, nil -} - -func NewKeccakWithTag(api frontend.API, tag []frontend.Variable) (*Digest, error) { - d, _ := NewKeccak(api) - for i := 136; i < 136+len(tag); i++ { - d.state[i/8][i%8].Val = tag[i-136] - } - - return d, nil -} - -func (d *Digest) Absorb(in []frontend.Variable) { - u8Arr := make([]uints.U8, len(in)) - for i := range in { - u8Arr[i].Val = in[i] - } - - for _, inputByte := range u8Arr { - if d.absorb_pos == 136 { - d.state = keccakf.Permute(d.uapi, d.state) - d.absorb_pos = 0 - } - d.state[d.absorb_pos/8][d.absorb_pos%8] = inputByte - d.absorb_pos++ - } - - d.squeeze_pos = 136 -} - -func (d *Digest) AbsorbQuadraticPolynomial(in [][]frontend.Variable) { - for i := range in { - d.Absorb(in[i]) - } -} - -func (d *Digest) Squeeze(len int) (result []frontend.Variable) { - for i := 0; i < len; i++ { - if d.squeeze_pos == 136 { - d.squeeze_pos = 0 - d.absorb_pos = 0 - d.state = keccakf.Permute(d.uapi, d.state) - } - result = append(result, d.state[d.squeeze_pos/8][d.squeeze_pos%8].Val) - d.squeeze_pos++ - } - return result -} - -func newState() (state [25]uints.U64) { - for i := range state { - state[i] = uints.NewU64(0) - } - return -} diff --git a/recursive-verifier/app/typeConverters/typeConverters.go b/recursive-verifier/app/typeConverters/typeConverters.go index e4786c29f..12556ad17 100644 --- a/recursive-verifier/app/typeConverters/typeConverters.go +++ b/recursive-verifier/app/typeConverters/typeConverters.go @@ -2,27 +2,8 @@ package typeConverters import ( "math/big" - - "github.com/consensys/gnark/frontend" - "github.com/consensys/gnark/std/math/uints" ) -func BigEndian(api frontend.API, varArr []frontend.Variable) frontend.Variable { - frontendVar := frontend.Variable(0) - for i := range varArr { - frontendVar = api.Add(api.Mul(256, frontendVar), varArr[i]) - } - return frontendVar -} - -func LittleEndian(api frontend.API, varArr []frontend.Variable) frontend.Variable { - frontendVar := frontend.Variable(0) - for i := range varArr { - frontendVar = api.Add(api.Mul(256, frontendVar), varArr[len(varArr)-1-i]) - } - return frontendVar -} - func LimbsToBigIntMod(limbs [4]uint64) *big.Int { modulus := new(big.Int) modulus.SetString("21888242871839275222246405745257275088548364400416034343698204186575808495617", 10) @@ -42,52 +23,3 @@ func LimbsToBigIntMod(limbs [4]uint64) *big.Int { return result } - -func LittleEndianFromUints(api frontend.API, varArr []uints.U8) frontend.Variable { - frontendVar := frontend.Variable(0) - for i := range varArr { - frontendVar = api.Add(api.Mul(256, frontendVar), varArr[len(varArr)-1-i].Val) - } - return frontendVar -} - -func BigEndianFromUints(api frontend.API, varArr []uints.U8) frontend.Variable { - frontendVar := frontend.Variable(0) - for i := 0; i < len(varArr); i++ { - frontendVar = api.Mul(frontendVar, 256) - frontendVar = api.Add(frontendVar, varArr[i].Val) - } - return frontendVar -} - -func LittleEndianArr(api frontend.API, arrVarArr [][]frontend.Variable) []frontend.Variable { - frontendArr := make([]frontend.Variable, len(arrVarArr)) - - for j := range arrVarArr { - frontendVar := frontend.Variable(0) - for i := range arrVarArr[j] { - frontendVar = api.Add(api.Mul(256, frontendVar), arrVarArr[j][len(arrVarArr[j])-1-i]) - } - frontendArr[j] = frontendVar - } - return frontendArr -} - -func ByteArrToVarArr(uint8Arr []uint8) []frontend.Variable { - frontendArr := make([]frontend.Variable, len(uint8Arr)) - for i := range frontendArr { - frontendArr[i] = frontend.Variable(uint8Arr[i]) - } - return frontendArr -} - -func LittleEndianUint8ToBigInt(bytes []uint8) *big.Int { - reversed := make([]byte, len(bytes)) - for i, b := range bytes { - reversed[len(bytes)-1-i] = b - } - - result := new(big.Int) - result.SetBytes(reversed) - return result -} diff --git a/recursive-verifier/app/utilities/utilities.go b/recursive-verifier/app/utilities/utilities.go index 691af27dd..dd1533b99 100644 --- a/recursive-verifier/app/utilities/utilities.go +++ b/recursive-verifier/app/utilities/utilities.go @@ -5,39 +5,11 @@ import ( "encoding/json" "fmt" "math/big" - "reilabs/whir-verifier-circuit/app/typeConverters" "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/std/math/uints" - gnarkNimue "github.com/reilabs/gnark-nimue" - skyscraper "github.com/reilabs/gnark-skyscraper" ) -func MultivarPoly(coefs []frontend.Variable, vars []frontend.Variable, api frontend.API) frontend.Variable { - if len(vars) == 0 { - return coefs[0] - } - deg_zero := MultivarPoly(coefs[:len(coefs)/2], vars[:len(vars)-1], api) - deg_one := api.Mul(vars[len(vars)-1], MultivarPoly(coefs[len(coefs)/2:], vars[:len(vars)-1], api)) - return api.Add(deg_zero, deg_one) -} - -func UnivarPoly(api frontend.API, coefficients []frontend.Variable, points []frontend.Variable) []frontend.Variable { - if len(points) == 0 { - return coefficients - } - - results := make([]frontend.Variable, len(points)) - for j := range points { - ans := frontend.Variable(0) - for i := range coefficients { - ans = api.Add(api.Mul(ans, points[j]), coefficients[len(coefficients)-1-i]) - } - results[j] = ans - } - return results -} - func IndexOf(_ *big.Int, inputs []*big.Int, outputs []*big.Int) error { if len(outputs) != 1 { return fmt.Errorf("expecting one output") @@ -60,92 +32,6 @@ func IndexOf(_ *big.Int, inputs []*big.Int, outputs []*big.Int) error { return nil } -func Reverse[T any](s []T) []T { - res := make([]T, len(s)) - copy(res, s) - for i, j := 0, len(s)-1; i < j; i, j = i+1, j-1 { - res[i], res[j] = s[j], s[i] - } - return res -} - -func PrefixDecodePath[T any](prevPath []T, prefixLen uint64, suffix []T) []T { - if prefixLen == 0 { - res := make([]T, len(suffix)) - copy(res, suffix) - return res - } else { - res := make([]T, prefixLen+uint64(len(suffix))) - copy(res, prevPath[:prefixLen]) - copy(res[prefixLen:], suffix) - return res - } -} - -func PoW(api frontend.API, sc *skyscraper.Skyscraper, arthur gnarkNimue.Arthur, difficulty int) ([]uints.U8, []uints.U8, error) { - challenge := make([]uints.U8, 32) - if err := arthur.FillChallengeBytes(challenge); err != nil { - return nil, nil, err - } - nonce := make([]uints.U8, 8) - - if err := arthur.FillNextBytes(nonce); err != nil { - return nil, nil, err - } - challengeFieldElement := typeConverters.LittleEndianFromUints(api, challenge) - nonceFieldElement := typeConverters.BigEndianFromUints(api, nonce) - err := CheckPoW(api, sc, challengeFieldElement, nonceFieldElement, difficulty) - if err != nil { - return nil, nil, err - } - return challenge, nonce, nil -} - -func CheckPoW(api frontend.API, sc *skyscraper.Skyscraper, challenge frontend.Variable, nonce frontend.Variable, difficulty int) error { - hash := sc.CompressV2(challenge, nonce) - - d0, _ := new(big.Int).SetString("21888242871839275222246405745257275088548364400416034343698204186575808495617", 10) - d1, _ := new(big.Int).SetString("10944121435919637611123202872628637544274182200208017171849102093287904247808", 10) - d2, _ := new(big.Int).SetString("5472060717959818805561601436314318772137091100104008585924551046643952123904", 10) - d3, _ := new(big.Int).SetString("2736030358979909402780800718157159386068545550052004292962275523321976061952", 10) - d4, _ := new(big.Int).SetString("1368015179489954701390400359078579693034272775026002146481137761660988030976", 10) - d5, _ := new(big.Int).SetString("684007589744977350695200179539289846517136387513001073240568880830494015488", 10) - d6, _ := new(big.Int).SetString("342003794872488675347600089769644923258568193756500536620284440415247007744", 10) - d7, _ := new(big.Int).SetString("171001897436244337673800044884822461629284096878250268310142220207623503872", 10) - d8, _ := new(big.Int).SetString("85500948718122168836900022442411230814642048439125134155071110103811751936", 10) - d9, _ := new(big.Int).SetString("42750474359061084418450011221205615407321024219562567077535555051905875968", 10) - d10, _ := new(big.Int).SetString("21375237179530542209225005610602807703660512109781283538767777525952937984", 10) - d11, _ := new(big.Int).SetString("10687618589765271104612502805301403851830256054890641769383888762976468992", 10) - d12, _ := new(big.Int).SetString("5343809294882635552306251402650701925915128027445320884691944381488234496", 10) - d13, _ := new(big.Int).SetString("2671904647441317776153125701325350962957564013722660442345972190744117248", 10) - d14, _ := new(big.Int).SetString("1335952323720658888076562850662675481478782006861330221172986095372058624", 10) - d15, _ := new(big.Int).SetString("667976161860329444038281425331337740739391003430665110586493047686029312", 10) - d16, _ := new(big.Int).SetString("333988080930164722019140712665668870369695501715332555293246523843014656", 10) - d17, _ := new(big.Int).SetString("166994040465082361009570356332834435184847750857666277646623261921507328", 10) - d18, _ := new(big.Int).SetString("83497020232541180504785178166417217592423875428833138823311630960753664", 10) - d19, _ := new(big.Int).SetString("41748510116270590252392589083208608796211937714416569411655815480376832", 10) - d20, _ := new(big.Int).SetString("20874255058135295126196294541604304398105968857208284705827907740188416", 10) - d21, _ := new(big.Int).SetString("10437127529067647563098147270802152199052984428604142352913953870094208", 10) - d22, _ := new(big.Int).SetString("5218563764533823781549073635401076099526492214302071176456976935047104", 10) - d23, _ := new(big.Int).SetString("2609281882266911890774536817700538049763246107151035588228488467523552", 10) - d24, _ := new(big.Int).SetString("1304640941133455945387268408850269024881623053575517794114244233761776", 10) - d25, _ := new(big.Int).SetString("652320470566727972693634204425134512440811526787758897057122116880888", 10) - d26, _ := new(big.Int).SetString("326160235283363986346817102212567256220405763393879448528561058440444", 10) - d27, _ := new(big.Int).SetString("163080117641681993173408551106283628110202881696939724264280529220222", 10) - - var arr = [28]*big.Int{d0, d1, d2, d3, d4, d5, d6, d7, d8, d9, d10, d11, d12, d13, d14, d15, d16, d17, d18, d19, d20, d21, d22, d23, d24, d25, d26, d27} - api.AssertIsLessOrEqual(hash, arr[difficulty]) - return nil -} - -func EqPolyOutside(api frontend.API, coords []frontend.Variable, point []frontend.Variable) frontend.Variable { - acc := frontend.Variable(1) - for i := range coords { - acc = api.Mul(acc, api.Add(api.Mul(coords[i], point[i]), api.Mul(api.Sub(frontend.Variable(1), coords[i]), api.Sub(frontend.Variable(1), point[i])))) - } - return acc -} - func EvaluateQuadraticPolynomialFromEvaluationList(api frontend.API, evaluations []frontend.Variable, point frontend.Variable) (ans frontend.Variable) { inv2 := api.Inverse(2) b0 := evaluations[0] @@ -165,11 +51,6 @@ func Exponent(api frontend.API, uapi *uints.BinaryField[uints.U64], X frontend.V return output } -func CheckSumOverBool(api frontend.API, value frontend.Variable, polyEvals []frontend.Variable) { - sumOverBools := api.Add(polyEvals[0], polyEvals[1]) - api.AssertIsEqual(value, sumOverBools) -} - func ExpandRandomness(api frontend.API, base frontend.Variable, len int) []frontend.Variable { res := make([]frontend.Variable, len) acc := frontend.Variable(1) @@ -190,21 +71,6 @@ func ExpandFromUnivariate(api frontend.API, base frontend.Variable, len int) []f return res } -func IsEqual(api frontend.API, uapi *uints.BinaryField[uints.U64], indexes []frontend.Variable, merkleIndexes []uints.U64) error { - api.AssertIsEqual(len(indexes), len(merkleIndexes)) - - merkleVars := make([]frontend.Variable, len(merkleIndexes)) - for i, index := range merkleIndexes { - merkleVars[i] = uapi.ToValue(index) - } - - for i := range indexes { - api.AssertIsEqual(indexes[i], merkleVars[i]) - } - - return nil -} - func DotProduct(api frontend.API, a []frontend.Variable, b []frontend.Variable) frontend.Variable { var acc = frontend.Variable(0) for i := range a { diff --git a/recursive-verifier/app/whir/mtUtilities.go b/recursive-verifier/app/whir/mtUtilities.go new file mode 100644 index 000000000..d95b65c22 --- /dev/null +++ b/recursive-verifier/app/whir/mtUtilities.go @@ -0,0 +1,218 @@ +package whir + +import ( + "fmt" + "math/big" + "math/bits" + + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/std/math/uints" + gnarkNimue "github.com/reilabs/gnark-nimue" + skyscraper "github.com/reilabs/gnark-skyscraper" +) + +// initialSumcheck mirrors the Rust WHIR verifier's initial sumcheck phase. +// The sum has already been computed by the caller (theSum). This function +// runs the sumcheck rounds to reduce the claim, and stores the OOD queries +// and RLC coefficients for the final W polynomial verification. +func initialSumcheck( + api frontend.API, + nimue gnarkNimue.Nimue, + theSum frontend.Variable, + oodPoints []frontend.Variable, + oodsRlcCoeffs []frontend.Variable, + initialFormRlcCoeffs []frontend.Variable, + whirParams WHIRParams, +) (InitialSumcheckData, frontend.Variable, []frontend.Variable, error) { + + initialSumcheckFoldingRandomness, lastEval, err := runWhirSumcheckRounds(api, theSum, nimue, whirParams.FoldingFactorArray[0]) + if err != nil { + return InitialSumcheckData{}, nil, nil, err + } + + combinedRlcCoeffs := make([]frontend.Variable, len(initialFormRlcCoeffs)+len(oodsRlcCoeffs)) + copy(combinedRlcCoeffs, initialFormRlcCoeffs) + copy(combinedRlcCoeffs[len(initialFormRlcCoeffs):], oodsRlcCoeffs) + + return InitialSumcheckData{ + InitialOODQueries: oodPoints, + InitialCombinationRandomness: combinedRlcCoeffs, + }, lastEval, initialSumcheckFoldingRandomness, nil +} + +// runWhirSumcheckRounds mirrors the Rust WHIR quadratic sumcheck verifier +// (whir/src/protocols/sumcheck.rs Config::verify). +// +// Each round the prover sends two coefficients (c0, c2) of a quadratic +// polynomial P(x) = c0 + c1·x + c2·x². The third coefficient c1 is derived +// from the sum constraint P(0) + P(1) = sum, giving c1 = sum − 2·c0 − c2. +// After squeezing a folding challenge r, the sum is updated to P(r). +func runWhirSumcheckRounds( + api frontend.API, + sum frontend.Variable, + nimue gnarkNimue.Nimue, + numRounds int, +) ([]frontend.Variable, frontend.Variable, error) { + foldingRandomness := make([]frontend.Variable, numRounds) + + for i := range numRounds { + coeffs := make([]frontend.Variable, 2) + if err := nimue.FillNextScalars(coeffs); err != nil { + return nil, nil, fmt.Errorf("sumcheck round %d: %w", i, err) + } + c0 := coeffs[0] + c2 := coeffs[1] + + c1 := api.Sub(sum, api.Add(api.Add(c0, c0), c2)) + + rBuf := make([]frontend.Variable, 1) + if err := nimue.FillChallengeScalars(rBuf); err != nil { + return nil, nil, fmt.Errorf("sumcheck round %d challenge: %w", i, err) + } + foldingRandomness[i] = rBuf[0] + + r := foldingRandomness[i] + sum = api.Add(api.Mul(api.Add(api.Mul(c2, r), c1), r), c0) + } + return foldingRandomness, sum, nil +} + +func getStirChallenges( + api frontend.API, + nimue gnarkNimue.Nimue, + numQueries int, + domainSize int, + foldingFactorPower int, +) ([]frontend.Variable, error) { + foldedDomainSize := domainSize / foldingFactorPower + domainSizeBytes := (bits.Len(uint(foldedDomainSize*2-1)) - 1 + 7) / 8 + + stirQueries := make([]uints.U8, domainSizeBytes*numQueries) + if err := nimue.FillChallengeBytes(stirQueries); err != nil { + return nil, err + } + bitLength := bits.Len(uint(foldedDomainSize)) - 1 + + indexes := make([]frontend.Variable, numQueries) + for i := range numQueries { + var value frontend.Variable = 0 + for j := range domainSizeBytes { + value = api.Add(stirQueries[j+i*domainSizeBytes].Val, api.Mul(value, 256)) + } + + bitsOfValue := api.ToBinary(value) + indexes[i] = api.FromBinary(bitsOfValue[:bitLength]...) + } + + return indexes, nil +} + +func generateEmptyMainRoundData(circuit WHIRParams) MainRoundData { + return MainRoundData{ + OODPoints: make([][]frontend.Variable, len(circuit.RoundParametersOODSamples)), + StirChallengesPoints: make([][]frontend.Variable, len(circuit.RoundParametersOODSamples)), + CombinationRandomness: make([][]frontend.Variable, len(circuit.RoundParametersOODSamples)), + } +} + +// ExponentVar computes base^exp using square-and-multiply with a field element exponent. +// numBits determines how many bits of exp to consider. +func ExponentVar(api frontend.API, base frontend.Variable, exp frontend.Variable, numBits int) frontend.Variable { + expBits := api.ToBinary(exp, numBits) + output := frontend.Variable(1) + multiply := base + for i := range expBits { + output = api.Select(expBits[i], api.Mul(output, multiply), output) + multiply = api.Mul(multiply, multiply) + } + return output +} + +// BitReversedExponentVar computes base^(bit_reverse(exp, numBits)). +// The provekit NTT (RSFr) uses bit-reversed evaluation order, so domain +// evaluation points are generator^(bit_reverse(index)) rather than +// generator^index. This function applies the reversal in-circuit. +func BitReversedExponentVar(api frontend.API, base frontend.Variable, exp frontend.Variable, numBits int) frontend.Variable { + expBits := api.ToBinary(exp, numBits) + // Reverse the bit order: expBits is LSB-first, so reversing gives + // the bit-reversed index in LSB-first order. + reversed := make([]frontend.Variable, numBits) + for i := range numBits { + reversed[i] = expBits[numBits-1-i] + } + reversedExp := api.FromBinary(reversed...) + return ExponentVar(api, base, reversedExp, numBits) +} + +// BitReverseInt reverses the lowest numBits bits of v. +func BitReverseInt(v int, numBits int) int { + result := 0 + for b := range numBits { + if v&(1< 0 { + _, _, err := PoW(api, sc, nimue, difficulty) + if err != nil { + return err + } + } + return nil +} + +// PoW performs a proof-of-work verification using nimue transcript and Skyscraper hash. +func PoW(api frontend.API, sc *skyscraper.Skyscraper, nimue gnarkNimue.Nimue, difficulty int) ([]uints.U8, []uints.U8, error) { + challenge := make([]uints.U8, 32) + if err := nimue.FillChallengeBytes(challenge); err != nil { + return nil, nil, err + } + nonce := make([]uints.U8, 8) + if err := nimue.FillNextBytes(nonce); err != nil { + return nil, nil, err + } + challengeFieldElement := LittleEndianFromUints(api, challenge) + nonceFieldElement := LittleEndianFromUints(api, nonce) + err := CheckPoW(api, sc, challengeFieldElement, nonceFieldElement, difficulty) + if err != nil { + return nil, nil, err + } + return challenge, nonce, nil +} + +func LittleEndianFromUints(api frontend.API, varArr []uints.U8) frontend.Variable { + frontendVar := frontend.Variable(0) + for i := range varArr { + frontendVar = api.Add(api.Mul(256, frontendVar), varArr[len(varArr)-1-i].Val) + } + return frontendVar +} + +// CheckPoW verifies a proof-of-work using Skyscraper hash. +// Compares only the first limb (low 64 bits) of the hash against the first +// limb of the threshold (modulus >> difficulty). +func CheckPoW(api frontend.API, sc *skyscraper.Skyscraper, challenge frontend.Variable, nonce frontend.Variable, difficulty int) error { + maxUint64, _ := new(big.Int).SetString("18446744073709551615", 10) + api.AssertIsLessOrEqual(nonce, maxUint64) + + hash := sc.CompressV2(challenge, nonce) + + // Decompose hash into 254 bits (BN254 field element size) + hashBits := api.ToBinary(hash, 254) + + // Reconstruct the first limb (low 64 bits) from bits + firstLimb := api.FromBinary(hashBits[:64]...) + + // Compute threshold first limb: (modulus >> difficulty) & 0xFFFFFFFFFFFFFFFF + modulus, _ := new(big.Int).SetString("21888242871839275222246405745257275088548364400416034343698204186575808495617", 10) + threshold := new(big.Int).Rsh(modulus, uint(difficulty)) + threshold.And(threshold, maxUint64) + + api.AssertIsLessOrEqual(firstLimb, threshold) + return nil +} diff --git a/recursive-verifier/app/whir/structs.go b/recursive-verifier/app/whir/structs.go new file mode 100644 index 000000000..7a5586b68 --- /dev/null +++ b/recursive-verifier/app/whir/structs.go @@ -0,0 +1,88 @@ +package whir + +import ( + "github.com/consensys/gnark/frontend" +) + +type ParsedCommitment struct { + Root frontend.Variable + OodPoints []frontend.Variable + OodAnswers []frontend.Variable // flat: out_domain_samples * num_vectors +} + +type Statement struct { + Constraints []MLConstraint + NVars int +} + +type WHIRParams struct { + ParamNRounds int + FoldingFactorArray []int + RoundParametersOODSamples []int + RoundParametersNumOfQueries []int + PowBits []int + FinalQueries int + FinalPowBits int + FinalFoldingPowBits int + StartingDomainBackingDomainGenerator frontend.Variable + DomainSize int + CommitmentOODSamples int + FinalSumcheckRounds int + MVParamsNumberOfVariables int + BatchSize int + InitialInDomainSamples int + // OmegaFull is the generator of the full NTT domain (order = DomainSize). + // Used to compute gamma points for batched_h_claims verification. + OmegaFull frontend.Variable + // Zeta = OmegaFull^(DomainSize/interleavingDepth), the interleaving coset generator. + Zeta frontend.Variable +} + +type InitialSumcheckData struct { + InitialOODQueries []frontend.Variable + InitialCombinationRandomness []frontend.Variable +} + +type MLConstraint struct { + Point []frontend.Variable + Evaluation frontend.Variable +} + +type MainRoundData struct { + OODPoints [][]frontend.Variable + StirChallengesPoints [][]frontend.Variable + CombinationRandomness [][]frontend.Variable +} + +// FinalClaimCircuit mirrors the Rust FinalClaim for the gnark circuit. +// The caller must verify: LinearFormRLC == Σ(RLCCoefficients[i] * weight_i.mle_evaluate(EvaluationPoint)) +type FinalClaimCircuit struct { + EvaluationPoint []frontend.Variable + RLCCoefficients []frontend.Variable + LinearFormRLC frontend.Variable +} + +// VerifyResult bundles all outputs from VerifyWhir. +type VerifyResult struct { + TotalFoldingRandomness []frontend.Variable + FinalClaim FinalClaimCircuit +} + +// RoundMerkleEntry holds the Merkle proof data for one round of WHIR opening. +type RoundMerkleEntry struct { + // Leaf values (submatrix): [query_idx][fold_element_idx] + Leaves [][]frontend.Variable + // Leaf sibling hashes for the Merkle proof: [query_idx] + SiblingHashes []frontend.Variable + // Auth path hashes for the Merkle proof: [query_idx][level] + AuthPaths [][]frontend.Variable + // Leaf indexes in the folded domain: [query_idx] + LeafIndexes []frontend.Variable +} + +// WhirMerkleData holds all Merkle proof data for a single VerifyWhir call. +// Rounds[0..nRounds-1] correspond to main round openings; +// Rounds[nRounds] is the final round opening. +type WhirMerkleData struct { + Rounds []RoundMerkleEntry +} diff --git a/recursive-verifier/app/whir/utils.go b/recursive-verifier/app/whir/utils.go new file mode 100644 index 000000000..3fa337510 --- /dev/null +++ b/recursive-verifier/app/whir/utils.go @@ -0,0 +1,171 @@ +package whir + +import ( + "github.com/consensys/gnark/frontend" + gnarkNimue "github.com/reilabs/gnark-nimue" + skyscraper "github.com/reilabs/gnark-skyscraper" +) + +// geometricChallenge mirrors Rust's geometric_challenge. +// Returns [1] for count <= 1 (no entropy sourced), or [1, x, x^2, ..., x^{count-1}] +// for count > 1 where x is squeezed from the transcript. +func geometricChallenge(api frontend.API, nimue gnarkNimue.Nimue, count int) ([]frontend.Variable, error) { + switch count { + case 0: + return nil, nil + case 1: + return []frontend.Variable{frontend.Variable(1)}, nil + default: + x := make([]frontend.Variable, 1) + if err := nimue.FillChallengeScalars(x); err != nil { + return nil, err + } + return ExpandRandomness(api, x[0], count), nil + } +} + +// Given some randomness r, return a vector r^0, r^1,...r^{len-1} +func ExpandRandomness(api frontend.API, base frontend.Variable, len int) []frontend.Variable { + res := make([]frontend.Variable, len) + acc := frontend.Variable(1) + for i := range len { + res[i] = acc + acc = api.Mul(acc, base) + } + return res +} + +func DotProduct(api frontend.API, a []frontend.Variable, b []frontend.Variable) frontend.Variable { + var acc = frontend.Variable(0) + for i := range a { + acc = api.Add(acc, api.Mul(a[i], b[i])) + } + return acc +} + +// TensorProduct computes the tensor (Kronecker) product of two vectors: +// result[i*len(b) + j] = a[i] * b[j] +func TensorProduct(api frontend.API, a []frontend.Variable, b []frontend.Variable) []frontend.Variable { + result := make([]frontend.Variable, len(a)*len(b)) + for i, x := range a { + for j, y := range b { + result[i*len(b)+j] = api.Mul(x, y) + } + } + return result +} + +func MultivarPoly(coefs []frontend.Variable, vars []frontend.Variable, api frontend.API) frontend.Variable { + if len(vars) == 0 { + return coefs[0] + } + deg_zero := MultivarPoly(coefs[:len(coefs)/2], vars[:len(vars)-1], api) + deg_one := api.Mul(vars[len(vars)-1], MultivarPoly(coefs[len(coefs)/2:], vars[:len(vars)-1], api)) + return api.Add(deg_zero, deg_one) +} + +func UnivarPoly(api frontend.API, coefficients []frontend.Variable, points []frontend.Variable) []frontend.Variable { + if len(points) == 0 { + return coefficients + } + + results := make([]frontend.Variable, len(points)) + for j := range points { + ans := frontend.Variable(0) + for i := range coefficients { + ans = api.Add(api.Mul(ans, points[j]), coefficients[len(coefficients)-1-i]) + } + results[j] = ans + } + return results +} + +// computeEqWeights computes eq(point, p) for all binary points p on the hypercube. +// Mirrors Rust MultilinearPoint::eq_weights / eq_poly. +// For point = [r_0, ..., r_{n-1}], returns 2^n values where +// result[p] = ∏_i (bit_i(p) ? r_{n-1-i} : (1 - r_{n-1-i})) +// matching Rust's reverse-iteration convention in eq_poly. +func computeEqWeights(api frontend.API, point []frontend.Variable) []frontend.Variable { + n := len(point) + size := 1 << n + result := make([]frontend.Variable, size) + result[0] = frontend.Variable(1) + cur := 1 + for i := 0; i < n; i++ { + for j := cur - 1; j >= 0; j-- { + lo := api.Mul(result[j], api.Sub(frontend.Variable(1), point[i])) + hi := api.Sub(result[j], lo) // result[j]*point[i] = result[j] - lo + result[2*j] = lo + result[2*j+1] = hi + } + cur *= 2 + } + return result +} + +// UnivarMleEvaluate computes the multilinear extension of the univariate +// evaluation linear form (1, x, x^2, ..., x^{2^n - 1}) at a given point. +// Mirrors Rust UnivariateEvaluation::mle_evaluate: +// +// Π_i ((1 - r_i) + r_i · x^{2^{n-1-i}}) +// +// This is NOT the same as EqPolyOutside(ExpandFromUnivariate(x, n), r) +// which computes the eq polynomial between expanded coordinates and r. +func UnivarMleEvaluate(api frontend.API, univarPoint frontend.Variable, point []frontend.Variable) frontend.Variable { + n := len(point) + result := frontend.Variable(1) + x2i := univarPoint + for i := n - 1; i >= 0; i-- { + factor := api.Add(api.Sub(frontend.Variable(1), point[i]), api.Mul(point[i], x2i)) + result = api.Mul(result, factor) + x2i = api.Mul(x2i, x2i) + } + return result +} + +// MultilinearEvalCircuit evaluates the multilinear extension of `values` at +// `point`: MLE(point) = Σ_i values[i] * eq(i, point). +// len(values) must equal 2^len(point). +func MultilinearEvalCircuit(api frontend.API, point []frontend.Variable, values []frontend.Variable) frontend.Variable { + eqW := computeEqWeights(api, point) + return DotProduct(api, eqW, values) +} + +// verifyMerkleProofs verifies Merkle membership proofs using Skyscraper CompressV2. +// Each leaf is hashed to a single field element, then the auth path is traversed +// up to the root. +func verifyMerkleProofs( + api frontend.API, + sc *skyscraper.Skyscraper, + leaves [][]frontend.Variable, + leafIndexes []frontend.Variable, + siblingHashes []frontend.Variable, + authPaths [][]frontend.Variable, + rootHash frontend.Variable, +) { + for i := range leaves { + treeHeight := len(authPaths[i]) + 1 + leafIndexBits := api.ToBinary(leafIndexes[i], treeHeight) + + // Hash the leaf elements into a single commitment. + claimedLeafHash := sc.CompressV2(leaves[i][0], leaves[i][1]) + for x := 2; x < len(leaves[i]); x++ { + claimedLeafHash = sc.CompressV2(claimedLeafHash, leaves[i][x]) + } + + // Level 0: combine with sibling. + dir := leafIndexBits[0] + left := api.Select(dir, siblingHashes[i], claimedLeafHash) + right := api.Select(dir, claimedLeafHash, siblingHashes[i]) + currentHash := sc.CompressV2(left, right) + + // Remaining levels. + for level := 1; level < treeHeight; level++ { + indexBit := api.And(leafIndexBits[level], 1) + left = api.Select(indexBit, authPaths[i][level-1], currentHash) + right = api.Select(indexBit, currentHash, authPaths[i][level-1]) + currentHash = sc.CompressV2(left, right) + } + api.AssertIsEqual(currentHash, rootHash) + } +} diff --git a/recursive-verifier/app/whir/whir_verifier.go b/recursive-verifier/app/whir/whir_verifier.go new file mode 100644 index 000000000..7ec6cfbfb --- /dev/null +++ b/recursive-verifier/app/whir/whir_verifier.go @@ -0,0 +1,365 @@ +package whir + +import ( + "fmt" + "math/bits" + + "github.com/consensys/gnark/frontend" + gnarkNimue "github.com/reilabs/gnark-nimue" + skyscraper "github.com/reilabs/gnark-skyscraper" +) + +func VerifyWhir( + api frontend.API, + sc *skyscraper.Skyscraper, + nimue gnarkNimue.Nimue, + commitment ParsedCommitment, + statements []Statement, + params WHIRParams, + merkleData *WhirMerkleData, // nil to skip Merkle verification +) (result *VerifyResult, err error) { + var totalFoldingRandomness []frontend.Variable + + numVectors := params.BatchSize + + // Complete the constraint and evaluation matrix with OODs and their cross-terms. + numOODConstraints := 0 + var oodMatrix []frontend.Variable + vectorOffset := 0 + committedOODRows := commitment.OodAnswers + numVectorsPerCommitment := params.BatchSize + for i := 0; i < params.CommitmentOODSamples; i++ { + for j := 0; j < numVectors; j++ { + if j >= vectorOffset && j < numVectorsPerCommitment+vectorOffset { + oodMatrix = append(oodMatrix, committedOODRows[i*numVectorsPerCommitment+(j-vectorOffset)]) + } else { + // Cross-term: read from transcript (absorb into sponge). + crossTerm := make([]frontend.Variable, 1) + if err = nimue.FillNextScalars(crossTerm); err != nil { + return nil, fmt.Errorf("ood cross-term: %w", err) + } + oodMatrix = append(oodMatrix, crossTerm[0]) + } + } + numOODConstraints++ + } + + // Random linear combination of the vectors. + vectorRlcCoeffs, err := geometricChallenge(api, nimue, numVectors) + if err != nil { + return nil, fmt.Errorf("vector_rlc: %w", err) + } + + // Random linear combination of the constraints. + numLinearForms := len(statements) + // Rust orders constraints as [linear_forms..., oods...]. + constraintRlcCoeffs, err := geometricChallenge(api, nimue, numLinearForms+numOODConstraints) + if err != nil { + return nil, fmt.Errorf("constraint_rlc: %w", err) + } + initialFormRlcCoeffs := constraintRlcCoeffs[:numLinearForms] + oodsRlcCoeffs := constraintRlcCoeffs[numLinearForms:] + + // Compute "the sum" (mirrors Rust whir::verifier lines 110-118) + // Each statement has one evaluation per vector. For numVectors=1 (typical), + // each statement contributes rlc[i] * eval. For numVectors>1, the evaluations + // are combined via the vector RLC. + theSum := frontend.Variable(0) + for i, rlcCoeff := range initialFormRlcCoeffs { + nConstraints := len(statements[i].Constraints) + evaluationRow := make([]frontend.Variable, nConstraints) + for j := range nConstraints { + evaluationRow[j] = statements[i].Constraints[j].Evaluation + } + // Pad or truncate to numVectors for the dot product + row := make([]frontend.Variable, numVectors) + for j := range numVectors { + if j < nConstraints { + row[j] = evaluationRow[j] + } else { + row[j] = frontend.Variable(0) + } + } + theSum = api.Add(theSum, api.Mul(rlcCoeff, DotProduct(api, vectorRlcCoeffs, row))) + } + for i, rlcCoeff := range oodsRlcCoeffs { + oodsRow := oodMatrix[i*numVectors : (i+1)*numVectors] + theSum = api.Add(theSum, api.Mul(rlcCoeff, DotProduct(api, vectorRlcCoeffs, oodsRow))) + } + + // Perform the initial sumcheck + initialSumcheckData, theSum, initialSumcheckFoldingRandomness, err := initialSumcheck(api, nimue, theSum, commitment.OodPoints, oodsRlcCoeffs, initialFormRlcCoeffs, params) + if err != nil { + return nil, err + } + + mainRoundData := generateEmptyMainRoundData(params) + expDomainGenerator := ExponentVar(api, params.StartingDomainBackingDomainGenerator, frontend.Variable(1< 0 { + if err = nimue.FillChallengeScalars(roundOODPoints); err != nil { + return nil, fmt.Errorf("round %d ood points: %w", r, err) + } + if err = nimue.FillNextScalars(roundOODAnswers); err != nil { + return nil, fmt.Errorf("round %d ood answers: %w", r, err) + } + } + mainRoundData.OODPoints[r] = roundOODPoints + + if err = RunPoW(api, sc, nimue, params.PowBits[r]); err != nil { + return nil, fmt.Errorf("round %d pow: %w", r, err) + } + + // Generate STIR challenge indices from sponge. + // The number of queries and folding factor depend on whether we are + // opening the initial commitment or a previous round commitment. + var numQueries, foldingFactorPower int + if r == 0 { + numQueries = params.InitialInDomainSamples + foldingFactorPower = 1 << params.FoldingFactorArray[r] + } else { + numQueries = params.RoundParametersNumOfQueries[r-1] + foldingFactorPower = 1 << params.FoldingFactorArray[r-1] + } + stirIndexes, err2 := getStirChallenges(api, nimue, numQueries, domainSize, foldingFactorPower) + if err2 != nil { + err = err2 + return nil, fmt.Errorf("round %d stir: %w", r, err) + } + + // Verify Merkle proofs: each round opens the previous commitment. + if merkleData != nil && r < len(merkleData.Rounds) { + rd := merkleData.Rounds[r] + for q := range stirIndexes { + if q < len(rd.LeafIndexes) { + api.AssertIsEqual(stirIndexes[q], rd.LeafIndexes[q]) + } + } + verifyMerkleProofs(api, sc, rd.Leaves, rd.LeafIndexes, rd.SiblingHashes, rd.AuthPaths, prevRootHash) + } + + prevRootHash = rootHash[0] + + // Compute domain evaluation points from indices. + // The provekit NTT uses bit-reversed evaluation order (RSFr), so + // evaluation_points(idx) = generator^(bit_reverse(idx, log2(foldedDomainSize))). + foldedDomainSize := domainSize / foldingFactorPower + numBitsForReversal := bits.Len(uint(foldedDomainSize)) - 1 + mainRoundData.StirChallengesPoints[r] = make([]frontend.Variable, len(stirIndexes)) + for index, idx := range stirIndexes { + mainRoundData.StirChallengesPoints[r][index] = BitReversedExponentVar(api, expDomainGenerator, idx, numBitsForReversal) + } + + // Constraint values = OOD values + in-domain values from Merkle-verified leaves. + numInDomainQueries := len(stirIndexes) + constraintValues := make([]frontend.Variable, 0, len(roundOODAnswers)+numInDomainQueries) + constraintValues = append(constraintValues, roundOODAnswers...) + + if merkleData != nil && r < len(merkleData.Rounds) { + // Compute in-domain constraint values from verified leaf data. + // For the initial round (r==0), weights = tensor_product(polyRLC, eqWeights) + // where polyRLC = vectorRlcCoeffs. For subsequent rounds, polyRLC = [1]. + lastFoldRand := totalFoldingRandomness[len(totalFoldingRandomness)-params.FoldingFactorArray[r]:] + eqW := computeEqWeights(api, lastFoldRand) + var inDomainWeights []frontend.Variable + if r == 0 { + inDomainWeights = TensorProduct(api, vectorRlcCoeffs, eqW) + } else { + inDomainWeights = eqW + } + rd := merkleData.Rounds[r] + for q := range numInDomainQueries { + if q < len(rd.Leaves) { + constraintValues = append(constraintValues, DotProduct(api, inDomainWeights, rd.Leaves[q])) + } else { + constraintValues = append(constraintValues, frontend.Variable(0)) + } + } + } else { + for range numInDomainQueries { + constraintValues = append(constraintValues, frontend.Variable(0)) + } + } + + // Combination randomness + roundCombRlcCoeffs, err2 := geometricChallenge(api, nimue, len(constraintValues)) + if err2 != nil { + return nil, fmt.Errorf("round %d comb: %w", r, err2) + } + mainRoundData.CombinationRandomness[r] = roundCombRlcCoeffs + + constraintDot := DotProduct(api, roundCombRlcCoeffs, constraintValues) + theSum = api.Add(theSum, constraintDot) + + // Sumcheck round + var roundFoldingRandomness []frontend.Variable + roundFoldingRandomness, theSum, err = runWhirSumcheckRounds(api, theSum, nimue, params.FoldingFactorArray[r+1]) + if err != nil { + return nil, fmt.Errorf("round %d sumcheck: %w", r, err) + } + + totalFoldingRandomness = append(totalFoldingRandomness, roundFoldingRandomness...) + + domainSize /= 2 + numSquarings := 1 + params.FoldingFactorArray[r+1] - params.FoldingFactorArray[r] + for k := 0; k < numSquarings; k++ { + expDomainGenerator = api.Mul(expDomainGenerator, expDomainGenerator) + } + } + + // Read the final polynomial coefficients from the transcript. + finalVector := make([]frontend.Variable, 1< 0 { + if err = RunPoW(api, sc, nimue, params.FinalFoldingPowBits); err != nil { + return nil, fmt.Errorf("final folding pow: %w", err) + } + } + + // Deferred evaluation check + // + // Mirrors Rust whir verifier.rs lines 246-268: + // poly_eval = MLE(finalSumcheckRandomness, finalVector) + // linear_form_rlc = the_sum / poly_eval + // for each round's internal constraints, subtract: + // rlc_coeff * UnivariateEvaluation{point, size}.mle_evaluate(evaluationPoint) + // --------------------------------------------------------------- + + evaluationPoint := totalFoldingRandomness + + polyEval := MultilinearEvalCircuit(api, finalSumcheckRandomness, finalVector) + linearFormRLC := api.Div(theSum, polyEval) + + // Subtract initial round OOD evaluator contributions. + // Each OOD evaluator is UnivariateEvaluation{point, size} with + // size = domainSize / (1 << rate) = 2^MVParamsNumberOfVariables. + numInitialVars := params.MVParamsNumberOfVariables + initialSubPoint := evaluationPoint[len(evaluationPoint)-numInitialVars:] + numOODInitial := len(initialSumcheckData.InitialOODQueries) + for i := 0; i < numOODInitial; i++ { + oodIdx := numLinearForms + i // OOD coeffs come after linear form coeffs + mleVal := UnivarMleEvaluate(api, initialSumcheckData.InitialOODQueries[i], initialSubPoint) + linearFormRLC = api.Sub(linearFormRLC, api.Mul(initialSumcheckData.InitialCombinationRandomness[oodIdx], mleVal)) + } + + // Subtract main round constraint contributions (OOD + in-domain STIR evaluators). + numVarsForRound := numInitialVars + for r := range params.ParamNRounds { + numVarsForRound -= params.FoldingFactorArray[r] + subPoint := evaluationPoint[len(evaluationPoint)-numVarsForRound:] + + roundOODCount := params.RoundParametersOODSamples[r] + roundCombRLC := mainRoundData.CombinationRandomness[r] + + // OOD evaluators for this round + for i := 0; i < roundOODCount; i++ { + mleVal := UnivarMleEvaluate(api, mainRoundData.OODPoints[r][i], subPoint) + linearFormRLC = api.Sub(linearFormRLC, api.Mul(roundCombRLC[i], mleVal)) + } + + // In-domain STIR evaluators for this round + stirPoints := mainRoundData.StirChallengesPoints[r] + for i, stirPt := range stirPoints { + mleVal := UnivarMleEvaluate(api, stirPt, subPoint) + linearFormRLC = api.Sub(linearFormRLC, api.Mul(roundCombRLC[roundOODCount+i], mleVal)) + } + } + + return &VerifyResult{ + TotalFoldingRandomness: totalFoldingRandomness, + FinalClaim: FinalClaimCircuit{ + EvaluationPoint: evaluationPoint, + RLCCoefficients: initialFormRlcCoeffs, + LinearFormRLC: linearFormRLC, + }, + }, nil +} + +// VerifyClaim verifies that the WHIR-committed polynomial is consistent with +// the provided weight MLE evaluations. It checks: +// +// LinearFormRLC == Σ(RLCCoefficients[i] * weightMLEEvals[i]) +// +// The caller is responsible for computing the weight MLE evaluations +// (e.g. public input weight, A/B/C matrix covectors, blinding covector) +// and passing them in the correct order matching the RLC coefficients. +func (fc *FinalClaimCircuit) VerifyClaim(api frontend.API, weightMLEEvals []frontend.Variable) { + expectedRLC := frontend.Variable(0) + for i, mleVal := range weightMLEEvals { + expectedRLC = api.Add(expectedRLC, api.Mul(fc.RLCCoefficients[i], mleVal)) + } + api.AssertIsEqual(fc.LinearFormRLC, expectedRLC) +} + +// ExpandFromUnivariate converts a univariate evaluation point into a multilinear one. +// +// It maps a single point 'y' to a vector of coordinates: +// [y^(2^(n-1)), ..., y^4, y^2, y] +// +// This corresponds to the Big-Endian binary decomposition mapping used in +// protocols like Sumcheck or Spartan. +func ExpandFromUnivariate(api frontend.API, point frontend.Variable, numVariables int) []frontend.Variable { + res := make([]frontend.Variable, numVariables) + current := point + + for i := 0; i < numVariables; i++ { + res[numVariables-1-i] = current + current = api.Mul(current, current) + } + + return res +} diff --git a/recursive-verifier/go.mod b/recursive-verifier/go.mod index da78e237e..826dd63a0 100644 --- a/recursive-verifier/go.mod +++ b/recursive-verifier/go.mod @@ -6,7 +6,7 @@ require ( github.com/consensys/gnark v0.13.0 github.com/consensys/gnark-crypto v0.18.0 github.com/gofiber/fiber/v2 v2.52.9 - github.com/reilabs/gnark-nimue v0.0.7-0.20250819071945-7382324c8642 + github.com/reilabs/gnark-nimue v0.1.1 github.com/reilabs/gnark-skyscraper v0.0.0-20250819020215-db52e4ee2949 github.com/reilabs/go-ark-serialize v0.0.0-20241120151746-4148c0ca17e3 github.com/urfave/cli/v2 v2.27.7 diff --git a/recursive-verifier/go.sum b/recursive-verifier/go.sum index 22e5a3716..4b972efe6 100644 --- a/recursive-verifier/go.sum +++ b/recursive-verifier/go.sum @@ -46,8 +46,8 @@ github.com/mattn/go-runewidth v0.0.16/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/reilabs/gnark-nimue v0.0.7-0.20250819071945-7382324c8642 h1:yszb3+OVg17bugNU8L7oXAyvJGj0LsZt82TXAyi9muw= -github.com/reilabs/gnark-nimue v0.0.7-0.20250819071945-7382324c8642/go.mod h1:HZvEohNWtV3PHogGBe2HziYXX0CetYPxfLapf5NW6ss= +github.com/reilabs/gnark-nimue v0.1.1 h1:UpjMdqcKbAsP/2yBqCQG0XELgJrTuap76+lO8VsTgXY= +github.com/reilabs/gnark-nimue v0.1.1/go.mod h1:7zB2GNbqvVrCVyWZbjhOvYyFC3PbiBVV2RrhGAfwuh4= github.com/reilabs/gnark-skyscraper v0.0.0-20250819020215-db52e4ee2949 h1:ywiOSRWCIaOpSg0exeNRG0+rz4c62J6Z+IrJpKTso+c= github.com/reilabs/gnark-skyscraper v0.0.0-20250819020215-db52e4ee2949/go.mod h1:kUPBp0nHa5TpefcqeK4Otfpz8WHgpXpQ8f+NgfEF5Ks= github.com/reilabs/go-ark-serialize v0.0.0-20241120151746-4148c0ca17e3 h1:EZA/mA0ju0eAsvcBADuKRPYSL1UYoeGCAM/vNEWeCoA= diff --git a/tooling/cli/src/cmd/generate_gnark_inputs.rs b/tooling/cli/src/cmd/generate_gnark_inputs.rs index d07614fe5..4d7bf7571 100644 --- a/tooling/cli/src/cmd/generate_gnark_inputs.rs +++ b/tooling/cli/src/cmd/generate_gnark_inputs.rs @@ -53,7 +53,9 @@ impl Command for Args { .context("verifier is missing whir_for_witness config")?; write_gnark_parameters_to_file( + &verifier.whir_for_witness.clone().unwrap(), &wfw.whir_witness.blinded_commitment, + &wfw.whir_witness.blinding_commitment, &proof.whir_r1cs_proof, wfw.m_0, wfw.m, diff --git a/tooling/provekit-gnark/src/gnark_config.rs b/tooling/provekit-gnark/src/gnark_config.rs index 4b2ece92b..99d7e36b5 100644 --- a/tooling/provekit-gnark/src/gnark_config.rs +++ b/tooling/provekit-gnark/src/gnark_config.rs @@ -1,6 +1,6 @@ use { ark_poly::{EvaluationDomain, GeneralEvaluationDomain}, - provekit_common::{FieldElement, PublicInputs, WhirConfig, WhirR1CSProof}, + provekit_common::{FieldElement, PublicInputs, WhirConfig, WhirR1CSProof, WhirR1CSScheme}, serde::{Deserialize, Serialize}, std::{fs::File, io::Write}, tracing::instrument, @@ -8,45 +8,64 @@ use { #[derive(Debug, Serialize, Deserialize)] pub struct GnarkConfig { - pub whir_config_witness: WHIRConfigGnark, + pub blinded_commitment_whir_config: WHIRConfigGnark, + pub blinding_commitment_whir_config: WHIRConfigGnark, pub log_num_constraints: usize, - pub log_num_variables: usize, - pub log_a_num_terms: usize, - pub narg_string: Vec, - pub narg_string_len: usize, - pub hints: Vec, - pub hints_len: usize, - pub num_challenges: usize, - pub w1_size: usize, - pub public_inputs: PublicInputs, + pub log_num_variables: usize, + pub log_a_num_terms: usize, + pub narg_string: Vec, + pub narg_string_len: usize, + pub hints: Vec, + pub hints_len: usize, + pub protocol_id: Vec, + pub num_challenges: usize, + pub challenge_offsets: Vec, + pub w1_size: usize, + pub public_inputs: PublicInputs, } #[derive(Debug, Serialize, Deserialize)] pub struct WHIRConfigGnark { /// Number of WHIR rounds. - pub n_rounds: usize, + pub n_rounds: usize, /// Reed-Solomon rate (log₂ of inverse rate). - pub rate: usize, + pub rate: usize, /// Number of variables in the multilinear polynomial. - pub n_vars: usize, + pub n_vars: usize, /// Folding factor per round. - pub folding_factor: Vec, + pub folding_factor: Vec, /// Out-of-domain samples per round. - pub ood_samples: Vec, + pub ood_samples: Vec, /// Number of queries per round. - pub num_queries: Vec, - /// Proof-of-work bits per round. - pub pow_bits: Vec, + pub num_queries: Vec, + /// Proof-of-work bits per round (truncated integer, kept for backwards + /// compat). + pub pow_bits: Vec, + /// Proof-of-work u64 thresholds per WHIR round (exact values from Rust). + pub pow_thresholds: Vec, + /// Sumcheck round PoW thresholds per WHIR round. + pub sumcheck_pow_thresholds: Vec, + /// Initial sumcheck round PoW threshold. + pub initial_sumcheck_pow_threshold: u64, + /// Initial skip PoW threshold (used when initial sumcheck is skipped). + pub initial_skip_pow_threshold: u64, /// Final round query count. - pub final_queries: usize, - /// Final round proof-of-work bits. - pub final_pow_bits: i32, - /// Final folding proof-of-work bits. + pub final_queries: usize, + /// Final round proof-of-work bits (truncated integer). + pub final_pow_bits: i32, + /// Final round proof-of-work threshold (exact u64). + pub final_pow_threshold: u64, + /// Final folding proof-of-work bits (truncated integer). pub final_folding_pow_bits: i32, + /// Final folding proof-of-work threshold (exact u64). + pub final_folding_pow_threshold: u64, /// Domain generator as a string. - pub domain_generator: String, + pub domain_generator: String, /// Batch size (number of polynomials committed together). - pub batch_size: usize, + pub batch_size: usize, + /// Initial committer in-domain samples (query count for zkWHIR in-domain + /// verification). + pub initial_in_domain_samples: usize, } impl WHIRConfigGnark { @@ -85,6 +104,18 @@ impl WHIRConfigGnark { f64::from(whir::protocols::proof_of_work::difficulty(rc.pow.threshold)) as i32 }) .collect(); + let pow_thresholds: Vec = whir_params + .round_configs + .iter() + .map(|rc| rc.pow.threshold) + .collect(); + let sumcheck_pow_thresholds: Vec = whir_params + .round_configs + .iter() + .map(|rc| rc.sumcheck.round_pow.threshold) + .collect(); + let initial_sumcheck_pow_threshold = whir_params.initial_sumcheck.round_pow.threshold; + let initial_skip_pow_threshold = whir_params.initial_skip_pow.threshold; // If there are no folding rounds, fall back to the initial commitment's // in-domain samples. @@ -107,6 +138,7 @@ impl WHIRConfigGnark { let domain_generator = format!("{}", domain.group_gen()); let batch_size = whir_params.initial_committer.num_vectors; + let initial_in_domain_samples = whir_params.initial_committer.in_domain_samples; WHIRConfigGnark { n_rounds, @@ -116,18 +148,27 @@ impl WHIRConfigGnark { ood_samples, num_queries, pow_bits, + pow_thresholds, + sumcheck_pow_thresholds, + initial_sumcheck_pow_threshold, + initial_skip_pow_threshold, final_queries, final_pow_bits, + final_pow_threshold: whir_params.final_pow.threshold, final_folding_pow_bits, + final_folding_pow_threshold: whir_params.final_sumcheck.round_pow.threshold, domain_generator, batch_size, + initial_in_domain_samples, } } } #[instrument(skip_all)] pub fn gnark_parameters( - whir_params_witness: &WhirConfig, + scheme: &WhirR1CSScheme, + blinded_commitment: &WhirConfig, + blinding_commitment: &WhirConfig, proof: &WhirR1CSProof, m_0: usize, m: usize, @@ -136,8 +177,11 @@ pub fn gnark_parameters( w1_size: usize, public_inputs: &PublicInputs, ) -> GnarkConfig { + let ds = scheme.create_domain_separator(); + let protocol_id: Vec = ds.protocol_id.to_vec(); GnarkConfig { - whir_config_witness: WHIRConfigGnark::new(whir_params_witness), + blinded_commitment_whir_config: WHIRConfigGnark::new(blinded_commitment), + blinding_commitment_whir_config: WHIRConfigGnark::new(blinding_commitment), log_num_constraints: m_0, log_num_variables: m, log_a_num_terms: a_num_terms, @@ -145,7 +189,9 @@ pub fn gnark_parameters( narg_string_len: proof.narg_string.len(), hints: proof.hints.clone(), hints_len: proof.hints.len(), + protocol_id, num_challenges, + challenge_offsets: scheme.challenge_offsets.clone(), w1_size, public_inputs: public_inputs.clone(), } @@ -153,7 +199,9 @@ pub fn gnark_parameters( #[instrument(skip_all)] pub fn write_gnark_parameters_to_file( - whir_params_witness: &WhirConfig, + scheme: &WhirR1CSScheme, + blinded_commitment: &WhirConfig, + blinding_commitment: &WhirConfig, proof: &WhirR1CSProof, m_0: usize, m: usize, @@ -164,7 +212,9 @@ pub fn write_gnark_parameters_to_file( file_path: &str, ) { let gnark_config = gnark_parameters( - whir_params_witness, + scheme, + blinded_commitment, + blinding_commitment, proof, m_0, m, diff --git a/tooling/verifier-server/src/services/verification.rs b/tooling/verifier-server/src/services/verification.rs index 312f64034..53cf59079 100644 --- a/tooling/verifier-server/src/services/verification.rs +++ b/tooling/verifier-server/src/services/verification.rs @@ -88,7 +88,9 @@ impl VerificationService { .ok_or_else(|| AppError::Internal("WHIR scheme not found in verifier".to_string()))?; write_gnark_parameters_to_file( + whir_scheme, &whir_scheme.whir_witness.blinded_commitment, + &whir_scheme.whir_witness.blinding_commitment, &proof.whir_r1cs_proof, whir_scheme.m_0, whir_scheme.m,